Skip to content

Slow jitting in module #5822

Answered by froystig
tetterl asked this question in Q&A
Feb 19, 2021 · 3 comments · 2 replies
Discussion options

You must be logged in to vote

The difference between m.f and f in your example is that m.f embeds a 10000000-entry f32 vector as a constant in the compiled computation (due to the static_argnums flag), whereas f takes it as a parameter. Timing the f call doesn't account for transferring that large vector to device memory. I suspect that timing the m.f one does, hence the longer wait.

One option is to write something along the lines of:

@jit
def _f(x, t): return x * t

class Module:
  # ...
  def f(self, x): return _f(x, self.t)

where t takes the role of what you call transformation in your example.

Replies: 3 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by froystig
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@harsh306
Comment options

@harsh306
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants
Converted from issue

This discussion was converted from issue #5790 on February 23, 2021 18:47.