Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-backend examples/generative/text_generation_with_miniature_gpt #1571

Merged
merged 1 commit into from
Nov 9, 2023

Conversation

mattdangerw
Copy link
Member

Manually set this to tf because jax apparently does not work.

The log output is off (tons of repeated progress bars) compared to the original. Not sure if this is a Keras 3 bug?

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I'd like to resolve the progress bar issue before merging this one.

# both `tensorflow` and `torch`. It does not work with JAX
# due to the behavior of `jax.numpy.tile` in a jit scope
# (used in `causal_attention_mask()`: `tile` in JAX does
# not support a dynamic `reps` argument.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember this issue. It's basically intractable apart from refactoring the example (I'm not sure the tile arg actually needs to be dynamic).

@@ -326,236 +349,1261 @@ model.fit(text_ds, verbose=2, epochs=25, callbacks=[text_gen_callback])
<div class="k-default-codeblock">
```
Epoch 1/25
391/391 - 135s - loss: 5.5949 - dense_2_loss: 5.5949

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend hand-scrubbing all logs like this because they don't add value in the rendered page.

@fchollet
Copy link
Member

fchollet commented Nov 4, 2023

Update: the progress bar is fine. You should add verbose=0 in the predict() call in TextGenerator.on_epoch_end.

@mattdangerw mattdangerw force-pushed the keras-3-text-generation-with-mini-gpt branch from 7a63697 to f9aebfa Compare November 9, 2023 03:13
@mattdangerw
Copy link
Member Author

Done finally! Sorry for the delay. All ready again.

@fchollet fchollet merged commit de5de18 into keras-3 Nov 9, 2023
3 of 5 checks passed
@fchollet fchollet deleted the keras-3-text-generation-with-mini-gpt branch November 28, 2023 00:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants