-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Improve torch.ops typing #154555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Improve torch.ops typing #154555
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154555
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 3 Unrelated FailuresAs of commit 762b0b3 with merge base 093fd47 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D75497142 |
@aorenste Would you like me to make the changes for this? I could send you a patch to apply. |
@benjaminglass1 Could you send me a patch? I've been fixing up the fallout from #154515. |
@aorenste I've tested this patch, minimally, locally. It appears to be functional, although we'll let the pipeline see. It contains fixes for all the suggestions above, as well as some additional genericization in diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 4fcb847ddd6..a13d93e26e0 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1631,7 +1631,7 @@ class Generator:
class _DispatchOperatorHandle:
def schema(self) -> FunctionSchema: ...
def debug(self) -> str: ...
- def redispatch_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Stack: ...
+ def redispatch_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Any: ...
class _DispatchModule:
def reset(self) -> None: ...
diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py
index 1e03146bbcc..a4103eb8387 100644
--- a/torch/_dispatch/python.py
+++ b/torch/_dispatch/python.py
@@ -3,12 +3,15 @@ import itertools
import unittest.mock
from collections.abc import Iterator
from contextlib import contextmanager
+from typing import Callable, TypeVar, Union
+from typing_extensions import ParamSpec
import torch
import torch._C
import torch._ops
import torch.utils._python_dispatch
import torch.utils._pytree as pytree
+from torch._C import DispatchKey
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
@@ -19,6 +22,9 @@ enable_pre_dispatch = torch._C._EnablePreDispatch
CROSSREF_FUNCTIONALIZE = False
+_P = ParamSpec("_P")
+_T = TypeVar("_T")
+
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
"""
@@ -103,14 +109,16 @@ def _fmt(a: object) -> object:
return a
-def make_crossref_functionalize(op, final_key):
+def make_crossref_functionalize(
+ op: torch._ops.OpOverload[_P, _T], final_key: DispatchKey
+) -> Union[Callable[_P, _T], DispatchKey]:
from torch._subclasses.fake_tensor import FakeTensorMode
# This case is pretty weird, suppress it for now
if op == torch.ops.aten.lift_fresh.default:
return final_key
- def handler(*args, **kwargs):
+ def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
fake_mode = FakeTensorMode()
def fakeify_defun(t):
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
index 42a2c94d3b8..98a93fe6eb9 100644
--- a/torch/_inductor/decomposition.py
+++ b/torch/_inductor/decomposition.py
@@ -6,7 +6,7 @@ import operator
import sys
import typing
from typing import Any, Callable, Optional, TypeVar, Union
-from typing_extensions import ParamSpec
+from typing_extensions import ParamSpec, TypeAlias
import torch
import torch._decomp as decomp
@@ -51,6 +51,10 @@ from .utils import (
_T = TypeVar("_T")
_P = ParamSpec("_P")
+_GenericOperator: TypeAlias = Union[
+ torch._ops.OperatorBase, torch._ops.OpOverloadPacket
+]
+
log = logging.getLogger(__name__)
aten = torch.ops.aten
prims = torch.ops.prims
@@ -132,9 +136,9 @@ remove_decompositions(decompositions, decomps_to_exclude)
def register_decomposition(
- ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]],
+ ops: Union[_GenericOperator, list[_GenericOperator]],
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
- for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
+ for op in ops if isinstance(ops, list) else [ops]:
if op in decompositions:
log.warning("duplicate decomp: %s", ops)
return decomp.register_decomposition(ops, decompositions)
@@ -523,7 +527,7 @@ def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
return torch.where(torch.isnan(other) | (other < self), self, other)
-@register_decomposition([aten.amax])
+@register_decomposition(aten.amax)
def amax(
self: torch.Tensor,
dim: Optional[int] = None,
@@ -534,7 +538,7 @@ def amax(
return NotImplemented
-@register_decomposition([aten.amin])
+@register_decomposition(aten.amin)
def amin(
self: torch.Tensor,
dim: Optional[int] = None,
@@ -582,7 +586,7 @@ def get_like_layout(
return memory_format
-@register_decomposition([aten.rand_like])
+@register_decomposition(aten.rand_like)
def rand_like(
self: torch.Tensor,
*,
@@ -599,7 +603,7 @@ def rand_like(
).to(memory_format=get_like_layout(self, memory_format))
-@register_decomposition([aten.randn_like])
+@register_decomposition(aten.randn_like)
def randn_like(
self: torch.Tensor,
*,
@@ -616,7 +620,7 @@ def randn_like(
).to(memory_format=get_like_layout(self, memory_format))
-@register_decomposition([aten.full_like])
+@register_decomposition(aten.full_like)
def full_like(
self: torch.Tensor,
fill_value: Union[int, float],
@@ -638,7 +642,7 @@ def full_like(
).to(memory_format=get_like_layout(self, memory_format))
-@register_decomposition([aten.randint_like.default])
+@register_decomposition(aten.randint_like.default)
def randint_like(
self: torch.Tensor,
high: int,
@@ -658,7 +662,7 @@ def randint_like(
).to(memory_format=get_like_layout(self, memory_format))
-@register_decomposition([aten.randint_like.low_dtype])
+@register_decomposition(aten.randint_like.low_dtype)
def randint_like_low(
self: torch.Tensor,
low: int,
@@ -679,7 +683,7 @@ def randint_like_low(
).to(memory_format=get_like_layout(self, memory_format))
-@register_decomposition([aten.randint.default])
+@register_decomposition(aten.randint.default)
def randint(
high: int,
size: list[Union[int, torch.SymInt]],
@@ -688,7 +692,7 @@ def randint(
return aten.randint.low(0, high, size, **kwargs)
-@register_decomposition([quantized.linear_dynamic_fp16_unpacked_weight.default])
+@register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default)
def linear_dynamic_fp16_unpacked_weight(
input: torch.Tensor,
weight: torch.Tensor,
@@ -700,7 +704,7 @@ def linear_dynamic_fp16_unpacked_weight(
)
-@register_decomposition([_quantized.wrapped_quantized_linear.default])
+@register_decomposition(_quantized.wrapped_quantized_linear.default)
def wrapped_quantized_linear(
input: torch.Tensor,
input_scale: torch.Tensor,
@@ -727,7 +731,7 @@ def wrapped_quantized_linear(
)
-@register_decomposition([torch.ops.quantized.embedding_bag_byte_unpack])
+@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor:
def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor:
x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3))
@@ -772,7 +776,7 @@ def grid_sampler_2d(
return output
-@register_decomposition([aten._foreach_addcmul.Scalar])
+@register_decomposition(aten._foreach_addcmul.Scalar)
def _foreach_addcmul_scalar(
self: list[torch.Tensor],
left_tensors: list[torch.Tensor],
@@ -784,7 +788,7 @@ def _foreach_addcmul_scalar(
)
-@register_decomposition([aten._foreach_addcdiv.Scalar])
+@register_decomposition(aten._foreach_addcdiv.Scalar)
def _foreach_addcdiv_scalar(
self: list[torch.Tensor],
left_tensors: list[torch.Tensor],
@@ -796,7 +800,7 @@ def _foreach_addcdiv_scalar(
)
-@register_decomposition([aten._foreach_lerp.Scalar])
+@register_decomposition(aten._foreach_lerp.Scalar)
def _foreach_lerp_scalar(
start_tensors: list[torch.Tensor],
end_tensors: list[torch.Tensor],
@@ -810,7 +814,7 @@ def _foreach_lerp_scalar(
)
-@register_decomposition([aten._foreach_lerp.ScalarList])
+@register_decomposition(aten._foreach_lerp.ScalarList)
def _foreach_lerp_scalarlist(
start_tensors: list[torch.Tensor],
end_tensors: list[torch.Tensor],
@@ -825,7 +829,7 @@ def _foreach_lerp_scalarlist(
@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
-@register_decomposition([aten.miopen_batch_norm])
+@register_decomposition(aten.miopen_batch_norm)
def miopen_batch_norm(
input: torch.Tensor,
weight: torch.Tensor,
@@ -870,7 +874,7 @@ def select_decomp_table() -> dict[Any, Callable[..., Any]]:
return fast_random_decomps()
-@register_decomposition([aten.masked_scatter])
+@register_decomposition(aten.masked_scatter)
def masked_scatter(
self: torch.Tensor,
mask: torch.Tensor,
@@ -889,7 +893,7 @@ def masked_scatter(
return NotImplemented
-@register_decomposition([quantized_decomposed.choose_qparams.tensor])
+@register_decomposition(quantized_decomposed.choose_qparams.tensor)
def choose_qparams_tensor(
input: torch.Tensor,
quant_min: int,
@@ -905,7 +909,7 @@ def choose_qparams_tensor(
return scale.to(torch.float64), zero_point.to(torch.int64)
-@register_decomposition([aten.put])
+@register_decomposition(aten.put)
def put(
self: torch.Tensor,
index: torch.Tensor,
@@ -919,7 +923,7 @@ def put(
return flattened.reshape(self.shape)
-@register_decomposition([aten.put_])
+@register_decomposition(aten.put_)
def put_(
self: torch.Tensor,
index: torch.Tensor,
@@ -930,7 +934,7 @@ def put_(
return self.copy_(out)
-@register_decomposition([aten._softmax_backward_data.default])
+@register_decomposition(aten._softmax_backward_data.default)
@pw_cast_for_opmath
def _softmax_backward_data(
grad_output: torch.Tensor,
@@ -952,7 +956,7 @@ def _softmax_backward_data(
return grad_input.contiguous()
-@register_decomposition([aten.index_reduce])
+@register_decomposition(aten.index_reduce)
def index_reduce(
self: torch.Tensor,
dim: int,
@@ -1057,7 +1061,7 @@ def _max_pool_with_indices(
return vals, indices
-@register_decomposition([aten.max_pool2d_with_indices])
+@register_decomposition(aten.max_pool2d_with_indices)
def max_pool2d_with_indices(
x: torch.Tensor,
kernel_size: list[int],
@@ -1071,7 +1075,7 @@ def max_pool2d_with_indices(
)
-@register_decomposition([aten.max_pool3d_with_indices])
+@register_decomposition(aten.max_pool3d_with_indices)
def max_pool3d_with_indices(
x: torch.Tensor,
kernel_size: list[int],
@@ -1085,7 +1089,7 @@ def max_pool3d_with_indices(
)
-@register_decomposition([aten.adaptive_max_pool2d])
+@register_decomposition(aten.adaptive_max_pool2d)
def adaptive_max_pool2d(
x: torch.Tensor, output_size: list[int]
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -1103,7 +1107,7 @@ def adaptive_max_pool2d(
return NotImplemented
-@register_decomposition([aten.searchsorted.Scalar])
+@register_decomposition(aten.searchsorted.Scalar)
def searchsorted_scalar(
sorted_sequence: torch.Tensor,
self: torch.types.Number,
@@ -1123,7 +1127,7 @@ def searchsorted_scalar(
)[0]
-@register_decomposition([aten.rrelu_with_noise_functional])
+@register_decomposition(aten.rrelu_with_noise_functional)
def rrelu_with_noise_functional(
self: torch.Tensor,
noise: torch.Tensor,
diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py
index d5b8a86b868..dc48f0b804d 100644
--- a/torch/_inductor/fx_passes/reinplace.py
+++ b/torch/_inductor/fx_passes/reinplace.py
@@ -180,10 +180,9 @@ def scatter_always_uses_mutation(node: torch.fx.Node) -> bool:
_, _, view_ops = node.args
view_ops = cast(Sequence[torch.fx.node.Argument], view_ops)
return any(
- view.target in _ALWAYS_MUTATING_SCATTER_OPS
+ target in _ALWAYS_MUTATING_SCATTER_OPS
for view in view_ops
- if isinstance(view, torch.fx.Node)
- and isinstance(view.target, torch._ops.OpOverload)
+ if isinstance(target := getattr(view, "target", None), torch._ops.OpOverload)
)
diff --git a/torch/_ops.py b/torch/_ops.py
index 9a8a97926b2..b46d1636207 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -6,18 +6,19 @@ import importlib
import inspect
import sys
import types
+from collections.abc import Iterator
from functools import cached_property
from typing import (
Any,
Callable,
ClassVar,
final,
+ Generic,
Optional,
TYPE_CHECKING,
- TypeVar,
Union,
)
-from typing_extensions import Concatenate, ParamSpec
+from typing_extensions import Concatenate, ParamSpec, TypeVar
import torch
import torch.utils._pytree as pytree
@@ -31,8 +32,8 @@ if TYPE_CHECKING:
from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
-_T = TypeVar("_T")
-_P = ParamSpec("_P")
+_T = TypeVar("_T", default=Any)
+_P = ParamSpec("_P", default=...)
# Query `hasattr` only once.
@@ -744,7 +745,7 @@ def get_cached_ops():
# Each OpOverload object contains pointer to a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
-class OpOverload(OperatorBase):
+class OpOverload(OperatorBase, Generic[_P, _T]):
def __init__(
self,
overloadpacket: "OpOverloadPacket",
@@ -789,13 +790,13 @@ class OpOverload(OperatorBase):
is_write = a.alias_info.is_write or is_write
self.is_view = is_write is not None and not is_write
- @property
- def _namespace(self):
- return self._schema.name.split("::")[0]
+ @cached_property
+ def _namespace(self) -> str:
+ return self._schema.name.split("::", maxsplit=1)[0]
- @property
- def _opname(self):
- return self._schema.name.split("::")[1]
+ @cached_property
+ def _opname(self) -> str:
+ return self._schema.name.split("::", maxsplit=1)[1]
@cached_property
def _handle(self) -> torch._C._DispatchOperatorHandle:
@@ -808,20 +809,18 @@ class OpOverload(OperatorBase):
return self
def __repr__(self):
- return "<OpOverload(op='{}.{}', overload='{}')>".format(
- *self._schema.name.split("::"), self._overloadname
- )
+ return f"<OpOverload(op='{self._namespace}.{self._opname}', overload='{self._overloadname}')>"
# Use positional-only argument to avoid naming collision with aten ops arguments
# that are named "self". This way, all the aten ops can be called by kwargs.
- def __call__(self, /, *args, **kwargs):
+ def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
return self._op(*args, **kwargs)
# Use positional-only argument to avoid naming collision with aten ops arguments
# that are named "self". This way, all the aten ops can be called by kwargs.
def redispatch(
- self, /, keyset: torch._C.DispatchKeySet, *args, **kwargs
- ) -> "torch._C.Stack":
+ self, /, keyset: torch._C.DispatchKeySet, *args: _P.args, **kwargs: _P.kwargs
+ ) -> _T:
return self._handle.redispatch_boxed(keyset, *args, **kwargs)
def __hash__(self):
@@ -831,27 +830,27 @@ class OpOverload(OperatorBase):
def __str__(self):
return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
- def has_kernel_for_dispatch_key(self, k):
+ def has_kernel_for_dispatch_key(self, k: DispatchKey) -> bool:
return super().has_kernel_for_dispatch_key(
k
) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k)
- def has_kernel_for_any_dispatch_key(self, ks):
+ def has_kernel_for_any_dispatch_key(self, ks: torch._C.DispatchKeySet) -> bool:
return torch._C._dispatch_has_kernel_for_any_dispatch_key(
self.name(), ks
) or super().has_kernel_for_any_dispatch_key(ks)
@property
- def namespace(self):
- return self._schema.name.split("::")[0]
+ def namespace(self) -> str:
+ return self._namespace
- def _can_decompose(self):
+ def _can_decompose(self) -> bool:
dk = DispatchKey.CompositeImplicitAutograd
return dk in self.py_kernels or torch._C._dispatch_has_kernel_for_dispatch_key(
self.name(), dk
)
- def decompose(self, *args, **kwargs):
+ def decompose(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
dk = DispatchKey.CompositeImplicitAutograd
if dk in self.py_kernels:
# NB: This branch is not too necessary anymore, because we can
@@ -872,11 +871,11 @@ class OpOverload(OperatorBase):
# registering Autograd affects AutogradCPU). del_dispatch is to be used
# only if you are specifically modifying how get_dispatch handles a
# particular input 'key'.
- def _uncache_dispatch(self, key):
+ def _uncache_dispatch(self, key: DispatchKey) -> None:
self._dispatch_cache.pop(key, None)
# This implements the pre-computation logic for the Python dispatcher.
- def _get_dispatch(self, key):
+ def _get_dispatch(self, key: DispatchKey) -> Union[DispatchKey, Callable[_P, _T]]:
# This is only called upon a cache miss
assert key not in self._dispatch_cache, f"{self} {key}"
@@ -886,7 +885,7 @@ class OpOverload(OperatorBase):
add_cached_op(self)
return key
- def handler(*args, **kwargs):
+ def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
from torch.utils._python_dispatch import _get_current_dispatch_mode
# TODO: We also need to handle tensor subclasses here
@@ -926,7 +925,7 @@ class OpOverload(OperatorBase):
)
):
- def handler(*args, **kwargs):
+ def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
@contextlib.contextmanager
def _temporarily_pop_modes_from_pre_dispatch():
top_mode = _pop_mode_from_pre_dispatch()
@@ -959,7 +958,7 @@ class OpOverload(OperatorBase):
import torch._dispatch.python as pydispatch
if pydispatch.CROSSREF_FUNCTIONALIZE:
- handler = pydispatch.make_crossref_functionalize(self, final_key)
+ handler = pydispatch.make_crossref_functionalize(self, final_key) # type: ignore[assignment]
if cache_result:
self._dispatch_cache[key] = handler
add_cached_op(self)
@@ -993,7 +992,7 @@ class OpOverload(OperatorBase):
# schema consists of torch.ScriptObject (i.e. custom class) input.
# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
# when its inputs contain FakeScriptObject in a similar way as higher order ops.
-class TorchBindOpOverload(OpOverload):
+class TorchBindOpOverload(OpOverload[_P, _T]):
def _fallthrough_keys(self) -> list[DispatchKey]:
# TODO: we should be calling the fallback for these, but a fallthrough is almost close
# enough to the fallback in most cases that we care about.
@@ -1042,7 +1041,7 @@ class TorchBindOpOverload(OpOverload):
# Use positional-only argument to avoid naming collision with aten ops arguments
# that are named "self". This way, all the aten ops can be called by kwargs.
- def __call__(self, /, *args, **kwargs):
+ def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
if _must_dispatch_in_python(args, kwargs):
# When any inputs are FakeScriptObject, we need to
# skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher
@@ -1055,10 +1054,14 @@ class TorchBindOpOverload(OpOverload):
# 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might
# cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction.
with self._register_as_effectful_op_temporarily():
- return self._dispatch_in_python(args, kwargs, self._fallthrough_keys())
+ return self._dispatch_in_python(
+ self._fallthrough_keys(), *args, **kwargs
+ )
return self._op(*args, **kwargs)
- def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
+ def _dispatch_in_python(
+ self, fallthrough_keys: list[DispatchKey], *args: _P.args, **kwargs: _P.kwargs
+ ) -> _T:
non_fallthrough_keys = torch._C._dispatch_keyset_full()
for key in fallthrough_keys:
non_fallthrough_keys = non_fallthrough_keys.remove(key)
@@ -1079,7 +1082,9 @@ class TorchBindOpOverload(OpOverload):
self.name(), dispatch_key
):
return self._dispatch_in_python(
- args, kwargs, fallthrough_keys + [dispatch_key]
+ fallthrough_keys + [dispatch_key],
+ *args,
+ **kwargs,
)
raise RuntimeError(
@@ -1111,14 +1116,14 @@ def _has_script_object_arg(schema: torch.FunctionSchema) -> bool:
# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
# You can obtain an OpOverload object through attribute query.
-class OpOverloadPacket:
+class OpOverloadPacket(Generic[_P, _T]):
__file__: ClassVar[str] = "torch.ops"
def __init__(
self,
qualified_op_name: str,
op_name: str,
- op: Callable[..., Any],
+ op: Callable[_P, _T],
overload_names: list[str],
) -> None:
# These attributes are accessible on the object through the properties
@@ -1158,7 +1163,7 @@ class OpOverloadPacket:
for overload_name in self._overload_names
}
- def __getattr__(self, key: str) -> OpOverload:
+ def __getattr__(self, key: str) -> OpOverload[_P, _T]:
# ensure that query for dunder attributes that does not exist on
# opoverloadpacket but instead exists on the self._op object does not unnecessarily call
# `_get_operation_overload` (which is an expensive operation).
@@ -1193,7 +1198,7 @@ class OpOverloadPacket:
op_, op_dk_, tags = op_dk_tags
schema = torch._C._get_schema(self._qualified_op_name, use_key)
- overload = (
+ overload: OpOverload[_P, _T] = (
OpOverload(self, op_, op_dk_, schema, tags)
if not _has_script_object_arg(schema)
else TorchBindOpOverload(self, op_, op_dk_, schema, tags)
@@ -1207,12 +1212,12 @@ class OpOverloadPacket:
f"The underlying op of '{str(self)}' has no overload name '{key}'"
) from None
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
return iter(self._dir)
# Use positional-only argument to avoid naming collision with aten ops arguments
# that are named "self". This way, all the aten ops can be called by kwargs.
- def __call__(self, /, *args, **kwargs):
+ def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
# overloading __call__ to ensure torch.ops.foo.bar()
# is still callable from JIT
# We save the function ptr as the `op` attribute on
@@ -1222,8 +1227,8 @@ class OpOverloadPacket:
# the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
# intercept it here and call TorchBindOpverload instead.
if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
- return _call_overload_packet_from_python(self, args, kwargs)
- return self._op(*args, **(kwargs or {}))
+ return _call_overload_packet_from_python(self, *args, **kwargs)
+ return self._op(*args, **kwargs)
# TODO: use this to make a __dir__
def overloads(self):
@@ -1232,10 +1237,12 @@ class OpOverloadPacket:
# Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp
# _jit_get_operations, which calls _get_operation_for_overload_or_packet.
-def _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs):
+def _call_overload_packet_from_python(
+ op: OpOverloadPacket[_P, _T], *args: _P.args, **kwargs: _P.kwargs
+) -> _T:
# Re-use the torch function handling logic in cpp
torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet(
- op, *args, **kwargs
+ op, args, kwargs
)
if torch_function_called:
@@ -1252,7 +1259,7 @@ def _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs):
op_overload = getattr(op, overload_name)
try:
_ = torch._C._check_schema_allow_fake_script_object(
- op_overload._schema, *args, **kwargs
+ op_overload._schema, args, kwargs
)
found_op = op_overload
break
@@ -1313,7 +1320,7 @@ class _OpNamespace(types.ModuleType):
self.name = name
self._dir: list[str] = []
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
return iter(self._dir)
def __getattr__(self, op_name: str) -> OpOverloadPacket:
@@ -1377,10 +1384,10 @@ class _HigherOrderNamespace(types.ModuleType):
super().__init__("torch.ops.higher_order")
self._dir: list[str] = []
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
return iter(self._dir)
- def __getattr__(self, name) -> HigherOrderOperator:
+ def __getattr__(self, name: str) -> HigherOrderOperator:
# Following _OpNamespace.__getattr__, we cache the op on this object.
op = _higher_order_ops.get(name, None)
if op is None:
@@ -1401,14 +1408,14 @@ class _Ops(types.ModuleType):
self.higher_order = _HigherOrderNamespace()
self._dir = []
- def __getattr__(self, name) -> _OpNamespace:
+ def __getattr__(self, name: str) -> _OpNamespace:
# Here we are creating `torch.ops.my_namespace`
namespace = _OpNamespace(name)
setattr(self, name, namespace)
self._dir.append(name)
return namespace
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
return iter(self._dir)
def import_module(self, module):
diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py
index e08b5937260..83cf384d511 100644
--- a/torch/_prims_common/wrappers.py
+++ b/torch/_prims_common/wrappers.py
@@ -370,7 +370,9 @@ def out_wrapper(
annotation=out_type,
)
# Mark that the function now returns a tuple
- assert isinstance(sig.return_annotation, str) or sig.return_annotation in (
+ assert isinstance(
+ sig.return_annotation, (str, TypeVar)
+ ) or sig.return_annotation in (
sig.empty,
out_type,
bc_out_type, |
0f63c94
to
678314e
Compare
Summary: Cloned #153558 from benjaminglass1 and fixed internal typing errors. Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases. Decisions made along the way: 1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class. 2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables. The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues. Test Plan: CI Differential Revision: D75497142
This pull request was exported from Phabricator. Differential Revision: D75497142 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
@aorenste Working on this new batch of failures locally; thankfully, they all appear to be failing more or less identically, so this should hopefully be the last patch. |
@aorenste Found it; I made a mistake thanks to some incorrect typing in diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 4e2b6b4c8df..27e3603872e 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -452,13 +452,13 @@ ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]]
# and torch/csrc/jit/python/init.cpp
def _maybe_call_torch_function_for_op_packet(
op_overload_packet: Any,
- args: Any,
- kwargs: Any,
+ *args: Any,
+ **kwargs: Any,
) -> Any: ...
def _check_schema_allow_fake_script_object(
schema: FunctionSchema,
- args: Any,
- kwargs: Any,
+ *args: Any,
+ **kwargs: Any,
) -> _bool: ...
def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ...
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
diff --git a/torch/_ops.py b/torch/_ops.py
index 34e56eece11..8455707aa9a 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -1242,7 +1242,7 @@ def _call_overload_packet_from_python(
) -> _T:
# Re-use the torch function handling logic in cpp
torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet(
- op, args, kwargs
+ op, *args, **kwargs
)
if torch_function_called:
@@ -1259,7 +1259,7 @@ def _call_overload_packet_from_python(
op_overload = getattr(op, overload_name)
try:
_ = torch._C._check_schema_allow_fake_script_object(
- op_overload._schema, args, kwargs
+ op_overload._schema, *args, **kwargs
)
found_op = op_overload
break |
678314e
to
5aeba39
Compare
Summary: X-link: pytorch/executorch#11276 Cloned pytorch#153558 from benjaminglass1 and fixed internal typing errors. Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases. Decisions made along the way: 1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class. 2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables. The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues. Test Plan: CI Differential Revision: D75497142
This pull request was exported from Phabricator. Differential Revision: D75497142 |
@pytorchbot merge |
Merge failedReason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR! Details for Dev Infra teamRaised by workflow job |
Summary: X-link: pytorch/pytorch#154555 Cloned pytorch/pytorch#153558 from benjaminglass1 and fixed internal typing errors. Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases. Decisions made along the way: 1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class. 2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables. The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues. Reviewed By: bobrenjc93 Differential Revision: D75497142
Summary: Pull Request resolved: pytorch#11276 X-link: pytorch/pytorch#154555 Cloned pytorch/pytorch#153558 from benjaminglass1 and fixed internal typing errors. Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases. Decisions made along the way: 1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class. 2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables. The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues. Reviewed By: bobrenjc93 Differential Revision: D75497142
@benjaminglass1 PRs that originate w/ internal diffs have to be landed internally. There's an internal tool which should allow me to check for internal breaks before landing - but ironically it's currently broken so I'm waiting for it to be fixed so I can make sure I don't have a repeat from last week where after landing I get a report of almost 1,000 failing tests... |
@aorenste Is the testing tool back up? I'd like to rebase and merge this, if we can. |
Not yet. I was told either yesterday or today... |
Summary: X-link: pytorch/executorch#11276 Pull Request resolved: pytorch#154555 Cloned pytorch#153558 from benjaminglass1 and fixed internal typing errors. Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases. Decisions made along the way: 1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class. 2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables. The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues. Test Plan: CI Reviewed By: bobrenjc93, mergennachin Differential Revision: D75497142
Summary: Pull Request resolved: pytorch#11276 X-link: pytorch/pytorch#154555 Cloned pytorch/pytorch#153558 from benjaminglass1 and fixed internal typing errors. Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases. Decisions made along the way: 1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class. 2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables. The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues. Reviewed By: bobrenjc93, mergennachin Differential Revision: D75497142
This pull request was exported from Phabricator. Differential Revision: D75497142 |
5aeba39
to
762b0b3
Compare
Summary:
Cloned #153558 from benjaminglass1 and fixed internal typing errors.
Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that
__getattr__
can return a single type in all other cases.Decisions made along the way:
torch.ops.higher_order
is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the_Ops
class.__getattr__
is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.
Test Plan: CI
Differential Revision: D75497142
Co-authored-by: Benjamin Glass bglass@quansight.com
cc @ezyang @malfet @xuzhao9 @gramster @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov