-
Notifications
You must be signed in to change notification settings - Fork 809
Allow symints to be created for arguments #16620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16620
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f350cc2 with merge base 7492d0d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Add test for creating args of SymInt type to be able to use them in view_copy nodes in the Arm TOSA backend together with the fix to make the pass work. Signed-off-by: Per Åstrand <per.astrand@arm.com> Change-Id: Ia947b8426af1b473df415a17e10f3db1582b84fd
00a24ad to
f350cc2
Compare
|
Hi @SS-JIA / @metascroy this is touching files outside Arm code and need a review if possible |
| if not hasattr(a, "constant") or a.constant is None: | ||
| raise ExportPassBaseError(f"Cannot add {a} to graph.") | ||
| a = a.constant | ||
| elif isinstance(a, torch.SymInt): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool)):
and add corresponding unit test for symfloat and symbool please
thank you for finding this bug btw @per
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how to trigger the SymFloat and SymBool paths here, since it comes from the dynamic shape export, which implies SymInts only, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, the proposed tests doesn't trigger the bug, since the symbool/symfloat aren't part of a list argument (as the shape argument is in view_copy). The tests you linked to are properly handled already (scalars and tensors).
The only alternative I've come up with is to manually construct a symbol with a shape_env, but that feels like a really constructed way to trigger the bug:
shape_env = ShapeEnv()
sym_bool = shape_env.create_unbacked_symbool()
tracer_owner = ExportPass()
tracer = tracer_owner.tracer
tracer.create_arg([sym_bool])
I've haven't been able to track down any operator that takes a list of bools or list of floats as argument, so IMHO it makes sense to keep the test as is, since that is what is exercised through the normal export flow. Or am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, thank you for explaining and walking me through the code
|
Other than Mergen's comment, LGTM 👍 |
|
@pytorchbot cherry-pick --onto release/1.1 -c fixnewfeature |
Cherry picking #16620The cherry pick PR is at #16774 and it is recommended to link a fixnewfeature cherry pick PR with an issue. Details for Dev Infra teamRaised by workflow job |
Summary
Add test for creating args of SymInt type to be able to use them in view_copy nodes together with the fix to make the test pass.
Test plan
Tested through CI tests.
cc @freddan80 @zingo @oscarandersson8218 @digantdesai