Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Add stack_trace on make_fx #155155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
Loading
from
Open

Conversation

yushangdi
Copy link
Contributor

@yushangdi yushangdi commented Jun 4, 2025

Summary:
Previosuly, we only add stack trace in class _ModuleStackTracer(PythonKeyTracer) for non-strict export. I moved this stack trace logic to the parent class PythonKeyTracer, this way the graph traced from Module using make_fx will have stack_trace as well.

Motivation: we've observed some uses cases where users first use make_fx on the Module, and then run export on the resulting graph. If the result of make_fx doesn't have stack trace, the stack trace information is lost.

Test Plan:

buck run test:test_export -- -r  test_stack_trace

Rollback Plan:

Differential Revision: D75985427

cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv

Copy link

pytorch-bot bot commented Jun 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155155

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Cancelled Jobs, 2 Unrelated Failures

As of commit 3291286 with merge base 1ccc57e (image):

NEW FAILURE - The following job has failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75985427

Copy link
Contributor

@angelayi angelayi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paraphrased from @zou3519: can we check that the stacktrace added here doesn't affect torch.compile node stacktraces, where dynamo generates the stacktrace, and then this codepath gets called in AOTAutograd?

pretty sure this code doesn't affect the torch.compile stacktraces because of the if "stack_trace" not in node.meta conditional, but let's just check to make sure

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 4, 2025
@yushangdi
Copy link
Contributor Author

yushangdi commented Jun 4, 2025

paraphrased from @zou3519: can we check that the stacktrace added here doesn't affect torch.compile node stacktraces, where dynamo generates the stacktrace, and then this codepath gets called in AOTAutograd?

pretty sure this code doesn't affect the torch.compile stacktraces because of the if "stack_trace" not in node.meta conditional, but let's just check to make sure

Yep it's still there!

        inp = torch.randn(4, 4)

        from torch._dynamo.backends.common import aot_autograd
        from functorch.compile import make_boxed_func

        def my_compiler(gm, example_inputs):
            gm.print_readable()
            return make_boxed_func(gm.forward)

        my_backend = aot_autograd(fw_compiler=my_compiler)

        result = torch.compile(Foo(), backend=my_backend)(inp)
class GraphModule(torch.nn.Module):
 def forward(self, primals_1: "f32[4, 4]", primals_2: "f32[4]", primals_3: "f32[4, 4]"):
      # File: /data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/a7ab72042f458866/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py:417 in forward, code: x = self.linear(x)
     t: "f32[4, 4]" = torch.ops.aten.t.default(primals_1);  primals_1 = None
     addmm: "f32[4, 4]" = torch.ops.aten.addmm.default(primals_2, primals_3, t);  primals_2 = t = None
     
      # File: /data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/a7ab72042f458866/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py:418 in forward, code: x *= 2.0
     mul: "f32[4, 4]" = torch.ops.aten.mul.Tensor(addmm, 2.0);  addmm = None
     return (mul, primals_3)

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75985427

yushangdi added a commit to yushangdi/pytorch that referenced this pull request Jun 4, 2025
Summary:
Pull Request resolved: pytorch#155155

Previosuly, we only add stack trace in `class _ModuleStackTracer(PythonKeyTracer)` for non-strict export. I moved this stack trace logic to the parent class `PythonKeyTracer`, this way the graph traced from Module using make_fx will have stack_trace as well.

Motivation: we've observed some uses cases where users first use `make_fx` on the Module, and then run `export` on the resulting graph. If the result of `make_fx` doesn't have stack trace, the stack trace information is lost.

Test Plan:
```
buck run test:test_export -- -r  test_stack_trace
```

Rollback Plan:

Reviewed By: angelayi

Differential Revision: D75985427
Comment on lines 1133 to 1134
frame.filename.endswith("fx/_symbolic_trace.py")
or frame.filename.endswith("export/_trace.py")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
frame.filename.endswith("fx/_symbolic_trace.py")
or frame.filename.endswith("export/_trace.py")
frame.filename.endswith(("fx/_symbolic_trace.py", "export/_trace.py"))

More readable and more efficient (does trie based matching).

yushangdi added a commit to yushangdi/pytorch that referenced this pull request Jun 5, 2025
Summary:

Previosuly, we only add stack trace in `class _ModuleStackTracer(PythonKeyTracer)` for non-strict export. I moved this stack trace logic to the parent class `PythonKeyTracer`, this way the graph traced from Module using make_fx will have stack_trace as well.

Motivation: we've observed some uses cases where users first use `make_fx` on the Module, and then run `export` on the resulting graph. If the result of `make_fx` doesn't have stack trace, the stack trace information is lost.

Test Plan:
```
buck run test:test_export -- -r  test_stack_trace
```

Rollback Plan:

Reviewed By: angelayi

Differential Revision: D75985427
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75985427

yushangdi added a commit to yushangdi/pytorch that referenced this pull request Jun 5, 2025
Summary:
Pull Request resolved: pytorch#155155

Previosuly, we only add stack trace in `class _ModuleStackTracer(PythonKeyTracer)` for non-strict export. I moved this stack trace logic to the parent class `PythonKeyTracer`, this way the graph traced from Module using make_fx will have stack_trace as well.

Motivation: we've observed some uses cases where users first use `make_fx` on the Module, and then run `export` on the resulting graph. If the result of `make_fx` doesn't have stack trace, the stack trace information is lost.

Test Plan:
```
buck run test:test_export -- -r  test_stack_trace
```

Rollback Plan:

Reviewed By: angelayi

Differential Revision: D75985427
Summary:

Previosuly, we only add stack trace in `class _ModuleStackTracer(PythonKeyTracer)` for non-strict export. I moved this stack trace logic to the parent class `PythonKeyTracer`, this way the graph traced from Module using make_fx will have stack_trace as well.

Motivation: we've observed some uses cases where users first use `make_fx` on the Module, and then run `export` on the resulting graph. If the result of `make_fx` doesn't have stack trace, the stack trace information is lost.

Test Plan:
```
buck run test:test_export -- -r  test_stack_trace
```

Rollback Plan:

Reviewed By: angelayi

Differential Revision: D75985427
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75985427

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fb-exported fx release notes: fx release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.