Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[MLIR][NVVM] Add dot.accumulate.4way OP #139043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions 64 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3444,6 +3444,70 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// NVVM dot.accumulate.4way Op
//===----------------------------------------------------------------------===//

def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;

def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
"NVVM DotAccumulate4WayType",
[DotAccumulate4WayS8, DotAccumulate4WayU8]> {
let cppNamespace = "::mlir::NVVM";
let genSpecializedAttr = 0;
}

def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType, "dot_accumulate_4way_type"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
let summary = "Four-way byte dot product-accumulate instruction.";
let description = [{
Performs a four-way byte dot-product which is accumulated in a 32-bit
result.
Operand `a` and `b` are vectors of 4 bytes between which the dot product is
computed.
The `a_type` and `b_type` attributes specify the type of the elements in `a`
and `b` respectively.
If `a_type` or `b_type` is `s8`, then the elements in the corresponding
vector are sign-extended to 32-bit before the dot product is computed.
If `a_type` or `b_type` is `u8`, then the elements in the corresponding
vector are zero-extended to 32-bit instead.
Operand `c` is a 32-bit integer to which the result is accumulated. It is
treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a)
}];

let arguments = (ins
VectorOfLengthAndType<[4], [I8]>:$a,
DotAccumulate4WayTypeAttr:$a_type,
VectorOfLengthAndType<[4], [I8]>:$b,
DotAccumulate4WayTypeAttr:$b_type,
I32:$c
);

let results = (outs I32:$res);

let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";

let extraClassDeclaration = [{
static llvm::Intrinsic::ID
getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
NVVM::DotAccumulate4WayType b_type);
llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
}];

string llvmBuilder = [{
llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
llvm::Value* argA = op.getPackedArg($a, builder);
llvm::Value* argB = op.getPackedArg($b, builder);
$res = createIntrinsicCall(builder, id, {argA, argB, $c});
}];
}

//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
Expand Down
28 changes: 28 additions & 0 deletions 28 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -1203,6 +1204,13 @@ LogicalResult NVVM::VoteSyncOp::verify() {
return success();
}

llvm::Value *
NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
llvm::IRBuilderBase &builder) {
return builder.CreateBitCast(arg,
llvm::Type::getInt32Ty(builder.getContext()));
}

//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1590,6 +1598,26 @@ static void nvvmInferResultRanges(Operation *op, Value result,
}
}

llvm::Intrinsic::ID
DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
NVVM::DotAccumulate4WayType b_type) {
bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
unsigned type = (is_a_siext << 1) | is_b_siext;
switch (type) {
case 0:
return llvm::Intrinsic::nvvm_idp4a_u_u;
case 1:
return llvm::Intrinsic::nvvm_idp4a_u_s;
case 2:
return llvm::Intrinsic::nvvm_idp4a_s_u;
case 3:
return llvm::Intrinsic::nvvm_idp4a_s_s;
default:
llvm_unreachable("Invalid DP4a type");
}
durga4github marked this conversation as resolved.
Show resolved Hide resolved
}

//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 9 additions & 0 deletions 9 mlir/test/Dialect/LLVMIR/nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,15 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
return
}

// CHECK-LABEL: @dot_accumulate_4way
func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
%1 = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
%3 = nvvm.dot.accumulate.4way %a_vec <s8>, %b_vec <s8>, %c: vector<4xi8>, vector<4xi8>
return
}

// -----

// Just check these don't emit errors.
Expand Down
22 changes: 22 additions & 0 deletions 22 mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -844,3 +844,25 @@ llvm.func @nvvm_st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size:
nvvm.st.bulk %addr_shared, size = %size, init = 0: !llvm.ptr<3>
llvm.return
}

// -----
// CHECK-LABEL: @nvvm_dot_accumulate_4way
llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32) {
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
%0 = nvvm.dot.accumulate.4way %a <u8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
%1 = nvvm.dot.accumulate.4way %a <s8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
%2 = nvvm.dot.accumulate.4way %a <u8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
%3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
llvm.return
}
Morty Proxy This is a proxified and sanitized view of the page, visit original site.