Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {

LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
if (op.getType().getIntOrFloatBitWidth() > 64)
return rewriter.notifyMatchFailure(op,
"bitwidth > 64 bits is not supported");

// Get APFloat function from runtime library.
FailureOr<FuncOp> fn =
lookupOrCreateBinaryFn(rewriter, symTable, APFloatName);
Expand Down Expand Up @@ -148,6 +152,11 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {

LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
if (op.getType().getIntOrFloatBitWidth() > 64 ||
op.getOperand().getType().getIntOrFloatBitWidth() > 64)
return rewriter.notifyMatchFailure(op,
"bitwidth > 64 bits is not supported");

// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
Expand Down Expand Up @@ -195,9 +204,10 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {

LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
if (op.getType().getIntOrFloatBitWidth() > 64)
return rewriter.notifyMatchFailure(
op, "result type > 64 bits is not supported");
if (op.getType().getIntOrFloatBitWidth() > 64 ||
op.getOperand().getType().getIntOrFloatBitWidth() > 64)
return rewriter.notifyMatchFailure(op,
"bitwidth > 64 bits is not supported");

// Get APFloat function from runtime library.
auto i1Type = IntegerType::get(symTable->getContext(), 1);
Expand Down Expand Up @@ -252,11 +262,10 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {

LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
if (op.getIn().getType().getIntOrFloatBitWidth() > 64) {
return rewriter.notifyMatchFailure(
loc, "integer bitwidth > 64 is not supported");
}
if (op.getType().getIntOrFloatBitWidth() > 64 ||
op.getOperand().getType().getIntOrFloatBitWidth() > 64)
return rewriter.notifyMatchFailure(op,
"bitwidth > 64 bits is not supported");

// Get APFloat function from runtime library.
auto i1Type = IntegerType::get(symTable->getContext(), 1);
Expand All @@ -270,6 +279,7 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {

rewriter.setInsertionPoint(op);
// Cast operands to 64-bit integers.
Location loc = op.getLoc();
auto inIntTy = cast<IntegerType>(op.getOperand().getType());
Value operandBits = op.getOperand();
if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
Expand Down Expand Up @@ -317,6 +327,10 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {

LogicalResult matchAndRewrite(arith::CmpFOp op,
PatternRewriter &rewriter) const override {
if (op.getLhs().getType().getIntOrFloatBitWidth() > 64)
return rewriter.notifyMatchFailure(op,
"bitwidth > 64 bits is not supported");

// Get APFloat function from runtime library.
auto i1Type = IntegerType::get(symTable->getContext(), 1);
auto i8Type = IntegerType::get(symTable->getContext(), 8);
Expand Down Expand Up @@ -456,6 +470,10 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {

LogicalResult matchAndRewrite(arith::NegFOp op,
PatternRewriter &rewriter) const override {
if (op.getOperand().getType().getIntOrFloatBitWidth() > 64)
return rewriter.notifyMatchFailure(op,
"bitwidth > 64 bits is not supported");

// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
Expand Down
25 changes: 25 additions & 0 deletions mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,28 @@ func.func @maxnumf(%arg0: f32, %arg1: f32) {
%0 = arith.maxnumf %arg0, %arg1 : f32
return
}

// -----

// CHECK-LABEL: func.func @unsupported_bitwidth
// CHECK: arith.addf {{.*}} : f128
// CHECK: arith.negf {{.*}} : f128
// CHECK: arith.cmpf {{.*}} : f128
// CHECK: arith.extf {{.*}} : f32 to f128
// CHECK: arith.truncf {{.*}} : f128 to f32
// CHECK: arith.fptosi {{.*}} : f128 to i32
// CHECK: arith.fptosi {{.*}} : f32 to i92
// CHECK: arith.sitofp {{.*}} : i1 to f128
// CHECK: arith.sitofp {{.*}} : i92 to f32
func.func @unsupported_bitwidth(%arg0: f128, %arg1: f128, %arg2: f32) {
%0 = arith.addf %arg0, %arg1 : f128
%1 = arith.negf %arg0 : f128
%2 = arith.cmpf "ult", %arg0, %arg1 : f128
%3 = arith.extf %arg2 : f32 to f128
%4 = arith.truncf %arg0 : f128 to f32
%5 = arith.fptosi %arg0 : f128 to i32
%6 = arith.fptosi %arg2 : f32 to i92
%7 = arith.sitofp %2 : i1 to f128
%8 = arith.sitofp %6 : i92 to f32
return
}
Loading