Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR] Extend support for floating point attributes #572

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,20 @@ def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
let summary = "An attribute containing a floating-point value";
let description = [{
An fp attribute is a literal attribute that represents a floating-point
value of the specified floating-point type.
value of the specified floating-point type. Supporting only CIR FP types.
}];
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APFloat":$value);
let parameters = (ins
AttributeSelfTypeParameter<"", "::mlir::cir::CIRFPTypeInterface">:$type,
APFloatParameter<"">:$value
);
let builders = [
AttrBuilderWithInferredContext<(ins "Type":$type,
"const APFloat &":$value), [{
return $_get(type.getContext(), type, value);
return $_get(type.getContext(), mlir::cast<CIRFPTypeInterface>(type), value);
}]>,
AttrBuilder<(ins "Type":$type,
"const APFloat &":$value), [{
return $_get($_ctxt, mlir::cast<CIRFPTypeInterface>(type), value);
}]>,
];
let extraClassDeclaration = [{
Expand Down Expand Up @@ -319,7 +326,7 @@ def ComplexAttr : CIR_Attr<"Complex", "complex", [TypedAttrInterface]> {
contains values of the same CIR type.
}];

let parameters = (ins
let parameters = (ins
AttributeSelfTypeParameter<"", "mlir::cir::ComplexType">:$type,
"mlir::TypedAttr":$real, "mlir::TypedAttr":$imag);

Expand Down Expand Up @@ -820,7 +827,7 @@ def AddressSpaceAttr : CIR_Attr<"AddressSpace", "addrspace"> {
let extraClassDeclaration = [{
static constexpr char kTargetKeyword[] = "}]#targetASCase.symbol#[{";
static constexpr int32_t kFirstTargetASValue = }]#targetASCase.value#[{;

bool isLang() const;
bool isTarget() const;
unsigned getTargetValue() const;
Expand Down Expand Up @@ -980,7 +987,7 @@ def ASTCallExprAttr : AST<"CallExpr", "call.expr",
// VisibilityAttr
//===----------------------------------------------------------------------===//

def VK_Default : I32EnumAttrCase<"Default", 1, "default">;
def VK_Default : I32EnumAttrCase<"Default", 1, "default">;
def VK_Hidden : I32EnumAttrCase<"Hidden", 2, "hidden">;
def VK_Protected : I32EnumAttrCase<"Protected", 3, "protected">;

Expand Down Expand Up @@ -1013,7 +1020,7 @@ def VisibilityAttr : CIR_Attr<"Visibility", "visibility"> {
bool isDefault() const { return getValue() == VisibilityKind::Default; };
bool isHidden() const { return getValue() == VisibilityKind::Hidden; };
bool isProtected() const { return getValue() == VisibilityKind::Protected; };
}];
}];
}


Expand Down Expand Up @@ -1160,7 +1167,7 @@ def AnnotationAttr : CIR_Attr<"Annotation", "annotation"> {
let parameters = (ins "StringAttr":$name,
"ArrayAttr":$args);

let assemblyFormat = "`<` struct($name, $args) `>`";
let assemblyFormat = "`<` struct($name, $args) `>`";

let extraClassDeclaration = [{
bool isNoArgs() const { return getArgs().empty(); };
Expand All @@ -1187,7 +1194,7 @@ def GlobalAnnotationValuesAttr : CIR_Attr<"GlobalAnnotationValues",
void *c __attribute__((annotate("noargvar")));
void foo(int i) __attribute__((annotate("noargfunc"))) {}
```
After CIR lowering prepare pass, compiler generates a
After CIR lowering prepare pass, compiler generates a
`GlobalAnnotationValuesAttr` like the following:
```
#cir<global_annotations [
Expand Down
52 changes: 17 additions & 35 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
mlir::Type ty);
static mlir::ParseResult
parseFloatLiteral(mlir::AsmParser &parser,
mlir::FailureOr<llvm::APFloat> &value, mlir::Type ty);
mlir::FailureOr<llvm::APFloat> &value,
mlir::cir::CIRFPTypeInterface fpType);

static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser,
mlir::IntegerAttr &value);
Expand Down Expand Up @@ -311,50 +312,31 @@ LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
// FPAttr definitions
//===----------------------------------------------------------------------===//

static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
mlir::Type ty) {
static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) {
p << value;
}

static mlir::ParseResult
parseFloatLiteral(mlir::AsmParser &parser,
mlir::FailureOr<llvm::APFloat> &value, mlir::Type ty) {
double rawValue;
if (parser.parseFloat(rawValue)) {
return parser.emitError(parser.getCurrentLocation(),
"expected floating-point value");
}

auto losesInfo = false;
value.emplace(rawValue);
static ParseResult parseFloatLiteral(AsmParser &parser,
FailureOr<APFloat> &value,
CIRFPTypeInterface fpType) {

auto tyFpInterface = dyn_cast<cir::CIRFPTypeInterface>(ty);
if (!tyFpInterface) {
// Parsing of the current floating-point literal has succeeded, but the
// given attribute type is invalid. This error will be reported later when
// the attribute is being verified.
return success();
}
APFloat parsedValue(0.0);
if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue))
return failure();

value->convert(tyFpInterface.getFloatSemantics(),
llvm::RoundingMode::TowardZero, &losesInfo);
value.emplace(parsedValue);
return success();
}

cir::FPAttr cir::FPAttr::getZero(mlir::Type type) {
return get(
type, APFloat::getZero(
mlir::cast<cir::CIRFPTypeInterface>(type).getFloatSemantics()));
FPAttr FPAttr::getZero(Type type) {
return get(type,
APFloat::getZero(
mlir::cast<CIRFPTypeInterface>(type).getFloatSemantics()));
}

LogicalResult cir::FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, APFloat value) {
auto fltTypeInterface = mlir::dyn_cast<cir::CIRFPTypeInterface>(type);
if (!fltTypeInterface) {
emitError() << "expected floating-point type";
return failure();
}
if (APFloat::SemanticsToEnum(fltTypeInterface.getFloatSemantics()) !=
LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
CIRFPTypeInterface fpType, APFloat value) {
if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) !=
APFloat::SemanticsToEnum(value.getSemantics())) {
emitError() << "floating-point semantics mismatch";
return failure();
Expand Down
25 changes: 25 additions & 0 deletions clang/test/CIR/IR/attribute.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: cir-opt %s -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s

cir.func @float_attrs_pass() {
"test.float_attrs"() {
// CHECK: float_attr = #cir.fp<2.000000e+00> : !cir.float
float_attr = #cir.fp<2.> : !cir.float
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = #cir.fp<-2.000000e+00> : !cir.float
float_attr = #cir.fp<-2.> : !cir.float
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = #cir.fp<2.000000e+00> : !cir.double
float_attr = #cir.fp<2.> : !cir.double
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = #cir.fp<2.000000e+00> : !cir.long_double<!cir.f80>
float_attr = #cir.fp<2.> : !cir.long_double<!cir.f80>
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = #cir.fp<2.000000e+00> : !cir.long_double<!cir.double>
float_attr = #cir.fp<2.> : !cir.long_double<!cir.double>
} : () -> ()
cir.return
}
90 changes: 90 additions & 0 deletions clang/test/CIR/IR/float.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// RUN: cir-opt %s | FileCheck %s

// Adapted from mlir/test/IR/parser.mlir

// CHECK-LABEL: @f32_special_values
cir.func @f32_special_values() {
// F32 signaling NaNs.
// CHECK: cir.const #cir.fp<0x7F800001> : !cir.float
%0 = cir.const #cir.fp<0x7F800001> : !cir.float
// CHECK: cir.const #cir.fp<0x7FBFFFFF> : !cir.float
%1 = cir.const #cir.fp<0x7FBFFFFF> : !cir.float

// F32 quiet NaNs.
// CHECK: cir.const #cir.fp<0x7FC00000> : !cir.float
%2 = cir.const #cir.fp<0x7FC00000> : !cir.float
// CHECK: cir.const #cir.fp<0xFFFFFFFF> : !cir.float
%3 = cir.const #cir.fp<0xFFFFFFFF> : !cir.float

// F32 positive infinity.
// CHECK: cir.const #cir.fp<0x7F800000> : !cir.float
%4 = cir.const #cir.fp<0x7F800000> : !cir.float
// F32 negative infinity.
// CHECK: cir.const #cir.fp<0xFF800000> : !cir.float
%5 = cir.const #cir.fp<0xFF800000> : !cir.float

cir.return
}

// CHECK-LABEL: @f64_special_values
cir.func @f64_special_values() {
// F64 signaling NaNs.
// CHECK: cir.const #cir.fp<0x7FF0000000000001> : !cir.double
%0 = cir.const #cir.fp<0x7FF0000000000001> : !cir.double
// CHECK: cir.const #cir.fp<0x7FF8000000000000> : !cir.double
%1 = cir.const #cir.fp<0x7FF8000000000000> : !cir.double

// F64 quiet NaNs.
// CHECK: cir.const #cir.fp<0x7FF0000001000000> : !cir.double
%2 = cir.const #cir.fp<0x7FF0000001000000> : !cir.double
// CHECK: cir.const #cir.fp<0xFFF0000001000000> : !cir.double
%3 = cir.const #cir.fp<0xFFF0000001000000> : !cir.double

// F64 positive infinity.
// CHECK: cir.const #cir.fp<0x7FF0000000000000> : !cir.double
%4 = cir.const #cir.fp<0x7FF0000000000000> : !cir.double
// F64 negative infinity.
// CHECK: cir.const #cir.fp<0xFFF0000000000000> : !cir.double
%5 = cir.const #cir.fp<0xFFF0000000000000> : !cir.double

// Check that values that can't be represented with the default format, use
// hex instead.
// CHECK: cir.const #cir.fp<0xC1CDC00000000000> : !cir.double
%6 = cir.const #cir.fp<0xC1CDC00000000000> : !cir.double

cir.return
}

// CHECK-LABEL: @f80_special_values
cir.func @f80_special_values() {
// F80 signaling NaNs.
// CHECK: cir.const #cir.fp<0x7FFFE000000000000001> : !cir.long_double<!cir.f80>
%0 = cir.const #cir.fp<0x7FFFE000000000000001> : !cir.long_double<!cir.f80>
// CHECK: cir.const #cir.fp<0x7FFFB000000000000011> : !cir.long_double<!cir.f80>
%1 = cir.const #cir.fp<0x7FFFB000000000000011> : !cir.long_double<!cir.f80>

// F80 quiet NaNs.
// CHECK: cir.const #cir.fp<0x7FFFC000000000100000> : !cir.long_double<!cir.f80>
%2 = cir.const #cir.fp<0x7FFFC000000000100000> : !cir.long_double<!cir.f80>
// CHECK: cir.const #cir.fp<0x7FFFE000000001000000> : !cir.long_double<!cir.f80>
%3 = cir.const #cir.fp<0x7FFFE000000001000000> : !cir.long_double<!cir.f80>

// F80 positive infinity.
// CHECK: cir.const #cir.fp<0x7FFF8000000000000000> : !cir.long_double<!cir.f80>
%4 = cir.const #cir.fp<0x7FFF8000000000000000> : !cir.long_double<!cir.f80>
// F80 negative infinity.
// CHECK: cir.const #cir.fp<0xFFFF8000000000000000> : !cir.long_double<!cir.f80>
%5 = cir.const #cir.fp<0xFFFF8000000000000000> : !cir.long_double<!cir.f80>

cir.return
}

// We want to print floats in exponential notation with 6 significant digits,
// but it may lead to precision loss when parsing back, in which case we print
// the decimal form instead.
// CHECK-LABEL: @f32_potential_precision_loss()
cir.func @f32_potential_precision_loss() {
// CHECK: cir.const #cir.fp<1.23697901> : !cir.float
%0 = cir.const #cir.fp<1.23697901> : !cir.float
cir.return
}
59 changes: 59 additions & 0 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -1378,3 +1378,62 @@ module {
cir.return
}
}
// -----

// Type of the attribute must be a CIR floating point type

// expected-error @below {{invalid kind of type specified}}
cir.global external @f = #cir.fp<0.5> : !cir.int<s, 32>

// -----

// Value must be a floating point literal or integer literal

// expected-error @below {{expected floating point literal}}
cir.global external @f = #cir.fp<"blabla"> : !cir.float

// -----

// Integer value must be in the width of the floating point type

// expected-error @below {{hexadecimal float constant out of range for type}}
cir.global external @f = #cir.fp<0x7FC000000> : !cir.float

// -----

// Integer value must be in the width of the floating point type

// expected-error @below {{hexadecimal float constant out of range for type}}
cir.global external @f = #cir.fp<0x7FC000007FC0000000> : !cir.double

// -----

// Integer value must be in the width of the floating point type

// expected-error @below {{hexadecimal float constant out of range for type}}
cir.global external @f = #cir.fp<0x7FC0000007FC0000007FC000000> : !cir.long_double<!cir.f80>

// -----

// Long double with `double` semnatics should have a value that fits in a double.

// CHECK: cir.global external @f = #cir.fp<0x7FC000007FC000000000> : !cir.long_double<!cir.f80>
cir.global external @f = #cir.fp<0x7FC000007FC000000000> : !cir.long_double<!cir.f80>

// expected-error @below {{hexadecimal float constant out of range for type}}
cir.global external @f = #cir.fp<0x7FC000007FC000000000> : !cir.long_double<!cir.double>

// -----

// Verify no need for type inside the attribute

// expected-error @below {{expected '>'}}
cir.global external @f = #cir.fp<0x7FC00000 : !cir.float> : !cir.float

// -----

// Verify literal must be hex or float

// expected-error @below {{unexpected decimal integer literal for a floating point value}}
// expected-note @below {{add a trailing dot to make the literal a float}}
cir.global external @f = #cir.fp<42> : !cir.float
2 changes: 1 addition & 1 deletion clang/test/CIR/Lowering/class.cir
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ module {
// CHECK: %0 = llvm.mlir.undef : !llvm.struct<"class.S1", (i32, f32, ptr)>
// CHECK: %1 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %2 = llvm.insertvalue %1, %0[0] : !llvm.struct<"class.S1", (i32, f32, ptr)>
// CHECK: %3 = llvm.mlir.constant(0.099999994 : f32) : f32
// CHECK: %3 = llvm.mlir.constant(1.000000e-01 : f32) : f32
// CHECK: %4 = llvm.insertvalue %3, %2[1] : !llvm.struct<"class.S1", (i32, f32, ptr)>
// CHECK: %5 = llvm.mlir.zero : !llvm.ptr
// CHECK: %6 = llvm.insertvalue %5, %4[2] : !llvm.struct<"class.S1", (i32, f32, ptr)>
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/Lowering/struct.cir
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ module {
// CHECK: %0 = llvm.mlir.undef : !llvm.struct<"struct.S1", (i32, f32, ptr)>
// CHECK: %1 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %2 = llvm.insertvalue %1, %0[0] : !llvm.struct<"struct.S1", (i32, f32, ptr)>
// CHECK: %3 = llvm.mlir.constant(0.099999994 : f32) : f32
// CHECK: %3 = llvm.mlir.constant(1.000000e-01 : f32) : f32
// CHECK: %4 = llvm.insertvalue %3, %2[1] : !llvm.struct<"struct.S1", (i32, f32, ptr)>
// CHECK: %5 = llvm.mlir.zero : !llvm.ptr
// CHECK: %6 = llvm.insertvalue %5, %4[2] : !llvm.struct<"struct.S1", (i32, f32, ptr)>
Expand Down