This repo contains JAX-based version of the model introduced in Neural Additive Models: Interpretable Machine Learning with Neural Nets by R. Agarwal et.al 2021.
- jax
- optax
- haiku # used for implementing NN model
- torch # used for creating mini-batches
- numpy
- scikit-learn
Checkout the nam_regression_example.ipynb
notebook to see an example of using the model for the California housing Dataset