Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding resize(PadOp) vectorization analysis #3321

Open
wants to merge 100 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
8f9708f
relaxing check
jjsjann123 Sep 2, 2024
54826aa
allow cache on inputs for pad
jjsjann123 Sep 3, 2024
e54938c
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Sep 3, 2024
2bc3c7a
cpp example
jjsjann123 Sep 24, 2024
d04e8c3
Merge branch 'jjsjann123/pad_vec' into jjsjann123/resize_vec
jjsjann123 Sep 24, 2024
d0addc4
reverting earlier changes
jjsjann123 Sep 24, 2024
490fdbe
Revert "reverting earlier changes"
jjsjann123 Sep 24, 2024
51c3022
cherry-pick my revert
jjsjann123 Sep 24, 2024
1158ef0
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Oct 2, 2024
fdc6a9a
debug print
jjsjann123 Oct 3, 2024
9a6c03a
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Oct 4, 2024
a9d16ce
removing comments
jjsjann123 Oct 7, 2024
3401119
removing assert
jjsjann123 Oct 8, 2024
5d05284
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Oct 8, 2024
b6587ee
patching test
jjsjann123 Oct 10, 2024
28decac
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Oct 10, 2024
3e53feb
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Oct 20, 2024
ad61ecb
fixing test
jjsjann123 Oct 20, 2024
a8edc56
fixing
jjsjann123 Oct 20, 2024
9cdeb64
fixing test
jjsjann123 Oct 21, 2024
09a2aee
does this work to replace Ternary(where) with IfThenElse
jjsjann123 Oct 21, 2024
895d0bf
fixing build
jjsjann123 Oct 21, 2024
7a15e22
removing print
jjsjann123 Oct 22, 2024
a6e8fb1
restore lower to ternary:where; restore vectorization on tests
jjsjann123 Oct 22, 2024
fe0f263
testing water
jjsjann123 Oct 23, 2024
baa7b09
fixing syntax
jjsjann123 Oct 23, 2024
ca5ced1
now it's functional
jjsjann123 Oct 23, 2024
e0492d3
better formatting on printed code
jjsjann123 Oct 23, 2024
b528429
adding a tab
jjsjann123 Oct 23, 2024
a23e010
supporting local memory
jjsjann123 Oct 23, 2024
57b90d1
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Oct 23, 2024
7a976c7
clangformat
jjsjann123 Oct 23, 2024
f11d662
apparently there are ternary operations on scalars
jjsjann123 Oct 23, 2024
5a83fc6
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Oct 23, 2024
39f83f7
fixing
jjsjann123 Oct 23, 2024
07eafd1
fixing
jjsjann123 Oct 23, 2024
986b361
clangformat
jjsjann123 Oct 23, 2024
7409913
clangformat
jjsjann123 Oct 23, 2024
76cbcd8
clangformat again
jjsjann123 Oct 23, 2024
a42b507
hacking mapping
jjsjann123 Oct 26, 2024
07fecdc
removing more hardcode on vectorization through resize
jjsjann123 Oct 26, 2024
69350a0
cpp test
jjsjann123 Oct 26, 2024
6b8527c
some typo in comment
jjsjann123 Oct 28, 2024
8236b8a
WIP for vectorization analysis
jjsjann123 Oct 28, 2024
5ed2139
WIP
jjsjann123 Oct 29, 2024
6c6bbde
fixing build
jjsjann123 Oct 29, 2024
fa64d0c
fixing kernel issues
jjsjann123 Oct 30, 2024
7cfa250
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Oct 30, 2024
62cc45d
fixing build
jjsjann123 Oct 30, 2024
8cce0fd
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Oct 30, 2024
c02ad8c
adding tests first
jjsjann123 Oct 30, 2024
28dd1a1
more tests
jjsjann123 Oct 30, 2024
5f996fc
Merge branch 'main' into jjsjann123/resize_vec
jjsjann123 Oct 31, 2024
e1558c8
Merge branch 'jjsjann123/resize_vec' into jjsjann123/pad_vec_analysis
jjsjann123 Oct 31, 2024
bd2161e
check should work now?!
jjsjann123 Oct 31, 2024
b36b04c
Merge remote-tracking branch 'origin/jjsjann123/pad_vec_analysis' int…
jjsjann123 Oct 31, 2024
04c9db0
WIP
jjsjann123 Oct 31, 2024
7816a93
adding one more test
jjsjann123 Oct 31, 2024
7836b0f
disable mark alias for test
jjsjann123 Oct 31, 2024
3e29156
updating tests
jjsjann123 Nov 1, 2024
3f58e68
adding filter on resize
jjsjann123 Nov 1, 2024
274bd98
fixing build
jjsjann123 Nov 1, 2024
03cb4a7
FIXING
jjsjann123 Nov 2, 2024
2f43e0d
typo
jjsjann123 Nov 2, 2024
d4f7680
fixing the logic
jjsjann123 Nov 4, 2024
de83f9d
Merge remote-tracking branch 'origin/main' into jjsjann123/pad_vec_an…
jjsjann123 Nov 4, 2024
7170a84
clangformat
jjsjann123 Nov 4, 2024
803a95b
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Nov 4, 2024
a67fb57
polish PR for review
jjsjann123 Nov 4, 2024
11cd4d1
Merge remote-tracking branch 'origin/jjsjann123/resize_vec' into jjsj…
jjsjann123 Nov 4, 2024
65aa77d
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Nov 4, 2024
1f75d7a
missed one arg
jjsjann123 Nov 4, 2024
4c92371
oops, fixing the generated code
jjsjann123 Nov 4, 2024
2765bb7
Merge remote-tracking branch 'origin/jjsjann123/resize_vec' into HEAD
jjsjann123 Nov 4, 2024
3ec2a6b
review comments
jjsjann123 Nov 4, 2024
d2864ab
fixing code
jjsjann123 Nov 5, 2024
1b4f2c1
I think this is fixed now
jjsjann123 Nov 5, 2024
0e4e61f
adding comments per review request
jjsjann123 Nov 5, 2024
4d4f747
another comment
jjsjann123 Nov 5, 2024
fc3b28f
Merge branch 'jjsjann123/resize_vec' into HEAD
jjsjann123 Nov 5, 2024
85517b8
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Nov 5, 2024
69469c5
fixing cacheAfter
jjsjann123 Nov 5, 2024
cc69971
fixing assert in cacheAfter
jjsjann123 Nov 5, 2024
c46e50e
fix
jjsjann123 Nov 5, 2024
a63bb67
fixing tests
jjsjann123 Nov 6, 2024
974c85b
fixing tests
jjsjann123 Nov 6, 2024
4c7366b
err should have checked the input for uop
jjsjann123 Nov 6, 2024
2a1ff76
clangformat
jjsjann123 Nov 6, 2024
7b71934
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Nov 6, 2024
d019f9a
comment
jjsjann123 Nov 6, 2024
e8033c1
Update tests/cpp/test_resize.cpp
jjsjann123 Nov 7, 2024
6848dad
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Nov 8, 2024
5299327
addressing review comment; adding more tests
jjsjann123 Nov 8, 2024
48cafba
errr
jjsjann123 Nov 8, 2024
61d7c3f
fix
jjsjann123 Nov 8, 2024
688b99c
Merge branch 'main' into jjsjann123/pad_vec_analysis
jjsjann123 Nov 9, 2024
67beb5c
Merge branch 'main' into jjsjann123/pad_vec_analysis
jjsjann123 Nov 12, 2024
d64ea9c
move from unordered_set to vector
jjsjann123 Nov 12, 2024
d9b3cc8
update const
jjsjann123 Nov 12, 2024
a4743bf
switching to assert
jjsjann123 Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,16 @@ class NVF_API TensorView : public Val {
//!
//! @param op_type: memory operator to use for the inserted op between
//! the the data tensor and the cache tensor
//! @param cache_op: cache operator, see enum class CacheOp
//! @param propagate_allocation_domain: replay allocation domain on cached
//! load
//! @param cached_uses: if empty, cache all uses; otherwise, only try to cache
//! uses in cached_uses.
TensorView* cacheAfter(
LoadStoreOpType op_type = LoadStoreOpType::Set,
CacheOp cache_op = CacheOp::Unspecified,
bool propagate_allocation_domain = true);
bool propagate_allocation_domain = true,
std::vector<Expr*> cached_uses = {});

// For a fusion output with other uses, we want to avoid writing to global
// memory and then reading the output again. We write to global memory
Expand Down
5 changes: 4 additions & 1 deletion csrc/preseg_passes/move_pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,10 @@ TensorView* replayConcretePad(

auto* new_out = IrBuilder::create<TensorView>(
IrBuilder::create<TensorDomain>(
merged_root_ids, merged_logical_ids, merged_logical_ids),
merged_root_ids,
merged_logical_ids,
merged_logical_ids,
TensorDomain::getContiguityFilledWith(merged_logical_ids, true)),
naoyam marked this conversation as resolved.
Show resolved Hide resolved
pad_tv->getDataType().value());
IrBuilder::create<PadOp>(
new_out,
Expand Down
68 changes: 34 additions & 34 deletions csrc/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1198,12 +1198,36 @@ std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll) {
for (auto tv : in_tvs) {
if (tv->uses().empty() || ir_utils::isTorchGatherLookupTv(tv) ||
ir_utils::isIndexSelectLookupTv(tv) ||
ir_utils::isTvUsedByOpsOfType<SliceOp, SelectOp, PadOp>(tv)) {
ir_utils::isTvUsedByOpsOfType<SliceOp, SelectOp>(tv)) {
// Right now, tensors that are input to the slice, select, and pad ops
// can't be cached as they must be in global memory.
continue;
}
auto cached_tv = tv->cacheAfter();

// TODO: might need to reverse this when scheduler handles pad directly
// Do not insert a cache for pad as vectorization needs to be
// done directly.
//
// Note that this means that if an input is padded and also is
// used without padding, it will be read twice, once for pad and
// once more for caching load. It would make sense to use the PTX
// caching load instructions.
std::vector<Expr*> cached_uses;
for (auto use : tv->uses()) {
if (!use->isA<PadOp>()) {
cached_uses.push_back(use);
}
}

if (cached_uses.empty()) {
continue;
}

auto cached_tv = tv->cacheAfter(
/*op_type=*/LoadStoreOpType::Set,
/*cache_op=*/CacheOp::Unspecified,
/*propagate_allocation_domain=*/true,
/*cached_uses=*/cached_uses);
cached_inputs.emplace_back(cached_tv);
}
return cached_inputs;
Expand Down Expand Up @@ -1290,12 +1314,7 @@ IterDomain* projectIdToRoot(
} else if (expr->isA<Resize>()) {
auto resize = expr->as<Resize>();
if (resize->out() == projected_id) {
// We do not allow vectorization with resize at this moment
if (vectorize_pass) {
projected_id = nullptr;
} else {
projected_id = resize->in();
}
projected_id = resize->in();
}
} else {
NVF_THROW("Didn't recognize the iterdomain expression: ", expr);
Expand Down Expand Up @@ -1350,12 +1369,7 @@ IterDomain* projectIdToRFactor(
} else if (expr->isA<Resize>()) {
auto resize = expr->as<Resize>();
if (resize->in() == projected_id) {
// We do not allow vectorization wit resize at this moment
if (vectorize_pass) {
projected_id = nullptr;
} else {
projected_id = resize->out();
}
projected_id = resize->out();
}
} else {
NVF_THROW("Didn't recognize the iterdomain expression: ", expr);
Expand Down Expand Up @@ -1549,12 +1563,6 @@ std::vector<TensorView*> getInputsOutputsWithInnerDim(
// scheduler prefer to use output instead of input as reference tensor.
for (auto output_tv :
ir_utils::filterByType<TensorView>(reference_tv->fusion()->outputs())) {
// At this moment, vectorization through resize is not
// supported. This is not required currently as we always insert
// cacheBefore, but just in case.
if (ir_utils::hasResizedRfactor(output_tv)) {
continue;
}
if (hasInnerDim(output_tv, vectorizable_dims, vectorize_pass)) {
vectorizable_tensors.push_back(output_tv);
}
Expand All @@ -1569,19 +1577,11 @@ std::vector<TensorView*> getInputsOutputsWithInnerDim(
continue;
}

auto expr_resizes = [](Expr* e) -> bool {
return std::any_of(
e->outputs().begin(), e->outputs().end(), [](Val* out) -> bool {
if (auto* out_tv = dynamic_cast<TensorView*>(out)) {
return ir_utils::hasResizedRfactor(out_tv);
}
return false;
});
};

// At this moment, vectorization through resize is not supported
if (std::any_of(
input_tv->uses().begin(), input_tv->uses().end(), expr_resizes)) {
// Slice op is explicitly not enabled for vectorized load.
if (std::all_of(
input_tv->uses().begin(),
input_tv->uses().end(),
[](Expr* e) -> bool { return e->isA<SliceOp>(); })) {
continue;
}

Expand Down Expand Up @@ -2385,7 +2385,7 @@ bool revertUseOfInputCache(
void prepareForMemoryTypePromotion(Fusion* fusion) {
auto non_pwise_pairs = getNonPointwiseProducerConsumerPairs(fusion);

// Inserting a copy of each proucer. If a tensor shows up as a
// Inserting a copy of each producer. If a tensor shows up as a
// producer for multiple consumers, only insert one
// copy and share it with all the consumers.

Expand Down
81 changes: 74 additions & 7 deletions csrc/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ Val* ContiguousInnerDimensionsMapper::isFullyProjected(IterDomain* id) {
getProjectedExtent(id), commonOrConstExtent(ca_map_, id));
}

void ContiguousInnerDimensionsMapper::initializeResizeInfo(Fusion* fusion) {
auto exprs = fusion->exprs();
for (auto* pad_op : ir_utils::filterByType<PadOp>(exprs)) {
if (!pad_op->out()->isA<TensorView>()) {
continue;
}

auto* out_tv = pad_op->out()->as<TensorView>();

auto consumer_exprs = StmtSort::getExprsBetween(
{out_tv->getMaybeRootDomain().begin(),
out_tv->getMaybeRootDomain().end()},
{out_tv->getLogicalDomain().begin(), out_tv->getLogicalDomain().end()});

// NOTE: if we can assume that PadOp is always on inputs, then we can skip
// to innermost resize instead.
auto resize_ops = ir_utils::filterByType<Resize>(consumer_exprs);
std::copy(
resize_ops.begin(),
resize_ops.end(),
std::inserter(resize_in_pad_, resize_in_pad_.end()));
}
}

ContiguousInnerDimensionsMapper::ContiguousInnerDimensionsMapper(
TensorView* reference,
const std::vector<IterDomain*>& ids,
Expand All @@ -67,6 +91,9 @@ ContiguousInnerDimensionsMapper::ContiguousInnerDimensionsMapper(
ca_map_(std::move(ca_map)),
divisible_splits_(divisible_splits) {
FusionGuard fg(reference->fusion());

initializeResizeInfo(reference->fusion());

// Exclude reduction IDs if the reference is a fusion input as they
// don't manifest at all in the fusion. This simplifies the
// analysis in getContigMergeOfInnerSize, which only looks at
Expand Down Expand Up @@ -365,9 +392,51 @@ std::vector<IterDomain*> ContiguousInnerDimensionsMapper::projectId(
distributePE(merge_or_split);
};

auto clear_left_of = [&frontier](IterDomain* id) {
auto it = std::find(frontier.begin(), frontier.end(), id);
if (it != frontier.end()) {
auto propagateResize = [&frontier, this](Resize* resize_op, bool p2c) {
IterDomain* id_from = p2c ? resize_op->in() : resize_op->out();
IterDomain* id_to = p2c ? resize_op->out() : resize_op->in();

auto it = std::find(frontier.begin(), frontier.end(), id_from);
if (it == frontier.end()) {
return;
}

auto pos = std::distance(frontier.begin(), it);
if (resize_in_pad_.count(resize_op) != 0) {
// resize created by PadOp.

// project resize op to frontier.
frontier[pos] = id_to;
// clear left of resize, since those are no long contiguous.
frontier.erase(frontier.begin(), it);

if (recording_) {
// TODO: support negative resize extent.
//
// Limit current support to only positive resize extent for now. So we
// only consider the pad_extent, which becomes the real buffer on
// output. Hence we do GCD among padded extent as well as extent of the
// id_from. Note since we are taking the GCD here, I don't think using
// id_from or id_to makes a difference.
auto consumer_factor = getProjectedExtent(id_from);
auto comp = [](Val* factor, Val* extent) {
return SimplifyingIrBuilder::whereExpr(
SimplifyingIrBuilder::eqExpr(
extent, extent->container()->zeroVal()),
jjsjann123 marked this conversation as resolved.
Show resolved Hide resolved
factor,
// for extent < 0, we'll take max(1, extent). Because of the gcd,
// This is effectively excluding the resize id from vectorization.
SimplifyingIrBuilder::gcdExpr(
factor,
SimplifyingIrBuilder::maxExpr(
extent->container()->oneVal(), extent)));
};
consumer_factor = comp(consumer_factor, resize_op->leftExpand());
consumer_factor = comp(consumer_factor, resize_op->rightExpand());
addProjectedExtent(id_to, consumer_factor);
}
} else {
// unsupproted resize.
frontier.erase(frontier.begin(), it + 1);
}
};
Expand All @@ -391,8 +460,7 @@ std::vector<IterDomain*> ContiguousInnerDimensionsMapper::projectId(
} else if (Merge* merge = dynamic_cast<Merge*>(expr)) {
propagateDistribute(merge);
} else if (Resize* resize = dynamic_cast<Resize*>(expr)) {
// Cannot vectorize through resize
clear_left_of(resize->out());
propagateResize(resize, false);
} else {
// TODO: I wonder if we should just remove all inputs instead of erroring.
// Seems that would be safe.
Expand All @@ -415,8 +483,7 @@ std::vector<IterDomain*> ContiguousInnerDimensionsMapper::projectId(
} else if (Split* split = dynamic_cast<Split*>(expr)) {
propagateDistribute(split);
} else if (Resize* resize = dynamic_cast<Resize*>(expr)) {
// Cannot vectorize through resize
clear_left_of(resize->in());
propagateResize(resize, true);
} else {
// TODO: I wonder if we should just remove all inputs instead of erroring.
// Seems that would be safe.
Expand Down
10 changes: 8 additions & 2 deletions csrc/scheduler/vectorize_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace vectorize_helper {

// Projects IterDomains through the fusion starting at provided reference. IDs
// in the reference are expected to be "contiguous", simply means dimensions
// that the iter domains are consecutive and next to eachother in the
// that the iter domains are consecutive and next to each other in the
// reference. This property is not enforced, but mapping can have some
// unpredictbale results if they are not. The reason we want contiguity here
// is this class is primarily used for vectorization analysis. Domains may be
Expand Down Expand Up @@ -78,7 +78,7 @@ namespace vectorize_helper {
// tv1[2*3, 5, 7*11] = view(tv0)
// with tv1 and [2*3, 7*11] as the reference and ids. tv0's 2 and 11 dim are
// easily identified as being mapped. The 3*5*7 dimension however, is
// partially mapped on the left and right side. Since this class is intended to
// partially mapped on the left and right side. Since this class is intended to
// line up "inner dimensions" of tensors through out the graph for the purpose
// of unrolling and vectorization, it only tracks partial dimensions as they are
// on the right hand side of iteration domains. For example in the last case we
Expand Down Expand Up @@ -289,6 +289,9 @@ class NVF_API ContiguousInnerDimensionsMapper
void propagateP2C(TensorView* from, TensorView* to) final;
void propagateSibling(TensorView* from, TensorView* to) final;

// traverse fusion to mark the origin of Resize
void initializeResizeInfo(Fusion* fusion);

// Initialized to false, series of compute... calls will be performed to find
// the spanning tree. Then propagate... calls will call the compute... calls.
// recording_ starts as false, and stays that way during the first series of
Expand All @@ -308,6 +311,9 @@ class NVF_API ContiguousInnerDimensionsMapper
tv_infos_;

std::unordered_map<IterDomain*, Val*> projected_extent_;

//! stores all Resize* op that's added from PadOp*
std::unordered_set<Resize*> resize_in_pad_;
};

// logical_reorder_map is provided to assume reference_tv will be reordered per
Expand Down
43 changes: 33 additions & 10 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1168,15 +1168,37 @@ TensorView* TensorView::cacheFork() {
TensorView* TensorView::cacheAfter(
LoadStoreOpType op_type,
CacheOp cache_op,
bool propagate_allocation_domain) {
bool propagate_allocation_domain,
std::vector<Expr*> cached_uses) {
NVF_ERROR(
!container()->isA<kir::Kernel>(),
"Function invalid for kernel container.");
FusionGuard fg(fusion());

if (!cached_uses.empty()) {
std::unordered_set<Expr*> unique_uses = fusion()->unordered_uses(this);
for (auto use : cached_uses) {
NVF_ERROR(
unique_uses.count(use),
"cached_uses is not among the use of the TensorView");
}
} else {
// avoid non-determinism and ensure unique
std::unordered_set<Expr*> unique_uses;
auto this_uses = uses();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to have duplicates?

Copy link
Collaborator Author

@jjsjann123 jjsjann123 Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. since we have uses_ as private and addUse does skip duplicates

bool Val::addUse(Expr* expr) {
if (std::find(uses_.begin(), uses_.end(), expr) == uses_.end()) {
uses_.push_back(expr);
return true;
}
return false;
}

But I'm a bit scared by just leaving it not checked, since it's only a std::vector and it's up to the implementation to change that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, let's make it an error if a duplicate is found.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switched. I'll kick off the CI to see if there's any surprises. 🤞

cached_uses.reserve(this_uses.size());
for (Expr* use : this_uses) {
NVF_ERROR(
unique_uses.count(use) == 0,
"detect duplicated entries in TensorView::uses()");
cached_uses.push_back(use);
unique_uses.insert(use);
}
}

// Get all the uses for this Tensorview
NVF_CHECK(
!uses().empty(),
!cached_uses.empty(),
"Error adding cacheAfter ",
this,
" we restrict using cacheAfter on tensors that have no further uses.");
Expand All @@ -1187,18 +1209,19 @@ TensorView* TensorView::cacheAfter(
!hasComputeAt(),
"Caching computed-at tensors is not allowed. Apply caching before computeAt.");

bool is_allowed_op =
!ir_utils::isTvUsedByOpsOfType<SliceOp, SelectOp, PadOp>(this) &&
!ir_utils::isIndexSelectLookupTv(this);
NVF_CHECK(
is_allowed_op,
"Right now, caching tensors that are input to the select/slice/pad ops are not allowed as they must be in global memory.")
// disallow cache on operation where we require data remain in global memory.
for (auto use : cached_uses) {
NVF_ERROR(
!(use->isOneOf<SliceOp, SelectOp, PadOp>()) &&
!(use->isA<IndexSelectOp>() && use->input(0) == this),
"Right now, caching tensors that are input to the select/slice/pad ops are not allowed as they must be in global memory.");
}

// It also did additional transformation when this tensor is an
// input and the outputs of its consumers have computeAt. Make sure
// we no longer rely on that behavior.
if (isFusionInput()) {
for (const auto& expr : uses()) {
for (const auto& expr : cached_uses) {
for (TensorView* output :
ir_utils::filterByType<TensorView>(expr->outputs())) {
NVF_CHECK(
Expand Down Expand Up @@ -1241,7 +1264,7 @@ TensorView* TensorView::cacheAfter(
// After: This TV -> [Set Op] -> New CA TV -> [Use Op] -> Next TV

// Expr* consumer_uses =
for (auto expr : fusion()->unordered_uses(this)) {
for (auto expr : cached_uses) {
ir_utils::replaceValInExprInputs(expr, this, consumer);
}

Expand Down
Loading
Loading