Skip to content

Commit

Permalink
refine functions
Browse files Browse the repository at this point in the history
  • Loading branch information
linzs148 committed Apr 30, 2024
1 parent a8fd7a3 commit bea5080
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 218 deletions.
288 changes: 144 additions & 144 deletions oneflow/core/device/cudnn_conv_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,90 @@ CudnnConvArgs::CudnnConvArgs(const user_op::KernelComputeContext& ctx, DataType
params.max_ws_size = max_workspace_size;
}

cudnn_frontend::Tensor GetTensorDescriptor(const user_op::Tensor* t, const int64_t id) {
auto dim = t->shape_view();
auto stride = t->stride();
return cudnn_frontend::TensorBuilder()
.setDim(dim.size(), dim.data())
.setStride(stride.size(), stride.data())
.setId(id)
.setAlignment(32)
.setDataType(GetCudnnDataType(t->data_type()))
.build();
}

cudnn_frontend::Tensor GetTensorDescriptor(const user_op::TensorDesc& t, const int64_t id) {
auto dim = t.shape();
auto stride = t.stride();
return cudnn_frontend::TensorBuilder()
.setDim(dim.size(), dim.data())
.setStride(stride.size(), stride.data())
.setId(id)
.setAlignment(32)
.setDataType(GetCudnnDataType(t.data_type()))
.build();
}

cudnn_frontend::ConvDesc GetConvDescriptor(const user_op::InferContext& ctx,
cudnnDataType_t data_type) {
if (data_type == CUDNN_DATA_HALF || data_type == CUDNN_DATA_BFLOAT16) {
data_type = CUDNN_DATA_FLOAT;
}

std::vector<int64_t> padding;
const auto& padding_before = ctx.Attr<std::vector<int32_t>>("padding_before");
copy(padding_before.begin(), padding_before.end(), back_inserter(padding));

std::vector<int64_t> stride;
const auto& strides = ctx.Attr<std::vector<int32_t>>("strides");
copy(strides.begin(), strides.end(), back_inserter(stride));

std::vector<int64_t> dilation;
const auto& dilation_rate = ctx.Attr<std::vector<int32_t>>("dilation_rate");
copy(dilation_rate.begin(), dilation_rate.end(), back_inserter(dilation));

uint64_t ndim = stride.size();
return cudnn_frontend::ConvDescBuilder()
.setDataType(data_type)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(ndim)
.setStrides(ndim, stride.data())
.setPrePadding(ndim, padding.data())
.setPostPadding(ndim, padding.data())
.setDilation(ndim, dilation.data())
.build();
}

cudnn_frontend::ConvDesc GetConvDescriptor(const user_op::KernelComputeContext& ctx,
cudnnDataType_t data_type) {
if (data_type == CUDNN_DATA_HALF || data_type == CUDNN_DATA_BFLOAT16) {
data_type = CUDNN_DATA_FLOAT;
}

std::vector<int64_t> padding;
const auto& padding_before = ctx.Attr<std::vector<int32_t>>("padding_before");
copy(padding_before.begin(), padding_before.end(), back_inserter(padding));

std::vector<int64_t> stride;
const auto& strides = ctx.Attr<std::vector<int32_t>>("strides");
copy(strides.begin(), strides.end(), back_inserter(stride));

std::vector<int64_t> dilation;
const auto& dilation_rate = ctx.Attr<std::vector<int32_t>>("dilation_rate");
copy(dilation_rate.begin(), dilation_rate.end(), back_inserter(dilation));

uint64_t ndim = stride.size();
return cudnn_frontend::ConvDescBuilder()
.setDataType(data_type)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(ndim)
.setStrides(ndim, stride.data())
.setPrePadding(ndim, padding.data())
.setPostPadding(ndim, padding.data())
.setDilation(ndim, dilation.data())
.build();
}

CudnnConvArgsV8::CudnnConvArgsV8(const user_op::InferContext& ctx, const user_op::TensorDesc& x,
const user_op::TensorDesc& y, const user_op::TensorDesc& w)
: xdesc(GetTensorDescriptor(x, 'x')),
Expand Down Expand Up @@ -443,30 +527,6 @@ cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvReso
args.wdesc.Get(), algo, sz);
}

void RunSingleConv(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc,
user_op::Tensor* x, user_op::Tensor* y, user_op::Tensor* w, user_op::Tensor* b,
const CudnnConvArgsV8& args) {
std::string tag;
auto configs =
GetConfigs(handle, desc, args.xdesc, args.ydesc, args.wdesc, args.cdesc, args.beta, tag);
TryConfigs(handle, x, y, w, b, configs, tag);
}

cudnn_frontend::EngineConfigList GetConfigs(const cudnnHandle_t handle,
const cudnnBackendDescriptorType_t desc,
const cudnn_frontend::Tensor& xdesc,
const cudnn_frontend::Tensor& ydesc,
const cudnn_frontend::Tensor& wdesc,
const cudnn_frontend::ConvDesc& cdesc, float beta,
std::string& tag) {
auto op_graph = BuildConvOpGraph(handle, desc, xdesc, ydesc, wdesc, cdesc, beta);
tag = op_graph.getTag();
auto sources = GetGeneratorSources(desc);
cudnn_frontend::EngineConfigGenerator generator(sources.size(), sources.data());
auto configs = generator.generate_engine_config(op_graph);
return configs;
}

cudnn_frontend::OperationGraph BuildConvOpGraph(const cudnnHandle_t handle,
const cudnnBackendDescriptorType_t desc,
const cudnn_frontend::Tensor& xdesc,
Expand All @@ -488,88 +548,20 @@ cudnn_frontend::OperationGraph BuildConvOpGraph(const cudnnHandle_t handle,
return op_graph;
}

cudnn_frontend::Tensor GetTensorDescriptor(const user_op::Tensor* t, const int64_t id) {
auto dim = t->shape_view();
auto stride = t->stride();
return cudnn_frontend::TensorBuilder()
.setDim(dim.size(), dim.data())
.setStride(stride.size(), stride.data())
.setId(id)
.setAlignment(32)
.setDataType(GetCudnnDataType(t->data_type()))
.build();
}

cudnn_frontend::Tensor GetTensorDescriptor(const user_op::TensorDesc& t, const int64_t id) {
auto dim = t.shape();
auto stride = t.stride();
return cudnn_frontend::TensorBuilder()
.setDim(dim.size(), dim.data())
.setStride(stride.size(), stride.data())
.setId(id)
.setAlignment(32)
.setDataType(GetCudnnDataType(t.data_type()))
.build();
}

cudnn_frontend::ConvDesc GetConvDescriptor(const user_op::InferContext& ctx,
cudnnDataType_t data_type) {
if (data_type == CUDNN_DATA_HALF || data_type == CUDNN_DATA_BFLOAT16) {
data_type = CUDNN_DATA_FLOAT;
}

std::vector<int64_t> padding;
const auto& padding_before = ctx.Attr<std::vector<int32_t>>("padding_before");
copy(padding_before.begin(), padding_before.end(), back_inserter(padding));

std::vector<int64_t> stride;
const auto& strides = ctx.Attr<std::vector<int32_t>>("strides");
copy(strides.begin(), strides.end(), back_inserter(stride));

std::vector<int64_t> dilation;
const auto& dilation_rate = ctx.Attr<std::vector<int32_t>>("dilation_rate");
copy(dilation_rate.begin(), dilation_rate.end(), back_inserter(dilation));

uint64_t ndim = stride.size();
return cudnn_frontend::ConvDescBuilder()
.setDataType(data_type)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(ndim)
.setStrides(ndim, stride.data())
.setPrePadding(ndim, padding.data())
.setPostPadding(ndim, padding.data())
.setDilation(ndim, dilation.data())
.build();
}

cudnn_frontend::ConvDesc GetConvDescriptor(const user_op::KernelComputeContext& ctx,
cudnnDataType_t data_type) {
if (data_type == CUDNN_DATA_HALF || data_type == CUDNN_DATA_BFLOAT16) {
data_type = CUDNN_DATA_FLOAT;
}

std::vector<int64_t> padding;
const auto& padding_before = ctx.Attr<std::vector<int32_t>>("padding_before");
copy(padding_before.begin(), padding_before.end(), back_inserter(padding));

std::vector<int64_t> stride;
const auto& strides = ctx.Attr<std::vector<int32_t>>("strides");
copy(strides.begin(), strides.end(), back_inserter(stride));

std::vector<int64_t> dilation;
const auto& dilation_rate = ctx.Attr<std::vector<int32_t>>("dilation_rate");
copy(dilation_rate.begin(), dilation_rate.end(), back_inserter(dilation));

uint64_t ndim = stride.size();
return cudnn_frontend::ConvDescBuilder()
.setDataType(data_type)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(ndim)
.setStrides(ndim, stride.data())
.setPrePadding(ndim, padding.data())
.setPostPadding(ndim, padding.data())
.setDilation(ndim, dilation.data())
.build();
void FilterEngineConfigs(cudnn_frontend::EngineConfigList& from,
cudnn_frontend::EngineConfigList& to, bool deterministic) {
auto filter = [=](cudnnBackendDescriptor_t c) {
if (deterministic) {
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(c)) {
return true;
}
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
return false;
};
cudnn_frontend::filter(from, to, filter);
}

std::vector<cudnn_frontend::GeneratorSource> GetGeneratorSources(
Expand All @@ -579,7 +571,7 @@ std::vector<cudnn_frontend::GeneratorSource> GetGeneratorSources(
.cudnn_conf()
.cudnn_conv_use_deterministic_algo_only();
bool heuristic = ParseBooleanFromEnv("ONEFLOW_CUDNN_USE_HEURISTIC_MODE_B", false);
auto heur_mode = heuristic ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_A;
auto heur_mode = heuristic ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_INSTANT;
// Method for engine config generator based on heuristics
const auto heurgen_method =
[deterministic,
Expand Down Expand Up @@ -610,20 +602,43 @@ std::vector<cudnn_frontend::GeneratorSource> GetGeneratorSources(
return sources;
}

void FilterEngineConfigs(cudnn_frontend::EngineConfigList& from,
cudnn_frontend::EngineConfigList& to, bool deterministic) {
auto filter = [=](cudnnBackendDescriptor_t c) {
if (deterministic) {
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(c)) {
return true;
}
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
cudnn_frontend::EngineConfigList CudnnFrontendGetConfigs(const cudnnHandle_t handle,
const cudnnBackendDescriptorType_t desc,
const cudnn_frontend::Tensor& xdesc,
const cudnn_frontend::Tensor& ydesc,
const cudnn_frontend::Tensor& wdesc,
const cudnn_frontend::ConvDesc& cdesc,
float beta, std::string& tag) {
auto op_graph = BuildConvOpGraph(handle, desc, xdesc, ydesc, wdesc, cdesc, beta);
tag = op_graph.getTag();
auto sources = GetGeneratorSources(desc);
cudnn_frontend::EngineConfigGenerator generator(sources.size(), sources.data());
auto configs = generator.generate_engine_config(op_graph);
return configs;
}

bool PlanErrataException(const cudnnHandle_t handle, const std::string& executionPlanTag) {
static nlohmann::json errata_json_handle;
static bool has_json = cudnn_frontend::load_from_config(errata_json_handle, "");
if (!has_json) {
return false;
};
cudnn_frontend::filter(from, to, filter);
} else {
return cudnn_frontend::check_errata(errata_json_handle, executionPlanTag, handle,
[]() { return true; });
}
}

void RunConvPlan(const cudnnHandle_t handle, user_op::Tensor* x, user_op::Tensor* y,
user_op::Tensor* w, user_op::Tensor* buf,
const cudnn_frontend::ExecutionPlan& plan) {
void* data[] = {x->mut_dptr(), y->mut_dptr(), w->mut_dptr()};
int64_t ids[] = {'x', 'y', 'w'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(buf->mut_dptr())
.setDataPointers(3, data)
.setUids(3, ids)
.build();
OF_CUDNN_CHECK(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
}

void TryConfigs(const cudnnHandle_t handle, user_op::Tensor* x, user_op::Tensor* y,
Expand All @@ -642,6 +657,15 @@ void TryConfigs(const cudnnHandle_t handle, user_op::Tensor* x, user_op::Tensor*
}
}

void CudnnFrontendRunConv(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc,
user_op::Tensor* x, user_op::Tensor* y, user_op::Tensor* w,
user_op::Tensor* b, const CudnnConvArgsV8& args) {
std::string tag;
auto configs = CudnnFrontendGetConfigs(handle, desc, args.xdesc, args.ydesc, args.wdesc,
args.cdesc, args.beta, tag);
TryConfigs(handle, x, y, w, b, configs, tag);
}

size_t GetCudnnConvWorkspaceSizeV8(const cudnnHandle_t handle,
cudnn_frontend::EngineConfigList& configs,
const std::string& tag) {
Expand All @@ -658,30 +682,6 @@ size_t GetCudnnConvWorkspaceSizeV8(const cudnnHandle_t handle,
return 1L;
}

bool PlanErrataException(const cudnnHandle_t handle, const std::string& executionPlanTag) {
static nlohmann::json errata_json_handle;
static bool has_json = cudnn_frontend::load_from_config(errata_json_handle, "");
if (!has_json) {
return false;
} else {
return cudnn_frontend::check_errata(errata_json_handle, executionPlanTag, handle,
[]() { return true; });
}
}

void RunConvPlan(const cudnnHandle_t handle, user_op::Tensor* x, user_op::Tensor* y,
user_op::Tensor* w, user_op::Tensor* buf,
const cudnn_frontend::ExecutionPlan& plan) {
void* data[] = {x->mut_dptr(), y->mut_dptr(), w->mut_dptr()};
int64_t ids[] = {'x', 'y', 'w'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(buf->mut_dptr())
.setDataPointers(3, data)
.setUids(3, ids)
.build();
OF_CUDNN_CHECK(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
}

template<>
struct CudnnConvAlgorithmSearch<cudnnConvolutionFwdAlgoPerf_t> {
using perf_t = cudnnConvolutionFwdAlgoPerf_t;
Expand Down
Loading

0 comments on commit bea5080

Please sign in to comment.