Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8568dbc

Browse filesBrowse files
recpytorchmergebot
authored andcommitted
[inductor] Clean typing in codegen/common.py and codecache.py (#150767)
Pull Request resolved: #150767 Approved by: https://github.com/aorenste
1 parent 27f7b65 commit 8568dbc
Copy full SHA for 8568dbc

File tree

Expand file treeCollapse file tree

2 files changed

+46
-39
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+46
-39
lines changed

‎torch/_inductor/codecache.py

Copy file name to clipboardExpand all lines: torch/_inductor/codecache.py
+11-7Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,16 @@
112112
)
113113
else:
114114

115-
def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
115+
def log_global_cache_errors(*args: Any, **kwargs: Any) -> None:
116116
pass
117117

118-
def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
118+
def log_global_cache_stats(*args: Any, **kwargs: Any) -> None:
119119
pass
120120

121-
def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
121+
def log_global_cache_vals(*args: Any, **kwargs: Any) -> None:
122122
pass
123123

124-
def use_global_cache() -> bool: # type: ignore[misc]
124+
def use_global_cache() -> bool:
125125
return False
126126

127127

@@ -2451,7 +2451,8 @@ def _load_library_inner(cls, path: str, key: str) -> ModuleType:
24512451
assert spec is not None
24522452
module = importlib.util.module_from_spec(spec)
24532453
sys.modules[module_name] = module
2454-
spec.loader.exec_module(module) # type: ignore[union-attr]
2454+
assert spec.loader is not None
2455+
spec.loader.exec_module(module)
24552456
return module
24562457

24572458
@classmethod
@@ -2945,6 +2946,7 @@ def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None:
29452946
job()
29462947
except subprocess.SubprocessError as e:
29472948
if os.environ.get("HALIDE_REPRO") == "1":
2949+
cmd: list[Any]
29482950
python, script, *cmd = getattr(e, "cmd", ("", "", ""))
29492951
if os.path.basename(python).startswith("python"):
29502952
code = open(script).read()
@@ -2955,7 +2957,9 @@ class Out:
29552957
def __repr__(self) -> str:
29562958
return "out"
29572959

2958-
cmd[cmd.index("-o") + 1] = Out() # type: ignore[call-overload]
2960+
ci = cmd.index("-o")
2961+
assert isinstance(ci, int)
2962+
cmd[ci + 1] = Out()
29592963
repl = textwrap.indent(
29602964
textwrap.dedent(
29612965
f"""\
@@ -3565,7 +3569,7 @@ def __init__(
35653569
self.result_fn = result_fn
35663570
self.future = future
35673571

3568-
def result(self) -> Callable[..., Any]: # type: ignore[override]
3572+
def result(self) -> Callable[..., Any]:
35693573
return self.result_fn()
35703574

35713575

‎torch/_inductor/codegen/common.py

Copy file name to clipboardExpand all lines: torch/_inductor/codegen/common.py
+35-32Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import os
1313
import re
1414
import tempfile
15-
import typing
1615
from abc import ABC, abstractmethod
1716
from enum import auto, Enum
1817
from itertools import chain
@@ -27,7 +26,7 @@
2726
TYPE_CHECKING,
2827
Union,
2928
)
30-
from typing_extensions import TypeVar
29+
from typing_extensions import Self, TypeVar
3130

3231
import sympy
3332

@@ -408,7 +407,7 @@ def get_backend_features(
408407
if isinstance(device, torch.device):
409408
device_type = device.type
410409
else:
411-
assert isinstance(device, str)
410+
assert isinstance(device, str), type(device)
412411
device_type = device
413412
device = torch.device(device_type)
414413
scheduling_ctor = get_scheduling_for_device(device_type)
@@ -538,7 +537,7 @@ def register_device_op_overrides(
538537

539538

540539
def get_device_op_overrides(device: str) -> DeviceOpOverrides:
541-
assert isinstance(device, str)
540+
assert isinstance(device, str), type(device)
542541

543542
if not device_op_overrides_dict:
544543
from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401
@@ -621,7 +620,7 @@ def check_dtype(
621620
elif config.test_configs.static_cpp_dtype_assert and backend == "cpp":
622621
from .cpp_utils import CppCSEVariable, DTYPE_TO_CPP
623622

624-
assert isinstance(var, CppCSEVariable)
623+
assert isinstance(var, CppCSEVariable), type(var)
625624
if dtype == torch.bool:
626625
if var.is_vec:
627626
is_same_dt = f"IsVecMaskType<decltype({var})>::value"
@@ -682,9 +681,11 @@ def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]:
682681
return None
683682

684683
if node.target == operator.getitem:
685-
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
684+
node_arg = node.args[0]
685+
assert isinstance(node_arg, torch.fx.Node), type(node_arg)
686+
return self.deduce_node_dtype(node_arg)
686687

687-
assert isinstance(node.target, str)
688+
assert isinstance(node.target, str), type(node.target)
688689

689690
if node.target.startswith("masked_subblock"):
690691
return self.deduce_node_dtype_by_subgraph(node)
@@ -730,8 +731,8 @@ def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]:
730731
from ..loop_body import LoopBody
731732
from ..scheduler import SchedulerNode
732733

733-
assert isinstance(node, SchedulerNode)
734-
assert isinstance(node._body, LoopBody)
734+
assert isinstance(node, SchedulerNode), type(node)
735+
assert isinstance(node._body, LoopBody), type(node._body)
735736
return DataTypePropagation.propagate_loopbody(node._body)
736737

737738

@@ -1428,7 +1429,7 @@ def output(self, name: str) -> str:
14281429
def make_inplace(self, input_name: str, output_name: str) -> None:
14291430
if input_name in V.graph.unaligned_buffers:
14301431
V.graph.unaligned_buffers.add(output_name)
1431-
assert output_name not in self.inplace_buffers
1432+
assert output_name not in self.inplace_buffers, output_name
14321433
if input_name in self.inplace_buffers:
14331434
buf = self.inplace_buffers[input_name]
14341435
assert not isinstance(buf, RemovedArg)
@@ -1490,7 +1491,7 @@ def workspace(self, nbytes: sympy.Expr, zero_fill: bool) -> tuple[str, int]:
14901491
assert (
14911492
existing_arg.inner_name != arg.inner_name
14921493
and existing_arg.outer_name != arg.outer_name
1493-
)
1494+
), existing_arg
14941495
self.workspace_args.append(arg)
14951496
return arg.inner_name, 0
14961497

@@ -1518,7 +1519,7 @@ def semaphores(self, min_size: sympy.Expr) -> str:
15181519
)
15191520
for existing_arg in self.workspace_args:
15201521
if existing_arg.inner_name == arg.inner_name:
1521-
assert arg == existing_arg
1522+
assert arg == existing_arg, (arg, existing_arg)
15221523
self.workspace_args.append(arg)
15231524
return arg.inner_name
15241525

@@ -1618,7 +1619,7 @@ def python_argdefs(
16181619
) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]:
16191620
arg_defs: list[ArgName] = []
16201621
call_args: list[str] = []
1621-
arg_types: list[torch.dtype] = []
1622+
arg_types: list[Any] = []
16221623
precompile_args: list[KernelArgType] = []
16231624
for inplaced in unique(self.inplace_buffers.values()):
16241625
if isinstance(inplaced, RemovedArg):
@@ -1651,7 +1652,7 @@ def python_argdefs(
16511652
for outer, inner in self.sizevars.items():
16521653
arg_defs.append(ArgName(inner))
16531654
call_args.append(outer)
1654-
arg_types.append(type(outer)) # type: ignore[arg-type]
1655+
arg_types.append(type(outer))
16551656
precompile_args.append(SizeArg(inner, outer))
16561657
if V.graph.wrapper_code:
16571658
V.graph.wrapper_code.ensure_size_computed(outer)
@@ -1686,7 +1687,7 @@ def is_removed(self, name: str) -> bool:
16861687
# after you do a call into this kernel, which buffers actually contain
16871688
# updated data? Modeled off of python_argdefs.
16881689
def live_output_buffers(self) -> OrderedSet[str]:
1689-
live_outs = OrderedSet() # type: ignore[var-annotated]
1690+
live_outs = OrderedSet[str]()
16901691
for inplaced in unique(self.inplace_buffers.values()):
16911692
if isinstance(inplaced, RemovedArg):
16921693
continue
@@ -1712,7 +1713,7 @@ def __init__(
17121713
dtype: Optional[torch.dtype] = None,
17131714
):
17141715
super().__init__()
1715-
assert isinstance(bounds, ValueRanges)
1716+
assert isinstance(bounds, ValueRanges), type(bounds)
17161717
self.name = name
17171718
self.bounds = bounds
17181719
self.use_count = 1 # track how many times this expression is used
@@ -1782,7 +1783,7 @@ def invalidate(self, keep_vars: OrderedSet[CSEVariable]) -> None:
17821783
else:
17831784
self._cache = {}
17841785

1785-
def clone(self) -> typing.Self:
1786+
def clone(self) -> Self:
17861787
return type(self)(
17871788
prefix=self.prefix,
17881789
suffix=self.suffix,
@@ -1793,7 +1794,7 @@ def clone(self) -> typing.Self:
17931794
reduction_cache=self.reduction_cache,
17941795
)
17951796

1796-
def scoped_copy(self) -> typing.Self:
1797+
def scoped_copy(self) -> Self:
17971798
"""Return a copy of using ScopedDict so changes to *_cache aren't visible in self"""
17981799
new_cse = self.clone()
17991800
new_cse._cache = ScopedDict(self._cache)
@@ -1918,7 +1919,7 @@ def __init__(self) -> None:
19181919
super().__init__()
19191920
self.exit_stack = contextlib.ExitStack()
19201921

1921-
def __enter__(self) -> typing.Self:
1922+
def __enter__(self) -> Self:
19221923
self.exit_stack.__enter__()
19231924
return self
19241925

@@ -2084,7 +2085,7 @@ def indirect_assert(
20842085
) -> str:
20852086
if isinstance(var, CSEVariable):
20862087
var = str(var)
2087-
assert isinstance(var, str)
2088+
assert isinstance(var, str), type(var)
20882089
assert lower is None or isinstance(lower, str)
20892090
assert upper is None or isinstance(upper, str)
20902091
if lower and upper:
@@ -2113,7 +2114,7 @@ def check_bounds(
21132114
def index_to_str(self, index: sympy.Expr) -> str:
21142115
raise NotImplementedError
21152116

2116-
def __enter__(self) -> typing.Self:
2117+
def __enter__(self) -> Self:
21172118
super().__enter__()
21182119
assert self.overrides
21192120
self.exit_stack.enter_context(
@@ -2184,7 +2185,7 @@ def rename_indexing(
21842185
# adds the necessary kernel args for index expressions
21852186
# and renames variables in index expressions to kernel arg names
21862187
if isinstance(index, (list, tuple)):
2187-
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
2188+
return [self.rename_indexing(x) for x in index]
21882189
index = V.graph.sizevars.simplify(index)
21892190
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
21902191
replacements = {
@@ -2362,7 +2363,7 @@ def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
23622363
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
23632364
bounds = self._bound_variable(name, *args, **kwargs)
23642365

2365-
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
2366+
value = getattr(self.parent_handler, name)(*args, **kwargs)
23662367
dtype_handler = DtypePropagationOpsHandler()
23672368

23682369
backend = get_current_backend()
@@ -2387,8 +2388,8 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) ->
23872388
def do_cse(v: str) -> CSEVariable:
23882389
# we tree_map over the output, so we need to fetch corresponding dtype
23892390
nonlocal output_idx
2390-
var_dtype: torch.dtype = (
2391-
output_dtype[output_idx] # type: ignore[assignment]
2391+
var_dtype: Optional[torch.dtype] = (
2392+
output_dtype[output_idx]
23922393
if isinstance(output_dtype, (list, tuple))
23932394
else output_dtype
23942395
)
@@ -2411,6 +2412,7 @@ def do_cse(v: str) -> CSEVariable:
24112412
config.test_configs.runtime_triton_dtype_assert
24122413
or config.test_configs.static_cpp_dtype_assert
24132414
):
2415+
assert var_dtype is not None
24142416
check_dtype(V.kernel.compute, csevar, var_dtype)
24152417
return csevar
24162418

@@ -2433,7 +2435,9 @@ def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[A
24332435

24342436
fx_node = V.interpreter.current_node
24352437
if fx_node.target == name and self.kernel.node_to_bounds is not None:
2436-
assert isinstance(self.kernel.node_to_bounds, dict)
2438+
assert isinstance(self.kernel.node_to_bounds, dict), type(
2439+
self.kernel.node_to_bounds
2440+
)
24372441
return self.kernel.node_to_bounds.get(fx_node, ValueRanges.unknown())
24382442
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
24392443
# These create lots of inner strings. We would need to compute the bounds at the ops
@@ -2468,14 +2472,14 @@ def indirect_indexing(
24682472
) -> sympy.Symbol:
24692473
if isinstance(size, int):
24702474
size = sympy.Integer(size)
2471-
assert isinstance(size, sympy.Expr), size
2475+
assert isinstance(size, sympy.Expr), (type(size), size)
24722476
# Skip CSE since this doesn't return an expression
24732477

2474-
if var.bounds.lower < 0: # type: ignore[operator]
2478+
if var.bounds.lower < 0:
24752479
if wrap_neg:
24762480
stm = ops.add(var, ops.index_expr(size, torch.long))
24772481
# Mixed negative and non-negative
2478-
if var.bounds.upper >= 0: # type: ignore[operator]
2482+
if var.bounds.upper >= 0:
24792483
lt = ops.lt(var, 0)
24802484
stm = ops.where(lt, stm, var)
24812485
else:
@@ -2492,7 +2496,7 @@ def indirect_indexing(
24922496
neg_bounds.lower + size, neg_bounds.upper + size
24932497
)
24942498
# We don't have a good way of representing the empty range
2495-
if var.bounds.upper >= 0: # type: ignore[operator]
2499+
if var.bounds.upper >= 0:
24962500
pos = var.bounds & ValueRanges(0, int_oo)
24972501
new_bounds = new_bounds | pos
24982502

@@ -2544,8 +2548,7 @@ def store(
25442548
if mode is None:
25452549
self._update_store_cache(name, value)
25462550
if name not in V.graph.removed_buffers:
2547-
return self.kernel.store(name, index, value, mode=mode)
2548-
return None # type: ignore[return-value]
2551+
self.kernel.store(name, index, value, mode=mode)
25492552

25502553
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
25512554
self.kernel.store_buffer_names.add(name)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.