Skip to content

Commit

Permalink
Add a classifier option to convnets
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Jan 9, 2025
1 parent 4d7db0e commit 0b3de08
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion neurobayes/flax_nets/convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class FlaxConvNet(nn.Module):
target_dim: int
activation: str = 'tanh'
kernel_size: Union[int, Tuple[int, ...]] = 3
classification: bool = False # Explicit flag for classification tasks

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
Expand All @@ -51,7 +52,8 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = FlaxMLP(
hidden_dims=self.fc_layers,
target_dim=self.target_dim,
activation=self.activation)(x)
activation=self.activation,
classification=self.classification)(x)

return x

Expand Down

0 comments on commit 0b3de08

Please sign in to comment.