How to do the pruning when you use jax? #5669
Unanswered
qkrclrl701
asked this question in
Q&A
Replies: 1 comment
-
I found that parameter is list of tuple of <class 'jaxlib.xla_extension.Buffer'> of <class 'jax.interpreters.xla._DeviceArray'> those Buffer and DeviceArray can be dealt with the same way as np.ndarray is dealt, and using them, pruning can be easily implemented. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I want to prune the network:
stax.serial(
stax.Dense(512, W_std=1, b_std=0.05), stax.Relu(),
stax.Dense(512, W_std=1, b_std=0.05), stax.Relu(),
stax.Dense(10, W_std=1, b_std=0.05)
)
I have to use stax building blocks to construct my network, since I use neural tangents.
I can access the parameters of the network, but I found that parameters have a type of list of <class 'jaxlib.xla_extension.Buffer'>, and I don't know an effective way to deal with those data structure.
I tried to simply make a mask with the same size as parameter, and go through all the parameters to flatten live neurons and find the percentile value, and again go through all live neurons and prune the neurons with value less than percentile value.
However, simply iterating through all the parameters is too time-consuming with this relatively small network.
Is there a way to speed up the pruning process, or any functions support pruning in jax?
Beta Was this translation helpful? Give feedback.
All reactions