29
29
from torch ._dynamo .device_interface import get_interface_for_device
30
30
from torch ._dynamo .testing import rand_strided
31
31
from torch ._dynamo .utils import counters , dynamo_timed , identity , preserve_rng_state
32
+ from torch ._inductor .codegen .cuda .cuda_kernel import CUDATemplateCaller
32
33
from torch ._inductor .utils import clear_on_fresh_inductor_cache
33
34
from torch .utils ._filelock import FileLock
34
35
from torch .utils ._ordered_set import OrderedSet
@@ -1834,8 +1835,6 @@ def __call__(
1834
1835
precompilation_timeout_seconds : int = 60 * 60 ,
1835
1836
return_multi_template = False ,
1836
1837
):
1837
- from .codegen .cuda .cuda_kernel import CUDATemplateCaller
1838
-
1839
1838
# Templates selected with input_gen_fns require specific input data to avoid IMA
1840
1839
# Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection
1841
1840
# TODO(jgong5): support multi-template on CPU
@@ -2141,10 +2140,6 @@ def wait_on_futures():
2141
2140
timeout = precompilation_timeout_seconds ,
2142
2141
):
2143
2142
if e := future .exception ():
2144
- from torch ._inductor .codegen .cuda .cuda_kernel import (
2145
- CUDATemplateCaller ,
2146
- )
2147
-
2148
2143
if isinstance (e , CUDACompileError ) and isinstance (
2149
2144
futures [future ], CUDATemplateCaller
2150
2145
):
@@ -2263,8 +2258,6 @@ def benchmark_choices(
2263
2258
try :
2264
2259
timing = cls .benchmark_choice (choice , autotune_args )
2265
2260
except CUDACompileError as e :
2266
- from torch ._inductor .codegen .cuda .cuda_kernel import CUDATemplateCaller
2267
-
2268
2261
if not isinstance (choice , CUDATemplateCaller ):
2269
2262
log .error (
2270
2263
"CUDA compilation error during autotuning: \n %s. \n Ignoring this choice." ,
@@ -2275,8 +2268,6 @@ def benchmark_choices(
2275
2268
log .warning ("Not yet implemented: %s" , e )
2276
2269
timing = float ("inf" )
2277
2270
except RuntimeError as e :
2278
- from torch ._inductor .codegen .cuda .cuda_kernel import CUDATemplateCaller
2279
-
2280
2271
msg = str (e )
2281
2272
if "invalid argument" in msg :
2282
2273
msg += "\n \n This may mean this GPU is too small for max_autotune mode.\n \n "
0 commit comments