Skip to content

Commit

Permalink
Revert "[OpenMP][libc] Remove special handling for OpenMP printf (llv…
Browse files Browse the repository at this point in the history
…m#98940)"

This reverts commit 069e8bc.

Summary:
Some tests failing, revert this for now.
  • Loading branch information
jhuber6 committed Jul 26, 2024
1 parent 1978c21 commit fea5914
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 16 deletions.
2 changes: 2 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5986,6 +5986,8 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
getTarget().getTriple().isAMDGCN() ||
(getTarget().getTriple().isSPIRV() &&
getTarget().getTriple().getVendor() == Triple::VendorType::AMD)) {
if (getLangOpts().OpenMPIsTargetDevice)
return EmitOpenMPDevicePrintfCallExpr(E);
if (getTarget().getTriple().isNVPTX())
return EmitNVPTXDevicePrintfCallExpr(E);
if ((getTarget().getTriple().isAMDGCN() ||
Expand Down
29 changes: 29 additions & 0 deletions clang/lib/CodeGen/CGGPUBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ llvm::Function *GetVprintfDeclaration(llvm::Module &M) {
VprintfFuncType, llvm::GlobalVariable::ExternalLinkage, "vprintf", &M);
}

llvm::Function *GetOpenMPVprintfDeclaration(CodeGenModule &CGM) {
const char *Name = "__llvm_omp_vprintf";
llvm::Module &M = CGM.getModule();
llvm::Type *ArgTypes[] = {llvm::PointerType::getUnqual(M.getContext()),
llvm::PointerType::getUnqual(M.getContext()),
llvm::Type::getInt32Ty(M.getContext())};
llvm::FunctionType *VprintfFuncType = llvm::FunctionType::get(
llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false);

if (auto *F = M.getFunction(Name)) {
if (F->getFunctionType() != VprintfFuncType) {
CGM.Error(SourceLocation(),
"Invalid type declaration for __llvm_omp_vprintf");
return nullptr;
}
return F;
}

return llvm::Function::Create(
VprintfFuncType, llvm::GlobalVariable::ExternalLinkage, Name, &M);
}

// Transforms a call to printf into a call to the NVPTX vprintf syscall (which
// isn't particularly special; it's invoked just like a regular function).
// vprintf takes two args: A format string, and a pointer to a buffer containing
Expand Down Expand Up @@ -191,3 +213,10 @@ RValue CodeGenFunction::EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E) {
Builder.SetInsertPoint(IRB.GetInsertBlock(), IRB.GetInsertPoint());
return RValue::get(Printf);
}

RValue CodeGenFunction::EmitOpenMPDevicePrintfCallExpr(const CallExpr *E) {
assert(getTarget().getTriple().isNVPTX() ||
getTarget().getTriple().isAMDGCN());
return EmitDevicePrintfCallExpr(E, this, GetOpenMPVprintfDeclaration(CGM),
true);
}
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4536,6 +4536,7 @@ class CodeGenFunction : public CodeGenTypeCache {

RValue EmitNVPTXDevicePrintfCallExpr(const CallExpr *E);
RValue EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E);
RValue EmitOpenMPDevicePrintfCallExpr(const CallExpr *E);

RValue EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
const CallExpr *E, ReturnValueSlot ReturnValue);
Expand Down
1 change: 1 addition & 0 deletions libc/config/gpu/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ set(TARGET_LIBC_ENTRYPOINTS

# gpu/rpc.h entrypoints
libc.src.gpu.rpc_host_call
libc.src.gpu.rpc_fprintf
)

set(TARGET_LIBM_ENTRYPOINTS
Expand Down
8 changes: 8 additions & 0 deletions libc/spec/gpu_ext.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ def GPUExtensions : StandardSpec<"GPUExtensions"> {
RetValSpec<VoidType>,
[ArgSpec<VoidPtr>, ArgSpec<VoidPtr>, ArgSpec<SizeTType>]
>,
FunctionSpec<
"rpc_fprintf",
RetValSpec<IntType>,
[ArgSpec<FILERestrictedPtr>,
ArgSpec<ConstCharRestrictedPtr>,
ArgSpec<VoidPtr>,
ArgSpec<SizeTType>]
>,
]
>;
let Headers = [
Expand Down
12 changes: 12 additions & 0 deletions libc/src/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,15 @@ add_entrypoint_object(
libc.src.__support.RPC.rpc_client
libc.src.__support.GPU.utils
)

add_entrypoint_object(
rpc_fprintf
SRCS
rpc_fprintf.cpp
HDRS
rpc_fprintf.h
DEPENDS
libc.src.stdio.gpu.gpu_file
libc.src.__support.RPC.rpc_client
libc.src.__support.GPU.utils
)
75 changes: 75 additions & 0 deletions libc/src/gpu/rpc_fprintf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//===-- GPU implementation of fprintf -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "rpc_fprintf.h"

#include "src/__support/CPP/string_view.h"
#include "src/__support/GPU/utils.h"
#include "src/__support/RPC/rpc_client.h"
#include "src/__support/common.h"
#include "src/__support/macros/config.h"
#include "src/stdio/gpu/file.h"

namespace LIBC_NAMESPACE_DECL {

template <uint16_t opcode>
int fprintf_impl(::FILE *__restrict file, const char *__restrict format,
size_t format_size, void *args, size_t args_size) {
uint64_t mask = gpu::get_lane_mask();
rpc::Client::Port port = rpc::client.open<opcode>();

if constexpr (opcode == RPC_PRINTF_TO_STREAM) {
port.send([&](rpc::Buffer *buffer) {
buffer->data[0] = reinterpret_cast<uintptr_t>(file);
});
}

port.send_n(format, format_size);
port.recv([&](rpc::Buffer *buffer) {
args_size = static_cast<size_t>(buffer->data[0]);
});
port.send_n(args, args_size);

uint32_t ret = 0;
for (;;) {
const char *str = nullptr;
port.recv([&](rpc::Buffer *buffer) {
ret = static_cast<uint32_t>(buffer->data[0]);
str = reinterpret_cast<const char *>(buffer->data[1]);
});
// If any lanes have a string argument it needs to be copied back.
if (!gpu::ballot(mask, str))
break;

uint64_t size = str ? internal::string_length(str) + 1 : 0;
port.send_n(str, size);
}

port.close();
return ret;
}

// TODO: Delete this and port OpenMP to use `printf`.
// place of varargs. Once varargs support is added we will use that to
// implement the real version.
LLVM_LIBC_FUNCTION(int, rpc_fprintf,
(::FILE *__restrict stream, const char *__restrict format,
void *args, size_t size)) {
cpp::string_view str(format);
if (stream == stdout)
return fprintf_impl<RPC_PRINTF_TO_STDOUT>(stream, format, str.size() + 1,
args, size);
else if (stream == stderr)
return fprintf_impl<RPC_PRINTF_TO_STDERR>(stream, format, str.size() + 1,
args, size);
else
return fprintf_impl<RPC_PRINTF_TO_STREAM>(stream, format, str.size() + 1,
args, size);
}

} // namespace LIBC_NAMESPACE_DECL
23 changes: 23 additions & 0 deletions libc/src/gpu/rpc_fprintf.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===-- Implementation header for RPC functions -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_SRC_GPU_RPC_HOST_CALL_H
#define LLVM_LIBC_SRC_GPU_RPC_HOST_CALL_H

#include "hdr/types/FILE.h"
#include "src/__support/macros/config.h"
#include <stddef.h>

namespace LIBC_NAMESPACE_DECL {

int rpc_fprintf(::FILE *__restrict stream, const char *__restrict format,
void *argc, size_t size);

} // namespace LIBC_NAMESPACE_DECL

#endif // LLVM_LIBC_SRC_GPU_RPC_HOST_CALL_H
3 changes: 1 addition & 2 deletions llvm/lib/Target/AMDGPU/AMDGPUPrintfRuntimeBinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,7 @@ bool AMDGPUPrintfRuntimeBindingImpl::run(Module &M) {
return false;

auto PrintfFunction = M.getFunction("printf");
if (!PrintfFunction || !PrintfFunction->isDeclaration() ||
M.getModuleFlag("openmp"))
if (!PrintfFunction || !PrintfFunction->isDeclaration())
return false;

for (auto &U : PrintfFunction->uses()) {
Expand Down
1 change: 1 addition & 0 deletions offload/DeviceRTL/include/LibC.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ extern "C" {

int memcmp(const void *lhs, const void *rhs, size_t count);
void memset(void *dst, int C, size_t count);

int printf(const char *format, ...);
}

Expand Down
44 changes: 30 additions & 14 deletions offload/DeviceRTL/src/LibC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,44 @@
#pragma omp begin declare target device_type(nohost)

namespace impl {
int32_t omp_vprintf(const char *Format, __builtin_va_list vlist);
int32_t omp_vprintf(const char *Format, void *Arguments, uint32_t);
}

#ifndef OMPTARGET_HAS_LIBC
namespace impl {
#pragma omp begin declare variant match( \
device = {arch(nvptx, nvptx64)}, \
implementation = {extension(match_any)})
extern "C" int vprintf(const char *format, ...);
int omp_vprintf(const char *Format, __builtin_va_list vlist) {
return vprintf(Format, vlist);
extern "C" int32_t vprintf(const char *, void *);
namespace impl {
int32_t omp_vprintf(const char *Format, void *Arguments, uint32_t) {
return vprintf(Format, Arguments);
}
} // namespace impl
#pragma omp end declare variant

#pragma omp begin declare variant match(device = {arch(amdgcn)})
int omp_vprintf(const char *Format, __builtin_va_list) { return -1; }
#pragma omp end declare variant
} // namespace impl

extern "C" int printf(const char *Format, ...) {
__builtin_va_list vlist;
__builtin_va_start(vlist, Format);
return impl::omp_vprintf(Format, vlist);
#ifdef OMPTARGET_HAS_LIBC
// TODO: Remove this handling once we have varargs support.
extern "C" struct FILE *stdout;
extern "C" int32_t rpc_fprintf(FILE *, const char *, void *, uint64_t);

namespace impl {
int32_t omp_vprintf(const char *Format, void *Arguments, uint32_t Size) {
return rpc_fprintf(stdout, Format, Arguments, Size);
}
#endif // OMPTARGET_HAS_LIBC
} // namespace impl
#else
// We do not have a vprintf implementation for AMD GPU so we use a stub.
namespace impl {
int32_t omp_vprintf(const char *Format, void *Arguments, uint32_t) {
return -1;
}
} // namespace impl
#endif
#pragma omp end declare variant

extern "C" {

[[gnu::weak]] int memcmp(const void *lhs, const void *rhs, size_t count) {
auto *L = reinterpret_cast<const unsigned char *>(lhs);
auto *R = reinterpret_cast<const unsigned char *>(rhs);
Expand All @@ -54,6 +65,11 @@ extern "C" {
for (size_t I = 0; I < count; ++I)
dstc[I] = C;
}

/// printf() calls are rewritten by CGGPUBuiltin to __llvm_omp_vprintf
int32_t __llvm_omp_vprintf(const char *Format, void *Arguments, uint32_t Size) {
return impl::omp_vprintf(Format, Arguments, Size);
}
}

#pragma omp end declare target

0 comments on commit fea5914

Please sign in to comment.