diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 0e63af410cce..3d2388fdf89c 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -524,7 +524,8 @@ void CodeGen_X86::visit(const Cast *op) { if (target.has_feature(Target::F16C) && dst.code() == Type::Float && src.code() == Type::Float && - (dst.bits() == 16 || src.bits() == 16)) { + (dst.bits() == 16 || src.bits() == 16) && + src.bits() <= 32) { // Don't use for narrowing casts from double - it results in a libm call // Node we use code() == Type::Float instead of is_float(), because we // don't want to catch bfloat casts. diff --git a/src/EmulateFloat16Math.cpp b/src/EmulateFloat16Math.cpp index 1fda58a838e9..9ffca1bb3b54 100644 --- a/src/EmulateFloat16Math.cpp +++ b/src/EmulateFloat16Math.cpp @@ -9,27 +9,44 @@ namespace Halide { namespace Internal { Expr bfloat16_to_float32(Expr e) { + const int lanes = e.type().lanes(); if (e.type().is_bfloat()) { e = reinterpret(e.type().with_code(Type::UInt), e); } - e = cast(UInt(32, e.type().lanes()), e); + e = cast(UInt(32, lanes), e); e = e << 16; - e = reinterpret(Float(32, e.type().lanes()), e); + e = reinterpret(Float(32, lanes), e); e = strict_float(e); return e; } -Expr float32_to_bfloat16(Expr e) { - internal_assert(e.type().bits() == 32); +Expr float_to_bfloat16(Expr e) { + const int lanes = e.type().lanes(); e = strict_float(e); - e = reinterpret(UInt(32, e.type().lanes()), e); - // We want to round ties to even, so before truncating either - // add 0x8000 (0.5) to odd numbers or 0x7fff (0.499999) to - // even numbers. - e += 0x7fff + ((e >> 16) & 1); + + Expr err; + // First round to float and record any gain of loss of magnitude + if (e.type().bits() == 64) { + Expr f = cast(Float(32, lanes), e); + err = abs(e) - abs(f); + e = f; + } else { + internal_assert(e.type().bits() == 32); + } + e = reinterpret(UInt(32, lanes), e); + + // We want to round ties to even, so if we have no error recorded above, + // before truncating either add 0x8000 (0.5) to odd numbers or 0x7fff + // (0.499999) to even numbers. If we have error, break ties using that + // instead. + Expr tie_breaker = (e >> 16) & 1; // 1 when rounding down would go to odd + if (err.defined()) { + tie_breaker = ((err == 0) & tie_breaker) | (err > 0); + } + e += tie_breaker + 0x7fff; e = (e >> 16); - e = cast(UInt(16, e.type().lanes()), e); - e = reinterpret(BFloat(16, e.type().lanes()), e); + e = cast(UInt(16, lanes), e); + e = reinterpret(BFloat(16, lanes), e); return e; } @@ -63,43 +80,67 @@ Expr float16_to_float32(Expr value) { return f32; } -Expr float32_to_float16(Expr value) { +Expr float_to_float16(Expr value) { // We're about the sniff the bits of a float, so we should // guard it with strict float to ensure we don't do things // like assume it can't be denormal. value = strict_float(value); - Type f32_t = Float(32, value.type().lanes()); + const int src_bits = value.type().bits(); + + Type float_t = Float(src_bits, value.type().lanes()); Type f16_t = Float(16, value.type().lanes()); - Type u32_t = UInt(32, value.type().lanes()); + Type bits_t = UInt(src_bits, value.type().lanes()); Type u16_t = UInt(16, value.type().lanes()); - Expr bits = reinterpret(u32_t, value); + Expr bits = reinterpret(bits_t, value); // Extract the sign bit - Expr sign = bits & make_const(u32_t, 0x80000000); + Expr sign = bits & make_const(bits_t, (uint64_t)1 << (src_bits - 1)); bits = bits ^ sign; // Test the endpoints - Expr is_denorm = (bits < make_const(u32_t, 0x38800000)); - Expr is_inf = (bits >= make_const(u32_t, 0x47800000)); - Expr is_nan = (bits > make_const(u32_t, 0x7f800000)); + + // Smallest input representable as normal float16 (2^-14) + Expr two_to_the_minus_14 = src_bits == 32 ? + make_const(bits_t, 0x38800000) : + make_const(bits_t, (uint64_t)0x3f10000000000000ULL); + Expr is_denorm = bits < two_to_the_minus_14; + + // Smallest input too big to represent as a float16 (2^16) + Expr two_to_the_16 = src_bits == 32 ? + make_const(bits_t, 0x47800000) : + make_const(bits_t, (uint64_t)0x40f0000000000000ULL); + Expr is_inf = bits >= two_to_the_16; + + // Check if the input is a nan, which is anything bigger than an infinity bit pattern + Expr input_inf_bits = src_bits == 32 ? + make_const(bits_t, 0x7f800000) : + make_const(bits_t, (uint64_t)0x7ff0000000000000ULL); + Expr is_nan = bits > input_inf_bits; // Denorms are linearly spaced, so we can handle them // by scaling up the input as a float and using the // existing int-conversion rounding instructions. - Expr denorm_bits = cast(u16_t, strict_float(round(strict_float(reinterpret(f32_t, bits + 0x0c000000))))); + Expr two_to_the_24 = src_bits == 32 ? + make_const(bits_t, 0x0c000000) : + make_const(bits_t, (uint64_t)0x0180000000000000ULL); + Expr denorm_bits = cast(u16_t, strict_float(round(reinterpret(float_t, bits + two_to_the_24)))); Expr inf_bits = make_const(u16_t, 0x7c00); Expr nan_bits = make_const(u16_t, 0x7fff); // We want to round to nearest even, so we add either // 0.5 if the integer part is odd, or 0.4999999 if the // integer part is even, then truncate. - bits += (bits >> 13) & 1; - bits += 0xfff; - bits = bits >> 13; + const int float16_mantissa_bits = 10; + const int input_mantissa_bits = src_bits == 32 ? 23 : 52; + const int bits_lost = input_mantissa_bits - float16_mantissa_bits; + bits += (bits >> bits_lost) & 1; + bits += make_const(bits_t, ((uint64_t)1 << (bits_lost - 1)) - 1); + bits = cast(u16_t, bits >> bits_lost); + // Rebias the exponent - bits -= 0x1c000; + bits -= 0x4000; // Truncate the top bits of the exponent bits = bits & 0x7fff; bits = select(is_denorm, denorm_bits, @@ -107,7 +148,7 @@ Expr float32_to_float16(Expr value) { is_nan, nan_bits, cast(u16_t, bits)); // Recover the sign bit - bits = bits | cast(u16_t, sign >> 16); + bits = bits | cast(u16_t, sign >> (src_bits - 16)); return common_subexpression_elimination(reinterpret(f16_t, bits)); } @@ -157,7 +198,7 @@ Expr lower_float16_transcendental_to_float32_equivalent(const Call *op) { Expr e = Call::make(t, it->second, new_args, op->call_type, op->func, op->value_index, op->image, op->param); if (op->type.is_float()) { - e = float32_to_float16(e); + e = float_to_float16(e); } internal_assert(e.type() == op->type); return e; @@ -171,6 +212,7 @@ Expr lower_float16_cast(const Cast *op) { Type src = op->value.type(); Type dst = op->type; Type f32 = Float(32, dst.lanes()); + Type f64 = Float(64, dst.lanes()); Expr val = op->value; if (src.is_bfloat()) { @@ -183,10 +225,20 @@ Expr lower_float16_cast(const Cast *op) { if (dst.is_bfloat()) { internal_assert(dst.bits() == 16); - val = float32_to_bfloat16(cast(f32, val)); + if (src.bits() > 32) { + val = cast(f64, val); + } else { + val = cast(f32, val); + } + val = float_to_bfloat16(val); } else if (dst.is_float() && dst.bits() < 32) { internal_assert(dst.bits() == 16); - val = float32_to_float16(cast(f32, val)); + if (src.bits() > 32) { + val = cast(f64, val); + } else { + val = cast(f32, val); + } + val = float_to_float16(val); } return cast(dst, val); diff --git a/src/EmulateFloat16Math.h b/src/EmulateFloat16Math.h index de1a5e091588..f61de7456fbc 100644 --- a/src/EmulateFloat16Math.h +++ b/src/EmulateFloat16Math.h @@ -19,8 +19,8 @@ Expr lower_float16_transcendental_to_float32_equivalent(const Call *); /** Cast to/from float and bfloat using bitwise math. */ //@{ -Expr float32_to_bfloat16(Expr e); -Expr float32_to_float16(Expr e); +Expr float_to_bfloat16(Expr e); +Expr float_to_float16(Expr e); Expr float16_to_float32(Expr e); Expr bfloat16_to_float32(Expr e); Expr lower_float16_cast(const Cast *op); diff --git a/src/Float16.cpp b/src/Float16.cpp index 80c96a38e6f1..1e9e789f476b 100644 --- a/src/Float16.cpp +++ b/src/Float16.cpp @@ -9,7 +9,10 @@ namespace Internal { // Conversion routines to and from float cribbed from Christian Rau's // half library (half.sourceforge.net) -uint16_t float_to_float16(float value) { +template +uint16_t float_to_float16(T value) { + static_assert(std::is_same_v || std::is_same_v, + "float_to_float16 only supports float and double types"); // Start by copying over the sign bit uint16_t bits = std::signbit(value) << 15; @@ -40,14 +43,14 @@ uint16_t float_to_float16(float value) { // We've normalized value as much as possible. Put the integer // portion of it into the mantissa. - float ival; - float frac = std::modf(value, &ival); + T ival; + T frac = std::modf(value, &ival); bits += (uint16_t)(std::abs((int)ival)); // Now consider the fractional part. We round to nearest with ties // going to even. frac = std::abs(frac); - bits += (frac > 0.5f) | ((frac == 0.5f) & bits); + bits += (frac > T(0.5)) | ((frac == T(0.5)) & bits); return bits; } @@ -341,6 +344,19 @@ uint16_t float_to_bfloat16(float f) { return ret >> 16; } +uint16_t float_to_bfloat16(double f) { + // Coming from double is a little tricker. We first narrow to float and + // record if any magnitude was lost or gained in the process. If so we'll + // use that to break ties instead of testing whether or not truncation would + // return odd. + float f32 = (float)f; + const double err = std::abs(f) - (double)std::abs(f32); + uint32_t ret; + memcpy(&ret, &f32, sizeof(float)); + ret += 0x7fff + (((err >= 0) & ((ret >> 16) & 1)) | (err > 0)); + return ret >> 16; +} + float bfloat16_to_float(uint16_t b) { // Assume little-endian floats uint16_t bits[2] = {0, b}; @@ -362,7 +378,17 @@ float16_t::float16_t(double value) } float16_t::float16_t(int value) - : data(float_to_float16(value)) { + : data(float_to_float16((float)value)) { + // integers of any size that map to finite float16s are all representable as + // float, so we can go via the float conversion method. +} + +float16_t::float16_t(int64_t value) + : data(float_to_float16((float)value)) { +} + +float16_t::float16_t(uint64_t value) + : data(float_to_float16((float)value)) { } float16_t::operator float() const { @@ -464,7 +490,15 @@ bfloat16_t::bfloat16_t(double value) } bfloat16_t::bfloat16_t(int value) - : data(float_to_bfloat16(value)) { + : data(float_to_bfloat16((double)value)) { +} + +bfloat16_t::bfloat16_t(int64_t value) + : data(float_to_bfloat16((double)value)) { +} + +bfloat16_t::bfloat16_t(uint64_t value) + : data(float_to_bfloat16((double)value)) { } bfloat16_t::operator float() const { diff --git a/src/Float16.h b/src/Float16.h index d3c285d6c09f..376813cbd507 100644 --- a/src/Float16.h +++ b/src/Float16.h @@ -32,6 +32,8 @@ struct float16_t { explicit float16_t(float value); explicit float16_t(double value); explicit float16_t(int value); + explicit float16_t(int64_t value); + explicit float16_t(uint64_t value); // @} /** Construct a float16_t with the bits initialised to 0. This represents @@ -175,6 +177,8 @@ struct bfloat16_t { explicit bfloat16_t(float value); explicit bfloat16_t(double value); explicit bfloat16_t(int value); + explicit bfloat16_t(int64_t value); + explicit bfloat16_t(uint64_t value); // @} /** Construct a bfloat16_t with the bits initialised to 0. This represents diff --git a/src/IR.cpp b/src/IR.cpp index c844c672656a..c82ae4ebd252 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -678,6 +678,7 @@ const char *const intrinsic_op_names[] = { "sliding_window_marker", "sorted_avg", "strict_add", + "strict_cast", "strict_div", "strict_eq", "strict_le", diff --git a/src/IR.h b/src/IR.h index 6dc0204b89ec..da27019a93c7 100644 --- a/src/IR.h +++ b/src/IR.h @@ -626,6 +626,7 @@ struct Call : public ExprNode { // them as reals and ignoring the existence of nan and inf. Using these // intrinsics instead prevents any such optimizations. strict_add, + strict_cast, strict_div, strict_eq, strict_le, @@ -792,6 +793,7 @@ struct Call : public ExprNode { bool is_strict_float_intrinsic() const { return is_intrinsic( {Call::strict_add, + Call::strict_cast, Call::strict_div, Call::strict_max, Call::strict_min, diff --git a/src/StrictifyFloat.cpp b/src/StrictifyFloat.cpp index 13dd0873bb12..8953ba035888 100644 --- a/src/StrictifyFloat.cpp +++ b/src/StrictifyFloat.cpp @@ -83,6 +83,16 @@ class Strictify : public IRMutator { return IRMutator::visit(op); } } + + Expr visit(const Cast *op) override { + if (op->value.type().is_float() && + op->type.is_float()) { + return Call::make(op->type, Call::strict_cast, + {mutate(op->value)}, Call::PureIntrinsic); + } else { + return IRMutator::visit(op); + } + } }; const std::set strict_externs = { @@ -142,6 +152,8 @@ Expr unstrictify_float(const Call *op) { return op->args[0] <= op->args[1]; } else if (op->is_intrinsic(Call::strict_eq)) { return op->args[0] == op->args[1]; + } else if (op->is_intrinsic(Call::strict_cast)) { + return cast(op->type, op->args[0]); } else { internal_error << "Missing lowering of strict float intrinsic: " << Expr(op) << "\n"; diff --git a/test/correctness/float16_t.cpp b/test/correctness/float16_t.cpp index 9e917a6216e6..120682cd86d8 100644 --- a/test/correctness/float16_t.cpp +++ b/test/correctness/float16_t.cpp @@ -301,6 +301,72 @@ int run_test() { } } + { + for (double f : {1.0, -1.0, 0.235, -0.235, 1e-7, -1e-7}) { + { + // Test double -> float16 doesn't have double-rounding issues + float16_t k{f}; + float16_t k_plus_eps = float16_t::make_from_bits(k.to_bits() + 1); + const bool k_is_odd = k.to_bits() & 1; + float16_t to_even = k_is_odd ? k_plus_eps : k; + float16_t to_odd = k_is_odd ? k : k_plus_eps; + float halfway = (float(k) + float(k_plus_eps)) / 2.f; + + // We expect ties to round to even + assert(float16_t(halfway) == to_even); + // Now let's construct a case where it *should* have rounded to + // odd if rounding directly from double, but rounding via + // float does the wrong thing. + double halfway_plus_eps = std::nextafter(halfway, (double)to_odd); + assert(std::abs(halfway_plus_eps - (double)to_odd) < + std::abs(halfway_plus_eps - (double)to_even)); + + assert(float(halfway_plus_eps) == halfway); +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 + assert(_Float16(halfway_plus_eps) == _Float16(float(to_odd))); +#endif + assert(float16_t(halfway_plus_eps) == to_odd); + + // Now test the same thing in generated code. We need strict float to + // prevent Halide from fusing multiple float casts into one. + Param p; + p.set(halfway_plus_eps); + // halfway plus epsilon rounds to exactly halfway as a float + assert(evaluate(strict_float(cast(Float(32), p))) == halfway); + // So if we go via float we get the even outcome, because + // exactly halfway rounds to even + assert(evaluate(strict_float(cast(Float(16), cast(Float(32), p)))) == to_even); + // But if we go direct, we should go to odd, because it's closer + assert(evaluate(strict_float(cast(Float(16), p))) == to_odd); + } + + { + // Test the same things for bfloat + bfloat16_t k{f}; + bfloat16_t k_plus_eps = bfloat16_t::make_from_bits(k.to_bits() + 1); + const bool k_is_odd = k.to_bits() & 1; + + bfloat16_t to_even = k_is_odd ? k_plus_eps : k; + bfloat16_t to_odd = k_is_odd ? k : k_plus_eps; + float halfway = (float(k) + float(k_plus_eps)) / 2.f; + + assert(bfloat16_t(halfway) == to_even); + double halfway_plus_eps = std::nextafter(halfway, (double)to_odd); + assert(std::abs(halfway_plus_eps - (double)to_odd) < + std::abs(halfway_plus_eps - (double)to_even)); + + assert(float(halfway_plus_eps) == halfway); + assert(bfloat16_t(halfway_plus_eps) == to_odd); + + Param p; + p.set(halfway_plus_eps); + assert(evaluate(strict_float(cast(Float(32), p))) == halfway); + assert(evaluate(strict_float(cast(BFloat(16), cast(Float(32), p)))) == to_even); + assert(evaluate(strict_float(cast(BFloat(16), p))) == to_odd); + } + } + } + // Enable to read assembly generated by the conversion routines if ((false)) { // Intentional dead code. Extra parens to pacify clang-tidy. Func src, to_f16, from_f16; @@ -309,11 +375,11 @@ int run_test() { to_f16(x) = cast(src(x)); from_f16(x) = cast(to_f16(x)); - src.compute_root().vectorize(x, 8, TailStrategy::RoundUp); - to_f16.compute_root().vectorize(x, 8, TailStrategy::RoundUp); - from_f16.compute_root().vectorize(x, 8, TailStrategy::RoundUp); + src.compute_root().vectorize(x, 16, TailStrategy::RoundUp); + to_f16.compute_root().vectorize(x, 16, TailStrategy::RoundUp); + from_f16.compute_root().vectorize(x, 16, TailStrategy::RoundUp); - from_f16.compile_to_assembly("/dev/stdout", {}, Target("host-no_asserts-no_bounds_query-no_runtime-disable_llvm_loop_unroll-disable_llvm_loop_vectorize")); + from_f16.compile_to_assembly("/dev/stdout", {}, Target("host-no_asserts-no_bounds_query-no_runtime")); } // Check infinity handling for both float16_t and Halide codegen.