Skip to content

Commit

Permalink
[CIR][CodeGen] Fix missing 'nsw' flag in add, sub, and mul in binop o…
Browse files Browse the repository at this point in the history
…perator (#677)

This PR is to fix the missing **nsw** flag in issue #664 regarding add,
mul arithmetic operations. there is also a problem with unary operations
such as **Inc ,Dec,Plus,Minus and Not** . which should also have 'nsw'
flag [example](https://godbolt.org/z/q3o3jsbe1). This part should need
to be fixed through lowering.
  • Loading branch information
mingshi2333 authored and lanza committed Jun 21, 2024
1 parent 9dbce79 commit c290c04
Show file tree
Hide file tree
Showing 17 changed files with 178 additions and 47 deletions.
39 changes: 37 additions & 2 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,21 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return createBinop(lhs, mlir::cir::BinOpKind::Or, rhs);
}

mlir::Value createMul(mlir::Value lhs, mlir::Value rhs) {
return createBinop(lhs, mlir::cir::BinOpKind::Mul, rhs);
mlir::Value createMul(mlir::Value lhs, mlir::Value rhs, bool hasNUW = false,
bool hasNSW = false) {
auto op = create<mlir::cir::BinOp>(lhs.getLoc(), lhs.getType(),
mlir::cir::BinOpKind::Mul, lhs, rhs);
if (hasNUW)
op.setNoUnsignedWrap(true);
if (hasNSW)
op.setNoSignedWrap(true);
return op;
}
mlir::Value createNSWMul(mlir::Value lhs, mlir::Value rhs) {
return createMul(lhs, rhs, false, true);
}
mlir::Value createNUWAMul(mlir::Value lhs, mlir::Value rhs) {
return createMul(lhs, rhs, true, false);
}

mlir::Value createMul(mlir::Value lhs, llvm::APInt rhs) {
Expand Down Expand Up @@ -235,6 +248,28 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return createSub(lhs, rhs, false, true);
}

mlir::Value createNUWSub(mlir::Value lhs, mlir::Value rhs) {
return createSub(lhs, rhs, true, false);
}

mlir::Value createAdd(mlir::Value lhs, mlir::Value rhs, bool hasNUW = false,
bool hasNSW = false) {
auto op = create<mlir::cir::BinOp>(lhs.getLoc(), lhs.getType(),
mlir::cir::BinOpKind::Add, lhs, rhs);
if (hasNUW)
op.setNoUnsignedWrap(true);
if (hasNSW)
op.setNoSignedWrap(true);
return op;
}

mlir::Value createNSWAdd(mlir::Value lhs, mlir::Value rhs) {
return createAdd(lhs, rhs, false, true);
}
mlir::Value createNUWAdd(mlir::Value lhs, mlir::Value rhs) {
return createAdd(lhs, rhs, true, false);
}

struct BinOpOverflowResults {
mlir::Value result;
mlir::Value overflow;
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class StructType
};

bool isAnyFloatingPointType(mlir::Type t);

bool isFPOrFPVectorTy(mlir::Type);
} // namespace cir
} // namespace mlir

Expand Down
19 changes: 19 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,25 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
lhs, rhs);
}

mlir::Value createFAdd(mlir::Value lhs, mlir::Value rhs) {
assert(!MissingFeatures::metaDataNode());
if (IsFPConstrained)
llvm_unreachable("Constrained FP NYI");

assert(!MissingFeatures::foldBinOpFMF());
return create<mlir::cir::BinOp>(lhs.getLoc(), mlir::cir::BinOpKind::Add,
lhs, rhs);
}
mlir::Value createFMul(mlir::Value lhs, mlir::Value rhs) {
assert(!MissingFeatures::metaDataNode());
if (IsFPConstrained)
llvm_unreachable("Constrained FP NYI");

assert(!MissingFeatures::foldBinOpFMF());
return create<mlir::cir::BinOp>(lhs.getLoc(), mlir::cir::BinOpKind::Mul,
lhs, rhs);
}

mlir::Value createDynCast(mlir::Location loc, mlir::Value src,
mlir::cir::PointerType destType, bool isRefCast,
mlir::cir::DynamicCastInfoAttr info) {
Expand Down
67 changes: 66 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,38 @@ static mlir::Value buildPointerArithmetic(CIRGenFunction &CGF,
}

mlir::Value ScalarExprEmitter::buildMul(const BinOpInfo &Ops) {
if (Ops.CompType->isSignedIntegerOrEnumerationType()) {
switch (CGF.getLangOpts().getSignedOverflowBehavior()) {
case LangOptions::SOB_Defined:
if (!CGF.SanOpts.has(SanitizerKind::SignedIntegerOverflow))
return Builder.createMul(Ops.LHS, Ops.RHS);
[[fallthrough]];
case LangOptions::SOB_Undefined:
if (!CGF.SanOpts.has(SanitizerKind::SignedIntegerOverflow))
return Builder.createNSWMul(Ops.LHS, Ops.RHS);
[[fallthrough]];
case LangOptions::SOB_Trapping:
if (CanElideOverflowCheck(CGF.getContext(), Ops))
return Builder.createNSWMul(Ops.LHS, Ops.RHS);
llvm_unreachable("NYI");
}
}
if (Ops.FullType->isConstantMatrixType()) {
llvm_unreachable("NYI");
}
if (Ops.CompType->isUnsignedIntegerType() &&
CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) &&
!CanElideOverflowCheck(CGF.getContext(), Ops))
llvm_unreachable("NYI");

if (mlir::cir::isFPOrFPVectorTy(Ops.LHS.getType())) {
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
return Builder.createFMul(Ops.LHS, Ops.RHS);
}

if (Ops.isFixedPointOp())
llvm_unreachable("NYI");

return Builder.create<mlir::cir::BinOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType),
mlir::cir::BinOpKind::Mul, Ops.LHS, Ops.RHS);
Expand All @@ -1308,6 +1340,39 @@ mlir::Value ScalarExprEmitter::buildAdd(const BinOpInfo &Ops) {
if (Ops.LHS.getType().isa<mlir::cir::PointerType>() ||
Ops.RHS.getType().isa<mlir::cir::PointerType>())
return buildPointerArithmetic(CGF, Ops, /*isSubtraction=*/false);
if (Ops.CompType->isSignedIntegerOrEnumerationType()) {
switch (CGF.getLangOpts().getSignedOverflowBehavior()) {
case LangOptions::SOB_Defined:
if (!CGF.SanOpts.has(SanitizerKind::SignedIntegerOverflow))
return Builder.createAdd(Ops.LHS, Ops.RHS);
[[fallthrough]];
case LangOptions::SOB_Undefined:
if (!CGF.SanOpts.has(SanitizerKind::SignedIntegerOverflow))
return Builder.createNSWAdd(Ops.LHS, Ops.RHS);
[[fallthrough]];
case LangOptions::SOB_Trapping:
if (CanElideOverflowCheck(CGF.getContext(), Ops))
return Builder.createNSWAdd(Ops.LHS, Ops.RHS);

llvm_unreachable("NYI");
}
}
if (Ops.FullType->isConstantMatrixType()) {
llvm_unreachable("NYI");
}

if (Ops.CompType->isUnsignedIntegerType() &&
CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) &&
!CanElideOverflowCheck(CGF.getContext(), Ops))
llvm_unreachable("NYI");

if (mlir::cir::isFPOrFPVectorTy(Ops.LHS.getType())) {
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
return Builder.createFAdd(Ops.LHS, Ops.RHS);
}

if (Ops.isFixedPointOp())
llvm_unreachable("NYI");

return Builder.create<mlir::cir::BinOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType),
Expand Down Expand Up @@ -1344,7 +1409,7 @@ mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) {
!CanElideOverflowCheck(CGF.getContext(), Ops))
llvm_unreachable("NYI");

if (Ops.CompType->isFloatingType()) {
if (mlir::cir::isFPOrFPVectorTy(Ops.LHS.getType())) {
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
return Builder.createFSub(Ops.LHS, Ops.RHS);
}
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,19 @@ bool mlir::cir::isAnyFloatingPointType(mlir::Type t) {
mlir::cir::LongDoubleType, mlir::cir::FP80Type>(t);
}

//===----------------------------------------------------------------------===//
// Floating-point and Float-point Vecotr type helpers
//===----------------------------------------------------------------------===//

bool mlir::cir::isFPOrFPVectorTy(mlir::Type t) {

if (isa<mlir::cir::VectorType>(t)) {
return isAnyFloatingPointType(
t.dyn_cast<mlir::cir::VectorType>().getEltType());
}
return isAnyFloatingPointType(t);
}

//===----------------------------------------------------------------------===//
// FuncType Definitions
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions clang/test/CIR/CodeGen/binop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ void b0(int a, int b) {
x = x | b;
}

// CHECK: = cir.binop(mul, %3, %4) : !s32i
// CHECK: = cir.binop(mul, %3, %4) nsw : !s32i
// CHECK: = cir.binop(div, %6, %7) : !s32i
// CHECK: = cir.binop(rem, %9, %10) : !s32i
// CHECK: = cir.binop(add, %12, %13) : !s32i
// CHECK: = cir.binop(add, %12, %13) nsw : !s32i
// CHECK: = cir.binop(sub, %15, %16) nsw : !s32i
// CHECK: = cir.shift( right, %18 : !s32i, %19 : !s32i) -> !s32i
// CHECK: = cir.shift(left, %21 : !s32i, %22 : !s32i) -> !s32i
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/CodeGen/bitint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ i10 test_arith(i10 lhs, i10 rhs) {
// CHECK: cir.func @_Z10test_arithDB10_S_(%arg0: !cir.int<s, 10> loc({{.+}}), %arg1: !cir.int<s, 10> loc({{.+}})) -> !cir.int<s, 10>
// CHECK: %[[#LHS:]] = cir.load %{{.+}} : !cir.ptr<!cir.int<s, 10>>, !cir.int<s, 10>
// CHECK-NEXT: %[[#RHS:]] = cir.load %{{.+}} : !cir.ptr<!cir.int<s, 10>>, !cir.int<s, 10>
// CHECK-NEXT: %{{.+}} = cir.binop(add, %[[#LHS]], %[[#RHS]]) : !cir.int<s, 10>
// CHECK-NEXT: %{{.+}} = cir.binop(add, %[[#LHS]], %[[#RHS]]) nsw : !cir.int<s, 10>
// CHECK: }

void Size1ExtIntParam(unsigned _BitInt(1) A) {
Expand Down
4 changes: 2 additions & 2 deletions clang/test/CIR/CodeGen/call.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void d(void) {
// CHECK: cir.store %arg1, %1 : !s32i, !cir.ptr<!s32i>
// CHECK: %3 = cir.load %0 : !cir.ptr<!s32i>, !s32i
// CHECK: %4 = cir.load %1 : !cir.ptr<!s32i>, !s32i
// CHECK: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK: %5 = cir.binop(add, %3, %4) nsw : !s32i
// CHECK: cir.store %5, %2 : !s32i, !cir.ptr<!s32i>
// CHECK: %6 = cir.load %2 : !cir.ptr<!s32i>, !s32i
// CHECK: cir.return %6
Expand Down Expand Up @@ -64,7 +64,7 @@ void d(void) {
// CXX-NEXT: cir.store %arg1, %1 : !s32i, !cir.ptr<!s32i>
// CXX-NEXT: %3 = cir.load %0 : !cir.ptr<!s32i>, !s32i
// CXX-NEXT: %4 = cir.load %1 : !cir.ptr<!s32i>, !s32i
// CXX-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CXX-NEXT: %5 = cir.binop(add, %3, %4) nsw : !s32i
// CXX-NEXT: cir.store %5, %2 : !s32i, !cir.ptr<!s32i>
// CXX-NEXT: %6 = cir.load %2 : !cir.ptr<!s32i>, !s32i
// CXX-NEXT: cir.return %6
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/CodeGen/comma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ int c0() {
// CHECK: %[[#A:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init]
// CHECK: %[[#B:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["b", init]
// CHECK: %[[#LOADED_B:]] = cir.load %[[#B]] : !cir.ptr<!s32i>, !s32i
// CHECK: %[[#]] = cir.binop(add, %[[#LOADED_B]], %[[#]]) : !s32i
// CHECK: %[[#]] = cir.binop(add, %[[#LOADED_B]], %[[#]]) nsw : !s32i
// CHECK: %[[#LOADED_A:]] = cir.load %[[#A]] : !cir.ptr<!s32i>, !s32i
// CHECK: cir.store %[[#LOADED_A]], %[[#RET]] : !s32i, !cir.ptr<!s32i>

Expand Down
4 changes: 2 additions & 2 deletions clang/test/CIR/CodeGen/if-constexpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void if0() {
// CHECK-NEXT: cir.store %4, %2 : !s32i, !cir.ptr<!s32i> loc({{.*}})
// CHECK-NEXT: %5 = cir.const #cir.int<3> : !s32i loc({{.*}})
// CHECK-NEXT: %6 = cir.load %2 : !cir.ptr<!s32i>, !s32i loc({{.*}})
// CHECK-NEXT: %7 = cir.binop(mul, %5, %6) : !s32i loc({{.*}})
// CHECK-NEXT: %7 = cir.binop(mul, %5, %6) nsw : !s32i loc({{.*}})
// CHECK-NEXT: cir.store %7, %3 : !s32i, !cir.ptr<!s32i> loc({{.*}})
// CHECK-NEXT: } loc({{.*}})
// CHECK-NEXT: cir.scope {
Expand All @@ -84,7 +84,7 @@ void if0() {
// CHECK-NEXT: cir.store %4, %2 : !s32i, !cir.ptr<!s32i> loc({{.*}})
// CHECK-NEXT: %5 = cir.const #cir.int<10> : !s32i loc({{.*}})
// CHECK-NEXT: %6 = cir.load %2 : !cir.ptr<!s32i>, !s32i loc({{.*}})
// CHECK-NEXT: %7 = cir.binop(mul, %5, %6) : !s32i loc({{.*}})
// CHECK-NEXT: %7 = cir.binop(mul, %5, %6) nsw : !s32i loc({{.*}})
// CHECK-NEXT: cir.store %7, %3 : !s32i, !cir.ptr<!s32i> loc({{.*}})
// CHECK-NEXT: } loc({{.*}})
// CHECK-NEXT: cir.scope {
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/CodeGen/lambda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void l0() {
// CHECK: %3 = cir.load %2 : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
// CHECK: %4 = cir.load %3 : !cir.ptr<!s32i>, !s32i
// CHECK: %5 = cir.const #cir.int<1> : !s32i
// CHECK: %6 = cir.binop(add, %4, %5) : !s32i
// CHECK: %6 = cir.binop(add, %4, %5) nsw : !s32i
// CHECK: %7 = cir.get_member %1[0] {name = "i"} : !cir.ptr<!ty_22anon2E422> -> !cir.ptr<!cir.ptr<!s32i>>
// CHECK: %8 = cir.load %7 : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
// CHECK: cir.store %6, %8 : !s32i, !cir.ptr<!s32i>
Expand Down
18 changes: 9 additions & 9 deletions clang/test/CIR/CodeGen/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ void l1() {
// CHECK-NEXT: } body {
// CHECK-NEXT: %4 = cir.load %0 : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %5 = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %6 = cir.binop(add, %4, %5) : !s32i
// CHECK-NEXT: %6 = cir.binop(add, %4, %5) nsw : !s32i
// CHECK-NEXT: cir.store %6, %0 : !s32i, !cir.ptr<!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: } step {
// CHECK-NEXT: %4 = cir.load %2 : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %5 = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %6 = cir.binop(add, %4, %5) : !s32i
// CHECK-NEXT: %6 = cir.binop(add, %4, %5) nsw : !s32i
// CHECK-NEXT: cir.store %6, %2 : !s32i, !cir.ptr<!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }
Expand All @@ -59,7 +59,7 @@ void l2(bool cond) {
// CHECK-NEXT: } do {
// CHECK-NEXT: %3 = cir.load %1 : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) nsw : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }
Expand All @@ -71,7 +71,7 @@ void l2(bool cond) {
// CHECK-NEXT: } do {
// CHECK-NEXT: %3 = cir.load %1 : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) nsw : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }
Expand All @@ -84,7 +84,7 @@ void l2(bool cond) {
// CHECK-NEXT: } do {
// CHECK-NEXT: %3 = cir.load %1 : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) nsw : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }
Expand All @@ -108,7 +108,7 @@ void l3(bool cond) {
// CHECK-NEXT: cir.do {
// CHECK-NEXT: %3 = cir.load %1 : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) nsw : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: } while {
Expand All @@ -120,7 +120,7 @@ void l3(bool cond) {
// CHECK-NEXT: cir.do {
// CHECK-NEXT: %3 = cir.load %1 : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) nsw : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: } while {
Expand All @@ -132,7 +132,7 @@ void l3(bool cond) {
// CHECK-NEXT: cir.do {
// CHECK-NEXT: %3 = cir.load %1 : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) nsw : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: } while {
Expand All @@ -159,7 +159,7 @@ void l4() {
// CHECK-NEXT: } do {
// CHECK-NEXT: %4 = cir.load %0 : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %5 = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %6 = cir.binop(add, %4, %5) : !s32i
// CHECK-NEXT: %6 = cir.binop(add, %4, %5) nsw : !s32i
// CHECK-NEXT: cir.store %6, %0 : !s32i, !cir.ptr<!s32i>
// CHECK-NEXT: cir.scope {
// CHECK-NEXT: %10 = cir.load %0 : !cir.ptr<!s32i>, !s32i
Expand Down
Loading

0 comments on commit c290c04

Please sign in to comment.