Skip to content

Commit

Permalink
update comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hjjq committed Sep 6, 2023
1 parent a7c1481 commit 6fbae2c
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions python/hidet/graph/transforms/conv_channel_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class PermutedOp:
"""

# Some imports will cause circular import errors if imported at top level of class
# Therefore we import those inside PermutedOp.get_scroped_ops() instead
# Therefore we import those inside PermutedOp.initialize_scoped_ops() instead.
from hidet.graph.ops.arithmetic import UnaryElementwiseOp, BinaryElementwiseOp
from hidet.graph.ops.transform import ConcatOp

Expand Down Expand Up @@ -73,8 +73,8 @@ def initialize_scoped_ops():
"""
The Operators imported within this function will cause circular imports if imported
at the top level of the class. As a temporary solution, we import them here and add them
to the class variable sets. This function will be called when the pass is instantiated.
TODO: Restructure the code to resolve the circular imports.
to the class variable tuples. This function will be called when the pass is instantiated.
TODO: Restructure the code to resolve the circular imports, and remove this function.
"""
from hidet.graph.ops.conv2d import Conv2dOp
from hidet.graph.ops.pool import MaxPool2dOp, AvgPool2dOp, AdaptiveAvgPool2dOp
Expand All @@ -99,7 +99,7 @@ def get_scoped_ops():
2. P is the set of attributes of Op; x_chnlast = x.transpose[0, 2, 3, 1];
and there exists another set of attributes P', such that
Op(x_chnlast, P')).transpose([0, 3, 1, 2]) == Op(x, P)
This list is incomplete. To support a new operator in the scope, implement its
TODO: This list is incomplete. To support a new operator in the scope, implement its
corresponding PermutedOp or use an existing one and update `get_permuted_op()`.
"""
return PermutedOp.scoped_operators
Expand Down Expand Up @@ -184,6 +184,7 @@ class ConvChannelLastPass(GraphPass):
1. Find a list of operators within scope (see `PermutedOp.get_scoped_op()`)
2. For operators within scope, transform its input to channel last and record this permutation.
3. For operators out of scope, convert its input back to channel first based on recorded permuations.
TODO: Not all CNNs benefit from this pass. Find a heuristic to enable/disable this pass for best performance.
"""

def span_from_nodes(self, usage: Dict[Tensor, List[Tuple[Operator, int]]], seeds: List[Operator]) -> List[Operator]:
Expand Down

0 comments on commit 6fbae2c

Please sign in to comment.