Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add state summaries for print and display #4438

Merged
merged 1 commit into from
Jan 10, 2025
Merged

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Dec 15, 2024

What does this PR do?

  • Adds parameter count and size per Variable type to both the penzai / nnx.display representation and to the string representation of NNX objects.
  • Adds colors to the __str__ representation if the console supports it, the __repr__ version doesn't add any colors in case a simpler format is needed. This comes through improvements to nnx.reprlib.
  • Slight simplification of the MNIST tutorial and Flax Basics.
from flax import nnx

class Block(nnx.Module):
  def __init__(self, din, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dout, rngs=rngs)
    self.bn = nnx.BatchNorm(dout, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)

  def __call__(self, x):
    return nnx.relu(self.dropout(self.bn(self.linear(x))))

class Foo(nnx.Module):
  def __init__(self, rngs: nnx.Rngs):
    self.block1 = Block(32, 128, rngs=rngs)
    self.block2 = Block(128, 10, rngs=rngs)

  def __call__(self, x):
    return self.block2(self.block1(x))

foo = Foo(nnx.Rngs(0))

nnx.display(foo)
print(foo)
Screenshot 2025-01-09 at 4 47 31 PM

@cgarciae cgarciae force-pushed the nnx-tabulate branch 2 times, most recently from c926c98 to 37bee51 Compare January 8, 2025 00:33
@cgarciae cgarciae marked this pull request as ready for review January 8, 2025 00:34
@cgarciae cgarciae force-pushed the nnx-tabulate branch 3 times, most recently from fd515c8 to e5247a0 Compare January 8, 2025 02:10
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@cgarciae cgarciae force-pushed the nnx-tabulate branch 2 times, most recently from 49ab76f to a7d60e7 Compare January 8, 2025 03:16
@cgarciae cgarciae force-pushed the nnx-tabulate branch 4 times, most recently from 4352578 to 7651177 Compare January 9, 2025 22:00
@cgarciae cgarciae changed the title [nnx] add tabulate [nnx] add state summaries for print and display Jan 9, 2025
@cgarciae cgarciae force-pushed the nnx-tabulate branch 5 times, most recently from ce53a54 to 1923e0d Compare January 9, 2025 22:58
@copybara-service copybara-service bot merged commit adbad95 into main Jan 10, 2025
19 checks passed
@copybara-service copybara-service bot deleted the nnx-tabulate branch January 10, 2025 00:02
copybara-service bot pushed a commit that referenced this pull request Jan 10, 2025
rollforward #4438

PiperOrigin-RevId: 714034644
copybara-service bot pushed a commit that referenced this pull request Jan 10, 2025
rollforward #4438

PiperOrigin-RevId: 714034644
copybara-service bot pushed a commit that referenced this pull request Jan 10, 2025
rollforward #4438

PiperOrigin-RevId: 714046512
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants