Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Improve torch.ops typing #154555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
Loading
from
Open

Conversation

aorenste
Copy link
Contributor

@aorenste aorenste commented May 28, 2025

Summary:
Cloned #153558 from benjaminglass1 and fixed internal typing errors.

Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that __getattr__ can return a single type in all other cases.

Decisions made along the way:

  1. torch.ops.higher_order is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the _Ops class.
  2. __getattr__ is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.

The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.

Test Plan: CI

Differential Revision: D75497142

Co-authored-by: Benjamin Glass bglass@quansight.com

cc @ezyang @malfet @xuzhao9 @gramster @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link

pytorch-bot bot commented May 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154555

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 3 Unrelated Failures

As of commit 762b0b3 with merge base 093fd47 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75497142

@aorenste aorenste changed the title TYPING: #153558 Improve torch.ops typing May 28, 2025
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 28, 2025
torch/_ops.py Outdated Show resolved Hide resolved
@Skylion007 Skylion007 added the module: typing Related to mypy type annotations label May 28, 2025
torch/_ops.py Outdated Show resolved Hide resolved
torch/_ops.py Outdated Show resolved Hide resolved
torch/_ops.py Outdated Show resolved Hide resolved
torch/_inductor/decomposition.py Outdated Show resolved Hide resolved
torch/_ops.py Outdated Show resolved Hide resolved
torch/_ops.py Outdated Show resolved Hide resolved
@benjaminglass1
Copy link
Collaborator

@aorenste Would you like me to make the changes for this? I could send you a patch to apply.

@aorenste
Copy link
Contributor Author

aorenste commented May 30, 2025

@aorenste Would you like me to make the changes for this? I could send you a patch to apply.

@benjaminglass1 Could you send me a patch? I've been fixing up the fallout from #154515.

@benjaminglass1
Copy link
Collaborator

@aorenste I've tested this patch, minimally, locally. It appears to be functional, although we'll let the pipeline see. It contains fixes for all the suggestions above, as well as some additional genericization in OpOverload.

diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 4fcb847ddd6..a13d93e26e0 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1631,7 +1631,7 @@ class Generator:
 class _DispatchOperatorHandle:
     def schema(self) -> FunctionSchema: ...
     def debug(self) -> str: ...
-    def redispatch_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Stack: ...
+    def redispatch_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Any: ...
 
 class _DispatchModule:
     def reset(self) -> None: ...
diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py
index 1e03146bbcc..a4103eb8387 100644
--- a/torch/_dispatch/python.py
+++ b/torch/_dispatch/python.py
@@ -3,12 +3,15 @@ import itertools
 import unittest.mock
 from collections.abc import Iterator
 from contextlib import contextmanager
+from typing import Callable, TypeVar, Union
+from typing_extensions import ParamSpec
 
 import torch
 import torch._C
 import torch._ops
 import torch.utils._python_dispatch
 import torch.utils._pytree as pytree
+from torch._C import DispatchKey
 
 
 __all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
@@ -19,6 +22,9 @@ enable_pre_dispatch = torch._C._EnablePreDispatch
 
 CROSSREF_FUNCTIONALIZE = False
 
+_P = ParamSpec("_P")
+_T = TypeVar("_T")
+
 
 def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
     """
@@ -103,14 +109,16 @@ def _fmt(a: object) -> object:
         return a
 
 
-def make_crossref_functionalize(op, final_key):
+def make_crossref_functionalize(
+    op: torch._ops.OpOverload[_P, _T], final_key: DispatchKey
+) -> Union[Callable[_P, _T], DispatchKey]:
     from torch._subclasses.fake_tensor import FakeTensorMode
 
     # This case is pretty weird, suppress it for now
     if op == torch.ops.aten.lift_fresh.default:
         return final_key
 
-    def handler(*args, **kwargs):
+    def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
         fake_mode = FakeTensorMode()
 
         def fakeify_defun(t):
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
index 42a2c94d3b8..98a93fe6eb9 100644
--- a/torch/_inductor/decomposition.py
+++ b/torch/_inductor/decomposition.py
@@ -6,7 +6,7 @@ import operator
 import sys
 import typing
 from typing import Any, Callable, Optional, TypeVar, Union
-from typing_extensions import ParamSpec
+from typing_extensions import ParamSpec, TypeAlias
 
 import torch
 import torch._decomp as decomp
@@ -51,6 +51,10 @@ from .utils import (
 _T = TypeVar("_T")
 _P = ParamSpec("_P")
 
+_GenericOperator: TypeAlias = Union[
+    torch._ops.OperatorBase, torch._ops.OpOverloadPacket
+]
+
 log = logging.getLogger(__name__)
 aten = torch.ops.aten
 prims = torch.ops.prims
@@ -132,9 +136,9 @@ remove_decompositions(decompositions, decomps_to_exclude)
 
 
 def register_decomposition(
-    ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]],
+    ops: Union[_GenericOperator, list[_GenericOperator]],
 ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
-    for op in [ops] if callable(ops) else ops:  # type: ignore[attr-defined]
+    for op in ops if isinstance(ops, list) else [ops]:
         if op in decompositions:
             log.warning("duplicate decomp: %s", ops)
     return decomp.register_decomposition(ops, decompositions)
@@ -523,7 +527,7 @@ def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
     return torch.where(torch.isnan(other) | (other < self), self, other)
 
 
-@register_decomposition([aten.amax])
+@register_decomposition(aten.amax)
 def amax(
     self: torch.Tensor,
     dim: Optional[int] = None,
@@ -534,7 +538,7 @@ def amax(
     return NotImplemented
 
 
-@register_decomposition([aten.amin])
+@register_decomposition(aten.amin)
 def amin(
     self: torch.Tensor,
     dim: Optional[int] = None,
@@ -582,7 +586,7 @@ def get_like_layout(
         return memory_format
 
 
-@register_decomposition([aten.rand_like])
+@register_decomposition(aten.rand_like)
 def rand_like(
     self: torch.Tensor,
     *,
@@ -599,7 +603,7 @@ def rand_like(
     ).to(memory_format=get_like_layout(self, memory_format))
 
 
-@register_decomposition([aten.randn_like])
+@register_decomposition(aten.randn_like)
 def randn_like(
     self: torch.Tensor,
     *,
@@ -616,7 +620,7 @@ def randn_like(
     ).to(memory_format=get_like_layout(self, memory_format))
 
 
-@register_decomposition([aten.full_like])
+@register_decomposition(aten.full_like)
 def full_like(
     self: torch.Tensor,
     fill_value: Union[int, float],
@@ -638,7 +642,7 @@ def full_like(
     ).to(memory_format=get_like_layout(self, memory_format))
 
 
-@register_decomposition([aten.randint_like.default])
+@register_decomposition(aten.randint_like.default)
 def randint_like(
     self: torch.Tensor,
     high: int,
@@ -658,7 +662,7 @@ def randint_like(
     ).to(memory_format=get_like_layout(self, memory_format))
 
 
-@register_decomposition([aten.randint_like.low_dtype])
+@register_decomposition(aten.randint_like.low_dtype)
 def randint_like_low(
     self: torch.Tensor,
     low: int,
@@ -679,7 +683,7 @@ def randint_like_low(
     ).to(memory_format=get_like_layout(self, memory_format))
 
 
-@register_decomposition([aten.randint.default])
+@register_decomposition(aten.randint.default)
 def randint(
     high: int,
     size: list[Union[int, torch.SymInt]],
@@ -688,7 +692,7 @@ def randint(
     return aten.randint.low(0, high, size, **kwargs)
 
 
-@register_decomposition([quantized.linear_dynamic_fp16_unpacked_weight.default])
+@register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default)
 def linear_dynamic_fp16_unpacked_weight(
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -700,7 +704,7 @@ def linear_dynamic_fp16_unpacked_weight(
     )
 
 
-@register_decomposition([_quantized.wrapped_quantized_linear.default])
+@register_decomposition(_quantized.wrapped_quantized_linear.default)
 def wrapped_quantized_linear(
     input: torch.Tensor,
     input_scale: torch.Tensor,
@@ -727,7 +731,7 @@ def wrapped_quantized_linear(
     )
 
 
-@register_decomposition([torch.ops.quantized.embedding_bag_byte_unpack])
+@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
 def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor:
     def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor:
         x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3))
@@ -772,7 +776,7 @@ def grid_sampler_2d(
     return output
 
 
-@register_decomposition([aten._foreach_addcmul.Scalar])
+@register_decomposition(aten._foreach_addcmul.Scalar)
 def _foreach_addcmul_scalar(
     self: list[torch.Tensor],
     left_tensors: list[torch.Tensor],
@@ -784,7 +788,7 @@ def _foreach_addcmul_scalar(
     )
 
 
-@register_decomposition([aten._foreach_addcdiv.Scalar])
+@register_decomposition(aten._foreach_addcdiv.Scalar)
 def _foreach_addcdiv_scalar(
     self: list[torch.Tensor],
     left_tensors: list[torch.Tensor],
@@ -796,7 +800,7 @@ def _foreach_addcdiv_scalar(
     )
 
 
-@register_decomposition([aten._foreach_lerp.Scalar])
+@register_decomposition(aten._foreach_lerp.Scalar)
 def _foreach_lerp_scalar(
     start_tensors: list[torch.Tensor],
     end_tensors: list[torch.Tensor],
@@ -810,7 +814,7 @@ def _foreach_lerp_scalar(
     )
 
 
-@register_decomposition([aten._foreach_lerp.ScalarList])
+@register_decomposition(aten._foreach_lerp.ScalarList)
 def _foreach_lerp_scalarlist(
     start_tensors: list[torch.Tensor],
     end_tensors: list[torch.Tensor],
@@ -825,7 +829,7 @@ def _foreach_lerp_scalarlist(
 
 
 @aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
-@register_decomposition([aten.miopen_batch_norm])
+@register_decomposition(aten.miopen_batch_norm)
 def miopen_batch_norm(
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -870,7 +874,7 @@ def select_decomp_table() -> dict[Any, Callable[..., Any]]:
     return fast_random_decomps()
 
 
-@register_decomposition([aten.masked_scatter])
+@register_decomposition(aten.masked_scatter)
 def masked_scatter(
     self: torch.Tensor,
     mask: torch.Tensor,
@@ -889,7 +893,7 @@ def masked_scatter(
     return NotImplemented
 
 
-@register_decomposition([quantized_decomposed.choose_qparams.tensor])
+@register_decomposition(quantized_decomposed.choose_qparams.tensor)
 def choose_qparams_tensor(
     input: torch.Tensor,
     quant_min: int,
@@ -905,7 +909,7 @@ def choose_qparams_tensor(
     return scale.to(torch.float64), zero_point.to(torch.int64)
 
 
-@register_decomposition([aten.put])
+@register_decomposition(aten.put)
 def put(
     self: torch.Tensor,
     index: torch.Tensor,
@@ -919,7 +923,7 @@ def put(
     return flattened.reshape(self.shape)
 
 
-@register_decomposition([aten.put_])
+@register_decomposition(aten.put_)
 def put_(
     self: torch.Tensor,
     index: torch.Tensor,
@@ -930,7 +934,7 @@ def put_(
     return self.copy_(out)
 
 
-@register_decomposition([aten._softmax_backward_data.default])
+@register_decomposition(aten._softmax_backward_data.default)
 @pw_cast_for_opmath
 def _softmax_backward_data(
     grad_output: torch.Tensor,
@@ -952,7 +956,7 @@ def _softmax_backward_data(
     return grad_input.contiguous()
 
 
-@register_decomposition([aten.index_reduce])
+@register_decomposition(aten.index_reduce)
 def index_reduce(
     self: torch.Tensor,
     dim: int,
@@ -1057,7 +1061,7 @@ def _max_pool_with_indices(
     return vals, indices
 
 
-@register_decomposition([aten.max_pool2d_with_indices])
+@register_decomposition(aten.max_pool2d_with_indices)
 def max_pool2d_with_indices(
     x: torch.Tensor,
     kernel_size: list[int],
@@ -1071,7 +1075,7 @@ def max_pool2d_with_indices(
     )
 
 
-@register_decomposition([aten.max_pool3d_with_indices])
+@register_decomposition(aten.max_pool3d_with_indices)
 def max_pool3d_with_indices(
     x: torch.Tensor,
     kernel_size: list[int],
@@ -1085,7 +1089,7 @@ def max_pool3d_with_indices(
     )
 
 
-@register_decomposition([aten.adaptive_max_pool2d])
+@register_decomposition(aten.adaptive_max_pool2d)
 def adaptive_max_pool2d(
     x: torch.Tensor, output_size: list[int]
 ) -> tuple[torch.Tensor, torch.Tensor]:
@@ -1103,7 +1107,7 @@ def adaptive_max_pool2d(
     return NotImplemented
 
 
-@register_decomposition([aten.searchsorted.Scalar])
+@register_decomposition(aten.searchsorted.Scalar)
 def searchsorted_scalar(
     sorted_sequence: torch.Tensor,
     self: torch.types.Number,
@@ -1123,7 +1127,7 @@ def searchsorted_scalar(
     )[0]
 
 
-@register_decomposition([aten.rrelu_with_noise_functional])
+@register_decomposition(aten.rrelu_with_noise_functional)
 def rrelu_with_noise_functional(
     self: torch.Tensor,
     noise: torch.Tensor,
diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py
index d5b8a86b868..dc48f0b804d 100644
--- a/torch/_inductor/fx_passes/reinplace.py
+++ b/torch/_inductor/fx_passes/reinplace.py
@@ -180,10 +180,9 @@ def scatter_always_uses_mutation(node: torch.fx.Node) -> bool:
     _, _, view_ops = node.args
     view_ops = cast(Sequence[torch.fx.node.Argument], view_ops)
     return any(
-        view.target in _ALWAYS_MUTATING_SCATTER_OPS
+        target in _ALWAYS_MUTATING_SCATTER_OPS
         for view in view_ops
-        if isinstance(view, torch.fx.Node)
-        and isinstance(view.target, torch._ops.OpOverload)
+        if isinstance(target := getattr(view, "target", None), torch._ops.OpOverload)
     )
 
 
diff --git a/torch/_ops.py b/torch/_ops.py
index 9a8a97926b2..b46d1636207 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -6,18 +6,19 @@ import importlib
 import inspect
 import sys
 import types
+from collections.abc import Iterator
 from functools import cached_property
 from typing import (
     Any,
     Callable,
     ClassVar,
     final,
+    Generic,
     Optional,
     TYPE_CHECKING,
-    TypeVar,
     Union,
 )
-from typing_extensions import Concatenate, ParamSpec
+from typing_extensions import Concatenate, ParamSpec, TypeVar
 
 import torch
 import torch.utils._pytree as pytree
@@ -31,8 +32,8 @@ if TYPE_CHECKING:
     from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
 
 
-_T = TypeVar("_T")
-_P = ParamSpec("_P")
+_T = TypeVar("_T", default=Any)
+_P = ParamSpec("_P", default=...)
 
 
 # Query `hasattr` only once.
@@ -744,7 +745,7 @@ def get_cached_ops():
 
 # Each OpOverload object contains pointer to a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
 # You can obtain an OpOverload object through attribute query on OpOverloadPacket.
-class OpOverload(OperatorBase):
+class OpOverload(OperatorBase, Generic[_P, _T]):
     def __init__(
         self,
         overloadpacket: "OpOverloadPacket",
@@ -789,13 +790,13 @@ class OpOverload(OperatorBase):
                 is_write = a.alias_info.is_write or is_write
         self.is_view = is_write is not None and not is_write
 
-    @property
-    def _namespace(self):
-        return self._schema.name.split("::")[0]
+    @cached_property
+    def _namespace(self) -> str:
+        return self._schema.name.split("::", maxsplit=1)[0]
 
-    @property
-    def _opname(self):
-        return self._schema.name.split("::")[1]
+    @cached_property
+    def _opname(self) -> str:
+        return self._schema.name.split("::", maxsplit=1)[1]
 
     @cached_property
     def _handle(self) -> torch._C._DispatchOperatorHandle:
@@ -808,20 +809,18 @@ class OpOverload(OperatorBase):
         return self
 
     def __repr__(self):
-        return "<OpOverload(op='{}.{}', overload='{}')>".format(
-            *self._schema.name.split("::"), self._overloadname
-        )
+        return f"<OpOverload(op='{self._namespace}.{self._opname}', overload='{self._overloadname}')>"
 
     # Use positional-only argument to avoid naming collision with aten ops arguments
     # that are named "self". This way, all the aten ops can be called by kwargs.
-    def __call__(self, /, *args, **kwargs):
+    def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
         return self._op(*args, **kwargs)
 
     # Use positional-only argument to avoid naming collision with aten ops arguments
     # that are named "self". This way, all the aten ops can be called by kwargs.
     def redispatch(
-        self, /, keyset: torch._C.DispatchKeySet, *args, **kwargs
-    ) -> "torch._C.Stack":
+        self, /, keyset: torch._C.DispatchKeySet, *args: _P.args, **kwargs: _P.kwargs
+    ) -> _T:
         return self._handle.redispatch_boxed(keyset, *args, **kwargs)
 
     def __hash__(self):
@@ -831,27 +830,27 @@ class OpOverload(OperatorBase):
     def __str__(self):
         return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
 
-    def has_kernel_for_dispatch_key(self, k):
+    def has_kernel_for_dispatch_key(self, k: DispatchKey) -> bool:
         return super().has_kernel_for_dispatch_key(
             k
         ) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k)
 
-    def has_kernel_for_any_dispatch_key(self, ks):
+    def has_kernel_for_any_dispatch_key(self, ks: torch._C.DispatchKeySet) -> bool:
         return torch._C._dispatch_has_kernel_for_any_dispatch_key(
             self.name(), ks
         ) or super().has_kernel_for_any_dispatch_key(ks)
 
     @property
-    def namespace(self):
-        return self._schema.name.split("::")[0]
+    def namespace(self) -> str:
+        return self._namespace
 
-    def _can_decompose(self):
+    def _can_decompose(self) -> bool:
         dk = DispatchKey.CompositeImplicitAutograd
         return dk in self.py_kernels or torch._C._dispatch_has_kernel_for_dispatch_key(
             self.name(), dk
         )
 
-    def decompose(self, *args, **kwargs):
+    def decompose(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
         dk = DispatchKey.CompositeImplicitAutograd
         if dk in self.py_kernels:
             # NB: This branch is not too necessary anymore, because we can
@@ -872,11 +871,11 @@ class OpOverload(OperatorBase):
     # registering Autograd affects AutogradCPU).  del_dispatch is to be used
     # only if you are specifically modifying how get_dispatch handles a
     # particular input 'key'.
-    def _uncache_dispatch(self, key):
+    def _uncache_dispatch(self, key: DispatchKey) -> None:
         self._dispatch_cache.pop(key, None)
 
     # This implements the pre-computation logic for the Python dispatcher.
-    def _get_dispatch(self, key):
+    def _get_dispatch(self, key: DispatchKey) -> Union[DispatchKey, Callable[_P, _T]]:
         # This is only called upon a cache miss
         assert key not in self._dispatch_cache, f"{self} {key}"
 
@@ -886,7 +885,7 @@ class OpOverload(OperatorBase):
                 add_cached_op(self)
                 return key
 
-            def handler(*args, **kwargs):
+            def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
                 from torch.utils._python_dispatch import _get_current_dispatch_mode
 
                 # TODO: We also need to handle tensor subclasses here
@@ -926,7 +925,7 @@ class OpOverload(OperatorBase):
                 )
             ):
 
-                def handler(*args, **kwargs):
+                def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
                     @contextlib.contextmanager
                     def _temporarily_pop_modes_from_pre_dispatch():
                         top_mode = _pop_mode_from_pre_dispatch()
@@ -959,7 +958,7 @@ class OpOverload(OperatorBase):
             import torch._dispatch.python as pydispatch
 
             if pydispatch.CROSSREF_FUNCTIONALIZE:
-                handler = pydispatch.make_crossref_functionalize(self, final_key)
+                handler = pydispatch.make_crossref_functionalize(self, final_key)  # type: ignore[assignment]
                 if cache_result:
                     self._dispatch_cache[key] = handler
                     add_cached_op(self)
@@ -993,7 +992,7 @@ class OpOverload(OperatorBase):
 # schema consists of torch.ScriptObject (i.e. custom class) input.
 # TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
 # when its inputs contain FakeScriptObject in a similar way as higher order ops.
-class TorchBindOpOverload(OpOverload):
+class TorchBindOpOverload(OpOverload[_P, _T]):
     def _fallthrough_keys(self) -> list[DispatchKey]:
         # TODO: we should be calling the fallback for these, but a fallthrough is almost close
         # enough to the fallback in most cases that we care about.
@@ -1042,7 +1041,7 @@ class TorchBindOpOverload(OpOverload):
 
     # Use positional-only argument to avoid naming collision with aten ops arguments
     # that are named "self". This way, all the aten ops can be called by kwargs.
-    def __call__(self, /, *args, **kwargs):
+    def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
         if _must_dispatch_in_python(args, kwargs):
             # When any inputs are FakeScriptObject, we need to
             # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher
@@ -1055,10 +1054,14 @@ class TorchBindOpOverload(OpOverload):
             # 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might
             #    cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction.
             with self._register_as_effectful_op_temporarily():
-                return self._dispatch_in_python(args, kwargs, self._fallthrough_keys())
+                return self._dispatch_in_python(
+                    self._fallthrough_keys(), *args, **kwargs
+                )
         return self._op(*args, **kwargs)
 
-    def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
+    def _dispatch_in_python(
+        self, fallthrough_keys: list[DispatchKey], *args: _P.args, **kwargs: _P.kwargs
+    ) -> _T:
         non_fallthrough_keys = torch._C._dispatch_keyset_full()
         for key in fallthrough_keys:
             non_fallthrough_keys = non_fallthrough_keys.remove(key)
@@ -1079,7 +1082,9 @@ class TorchBindOpOverload(OpOverload):
                 self.name(), dispatch_key
             ):
                 return self._dispatch_in_python(
-                    args, kwargs, fallthrough_keys + [dispatch_key]
+                    fallthrough_keys + [dispatch_key],
+                    *args,
+                    **kwargs,
                 )
 
             raise RuntimeError(
@@ -1111,14 +1116,14 @@ def _has_script_object_arg(schema: torch.FunctionSchema) -> bool:
 
 # OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
 # You can obtain an OpOverload object through attribute query.
-class OpOverloadPacket:
+class OpOverloadPacket(Generic[_P, _T]):
     __file__: ClassVar[str] = "torch.ops"
 
     def __init__(
         self,
         qualified_op_name: str,
         op_name: str,
-        op: Callable[..., Any],
+        op: Callable[_P, _T],
         overload_names: list[str],
     ) -> None:
         # These attributes are accessible on the object through the properties
@@ -1158,7 +1163,7 @@ class OpOverloadPacket:
             for overload_name in self._overload_names
         }
 
-    def __getattr__(self, key: str) -> OpOverload:
+    def __getattr__(self, key: str) -> OpOverload[_P, _T]:
         # ensure that query for dunder attributes that does not exist on
         # opoverloadpacket but instead exists on the self._op object does not unnecessarily call
         # `_get_operation_overload` (which is an expensive operation).
@@ -1193,7 +1198,7 @@ class OpOverloadPacket:
 
             op_, op_dk_, tags = op_dk_tags
             schema = torch._C._get_schema(self._qualified_op_name, use_key)
-            overload = (
+            overload: OpOverload[_P, _T] = (
                 OpOverload(self, op_, op_dk_, schema, tags)
                 if not _has_script_object_arg(schema)
                 else TorchBindOpOverload(self, op_, op_dk_, schema, tags)
@@ -1207,12 +1212,12 @@ class OpOverloadPacket:
                 f"The underlying op of '{str(self)}' has no overload name '{key}'"
             ) from None
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[str]:
         return iter(self._dir)
 
     # Use positional-only argument to avoid naming collision with aten ops arguments
     # that are named "self". This way, all the aten ops can be called by kwargs.
-    def __call__(self, /, *args, **kwargs):
+    def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
         # overloading __call__ to ensure torch.ops.foo.bar()
         # is still callable from JIT
         # We save the function ptr as the `op` attribute on
@@ -1222,8 +1227,8 @@ class OpOverloadPacket:
         # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
         # intercept it here and call TorchBindOpverload instead.
         if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
-            return _call_overload_packet_from_python(self, args, kwargs)
-        return self._op(*args, **(kwargs or {}))
+            return _call_overload_packet_from_python(self, *args, **kwargs)
+        return self._op(*args, **kwargs)
 
     # TODO: use this to make a __dir__
     def overloads(self):
@@ -1232,10 +1237,12 @@ class OpOverloadPacket:
 
 # Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp
 # _jit_get_operations, which calls _get_operation_for_overload_or_packet.
-def _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs):
+def _call_overload_packet_from_python(
+    op: OpOverloadPacket[_P, _T], *args: _P.args, **kwargs: _P.kwargs
+) -> _T:
     # Re-use the torch function handling logic in cpp
     torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet(
-        op, *args, **kwargs
+        op, args, kwargs
     )
 
     if torch_function_called:
@@ -1252,7 +1259,7 @@ def _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs):
         op_overload = getattr(op, overload_name)
         try:
             _ = torch._C._check_schema_allow_fake_script_object(
-                op_overload._schema, *args, **kwargs
+                op_overload._schema, args, kwargs
             )
             found_op = op_overload
             break
@@ -1313,7 +1320,7 @@ class _OpNamespace(types.ModuleType):
         self.name = name
         self._dir: list[str] = []
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[str]:
         return iter(self._dir)
 
     def __getattr__(self, op_name: str) -> OpOverloadPacket:
@@ -1377,10 +1384,10 @@ class _HigherOrderNamespace(types.ModuleType):
         super().__init__("torch.ops.higher_order")
         self._dir: list[str] = []
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[str]:
         return iter(self._dir)
 
-    def __getattr__(self, name) -> HigherOrderOperator:
+    def __getattr__(self, name: str) -> HigherOrderOperator:
         # Following _OpNamespace.__getattr__, we cache the op on this object.
         op = _higher_order_ops.get(name, None)
         if op is None:
@@ -1401,14 +1408,14 @@ class _Ops(types.ModuleType):
         self.higher_order = _HigherOrderNamespace()
         self._dir = []
 
-    def __getattr__(self, name) -> _OpNamespace:
+    def __getattr__(self, name: str) -> _OpNamespace:
         # Here we are creating `torch.ops.my_namespace`
         namespace = _OpNamespace(name)
         setattr(self, name, namespace)
         self._dir.append(name)
         return namespace
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[str]:
         return iter(self._dir)
 
     def import_module(self, module):
diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py
index e08b5937260..83cf384d511 100644
--- a/torch/_prims_common/wrappers.py
+++ b/torch/_prims_common/wrappers.py
@@ -370,7 +370,9 @@ def out_wrapper(
             annotation=out_type,
         )
         # Mark that the function now returns a tuple
-        assert isinstance(sig.return_annotation, str) or sig.return_annotation in (
+        assert isinstance(
+            sig.return_annotation, (str, TypeVar)
+        ) or sig.return_annotation in (
             sig.empty,
             out_type,
             bc_out_type,

@aorenste aorenste force-pushed the export-D75497142 branch from 0f63c94 to 678314e Compare June 2, 2025 03:43
pytorch-bot bot pushed a commit that referenced this pull request Jun 2, 2025
Summary:

Cloned #153558 from benjaminglass1 and fixed internal typing errors.

Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases.

Decisions made along the way:

1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class.
2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.

The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.

Test Plan: CI

Differential Revision: D75497142
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75497142

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@benjaminglass1
Copy link
Collaborator

@aorenste Working on this new batch of failures locally; thankfully, they all appear to be failing more or less identically, so this should hopefully be the last patch.

@benjaminglass1
Copy link
Collaborator

@aorenste Found it; I made a mistake thanks to some incorrect typing in torch._C that I've patched up in this diff as well. Hopefully this is the last patch.

diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 4e2b6b4c8df..27e3603872e 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -452,13 +452,13 @@ ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]]
 #        and torch/csrc/jit/python/init.cpp
 def _maybe_call_torch_function_for_op_packet(
     op_overload_packet: Any,
-    args: Any,
-    kwargs: Any,
+    *args: Any,
+    **kwargs: Any,
 ) -> Any: ...
 def _check_schema_allow_fake_script_object(
     schema: FunctionSchema,
-    args: Any,
-    kwargs: Any,
+    *args: Any,
+    **kwargs: Any,
 ) -> _bool: ...
 def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ...
 def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
diff --git a/torch/_ops.py b/torch/_ops.py
index 34e56eece11..8455707aa9a 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -1242,7 +1242,7 @@ def _call_overload_packet_from_python(
 ) -> _T:
     # Re-use the torch function handling logic in cpp
     torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet(
-        op, args, kwargs
+        op, *args, **kwargs
     )
 
     if torch_function_called:
@@ -1259,7 +1259,7 @@ def _call_overload_packet_from_python(
         op_overload = getattr(op, overload_name)
         try:
             _ = torch._C._check_schema_allow_fake_script_object(
-                op_overload._schema, args, kwargs
+                op_overload._schema, *args, **kwargs
             )
             found_op = op_overload
             break

@aorenste aorenste force-pushed the export-D75497142 branch from 678314e to 5aeba39 Compare June 2, 2025 17:20
aorenste added a commit to aorenste/pytorch that referenced this pull request Jun 2, 2025
Summary:
X-link: pytorch/executorch#11276


Cloned pytorch#153558 from benjaminglass1 and fixed internal typing errors.

Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases.

Decisions made along the way:

1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class.
2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.

The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.

Test Plan: CI

Differential Revision: D75497142
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75497142

@benjaminglass1
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!

Details for Dev Infra team Raised by workflow job

aorenste added a commit to aorenste/executorch that referenced this pull request Jun 3, 2025
Summary:

X-link: pytorch/pytorch#154555

Cloned pytorch/pytorch#153558 from benjaminglass1 and fixed internal typing errors.

Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases.

Decisions made along the way:

1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class.
2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.

The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.

Reviewed By: bobrenjc93

Differential Revision: D75497142
aorenste added a commit to aorenste/executorch that referenced this pull request Jun 3, 2025
Summary:
Pull Request resolved: pytorch#11276

X-link: pytorch/pytorch#154555

Cloned pytorch/pytorch#153558 from benjaminglass1 and fixed internal typing errors.

Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases.

Decisions made along the way:

1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class.
2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.

The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.

Reviewed By: bobrenjc93

Differential Revision: D75497142
@aorenste
Copy link
Contributor Author

aorenste commented Jun 3, 2025

@pytorchbot merge

@benjaminglass1 PRs that originate w/ internal diffs have to be landed internally.

There's an internal tool which should allow me to check for internal breaks before landing - but ironically it's currently broken so I'm waiting for it to be fixed so I can make sure I don't have a repeat from last week where after landing I get a report of almost 1,000 failing tests...

@benjaminglass1
Copy link
Collaborator

@aorenste Is the testing tool back up? I'd like to rebase and merge this, if we can.

@aorenste
Copy link
Contributor Author

@aorenste Is the testing tool back up? I'd like to rebase and merge this, if we can.

Not yet. I was told either yesterday or today...

Summary:
X-link: pytorch/executorch#11276

Pull Request resolved: pytorch#154555

Cloned pytorch#153558 from benjaminglass1 and fixed internal typing errors.

Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases.

Decisions made along the way:

1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class.
2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.

The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.

Test Plan: CI

Reviewed By: bobrenjc93, mergennachin

Differential Revision: D75497142
aorenste added a commit to aorenste/executorch that referenced this pull request Jun 12, 2025
Summary:
Pull Request resolved: pytorch#11276

X-link: pytorch/pytorch#154555

Cloned pytorch/pytorch#153558 from benjaminglass1 and fixed internal typing errors.

Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases.

Decisions made along the way:

1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class.
2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.

The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.

Reviewed By: bobrenjc93, mergennachin

Differential Revision: D75497142
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75497142

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.