Skip to content

Conversation

@timmoon10
Copy link
Collaborator

Description

This PR adds the register_forward_fusion and register_backward_fusion functions to the op fuser API, allowing users to register custom fusions.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add function to register custom op fusions
  • Refactor op fuser to have consistent op order in forward and backward pass
  • Refactor op fusion functions to avoid index bookkeeping
  • Add tests for user-defined ops

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <[email protected]>
@timmoon10 timmoon10 requested review from ksivaman and pggPL January 14, 2026 08:28
@timmoon10 timmoon10 added the enhancement New feature or request label Jan 14, 2026
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch L1

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 14, 2026

Greptile Summary

  • Adds register_forward_fusion and register_backward_fusion functions to the op fuser API, enabling users to register custom operation fusions alongside existing built-in fusions
  • Refactors operation fusion from hardcoded function calls to a registry-based system with configurable fusion function lists stored in class-level registries
  • Eliminates index bookkeeping complexity by converting fusion functions from standalone functions to static methods on operation classes that work directly with operation objects rather than tuples

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fuser.py Adds public API for fusion registration and refactors core fusion logic with validation
transformer_engine/pytorch/ops/fused/backward_linear_add.py Fixes critical indexing bugs in context/gradient handling that would cause runtime errors
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py Corrects index mismatches in gradient flow that could cause incorrect backward computations
tests/pytorch/test_fusible_ops.py Adds comprehensive test coverage for custom operation fusion functionality

Confidence score: 3/5

  • This PR contains significant architectural changes and bug fixes that require careful review for correctness
  • Score reflects the complexity of the fusion system refactoring and several critical bug fixes found in backward pass implementations that could cause incorrect gradients or runtime errors
  • Pay close attention to the indexing fixes in backward operation files and the new fusion validation logic

Sequence Diagram

sequenceDiagram
    participant User
    participant te_ops as "te.ops"
    participant OperationFuser
    participant CustomFusedOp as "Custom Fused Op"
    participant Sequential

    User->>te_ops: "register_forward_fusion(custom_fuse_func)"
    te_ops->>OperationFuser: "append to forward_fusion_functions"

    User->>te_ops: "register_backward_fusion(custom_fuse_func)"
    te_ops->>OperationFuser: "append to backward_fusion_functions"

    User->>Sequential: "create model with basic ops"
    Sequential->>OperationFuser: "initialize with ops list"

    User->>Sequential: "forward(input)"
    Sequential->>OperationFuser: "maybe_fuse_ops()"
    OperationFuser->>OperationFuser: "apply forward_fusion_functions"
    OperationFuser->>CustomFusedOp: "create fused operation"
    OperationFuser->>OperationFuser: "apply backward_fusion_functions"
    OperationFuser->>CustomFusedOp: "create backward fused operation"
    
    Sequential->>OperationFuser: "__call__(input)"
    OperationFuser->>CustomFusedOp: "fuser_forward()"
    CustomFusedOp->>Sequential: "return output"
    Sequential->>User: "return result"

    User->>Sequential: "backward(grad_output)"
    Sequential->>OperationFuser: "backward via autograd"
    OperationFuser->>CustomFusedOp: "fuser_backward()"
    CustomFusedOp->>Sequential: "return gradients"
    Sequential->>User: "gradients computed"
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +25 to +29
register_backward_fusion(UserbuffersBackwardLinear.fuse_backward_ops)
register_backward_fusion(BackwardLinearAdd.fuse_backward_ops)
register_backward_fusion(BackwardLinearScale.fuse_backward_ops)
register_backward_fusion(BackwardActivationBias.fuse_backward_ops)
register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The order of backward fusion registrations differs from the forward order - consider maintaining consistent ordering for predictability

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +114 to +132
while ops:

# Shift window
while len(window) >= 3:
out.append(window[0])
window = window[1:]
while ops and len(window) < 3:
window.append(ops[0])
ops = ops[1:]

# Construct fused op if window matches pattern
if (
len(window) == 3
and isinstance(window[2], _fusible_activations)
and isinstance(window[1], Bias)
and window[0].get_grad_output_quantizer() is not None
):
op = BackwardActivationBias(bias=window[1], activation=window[2])
window = [window[0], op]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The sliding window logic has a potential infinite loop if the fusion condition is never met. The while ops: loop doesn't guarantee termination when the window size is maintained at 3 and no fusion occurs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant