diff --git a/python/hidet/graph/transforms/conv_channel_last.py b/python/hidet/graph/transforms/conv_channel_last.py index 161f23a0b..6e958790d 100644 --- a/python/hidet/graph/transforms/conv_channel_last.py +++ b/python/hidet/graph/transforms/conv_channel_last.py @@ -26,7 +26,7 @@ class PermutedOp: """ # Some imports will cause circular import errors if imported at top level of class - # Therefore we import those inside PermutedOp.get_scroped_ops() instead + # Therefore we import those inside PermutedOp.initialize_scoped_ops() instead. from hidet.graph.ops.arithmetic import UnaryElementwiseOp, BinaryElementwiseOp from hidet.graph.ops.transform import ConcatOp @@ -73,8 +73,8 @@ def initialize_scoped_ops(): """ The Operators imported within this function will cause circular imports if imported at the top level of the class. As a temporary solution, we import them here and add them - to the class variable sets. This function will be called when the pass is instantiated. - TODO: Restructure the code to resolve the circular imports. + to the class variable tuples. This function will be called when the pass is instantiated. + TODO: Restructure the code to resolve the circular imports, and remove this function. """ from hidet.graph.ops.conv2d import Conv2dOp from hidet.graph.ops.pool import MaxPool2dOp, AvgPool2dOp, AdaptiveAvgPool2dOp @@ -99,7 +99,7 @@ def get_scoped_ops(): 2. P is the set of attributes of Op; x_chnlast = x.transpose[0, 2, 3, 1]; and there exists another set of attributes P', such that Op(x_chnlast, P')).transpose([0, 3, 1, 2]) == Op(x, P) - This list is incomplete. To support a new operator in the scope, implement its + TODO: This list is incomplete. To support a new operator in the scope, implement its corresponding PermutedOp or use an existing one and update `get_permuted_op()`. """ return PermutedOp.scoped_operators @@ -184,6 +184,7 @@ class ConvChannelLastPass(GraphPass): 1. Find a list of operators within scope (see `PermutedOp.get_scoped_op()`) 2. For operators within scope, transform its input to channel last and record this permutation. 3. For operators out of scope, convert its input back to channel first based on recorded permuations. + TODO: Not all CNNs benefit from this pass. Find a heuristic to enable/disable this pass for best performance. """ def span_from_nodes(self, usage: Dict[Tensor, List[Tuple[Operator, int]]], seeds: List[Operator]) -> List[Operator]: