-
Notifications
You must be signed in to change notification settings - Fork 607
(Bug fix) Fix accuracy issue for blockwise scaling+E8 scale on Blackwell #2589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: hongbinl <[email protected]>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR fixes a critical accuracy bug in the blockwise-to-MXFP8 scaling conversion for Blackwell GPUs. The bug manifests in a specific corner case when quantizing tensors with extremely small amax values. Root CauseThe bug occurs through this sequence:
What Changed1D Kernel (lines 116-117):
2D Kernel (lines 202-204):
Both kernels now properly extract only the 8 exponent bits [30:23] using a uniform extraction pattern with masking, preventing mantissa contamination. Impact✅ Correctness: The fix resolves accuracy issues for edge cases with very small amax values Minor Recommendations
Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant Quantize as Quantization Kernel
participant Recipe as compute_scale_from_amax
participant Swizzle as Swizzle Kernel
participant E8M0 as E8M0 Output
Note over Quantize,Recipe: Corner case: Very small amax
Quantize->>Recipe: amax ≈ 0 (very small)
Recipe->>Recipe: scale = max_fp8 / amax → ∞
Recipe->>Recipe: scale = 0x7F800000 (FP32_MAX)
Recipe->>Recipe: scale &= 0xFF800000 (mask mantissa)
Recipe-->>Quantize: scale = 0x7F000000
Note over Quantize,Swizzle: Compute scale_inv
Quantize->>Quantize: scale_inv = 1.0 / 0x7F000000
Quantize->>Quantize: scale_inv = 0x00400000 (subnormal!)
Note right of Quantize: Exponent=0, bit 22=1
Quantize->>Swizzle: scale_inv = 0x00400000
alt Old (Buggy) Implementation
Note over Swizzle: 1D: (sf.y >> 15) includes bit 7!
Swizzle->>Swizzle: 0x00400000 >> 15 = 0x00000080
Swizzle->>Swizzle: Mantissa contamination!
Swizzle->>E8M0: Wrong result: 0x00808080
else New (Fixed) Implementation
Note over Swizzle: Extract exponent only
Swizzle->>Swizzle: (sf.x >> 23) & 0xFF = 0x00
Swizzle->>Swizzle: (sf.y >> 23) & 0xFF = 0x00
Swizzle->>Swizzle: (sf.z >> 23) & 0xFF = 0x00
Swizzle->>Swizzle: (sf.w >> 23) & 0xFF = 0x00
Swizzle->>E8M0: Correct result: 0x00000000
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
|
|
||
| // pack the exponent bits of the scaling factors | ||
| uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1); | ||
| uint32_t packed_exponents = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent parenthesization style: sf.y >> 23 has extra parentheses ((sf.y >> 23) << 8), while sf.z >> 23 and sf.w >> 23 don't have them before their left shifts.
While operator precedence makes this work correctly (shift operators bind tighter than OR), the inconsistent style suggests uncertainty about precedence and makes the code harder to read.
For consistency and clarity, either add parentheses to all:
| uint32_t packed_exponents = | |
| uint32_t packed_exponents = (sf.x >> 23) | ((sf.y >> 23) << 8) | ((sf.z >> 23) << 16) | ((sf.w >> 23) << 24); |
Or remove them from all (though I recommend keeping them for clarity).
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Greptile OverviewGreptile SummaryThis PR fixes a corner-case accuracy bug in blockwise scaling with E8M0 format on Blackwell for the 1D scaling kernel. The issue occurs when very small amax values produce extremely large FP32 scales (0x7F000000), whose reciprocal (scale_inv) becomes a subnormal number (0x00400000) with non-zero mantissa bits despite the power-of-2 masking. What ChangedThe 1D kernel's exponent packing logic (line 116-117) was corrected from: // OLD: Incorrect - includes mantissa bits in shifts
(sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1)to: // NEW: Correct - extracts only exponent+sign bits
(sf.x >> 23) | ((sf.y >> 23) << 8) | (sf.z >> 23) << 16 | (sf.w >> 23) << 24The new approach properly extracts the 9 high-order bits (sign + 8 exponent bits) from each float and packs them into separate bytes, preventing mantissa bit contamination. Critical Issue Found
Missing Elements
Confidence Score: 3/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant Quantize as Quantization Logic
participant Recipe as recipe_common.cuh
participant Transpose as quantize_transpose_vector
participant Swizzle1D as Swizzle 1D Kernel (FIXED)
participant Swizzle2D as Swizzle 2D Kernel (UNFIXED)
participant Output as E8M0 Scale Output
Note over Quantize,Recipe: Corner Case: Very Small Amax
Quantize->>Recipe: compute_scale_from_amax(amax < epsilon)
Recipe->>Recipe: scale = max_fp32 / amax → 0x7F000000
Recipe->>Recipe: mask mantissa: scale &= 0xFF800000
Note over Recipe: scale = 0x7F000000 (FP32 max, mantissa masked)
Recipe->>Transpose: return scale
Transpose->>Transpose: scale_inv = 1.0 / scale
Note over Transpose: scale_inv = 0x00400000<br/>(subnormal with non-zero mantissa!)
alt 1D Block Scaling Path
Transpose->>Swizzle1D: scale_inv values (sf.x, sf.y, sf.z, sf.w)
Note over Swizzle1D: OLD BUG: (sf.y >> 15)<br/>includes mantissa bits!
Note over Swizzle1D: NEW FIX: (sf.y >> 23) << 8<br/>extracts only exponent bits
Swizzle1D->>Swizzle1D: packed = (sf.x>>23) | ((sf.y>>23)<<8) |<br/>(sf.z>>23)<<16 | (sf.w>>23)<<24
Swizzle1D->>Output: Clean E8M0 scales ✓
else 2D Block Scaling Path
Transpose->>Swizzle2D: scale_inv value (sf)
Note over Swizzle2D: BUG STILL EXISTS!<br/>sf = (sf << 1) | (sf >> 7)
Note over Swizzle2D: Mantissa bit 22 → bit 15<br/>contaminates output
Swizzle2D->>Output: Corrupted E8M0 scales ✗
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Additional Comments (1)
When
The correct approach would be to extract only the exponent bits (similar to the 1D fix): This ensures that only the sign+exponent bits are extracted and broadcast, preventing mantissa bits from corrupting the E8M0 scaling factors. |
Mask to 8 bits to prevent potential bit overlap Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Hongbin Liu <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
Additional Comments (1)
The 2D kernel's bit-shifting logic suffers from the exact same bug described in the PR description. When Bug demonstration:
The fix: This matches the approach used in the 1D kernel fix and ensures mantissa bits cannot contaminate the packed exponents. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
Additional Comments (1)
For example, with
The expected result for exponent bits of The correct approach is to extract only the sign+exponent bits (top 9 bits) and broadcast them: This ensures only the exponent bits are extracted and broadcast, preventing mantissa bit contamination. |
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Hongbin Liu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Additional Comments (1)
The bit-shifting logic in lines 202-203 suffers from the exact same corner-case bug described in the PR description. When Bug demonstration with Expected result: For a subnormal with exponent bits = 0, the result should be The fix: Extract only the sign+exponent bits (top 9 bits, specifically bits 23-31) and broadcast them properly: This ensures mantissa bits cannot contaminate the exponent packing, matching the fix applied to the 1D kernel on lines 116-117. |
Signed-off-by: Hongbin Liu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
|
/te-ci pytorch |
|
@ksivaman @zhongbozhu @yaox12 can you review this PR? |
yaox12
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
|
LGTM |
Description
Please include a brief summary of the changes, relevant motivation and context.
nvbug: https://nvbugspro.nvidia.com/bug/5754531
uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1);
This assumes the mantissa bits of scale_invs are full zeros, which works in most cases because we mask out the mantissa bits in step 2. However, in this corner case, the scale_inv of 0x400000 has non-zero mantissa bits, which will modify the bits of other scale_invs. That's an unexpected behaviour.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: