From 4c278081e96b0829610949991d50b4273939f158 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 16 Sep 2023 21:19:17 -0400 Subject: [PATCH] Revert "[TIR] Shuffle in PointerValueTypeRewrite for scalar reads (#15517)" This reverts commit 925148e444103f044e9dbe111aacf0c5079abc3a. --- python/tvm/tir/transform/transform.py | 14 --- src/target/spirv/codegen_spirv.cc | 11 -- src/target/spirv/codegen_spirv.h | 1 - src/tir/transforms/storage_rewrite.cc | 112 +++++------------- ...ir_transform_pointer_value_type_rewrite.py | 73 ------------ 5 files changed, 31 insertions(+), 180 deletions(-) delete mode 100644 tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index c58062045c..dda81ce34f 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -230,20 +230,6 @@ def StorageRewrite(): return _ffi_api.StorageRewrite() # type: ignore -def PointerValueTypeRewrite(): - """ - Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use - the most frequently accessed type for load/store to avoid pointer casting in backend when - possible. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.PointerValueTypeRewrite() # type: ignore - - def UnrollLoop(): """Unroll the constant loop marked by unroll. diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 8f1aa8063b..3a98bfa305 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -610,17 +610,6 @@ void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::functionvectors.size() == 1 && op->indices.size() == 1) - << "SPIR-V codegen only supports shuffle " - << "of one vector with one index"; - spirv::Value vector = MakeValue(op->vectors[0]); - int index = Downcast(op->indices[0])->value; - spirv::SType etype = builder_->GetSType(op->dtype); - spirv::Value element = builder_->MakeValue(spv::OpCompositeExtract, etype, vector, index); - return element; -} - void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers"; Var buffer_var = op->buffer->data; diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 8ea90a9c4b..1e7b535585 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -102,7 +102,6 @@ class CodeGenSPIRV : public ExprFunctor, spirv::Value VisitExpr_(const RampNode* op) override; spirv::Value VisitExpr_(const BroadcastNode* op) override; spirv::Value VisitExpr_(const BufferLoadNode* op) override; - spirv::Value VisitExpr_(const ShuffleNode* op) override; // stmt void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const ForNode* op) override; diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index f271769c80..3ecd0f64bb 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -36,7 +36,6 @@ #include #include -#include "../../arith/int_operator.h" #include "../../runtime/thread_storage_scope.h" #include "../ir/buffer_common.h" #include "ir_utils.h" @@ -1067,18 +1066,12 @@ struct BufferVarInfo { // packing in StorageRewrite) or in number of lanes (e.g. float16* // cast to float16x4*). std::unordered_set access_dtype; - // Data types used for scalar reads. This is used to record vectorized read dtypes that can be - // shuffled for scalar reads when rewrite_scalar_read_to_vector_shuffle is enabled. - std::unordered_set scalar_read_dtype; DataType get_preferred_dtype() const { std::unordered_set base_access_dtype; for (auto dtype : access_dtype) { base_access_dtype.insert(dtype.element_of()); } - for (auto dtype : scalar_read_dtype) { - base_access_dtype.insert(dtype.element_of()); - } // If the array is accessed as multiple base types within a // function, no point in changing the declared type. CodeGenC can // handle this with a type-cast prior to indexing. Vulkan will @@ -1095,19 +1088,12 @@ struct BufferVarInfo { // size, then the buffer is vectorizable. In the future, this // could be improved to allow vectorized buffer access of size // GCD(*lanes_used), if necessary. - // When there are scalar reads and no writes, access_dtype can be empty and we should avoid - // rewriting. int preferred_lanes = element_dtype.lanes(); - if (element_dtype.lanes() == 1 && (access_dtype.size() == 1)) { - int lanes = access_dtype.begin()->lanes(); - // Check the scalar read dtypes are compatible with the vectorized access dtype. - for (auto dtype : scalar_read_dtype) { - if (dtype.lanes() % lanes != 0) { - return element_dtype; - } - } + if ((element_dtype.lanes() == 1) && (access_dtype.size() == 1)) { arith::Analyzer analyzer_; arith::ModularSet me = analyzer_.modular_set(extent); + + int lanes = access_dtype.begin()->lanes(); if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) { preferred_lanes = lanes; } @@ -1134,10 +1120,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor { * type as it is later accessed, with scalar element types. */ VectorTypeAccessChecker(const Array& params, const Map& buffer_map, - bool allow_untyped_pointers = false, - bool detect_scalar_read_patterns = true) - : allow_untyped_pointers_(allow_untyped_pointers), - detect_scalar_read_patterns_(detect_scalar_read_patterns) { + bool allow_untyped_pointers = false) + : allow_untyped_pointers_(allow_untyped_pointers) { // If a parameter is in the buffer map, we want to track the // version in the map. for (auto it : buffer_map) { @@ -1161,12 +1145,12 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) final { - OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices, /*is_buffer_load=*/true); + OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices); StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode* op) final { - OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices, /*is_buffer_load=*/false); + OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices); StmtExprVisitor::VisitStmt_(op); } @@ -1175,10 +1159,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); PrimExpr index = op->args[2]; - OnArrayAccess(dtype, buffer, {index}, false); - } else if (op->op.same_as(builtin::address_of())) { - BufferLoad load = Downcast(op->args[0]); - OnArrayAccess(load->dtype, load->buffer->data.get(), load->indices, /*is_buffer_load=*/false); + OnArrayAccess(dtype, buffer, {index}); } StmtExprVisitor::VisitExpr_(op); } @@ -1245,7 +1226,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor { if (element_dtype == DataType::Bool()) { element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes()); } - info_map_[buffer.get()] = BufferVarInfo{buffer, element_dtype, extent, declaration_location}; + + info_map_[buffer.get()] = {buffer, element_dtype, extent, declaration_location}; } /* Update the type map for a buffer based on its usage @@ -1255,12 +1237,11 @@ class VectorTypeAccessChecker : public StmtExprVisitor { * * @param buffer The VarNode representing the buffer. * - * @param indices The index at which the value is being stored/loaded. + * @param index The index at which the value is being stored/loaded. * - * @param is_buffer_load Whether the access is BufferLoad + * @param predicate The predicate used for the store/load. */ - void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const Array& indices, - bool is_buffer_load) { + void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const Array& indices) { auto it = info_map_.find(buffer); ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer << ") occurred before its declaration."; @@ -1323,14 +1304,6 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } } - if (detect_scalar_read_patterns_ && is_buffer_load && indices.size()) { - const PrimExpr last_dim_index = indices[indices.size() - 1]; - if (last_dim_index.dtype().lanes() == 1) { - arith::ModularSet me = analyzer_.modular_set(last_dim_index); - var_info.scalar_read_dtype.emplace(access_dtype.with_lanes(me->coeff)); - return; - } - } var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used)); } @@ -1339,8 +1312,6 @@ class VectorTypeAccessChecker : public StmtExprVisitor { // bool allow_untyped_pointers_{false}; - // Whether to detect scalar read patterns for rewriting to vector shuffle - bool detect_scalar_read_patterns_{true}; // internal analyzer arith::Analyzer analyzer_; @@ -1395,8 +1366,7 @@ class VectorTypeRewriter : public StmtExprMutator { VectorTypeRewriter(const std::unordered_map& info_map, bool rewrite_params = true, bool rewrite_buffer_map = true, bool rewrite_allocate_node = true, bool rewrite_indices = true, - bool rewrite_let_node = true, bool rewrite_allocate_const_node = true, - bool rewrite_scalar_read_to_vector_shuffle = true) + bool rewrite_let_node = true, bool rewrite_allocate_const_node = true) : rewrite_indices_(rewrite_indices) { int rewrite_mask = 0; if (rewrite_params) { @@ -1431,53 +1401,42 @@ class VectorTypeRewriter : public StmtExprMutator { } } - /*! - * \brief Mutator for BufferLoad or BufferStore. - * \return The rewritten node and the shuffle index. (Only for BufferLoad) When the shuffle index - * is non-negative, the caller should generate Shuffle to extract the element from the vector. - */ template - std::pair VisitBufferAccess(Node node) { - int shuffle_index = -1; + Node VisitBufferAccess(Node node) { if (!rewrite_indices_) { - return {node, shuffle_index}; + return node; } auto it = rewrite_map_.find(node->buffer->data.get()); if (it == rewrite_map_.end()) { - return {node, shuffle_index}; + return node; } const auto& info = it->second; Array indices = node->indices; - const PrimExpr& last_dim_index = indices[indices.size() - 1]; - if (const RampNode* ramp_index = last_dim_index.as(); - ramp_index && is_one(ramp_index->stride)) { + + const RampNode* ramp_index = indices[indices.size() - 1].as(); + if (ramp_index && is_one(ramp_index->stride)) { PrimExpr new_index = ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); if (ramp_index->lanes != info.factor()) { - ICHECK(info.factor() && ramp_index->lanes % info.factor() == 0); - int new_lanes = ramp_index->lanes / info.factor(); - new_index = Ramp(new_index * new_lanes, ramp_index->stride, new_lanes, ramp_index->span); + new_index = Ramp(new_index, ramp_index->stride, ramp_index->lanes / info.factor(), + ramp_index->span); } - indices.Set(indices.size() - 1, new_index); - } else if (last_dim_index.dtype().lanes() == 1 && info.factor() > 1) { - arith::ModularSet me = analyzer_.modular_set(last_dim_index); - ICHECK(me->coeff == 0 || info.factor() % me->coeff == 0); - PrimExpr new_index = last_dim_index / make_const(last_dim_index.dtype(), info.factor()); - shuffle_index = me->base; + indices.Set(indices.size() - 1, new_index); } auto writer = node.CopyOnWrite(); writer->buffer = RemapBuffer(node->buffer); writer->indices = indices; - return {node, shuffle_index}; + + return node; } PrimExpr VisitExpr_(const BufferLoadNode* op) final { auto node = Downcast(StmtExprMutator::VisitExpr_(op)); - auto [modified, shuffle_index] = VisitBufferAccess(node); + auto modified = VisitBufferAccess(node); // Not needed for BufferStoreNode, so we can't just call // LegalizeDtype() in VisitBufferAccess. @@ -1486,18 +1445,13 @@ class VectorTypeRewriter : public StmtExprMutator { } else { auto writer = modified.CopyOnWrite(); writer->LegalizeDType(); - if (shuffle_index >= 0) { - return Shuffle::ExtractElement(std::move(modified), shuffle_index); - } return std::move(modified); } } Stmt VisitStmt_(const BufferStoreNode* op) final { auto node = Downcast(StmtExprMutator::VisitStmt_(op)); - auto [modified, shuffle_index] = VisitBufferAccess(std::move(node)); - ICHECK(shuffle_index < 0); - return std::move(modified); + return VisitBufferAccess(std::move(node)); } Stmt VisitStmt_(const LetStmtNode* op) final { @@ -1673,7 +1627,6 @@ class VectorTypeRewriter : public StmtExprMutator { bool rewrite_indices_{true}; std::unordered_map rewrite_map_; std::unordered_map buffer_map_; - arith::Analyzer analyzer_; }; // Rewrite allocates, pointer parameters, and buffer map into vectorized versions @@ -1682,15 +1635,13 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false bool rewrite_params = true, bool rewrite_buffer_map = true, bool rewrite_allocate_node = true, bool rewrite_indices = true, bool rewrite_let_node = true, - bool rewrite_allocate_const_node = true, - bool rewrite_scalar_read_to_vector_shuffle = true) { - VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers, - rewrite_scalar_read_to_vector_shuffle); + bool rewrite_allocate_const_node = true) { + VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers); checker(f->body); VectorTypeRewriter rewriter(checker.info_map_, rewrite_params, rewrite_buffer_map, rewrite_allocate_node, rewrite_indices, rewrite_let_node, - rewrite_allocate_const_node, rewrite_scalar_read_to_vector_shuffle); + rewrite_allocate_const_node); PrimFuncNode* n = f.CopyOnWrite(); n->body = rewriter(std::move(n->body)); rewriter.Finalize(&f); @@ -1710,8 +1661,7 @@ Pass StorageRewrite() { // padded out to 32 bits) would require either rewriting // AllocateConst::data, or would require the code generators to // handle vectorized constants. - return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true, false, - false); + return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true, false); }; return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } diff --git a/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py b/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py deleted file mode 100644 index 7baa96c1a1..0000000000 --- a/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py +++ /dev/null @@ -1,73 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -import tvm.testing -from tvm import te -from tvm.driver.build_module import schedule_to_module -from tvm.script import tir as T - - -class BaseCompare(tvm.testing.CompareBeforeAfter): - transform = tvm.tir.transform.PointerValueTypeRewrite() - - -class TestRewriteToShuffle(BaseCompare): - @T.prim_func - def before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")): - A_local_data = T.allocate([16], "float32", scope="local") - A_local = T.Buffer((16,), "float32", data=A_local_data, scope="local") - for i in range(4): - A_local[i * 4 : i * 4 + 4] = A[i * 4 : i * 4 + 4] - for i in range(4): - B[i] = A_local[i * 4] + A_local[i * 4 + 1] + A_local[i * 4 + 2] + A_local[i * 4 + 3] - - @T.prim_func - def expected(A: T.Buffer((4,), "float32x4"), B: T.Buffer((4,), "float32")): - A_local_data = T.allocate([4], "float32x4", scope="local") - A_local = T.Buffer((4,), "float32x4", data=A_local_data, scope="local") - for i in range(4): - A_local[T.Div(i * 4, 4)] = A[T.Div(i * 4, 4)] - for i in range(4): - B[i] = ( - T.Shuffle([A_local[T.Div(i * 4, 4)]], [0]) - + T.Shuffle([A_local[T.Div(i * 4 + 1, 4)]], [1]) - + T.Shuffle([A_local[T.Div(i * 4 + 2, 4)]], [2]) - + T.Shuffle([A_local[T.Div(i * 4 + 3, 4)]], [3]) - ) - - -class TestAddressOf(BaseCompare): - @T.prim_func - def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): - for i in range(4): - T.evaluate(T.address_of(A[i * 4])) - B[i * 4 : i * 4 + 4] = A[i * 4 : i * 4 + 4] - - @T.prim_func - def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32x4")): - for i in range(4): - T.evaluate(T.address_of(A[i * 4])) - B[T.Div(i * 4, 4)] = A[i * 4 : i * 4 + 4] - - -class TestScalarReadWithoutWrite(BaseCompare): - @T.prim_func - def before(A: T.Buffer((16,), "float32")): - for i in range(4): - T.evaluate(A[i * 4]) - - expected = before