diff --git a/flang/lib/Lower/Support/Utils.cpp b/flang/lib/Lower/Support/Utils.cpp index 5a9a839330364..ed2700c42fc55 100644 --- a/flang/lib/Lower/Support/Utils.cpp +++ b/flang/lib/Lower/Support/Utils.cpp @@ -478,9 +478,47 @@ class IsEqualEvaluateExpr { return isEqual(x.proc(), y.proc()) && isEqual(x.arguments(), y.arguments()); } template + static bool isEqual(const Fortran::evaluate::ImpliedDo &x, + const Fortran::evaluate::ImpliedDo &y) { + return isEqual(x.values(), y.values()) && isEqual(x.lower(), y.lower()) && + isEqual(x.upper(), y.upper()) && isEqual(x.stride(), y.stride()); + } + template + static bool isEqual(const Fortran::evaluate::ArrayConstructorValues &x, + const Fortran::evaluate::ArrayConstructorValues &y) { + using Expr = Fortran::evaluate::Expr; + using ImpliedDo = Fortran::evaluate::ImpliedDo; + for (const auto &[xValue, yValue] : llvm::zip(x, y)) { + bool checkElement = Fortran::common::visit( + common::visitors{ + [&](const Expr &v, const Expr &w) { return isEqual(v, w); }, + [&](const ImpliedDo &v, const ImpliedDo &w) { + return isEqual(v, w); + }, + [&](const Expr &, const ImpliedDo &) { return false; }, + [&](const ImpliedDo &, const Expr &) { return false; }, + }, + xValue.u, yValue.u); + if (!checkElement) { + return false; + } + } + return true; + } + static bool isEqual(const Fortran::evaluate::SubscriptInteger &x, + const Fortran::evaluate::SubscriptInteger &y) { + return x == y; + } + template static bool isEqual(const Fortran::evaluate::ArrayConstructor &x, const Fortran::evaluate::ArrayConstructor &y) { - llvm::report_fatal_error("not implemented"); + bool checkCharacterType = true; + if constexpr (A::category == Fortran::common::TypeCategory::Character) { + checkCharacterType = isEqual(*x.LEN(), *y.LEN()); + } + using Base = Fortran::evaluate::ArrayConstructorValues; + return isEqual((Base)x, (Base)y) && + (x.GetType() == y.GetType() && checkCharacterType); } static bool isEqual(const Fortran::evaluate::ImpliedDoIndex &x, const Fortran::evaluate::ImpliedDoIndex &y) { diff --git a/flang/test/Lower/OpenMP/atomic-update.f90 b/flang/test/Lower/OpenMP/atomic-update.f90 index 16dae9d5f301c..7d04745015faa 100644 --- a/flang/test/Lower/OpenMP/atomic-update.f90 +++ b/flang/test/Lower/OpenMP/atomic-update.f90 @@ -185,4 +185,19 @@ program OmpAtomicUpdate !$omp atomic update w = max(w,x,y,z) +!CHECK: %[[IMP_DO:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr { +!CHECK: ^bb0(%{{.*}}: index): +! [...] +!CHECK: %[[ADD_I1:.*]] = arith.addi {{.*}} : i32 +!CHECK: hlfir.yield_element %[[ADD_I1]] : i32 +!CHECK: } +! [...] +!CHECK: %[[SUM:.*]] = hlfir.sum %[[IMP_DO]] +!CHECK: omp.atomic.update %[[VAL_X_DECLARE]]#1 : !fir.ref { +!CHECK: ^bb0(%[[ARG0:.*]]: i32): +!CHECK: %[[ADD_I2:.*]] = arith.addi %[[ARG0]], %[[SUM]] : i32 +!CHECK: omp.yield(%[[ADD_I2]] : i32) +!CHECK: } + !$omp atomic update + x = x + sum([ (y+2, y=1, z) ]) end program OmpAtomicUpdate