diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu index c5ad1aed43..37993787a5 100644 --- a/transformer_engine/common/swizzle/swizzle_block_scaling.cu +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -113,7 +113,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) } // pack the exponent bits of the scaling factors - uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1); + uint32_t packed_exponents = ((sf.x >> 23) & 0xFF) | (((sf.y >> 23) & 0xFF) << 8) | + (((sf.z >> 23) & 0xFF) << 16) | (((sf.w >> 23) & 0xFF) << 24); // partially swizzle the scaling factors constexpr uint32_t ACTIVE_MASK = 0xFFFFFFFF; // no divergent branches @@ -198,8 +199,9 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) uint32_t sf = *reinterpret_cast(warp_src); // broadcast it to four scaling factors for 1x32 tiles - sf = (sf << 1) | (sf >> 7); - sf = sf | (sf >> 16); + // extract and broadcast the exponent byte to four bytes for E8M0 format + uint32_t exp_byte = (sf >> 23) & 0xFF; + sf = exp_byte | (exp_byte << 8) | (exp_byte << 16) | (exp_byte << 24); // broadcast it to sixteen scaling factors for 1x32 tiles const uint4 sf4{sf, sf, sf, sf};