Skip to content

Convert for loop stacking, potentially with remat_scan #3336

Answered by chiamp
mahdikooshkbaghi asked this question in Q&A
Discussion options

You must be logged in to vote

From my understanding, scan only works when the target arg is using the same module arg; i.e. you want to apply the same module/layer iteratively over an input. In the documentation, it says "you can do this [scan] when you have a sequence of identical layers that you want to apply iteratively to an input".
In your example since you are using MLP layers with different configurations, I don't think there's a more efficient implementation that can be written. With remat, you are reducing memory usage at the cost of more computation.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by mahdikooshkbaghi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants