Skip to content

Commit

Permalink
benchdnn: graph: extend --dt to support ID:DT pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
TaoLv committed Jan 16, 2025
1 parent 283cf37 commit 129575f
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 12 deletions.
8 changes: 8 additions & 0 deletions tests/benchdnn/doc/driver_graph.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ where *graph-knobs* are:
field of all logical tensors in the input JSON file from `f32` to `f16`. If
`--dt` is not specified or specified as `undef`, the original data types
contained in the input JSON file will be used for testing.
- `--dt=ID:DT[+ID:DT]` -- Another format to specify the data types in the
input JSON file. `ID` specifies the input or output tensor of an operation
in the JSON file. `DT` is the target data type. To specify the data types of
multiple tensors, use `+` to concatenate the `ID` and `DT` pairs. An error
will occur if `ID` is not contained in the JSON file. oneDNN operations have
restrictions for input and output tensor data types. Changing the data type
of a tensor may lead to graph construction failures, for example, failure to
perform the `add_op()` operation in the graph.

and *graph-case* is a JSON file which is dumped by a library or created from scratch.
It must be passed to the graph driver as `--case=JSON_FILE`. Refer to the JSON
Expand Down
8 changes: 5 additions & 3 deletions tests/benchdnn/graph/bench_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,18 @@ void check_correctness(const settings_t &s) {
for_(const auto &i_expected_n_partition : s.expected_n_partition_vec)
for_(const auto &i_fpmath_mode : s.fpmath_mode_vec)
for_(const auto &i_dt : s.dt)
for_(const auto &i_dt_map : s.dt_map)
for (const auto &i_mb : s.mb) {
deserialized_graph dg;
dg.load(locate_file(s.json_file));
flex_rewrite fw(i_in_shapes, i_op_attrs, i_fpmath_mode, i_mb, i_dt);
flex_rewrite fw(
i_in_shapes, i_op_attrs, i_fpmath_mode, i_mb, i_dt, i_dt_map);
fw.rewrite(dg);
BENCHDNN_PRINT(7, "[INFO] Graph dump:\n%s\n", dg.get_string().c_str());

const prb_t prb(dg, i_expected_n_partition);
const auto &cpp_pstr = case_to_str(s.json_file, i_in_shapes, i_op_attrs,
i_fpmath_mode, i_expected_n_partition, i_mb, i_dt);
i_fpmath_mode, i_expected_n_partition, i_mb, i_dt, i_dt_map);
const char *pstr = cpp_pstr.c_str();
BENCHDNN_PRINT(1, "run: %s\n", pstr);
res_t res {};
Expand Down Expand Up @@ -76,7 +78,7 @@ int bench(int argc, char **argv) {
for (; argc > 0; --argc, ++argv) {
const bool parsed_options = parse_bench_settings(argv[0])
|| parse_batch(bench, argv[0])
|| parse_dt(s.dt, def.dt, argv[0])
|| parse_dt(s.dt, s.dt_map, argv[0])
|| parse_input_shapes(s.in_shapes_vec, argv[0])
|| parse_op_attrs(s.op_attrs_vec, argv[0])
|| parse_graph_expected_n_partitions(
Expand Down
40 changes: 40 additions & 0 deletions tests/benchdnn/graph/flex_rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void flex_rewrite::rewrite(deserialized_graph &dgraph) {
quantized_graph_rewrite(dgraph);
graph_attrs_rewrite(dgraph);
dt_rewrite(dgraph);
dt_map_rewrite(dgraph);
}

void flex_rewrite::split_ncx(const std::string &data_format, dims_t &in,
Expand Down Expand Up @@ -1162,6 +1163,45 @@ void flex_rewrite::dt_rewrite(deserialized_graph &dgraph) {
}
}

void flex_rewrite::dt_map_rewrite(deserialized_graph &dgraph) {
// check the IDs and data types in dt_map.
for (const auto &v : dt_map_) {
if (v.second == dnnl_data_type_undef) return;

bool found_id = false;
for (auto &aop : dgraph.ops_) {
for (auto &lt : aop.in_lts_) {
if (lt.id_ == v.first) found_id = true;
}

for (auto &lt : aop.out_lts_) {
if (lt.id_ == v.first) found_id = true;
}
}

if (!found_id) {
BENCHDNN_PRINT(0,
"graph: rewrite: ID `%zd` is not found in the graph\n",
v.first);
SAFE_V(FAIL);
}
}

// rewrite
for (const auto &v : dt_map_) {
const std::string str_dt(dt2str(v.second));
for (auto &aop : dgraph.ops_) {
for (auto &lt : aop.in_lts_) {
if (lt.id_ == v.first) lt.data_type_ = str_dt;
}

for (auto &lt : aop.out_lts_) {
if (lt.id_ == v.first) lt.data_type_ = str_dt;
}
}
}
}

void flex_rewrite::graph_attrs_rewrite(deserialized_graph &dgraph) {

// if the fpmath mode is specified by users through cml, replace the fpmath
Expand Down
14 changes: 9 additions & 5 deletions tests/benchdnn/graph/flex_rewrite.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022-2024 Intel Corporation
* Copyright 2022-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,13 +29,14 @@ struct flex_rewrite {
flex_rewrite(const std::map<size_t, std::string> &in_shapes,
const std::map<size_t, std::string> &op_attrs,
const graph_fpmath_mode_t &fpmath_mode, const int64_t mb,
const dnnl_data_type_t dt)

const dnnl_data_type_t dt,
const std::map<size_t, dnnl_data_type_t> &dt_map)
: in_shapes_(in_shapes)
, op_attrs_(op_attrs)
, fpmath_mode_(fpmath_mode)
, mb_(mb)
, dt_(dt) {}
, dt_(dt)
, dt_map_(dt_map) {}

void rewrite(deserialized_graph &dgraph);

Expand All @@ -46,7 +47,9 @@ struct flex_rewrite {
std::map<size_t, std::string> op_attrs_;
graph_fpmath_mode_t fpmath_mode_;
int64_t mb_;
dnnl_data_type_t dt_;
dnnl_data_type_t dt_; // Updates whole graph with a single dt value.
std::map<size_t, dnnl_data_type_t>
dt_map_; // Updates specific LT with selected dt values.

void split_ncx(const std::string &data_format, dims_t &in, int64_t &n,
int64_t &c, dims_t &x) const;
Expand All @@ -70,6 +73,7 @@ struct flex_rewrite {
bool change_stride);
void graph_attrs_rewrite(deserialized_graph &dgraph);
void dt_rewrite(deserialized_graph &dgraph);
void dt_map_rewrite(deserialized_graph &dgraph);
};

} // namespace graph
Expand Down
14 changes: 13 additions & 1 deletion tests/benchdnn/graph/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,26 @@ std::string case_to_str(const std::string &json_file,
const std::map<size_t, std::string> &op_attrs,
const graph_fpmath_mode_t &fpmath_mode,
const size_t expected_n_partitions, const int64_t mb,
const dnnl_data_type_t dt) {
const dnnl_data_type_t dt,
const std::map<size_t, dnnl_data_type_t> &dt_map) {
std::stringstream s;
dump_global_params(s);

if (mb != 0) { s << "--mb=" << mb << " "; }

if (dt != dnnl_data_type_undef) { s << "--dt=" << dt << " "; }

const bool skip_dts = dt_map.empty()
|| (dt_map.size() == 1 && dt_map.count(SIZE_MAX) == 1);
if (!skip_dts) {
s << "--dt=";
std::string tmp;
for (const auto &v : dt_map) {
tmp += (std::to_string(v.first) + ":" + dt2str(v.second) + "+");
}
s << tmp.substr(0, tmp.length() - 1) << " ";
}

if (!(in_shapes.size() == 1 && in_shapes.count(0)
&& in_shapes.at(0) == "default")) {
s << "--in-shapes=";
Expand Down
5 changes: 4 additions & 1 deletion tests/benchdnn/graph/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ struct settings_t : public base_settings_t {
std::vector<size_t> expected_n_partition_vec {1};
std::vector<graph_fpmath_mode_t> fpmath_mode_vec {graph_fpmath_mode_t {}};
std::vector<dnnl_data_type_t> dt {dnnl_data_type_undef};
std::vector<std::map<size_t, dnnl_data_type_t>> dt_map {
{{SIZE_MAX, dnnl_data_type_undef}}};

const char *perf_template_csv
= "perf,%engine%,%DESC%,"
Expand Down Expand Up @@ -83,7 +85,8 @@ std::string case_to_str(const std::string &json_file,
const std::map<size_t, std::string> &op_attrs,
const graph_fpmath_mode_t &fpmath_mode,
const size_t expected_n_partitions, const int64_t mb,
const dnnl_data_type_t dt);
const dnnl_data_type_t dt,
const std::map<size_t, dnnl_data_type_t> &dt_map);

struct perf_report_t : public base_perf_report_t {
perf_report_t(const std::string case_str, const char *perf_template)
Expand Down
46 changes: 45 additions & 1 deletion tests/benchdnn/graph/parser.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022-2024 Intel Corporation
* Copyright 2022-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -117,6 +117,50 @@ bool parse_op_attrs(std::vector<std::map<size_t, std::string>> &op_attrs_vec,
return parse_key_value(op_attrs_vec, op_attrs_str), true;
}

bool parse_dt(std::vector<dnnl_data_type_t> &dt,
std::vector<std::map<size_t, dnnl_data_type_t>> &dt_map,
const char *str, const std::string &option_name) {
std::string dts_str;
if (!parse_string(dts_str, str, option_name)) return false;

if (dts_str.find(":") == std::string::npos) {
// `dt` object: format like --dt=f32,bf16,f16
const bool has_dt_map
= dt_map.size() != 1 || (dt_map[0].count(SIZE_MAX) == 0);
if (has_dt_map) {
BENCHDNN_PRINT(0, "%s\n",
"Error: --dt is specified twice with different styles.");
SAFE_V(FAIL);
}

const std::vector<dnnl_data_type_t> def_dt = {dnnl_data_type_undef};
parser::parse_dt(dt, def_dt, str, option_name);
} else {
// `dt_map` object: format like --dt=0:f32+1:f32,0:f16+1:f16
const bool has_dt = dt.size() != 1 || dt[0] != dnnl_data_type_undef;
if (has_dt) {
BENCHDNN_PRINT(0, "%s\n",
"Error: --dt is specified twice with different styles.");
SAFE_V(FAIL);
}

std::vector<std::map<size_t, std::string>> dts_tmp;
parse_key_value(dts_tmp, dts_str, option_name);
dt_map.clear();
dt_map.resize(dts_tmp.size());
// convert size_t:string to size_t:dnnl_data_type_t
for (size_t i = 0; i < dts_tmp.size(); i++) {
std::map<size_t, dnnl_data_type_t> tmp;
for (const auto &v : dts_tmp[i]) {
tmp[v.first] = str2dt(v.second.c_str());
}
dt_map[i] = std::move(tmp);
}
}

return true;
}

bool parse_graph_expected_n_partitions(
std::vector<size_t> &expected_n_partition_vec, const char *str) {
std::string expected_n_partitions_str;
Expand Down
6 changes: 5 additions & 1 deletion tests/benchdnn/graph/parser.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022-2024 Intel Corporation
* Copyright 2022-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,6 +44,10 @@ bool parse_graph_fpmath_mode(

bool parse_input_file(std::string &json_file, const char *str);

bool parse_dt(std::vector<dnnl_data_type_t> &dt,
std::vector<std::map<size_t, dnnl_data_type_t>> &dt_map,
const char *str, const std::string &option_name = "dt");

std::map<std::string, std::string> parse_attrs(const std::string &attrs_str);

// Convert f32 vec attrs string into f32 vec
Expand Down

0 comments on commit 129575f

Please sign in to comment.