Skip to content

Commit

Permalink
[luci/pass] Update TaggedShapeAnaylzer methods interface (#14150)
Browse files Browse the repository at this point in the history
This PR updates TaggedShapeAnaylzer methods to get vector instead of CircleNode.

ONE-DCO-1.0-Signed-off-by: seunghui youn <sseung.youn@samsung.com>
  • Loading branch information
zetwhite authored Oct 6, 2024
1 parent 75df9a0 commit 72aea5c
Showing 1 changed file with 45 additions and 60 deletions.
105 changes: 45 additions & 60 deletions compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ bool extract_const(const luci::CircleConst *const_node, std::vector<int32_t> &va

struct TagDim final
{
TagDim(int32_t v) : value(v) {}

int32_t value;
std::vector<uint8_t> tags;
};
Expand All @@ -88,16 +90,12 @@ class TaggedShapeAnalyzer final
public:
bool init(const luci::CircleTranspose *, const luci::CircleReshape *,
const luci::CircleTranspose *);

template <loco::DataType DType> bool can_remove_transposes();
bool can_remove_transposes();

private:
void init_shape_with_tag(const luci::CircleNode *);

template <loco::DataType PermType> void analyze_transpose(const luci::CircleTranspose *);

bool analyze_reshape(const luci::CircleReshape *);

void init_shape_with_tag(const std::vector<int32_t> &);
void analyze_transpose(const std::vector<int32_t> &);
bool analyze_reshape(const std::vector<int32_t> &);
bool verify_tag() const;

private:
Expand All @@ -116,31 +114,30 @@ class TaggedShapeAnalyzer final
};

/**
* @brief initalize _shape with input tensor named in_tensor
* @brief initalize _shape with input tensor named 'in_shape'
*
* @note 'tags' are attached to non-1 valued dimension.
*/
void TaggedShapeAnalyzer::init_shape_with_tag(const luci::CircleNode *in_tensor)
void TaggedShapeAnalyzer::init_shape_with_tag(const std::vector<int32_t> &in_shape)
{
_shape.clear();
uint8_t tag = START_TAG;

for (uint32_t i = 0; i < in_tensor->rank(); i++)
for (auto i : in_shape)
{
TagDim dim;
{
dim.value = in_tensor->dim(i).value();
if (dim.value != 1)
dim.tags.push_back(tag++);
}
TagDim dim(i);

if (dim.value != 1)
dim.tags.push_back(tag++);

_shape.push_back(dim);
}
}

/**
* @brief update _shape based on 'Transpose' permutation value
* @brief update _shape based on 'perm'
*
* @example Let's assume Transpose(perm=0, 3, 1, 2) is given to [before] _shape.
* @example Let's assume perm={0, 3, 1, 2} is given to [before] _shape.
*
* This function reordered the Dims' order based on permutaiton value.
*
Expand All @@ -152,28 +149,24 @@ void TaggedShapeAnalyzer::init_shape_with_tag(const luci::CircleNode *in_tensor)
* - value : (1) (448) (7) (7)
* - tags : (-) (2) (0) (1)
*/
template <loco::DataType PermType>
void TaggedShapeAnalyzer::analyze_transpose(const luci::CircleTranspose *transpose_node)
void TaggedShapeAnalyzer::analyze_transpose(const std::vector<int32_t> &perm)
{
const luci::CircleConst *perm_node = loco::must_cast<luci::CircleConst *>(transpose_node->perm());
assert(perm_node->dtype() == PermType);
assert(_shape.size() == perm.size());

TagShape new_shape;
const auto size = perm_node->size<PermType>();
for (uint32_t i = 0; i < size; i++)
for (auto i : perm)
{
auto perm_idx = perm_node->at<PermType>(i);
new_shape.push_back(_shape.at(perm_idx));
new_shape.push_back(_shape.at(i));
}
_shape = new_shape;
}

/**
* @brief update _shape based on 'Reshape' shape value
* @brief update _shape based on 'new_shape'
*
* @return False, if it determined that removing transposes is impossible
* @return False, if it fails to update _shape
*
* @example Let's assume Reshape(shape=1, 448, 49) is given to [before] _shape.
* @example Let's assume new_shape={1, 448, 49} is given to [before] _shape.
*
* [before] _shape :
* - value : (1) (448) (7) (7)
Expand All @@ -183,24 +176,11 @@ void TaggedShapeAnalyzer::analyze_transpose(const luci::CircleTranspose *transpo
* - value : (1) (448) (49)
* - tags : (-) (2) (0, 1)
*/
bool TaggedShapeAnalyzer::analyze_reshape(const luci::CircleReshape *reshape_node)
bool TaggedShapeAnalyzer::analyze_reshape(const std::vector<int32_t> &new_shape)
{
// At least one element must be in reshape's output-tensor.
if (reshape_node->rank() <= 0)
if (new_shape.size() <= 0)
return false;

// Create new_shape based on reshape_node
TagShape new_shape;
for (uint32_t i = 0; i < reshape_node->rank(); i++)
{
if (not reshape_node->dim(i).known())
return false;

TagDim dim;
dim.value = reshape_node->dim(i).value();
new_shape.push_back(dim);
}

// indexing for _shape [old_shape_start_idx, old_shape_end_idx)
uint32_t old_shape_start_idx = 0;
uint32_t old_shape_end_idx = 1;
Expand All @@ -225,39 +205,45 @@ bool TaggedShapeAnalyzer::analyze_reshape(const luci::CircleReshape *reshape_nod
return true;
};

// Add tags from '_shape' to the 'new_shape'
// Create 'new_tagged_shape' based on 'new_shape'
TagShape new_tagged_shape;

uint32_t new_shape_idx = 0;
while (new_shape_idx < new_shape.size())
{
TagDim &target_dim = new_shape[new_shape_idx];
auto target_dim = new_shape[new_shape_idx];

// Ignore dim == 1
if (target_dim.value == 1)
if (target_dim == 1)
{
new_tagged_shape.emplace_back(1);
new_shape_idx++;
continue;
}

while (old_shape_product < target_dim.value)
while (old_shape_product < target_dim)
{
if (expand_range() == false)
break;
}

if (old_shape_product != target_dim.value)
if (old_shape_product != target_dim)
return false;

assert(old_shape_product == target_dim.value);
assert(old_shape_product == target_dim);

TagDim dim(target_dim);
for (uint32_t idx = old_shape_start_idx; idx < old_shape_end_idx; idx++)
{
const auto &old_tags = _shape[idx].tags;
target_dim.tags.insert(target_dim.tags.end(), old_tags.begin(), old_tags.end());
dim.tags.insert(dim.tags.end(), old_tags.begin(), old_tags.end());
}
new_tagged_shape.push_back(dim);

new_shape_idx++;
move_to_next_range();
}
_shape = new_shape;
_shape = new_tagged_shape;
return true;
}

Expand Down Expand Up @@ -369,20 +355,19 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose,
* Transpose has no effect in final shape, which they can be removed as
* unnecessary Ops.
*/
template <loco::DataType DType> bool TaggedShapeAnalyzer::can_remove_transposes()
bool TaggedShapeAnalyzer::can_remove_transposes()
{
assert(_in != nullptr && _front_transpose != nullptr && _mid_reshape != nullptr &&
_back_transpose != nullptr);

// TODO: Update under methods to use std::vector<int32_t&> intead of CircleNode*
init_shape_with_tag(_in);
init_shape_with_tag(_in_shape_v);

analyze_transpose<DType>(_front_transpose);
analyze_transpose(_front_perm_v);

if (not analyze_reshape(_mid_reshape))
if (not analyze_reshape(_mid_shape_v))
return false;

analyze_transpose<DType>(_back_transpose);
analyze_transpose(_back_perm_v);

if (not verify_tag())
return false;
Expand Down Expand Up @@ -454,7 +439,7 @@ bool remove_unnecessary_transpose(luci::CircleTranspose *node)
if (not analyzer.init(front_transpose, mid_reshape, back_transpose))
return false;

if (not analyzer.can_remove_transposes<loco::DataType::S32>())
if (not analyzer.can_remove_transposes())
return false;

// repalce with new_node
Expand Down

0 comments on commit 72aea5c

Please sign in to comment.