Skip to content

Commit

Permalink
Revert "[TIR] Shuffle in PointerValueTypeRewrite for scalar reads (#1…
Browse files Browse the repository at this point in the history
…5517)"

This reverts commit 925148e.
  • Loading branch information
MasterJH5574 authored and junrushao committed Sep 24, 2023
1 parent c30cbf3 commit 4c27808
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 180 deletions.
14 changes: 0 additions & 14 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,20 +230,6 @@ def StorageRewrite():
return _ffi_api.StorageRewrite() # type: ignore


def PointerValueTypeRewrite():
"""
Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use
the most frequently accessed type for load/store to avoid pointer casting in backend when
possible.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.PointerValueTypeRewrite() # type: ignore


def UnrollLoop():
"""Unroll the constant loop marked by unroll.
Expand Down
11 changes: 0 additions & 11 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -610,17 +610,6 @@ void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function<void(int i, spirv:
}
}

spirv::Value CodeGenSPIRV::VisitExpr_(const ShuffleNode* op) {
ICHECK(op->vectors.size() == 1 && op->indices.size() == 1)
<< "SPIR-V codegen only supports shuffle "
<< "of one vector with one index";
spirv::Value vector = MakeValue(op->vectors[0]);
int index = Downcast<Integer>(op->indices[0])->value;
spirv::SType etype = builder_->GetSType(op->dtype);
spirv::Value element = builder_->MakeValue(spv::OpCompositeExtract, etype, vector, index);
return element;
}

void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) {
ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers";
Var buffer_var = op->buffer->data;
Expand Down
1 change: 0 additions & 1 deletion src/target/spirv/codegen_spirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
spirv::Value VisitExpr_(const RampNode* op) override;
spirv::Value VisitExpr_(const BroadcastNode* op) override;
spirv::Value VisitExpr_(const BufferLoadNode* op) override;
spirv::Value VisitExpr_(const ShuffleNode* op) override;
// stmt
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const ForNode* op) override;
Expand Down
112 changes: 31 additions & 81 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include <unordered_map>
#include <unordered_set>

#include "../../arith/int_operator.h"
#include "../../runtime/thread_storage_scope.h"
#include "../ir/buffer_common.h"
#include "ir_utils.h"
Expand Down Expand Up @@ -1067,18 +1066,12 @@ struct BufferVarInfo {
// packing in StorageRewrite) or in number of lanes (e.g. float16*
// cast to float16x4*).
std::unordered_set<DataType> access_dtype;
// Data types used for scalar reads. This is used to record vectorized read dtypes that can be
// shuffled for scalar reads when rewrite_scalar_read_to_vector_shuffle is enabled.
std::unordered_set<DataType> scalar_read_dtype;

DataType get_preferred_dtype() const {
std::unordered_set<DataType> base_access_dtype;
for (auto dtype : access_dtype) {
base_access_dtype.insert(dtype.element_of());
}
for (auto dtype : scalar_read_dtype) {
base_access_dtype.insert(dtype.element_of());
}
// If the array is accessed as multiple base types within a
// function, no point in changing the declared type. CodeGenC can
// handle this with a type-cast prior to indexing. Vulkan will
Expand All @@ -1095,19 +1088,12 @@ struct BufferVarInfo {
// size, then the buffer is vectorizable. In the future, this
// could be improved to allow vectorized buffer access of size
// GCD(*lanes_used), if necessary.
// When there are scalar reads and no writes, access_dtype can be empty and we should avoid
// rewriting.
int preferred_lanes = element_dtype.lanes();
if (element_dtype.lanes() == 1 && (access_dtype.size() == 1)) {
int lanes = access_dtype.begin()->lanes();
// Check the scalar read dtypes are compatible with the vectorized access dtype.
for (auto dtype : scalar_read_dtype) {
if (dtype.lanes() % lanes != 0) {
return element_dtype;
}
}
if ((element_dtype.lanes() == 1) && (access_dtype.size() == 1)) {
arith::Analyzer analyzer_;
arith::ModularSet me = analyzer_.modular_set(extent);

int lanes = access_dtype.begin()->lanes();
if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) {
preferred_lanes = lanes;
}
Expand All @@ -1134,10 +1120,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
* type as it is later accessed, with scalar element types.
*/
VectorTypeAccessChecker(const Array<tir::Var>& params, const Map<Var, Buffer>& buffer_map,
bool allow_untyped_pointers = false,
bool detect_scalar_read_patterns = true)
: allow_untyped_pointers_(allow_untyped_pointers),
detect_scalar_read_patterns_(detect_scalar_read_patterns) {
bool allow_untyped_pointers = false)
: allow_untyped_pointers_(allow_untyped_pointers) {
// If a parameter is in the buffer map, we want to track the
// version in the map.
for (auto it : buffer_map) {
Expand All @@ -1161,12 +1145,12 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
}

void VisitExpr_(const BufferLoadNode* op) final {
OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices, /*is_buffer_load=*/true);
OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices);
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode* op) final {
OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices, /*is_buffer_load=*/false);
OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices);
StmtExprVisitor::VisitStmt_(op);
}

Expand All @@ -1175,10 +1159,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
PrimExpr index = op->args[2];
OnArrayAccess(dtype, buffer, {index}, false);
} else if (op->op.same_as(builtin::address_of())) {
BufferLoad load = Downcast<BufferLoad>(op->args[0]);
OnArrayAccess(load->dtype, load->buffer->data.get(), load->indices, /*is_buffer_load=*/false);
OnArrayAccess(dtype, buffer, {index});
}
StmtExprVisitor::VisitExpr_(op);
}
Expand Down Expand Up @@ -1245,7 +1226,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
if (element_dtype == DataType::Bool()) {
element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes());
}
info_map_[buffer.get()] = BufferVarInfo{buffer, element_dtype, extent, declaration_location};

info_map_[buffer.get()] = {buffer, element_dtype, extent, declaration_location};
}

/* Update the type map for a buffer based on its usage
Expand All @@ -1255,12 +1237,11 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
*
* @param buffer The VarNode representing the buffer.
*
* @param indices The index at which the value is being stored/loaded.
* @param index The index at which the value is being stored/loaded.
*
* @param is_buffer_load Whether the access is BufferLoad
* @param predicate The predicate used for the store/load.
*/
void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const Array<PrimExpr>& indices,
bool is_buffer_load) {
void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const Array<PrimExpr>& indices) {
auto it = info_map_.find(buffer);
ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer
<< ") occurred before its declaration.";
Expand Down Expand Up @@ -1323,14 +1304,6 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
}
}

if (detect_scalar_read_patterns_ && is_buffer_load && indices.size()) {
const PrimExpr last_dim_index = indices[indices.size() - 1];
if (last_dim_index.dtype().lanes() == 1) {
arith::ModularSet me = analyzer_.modular_set(last_dim_index);
var_info.scalar_read_dtype.emplace(access_dtype.with_lanes(me->coeff));
return;
}
}
var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used));
}

Expand All @@ -1339,8 +1312,6 @@ class VectorTypeAccessChecker : public StmtExprVisitor {

//
bool allow_untyped_pointers_{false};
// Whether to detect scalar read patterns for rewriting to vector shuffle
bool detect_scalar_read_patterns_{true};

// internal analyzer
arith::Analyzer analyzer_;
Expand Down Expand Up @@ -1395,8 +1366,7 @@ class VectorTypeRewriter : public StmtExprMutator {
VectorTypeRewriter(const std::unordered_map<const VarNode*, BufferVarInfo>& info_map,
bool rewrite_params = true, bool rewrite_buffer_map = true,
bool rewrite_allocate_node = true, bool rewrite_indices = true,
bool rewrite_let_node = true, bool rewrite_allocate_const_node = true,
bool rewrite_scalar_read_to_vector_shuffle = true)
bool rewrite_let_node = true, bool rewrite_allocate_const_node = true)
: rewrite_indices_(rewrite_indices) {
int rewrite_mask = 0;
if (rewrite_params) {
Expand Down Expand Up @@ -1431,53 +1401,42 @@ class VectorTypeRewriter : public StmtExprMutator {
}
}

/*!
* \brief Mutator for BufferLoad or BufferStore.
* \return The rewritten node and the shuffle index. (Only for BufferLoad) When the shuffle index
* is non-negative, the caller should generate Shuffle to extract the element from the vector.
*/
template <typename Node>
std::pair<Node, int> VisitBufferAccess(Node node) {
int shuffle_index = -1;
Node VisitBufferAccess(Node node) {
if (!rewrite_indices_) {
return {node, shuffle_index};
return node;
}

auto it = rewrite_map_.find(node->buffer->data.get());
if (it == rewrite_map_.end()) {
return {node, shuffle_index};
return node;
}
const auto& info = it->second;

Array<PrimExpr> indices = node->indices;
const PrimExpr& last_dim_index = indices[indices.size() - 1];
if (const RampNode* ramp_index = last_dim_index.as<RampNode>();
ramp_index && is_one(ramp_index->stride)) {

const RampNode* ramp_index = indices[indices.size() - 1].as<RampNode>();
if (ramp_index && is_one(ramp_index->stride)) {
PrimExpr new_index =
ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes);
if (ramp_index->lanes != info.factor()) {
ICHECK(info.factor() && ramp_index->lanes % info.factor() == 0);
int new_lanes = ramp_index->lanes / info.factor();
new_index = Ramp(new_index * new_lanes, ramp_index->stride, new_lanes, ramp_index->span);
new_index = Ramp(new_index, ramp_index->stride, ramp_index->lanes / info.factor(),
ramp_index->span);
}
indices.Set(indices.size() - 1, new_index);
} else if (last_dim_index.dtype().lanes() == 1 && info.factor() > 1) {
arith::ModularSet me = analyzer_.modular_set(last_dim_index);
ICHECK(me->coeff == 0 || info.factor() % me->coeff == 0);
PrimExpr new_index = last_dim_index / make_const(last_dim_index.dtype(), info.factor());
shuffle_index = me->base;

indices.Set(indices.size() - 1, new_index);
}

auto writer = node.CopyOnWrite();
writer->buffer = RemapBuffer(node->buffer);
writer->indices = indices;
return {node, shuffle_index};

return node;
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto [modified, shuffle_index] = VisitBufferAccess(node);
auto modified = VisitBufferAccess(node);

// Not needed for BufferStoreNode, so we can't just call
// LegalizeDtype() in VisitBufferAccess.
Expand All @@ -1486,18 +1445,13 @@ class VectorTypeRewriter : public StmtExprMutator {
} else {
auto writer = modified.CopyOnWrite();
writer->LegalizeDType();
if (shuffle_index >= 0) {
return Shuffle::ExtractElement(std::move(modified), shuffle_index);
}
return std::move(modified);
}
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto [modified, shuffle_index] = VisitBufferAccess(std::move(node));
ICHECK(shuffle_index < 0);
return std::move(modified);
return VisitBufferAccess(std::move(node));
}

Stmt VisitStmt_(const LetStmtNode* op) final {
Expand Down Expand Up @@ -1673,7 +1627,6 @@ class VectorTypeRewriter : public StmtExprMutator {
bool rewrite_indices_{true};
std::unordered_map<const VarNode*, RewriteInfo> rewrite_map_;
std::unordered_map<const BufferNode*, Buffer> buffer_map_;
arith::Analyzer analyzer_;
};

// Rewrite allocates, pointer parameters, and buffer map into vectorized versions
Expand All @@ -1682,15 +1635,13 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false
bool rewrite_params = true, bool rewrite_buffer_map = true,
bool rewrite_allocate_node = true, bool rewrite_indices = true,
bool rewrite_let_node = true,
bool rewrite_allocate_const_node = true,
bool rewrite_scalar_read_to_vector_shuffle = true) {
VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers,
rewrite_scalar_read_to_vector_shuffle);
bool rewrite_allocate_const_node = true) {
VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers);
checker(f->body);

VectorTypeRewriter rewriter(checker.info_map_, rewrite_params, rewrite_buffer_map,
rewrite_allocate_node, rewrite_indices, rewrite_let_node,
rewrite_allocate_const_node, rewrite_scalar_read_to_vector_shuffle);
rewrite_allocate_const_node);
PrimFuncNode* n = f.CopyOnWrite();
n->body = rewriter(std::move(n->body));
rewriter.Finalize(&f);
Expand All @@ -1710,8 +1661,7 @@ Pass StorageRewrite() {
// padded out to 32 bits) would require either rewriting
// AllocateConst::data, or would require the code generators to
// handle vectorized constants.
return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true, false,
false);
return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true, false);
};
return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
}
Expand Down

This file was deleted.

0 comments on commit 4c27808

Please sign in to comment.