12
12
import os
13
13
import re
14
14
import tempfile
15
- import typing
16
15
from abc import ABC , abstractmethod
17
16
from enum import auto , Enum
18
17
from itertools import chain
27
26
TYPE_CHECKING ,
28
27
Union ,
29
28
)
30
- from typing_extensions import TypeVar
29
+ from typing_extensions import Self , TypeVar
31
30
32
31
import sympy
33
32
@@ -408,7 +407,7 @@ def get_backend_features(
408
407
if isinstance (device , torch .device ):
409
408
device_type = device .type
410
409
else :
411
- assert isinstance (device , str )
410
+ assert isinstance (device , str ), type ( device )
412
411
device_type = device
413
412
device = torch .device (device_type )
414
413
scheduling_ctor = get_scheduling_for_device (device_type )
@@ -538,7 +537,7 @@ def register_device_op_overrides(
538
537
539
538
540
539
def get_device_op_overrides (device : str ) -> DeviceOpOverrides :
541
- assert isinstance (device , str )
540
+ assert isinstance (device , str ), type ( device )
542
541
543
542
if not device_op_overrides_dict :
544
543
from . import cpu_device_op_overrides , mps_device_op_overrides # noqa: F401
@@ -621,7 +620,7 @@ def check_dtype(
621
620
elif config .test_configs .static_cpp_dtype_assert and backend == "cpp" :
622
621
from .cpp_utils import CppCSEVariable , DTYPE_TO_CPP
623
622
624
- assert isinstance (var , CppCSEVariable )
623
+ assert isinstance (var , CppCSEVariable ), type ( var )
625
624
if dtype == torch .bool :
626
625
if var .is_vec :
627
626
is_same_dt = f"IsVecMaskType<decltype({ var } )>::value"
@@ -682,9 +681,11 @@ def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]:
682
681
return None
683
682
684
683
if node .target == operator .getitem :
685
- return self .deduce_node_dtype (node .args [0 ]) # type: ignore[arg-type]
684
+ node_arg = node .args [0 ]
685
+ assert isinstance (node_arg , torch .fx .Node ), type (node_arg )
686
+ return self .deduce_node_dtype (node_arg )
686
687
687
- assert isinstance (node .target , str )
688
+ assert isinstance (node .target , str ), type ( node . target )
688
689
689
690
if node .target .startswith ("masked_subblock" ):
690
691
return self .deduce_node_dtype_by_subgraph (node )
@@ -730,8 +731,8 @@ def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]:
730
731
from ..loop_body import LoopBody
731
732
from ..scheduler import SchedulerNode
732
733
733
- assert isinstance (node , SchedulerNode )
734
- assert isinstance (node ._body , LoopBody )
734
+ assert isinstance (node , SchedulerNode ), type ( node )
735
+ assert isinstance (node ._body , LoopBody ), type ( node . _body )
735
736
return DataTypePropagation .propagate_loopbody (node ._body )
736
737
737
738
@@ -1428,7 +1429,7 @@ def output(self, name: str) -> str:
1428
1429
def make_inplace (self , input_name : str , output_name : str ) -> None :
1429
1430
if input_name in V .graph .unaligned_buffers :
1430
1431
V .graph .unaligned_buffers .add (output_name )
1431
- assert output_name not in self .inplace_buffers
1432
+ assert output_name not in self .inplace_buffers , output_name
1432
1433
if input_name in self .inplace_buffers :
1433
1434
buf = self .inplace_buffers [input_name ]
1434
1435
assert not isinstance (buf , RemovedArg )
@@ -1490,7 +1491,7 @@ def workspace(self, nbytes: sympy.Expr, zero_fill: bool) -> tuple[str, int]:
1490
1491
assert (
1491
1492
existing_arg .inner_name != arg .inner_name
1492
1493
and existing_arg .outer_name != arg .outer_name
1493
- )
1494
+ ), existing_arg
1494
1495
self .workspace_args .append (arg )
1495
1496
return arg .inner_name , 0
1496
1497
@@ -1518,7 +1519,7 @@ def semaphores(self, min_size: sympy.Expr) -> str:
1518
1519
)
1519
1520
for existing_arg in self .workspace_args :
1520
1521
if existing_arg .inner_name == arg .inner_name :
1521
- assert arg == existing_arg
1522
+ assert arg == existing_arg , ( arg , existing_arg )
1522
1523
self .workspace_args .append (arg )
1523
1524
return arg .inner_name
1524
1525
@@ -1618,7 +1619,7 @@ def python_argdefs(
1618
1619
) -> tuple [list [ArgName ], list [str ], list [KernelArgType ], list [Any ]]:
1619
1620
arg_defs : list [ArgName ] = []
1620
1621
call_args : list [str ] = []
1621
- arg_types : list [torch . dtype ] = []
1622
+ arg_types : list [Any ] = []
1622
1623
precompile_args : list [KernelArgType ] = []
1623
1624
for inplaced in unique (self .inplace_buffers .values ()):
1624
1625
if isinstance (inplaced , RemovedArg ):
@@ -1651,7 +1652,7 @@ def python_argdefs(
1651
1652
for outer , inner in self .sizevars .items ():
1652
1653
arg_defs .append (ArgName (inner ))
1653
1654
call_args .append (outer )
1654
- arg_types .append (type (outer )) # type: ignore[arg-type]
1655
+ arg_types .append (type (outer ))
1655
1656
precompile_args .append (SizeArg (inner , outer ))
1656
1657
if V .graph .wrapper_code :
1657
1658
V .graph .wrapper_code .ensure_size_computed (outer )
@@ -1686,7 +1687,7 @@ def is_removed(self, name: str) -> bool:
1686
1687
# after you do a call into this kernel, which buffers actually contain
1687
1688
# updated data? Modeled off of python_argdefs.
1688
1689
def live_output_buffers (self ) -> OrderedSet [str ]:
1689
- live_outs = OrderedSet () # type: ignore[var-annotated]
1690
+ live_outs = OrderedSet [ str ]()
1690
1691
for inplaced in unique (self .inplace_buffers .values ()):
1691
1692
if isinstance (inplaced , RemovedArg ):
1692
1693
continue
@@ -1712,7 +1713,7 @@ def __init__(
1712
1713
dtype : Optional [torch .dtype ] = None ,
1713
1714
):
1714
1715
super ().__init__ ()
1715
- assert isinstance (bounds , ValueRanges )
1716
+ assert isinstance (bounds , ValueRanges ), type ( bounds )
1716
1717
self .name = name
1717
1718
self .bounds = bounds
1718
1719
self .use_count = 1 # track how many times this expression is used
@@ -1782,7 +1783,7 @@ def invalidate(self, keep_vars: OrderedSet[CSEVariable]) -> None:
1782
1783
else :
1783
1784
self ._cache = {}
1784
1785
1785
- def clone (self ) -> typing . Self :
1786
+ def clone (self ) -> Self :
1786
1787
return type (self )(
1787
1788
prefix = self .prefix ,
1788
1789
suffix = self .suffix ,
@@ -1793,7 +1794,7 @@ def clone(self) -> typing.Self:
1793
1794
reduction_cache = self .reduction_cache ,
1794
1795
)
1795
1796
1796
- def scoped_copy (self ) -> typing . Self :
1797
+ def scoped_copy (self ) -> Self :
1797
1798
"""Return a copy of using ScopedDict so changes to *_cache aren't visible in self"""
1798
1799
new_cse = self .clone ()
1799
1800
new_cse ._cache = ScopedDict (self ._cache )
@@ -1918,7 +1919,7 @@ def __init__(self) -> None:
1918
1919
super ().__init__ ()
1919
1920
self .exit_stack = contextlib .ExitStack ()
1920
1921
1921
- def __enter__ (self ) -> typing . Self :
1922
+ def __enter__ (self ) -> Self :
1922
1923
self .exit_stack .__enter__ ()
1923
1924
return self
1924
1925
@@ -2084,7 +2085,7 @@ def indirect_assert(
2084
2085
) -> str :
2085
2086
if isinstance (var , CSEVariable ):
2086
2087
var = str (var )
2087
- assert isinstance (var , str )
2088
+ assert isinstance (var , str ), type ( var )
2088
2089
assert lower is None or isinstance (lower , str )
2089
2090
assert upper is None or isinstance (upper , str )
2090
2091
if lower and upper :
@@ -2113,7 +2114,7 @@ def check_bounds(
2113
2114
def index_to_str (self , index : sympy .Expr ) -> str :
2114
2115
raise NotImplementedError
2115
2116
2116
- def __enter__ (self ) -> typing . Self :
2117
+ def __enter__ (self ) -> Self :
2117
2118
super ().__enter__ ()
2118
2119
assert self .overrides
2119
2120
self .exit_stack .enter_context (
@@ -2184,7 +2185,7 @@ def rename_indexing(
2184
2185
# adds the necessary kernel args for index expressions
2185
2186
# and renames variables in index expressions to kernel arg names
2186
2187
if isinstance (index , (list , tuple )):
2187
- return [self .rename_indexing (x ) for x in index ] # type: ignore[return-value]
2188
+ return [self .rename_indexing (x ) for x in index ]
2188
2189
index = V .graph .sizevars .simplify (index )
2189
2190
sorted_symbols = sorted (index .free_symbols , key = lambda s : s .name )
2190
2191
replacements = {
@@ -2362,7 +2363,7 @@ def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
2362
2363
def _default (self , name : str , args : tuple [Any , ...], kwargs : dict [str , Any ]) -> Any :
2363
2364
bounds = self ._bound_variable (name , * args , ** kwargs )
2364
2365
2365
- value = getattr (self .parent_handler , name )(* args , ** kwargs ) # type: ignore[has-type]
2366
+ value = getattr (self .parent_handler , name )(* args , ** kwargs )
2366
2367
dtype_handler = DtypePropagationOpsHandler ()
2367
2368
2368
2369
backend = get_current_backend ()
@@ -2387,8 +2388,8 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) ->
2387
2388
def do_cse (v : str ) -> CSEVariable :
2388
2389
# we tree_map over the output, so we need to fetch corresponding dtype
2389
2390
nonlocal output_idx
2390
- var_dtype : torch .dtype = (
2391
- output_dtype [output_idx ] # type: ignore[assignment]
2391
+ var_dtype : Optional [ torch .dtype ] = (
2392
+ output_dtype [output_idx ]
2392
2393
if isinstance (output_dtype , (list , tuple ))
2393
2394
else output_dtype
2394
2395
)
@@ -2411,6 +2412,7 @@ def do_cse(v: str) -> CSEVariable:
2411
2412
config .test_configs .runtime_triton_dtype_assert
2412
2413
or config .test_configs .static_cpp_dtype_assert
2413
2414
):
2415
+ assert var_dtype is not None
2414
2416
check_dtype (V .kernel .compute , csevar , var_dtype )
2415
2417
return csevar
2416
2418
@@ -2433,7 +2435,9 @@ def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[A
2433
2435
2434
2436
fx_node = V .interpreter .current_node
2435
2437
if fx_node .target == name and self .kernel .node_to_bounds is not None :
2436
- assert isinstance (self .kernel .node_to_bounds , dict )
2438
+ assert isinstance (self .kernel .node_to_bounds , dict ), type (
2439
+ self .kernel .node_to_bounds
2440
+ )
2437
2441
return self .kernel .node_to_bounds .get (fx_node , ValueRanges .unknown ())
2438
2442
elif config .compute_all_bounds and hasattr (ValueRangeAnalysis , name ):
2439
2443
# These create lots of inner strings. We would need to compute the bounds at the ops
@@ -2468,14 +2472,14 @@ def indirect_indexing(
2468
2472
) -> sympy .Symbol :
2469
2473
if isinstance (size , int ):
2470
2474
size = sympy .Integer (size )
2471
- assert isinstance (size , sympy .Expr ), size
2475
+ assert isinstance (size , sympy .Expr ), ( type ( size ), size )
2472
2476
# Skip CSE since this doesn't return an expression
2473
2477
2474
- if var .bounds .lower < 0 : # type: ignore[operator]
2478
+ if var .bounds .lower < 0 :
2475
2479
if wrap_neg :
2476
2480
stm = ops .add (var , ops .index_expr (size , torch .long ))
2477
2481
# Mixed negative and non-negative
2478
- if var .bounds .upper >= 0 : # type: ignore[operator]
2482
+ if var .bounds .upper >= 0 :
2479
2483
lt = ops .lt (var , 0 )
2480
2484
stm = ops .where (lt , stm , var )
2481
2485
else :
@@ -2492,7 +2496,7 @@ def indirect_indexing(
2492
2496
neg_bounds .lower + size , neg_bounds .upper + size
2493
2497
)
2494
2498
# We don't have a good way of representing the empty range
2495
- if var .bounds .upper >= 0 : # type: ignore[operator]
2499
+ if var .bounds .upper >= 0 :
2496
2500
pos = var .bounds & ValueRanges (0 , int_oo )
2497
2501
new_bounds = new_bounds | pos
2498
2502
@@ -2544,8 +2548,7 @@ def store(
2544
2548
if mode is None :
2545
2549
self ._update_store_cache (name , value )
2546
2550
if name not in V .graph .removed_buffers :
2547
- return self .kernel .store (name , index , value , mode = mode )
2548
- return None # type: ignore[return-value]
2551
+ self .kernel .store (name , index , value , mode = mode )
2549
2552
2550
2553
def store_reduction (self , name : str , index : sympy .Expr , value : CSEVariable ) -> None :
2551
2554
self .kernel .store_buffer_names .add (name )
0 commit comments