blob: 2ab91b239a3ccd56c69428b96edb0d7104dddb5e [file] [log] [blame]
//===-- Atomic.cpp -- Lowering of atomic constructs -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Atomic.h"
#include "Clauses.h"
#include "flang/Evaluate/expression.h"
#include "flang/Evaluate/fold.h"
#include "flang/Evaluate/tools.h"
#include "flang/Evaluate/traverse.h"
#include "flang/Evaluate/type.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/type.h"
#include "flang/Support/Fortran.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#include <string>
#include <type_traits>
#include <variant>
#include <vector>
static llvm::cl::opt<bool> DumpAtomicAnalysis("fdebug-dump-atomic-analysis");
using namespace Fortran;
// Don't import the entire Fortran::lower.
namespace omp {
using namespace Fortran::lower::omp;
}
namespace {
// An example of a type that can be used to get the return value from
// the visitor:
// visitor(type_identity<Xyz>) -> result_type
using SomeArgType = evaluate::Type<common::TypeCategory::Integer, 4>;
struct GetProc
: public evaluate::Traverse<GetProc, const evaluate::ProcedureDesignator *,
false> {
using Result = const evaluate::ProcedureDesignator *;
using Base = evaluate::Traverse<GetProc, Result, false>;
GetProc() : Base(*this) {}
using Base::operator();
static Result Default() { return nullptr; }
Result operator()(const evaluate::ProcedureDesignator &p) const { return &p; }
static Result Combine(Result a, Result b) { return a != nullptr ? a : b; }
};
struct WithType {
WithType(const evaluate::DynamicType &t) : type(t) {
assert(type.category() != common::TypeCategory::Derived &&
"Type cannot be a derived type");
}
template <typename VisitorTy> //
auto visit(VisitorTy &&visitor) const
-> std::invoke_result_t<VisitorTy, SomeArgType> {
switch (type.category()) {
case common::TypeCategory::Integer:
switch (type.kind()) {
case 1:
return visitor(llvm::type_identity<evaluate::Type<Integer, 1>>{});
case 2:
return visitor(llvm::type_identity<evaluate::Type<Integer, 2>>{});
case 4:
return visitor(llvm::type_identity<evaluate::Type<Integer, 4>>{});
case 8:
return visitor(llvm::type_identity<evaluate::Type<Integer, 8>>{});
case 16:
return visitor(llvm::type_identity<evaluate::Type<Integer, 16>>{});
}
break;
case common::TypeCategory::Unsigned:
switch (type.kind()) {
case 1:
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 1>>{});
case 2:
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 2>>{});
case 4:
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 4>>{});
case 8:
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 8>>{});
case 16:
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 16>>{});
}
break;
case common::TypeCategory::Real:
switch (type.kind()) {
case 2:
return visitor(llvm::type_identity<evaluate::Type<Real, 2>>{});
case 3:
return visitor(llvm::type_identity<evaluate::Type<Real, 3>>{});
case 4:
return visitor(llvm::type_identity<evaluate::Type<Real, 4>>{});
case 8:
return visitor(llvm::type_identity<evaluate::Type<Real, 8>>{});
case 10:
return visitor(llvm::type_identity<evaluate::Type<Real, 10>>{});
case 16:
return visitor(llvm::type_identity<evaluate::Type<Real, 16>>{});
}
break;
case common::TypeCategory::Complex:
switch (type.kind()) {
case 2:
return visitor(llvm::type_identity<evaluate::Type<Complex, 2>>{});
case 3:
return visitor(llvm::type_identity<evaluate::Type<Complex, 3>>{});
case 4:
return visitor(llvm::type_identity<evaluate::Type<Complex, 4>>{});
case 8:
return visitor(llvm::type_identity<evaluate::Type<Complex, 8>>{});
case 10:
return visitor(llvm::type_identity<evaluate::Type<Complex, 10>>{});
case 16:
return visitor(llvm::type_identity<evaluate::Type<Complex, 16>>{});
}
break;
case common::TypeCategory::Logical:
switch (type.kind()) {
case 1:
return visitor(llvm::type_identity<evaluate::Type<Logical, 1>>{});
case 2:
return visitor(llvm::type_identity<evaluate::Type<Logical, 2>>{});
case 4:
return visitor(llvm::type_identity<evaluate::Type<Logical, 4>>{});
case 8:
return visitor(llvm::type_identity<evaluate::Type<Logical, 8>>{});
}
break;
case common::TypeCategory::Character:
switch (type.kind()) {
case 1:
return visitor(llvm::type_identity<evaluate::Type<Character, 1>>{});
case 2:
return visitor(llvm::type_identity<evaluate::Type<Character, 2>>{});
case 4:
return visitor(llvm::type_identity<evaluate::Type<Character, 4>>{});
}
break;
case common::TypeCategory::Derived:
(void)Derived;
break;
}
llvm_unreachable("Unhandled type");
}
const evaluate::DynamicType &type;
private:
// Shorter names.
static constexpr auto Character = common::TypeCategory::Character;
static constexpr auto Complex = common::TypeCategory::Complex;
static constexpr auto Derived = common::TypeCategory::Derived;
static constexpr auto Integer = common::TypeCategory::Integer;
static constexpr auto Logical = common::TypeCategory::Logical;
static constexpr auto Real = common::TypeCategory::Real;
static constexpr auto Unsigned = common::TypeCategory::Unsigned;
};
template <typename T, typename U = std::remove_const_t<T>>
U AsRvalue(T &t) {
U copy{t};
return std::move(copy);
}
template <typename T>
T &&AsRvalue(T &&t) {
return std::move(t);
}
struct ArgumentReplacer
: public evaluate::Traverse<ArgumentReplacer, bool, false> {
using Base = evaluate::Traverse<ArgumentReplacer, bool, false>;
using Result = bool;
Result Default() const { return false; }
ArgumentReplacer(evaluate::ActualArguments &&newArgs)
: Base(*this), args_(std::move(newArgs)) {}
using Base::operator();
template <typename T>
Result operator()(const evaluate::FunctionRef<T> &x) {
assert(!done_);
auto &mut = const_cast<evaluate::FunctionRef<T> &>(x);
mut.arguments() = args_;
done_ = true;
return true;
}
Result Combine(Result &&a, Result &&b) { return a || b; }
private:
bool done_{false};
evaluate::ActualArguments &&args_;
};
} // namespace
[[maybe_unused]] static void
dumpAtomicAnalysis(const parser::OpenMPAtomicConstruct::Analysis &analysis) {
auto whatStr = [](int k) {
std::string txt = "?";
switch (k & parser::OpenMPAtomicConstruct::Analysis::Action) {
case parser::OpenMPAtomicConstruct::Analysis::None:
txt = "None";
break;
case parser::OpenMPAtomicConstruct::Analysis::Read:
txt = "Read";
break;
case parser::OpenMPAtomicConstruct::Analysis::Write:
txt = "Write";
break;
case parser::OpenMPAtomicConstruct::Analysis::Update:
txt = "Update";
break;
}
switch (k & parser::OpenMPAtomicConstruct::Analysis::Condition) {
case parser::OpenMPAtomicConstruct::Analysis::IfTrue:
txt += " | IfTrue";
break;
case parser::OpenMPAtomicConstruct::Analysis::IfFalse:
txt += " | IfFalse";
break;
}
return txt;
};
auto exprStr = [&](const parser::TypedExpr &expr) {
if (auto *maybe = expr.get()) {
if (maybe->v)
return maybe->v->AsFortran();
}
return "<null>"s;
};
auto assignStr = [&](const parser::AssignmentStmt::TypedAssignment &assign) {
if (auto *maybe = assign.get(); maybe && maybe->v) {
std::string str;
llvm::raw_string_ostream os(str);
maybe->v->AsFortran(os);
return str;
}
return "<null>"s;
};
const semantics::SomeExpr &atom = *analysis.atom.get()->v;
llvm::errs() << "Analysis {\n";
llvm::errs() << " atom: " << atom.AsFortran() << "\n";
llvm::errs() << " cond: " << exprStr(analysis.cond) << "\n";
llvm::errs() << " op0 {\n";
llvm::errs() << " what: " << whatStr(analysis.op0.what) << "\n";
llvm::errs() << " assign: " << assignStr(analysis.op0.assign) << "\n";
llvm::errs() << " }\n";
llvm::errs() << " op1 {\n";
llvm::errs() << " what: " << whatStr(analysis.op1.what) << "\n";
llvm::errs() << " assign: " << assignStr(analysis.op1.assign) << "\n";
llvm::errs() << " }\n";
llvm::errs() << "}\n";
}
static bool isPointerAssignment(const evaluate::Assignment &assign) {
return common::visit(
common::visitors{
[](const evaluate::Assignment::BoundsSpec &) { return true; },
[](const evaluate::Assignment::BoundsRemapping &) { return true; },
[](const auto &) { return false; },
},
assign.u);
}
static fir::FirOpBuilder::InsertPoint
getInsertionPointBefore(mlir::Operation *op) {
return fir::FirOpBuilder::InsertPoint(op->getBlock(),
mlir::Block::iterator(op));
}
static fir::FirOpBuilder::InsertPoint
getInsertionPointAfter(mlir::Operation *op) {
return fir::FirOpBuilder::InsertPoint(op->getBlock(),
++mlir::Block::iterator(op));
}
static mlir::IntegerAttr getAtomicHint(lower::AbstractConverter &converter,
const omp::List<omp::Clause> &clauses) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
for (const omp::Clause &clause : clauses) {
if (clause.id != llvm::omp::Clause::OMPC_hint)
continue;
auto &hint = std::get<omp::clause::Hint>(clause.u);
auto maybeVal = evaluate::ToInt64(hint.v);
CHECK(maybeVal);
return builder.getI64IntegerAttr(*maybeVal);
}
return nullptr;
}
static mlir::omp::ClauseMemoryOrderKind
getMemoryOrderKind(common::OmpMemoryOrderType kind) {
switch (kind) {
case common::OmpMemoryOrderType::Acq_Rel:
return mlir::omp::ClauseMemoryOrderKind::Acq_rel;
case common::OmpMemoryOrderType::Acquire:
return mlir::omp::ClauseMemoryOrderKind::Acquire;
case common::OmpMemoryOrderType::Relaxed:
return mlir::omp::ClauseMemoryOrderKind::Relaxed;
case common::OmpMemoryOrderType::Release:
return mlir::omp::ClauseMemoryOrderKind::Release;
case common::OmpMemoryOrderType::Seq_Cst:
return mlir::omp::ClauseMemoryOrderKind::Seq_cst;
}
llvm_unreachable("Unexpected kind");
}
static std::optional<mlir::omp::ClauseMemoryOrderKind>
getMemoryOrderKind(llvm::omp::Clause clauseId) {
switch (clauseId) {
case llvm::omp::Clause::OMPC_acq_rel:
return mlir::omp::ClauseMemoryOrderKind::Acq_rel;
case llvm::omp::Clause::OMPC_acquire:
return mlir::omp::ClauseMemoryOrderKind::Acquire;
case llvm::omp::Clause::OMPC_relaxed:
return mlir::omp::ClauseMemoryOrderKind::Relaxed;
case llvm::omp::Clause::OMPC_release:
return mlir::omp::ClauseMemoryOrderKind::Release;
case llvm::omp::Clause::OMPC_seq_cst:
return mlir::omp::ClauseMemoryOrderKind::Seq_cst;
default:
return std::nullopt;
}
}
static std::optional<mlir::omp::ClauseMemoryOrderKind>
getMemoryOrderFromRequires(const semantics::Scope &scope) {
// The REQUIRES construct is only allowed in the main program scope
// and module scope, but seems like we also accept it in a subprogram
// scope.
// For safety, traverse all enclosing scopes and check if their symbol
// contains REQUIRES.
for (const auto *sc{&scope}; sc->kind() != semantics::Scope::Kind::Global;
sc = &sc->parent()) {
const semantics::Symbol *sym = sc->symbol();
if (!sym)
continue;
const common::OmpMemoryOrderType *admo = common::visit(
[](auto &&s) {
using WithOmpDeclarative = semantics::WithOmpDeclarative;
if constexpr (std::is_convertible_v<decltype(s),
const WithOmpDeclarative &>) {
return s.ompAtomicDefaultMemOrder();
}
return static_cast<const common::OmpMemoryOrderType *>(nullptr);
},
sym->details());
if (admo)
return getMemoryOrderKind(*admo);
}
return std::nullopt;
}
static std::optional<mlir::omp::ClauseMemoryOrderKind>
getDefaultAtomicMemOrder(semantics::SemanticsContext &semaCtx) {
unsigned version = semaCtx.langOptions().OpenMPVersion;
if (version > 50)
return mlir::omp::ClauseMemoryOrderKind::Relaxed;
return std::nullopt;
}
static std::optional<mlir::omp::ClauseMemoryOrderKind>
getAtomicMemoryOrder(semantics::SemanticsContext &semaCtx,
const omp::List<omp::Clause> &clauses,
const semantics::Scope &scope) {
for (const omp::Clause &clause : clauses) {
if (auto maybeKind = getMemoryOrderKind(clause.id))
return *maybeKind;
}
if (auto maybeKind = getMemoryOrderFromRequires(scope))
return *maybeKind;
return getDefaultAtomicMemOrder(semaCtx);
}
static mlir::omp::ClauseMemoryOrderKindAttr
makeMemOrderAttr(lower::AbstractConverter &converter,
std::optional<mlir::omp::ClauseMemoryOrderKind> maybeKind) {
if (maybeKind) {
return mlir::omp::ClauseMemoryOrderKindAttr::get(
converter.getFirOpBuilder().getContext(), *maybeKind);
}
return nullptr;
}
static bool replaceArgs(semantics::SomeExpr &expr,
evaluate::ActualArguments &&newArgs) {
return ArgumentReplacer(std::move(newArgs))(expr);
}
static semantics::SomeExpr makeCall(const evaluate::DynamicType &type,
const evaluate::ProcedureDesignator &proc,
const evaluate::ActualArguments &args) {
return WithType(type).visit([&](auto &&s) -> semantics::SomeExpr {
using Type = typename llvm::remove_cvref_t<decltype(s)>::type;
return evaluate::AsGenericExpr(
evaluate::FunctionRef<Type>(AsRvalue(proc), AsRvalue(args)));
});
}
static const evaluate::ProcedureDesignator &
getProcedureDesignator(const semantics::SomeExpr &call) {
const evaluate::ProcedureDesignator *proc = GetProc{}(call);
assert(proc && "Call has no procedure designator");
return *proc;
}
static semantics::SomeExpr //
genReducedMinMax(const semantics::SomeExpr &orig,
const semantics::SomeExpr *atomArg,
const std::vector<semantics::SomeExpr> &args) {
// Take a list of arguments to a min/max operation, e.g. [a0, a1, ...]
// One of the a_i's, say a_t, must be atomArg.
// Generate tmp = min/max(a0, a1, ... [except a_t]). Then generate
// call = min/max(a_t, tmp).
// Return "call".
// The min/max intrinsics have 2 mandatory arguments, the rest is optional.
// Make sure that the "tmp = min/max(...)" doesn't promote an optional
// argument to a non-optional position. This could happen if a_t is at
// position 0 or 1.
if (args.size() <= 2)
return orig;
evaluate::ActualArguments nonAtoms;
auto AsActual = [](const semantics::SomeExpr &x) {
semantics::SomeExpr copy = x;
return evaluate::ActualArgument(std::move(copy));
};
// Semantic checks guarantee that the "atom" shows exactly once in the
// argument list (with potential conversions around it).
// For the first two (non-optional) arguments, if "atom" is among them,
// replace it with another occurrence of the other non-optional argument.
if (atomArg == &args[0]) {
// (atom, x, y...) -> (x, x, y...)
nonAtoms.push_back(AsActual(args[1]));
nonAtoms.push_back(AsActual(args[1]));
} else if (atomArg == &args[1]) {
// (x, atom, y...) -> (x, x, y...)
nonAtoms.push_back(AsActual(args[0]));
nonAtoms.push_back(AsActual(args[0]));
} else {
// (x, y, z...) -> unchanged
nonAtoms.push_back(AsActual(args[0]));
nonAtoms.push_back(AsActual(args[1]));
}
// The rest of arguments are optional, so we can just skip "atom".
for (size_t i = 2, e = args.size(); i != e; ++i) {
if (atomArg != &args[i])
nonAtoms.push_back(AsActual(args[i]));
}
// The type of the intermediate min/max is the same as the type of its
// arguments, which may be different from the type of the original
// expression. The original expression may have additional coverts.
auto tmp =
makeCall(*atomArg->GetType(), getProcedureDesignator(orig), nonAtoms);
semantics::SomeExpr call = orig;
replaceArgs(call, {AsActual(*atomArg), AsActual(tmp)});
return call;
}
static mlir::Operation * //
genAtomicRead(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx, mlir::Location loc,
lower::StatementContext &stmtCtx, mlir::Value atomAddr,
const semantics::SomeExpr &atom,
const evaluate::Assignment &assign, mlir::IntegerAttr hint,
std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder,
fir::FirOpBuilder::InsertPoint preAt,
fir::FirOpBuilder::InsertPoint atomicAt,
fir::FirOpBuilder::InsertPoint postAt) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
builder.restoreInsertionPoint(preAt);
// If the atomic clause is read then the memory-order clause must
// not be release.
if (memOrder) {
if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Release) {
// Reset it back to the default.
memOrder = getDefaultAtomicMemOrder(semaCtx);
} else if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acq_rel) {
// The MLIR verifier doesn't like acq_rel either.
memOrder = mlir::omp::ClauseMemoryOrderKind::Acquire;
}
}
mlir::Value storeAddr =
fir::getBase(converter.genExprAddr(assign.lhs, stmtCtx, &loc));
mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
mlir::Type storeType = fir::unwrapRefType(storeAddr.getType());
mlir::Value toAddr = [&]() {
if (atomType == storeType)
return storeAddr;
return builder.createTemporary(loc, atomType, ".tmp.atomval");
}();
builder.restoreInsertionPoint(atomicAt);
mlir::Operation *op = builder.create<mlir::omp::AtomicReadOp>(
loc, atomAddr, toAddr, mlir::TypeAttr::get(atomType), hint,
makeMemOrderAttr(converter, memOrder));
if (atomType != storeType) {
lower::ExprToValueMap overrides;
// The READ operation could be a part of UPDATE CAPTURE, so make sure
// we don't emit extra code into the body of the atomic op.
builder.restoreInsertionPoint(postAt);
mlir::Value load = builder.create<fir::LoadOp>(loc, toAddr);
overrides.try_emplace(&atom, load);
converter.overrideExprValues(&overrides);
mlir::Value value =
fir::getBase(converter.genExprValue(assign.rhs, stmtCtx, &loc));
converter.resetExprOverrides();
builder.create<fir::StoreOp>(loc, value, storeAddr);
}
return op;
}
static mlir::Operation * //
genAtomicWrite(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx, mlir::Location loc,
lower::StatementContext &stmtCtx, mlir::Value atomAddr,
const semantics::SomeExpr &atom,
const evaluate::Assignment &assign, mlir::IntegerAttr hint,
std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder,
fir::FirOpBuilder::InsertPoint preAt,
fir::FirOpBuilder::InsertPoint atomicAt,
fir::FirOpBuilder::InsertPoint postAt) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
builder.restoreInsertionPoint(preAt);
// If the atomic clause is write then the memory-order clause must
// not be acquire.
if (memOrder) {
if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acquire) {
// Reset it back to the default.
memOrder = getDefaultAtomicMemOrder(semaCtx);
} else if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acq_rel) {
// The MLIR verifier doesn't like acq_rel either.
memOrder = mlir::omp::ClauseMemoryOrderKind::Release;
}
}
mlir::Value value =
fir::getBase(converter.genExprValue(assign.rhs, stmtCtx, &loc));
mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
mlir::Value converted = builder.createConvert(loc, atomType, value);
builder.restoreInsertionPoint(atomicAt);
mlir::Operation *op = builder.create<mlir::omp::AtomicWriteOp>(
loc, atomAddr, converted, hint, makeMemOrderAttr(converter, memOrder));
return op;
}
static mlir::Operation *
genAtomicUpdate(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx, mlir::Location loc,
lower::StatementContext &stmtCtx, mlir::Value atomAddr,
const semantics::SomeExpr &atom,
const evaluate::Assignment &assign, mlir::IntegerAttr hint,
std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder,
fir::FirOpBuilder::InsertPoint preAt,
fir::FirOpBuilder::InsertPoint atomicAt,
fir::FirOpBuilder::InsertPoint postAt) {
lower::ExprToValueMap overrides;
lower::StatementContext naCtx;
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
builder.restoreInsertionPoint(preAt);
mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
// This must exist by now.
semantics::SomeExpr rhs = assign.rhs;
semantics::SomeExpr input = *evaluate::GetConvertInput(rhs);
auto [opcode, args] = evaluate::GetTopLevelOperation(input);
assert(!args.empty() && "Update operation without arguments");
// Pass args as an argument to avoid capturing a structured binding.
const semantics::SomeExpr *atomArg = [&](auto &args) {
for (const semantics::SomeExpr &e : args) {
if (evaluate::IsSameOrConvertOf(e, atom))
return &e;
}
llvm_unreachable("Atomic variable not in argument list");
}(args);
if (opcode == evaluate::operation::Operator::Min ||
opcode == evaluate::operation::Operator::Max) {
// Min and max operations are expanded inline, so reduce them to
// operations with exactly two (non-optional) arguments.
rhs = genReducedMinMax(rhs, atomArg, args);
input = *evaluate::GetConvertInput(rhs);
std::tie(opcode, args) = evaluate::GetTopLevelOperation(input);
atomArg = nullptr; // No longer valid.
}
for (auto &arg : args) {
if (!evaluate::IsSameOrConvertOf(arg, atom)) {
mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc));
overrides.try_emplace(&arg, val);
}
}
builder.restoreInsertionPoint(atomicAt);
auto updateOp = builder.create<mlir::omp::AtomicUpdateOp>(
loc, atomAddr, hint, makeMemOrderAttr(converter, memOrder));
mlir::Region &region = updateOp->getRegion(0);
mlir::Block *block = builder.createBlock(&region, {}, {atomType}, {loc});
mlir::Value localAtom = fir::getBase(block->getArgument(0));
overrides.try_emplace(&atom, localAtom);
converter.overrideExprValues(&overrides);
mlir::Value updated =
fir::getBase(converter.genExprValue(rhs, stmtCtx, &loc));
mlir::Value converted = builder.createConvert(loc, atomType, updated);
builder.create<mlir::omp::YieldOp>(loc, converted);
converter.resetExprOverrides();
builder.restoreInsertionPoint(postAt); // For naCtx cleanups
return updateOp;
}
static mlir::Operation *
genAtomicOperation(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx, mlir::Location loc,
lower::StatementContext &stmtCtx, int action,
mlir::Value atomAddr, const semantics::SomeExpr &atom,
const evaluate::Assignment &assign, mlir::IntegerAttr hint,
std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder,
fir::FirOpBuilder::InsertPoint preAt,
fir::FirOpBuilder::InsertPoint atomicAt,
fir::FirOpBuilder::InsertPoint postAt) {
if (isPointerAssignment(assign)) {
TODO(loc, "Code generation for pointer assignment is not implemented yet");
}
// This function and the functions called here do not preserve the
// builder's insertion point, or set it to anything specific.
switch (action) {
case parser::OpenMPAtomicConstruct::Analysis::Read:
return genAtomicRead(converter, semaCtx, loc, stmtCtx, atomAddr, atom,
assign, hint, memOrder, preAt, atomicAt, postAt);
case parser::OpenMPAtomicConstruct::Analysis::Write:
return genAtomicWrite(converter, semaCtx, loc, stmtCtx, atomAddr, atom,
assign, hint, memOrder, preAt, atomicAt, postAt);
case parser::OpenMPAtomicConstruct::Analysis::Update:
return genAtomicUpdate(converter, semaCtx, loc, stmtCtx, atomAddr, atom,
assign, hint, memOrder, preAt, atomicAt, postAt);
default:
return nullptr;
}
}
void Fortran::lower::omp::lowerAtomic(
AbstractConverter &converter, SymMap &symTable,
semantics::SemanticsContext &semaCtx, pft::Evaluation &eval,
const parser::OpenMPAtomicConstruct &construct) {
auto get = [](auto &&typedWrapper) -> decltype(&*typedWrapper.get()->v) {
if (auto *maybe = typedWrapper.get(); maybe && maybe->v) {
return &*maybe->v;
} else {
return nullptr;
}
};
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
auto &dirSpec = std::get<parser::OmpDirectiveSpecification>(construct.t);
omp::List<omp::Clause> clauses = makeClauses(dirSpec.Clauses(), semaCtx);
lower::StatementContext stmtCtx;
const parser::OpenMPAtomicConstruct::Analysis &analysis = construct.analysis;
if (DumpAtomicAnalysis)
dumpAtomicAnalysis(analysis);
const semantics::SomeExpr &atom = *get(analysis.atom);
mlir::Location loc = converter.genLocation(construct.source);
mlir::Value atomAddr =
fir::getBase(converter.genExprAddr(atom, stmtCtx, &loc));
mlir::IntegerAttr hint = getAtomicHint(converter, clauses);
std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder =
getAtomicMemoryOrder(semaCtx, clauses,
semaCtx.FindScope(construct.source));
if (auto *cond = get(analysis.cond)) {
(void)cond;
TODO(loc, "OpenMP ATOMIC COMPARE");
} else {
int action0 = analysis.op0.what & analysis.Action;
int action1 = analysis.op1.what & analysis.Action;
mlir::Operation *captureOp = nullptr;
fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
fir::FirOpBuilder::InsertPoint atomicAt, postAt;
if (construct.IsCapture()) {
// Capturing operation.
assert(action0 != analysis.None && action1 != analysis.None &&
"Expexcing two actions");
(void)action0;
(void)action1;
captureOp = builder.create<mlir::omp::AtomicCaptureOp>(
loc, hint, makeMemOrderAttr(converter, memOrder));
// Set the non-atomic insertion point to before the atomic.capture.
preAt = getInsertionPointBefore(captureOp);
mlir::Block *block = builder.createBlock(&captureOp->getRegion(0));
builder.setInsertionPointToEnd(block);
// Set the atomic insertion point to before the terminator inside
// atomic.capture.
mlir::Operation *term = builder.create<mlir::omp::TerminatorOp>(loc);
atomicAt = getInsertionPointBefore(term);
postAt = getInsertionPointAfter(captureOp);
hint = nullptr;
memOrder = std::nullopt;
} else {
// Non-capturing operation.
assert(action0 != analysis.None && action1 == analysis.None &&
"Expexcing single action");
assert(!(analysis.op0.what & analysis.Condition));
postAt = atomicAt = preAt;
}
// The builder's insertion point needs to be specifically set before
// each call to `genAtomicOperation`.
mlir::Operation *firstOp = genAtomicOperation(
converter, semaCtx, loc, stmtCtx, analysis.op0.what, atomAddr, atom,
*get(analysis.op0.assign), hint, memOrder, preAt, atomicAt, postAt);
assert(firstOp && "Should have created an atomic operation");
atomicAt = getInsertionPointAfter(firstOp);
mlir::Operation *secondOp = nullptr;
if (analysis.op1.what != analysis.None) {
secondOp = genAtomicOperation(
converter, semaCtx, loc, stmtCtx, analysis.op1.what, atomAddr, atom,
*get(analysis.op1.assign), hint, memOrder, preAt, atomicAt, postAt);
}
if (construct.IsCapture()) {
// If this is a capture operation, the first/second ops will be inside
// of it. Set the insertion point to past the capture op itself.
builder.restoreInsertionPoint(postAt);
} else {
if (secondOp) {
builder.setInsertionPointAfter(secondOp);
} else {
builder.setInsertionPointAfter(firstOp);
}
}
}
}