Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions transformer_engine/common/swizzle/swizzle_block_scaling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
}

// pack the exponent bits of the scaling factors
uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1);
uint32_t packed_exponents = ((sf.x >> 23) & 0xFF) | (((sf.y >> 23) & 0xFF) << 8) |
(((sf.z >> 23) & 0xFF) << 16) | (((sf.w >> 23) & 0xFF) << 24);

// partially swizzle the scaling factors
constexpr uint32_t ACTIVE_MASK = 0xFFFFFFFF; // no divergent branches
Expand Down Expand Up @@ -198,8 +199,9 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
uint32_t sf = *reinterpret_cast<const uint32_t*>(warp_src);

// broadcast it to four scaling factors for 1x32 tiles
sf = (sf << 1) | (sf >> 7);
sf = sf | (sf >> 16);
// extract and broadcast the exponent byte to four bytes for E8M0 format
uint32_t exp_byte = (sf >> 23) & 0xFF;
sf = exp_byte | (exp_byte << 8) | (exp_byte << 16) | (exp_byte << 24);

// broadcast it to sixteen scaling factors for 1x32 tiles
const uint4 sf4{sf, sf, sf, sf};
Expand Down
Loading