From fcfa87106a71f1018c0f83dfc6c1dccfc433ff76 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Wed, 17 Dec 2025 16:47:01 -0800 Subject: [PATCH 1/2] Fix double-rounding bug in double -> (b)float16 casts --- src/CodeGen_X86.cpp | 3 +- src/EmulateFloat16Math.cpp | 98 ++++++++++++++++++++++++++++++---- src/Float16.cpp | 46 +++++++++++++--- src/Float16.h | 4 ++ src/IR.cpp | 1 + src/IR.h | 2 + src/StrictifyFloat.cpp | 12 +++++ test/correctness/float16_t.cpp | 76 ++++++++++++++++++++++++-- 8 files changed, 221 insertions(+), 21 deletions(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 0e63af410cce..3d2388fdf89c 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -524,7 +524,8 @@ void CodeGen_X86::visit(const Cast *op) { if (target.has_feature(Target::F16C) && dst.code() == Type::Float && src.code() == Type::Float && - (dst.bits() == 16 || src.bits() == 16)) { + (dst.bits() == 16 || src.bits() == 16) && + src.bits() <= 32) { // Don't use for narrowing casts from double - it results in a libm call // Node we use code() == Type::Float instead of is_float(), because we // don't want to catch bfloat casts. diff --git a/src/EmulateFloat16Math.cpp b/src/EmulateFloat16Math.cpp index 1fda58a838e9..b24ccffae096 100644 --- a/src/EmulateFloat16Math.cpp +++ b/src/EmulateFloat16Math.cpp @@ -9,27 +9,46 @@ namespace Halide { namespace Internal { Expr bfloat16_to_float32(Expr e) { + const int lanes = e.type().lanes(); if (e.type().is_bfloat()) { e = reinterpret(e.type().with_code(Type::UInt), e); } - e = cast(UInt(32, e.type().lanes()), e); + e = cast(UInt(32, lanes), e); e = e << 16; - e = reinterpret(Float(32, e.type().lanes()), e); + e = reinterpret(Float(32, lanes), e); e = strict_float(e); return e; } Expr float32_to_bfloat16(Expr e) { internal_assert(e.type().bits() == 32); + const int lanes = e.type().lanes(); e = strict_float(e); - e = reinterpret(UInt(32, e.type().lanes()), e); + e = reinterpret(UInt(32, lanes), e); // We want to round ties to even, so before truncating either // add 0x8000 (0.5) to odd numbers or 0x7fff (0.499999) to // even numbers. e += 0x7fff + ((e >> 16) & 1); e = (e >> 16); - e = cast(UInt(16, e.type().lanes()), e); - e = reinterpret(BFloat(16, e.type().lanes()), e); + e = cast(UInt(16, lanes), e); + e = reinterpret(BFloat(16, lanes), e); + return e; +} + +Expr float64_to_bfloat16(Expr e) { + internal_assert(e.type().bits() == 64); + const int lanes = e.type().lanes(); + e = strict_float(e); + + // First round to float and record any gain of loss of magnitude + Expr f = cast(Float(32, lanes), e); + Expr err = abs(e) - abs(f); + e = reinterpret(UInt(32, lanes), f); + // As above, but break ties using err, if non-zero + e += 0x7fff + (((err >= 0) & ((e >> 16) & 1)) | (err > 0)); + e = (e >> 16); + e = cast(UInt(16, lanes), e); + e = reinterpret(BFloat(16, lanes), e); return e; } @@ -96,10 +115,11 @@ Expr float32_to_float16(Expr value) { // 0.5 if the integer part is odd, or 0.4999999 if the // integer part is even, then truncate. bits += (bits >> 13) & 1; - bits += 0xfff; - bits = bits >> 13; + bits += make_const(UInt(32), ((uint32_t)1 << (13 - 1)) - 1); + bits = cast(u16_t, bits >> 13); + // Rebias the exponent - bits -= 0x1c000; + bits -= 0x4000; // Truncate the top bits of the exponent bits = bits & 0x7fff; bits = select(is_denorm, denorm_bits, @@ -111,6 +131,55 @@ Expr float32_to_float16(Expr value) { return common_subexpression_elimination(reinterpret(f16_t, bits)); } +Expr float64_to_float16(Expr value) { + value = strict_float(value); + + Type f64_t = Float(64, value.type().lanes()); + Type f16_t = Float(16, value.type().lanes()); + Type u64_t = UInt(64, value.type().lanes()); + Type u16_t = UInt(16, value.type().lanes()); + + Expr bits = reinterpret(u64_t, value); + + // Extract the sign bit + Expr sign = bits & make_const(u64_t, (uint64_t)(0x8000000000000000ULL)); + bits = bits ^ sign; + + // Test the endpoints + Expr is_denorm = (bits < make_const(u64_t, (uint64_t)(0x3f10000000000000ULL))); + Expr is_inf = (bits >= make_const(u64_t, (uint64_t)(0x40f0000000000000ULL))); + Expr is_nan = (bits > make_const(u64_t, (uint64_t)(0x7ff0000000000000ULL))); + + // Denorms are linearly spaced, so we can handle them by scaling up the + // input as a float or double by 2^24 and using the existing int-conversion + // rounding instructions. We can scale up by adding 24 to the exponent. + Expr denorm_bits = cast(u16_t, strict_float(round(strict_float(reinterpret(f64_t, bits + make_const(u64_t, (uint64_t)(0x0180000000000000ULL))))))); + Expr inf_bits = make_const(u16_t, 0x7c00); + Expr nan_bits = make_const(u16_t, 0x7fff); + + // We want to round to nearest even, so we add either 0.5 if after + // truncation the last bit would be 1, or 0.4999999 if after truncation the + // last bit would be zero, then truncate. + bits += (bits >> 42) & 1; + bits += make_const(UInt(64), ((uint64_t)1 << (42 - 1)) - 1); + bits = bits >> 42; + + // We no longer need the high bits + bits = cast(u16_t, bits); + + // Rebias the exponent + bits -= 0x4000; + // Truncate the top bits of the exponent + bits = bits & 0x7fff; + bits = select(is_denorm, denorm_bits, + is_inf, inf_bits, + is_nan, nan_bits, + cast(u16_t, bits)); + // Recover the sign bit + bits = bits | cast(u16_t, sign >> 48); + return common_subexpression_elimination(reinterpret(f16_t, bits)); +} + namespace { const std::map transcendental_remapping = @@ -171,6 +240,7 @@ Expr lower_float16_cast(const Cast *op) { Type src = op->value.type(); Type dst = op->type; Type f32 = Float(32, dst.lanes()); + Type f64 = Float(64, dst.lanes()); Expr val = op->value; if (src.is_bfloat()) { @@ -183,10 +253,18 @@ Expr lower_float16_cast(const Cast *op) { if (dst.is_bfloat()) { internal_assert(dst.bits() == 16); - val = float32_to_bfloat16(cast(f32, val)); + if (src.bits() > 32) { + val = float64_to_bfloat16(cast(f64, val)); + } else { + val = float32_to_bfloat16(cast(f32, val)); + } } else if (dst.is_float() && dst.bits() < 32) { internal_assert(dst.bits() == 16); - val = float32_to_float16(cast(f32, val)); + if (src.bits() > 32) { + val = float64_to_float16(cast(f64, val)); + } else { + val = float32_to_float16(cast(f32, val)); + } } return cast(dst, val); diff --git a/src/Float16.cpp b/src/Float16.cpp index 80c96a38e6f1..6e7dbbe4b7c0 100644 --- a/src/Float16.cpp +++ b/src/Float16.cpp @@ -9,7 +9,10 @@ namespace Internal { // Conversion routines to and from float cribbed from Christian Rau's // half library (half.sourceforge.net) -uint16_t float_to_float16(float value) { +template +uint16_t float_to_float16(T value) { + static_assert(std::is_same_v || std::is_same_v, + "float_to_float16 only supports float and double types"); // Start by copying over the sign bit uint16_t bits = std::signbit(value) << 15; @@ -40,14 +43,14 @@ uint16_t float_to_float16(float value) { // We've normalized value as much as possible. Put the integer // portion of it into the mantissa. - float ival; - float frac = std::modf(value, &ival); + T ival; + T frac = std::modf(value, &ival); bits += (uint16_t)(std::abs((int)ival)); // Now consider the fractional part. We round to nearest with ties // going to even. frac = std::abs(frac); - bits += (frac > 0.5f) | ((frac == 0.5f) & bits); + bits += (frac > T(0.5)) | ((frac == T(0.5)) & bits); return bits; } @@ -341,6 +344,19 @@ uint16_t float_to_bfloat16(float f) { return ret >> 16; } +uint16_t float_to_bfloat16(double f) { + // Coming from double is a little tricker. We first narrow to float and + // record if any magnitude was lost of gained in the process. If so we'll + // use that to break ties instead of testing whether or not truncation would + // return odd. + float f32 = (float)f; + const double err = std::abs(f) - (double)std::abs(f32); + uint32_t ret; + memcpy(&ret, &f32, sizeof(float)); + ret += 0x7fff + (((err >= 0) & ((ret >> 16) & 1)) | (err > 0)); + return ret >> 16; +} + float bfloat16_to_float(uint16_t b) { // Assume little-endian floats uint16_t bits[2] = {0, b}; @@ -362,7 +378,17 @@ float16_t::float16_t(double value) } float16_t::float16_t(int value) - : data(float_to_float16(value)) { + : data(float_to_float16((float)value)) { + // integers of any size that map to finite float16s are all representable as + // float, so we can go via the float conversion method. +} + +float16_t::float16_t(int64_t value) + : data(float_to_float16((float)value)) { +} + +float16_t::float16_t(uint64_t value) + : data(float_to_float16((float)value)) { } float16_t::operator float() const { @@ -464,7 +490,15 @@ bfloat16_t::bfloat16_t(double value) } bfloat16_t::bfloat16_t(int value) - : data(float_to_bfloat16(value)) { + : data(float_to_bfloat16((double)value)) { +} + +bfloat16_t::bfloat16_t(int64_t value) + : data(float_to_bfloat16((double)value)) { +} + +bfloat16_t::bfloat16_t(uint64_t value) + : data(float_to_bfloat16((double)value)) { } bfloat16_t::operator float() const { diff --git a/src/Float16.h b/src/Float16.h index d3c285d6c09f..376813cbd507 100644 --- a/src/Float16.h +++ b/src/Float16.h @@ -32,6 +32,8 @@ struct float16_t { explicit float16_t(float value); explicit float16_t(double value); explicit float16_t(int value); + explicit float16_t(int64_t value); + explicit float16_t(uint64_t value); // @} /** Construct a float16_t with the bits initialised to 0. This represents @@ -175,6 +177,8 @@ struct bfloat16_t { explicit bfloat16_t(float value); explicit bfloat16_t(double value); explicit bfloat16_t(int value); + explicit bfloat16_t(int64_t value); + explicit bfloat16_t(uint64_t value); // @} /** Construct a bfloat16_t with the bits initialised to 0. This represents diff --git a/src/IR.cpp b/src/IR.cpp index c844c672656a..c82ae4ebd252 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -678,6 +678,7 @@ const char *const intrinsic_op_names[] = { "sliding_window_marker", "sorted_avg", "strict_add", + "strict_cast", "strict_div", "strict_eq", "strict_le", diff --git a/src/IR.h b/src/IR.h index 6dc0204b89ec..da27019a93c7 100644 --- a/src/IR.h +++ b/src/IR.h @@ -626,6 +626,7 @@ struct Call : public ExprNode { // them as reals and ignoring the existence of nan and inf. Using these // intrinsics instead prevents any such optimizations. strict_add, + strict_cast, strict_div, strict_eq, strict_le, @@ -792,6 +793,7 @@ struct Call : public ExprNode { bool is_strict_float_intrinsic() const { return is_intrinsic( {Call::strict_add, + Call::strict_cast, Call::strict_div, Call::strict_max, Call::strict_min, diff --git a/src/StrictifyFloat.cpp b/src/StrictifyFloat.cpp index 13dd0873bb12..8953ba035888 100644 --- a/src/StrictifyFloat.cpp +++ b/src/StrictifyFloat.cpp @@ -83,6 +83,16 @@ class Strictify : public IRMutator { return IRMutator::visit(op); } } + + Expr visit(const Cast *op) override { + if (op->value.type().is_float() && + op->type.is_float()) { + return Call::make(op->type, Call::strict_cast, + {mutate(op->value)}, Call::PureIntrinsic); + } else { + return IRMutator::visit(op); + } + } }; const std::set strict_externs = { @@ -142,6 +152,8 @@ Expr unstrictify_float(const Call *op) { return op->args[0] <= op->args[1]; } else if (op->is_intrinsic(Call::strict_eq)) { return op->args[0] == op->args[1]; + } else if (op->is_intrinsic(Call::strict_cast)) { + return cast(op->type, op->args[0]); } else { internal_error << "Missing lowering of strict float intrinsic: " << Expr(op) << "\n"; diff --git a/test/correctness/float16_t.cpp b/test/correctness/float16_t.cpp index 9e917a6216e6..d206e43202b0 100644 --- a/test/correctness/float16_t.cpp +++ b/test/correctness/float16_t.cpp @@ -301,6 +301,74 @@ int run_test() { } } + { + for (double f : {1.0, -1.0, 0.235, -0.235, 1e-7, -1e-7}) { + { + // Test double -> float16 doesn't have double-rounding issues + float16_t k{f}; + float16_t k_plus_eps = float16_t::make_from_bits(k.to_bits() + 1); + const bool k_is_odd = k.to_bits() & 1; + float16_t to_even = k_is_odd ? k_plus_eps : k; + float16_t to_odd = k_is_odd ? k : k_plus_eps; + float halfway = (float(k) + float(k_plus_eps)) / 2.f; + + printf("float16 k_is_odd = %d\n", k_is_odd); + + // We expect ties to round to even + assert(float16_t(halfway) == to_even); + // Now let's construct a case where it *should* have rounded to + // odd if rounding directly from double, but rounding via + // float does the wrong thing. + double halfway_plus_eps = std::nextafter(halfway, (double)to_odd); + assert(std::abs(halfway_plus_eps - (double)to_odd) < + std::abs(halfway_plus_eps - (double)to_even)); + + assert(float(halfway_plus_eps) == halfway); +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 + assert(_Float16(halfway_plus_eps) == _Float16(float(to_odd))); +#endif + assert(float16_t(halfway_plus_eps) == to_odd); + + // Now test the same thing in generated code. We need strict float to + // prevent Halide from fusing multiple float casts into one. + Param p; + p.set(halfway_plus_eps); + // halfway plus epsilon rounds to exactly halfway as a float + assert(evaluate(strict_float(cast(Float(32), p))) == halfway); + // So if we go via float we get the even outcome, because + // exactly halfway rounds to even + assert(evaluate(strict_float(cast(Float(16), cast(Float(32), p)))) == to_even); + // But if we go direct, we should go to odd, because it's closer + assert(evaluate(strict_float(cast(Float(16), p))) == to_odd); + } + + { + // Test the same things for bfloat + bfloat16_t k{f}; + bfloat16_t k_plus_eps = bfloat16_t::make_from_bits(k.to_bits() + 1); + const bool k_is_odd = k.to_bits() & 1; + + bfloat16_t to_even = k_is_odd ? k_plus_eps : k; + bfloat16_t to_odd = k_is_odd ? k : k_plus_eps; + float halfway = (float(k) + float(k_plus_eps)) / 2.f; + + assert(bfloat16_t(halfway) == to_even); + double halfway_plus_eps = std::nextafter(halfway, (double)to_odd); + assert(std::abs(halfway_plus_eps - (double)to_odd) < + std::abs(halfway_plus_eps - (double)to_even)); + + assert(float(halfway_plus_eps) == halfway); + assert(bfloat16_t(halfway_plus_eps) == to_odd); + + Param p; + p.set(halfway_plus_eps); + assert(evaluate(strict_float(cast(Float(32), p))) == halfway); + assert(evaluate(strict_float(cast(BFloat(16), cast(Float(32), p)))) == to_even); + assert(evaluate(strict_float(cast(BFloat(16), p))) == to_odd); + } + } + } + // Enable to read assembly generated by the conversion routines if ((false)) { // Intentional dead code. Extra parens to pacify clang-tidy. Func src, to_f16, from_f16; @@ -309,11 +377,11 @@ int run_test() { to_f16(x) = cast(src(x)); from_f16(x) = cast(to_f16(x)); - src.compute_root().vectorize(x, 8, TailStrategy::RoundUp); - to_f16.compute_root().vectorize(x, 8, TailStrategy::RoundUp); - from_f16.compute_root().vectorize(x, 8, TailStrategy::RoundUp); + src.compute_root().vectorize(x, 16, TailStrategy::RoundUp); + to_f16.compute_root().vectorize(x, 16, TailStrategy::RoundUp); + from_f16.compute_root().vectorize(x, 16, TailStrategy::RoundUp); - from_f16.compile_to_assembly("/dev/stdout", {}, Target("host-no_asserts-no_bounds_query-no_runtime-disable_llvm_loop_unroll-disable_llvm_loop_vectorize")); + from_f16.compile_to_assembly("/dev/stdout", {}, Target("host-no_asserts-no_bounds_query-no_runtime")); } // Check infinity handling for both float16_t and Halide codegen. From 9b23ae688f012ac00fae314fcf7a55028ffdf294 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 18 Dec 2025 12:49:40 -0800 Subject: [PATCH 2/2] Share more code between coming from 64 and 32 bits Also add and fix some comments --- src/EmulateFloat16Math.cpp | 152 ++++++++++++++------------------- src/EmulateFloat16Math.h | 4 +- src/Float16.cpp | 2 +- test/correctness/float16_t.cpp | 2 - 4 files changed, 66 insertions(+), 94 deletions(-) diff --git a/src/EmulateFloat16Math.cpp b/src/EmulateFloat16Math.cpp index b24ccffae096..9ffca1bb3b54 100644 --- a/src/EmulateFloat16Math.cpp +++ b/src/EmulateFloat16Math.cpp @@ -20,32 +20,30 @@ Expr bfloat16_to_float32(Expr e) { return e; } -Expr float32_to_bfloat16(Expr e) { - internal_assert(e.type().bits() == 32); - const int lanes = e.type().lanes(); - e = strict_float(e); - e = reinterpret(UInt(32, lanes), e); - // We want to round ties to even, so before truncating either - // add 0x8000 (0.5) to odd numbers or 0x7fff (0.499999) to - // even numbers. - e += 0x7fff + ((e >> 16) & 1); - e = (e >> 16); - e = cast(UInt(16, lanes), e); - e = reinterpret(BFloat(16, lanes), e); - return e; -} - -Expr float64_to_bfloat16(Expr e) { - internal_assert(e.type().bits() == 64); +Expr float_to_bfloat16(Expr e) { const int lanes = e.type().lanes(); e = strict_float(e); + Expr err; // First round to float and record any gain of loss of magnitude - Expr f = cast(Float(32, lanes), e); - Expr err = abs(e) - abs(f); - e = reinterpret(UInt(32, lanes), f); - // As above, but break ties using err, if non-zero - e += 0x7fff + (((err >= 0) & ((e >> 16) & 1)) | (err > 0)); + if (e.type().bits() == 64) { + Expr f = cast(Float(32, lanes), e); + err = abs(e) - abs(f); + e = f; + } else { + internal_assert(e.type().bits() == 32); + } + e = reinterpret(UInt(32, lanes), e); + + // We want to round ties to even, so if we have no error recorded above, + // before truncating either add 0x8000 (0.5) to odd numbers or 0x7fff + // (0.499999) to even numbers. If we have error, break ties using that + // instead. + Expr tie_breaker = (e >> 16) & 1; // 1 when rounding down would go to odd + if (err.defined()) { + tie_breaker = ((err == 0) & tie_breaker) | (err > 0); + } + e += tie_breaker + 0x7fff; e = (e >> 16); e = cast(UInt(16, lanes), e); e = reinterpret(BFloat(16, lanes), e); @@ -82,90 +80,64 @@ Expr float16_to_float32(Expr value) { return f32; } -Expr float32_to_float16(Expr value) { +Expr float_to_float16(Expr value) { // We're about the sniff the bits of a float, so we should // guard it with strict float to ensure we don't do things // like assume it can't be denormal. value = strict_float(value); - Type f32_t = Float(32, value.type().lanes()); + const int src_bits = value.type().bits(); + + Type float_t = Float(src_bits, value.type().lanes()); Type f16_t = Float(16, value.type().lanes()); - Type u32_t = UInt(32, value.type().lanes()); + Type bits_t = UInt(src_bits, value.type().lanes()); Type u16_t = UInt(16, value.type().lanes()); - Expr bits = reinterpret(u32_t, value); + Expr bits = reinterpret(bits_t, value); // Extract the sign bit - Expr sign = bits & make_const(u32_t, 0x80000000); + Expr sign = bits & make_const(bits_t, (uint64_t)1 << (src_bits - 1)); bits = bits ^ sign; // Test the endpoints - Expr is_denorm = (bits < make_const(u32_t, 0x38800000)); - Expr is_inf = (bits >= make_const(u32_t, 0x47800000)); - Expr is_nan = (bits > make_const(u32_t, 0x7f800000)); + + // Smallest input representable as normal float16 (2^-14) + Expr two_to_the_minus_14 = src_bits == 32 ? + make_const(bits_t, 0x38800000) : + make_const(bits_t, (uint64_t)0x3f10000000000000ULL); + Expr is_denorm = bits < two_to_the_minus_14; + + // Smallest input too big to represent as a float16 (2^16) + Expr two_to_the_16 = src_bits == 32 ? + make_const(bits_t, 0x47800000) : + make_const(bits_t, (uint64_t)0x40f0000000000000ULL); + Expr is_inf = bits >= two_to_the_16; + + // Check if the input is a nan, which is anything bigger than an infinity bit pattern + Expr input_inf_bits = src_bits == 32 ? + make_const(bits_t, 0x7f800000) : + make_const(bits_t, (uint64_t)0x7ff0000000000000ULL); + Expr is_nan = bits > input_inf_bits; // Denorms are linearly spaced, so we can handle them // by scaling up the input as a float and using the // existing int-conversion rounding instructions. - Expr denorm_bits = cast(u16_t, strict_float(round(strict_float(reinterpret(f32_t, bits + 0x0c000000))))); + Expr two_to_the_24 = src_bits == 32 ? + make_const(bits_t, 0x0c000000) : + make_const(bits_t, (uint64_t)0x0180000000000000ULL); + Expr denorm_bits = cast(u16_t, strict_float(round(reinterpret(float_t, bits + two_to_the_24)))); Expr inf_bits = make_const(u16_t, 0x7c00); Expr nan_bits = make_const(u16_t, 0x7fff); // We want to round to nearest even, so we add either // 0.5 if the integer part is odd, or 0.4999999 if the // integer part is even, then truncate. - bits += (bits >> 13) & 1; - bits += make_const(UInt(32), ((uint32_t)1 << (13 - 1)) - 1); - bits = cast(u16_t, bits >> 13); - - // Rebias the exponent - bits -= 0x4000; - // Truncate the top bits of the exponent - bits = bits & 0x7fff; - bits = select(is_denorm, denorm_bits, - is_inf, inf_bits, - is_nan, nan_bits, - cast(u16_t, bits)); - // Recover the sign bit - bits = bits | cast(u16_t, sign >> 16); - return common_subexpression_elimination(reinterpret(f16_t, bits)); -} - -Expr float64_to_float16(Expr value) { - value = strict_float(value); - - Type f64_t = Float(64, value.type().lanes()); - Type f16_t = Float(16, value.type().lanes()); - Type u64_t = UInt(64, value.type().lanes()); - Type u16_t = UInt(16, value.type().lanes()); - - Expr bits = reinterpret(u64_t, value); - - // Extract the sign bit - Expr sign = bits & make_const(u64_t, (uint64_t)(0x8000000000000000ULL)); - bits = bits ^ sign; - - // Test the endpoints - Expr is_denorm = (bits < make_const(u64_t, (uint64_t)(0x3f10000000000000ULL))); - Expr is_inf = (bits >= make_const(u64_t, (uint64_t)(0x40f0000000000000ULL))); - Expr is_nan = (bits > make_const(u64_t, (uint64_t)(0x7ff0000000000000ULL))); - - // Denorms are linearly spaced, so we can handle them by scaling up the - // input as a float or double by 2^24 and using the existing int-conversion - // rounding instructions. We can scale up by adding 24 to the exponent. - Expr denorm_bits = cast(u16_t, strict_float(round(strict_float(reinterpret(f64_t, bits + make_const(u64_t, (uint64_t)(0x0180000000000000ULL))))))); - Expr inf_bits = make_const(u16_t, 0x7c00); - Expr nan_bits = make_const(u16_t, 0x7fff); - - // We want to round to nearest even, so we add either 0.5 if after - // truncation the last bit would be 1, or 0.4999999 if after truncation the - // last bit would be zero, then truncate. - bits += (bits >> 42) & 1; - bits += make_const(UInt(64), ((uint64_t)1 << (42 - 1)) - 1); - bits = bits >> 42; - - // We no longer need the high bits - bits = cast(u16_t, bits); + const int float16_mantissa_bits = 10; + const int input_mantissa_bits = src_bits == 32 ? 23 : 52; + const int bits_lost = input_mantissa_bits - float16_mantissa_bits; + bits += (bits >> bits_lost) & 1; + bits += make_const(bits_t, ((uint64_t)1 << (bits_lost - 1)) - 1); + bits = cast(u16_t, bits >> bits_lost); // Rebias the exponent bits -= 0x4000; @@ -176,7 +148,7 @@ Expr float64_to_float16(Expr value) { is_nan, nan_bits, cast(u16_t, bits)); // Recover the sign bit - bits = bits | cast(u16_t, sign >> 48); + bits = bits | cast(u16_t, sign >> (src_bits - 16)); return common_subexpression_elimination(reinterpret(f16_t, bits)); } @@ -226,7 +198,7 @@ Expr lower_float16_transcendental_to_float32_equivalent(const Call *op) { Expr e = Call::make(t, it->second, new_args, op->call_type, op->func, op->value_index, op->image, op->param); if (op->type.is_float()) { - e = float32_to_float16(e); + e = float_to_float16(e); } internal_assert(e.type() == op->type); return e; @@ -254,17 +226,19 @@ Expr lower_float16_cast(const Cast *op) { if (dst.is_bfloat()) { internal_assert(dst.bits() == 16); if (src.bits() > 32) { - val = float64_to_bfloat16(cast(f64, val)); + val = cast(f64, val); } else { - val = float32_to_bfloat16(cast(f32, val)); + val = cast(f32, val); } + val = float_to_bfloat16(val); } else if (dst.is_float() && dst.bits() < 32) { internal_assert(dst.bits() == 16); if (src.bits() > 32) { - val = float64_to_float16(cast(f64, val)); + val = cast(f64, val); } else { - val = float32_to_float16(cast(f32, val)); + val = cast(f32, val); } + val = float_to_float16(val); } return cast(dst, val); diff --git a/src/EmulateFloat16Math.h b/src/EmulateFloat16Math.h index de1a5e091588..f61de7456fbc 100644 --- a/src/EmulateFloat16Math.h +++ b/src/EmulateFloat16Math.h @@ -19,8 +19,8 @@ Expr lower_float16_transcendental_to_float32_equivalent(const Call *); /** Cast to/from float and bfloat using bitwise math. */ //@{ -Expr float32_to_bfloat16(Expr e); -Expr float32_to_float16(Expr e); +Expr float_to_bfloat16(Expr e); +Expr float_to_float16(Expr e); Expr float16_to_float32(Expr e); Expr bfloat16_to_float32(Expr e); Expr lower_float16_cast(const Cast *op); diff --git a/src/Float16.cpp b/src/Float16.cpp index 6e7dbbe4b7c0..1e9e789f476b 100644 --- a/src/Float16.cpp +++ b/src/Float16.cpp @@ -346,7 +346,7 @@ uint16_t float_to_bfloat16(float f) { uint16_t float_to_bfloat16(double f) { // Coming from double is a little tricker. We first narrow to float and - // record if any magnitude was lost of gained in the process. If so we'll + // record if any magnitude was lost or gained in the process. If so we'll // use that to break ties instead of testing whether or not truncation would // return odd. float f32 = (float)f; diff --git a/test/correctness/float16_t.cpp b/test/correctness/float16_t.cpp index d206e43202b0..120682cd86d8 100644 --- a/test/correctness/float16_t.cpp +++ b/test/correctness/float16_t.cpp @@ -312,8 +312,6 @@ int run_test() { float16_t to_odd = k_is_odd ? k : k_plus_eps; float halfway = (float(k) + float(k_plus_eps)) / 2.f; - printf("float16 k_is_odd = %d\n", k_is_odd); - // We expect ties to round to even assert(float16_t(halfway) == to_even); // Now let's construct a case where it *should* have rounded to