Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[flang][fir] Add locality specifiers modeling to fir.do_concurrent.loop #138506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion 33 flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3647,6 +3647,13 @@ def fir_DoConcurrentOp : fir_Op<"do_concurrent",
let hasVerifier = 1;
}

def fir_LocalSpecifier {
dag arguments = (ins
Variadic<AnyType>:$local_vars,
OptionalAttr<SymbolRefArrayAttr>:$local_syms
);
}

def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getLoopInductionVars"]>,
Expand Down Expand Up @@ -3700,7 +3707,7 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
LLVM.
}];

let arguments = (ins
defvar opArgs = (ins
Variadic<Index>:$lowerBound,
Variadic<Index>:$upperBound,
Variadic<Index>:$step,
Expand All @@ -3709,16 +3716,40 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
OptionalAttr<LoopAnnotationAttr>:$loopAnnotation
);

let arguments = !con(opArgs, fir_LocalSpecifier.arguments);

let regions = (region SizedRegion<1>:$region);

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;

let extraClassDeclaration = [{
unsigned getNumInductionVars() { return getLowerBound().size(); }

unsigned getNumLocalOperands() { return getLocalVars().size(); }

mlir::Block::BlockArgListType getInductionVars() {
return getBody()->getArguments().slice(0, getNumInductionVars());
}

mlir::Block::BlockArgListType getRegionLocalArgs() {
return getBody()->getArguments().slice(getNumInductionVars(),
getNumLocalOperands());
}

/// Number of operands controlling the loop
unsigned getNumControlOperands() { return getLowerBound().size() * 3; }

// Get Number of reduction operands
unsigned getNumReduceOperands() {
return getReduceOperands().size();
}

mlir::Operation::operand_range getLocalOperands() {
return getOperands()
.slice(getNumControlOperands() + getNumReduceOperands(),
getNumLocalOperands());
}
}];
}

Expand Down
2 changes: 1 addition & 1 deletion 2 flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2460,7 +2460,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
nestReduceAttrs.empty()
? nullptr
: mlir::ArrayAttr::get(builder->getContext(), nestReduceAttrs),
nullptr);
nullptr, /*local_vars=*/std::nullopt, /*local_syms=*/nullptr);

llvm::SmallVector<mlir::Type> loopBlockArgTypes(
incrementLoopNestInfo.size(), builder->getIndexType());
Expand Down
112 changes: 94 additions & 18 deletions 112 flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5033,29 +5033,33 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
auto &builder = parser.getBuilder();
// Parse an opening `(` followed by induction variables followed by `)`
llvm::SmallVector<mlir::OpAsmParser::Argument, 4> ivs;
if (parser.parseArgumentList(ivs, mlir::OpAsmParser::Delimiter::Paren))
llvm::SmallVector<mlir::OpAsmParser::Argument, 4> regionArgs;

if (parser.parseArgumentList(regionArgs, mlir::OpAsmParser::Delimiter::Paren))
return mlir::failure();

llvm::SmallVector<mlir::Type> argTypes(regionArgs.size(),
builder.getIndexType());

// Parse loop bounds.
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> lower;
if (parser.parseEqual() ||
parser.parseOperandList(lower, ivs.size(),
parser.parseOperandList(lower, regionArgs.size(),
mlir::OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(lower, builder.getIndexType(), result.operands))
return mlir::failure();

llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> upper;
if (parser.parseKeyword("to") ||
parser.parseOperandList(upper, ivs.size(),
parser.parseOperandList(upper, regionArgs.size(),
mlir::OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(upper, builder.getIndexType(), result.operands))
return mlir::failure();

// Parse step values.
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> steps;
if (parser.parseKeyword("step") ||
parser.parseOperandList(steps, ivs.size(),
parser.parseOperandList(steps, regionArgs.size(),
mlir::OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(steps, builder.getIndexType(), result.operands))
return mlir::failure();
Expand Down Expand Up @@ -5086,20 +5090,72 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
builder.getArrayAttr(arrayAttr));
}

// Now parse the body.
mlir::Region *body = result.addRegion();
for (auto &iv : ivs)
iv.type = builder.getIndexType();
if (parser.parseRegion(*body, ivs))
return mlir::failure();
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> localOperands;
if (succeeded(parser.parseOptionalKeyword("local"))) {
std::size_t oldArgTypesSize = argTypes.size();
if (failed(parser.parseLParen()))
return mlir::failure();

llvm::SmallVector<mlir::SymbolRefAttr> localSymbolVec;
if (failed(parser.parseCommaSeparatedList([&]() {
if (failed(parser.parseAttribute(localSymbolVec.emplace_back())))
return mlir::failure();

if (parser.parseOperand(localOperands.emplace_back()) ||
parser.parseArrow() ||
parser.parseArgument(regionArgs.emplace_back()))
return mlir::failure();

return mlir::success();
})))
return mlir::failure();

if (failed(parser.parseColon()))
return mlir::failure();

if (failed(parser.parseCommaSeparatedList([&]() {
if (failed(parser.parseType(argTypes.emplace_back())))
return mlir::failure();

return mlir::success();
})))
return mlir::failure();

if (regionArgs.size() != argTypes.size())
return parser.emitError(parser.getNameLoc(),
"mismatch in number of local arg and types");

if (failed(parser.parseRParen()))
return mlir::failure();

for (auto operandType : llvm::zip_equal(
localOperands, llvm::drop_begin(argTypes, oldArgTypesSize)))
if (parser.resolveOperand(std::get<0>(operandType),
std::get<1>(operandType), result.operands))
return mlir::failure();

llvm::SmallVector<mlir::Attribute> symbolAttrs(localSymbolVec.begin(),
localSymbolVec.end());
result.addAttribute(getLocalSymsAttrName(result.name),
builder.getArrayAttr(symbolAttrs));
}

// Set `operandSegmentSizes` attribute.
result.addAttribute(DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr(
{static_cast<int32_t>(lower.size()),
static_cast<int32_t>(upper.size()),
static_cast<int32_t>(steps.size()),
static_cast<int32_t>(reduceOperands.size())}));
static_cast<int32_t>(reduceOperands.size()),
static_cast<int32_t>(localOperands.size())}));

// Now parse the body.
for (auto [arg, type] : llvm::zip_equal(regionArgs, argTypes))
arg.type = type;

mlir::Region *body = result.addRegion();
if (parser.parseRegion(*body, regionArgs))
return mlir::failure();

// Parse attributes.
if (parser.parseOptionalAttrDict(result.attributes))
Expand All @@ -5109,8 +5165,9 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
}

void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
<< ") to (" << getUpperBound() << ") step (" << getStep() << ")";
p << " (" << getBody()->getArguments().slice(0, getNumInductionVars())
<< ") = (" << getLowerBound() << ") to (" << getUpperBound() << ") step ("
<< getStep() << ")";

if (!getReduceOperands().empty()) {
p << " reduce(";
Expand All @@ -5123,12 +5180,27 @@ void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
p << ')';
}

if (!getLocalVars().empty()) {
p << " local(";
llvm::interleaveComma(llvm::zip_equal(getLocalSymsAttr(), getLocalVars(),
getRegionLocalArgs()),
p, [&](auto it) {
p << std::get<0>(it) << " " << std::get<1>(it)
<< " -> " << std::get<2>(it);
});
p << " : ";
llvm::interleaveComma(getLocalVars(), p,
[&](auto it) { p << it.getType(); });
p << ")";
}

p << ' ';
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(
(*this)->getAttrs(),
/*elidedAttrs=*/{DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
DoConcurrentLoopOp::getReduceAttrsAttrName()});
DoConcurrentLoopOp::getReduceAttrsAttrName(),
DoConcurrentLoopOp::getLocalSymsAttrName()});
}

llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions() {
Expand All @@ -5139,6 +5211,7 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
mlir::Operation::operand_range lbValues = getLowerBound();
mlir::Operation::operand_range ubValues = getUpperBound();
mlir::Operation::operand_range stepValues = getStep();
mlir::Operation::operand_range localVars = getLocalVars();

if (lbValues.empty())
return emitOpError(
Expand All @@ -5152,11 +5225,13 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
// Check that the body defines the same number of block arguments as the
// number of tuple elements in step.
mlir::Block *body = getBody();
if (body->getNumArguments() != stepValues.size())
unsigned numIndVarArgs = body->getNumArguments() - localVars.size();

if (numIndVarArgs != stepValues.size())
return emitOpError() << "expects the same number of induction variables: "
<< body->getNumArguments()
<< " as bound and step values: " << stepValues.size();
for (auto arg : body->getArguments())
for (auto arg : body->getArguments().slice(0, numIndVarArgs))
if (!arg.getType().isIndex())
return emitOpError(
"expects arguments for the induction variable to be of index type");
Expand All @@ -5171,7 +5246,8 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {

std::optional<llvm::SmallVector<mlir::Value>>
fir::DoConcurrentLoopOp::getLoopInductionVars() {
return llvm::SmallVector<mlir::Value>{getBody()->getArguments()};
return llvm::SmallVector<mlir::Value>{
getBody()->getArguments().slice(0, getLowerBound().size())};
}

//===----------------------------------------------------------------------===//
Expand Down
54 changes: 53 additions & 1 deletion 54 flang/test/Fir/do_concurrent.fir
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ func.func @dc_2d_reduction(%i_lb: index, %i_ub: index, %i_st: index,
// CHECK: }
// CHECK: }


fir.local {type = local} @local_privatizer : i32

// CHECK: fir.local {type = local} @[[LOCAL_PRIV_SYM:local_privatizer]] : i32
Expand All @@ -109,3 +108,56 @@ fir.local {type = local_init} @local_init_privatizer : i32 copy {
// CHECK: fir.store %[[ORIG_VAL_LD]] to %[[LOCAL_VAL]] : !fir.ref<i32>
// CHECK: fir.yield(%[[LOCAL_VAL]] : !fir.ref<i32>)
// CHECK: }

func.func @do_concurrent_with_locality_specs() {
%3 = fir.alloca i32 {bindc_name = "local_init_var"}
%4:2 = hlfir.declare %3 {uniq_name = "_QFdo_concurrentElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%5 = fir.alloca i32 {bindc_name = "local_var"}
%6:2 = hlfir.declare %5 {uniq_name = "_QFdo_concurrentElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index

fir.do_concurrent {
%9 = fir.alloca i32 {bindc_name = "i"}
%10:2 = hlfir.declare %9 {uniq_name = "_QFdo_concurrentEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)

fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1)
local(@local_privatizer %6#0 -> %arg1, @local_init_privatizer %4#0 -> %arg2 : !fir.ref<i32>, !fir.ref<i32>) {
%11 = fir.convert %arg0 : (index) -> i32
fir.store %11 to %10#0 : !fir.ref<i32>
%13:2 = hlfir.declare %arg1 {uniq_name = "_QFdo_concurrentElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%15:2 = hlfir.declare %arg2 {uniq_name = "_QFdo_concurrentElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
}
}
return
}

// CHECK-LABEL: func.func @do_concurrent_with_locality_specs() {
// CHECK: %[[LOC_INIT_ALLOC:.*]] = fir.alloca i32 {bindc_name = "local_init_var"}
// CHECK: %[[LOC_INIT_DECL:.*]]:2 = hlfir.declare %[[LOC_INIT_ALLOC]]

// CHECK: %[[LOC_ALLOC:.*]] = fir.alloca i32 {bindc_name = "local_var"}
// CHECK: %[[LOC_DECL:.*]]:2 = hlfir.declare %[[LOC_ALLOC]]

// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C10:.*]] = arith.constant 10 : index

// CHECK: fir.do_concurrent {
// CHECK: %[[DC_I_ALLOC:.*]] = fir.alloca i32 {bindc_name = "i"}
// CHECK: %[[DC_I_DECL:.*]]:2 = hlfir.declare %[[DC_I_ALLOC]]

// CHECK: fir.do_concurrent.loop (%[[IV:.*]]) = (%[[C1]]) to
// CHECK-SAME: (%[[C10]]) step (%[[C1]])
// CHECK-SAME: local(@[[LOCAL_PRIV_SYM]] %[[LOC_DECL]]#0 -> %[[LOC_ARG:[^,]*]],
// CHECK-SAME: @[[LOCAL_INIT_PRIV_SYM]] %[[LOC_INIT_DECL]]#0 -> %[[LOC_INIT_ARG:.*]] :
// CHECK-SAME: !fir.ref<i32>, !fir.ref<i32>) {

// CHECK: %[[IV_CVT:.*]] = fir.convert %[[IV]] : (index) -> i32
// CHECK: fir.store %[[IV_CVT]] to %[[DC_I_DECL]]#0 : !fir.ref<i32>

// CHECK: %[[LOC_PRIV_DECL:.*]]:2 = hlfir.declare %[[LOC_ARG]]
// CHECK: %[[LOC_INIT_PRIV_DECL:.*]]:2 = hlfir.declare %[[LOC_INIT_ARG]]
// CHECK: }
// CHECK: }
// CHECK: return
// CHECK: }
10 changes: 5 additions & 5 deletions 10 flang/test/Fir/invalid.fir
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,7 @@ func.func @dc_0d() {

func.func @dc_invalid_parent(%arg0: index, %arg1: index) {
// expected-error@+1 {{'fir.do_concurrent.loop' op expects parent op 'fir.do_concurrent'}}
"fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>}> ({
"fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0, 0>}> ({
^bb0(%arg2: index):
%tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
}) : (index, index) -> ()
Expand All @@ -1208,7 +1208,7 @@ func.func @dc_invalid_parent(%arg0: index, %arg1: index) {
func.func @dc_invalid_control(%arg0: index, %arg1: index) {
// expected-error@+2 {{'fir.do_concurrent.loop' op different number of tuple elements for lowerBound, upperBound or step}}
fir.do_concurrent {
"fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>}> ({
"fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0, 0>}> ({
^bb0(%arg2: index):
%tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
}) : (index, index) -> ()
Expand All @@ -1221,7 +1221,7 @@ func.func @dc_invalid_control(%arg0: index, %arg1: index) {
func.func @dc_invalid_ind_var(%arg0: index, %arg1: index) {
// expected-error@+2 {{'fir.do_concurrent.loop' op expects the same number of induction variables: 2 as bound and step values: 1}}
fir.do_concurrent {
"fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>}> ({
"fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 0>}> ({
^bb0(%arg3: index, %arg4: index):
%tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
}) : (index, index, index) -> ()
Expand All @@ -1234,7 +1234,7 @@ func.func @dc_invalid_ind_var(%arg0: index, %arg1: index) {
func.func @dc_invalid_ind_var_type(%arg0: index, %arg1: index) {
// expected-error@+2 {{'fir.do_concurrent.loop' op expects arguments for the induction variable to be of index type}}
fir.do_concurrent {
"fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>}> ({
"fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 0>}> ({
^bb0(%arg3: i32):
%tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
}) : (index, index, index) -> ()
Expand All @@ -1248,7 +1248,7 @@ func.func @dc_invalid_reduction(%arg0: index, %arg1: index) {
%sum = fir.alloca i32
// expected-error@+2 {{'fir.do_concurrent.loop' op mismatch in number of reduction variables and reduction attributes}}
fir.do_concurrent {
"fir.do_concurrent.loop"(%arg0, %arg1, %arg0, %sum) <{operandSegmentSizes = array<i32: 1, 1, 1, 1>}> ({
"fir.do_concurrent.loop"(%arg0, %arg1, %arg0, %sum) <{operandSegmentSizes = array<i32: 1, 1, 1, 1, 0>}> ({
^bb0(%arg3: index):
%tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
}) : (index, index, index, !fir.ref<i32>) -> ()
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.