Skip to content

Conversation

@lhb8125
Copy link
Contributor

@lhb8125 lhb8125 commented Jan 13, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.
nvbug: https://nvbugspro.nvidia.com/bug/5754531

  1. When computing scaling, a very small amax will return a scale of FP32 maximum representable value: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/recipe/recipe_common.cuh#L37
  2. Then the generated pow-2-scale will be 0x7F000000, where the mantissa bits are masked: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/recipe/recipe_common.cuh#L41
  3. The scale_inv is computed by 1 / scale so that it will be 0x400000, whose exp bits are 0, but its mantissa bits are not full zeros (100 0000 0000 0000 0000 0000): https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu#L302
  4. When converting the blockwise-scale_inv to mxfp8-scale_inv, we extract the exponent bits of four scale_invs and pack them together for further broadcast by following the code line:
    uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1);
    This assumes the mantissa bits of scale_invs are full zeros, which works in most cases because we mask out the mantissa bits in step 2. However, in this corner case, the scale_inv of 0x400000 has non-zero mantissa bits, which will modify the bits of other scale_invs. That's an unexpected behaviour.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 13, 2026

Greptile Overview

Greptile Summary

This PR fixes a critical accuracy bug in the blockwise-to-MXFP8 scaling conversion for Blackwell GPUs. The bug manifests in a specific corner case when quantizing tensors with extremely small amax values.

Root Cause

The bug occurs through this sequence:

  1. Very small amax values cause compute_scale_from_amax to produce FP32 maximum scales (~0x7F000000)
  2. After power-of-2 masking (scale &= 0xFF800000), mantissa bits are cleared
  3. Computing scale_inv = 1.0 / scale produces a subnormal number (e.g., 0x00400000) with exponent=0 but non-zero mantissa bits
  4. The old exponent-packing logic used non-uniform bit shifts that allowed mantissa bits to contaminate adjacent exponent values

What Changed

1D Kernel (lines 116-117):

  • Before: (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1)
  • After: ((sf.x >> 23) & 0xFF) | (((sf.y >> 23) & 0xFF) << 8) | (((sf.z >> 23) & 0xFF) << 16) | (((sf.w >> 23) & 0xFF) << 24)

2D Kernel (lines 202-204):

  • Before: sf = (sf << 1) | (sf >> 7); sf = sf | (sf >> 16);
  • After: uint32_t exp_byte = (sf >> 23) & 0xFF; sf = exp_byte | (exp_byte << 8) | (exp_byte << 16) | (exp_byte << 24);

Both kernels now properly extract only the 8 exponent bits [30:23] using a uniform extraction pattern with masking, preventing mantissa contamination.

Impact

Correctness: The fix resolves accuracy issues for edge cases with very small amax values
Completeness: Both 1D and 2D kernels are fixed
Performance: No performance regression (bit operations are equally fast)
Edge Cases: Handles subnormal numbers, normal numbers, zero, infinity, and NaN correctly

Minor Recommendations

  • Consider adding inline comments explaining why the & 0xFF mask is critical (see style comments)
  • No tests were added for this specific corner case - consider adding regression tests that exercise the subnormal scale_inv scenario

Confidence Score: 4/5

  • This PR is safe to merge with high confidence. The fix is mathematically correct and addresses the root cause of the mantissa contamination bug.
  • Score of 4/5 reflects that the fix is technically sound and complete, but lacks regression tests for the specific corner case (subnormal scale_inv values). The implementation correctly extracts only exponent bits for E8M0 format in both 1D and 2D kernels. All edge cases (subnormal, normal, zero, infinity, NaN) are handled correctly. The only minor concerns are: (1) no new tests added for this corner case, and (2) lack of inline documentation explaining the critical & 0xFF masking operation.
  • No files require special attention. The fix is isolated, mathematically correct, and complete for both affected kernels.

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/swizzle/swizzle_block_scaling.cu 4/5 Fixes mantissa contamination bug in both 1D and 2D blockwise-to-MXFP8 scaling conversion kernels. The fix correctly extracts only exponent bits for E8M0 format. Could benefit from tests and inline comments.

Sequence Diagram

sequenceDiagram
    participant Quantize as Quantization Kernel
    participant Recipe as compute_scale_from_amax
    participant Swizzle as Swizzle Kernel
    participant E8M0 as E8M0 Output
    
    Note over Quantize,Recipe: Corner case: Very small amax
    Quantize->>Recipe: amax ≈ 0 (very small)
    Recipe->>Recipe: scale = max_fp8 / amax → ∞
    Recipe->>Recipe: scale = 0x7F800000 (FP32_MAX)
    Recipe->>Recipe: scale &= 0xFF800000 (mask mantissa)
    Recipe-->>Quantize: scale = 0x7F000000
    
    Note over Quantize,Swizzle: Compute scale_inv
    Quantize->>Quantize: scale_inv = 1.0 / 0x7F000000
    Quantize->>Quantize: scale_inv = 0x00400000 (subnormal!)
    Note right of Quantize: Exponent=0, bit 22=1
    Quantize->>Swizzle: scale_inv = 0x00400000
    
    alt Old (Buggy) Implementation
        Note over Swizzle: 1D: (sf.y >> 15) includes bit 7!
        Swizzle->>Swizzle: 0x00400000 >> 15 = 0x00000080
        Swizzle->>Swizzle: Mantissa contamination!
        Swizzle->>E8M0: Wrong result: 0x00808080
    else New (Fixed) Implementation  
        Note over Swizzle: Extract exponent only
        Swizzle->>Swizzle: (sf.x >> 23) & 0xFF = 0x00
        Swizzle->>Swizzle: (sf.y >> 23) & 0xFF = 0x00
        Swizzle->>Swizzle: (sf.z >> 23) & 0xFF = 0x00
        Swizzle->>Swizzle: (sf.w >> 23) & 0xFF = 0x00
        Swizzle->>E8M0: Correct result: 0x00000000
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile


// pack the exponent bits of the scaling factors
uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1);
uint32_t packed_exponents =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent parenthesization style: sf.y >> 23 has extra parentheses ((sf.y >> 23) << 8), while sf.z >> 23 and sf.w >> 23 don't have them before their left shifts.

While operator precedence makes this work correctly (shift operators bind tighter than OR), the inconsistent style suggests uncertainty about precedence and makes the code harder to read.

For consistency and clarity, either add parentheses to all:

Suggested change
uint32_t packed_exponents =
uint32_t packed_exponents = (sf.x >> 23) | ((sf.y >> 23) << 8) | ((sf.z >> 23) << 16) | ((sf.w >> 23) << 24);

Or remove them from all (though I recommend keeping them for clarity).

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 13, 2026

Greptile Overview

Greptile Summary

This PR fixes a corner-case accuracy bug in blockwise scaling with E8M0 format on Blackwell for the 1D scaling kernel. The issue occurs when very small amax values produce extremely large FP32 scales (0x7F000000), whose reciprocal (scale_inv) becomes a subnormal number (0x00400000) with non-zero mantissa bits despite the power-of-2 masking.

What Changed

The 1D kernel's exponent packing logic (line 116-117) was corrected from:

// OLD: Incorrect - includes mantissa bits in shifts
(sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1)

to:

// NEW: Correct - extracts only exponent+sign bits
(sf.x >> 23) | ((sf.y >> 23) << 8) | (sf.z >> 23) << 16 | (sf.w >> 23) << 24

The new approach properly extracts the 9 high-order bits (sign + 8 exponent bits) from each float and packs them into separate bytes, preventing mantissa bit contamination.

Critical Issue Found

⚠️ The 2D scaling kernel (lines 202-203) has the same bug but was not fixed in this PR. The broadcast operations sf = (sf << 1) | (sf >> 7) followed by sf = sf | (sf >> 16) will exhibit identical mantissa contamination when processing subnormal scale_inv values. This should be addressed for consistency.

Missing Elements

  • No tests added to verify the fix handles the corner case (very small amax → large scale → subnormal scale_inv)
  • No code comments explaining the bit manipulation or why the specific shifts are needed
  • PR checklist items remain unchecked

Confidence Score: 3/5

  • This PR fixes a critical corner-case bug in 1D scaling but leaves the same bug unfixed in 2D scaling, and lacks tests to prevent regression
  • The 1D kernel fix is mathematically correct and properly addresses the mantissa contamination issue. However, the score is reduced to 3/5 because: (1) the same bug exists in the 2D kernel (lines 202-203) but wasn't fixed, creating an inconsistency, (2) no tests were added to verify the corner case is handled correctly, (3) the code lacks comments explaining the complex bit manipulation, and (4) this is a subtle bug that could easily regress without proper test coverage
  • Pay close attention to lines 202-203 in swizzle_block_scaling.cu where the 2D kernel has the same unfixed bug. Also verify that tests are added before merging to prevent regression of this corner case.

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/swizzle/swizzle_block_scaling.cu 3/5 Fixed exponent packing bug in 1D kernel (line 117) but same bug remains in 2D kernel (line 202-203) - both need consistent fixes for subnormal scale_inv handling

Sequence Diagram

sequenceDiagram
    participant Quantize as Quantization Logic
    participant Recipe as recipe_common.cuh
    participant Transpose as quantize_transpose_vector
    participant Swizzle1D as Swizzle 1D Kernel (FIXED)
    participant Swizzle2D as Swizzle 2D Kernel (UNFIXED)
    participant Output as E8M0 Scale Output

    Note over Quantize,Recipe: Corner Case: Very Small Amax
    Quantize->>Recipe: compute_scale_from_amax(amax < epsilon)
    Recipe->>Recipe: scale = max_fp32 / amax → 0x7F000000
    Recipe->>Recipe: mask mantissa: scale &= 0xFF800000
    Note over Recipe: scale = 0x7F000000 (FP32 max, mantissa masked)
    
    Recipe->>Transpose: return scale
    Transpose->>Transpose: scale_inv = 1.0 / scale
    Note over Transpose: scale_inv = 0x00400000<br/>(subnormal with non-zero mantissa!)
    
    alt 1D Block Scaling Path
        Transpose->>Swizzle1D: scale_inv values (sf.x, sf.y, sf.z, sf.w)
        Note over Swizzle1D: OLD BUG: (sf.y >> 15)<br/>includes mantissa bits!
        Note over Swizzle1D: NEW FIX: (sf.y >> 23) << 8<br/>extracts only exponent bits
        Swizzle1D->>Swizzle1D: packed = (sf.x>>23) | ((sf.y>>23)<<8) |<br/>(sf.z>>23)<<16 | (sf.w>>23)<<24
        Swizzle1D->>Output: Clean E8M0 scales ✓
    else 2D Block Scaling Path
        Transpose->>Swizzle2D: scale_inv value (sf)
        Note over Swizzle2D: BUG STILL EXISTS!<br/>sf = (sf << 1) | (sf >> 7)
        Note over Swizzle2D: Mantissa bit 22 → bit 15<br/>contaminates output
        Swizzle2D->>Output: Corrupted E8M0 scales ✗
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 13, 2026

Additional Comments (1)

transformer_engine/common/swizzle/swizzle_block_scaling.cu
The 2D kernel appears to have the same mantissa contamination bug that was fixed in the 1D kernel on line 117.

When sf contains a subnormal value like 0x00400000 (which can occur when scale_inv = 1/scale for very large scales), the mantissa bit at position 22 will contaminate the packed result through the shift operations:

  • sf = 0x00400000 (bit 22 set)
  • sf << 1 = 0x00800000 (bit 23 set)
  • sf >> 7 = 0x00008000 (bit 15 set - mantissa contamination!)
  • Combined: 0x00808000 (bits 15 and 23 set)

The correct approach would be to extract only the exponent bits (similar to the 1D fix):

  // extract and broadcast the exponent byte to four bytes for E8M0 format
  uint32_t exp_byte = (sf >> 23) & 0xFF;
  sf = exp_byte | (exp_byte << 8) | (exp_byte << 16) | (exp_byte << 24);

This ensures that only the sign+exponent bits are extracted and broadcast, preventing mantissa bits from corrupting the E8M0 scaling factors.

lhb8125 and others added 2 commits January 13, 2026 12:03
Mask to 8 bits to prevent potential bit overlap

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Hongbin Liu  <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 13, 2026

Additional Comments (1)

transformer_engine/common/swizzle/swizzle_block_scaling.cu
CRITICAL: 2D kernel has the same mantissa contamination bug that was fixed in the 1D kernel

The 2D kernel's bit-shifting logic suffers from the exact same bug described in the PR description. When sf contains a subnormal value like 0x00400000 (which occurs when scale_inv = 1/scale for very large scales like 0x7F000000), the non-zero mantissa bits will contaminate the result during the shift operations.

Bug demonstration:
For sf = 0x00400000 (subnormal with mantissa bit 22 set):

  • sf << 1 = 0x00800000 (bit 23 now set)
  • sf >> 7 = 0x00008000 (mantissa bit contaminated into bit 15!)
  • After line 202: sf = 0x00808000 (corrupted)
  • After line 203: sf = 0x80808080 (all bytes corrupted)

The fix:
Extract only the exponent+sign bits (top 9 bits, specifically bits 23-31) and broadcast them properly:

  // extract and broadcast the exponent byte to four bytes for E8M0 format
  uint32_t exp_byte = (sf >> 23) & 0xFF;
  sf = exp_byte | (exp_byte << 8) | (exp_byte << 16) | (exp_byte << 24);

This matches the approach used in the 1D kernel fix and ensures mantissa bits cannot contaminate the packed exponents.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 13, 2026

Additional Comments (1)

transformer_engine/common/swizzle/swizzle_block_scaling.cu
The 2D kernel has the same mantissa contamination bug that was fixed in the 1D kernel. When sf contains a subnormal value like 0x00400000 (which occurs when scale_inv = 1/scale for very large scales as described in the PR), the mantissa bits will contaminate the packed result.

For example, with sf = 0x00400000 (bit 22 set in mantissa):

  1. sf << 1 = 0x00800000 (bit 23 set)
  2. sf >> 7 = 0x00008000 (bit 15 set - mantissa contamination)
  3. Combined: 0x00808000 (bits 15 and 23 set)
  4. After sf | (sf >> 16): 0x00808080

The expected result for exponent bits of 0x00400000 should be 0x00000000 (all exponent bits are 0), but we get 0x00808080 instead.

The correct approach is to extract only the sign+exponent bits (top 9 bits) and broadcast them:

  // extract and broadcast the exponent byte to four bytes for E8M0 format
  uint32_t exp_byte = (sf >> 23) & 0xFF;
  sf = exp_byte | (exp_byte << 8) | (exp_byte << 16) | (exp_byte << 24);

This ensures only the exponent bits are extracted and broadcast, preventing mantissa bit contamination.

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Hongbin Liu  <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 13, 2026

Additional Comments (1)

transformer_engine/common/swizzle/swizzle_block_scaling.cu
CRITICAL: The 2D kernel has the same mantissa contamination bug that was fixed in the 1D kernel.

The bit-shifting logic in lines 202-203 suffers from the exact same corner-case bug described in the PR description. When sf contains a subnormal value like 0x00400000 (which occurs when scale_inv = 1/scale for very large scales like 0x7F000000), the non-zero mantissa bits will contaminate the broadcast operation.

Bug demonstration with sf = 0x00400000:

Initial:         sf = 0x00400000  (bit 22 set in mantissa, exponent = 0)
After line 202:  sf = (0x00400000 << 1) | (0x00400000 >> 7)
                    = 0x00800000 | 0x00008000
                    = 0x00808000  (mantissa contamination in bits 15 and 23!)
After line 203:  sf = 0x00808000 | (0x00808000 >> 16)
                    = 0x00808000 | 0x00000080
                    = 0x00808080  (all bytes corrupted)

Expected result: For a subnormal with exponent bits = 0, the result should be 0x00000000, but we get 0x00808080 instead.

The fix: Extract only the sign+exponent bits (top 9 bits, specifically bits 23-31) and broadcast them properly:

  // extract and broadcast the exponent byte to four bytes for E8M0 format
  uint32_t exp_byte = (sf >> 23) & 0xFF;
  sf = exp_byte | (exp_byte << 8) | (exp_byte << 16) | (exp_byte << 24);

This ensures mantissa bits cannot contaminate the exponent packing, matching the fix applied to the 1D kernel on lines 116-117.

Signed-off-by: Hongbin Liu <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@lhb8125
Copy link
Contributor Author

lhb8125 commented Jan 13, 2026

/te-ci pytorch

@lhb8125
Copy link
Contributor Author

lhb8125 commented Jan 13, 2026

@ksivaman @zhongbozhu @yaox12 can you review this PR?

@yaox12
Copy link
Member

yaox12 commented Jan 14, 2026

cc @kwyss-nvidia

Copy link
Member

@yaox12 yaox12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@zhongbozhu
Copy link
Collaborator

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants