-
Notifications
You must be signed in to change notification settings - Fork 24.4k
[dynamic shapes] unbacked safe conv1d #154089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154089
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 3fe085e with merge base ab6cb85 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
return torch.contiguous_format | ||
elif input_tensor.is_contiguous(memory_format=torch.preserve_format): | ||
return torch.preserve_format |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be wrong, but it seems:
- the result is used in a
.to()
call here:pytorch/torch/_meta_registrations.py
Line 2472 in 7053930
out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload] - and
memory_format=None
is the same asmemory_format=torch.preserve_format
- I'm not sure "contiguous according to
torch.preserve_format
" makes sense? This is mostly used for copying the original tensor, to say "keep the original format" - how can you be contiguous w.r.t that? There's some indications this isn't supported:pytorch/aten/src/ATen/native/TensorProperties.cpp
Lines 115 to 124 in 7053930
Tensor contiguous(const Tensor& self, MemoryFormat memory_format) { if (self.is_contiguous(memory_format)) { return self; } TORCH_CHECK( memory_format != MemoryFormat::Preserve, "preserve memory format is unsupported by the contiguous operator"); return self.clone(memory_format); }
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh is that something you are familiar with?
@@ -2447,10 +2450,8 @@ def pick_memory_format(): | ||
else: | ||
if is_channels_last(input_tensor): | ||
return torch.channels_last | ||
if input_tensor.is_contiguous(memory_format=torch.contiguous_format): | ||
if utils.definitely_contiguous_for_memory_format(input_tensor, memory_format=torch.contiguous_format): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are just changing the meta function, is there a decomposition for conv that asserts a matching behaviour with the meta for unbacked?
or does the conv kernel generate the correct expected memory format anyway? if the later i wonder if its safe to do the change?
@@ -2447,10 +2450,8 @@ def pick_memory_format(): | ||
else: | ||
if is_channels_last(input_tensor): | ||
return torch.channels_last | ||
if input_tensor.is_contiguous(memory_format=torch.contiguous_format): | ||
if utils.definitely_contiguous_for_memory_format(input_tensor, memory_format=torch.contiguous_format): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to return None explicitly ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need the striding of convolution to be correct because we fallback to aten convolution for it. If don't guard, and we get the striding wrong, inductor output code will be incorrect.
We would need a mechanism to force the contiguous/channels last path for the convolution if we're not able to infer the meta striding. but thats not implemented here, we're just changing the meta.
Ran into this exporting https://github.com/SWivid/F5-TTS
cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv