Replies: 2 comments 16 replies
-
As written, you will have to re-compile your A better pattern would be to avoid having a list of train states, but instead have a single batched train state which you can index into using a dynamic agent index. In other words, use a struct-of-arrays pattern for your train states rather than the current list-of-structs pattern. A further benefit of this would be that you could replace the Does that make sense? |
Beta Was this translation helpful? Give feedback.
-
Yes, it's pretty big. pola_algorithm.py is probably the most important file. The |
Beta Was this translation helpful? Give feedback.
-
Hi folks!
I'm using JAX for MARL experiments using an opponent shaping algorithm called POLA. This algorithm is complicated and takes a long time to compile, between minutes and 12 hours, depending on options. I believe that one of the causes for how long the compilations are taking is that I have many jitted functions that get an agent index as an argument, and that argument is marked as static.
Here is a grossly simplified example of code that I hope helps explain the issue: https://gist.github.com/cool-RR/dd0b548c1990ec111e03c5b8e0863808
The
Population
class has multiple agents, i.e. multipleTrainState
objects in a list. Theupdate_agent
method trains one agent while depending on the behavior of all the other agents. As you can see on line 38, theagent
argument is marked as static. If I don't mark it as static, JAX would raise an exception, since I'm usingagent
as an array index later.This is a simplified example; in my real code I have a dozen different methods that call each other and they all have an
agent
argument that's marked as static.Am I right in my intution that this static argument is making the compilation of my program slower? Is there any way for me to make it so the
agent
argument would not be static?Thanks for your help,
Ram Rachum.
Beta Was this translation helpful? Give feedback.
All reactions