-
Notifications
You must be signed in to change notification settings - Fork 607
[PyTorch] Support user-defined op fusions #2597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
Greptile Summary
Important Files Changed
Confidence score: 3/5
Sequence DiagramsequenceDiagram
participant User
participant te_ops as "te.ops"
participant OperationFuser
participant CustomFusedOp as "Custom Fused Op"
participant Sequential
User->>te_ops: "register_forward_fusion(custom_fuse_func)"
te_ops->>OperationFuser: "append to forward_fusion_functions"
User->>te_ops: "register_backward_fusion(custom_fuse_func)"
te_ops->>OperationFuser: "append to backward_fusion_functions"
User->>Sequential: "create model with basic ops"
Sequential->>OperationFuser: "initialize with ops list"
User->>Sequential: "forward(input)"
Sequential->>OperationFuser: "maybe_fuse_ops()"
OperationFuser->>OperationFuser: "apply forward_fusion_functions"
OperationFuser->>CustomFusedOp: "create fused operation"
OperationFuser->>OperationFuser: "apply backward_fusion_functions"
OperationFuser->>CustomFusedOp: "create backward fused operation"
Sequential->>OperationFuser: "__call__(input)"
OperationFuser->>CustomFusedOp: "fuser_forward()"
CustomFusedOp->>Sequential: "return output"
Sequential->>User: "return result"
User->>Sequential: "backward(grad_output)"
Sequential->>OperationFuser: "backward via autograd"
OperationFuser->>CustomFusedOp: "fuser_backward()"
CustomFusedOp->>Sequential: "return gradients"
Sequential->>User: "gradients computed"
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, 2 comments
| register_backward_fusion(UserbuffersBackwardLinear.fuse_backward_ops) | ||
| register_backward_fusion(BackwardLinearAdd.fuse_backward_ops) | ||
| register_backward_fusion(BackwardLinearScale.fuse_backward_ops) | ||
| register_backward_fusion(BackwardActivationBias.fuse_backward_ops) | ||
| register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: The order of backward fusion registrations differs from the forward order - consider maintaining consistent ordering for predictability
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| while ops: | ||
|
|
||
| # Shift window | ||
| while len(window) >= 3: | ||
| out.append(window[0]) | ||
| window = window[1:] | ||
| while ops and len(window) < 3: | ||
| window.append(ops[0]) | ||
| ops = ops[1:] | ||
|
|
||
| # Construct fused op if window matches pattern | ||
| if ( | ||
| len(window) == 3 | ||
| and isinstance(window[2], _fusible_activations) | ||
| and isinstance(window[1], Bias) | ||
| and window[0].get_grad_output_quantizer() is not None | ||
| ): | ||
| op = BackwardActivationBias(bias=window[1], activation=window[2]) | ||
| window = [window[0], op] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The sliding window logic has a potential infinite loop if the fusion condition is never met. The while ops: loop doesn't guarantee termination when the window size is maintained at 3 and no fusion occurs.
Description
This PR adds the
register_forward_fusionandregister_backward_fusionfunctions to the op fuser API, allowing users to register custom fusions.Type of change
Changes
Checklist: