Need help using jax.grad -- getting type error in loss function #14607
Unanswered
timnewsham
asked this question in
Q&A
Replies: 2 comments 6 replies
-
Can you include the traceback of the error you're seeing? Just a guess without running your code; perhaps you need |
Beta Was this translation helpful? Give feedback.
2 replies
-
My training data in the stack trace is:
all values of I'm using
|
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi, I need some help using jax.grad. I'm getting a type error in my loss function, defined as:
when computing
error = expected - actual
theexpected
field is a jax array and theactual
field seems to be some wrapped type that the gradient code is using to track data.The entire project (not large) is in the attached file, and invoked with
python test_train.py
.jax_help.tgz
The error I received is:
Beta Was this translation helpful? Give feedback.
All reactions