Skip to content

Commit

Permalink
[Bugfix] Work around "Internal Compiler Error" in MSVC (#271)
Browse files Browse the repository at this point in the history
This PR found a workaround that prevents the nigthly windows pypi
package of TVM Unity to be released. Basically it splits out a deeply
nested method into several smaller ones so that MSVC's internal handling
is supposed to be slightly simpler and less error prone.
  • Loading branch information
junrushao authored Jul 22, 2023
1 parent 2031026 commit dc745fc
Showing 1 changed file with 84 additions and 78 deletions.
162 changes: 84 additions & 78 deletions src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,86 +495,29 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
access_annotations_;
};

/*! \brief The storage alignment for a dimension */
struct DimAlignInfo {
/*! \brief The factor of the alignment */
int align_factor{0};
/*! \brief The offset of the alignment */
int align_offset{0};
};

struct BufferAllocInfo {
/*! \brief The buffer access region. */
Region region;
/*! \brief The storage alignment information. */
std::vector<DimAlignInfo> dim_aligns;
/*!
* \brief The reallocated buffer with minimal size.
* \note The value if NullOpt if the buffer do not need reallocate (e.g parameter buffer).
*/
Buffer new_buffer;
};

/*! \brief Reallocate the buffers with minimal region. */
class BufferCompactor : public StmtExprMutator {
public:
static Stmt Compact(
const PrimFunc& f,
const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& regions,
const std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual>&
storage_align) {
// collect buffer allocation info for no-alias buffers
std::unordered_map<Var, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info;
for (const auto& kv : regions) {
const Buffer& buffer = kv.first;

// set dim alignment info
Region region = kv.second;
BufferAllocInfo alloc_info;
auto it = storage_align.find(buffer->data);
if (it != storage_align.end()) {
std::vector<DimAlignInfo> dim_aligns(buffer->shape.size());
for (const StorageAlignTuple& dim_align : (*it).second) {
ICHECK(dim_align.size() == 4);
int dim = dim_align[1]->value;
int factor = dim_align[2]->value;
int offset = dim_align[3]->value;
dim_aligns.at(dim) = {factor, offset};
}
alloc_info.dim_aligns = std::move(dim_aligns);
}

// prepare new buffer
Array<PrimExpr> shape = region.Map([](const Range& range) { return range->extent; });
Array<PrimExpr> strides;
if (alloc_info.dim_aligns.size()) {
ICHECK(alloc_info.dim_aligns.size() == shape.size());
strides.resize(shape.size());
PrimExpr stride = make_const(shape[0].dtype(), 1);
for (size_t i = shape.size(); i != 0; --i) {
size_t dim = i - 1;
if (alloc_info.dim_aligns[dim].align_factor != 0) {
PrimExpr factor = make_const(stride.dtype(), alloc_info.dim_aligns[dim].align_factor);
PrimExpr offset = make_const(stride.dtype(), alloc_info.dim_aligns[dim].align_offset);
stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
}
strides.Set(dim, stride);
stride = stride * shape[dim];
}
}
ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get());
n->shape = std::move(shape);
n->strides = std::move(strides);
alloc_info.new_buffer = Buffer(std::move(n));
alloc_info.region = region;
buffer_info.emplace(buffer->data, std::move(alloc_info));
}
BufferCompactor compactor(std::move(buffer_info));
Stmt stmt = compactor(f->body);
return stmt;
}

private:
/*! \brief The storage alignment for a dimension */
struct DimAlignInfo {
/*! \brief The factor of the alignment */
int align_factor{0};
/*! \brief The offset of the alignment */
int align_offset{0};
};

struct BufferAllocInfo {
/*! \brief The buffer access region. */
Region region;
/*! \brief The storage alignment information. */
std::vector<DimAlignInfo> dim_aligns;
/*!
* \brief The reallocated buffer with minimal size.
* \note The value if NullOpt if the buffer do not need reallocate (e.g parameter buffer).
*/
Buffer new_buffer;
};

explicit BufferCompactor(
std::unordered_map<Var, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info)
: buffer_info_(std::move(buffer_info)) {}
Expand Down Expand Up @@ -709,13 +652,76 @@ class BufferCompactor : public StmtExprMutator {
std::unordered_map<Var, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info_;
};

Array<PrimExpr> CalcStrides(const BufferAllocInfo& alloc_info, const Array<PrimExpr>& shape) {
std::vector<PrimExpr> strides;
if (alloc_info.dim_aligns.size()) {
ICHECK(alloc_info.dim_aligns.size() == shape.size());
strides.resize(shape.size());
PrimExpr stride = make_const(shape[0].dtype(), 1);
for (size_t i = shape.size(); i != 0; --i) {
size_t dim = i - 1;
DimAlignInfo info = alloc_info.dim_aligns[dim];
int align_factor = info.align_factor;
int align_offset = info.align_offset;
if (align_factor != 0) {
PrimExpr factor = make_const(stride.dtype(), align_factor);
PrimExpr offset = make_const(stride.dtype(), align_offset);
stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
}
strides[dim] = stride;
stride = stride * shape[dim];
}
}
return strides;
}

Stmt BufferCompactorCompact(
const PrimFunc& f,
const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& regions,
const std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual>&
storage_align) {
// collect buffer allocation info for no-alias buffers
std::unordered_map<Var, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info;
for (const auto& kv : regions) {
const Buffer& buffer = kv.first;
// set dim alignment info
Region region = kv.second;
BufferAllocInfo alloc_info;
auto it = storage_align.find(buffer->data);
if (it != storage_align.end()) {
std::vector<DimAlignInfo> dim_aligns(buffer->shape.size());
for (const StorageAlignTuple& dim_align : (*it).second) {
ICHECK(dim_align.size() == 4);
int dim = dim_align[1]->value;
int factor = dim_align[2]->value;
int offset = dim_align[3]->value;
dim_aligns.at(dim) = {factor, offset};
}
alloc_info.dim_aligns = std::move(dim_aligns);
}

// prepare new buffer
Array<PrimExpr> shape = region.Map([](const Range& range) { return range->extent; });
Array<PrimExpr> strides = CalcStrides(alloc_info, shape);
ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get());
n->shape = std::move(shape);
n->strides = std::move(strides);
alloc_info.new_buffer = Buffer(std::move(n));
alloc_info.region = region;
buffer_info.emplace(buffer->data, std::move(alloc_info));
}
BufferCompactor compactor(std::move(buffer_info));
Stmt stmt = compactor(f->body);
return stmt;
}

PrimFunc CompactBufferAllocation(PrimFunc f, bool is_strict) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
auto region = BufferAccessRegionCollector::Collect(f, /*collect_inbound=*/is_strict);
auto storage_align = CollectStorageAlignAnnotation(f->body);
fptr->body = BufferCompactor::Compact(f, region, storage_align);
fptr->body = BufferCompactorCompact(f, region, storage_align);
return f;
} else {
return f;
Expand Down

0 comments on commit dc745fc

Please sign in to comment.