Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1e19c9c

Browse filesBrowse files
committed
[python] fix enum ambiguity
1 parent 9d55e86 commit 1e19c9c
Copy full SHA for 1e19c9c
Expand file treeCollapse file tree

26 files changed

+578
-137
lines changed

‎mlir/cmake/modules/AddMLIRPython.cmake

Copy file name to clipboardExpand all lines: mlir/cmake/modules/AddMLIRPython.cmake
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ function(declare_mlir_dialect_python_bindings)
318318
set(LLVM_TARGET_DEFINITIONS ${td_file})
319319
endif()
320320
set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
321-
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
321+
mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
322322
list(APPEND _sources ${enum_filename})
323323
endif()
324324

@@ -390,7 +390,7 @@ function(declare_mlir_dialect_extension_python_bindings)
390390
set(LLVM_TARGET_DEFINITIONS ${td_file})
391391
endif()
392392
set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
393-
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
393+
mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
394394
list(APPEND _sources ${enum_filename})
395395
endif()
396396

‎mlir/python/CMakeLists.txt

Copy file name to clipboardExpand all lines: mlir/python/CMakeLists.txt
+1-2Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ declare_mlir_dialect_python_bindings(
6363
TD_FILE dialects/AffineOps.td
6464
SOURCES
6565
dialects/affine.py
66-
DIALECT_NAME affine
67-
GEN_ENUM_BINDINGS)
66+
DIALECT_NAME affine)
6867

6968
declare_mlir_dialect_python_bindings(
7069
ADD_TO_PARENT MLIRPythonSources.Dialects

‎mlir/python/mlir/dialects/_ods_common.py

Copy file name to clipboardExpand all lines: mlir/python/mlir/dialects/_ods_common.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def get_op_result_or_op_results(
143143
else op
144144
)
145145

146+
146147
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
147148
ResultValueT = _Union[ResultValueTypeTuple]
148149
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]

‎mlir/python/mlir/dialects/amdgpu.py

Copy file name to clipboardExpand all lines: mlir/python/mlir/dialects/amdgpu.py
+16Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,21 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ..ir import IntegerAttr, IntegerType, register_attribute_builder
56
from ._amdgpu_ops_gen import *
67
from ._amdgpu_enum_gen import *
8+
9+
10+
@register_attribute_builder("builtin.AMDGPU_DPPPerm")
11+
def _amdgpu_dppperm(x, context):
12+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
13+
14+
15+
@register_attribute_builder("builtin.AMDGPU_MFMAPermB")
16+
def _amdgpu_mfmapermb(x, context):
17+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
18+
19+
20+
@register_attribute_builder("builtin.AMDGPU_SchedBarrierOpOpt")
21+
def _amdgpu_schedbarrieropopt(x, context):
22+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

‎mlir/python/mlir/dialects/arith.py

Copy file name to clipboardExpand all lines: mlir/python/mlir/dialects/arith.py
+35Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,38 @@ def constant(
108108
result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
109109
) -> Value:
110110
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
111+
112+
113+
@register_attribute_builder("builtin.Arith_CmpFPredicateAttr")
114+
def _arith_cmpfpredicateattr(x, context):
115+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
116+
117+
118+
@register_attribute_builder("builtin.Arith_CmpIPredicateAttr")
119+
def _arith_cmpipredicateattr(x, context):
120+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
121+
122+
123+
@register_attribute_builder("builtin.Arith_DenormalMode")
124+
def _arith_denormalmode(x, context):
125+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
126+
127+
128+
@register_attribute_builder("builtin.Arith_IntegerOverflowFlags")
129+
def _arith_integeroverflowflags(x, context):
130+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
131+
132+
133+
@register_attribute_builder("builtin.Arith_RoundingModeAttr")
134+
def _arith_roundingmodeattr(x, context):
135+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
136+
137+
138+
@register_attribute_builder("builtin.AtomicRMWKindAttr")
139+
def _atomicrmwkindattr(x, context):
140+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
141+
142+
143+
@register_attribute_builder("builtin.FastMathFlags")
144+
def _fastmathflags(x, context):
145+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

‎mlir/python/mlir/dialects/bufferization.py

Copy file name to clipboardExpand all lines: mlir/python/mlir/dialects/bufferization.py
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,11 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ..ir import IntegerAttr, IntegerType, register_attribute_builder
56
from ._bufferization_ops_gen import *
67
from ._bufferization_enum_gen import *
8+
9+
10+
@register_attribute_builder("builtin.LayoutMapOption")
11+
def _layoutmapoption(x, context):
12+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

‎mlir/python/mlir/dialects/gpu/__init__.py

Copy file name to clipboardExpand all lines: mlir/python/mlir/dialects/gpu/__init__.py
+56Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,62 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ...ir import IntegerAttr, IntegerType, register_attribute_builder
56
from .._gpu_ops_gen import *
67
from .._gpu_enum_gen import *
78
from ..._mlir_libs._mlirDialectsGPU import *
9+
10+
11+
@register_attribute_builder("builtin.GPU_AddressSpaceEnum")
12+
def _gpu_addressspaceenum(x, context):
13+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
14+
15+
16+
@register_attribute_builder("builtin.GPU_AllReduceOperation")
17+
def _gpu_allreduceoperation(x, context):
18+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
19+
20+
21+
@register_attribute_builder("builtin.GPU_CompilationTargetEnum")
22+
def _gpu_compilationtargetenum(x, context):
23+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
24+
25+
26+
@register_attribute_builder("builtin.GPU_Dimension")
27+
def _gpu_dimension(x, context):
28+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
29+
30+
31+
@register_attribute_builder("builtin.GPU_Prune2To4SpMatFlag")
32+
def _gpu_prune2to4spmatflag(x, context):
33+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
34+
35+
36+
@register_attribute_builder("builtin.GPU_ShuffleMode")
37+
def _gpu_shufflemode(x, context):
38+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
39+
40+
41+
@register_attribute_builder("builtin.GPU_SpGEMMWorkEstimationOrComputeKind")
42+
def _gpu_spgemmworkestimationorcomputekind(x, context):
43+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
44+
45+
46+
@register_attribute_builder("builtin.GPU_TransposeMode")
47+
def _gpu_transposemode(x, context):
48+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
49+
50+
51+
@register_attribute_builder("builtin.MMAElementWise")
52+
def _mmaelementwise(x, context):
53+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
54+
55+
56+
@register_attribute_builder("builtin.MappingIdEnum")
57+
def _mappingidenum(x, context):
58+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
59+
60+
61+
@register_attribute_builder("builtin.ProcessorEnum")
62+
def _processorenum(x, context):
63+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))

‎mlir/python/mlir/dialects/index.py

Copy file name to clipboardExpand all lines: mlir/python/mlir/dialects/index.py
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,11 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ..ir import IntegerAttr, IntegerType, register_attribute_builder
56
from ._index_ops_gen import *
67
from ._index_enum_gen import *
8+
9+
10+
@register_attribute_builder("builtin.IndexCmpPredicate")
11+
def _indexcmppredicate(x, context):
12+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

‎mlir/python/mlir/dialects/linalg/__init__.py

Copy file name to clipboardExpand all lines: mlir/python/mlir/dialects/linalg/__init__.py
+25Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,28 @@ def broadcast(
102102
)
103103
fill_builtin_region(op.operation)
104104
return op
105+
106+
107+
@register_attribute_builder("builtin.BinaryFn")
108+
def _binaryfn(x, context):
109+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
110+
111+
112+
@register_attribute_builder("builtin.IteratorType")
113+
def _iteratortype(x, context):
114+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
115+
116+
117+
@register_attribute_builder("builtin.TernaryFn")
118+
def _ternaryfn(x, context):
119+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
120+
121+
122+
@register_attribute_builder("builtin.TypeFn")
123+
def _typefn(x, context):
124+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
125+
126+
127+
@register_attribute_builder("builtin.UnaryFn")
128+
def _unaryfn(x, context):
129+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

‎mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Copy file name to clipboardExpand all lines: mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+66-50Lines changed: 66 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,7 @@ def conv_2d_nchw_fchw_q(
888888
- TypeFn.cast_signed(U, IZp)
889889
) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp))
890890

891+
891892
@linalg_structured_op
892893
def conv_2d_nchw_fchw(
893894
I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
@@ -1082,16 +1083,19 @@ def conv_3d_ndhwc_dhwcf(
10821083
"""
10831084
implements(ConvolutionOpInterface)
10841085
domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
1085-
O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed(
1086-
U,
1087-
I[
1088-
D.n,
1089-
D.od * S.SD + D.kd * S.DD,
1090-
D.oh * S.SH + D.kh * S.DH,
1091-
D.ow * S.SW + D.kw * S.DW,
1092-
D.c,
1093-
],
1094-
) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f])
1086+
O[D.n, D.od, D.oh, D.ow, D.f] += (
1087+
TypeFn.cast_signed(
1088+
U,
1089+
I[
1090+
D.n,
1091+
D.od * S.SD + D.kd * S.DD,
1092+
D.oh * S.SH + D.kh * S.DH,
1093+
D.ow * S.SW + D.kw * S.DW,
1094+
D.c,
1095+
],
1096+
)
1097+
* TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f])
1098+
)
10951099

10961100

10971101
@linalg_structured_op
@@ -1159,16 +1163,19 @@ def conv_3d_ncdhw_fcdhw(
11591163
"""
11601164
implements(ConvolutionOpInterface)
11611165
domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
1162-
O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed(
1163-
U,
1164-
I[
1165-
D.n,
1166-
D.c,
1167-
D.od * S.SD + D.kd * S.DD,
1168-
D.oh * S.SH + D.kh * S.DH,
1169-
D.ow * S.SW + D.kw * S.DW,
1170-
],
1171-
) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw])
1166+
O[D.n, D.f, D.od, D.oh, D.ow] += (
1167+
TypeFn.cast_signed(
1168+
U,
1169+
I[
1170+
D.n,
1171+
D.c,
1172+
D.od * S.SD + D.kd * S.DD,
1173+
D.oh * S.SH + D.kh * S.DH,
1174+
D.ow * S.SW + D.kw * S.DW,
1175+
],
1176+
)
1177+
* TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw])
1178+
)
11721179

11731180

11741181
@linalg_structured_op
@@ -1368,16 +1375,19 @@ def depthwise_conv_3d_ndhwc_dhwc(
13681375
"""
13691376
implements(ConvolutionOpInterface)
13701377
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
1371-
O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
1372-
U,
1373-
I[
1374-
D.n,
1375-
D.od * S.SD + D.kd * S.DD,
1376-
D.oh * S.SH + D.kh * S.DH,
1377-
D.ow * S.SW + D.kw * S.DW,
1378-
D.ic,
1379-
],
1380-
) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic])
1378+
O[D.n, D.od, D.oh, D.ow, D.ic] += (
1379+
TypeFn.cast_signed(
1380+
U,
1381+
I[
1382+
D.n,
1383+
D.od * S.SD + D.kd * S.DD,
1384+
D.oh * S.SH + D.kh * S.DH,
1385+
D.ow * S.SW + D.kw * S.DW,
1386+
D.ic,
1387+
],
1388+
)
1389+
* TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic])
1390+
)
13811391

13821392

13831393
@linalg_structured_op
@@ -1403,16 +1413,19 @@ def depthwise_conv_3d_ncdhw_cdhw(
14031413
"""
14041414
implements(ConvolutionOpInterface)
14051415
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
1406-
O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed(
1407-
U,
1408-
I[
1409-
D.n,
1410-
D.ic,
1411-
D.od * S.SD + D.kd * S.DD,
1412-
D.oh * S.SH + D.kh * S.DH,
1413-
D.ow * S.SW + D.kw * S.DW,
1414-
],
1415-
) * TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw])
1416+
O[D.n, D.ic, D.od, D.oh, D.ow] += (
1417+
TypeFn.cast_signed(
1418+
U,
1419+
I[
1420+
D.n,
1421+
D.ic,
1422+
D.od * S.SD + D.kd * S.DD,
1423+
D.oh * S.SH + D.kh * S.DH,
1424+
D.ow * S.SW + D.kw * S.DW,
1425+
],
1426+
)
1427+
* TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw])
1428+
)
14161429

14171430

14181431
@linalg_structured_op
@@ -1437,16 +1450,19 @@ def depthwise_conv_3d_ndhwc_dhwcm(
14371450
"""
14381451
implements(ConvolutionOpInterface)
14391452
domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic)
1440-
O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
1441-
U,
1442-
I[
1443-
D.n,
1444-
D.od * S.SD + D.kd * S.DD,
1445-
D.oh * S.SH + D.kh * S.DH,
1446-
D.ow * S.SW + D.kw * S.DW,
1447-
D.ic,
1448-
],
1449-
) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm])
1453+
O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += (
1454+
TypeFn.cast_signed(
1455+
U,
1456+
I[
1457+
D.n,
1458+
D.od * S.SD + D.kd * S.DD,
1459+
D.oh * S.SH + D.kh * S.DH,
1460+
D.ow * S.SW + D.kw * S.DW,
1461+
D.ic,
1462+
],
1463+
)
1464+
* TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm])
1465+
)
14501466

14511467

14521468
@linalg_structured_op

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.