Model Inspection: Computing Model FLOPs #3830
-
I am currently looking to evaluate the FLOPs of a model architecture with token compression methods such as pruning and merging. I am using the Is anyone else using My model architecture repo can be found here. Token compression attention module here. Token merging/pruning methods here. I get errors in my attention layers that contain merging/pruning functions. In particular, my pruning method computes the top_k token embeddings based on an importance score, when computing FLOPs it seems like I am getting empty arrays when trying to compute indices of top_k token embeddings as detailed in below trace:
Any advice from the Flax dev community is appreciated. Update: I am able to evaluate the FLOPs of my CompressedMultiHeadDotProductAttention just fine from a Colab notebook where I only inspect this layer and not the overall model architecture. I am still debugging potential causes of the error I am facing, I thought it may be due to my module setup for the overall architecture but this is still TBC. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
#1854 resolves my issue for now. When required I can also compute layer specific FLOPs with this API. |
Beta Was this translation helpful? Give feedback.
#1854 resolves my issue for now. When required I can also compute layer specific FLOPs with this API.