-
(Similar problem as in: #5693) I'm trying to write some modules that do some transformations. These transformations depend on something config-like and take quite a while to compute. Thus, the initialization of these transformations is done in the initialization of the module. The following presents and dummy example.
resulting in:
Using a pure function:
results in Does the capturing of the |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 2 replies
-
The difference between One option is to write something along the lines of: @jit
def _f(x, t): return x * t
class Module:
# ...
def f(self, x): return _f(x, self.t) where |
Beta Was this translation helpful? Give feedback.
-
Thanks for the explanation. That's quite surprising that the embedding takes such a long time. Unfortunately I have multiple nested modules and can only jit the most outer functions. |
Beta Was this translation helpful? Give feedback.
-
Embedding isn't what takes time per se, so much as transferring a large array from/to the device. Before it can be run, a program and its data must be sent to device memory. If an array is an embedded constant (via Your array is the latter. It is a large random array computed on device from a much smaller input (the RNG key). This line in your example only transfers the key, and then holds a pointer to the resulting array: transformation = random.uniform(key, (10000000,)) # expensive computation This avoids a large transfer. When later captured as an embedded constant, 10000000 floating point numbers are pulled from the device, embedded into the compiled program as a constant, and then sent back with the program. Using a parameter lets you accept the pointer directly instead. |
Beta Was this translation helpful? Give feedback.
The difference between
m.f
andf
in your example is thatm.f
embeds a 10000000-entry f32 vector as a constant in the compiled computation (due to thestatic_argnums
flag), whereasf
takes it as a parameter. Timing thef
call doesn't account for transferring that large vector to device memory. I suspect that timing them.f
one does, hence the longer wait.One option is to write something along the lines of:
where
t
takes the role of what you calltransformation
in your example.