Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c15539c

Browse filesBrowse files
authored
[mlir][x86vector] Improve intrinsic operands creation (#138666)
Refactors intrinsic op interface to delegate initial operands mapping to the dialect converter and allow intrinsic operands getters to only perform last mile post-processing.
1 parent aa9f859 commit c15539c
Copy full SHA for c15539c

File tree

4 files changed

+72
-52
lines changed
Filter options

4 files changed

+72
-52
lines changed

‎mlir/include/mlir/Dialect/X86Vector/X86Vector.td

Copy file name to clipboardExpand all lines: mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+20-5Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
8383
}
8484
}];
8585
let extraClassDeclaration = [{
86-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
86+
SmallVector<Value> getIntrinsicOperands(
87+
::mlir::ArrayRef<Value> operands,
88+
const ::mlir::LLVMTypeConverter &typeConverter,
89+
::mlir::RewriterBase &rewriter);
8790
}];
8891
}
8992

@@ -404,7 +407,10 @@ def DotOp : AVX_LowOp<"dot", [Pure,
404407
}
405408
}];
406409
let extraClassDeclaration = [{
407-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
410+
SmallVector<Value> getIntrinsicOperands(
411+
::mlir::ArrayRef<Value> operands,
412+
const ::mlir::LLVMTypeConverter &typeConverter,
413+
::mlir::RewriterBase &rewriter);
408414
}];
409415
}
410416

@@ -452,7 +458,10 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
452458
}];
453459

454460
let extraClassDeclaration = [{
455-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
461+
SmallVector<Value> getIntrinsicOperands(
462+
::mlir::ArrayRef<Value> operands,
463+
const ::mlir::LLVMTypeConverter &typeConverter,
464+
::mlir::RewriterBase &rewriter);
456465
}];
457466

458467
}
@@ -500,7 +509,10 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
500509
}];
501510

502511
let extraClassDeclaration = [{
503-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
512+
SmallVector<Value> getIntrinsicOperands(
513+
::mlir::ArrayRef<Value> operands,
514+
const ::mlir::LLVMTypeConverter &typeConverter,
515+
::mlir::RewriterBase &rewriter);
504516
}];
505517
}
506518

@@ -543,7 +555,10 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory
543555
}];
544556

545557
let extraClassDeclaration = [{
546-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
558+
SmallVector<Value> getIntrinsicOperands(
559+
::mlir::ArrayRef<Value> operands,
560+
const ::mlir::LLVMTypeConverter &typeConverter,
561+
::mlir::RewriterBase &rewriter);
547562
}];
548563
}
549564
#endif // X86VECTOR_OPS

‎mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td

Copy file name to clipboardExpand all lines: mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
+4-2Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> {
5858
}],
5959
/*retType=*/"SmallVector<Value>",
6060
/*methodName=*/"getIntrinsicOperands",
61-
/*args=*/(ins "::mlir::RewriterBase &":$rewriter, "const LLVMTypeConverter &":$typeConverter),
61+
/*args=*/(ins "::mlir::ArrayRef<Value>":$operands,
62+
"const ::mlir::LLVMTypeConverter &":$typeConverter,
63+
"::mlir::RewriterBase &":$rewriter),
6264
/*methodBody=*/"",
63-
/*defaultImplementation=*/"return SmallVector<Value>($_op->getOperands());"
65+
/*defaultImplementation=*/"return SmallVector<Value>(operands);"
6466
>,
6567
];
6668
}

‎mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp

Copy file name to clipboardExpand all lines: mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+36-36Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,11 @@ void x86vector::X86VectorDialect::initialize() {
3131
>();
3232
}
3333

34-
static SmallVector<Value>
35-
getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal,
36-
RewriterBase &rewriter,
37-
const LLVMTypeConverter &typeConverter) {
38-
SmallVector<Value> operands;
39-
auto opType = memrefVal.getType();
40-
41-
Type llvmStructType = typeConverter.convertType(opType);
42-
Value llvmStruct =
43-
rewriter
44-
.create<UnrealizedConversionCastOp>(loc, llvmStructType, memrefVal)
45-
.getResult(0);
46-
MemRefDescriptor memRefDescriptor(llvmStruct);
47-
48-
Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType);
49-
operands.push_back(ptr);
50-
51-
return operands;
34+
static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
35+
const LLVMTypeConverter &typeConverter,
36+
RewriterBase &rewriter) {
37+
MemRefDescriptor memRefDescriptor(buffer);
38+
return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
5239
}
5340

5441
LogicalResult x86vector::MaskCompressOp::verify() {
@@ -66,48 +53,61 @@ LogicalResult x86vector::MaskCompressOp::verify() {
6653
}
6754

6855
SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
69-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
56+
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
57+
RewriterBase &rewriter) {
7058
auto loc = getLoc();
59+
Adaptor adaptor(operands, *this);
7160

72-
auto opType = getA().getType();
61+
auto opType = adaptor.getA().getType();
7362
Value src;
74-
if (getSrc()) {
75-
src = getSrc();
76-
} else if (getConstantSrc()) {
77-
src = rewriter.create<LLVM::ConstantOp>(loc, opType, getConstantSrcAttr());
63+
if (adaptor.getSrc()) {
64+
src = adaptor.getSrc();
65+
} else if (adaptor.getConstantSrc()) {
66+
src = rewriter.create<LLVM::ConstantOp>(loc, opType,
67+
adaptor.getConstantSrcAttr());
7868
} else {
7969
auto zeroAttr = rewriter.getZeroAttr(opType);
8070
src = rewriter.create<LLVM::ConstantOp>(loc, opType, zeroAttr);
8171
}
8272

83-
return SmallVector<Value>{getA(), src, getK()};
73+
return SmallVector<Value>{adaptor.getA(), src, adaptor.getK()};
8474
}
8575

8676
SmallVector<Value>
87-
x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
88-
const LLVMTypeConverter &typeConverter) {
89-
SmallVector<Value> operands(getOperands());
77+
x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
78+
const LLVMTypeConverter &typeConverter,
79+
RewriterBase &rewriter) {
80+
SmallVector<Value> intrinsicOperands(operands);
9081
// Dot product of all elements, broadcasted to all elements.
9182
Value scale =
9283
rewriter.create<LLVM::ConstantOp>(getLoc(), rewriter.getI8Type(), 0xff);
93-
operands.push_back(scale);
84+
intrinsicOperands.push_back(scale);
9485

95-
return operands;
86+
return intrinsicOperands;
9687
}
9788

9889
SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
99-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
100-
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
90+
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
91+
RewriterBase &rewriter) {
92+
Adaptor adaptor(operands, *this);
93+
return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
94+
typeConverter, rewriter)};
10195
}
10296

10397
SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
104-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
105-
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
98+
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
99+
RewriterBase &rewriter) {
100+
Adaptor adaptor(operands, *this);
101+
return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
102+
typeConverter, rewriter)};
106103
}
107104

108105
SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
109-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
110-
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
106+
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
107+
RewriterBase &rewriter) {
108+
Adaptor adaptor(operands, *this);
109+
return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
110+
typeConverter, rewriter)};
111111
}
112112

113113
#define GET_OP_CLASSES

‎mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp

Copy file name to clipboardExpand all lines: mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+12-9Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,23 @@ LogicalResult intrinsicRewrite(Operation *op, StringAttr intrinsic,
8484
/// Generic one-to-one conversion of simply mappable operations into calls
8585
/// to their respective LLVM intrinsics.
8686
struct OneToOneIntrinsicOpConversion
87-
: public OpInterfaceRewritePattern<x86vector::OneToOneIntrinsicOp> {
88-
using OpInterfaceRewritePattern<
89-
x86vector::OneToOneIntrinsicOp>::OpInterfaceRewritePattern;
87+
: public OpInterfaceConversionPattern<x86vector::OneToOneIntrinsicOp> {
88+
using OpInterfaceConversionPattern<
89+
x86vector::OneToOneIntrinsicOp>::OpInterfaceConversionPattern;
9090

9191
OneToOneIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
9292
PatternBenefit benefit = 1)
93-
: OpInterfaceRewritePattern(&typeConverter.getContext(), benefit),
93+
: OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
94+
benefit),
9495
typeConverter(typeConverter) {}
9596

96-
LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op,
97-
PatternRewriter &rewriter) const override {
98-
return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()),
99-
op.getIntrinsicOperands(rewriter, typeConverter),
100-
typeConverter, rewriter);
97+
LogicalResult
98+
matchAndRewrite(x86vector::OneToOneIntrinsicOp op, ArrayRef<Value> operands,
99+
ConversionPatternRewriter &rewriter) const override {
100+
return intrinsicRewrite(
101+
op, rewriter.getStringAttr(op.getIntrinsicName()),
102+
op.getIntrinsicOperands(operands, typeConverter, rewriter),
103+
typeConverter, rewriter);
101104
}
102105

103106
private:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.