Skip to content

Commit

Permalink
benchdnn: graph: enable displacer for sdpa mask
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Jan 15, 2025
1 parent 17bec50 commit 3389eb0
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 5 deletions.
103 changes: 102 additions & 1 deletion tests/benchdnn/graph/input_displacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ partition_data_displacer_t::partition_data_displacer_t(

static const std::unordered_set<std::string> main_op_kind {"Convolution",
"ConvTranspose", "AvgPool", "MaxPool", "MatMul", "Add", "Divide",
"Maximum", "Minimum", "Multiply", "Substract"};
"Maximum", "Minimum", "Multiply", "Substract", "Select"};

static const std::unordered_set<std::string> go_through_op_kind {
"StaticTranspose", "StaticReshape", "TypeCast", "Quantize",
Expand Down Expand Up @@ -136,6 +136,78 @@ partition_data_displacer_t::partition_data_displacer_t(
}
}
}

// Alternatively, looking for Add->SoftMax chain, which represents
// explicit SDPA mask, and should be filled with upper-corner with -inf:
// 0 -inf -inf -inf
// 0 0 -inf -inf
// 0 0 0 -inf
// 0 0 0 0
// This is done to avoid taking future tokens into account by
// influencing SoftMax input values.
while (aop.kind_ == "Add" || aop.kind_ == "Select") {
auto *aop_out_lt = &aop.out_lts_[0];
auto *child_op = &dg_->get_op_by_in_lt(aop_out_lt->id_);
if (child_op->kind_ != "SoftMax") break;

// Softmax must be a part of same partition as the mask. This is to
// avoid cases, where mask is the last op in the partition, from
// being modified.
if (op_ids_set_.find(child_op->id_) == op_ids_set_.end()) break;

// Search for an input lt without a parent, this is the one to
// modify for both explicit and implicit masks.
const deserialized_lt *causal_mask_lt = nullptr;
size_t offset = SIZE_MAX;
for (size_t i = 0; i < aop.in_lts_.size(); i++) {
auto *aop_in_lt = &aop.in_lts_[i];
auto *parent_op = &dg_->get_op_by_out_lt(aop_in_lt->id_);
if (!parent_op->empty()) continue;

// Explicit masks expressed through Select op would have cond
// tensor standalone and not filled for every point. This
// represents a padding mask and not supported for now.
if (aop_in_lt->get_data_type()
== logical_tensor::data_type::boolean)
break;

causal_mask_lt = aop_in_lt;
offset = i;
break;
}
// No suitable tensor/subgraph for a mask displacement.
if (!causal_mask_lt) break;

if (aop.kind_ == "Add") {
const auto ndims = causal_mask_lt->shape_.size();
if (ndims < 2) {
BENCHDNN_PRINT(7, "%s\n",
"[DISPLACE]: Causal mask ndims is less than 2");
break;
}

const auto N = causal_mask_lt->shape_[ndims - 1];
const auto M = causal_mask_lt->shape_[ndims - 2];
if (M == 1 || N == 1) {
BENCHDNN_PRINT(7,
"[DISPLACE]: Causal mask shape has one: {%ld, "
"%ld}\n",
M, N);
break;
}
}

filling_type_t filling_type = filling_type_t::undef;
if (aop.kind_ == "Add")
filling_type = filling_type_t::causal_mask;
else if (aop.kind_ == "Select")
filling_type = filling_type_t::minus_infinity;

quantize_displace_.emplace(causal_mask_lt->id_,
std::make_tuple(
aop, offset, *causal_mask_lt, filling_type));
break;
}
}
}

Expand Down Expand Up @@ -188,6 +260,12 @@ int partition_data_displacer_t::displace_input_data(
= is_div ? pow2_div_vals : (is_mul ? pow2_mul_vals : dummy);
fill_cfg_t fill_cfg(user_set, "Mul/Div displacer");
SAFE(gen_fixed_set_filling(mem_replace, mem.md_, fill_cfg, res), WARN);
} else if (filling_type == filling_type_t::causal_mask) {
SAFE(gen_causal_mask_filling(mem_replace, mem.md_, res), WARN);
} else if (filling_type == filling_type_t::minus_infinity) {
static const std::vector<float> user_set {-INFINITY};
fill_cfg_t fill_cfg(user_set, "Implicit_causal_mask");
SAFE(gen_fixed_set_filling(mem_replace, mem.md_, fill_cfg, res), WARN);
} else {
assert(!"unexpected filling type");
}
Expand Down Expand Up @@ -402,4 +480,27 @@ int partition_data_displacer_t::gen_fixed_set_filling(dnn_mem_t &mem,
return OK;
}

int partition_data_displacer_t::gen_causal_mask_filling(
dnn_mem_t &mem, const_dnnl_memory_desc_t md, res_t *res) const {

dnn_mem_t tmp_mem(md, get_test_engine());

const int ndims = query_md_ndims(md);
assert(ndims >= 2); // This was checked at displacer initialization.
const auto &dims = query_md_dims(md);
const int64_t batch = std::accumulate(dims, dims + ndims - 2, (dnnl_dim_t)1,
std::multiplies<dnnl_dim_t>());
const int64_t M = dims[ndims - 2];
const int64_t N = dims[ndims - 1];

benchdnn_parallel_nd(batch, M, N, [&](int64_t b, int64_t m, int64_t n) {
int64_t idx = b * M * N + m * N + n;
float val = m >= n ? 0.f : -INFINITY;
tmp_mem.set_elem(idx, val);
});

mem = std::move(tmp_mem);
return OK;
}

} // namespace graph
9 changes: 8 additions & 1 deletion tests/benchdnn/graph/input_displacer.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2023-2024 Intel Corporation
* Copyright 2023-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,6 +29,10 @@ enum class filling_type_t {
quantization,
// Floating-point power-of-2 values for precise disivision/multiplication.
pow2,
// Explicit causal mask from SDPA pattern.
causal_mask,
// Implicit causal mask free input.
minus_infinity,
};

// tuple<
Expand Down Expand Up @@ -60,6 +64,9 @@ class partition_data_displacer_t {
// from `fill_cfg`.
int gen_fixed_set_filling(dnn_mem_t &mem, const_dnnl_memory_desc_t md,
const fill_cfg_t &fill_cfg, res_t *res) const;
// Generates causal mask filling for "Add" operation.
int gen_causal_mask_filling(
dnn_mem_t &mem, const_dnnl_memory_desc_t md, res_t *res) const;
};

} // namespace graph
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,12 @@
"shape": [
1,
1,
1,
384,
384
],
"stride": [
384,
384,
147456,
147456,
384,
1
],
Expand Down

0 comments on commit 3389eb0

Please sign in to comment.