diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bee8a64 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..6d597c3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,7 @@ +Copyright © 2024 Parsiad Azimzadeh + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..a41a6fe --- /dev/null +++ b/README.md @@ -0,0 +1,88 @@ +
+ +
+ + + +Micrograd++ is a minimalistic wrapper around NumPy which adds support for automatic differentiation. +Designed as a learning tool, Micrograd++ provides an accessible entry point for those interested in understanding automatic differentiation or seeking a clean, educational resource. +Explore backpropagation and deepen your understanding of machine learning with Micrograd++. + +Micrograd++ draws inspiration from Andrej Karpathy's awesome [micrograd](https://github.com/karpathy/micrograd) library, prioritizing simplicity and readability over speed. +Unlike micrograd, which tackles scalar inputs, Micrograd++ supports tensor inputs (specifically, NumPy arrays). +This makes it possible to train larger networks. + +## Usage + +Micrograd++ is not yet pip-able. +Therefore, you will have to clone the Micrograd++ repository to your home directory and include it via the PYTHONPATH in any script or notebook you want to use Micrograd++ in: + +```python +import sys +sys.path.insert(0, os.path.expanduser("~/micrograd-pp/python")) +``` + +## Example: MNIST + +![](https://upload.wikimedia.org/wikipedia/commons/f/f7/MnistExamplesModified.png) + +[MNIST](https://en.wikipedia.org/wiki/MNIST_database) is a dataset of handwritten digits (0-9) commonly used for training and testing image processing systems. +It consists of 28x28 pixel grayscale images, with a total of 60,000 training samples and 10,000 test samples. +It's widely used in machine learning for digit recognition tasks. + +Below is an example of using Micrograd++ to train a simple [feedforward neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network) to recognize digits. + +```python +import micrograd_pp as mpp +import numpy as np + +mnist = mpp.datasets.load_mnist(normalize=True) +train_images, train_labels, test_images, test_labels = mnist + +# Flatten images +train_images = train_images.reshape(-1, 28 * 28) +test_images = test_images.reshape(-1, 28 * 28) + +# Drop extra training examples +trim = train_images.shape[0] % batch_sz +train_images = train_images[: train_images.shape[0] - trim] + +# Shuffle +indices = np.random.permutation(train_images.shape[0]) +train_images = train_images[indices] +train_labels = train_labels[indices] + +# Make batches +n_batches = train_images.shape[0] // batch_sz +train_images = np.split(train_images, n_batches) +train_labels = np.split(train_labels, n_batches) + +# Optimizer +opt = mpp.SGD(lr=0.01) + +# Feedforward neural network +model = mpp.Sequential( + mpp.Linear(28 * 28, 128, bias=False), + mpp.ReLU(), + mpp.Linear(128, 10, bias=False), +) + +# Train +accuracy = float("nan") +for epoch in range(n_epochs): + for batch_index in np.random.permutation(np.arange(n_batches)): + x = mpp.Constant(train_images[batch_index]) + y = train_labels[batch_index] + fx = model(x) + fx_max = fx.max(dim=1) + delta = fx - fx_max.expand(fx.shape) + log_sum_exp = delta.exp().sum(dim=1).log().squeeze() + loss = -(delta[np.arange(batch_sz), y] - log_sum_exp).sum() / batch_sz + loss.backward(opt=opt) + opt.step() + test_x = mpp.Constant(test_images) + test_fx = model(test_x) + pred_labels = np.argmax(test_fx.value, axis=1) + accuracy = (pred_labels == test_labels).mean().item() + print(f"Test accuracy at epoch {epoch}: {accuracy * 100:.2f}%") +``` diff --git a/logo.png b/logo.png new file mode 100644 index 0000000..dd7f54d Binary files /dev/null and b/logo.png differ diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..40d4ae3 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,3 @@ +{ + "extraPaths": ["python"] +} diff --git a/python/micrograd_pp/__init__.py b/python/micrograd_pp/__init__.py new file mode 100644 index 0000000..48b65aa --- /dev/null +++ b/python/micrograd_pp/__init__.py @@ -0,0 +1,18 @@ +from ._expr import Constant, Expr, Parameter, maximum, relu +from ._nn import Linear, ReLU, Sequential +from ._opt import SGD + +from . import datasets + +__all__ = ( + "Constant", + "Expr", + "Linear", + "Parameter", + "ReLU", + "Sequential", + "SGD", + "datasets", + "maximum", + "relu", +) diff --git a/python/micrograd_pp/_expr.py b/python/micrograd_pp/_expr.py new file mode 100644 index 0000000..4cbea65 --- /dev/null +++ b/python/micrograd_pp/_expr.py @@ -0,0 +1,548 @@ +from __future__ import annotations + +import itertools +from abc import ABC, abstractmethod +from collections import deque +from functools import lru_cache +from typing import Any, Callable, Sequence + +import numpy as np +import numpy.typing as npt + + +class Expr: + """Represents a differentiable expression in the graph. + + Parameters + ---------- + value + Value + children + Sequence of subexpressions (those through which backpropagation occurs) + label + Human-readable name + requires_grad + If the gradient is not required, backpropagation will stop at this expression (if unspecified, the gradient is + required if the gradient of at least one child is required) + """ + + def __init__( + self, + value: npt.NDArray, + children: Sequence[Expr] = (), + label: str | None = None, + requires_grad: bool | None = None, + ) -> None: + self._value = value + self._children = set(children) + self._label = label + if requires_grad is None: + requires_grad = any(child._requires_grad for child in children) + self._requires_grad = requires_grad + self._grad = None + + def __repr__(self) -> str: + d = {"value": self._value, "requires_grad": self._requires_grad} + if self._label is not None: + d["label"] = self._label + args = ", ".join(f"{k}={v}" for k, v in d.items()) + return f"_Expr({args})" + + def __matmul__(self, other: Any) -> Expr: + return _MatMul(self, other) + + def __add__(self, other: Any) -> Expr: + if isinstance(other, int): + other = float(other) + if isinstance(other, float): + return _AddScalar(self, other) + return _Add(self, other) + + def __getitem__(self, index: Any) -> Expr: + return _Slice(self, index=index) + + def __truediv__(self, other: Any) -> Expr: + return self * other ** (-1) + + def __mul__(self, other: Any) -> Expr: + if isinstance(other, int): + other = float(other) + if isinstance(other, float): + return _MultScalar(self, other) + return _Mult(self, other) + + def __neg__(self) -> Expr: + return self * (-1.0) + + def __pow__(self, pow: Any) -> Expr: + if not isinstance(pow, (int, float)): + msg = f"Expected int or float exponent; received {pow}" + raise ValueError(msg) + return _Pow(self, pow) + + def __radd__(self, other: Any) -> Expr: + return self + other + + def __rtruediv__(self, other: Any) -> Expr: + return self / other + + def __rmatmul__(self, other: Any) -> Expr: + return self @ other + + def __rmul__(self, other: Any) -> Expr: + return self * other + + def __rsub__(self, other: Any) -> Expr: + return (-self) + other + + def __sub__(self, other: Any) -> Expr: + return self + (-other) + + def _backward(self, grad: npt.NDArray) -> None: + del grad + + @lru_cache(maxsize=None) + def _get_nodes(self) -> deque[Expr]: + retval: deque[Expr] = deque() + if not self._requires_grad: + return retval + visited: set[Expr] = set() + marked: set[Expr] = set() + + def visit(node: Expr) -> None: + if node in visited: + return + if node in marked: + msg = "Detected cycle in gradient graph" + raise RuntimeError(msg) + marked.add(node) + for child in node._children: + if not child._requires_grad: + continue + visit(child) + marked.remove(node) + visited.add(node) + retval.appendleft(node) + + visit(self) + return retval + + def backward( + self, + init: np.ndarray | float = 1.0, + opt: Opt | None = None, + retain_grad: bool | None = None, + ): + """Perform backpropagation and return all affected parameters. + + Suppose we call ``loss.backward()``. + If ``loss`` is a scalar, then ``param.grad`` accumulates the derivative of ``loss`` with respect to ``param``. + Otherwise, it accumulates the derivative of ``loss.sum()`` with respect to ``param``. + If ``init`` is specified, it accumulates the derivative of ``(loss * init).sum()`` with respect to ``param``. + + Parameters + ---------- + init + Initial gradient + opt + If specified, the optimizer updates parameters using their respective gradients + retain_grad + Whether or not to deallocate gradients (if unspecified, gradients are deallocated if an optimizer is + specified) + """ + if not self._requires_grad: + msg = "Attempted to perform backward pass on an expression that does not require a gradient" + raise ValueError(msg) + if retain_grad is None: + retain_grad = opt is None + self._grad = np.empty_like(self._value) + self._grad[...] = init + nodes = self._get_nodes() + for node in nodes: + assert node._grad is not None + node._backward(node._grad) + if opt is not None and len(node._children) == 0: + opt.update_param(node) + if not retain_grad: + node._grad = None + + def exp(self) -> Expr: + """Return the element-wise exponential.""" + return _Exp(self) + + def expand(self, shape: tuple[int, ...]) -> Expr: + """Broadcast. + + Parameter + --------- + shape + Shape to broadcast to + """ + return _Expand(self, shape=shape) + + def log(self) -> Expr: + """Take the element-wise natural logarithm.""" + return _Log(self) + + def max(self, dim: int | tuple[int, ...] | None = None) -> Expr: + """Maximize across a dimension. + + Parameters + ---------- + dim + Axis or axes along which to operate. By default, all axes are used. + """ + return _Max(a=self, dim=dim) + + def set_label(self, label: str) -> None: + """Set the expression label.""" + self._label = label + + def squeeze(self, dim: int | tuple[int, ...] | None = None) -> Expr: + """Remove axes of length one. + + Parameters + ---------- + dim + One or more axes to remove. By default, all length one axes are removed. + """ + return _Squeeze(self, dim=dim) + + def sum(self, dim: int | tuple[int, ...] | None = None) -> Expr: + """Sum across one or more dimensions. + + Parameters + ---------- + dim + Axis or axes along which to operate. By default, all axes are used. + """ + return _Sum(self, dim=dim) + + def transpose(self, dim0: int, dim1: int) -> Expr: + """Transpose axes. + + Parameters + ---------- + dim0 + First dimension + dim1 + Second dimension + """ + return _Transpose(self, dim0=dim0, dim1=dim1) + + def update_grad(self, func: Callable[[], npt.NDArray]) -> None: + """Update the gradient by adding to it the output of a function. + + Note that the function is only invoked if the gradient is required. + """ + if not self._requires_grad: + return + if self._grad is None: + self._grad = np.zeros_like(self._value) + self._grad += func() + + def update_value(self, increment: npt.NDArray) -> None: + """Update the value by adding to it.""" + self._value += increment + + def unsqueeze(self, dim: int) -> Expr: + """Insert a dimension of size one at the specified position. + + Parameters + ---------- + dim + Position + """ + return _Unsqueeze(self, dim=dim) + + @property + def dtype(self) -> npt.DTypeLike: + """Data type.""" + return self._value.dtype + + @property + def value(self) -> npt.NDArray: + """Value.""" + return self._value.view() + + @property + def grad(self) -> npt.NDArray: + """Gradient.""" + if self._grad is None: + msg = "Attempted to view untracked gradient" + raise ValueError(msg) + return self._grad.view() + + @property + def ndim(self) -> int: + """Number of dimensions.""" + return self._value.ndim + + @property + def requires_grad(self) -> bool: + """Whether this expression requires its gradient be computed.""" + return self._requires_grad + + @property + def shape(self) -> tuple[int, ...]: + """Shape.""" + return self._value.shape + + +class Opt(ABC): + @abstractmethod + def update_param(self, param: Expr) -> None: + pass + + @abstractmethod + def step(self) -> None: + pass + + +class Constant(Expr): + """An expression that does not require a gradient and has no children. + + Parameters + ---------- + value + Value + label + Human-readable name + """ + + def __init__(self, value: npt.NDArray, label: str | None = None) -> None: + super().__init__(value=value, label=label) + + +class Parameter(Expr): + """An expression that requires a gradient but has no children. + + Parameters + ---------- + value + Value + label + Human-readable name + """ + + def __init__(self, c: npt.NDArray, label: str | None = None) -> None: + super().__init__(value=c, label=label, requires_grad=True) + + +def maximum(a: Expr, b: Expr) -> Expr: + """The element-wise maximum of two expressions.""" + return _Maximum(a, b) + + +def relu(expr: Expr) -> Expr: + """The positive part of an expression.""" + return _ReLU(expr) + + +class _Add(Expr): + def __init__(self, a: Expr, b: Expr) -> None: + _raise_if_not_same_shape(a, b) + super().__init__(value=a._value + b._value, children=(a, b)) + self._a = a + self._b = b + + def _backward(self, grad: npt.NDArray) -> None: + self._a.update_grad(lambda: grad) + self._b.update_grad(lambda: grad) + + +class _AddScalar(Expr): + def __init__(self, a: Expr, b: float) -> None: + super().__init__(value=a._value + b, children=(a,)) + self._a = a + + def _backward(self, grad: npt.NDArray) -> None: + self._a.update_grad(lambda: grad) + + +class _Exp(Expr): + def __init__(self, a: Expr) -> None: + super().__init__(value=np.exp(a._value), children=(a,)) + self._a = a + + def _backward(self, grad: npt.NDArray) -> None: + self._a.update_grad(lambda: grad * self._value) + + +class _Expand(Expr): + def __init__(self, a: Expr, shape: tuple[int, ...]) -> None: + super().__init__(value=np.broadcast_to(a._value, shape=shape), children=(a,)) + self._a = a + + def _backward(self, grad: npt.NDArray) -> None: + def func() -> npt.NDArray: + dim = tuple( + [ + self.ndim - 1 - i + for i, (m, n) in enumerate( + itertools.zip_longest( + reversed(self._a.shape), reversed(self.shape) + ) + ) + if m is None or m != n + ] + ) + return grad.sum(axis=dim, keepdims=True) + + self._a.update_grad(func) + + +class _Log(Expr): + def __init__(self, a: Expr) -> None: + super().__init__(value=np.log(a._value), children=(a,)) + self._a = a + + def _backward(self, grad: npt.NDArray) -> None: + self._a.update_grad(lambda: grad / self._a._value) + + +class _MatMul(Expr): + def __init__(self, a: Expr, b: Expr) -> None: + if not (a.ndim == 2 and b.ndim == 2): + msg = "Matrix multiplication currently only supports 2-D inputs" + raise NotImplementedError(msg) + super().__init__(value=a._value @ b._value, children=(a, b)) + self._a = a + self._b = b + + def _backward(self, grad: npt.NDArray) -> None: + self._a.update_grad(lambda: grad @ self._b._value.T) + self._b.update_grad(lambda: self._a._value.T @ grad) + + +class _Max(Expr): + def __init__(self, a: Expr, dim: int | tuple[int, ...] | None) -> None: + super().__init__(value=a._value.max(axis=dim, keepdims=True), children=(a,)) + self._a = a + self._dim = dim + + def _backward(self, grad: npt.NDArray) -> None: + def func() -> npt.NDArray: + # TODO(parsiad): Materializing a mask is expensive + mask = (self._a._value == self._value).astype(self._a._value.dtype) + mask /= mask.sum(axis=self._dim, keepdims=True) + return grad * mask + + self._a.update_grad(func) + + +class _Maximum(Expr): + def __init__(self, a: Expr, b: Expr) -> None: + _raise_if_not_same_shape(a, b) + super().__init__(value=np.maximum(a._value, b._value), children=(a, b)) + self._a = a + self._b = b + + def _backward(self, grad: npt.NDArray) -> None: + self._a.update_grad(lambda: grad * (self._a._value >= self._b._value)) + self._b.update_grad(lambda: grad * (self._a._value < self._b._value)) + + +class _Mult(Expr): + def __init__(self, a: Expr, b: Expr) -> None: + _raise_if_not_same_shape(a, b) + super().__init__(value=a._value * b._value, children=(a, b)) + self._a = a + self._b = b + + def _backward(self, grad: npt.NDArray) -> None: + self._a.update_grad(lambda: grad * self._b._value) + self._b.update_grad(lambda: self._a._value * grad) + + +class _MultScalar(Expr): + def __init__(self, a: Expr, b: float) -> None: + super().__init__(value=a._value * b, children=(a,)) + self._a = a + self._b = b + + def _backward(self, grad: npt.NDArray) -> None: + self._a.update_grad(lambda: grad * self._b) + + +class _Pow(Expr): + def __init__(self, a: Expr, pow: int | float) -> None: + super().__init__(value=a._value**pow, children=(a,)) + self._a = a + self._pow = pow + + def _backward(self, grad: npt.NDArray) -> None: + self._a.update_grad( + lambda: grad * self._pow * self._a._value ** (self._pow - 1) + ) + + +class _ReLU(Expr): + def __init__(self, a: Expr) -> None: + super().__init__(value=np.maximum(a._value, 0.0), children=(a,)) + self._a = a + + def _backward(self, grad: npt.NDArray) -> None: + self._a.update_grad(lambda: grad * (self._a._value > 0.0)) + + +class _Slice(Expr): + def __init__(self, a: Expr, index: Any) -> None: + super().__init__(value=a._value[index], children=(a,)) + self._a = a + self._index = index + + def _backward(self, grad: npt.NDArray) -> None: + def func() -> npt.NDArray: + backprop_grad = np.zeros_like(self._a._value) + backprop_grad[self._index] = grad + return backprop_grad + + self._a.update_grad(func) + + +class _Squeeze(Expr): + def __init__(self, a: Expr, dim: int | tuple[int, ...] | None) -> None: + super().__init__(value=a._value.squeeze(axis=dim), children=(a,)) + self._a = a + + def _backward(self, grad: npt.NDArray) -> None: + self._a.update_grad(lambda: grad.reshape(self._a.shape)) + + +class _Sum(Expr): + def __init__(self, a: Expr, dim: int | tuple[int, ...] | None) -> None: + super().__init__(value=a._value.sum(axis=dim, keepdims=True), children=(a,)) + self._a = a + + def _backward(self, grad: npt.NDArray) -> None: + self._a.update_grad(lambda: grad) + + +class _Transpose(Expr): + def __init__(self, a: Expr, dim0: int, dim1: int) -> None: + self._axes = list(range(a.ndim)) + self._axes[dim0] = dim1 + self._axes[dim1] = dim0 + super().__init__(value=np.transpose(a._value, self._axes), children=(a,)) + self._a = a + + def _backward(self, grad: npt.NDArray) -> None: + self._a.update_grad(lambda: np.transpose(grad, self._axes)) + + +class _Unsqueeze(Expr): + def __init__(self, a: Expr, dim: int) -> None: + super().__init__(value=np.expand_dims(a._value, axis=dim), children=(a,)) + self._a = a + self._dim = dim + + def _backward(self, grad: npt.NDArray) -> None: + self._a.update_grad(lambda: grad.squeeze(axis=self._dim)) + + +def _raise_if_not_same_shape(*exprs: Expr): + shape = next(iter(exprs)).shape + if not all(expr.shape == shape for expr in exprs): + msg = "Operands must be the same shape" + raise ValueError(msg) diff --git a/python/micrograd_pp/_nn.py b/python/micrograd_pp/_nn.py new file mode 100644 index 0000000..c447a85 --- /dev/null +++ b/python/micrograd_pp/_nn.py @@ -0,0 +1,82 @@ +from collections.abc import Callable +import numpy as np + +from ._expr import Expr, Parameter, relu + + +Module = Callable[[Expr], Expr] + + +class Linear: + """Linear layer. + + Parameters + ---------- + in_features + Number of input features + out_features + Number of output features + bias + Whether or not to include a bias + label + Human-readable name + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + label: str | None = None, + ) -> None: + self._a = Parameter( + np.random.randn(out_features, in_features) / np.sqrt(in_features), + label=None if label is None else f"{label}/weight", + ) + if bias: + self._b = Parameter( + np.zeros((out_features,)), + label=None if label is None else f"{label}/bias", + ) + else: + self._b = None + + def __call__(self, x: Expr) -> Expr: + retval = x @ self._a.transpose(0, 1) + if self._b is not None: + retval = retval + self._b + return retval + + def __repr__(self) -> str: + return f"Linear({self._a.shape[0]}, {self._a.shape[1]})" + + +class ReLU: + """Modular wrapper around the ReLU function.""" + + def __call__(self, expr: Expr) -> Expr: + return relu(expr) + + def __repr__(self) -> str: + return "ReLU()" + + +class Sequential: + """Sequential container of modules. + + Parameters + ---------- + modules + Zero or more modules + """ + + def __init__(self, *modules: Module) -> None: + self._modules = modules + + def __call__(self, x: Expr) -> Expr: + for module in self._modules: + x = module(x) + return x + + def __repr__(self) -> str: + return f"Sequential({', '.join(str(module) for module in self._modules)})" diff --git a/python/micrograd_pp/_opt.py b/python/micrograd_pp/_opt.py new file mode 100644 index 0000000..12c842e --- /dev/null +++ b/python/micrograd_pp/_opt.py @@ -0,0 +1,20 @@ +from ._expr import Expr, Opt + + +class SGD(Opt): + """Performs stochastic gradient descent. + + Parameters + ---------- + lr + Learning rate + """ + + def __init__(self, lr: float) -> None: + self._lr = lr + + def update_param(self, param: Expr) -> None: + param.update_value(-self._lr * param.grad) + + def step(self) -> None: + pass diff --git a/tests/test_expr.py b/tests/test_expr.py new file mode 100644 index 0000000..cadc936 --- /dev/null +++ b/tests/test_expr.py @@ -0,0 +1,215 @@ +import itertools +from typing import Generator + +import numpy as np +import pytest + +import micrograd_pp as mpp + +DIMS = [0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2), None] + + +@pytest.fixture(autouse=True) +def run_before_and_after_tests() -> Generator[None, None, None]: + np.random.seed(0) + yield + + +def test_add() -> None: + a = np.random.randn(3, 2) + b = np.random.randn(3, 2) + a_ = mpp.Parameter(a) + b_ = mpp.Parameter(b) + c_ = a_ + b_ + c_.backward() + grad = np.ones_like(a) + np.testing.assert_equal(a_.grad, grad) + np.testing.assert_equal(b_.grad, grad) + + +def test_add_bcast_fails() -> None: + a = np.random.randn(3, 2) + b = np.random.randn(3, 1) + a_ = mpp.Parameter(a) + b_ = mpp.Parameter(b) + with pytest.raises(ValueError): + c_ = a_ + b_ + del c_ + + +def test_add_scalar() -> None: + a = np.random.randn(3, 2) + a_ = mpp.Parameter(a) + c = 2.0 + b_ = c + a_ + b_.backward() + grad = np.ones_like(a) + np.testing.assert_equal(a_.grad, grad) + + +def test_exp() -> None: + a = np.random.randn(3, 2) + a_ = mpp.Parameter(a) + b_ = a_.exp() + b_.backward() + np.testing.assert_equal(a_.grad, b_.value) + + +@pytest.mark.parametrize("dim", DIMS) +def test_expand(dim: int | tuple[int, ...] | None) -> None: + expand_shape = [4, 3, 2] + shape = expand_shape.copy() + if dim is None: + dim = (0, 2) + if isinstance(dim, int): + dim = (dim,) + n_copies = 1 + for d in dim: + n_copies *= shape[d] + shape[d] = 1 + a = np.random.randn(shape[0], shape[1], shape[2]) + a_ = mpp.Parameter(a) + b_ = a_.expand(tuple(expand_shape)) + b_.backward() + grad = np.full_like(a_.value, fill_value=n_copies) + np.testing.assert_equal(a_.grad, grad) + + +def test_log() -> None: + a = np.random.randn(3, 2) ** 2 + a_ = mpp.Parameter(a) + c_ = a_.log() + c_.backward() + np.testing.assert_equal(a_.grad, 1.0 / a_.value) + + +def test_matmul() -> None: + a = np.random.randn(4, 3) + b = np.random.randn(3, 2) + a_ = mpp.Parameter(a) + b_ = mpp.Parameter(b) + c_ = a_ @ b_ + c_.backward() + for i, j in itertools.product(range(a.shape[0]), range(a.shape[1])): + h = np.zeros_like(a) + h[i, j] = 1e-6 + d = (((a + h) @ b - a @ b) / 1e-6).sum() + assert pytest.approx(d) == a_.grad[i, j] + for i, j in itertools.product(range(b.shape[0]), range(b.shape[1])): + h = np.zeros_like(b) + h[i, j] = 1e-6 + d = ((a @ (b + h) - a @ b) / 1e-6).sum() + assert pytest.approx(d) == b_.grad[i, j] + + +@pytest.mark.parametrize("dim", DIMS) +def test_max(dim: int | tuple[int, ...] | None) -> None: + a = np.random.randn(4, 3, 2) + a_ = mpp.Parameter(a) + b_ = a_.max(dim=dim) + b_.backward() + grad = (np.max(a, axis=dim, keepdims=True) == a).astype(a.dtype) + np.testing.assert_equal(a_.grad, grad) + + +def test_maximum() -> None: + a = np.random.randn(3, 2) + b = np.random.randn(3, 2) + a_ = mpp.Parameter(a) + b_ = mpp.Parameter(b) + c_ = mpp.maximum(a_, b_) + c_.backward() + grad = a > b + np.testing.assert_equal(a_.grad, grad) + np.testing.assert_equal(b_.grad, ~grad) + + +def test_mult() -> None: + a = np.random.randn(3, 2) + b = np.random.randn(3, 2) + a_ = mpp.Parameter(a) + b_ = mpp.Parameter(b) + c_ = a_ * b_ + c_.backward() + np.testing.assert_equal(a_.grad, b) + np.testing.assert_equal(b_.grad, a) + + +def test_mult_bcast_fails() -> None: + a = np.random.randn(3, 2) + b = np.random.randn(3, 1) + a_ = mpp.Parameter(a) + b_ = mpp.Parameter(b) + with pytest.raises(ValueError): + c_ = a_ * b_ + del c_ + + +def test_mult_scalar() -> None: + a = np.random.randn(3, 2) + a_ = mpp.Parameter(a) + c = 2.0 + b_ = c * a_ + b_.backward() + grad = np.full_like(a, fill_value=c) + np.testing.assert_equal(a_.grad, grad) + + +def test_pow() -> None: + a = np.random.randn(3, 2) + a_ = mpp.Parameter(a) + c_ = a_**3 + c_.backward() + np.testing.assert_equal(a_.grad, 3 * a**2) + + +def test_relu() -> None: + a = np.random.randn(3, 2) + a_ = mpp.Parameter(a) + b_ = mpp.relu(a_) + b_.backward() + np.testing.assert_equal(a_.grad, a > 0.0) + + +@pytest.mark.parametrize("dim", DIMS) +def test_squeeze(dim: int | tuple[int, ...] | None) -> None: + shape = [4, 3, 2] + if dim is None: + dim = (0, 2) + if isinstance(dim, int): + dim = (dim,) + for d in dim: + shape[d] = 1 + a = np.random.randn(shape[0], shape[1], shape[2]) + a_ = mpp.Parameter(a) + b_ = a_.squeeze(dim=dim) + b_.backward() + grad = np.ones_like(a) + np.testing.assert_equal(a_.grad, grad) + + +@pytest.mark.parametrize("dim", DIMS) +def test_sum(dim: int | tuple[int, ...] | None) -> None: + a = np.random.randn(4, 3, 2) + a_ = mpp.Parameter(a) + b_ = a_.sum(dim=dim) + b_.backward() + grad = np.ones_like(a) + np.testing.assert_equal(a_.grad, grad) + + +def test_transpose() -> None: + a = np.random.randn(4, 3) + a_ = mpp.Parameter(a) + b_ = a_.transpose(0, 1) + np.testing.assert_equal(b_.value, a.T) + + +@pytest.mark.parametrize("dim", [0, 1, 2]) +def test_unsqueeze(dim: int) -> None: + a = np.random.randn(4, 3, 2) + a_ = mpp.Parameter(a) + b_ = a_.unsqueeze(dim) + b_.backward() + grad = np.ones_like(a) + np.testing.assert_equal(a_.grad, grad) diff --git a/tests/test_mnist.py b/tests/test_mnist.py new file mode 100644 index 0000000..b3d0207 --- /dev/null +++ b/tests/test_mnist.py @@ -0,0 +1,64 @@ +import numpy as np +import pytest + +import micrograd_pp as mpp + + +@pytest.fixture(autouse=True) +def run_before_and_after_tests(): + np.random.seed(0) + yield + + +def test_mnist(batch_sz: int = 64, n_epochs: int = 3): + mnist = mpp.datasets.load_mnist(normalize=True) + train_images, train_labels, test_images, test_labels = mnist + + # Flatten images + train_images = train_images.reshape(-1, 28 * 28) + test_images = test_images.reshape(-1, 28 * 28) + + # Drop extra training examples + trim = train_images.shape[0] % batch_sz + train_images = train_images[: train_images.shape[0] - trim] + + # Shuffle + indices = np.random.permutation(train_images.shape[0]) + train_images = train_images[indices] + train_labels = train_labels[indices] + + # Make batches + n_batches = train_images.shape[0] // batch_sz + train_images = np.split(train_images, n_batches) + train_labels = np.split(train_labels, n_batches) + + # Optimizer + opt = mpp.SGD(lr=0.01) + + # Feedforward neural network + model = mpp.Sequential( + mpp.Linear(28 * 28, 128, bias=False), + mpp.ReLU(), + mpp.Linear(128, 10, bias=False), + ) + + # Train + accuracy = float("nan") + for epoch in range(n_epochs): + for batch_index in np.random.permutation(np.arange(n_batches)): + x = mpp.Constant(train_images[batch_index]) + y = train_labels[batch_index] + fx = model(x) + fx_max = fx.max(dim=1) + delta = fx - fx_max.expand(fx.shape) + log_sum_exp = delta.exp().sum(dim=1).log().squeeze() + loss = -(delta[np.arange(batch_sz), y] - log_sum_exp).sum() / batch_sz + loss.backward(opt=opt) + opt.step() + test_x = mpp.Constant(test_images) + test_fx = model(test_x) + pred_labels = np.argmax(test_fx.value, axis=1) + accuracy = (pred_labels == test_labels).mean().item() + print(f"Test accuracy at epoch {epoch}: {accuracy * 100:.2f}%") + + assert accuracy >= 0.9 diff --git a/tests/test_opt.py b/tests/test_opt.py new file mode 100644 index 0000000..9f87041 --- /dev/null +++ b/tests/test_opt.py @@ -0,0 +1,32 @@ +import numpy as np +import pytest + +import micrograd_pp as mpp + + +@pytest.fixture(autouse=True) +def run_before_and_after_tests(): + np.random.seed(0) + yield + + +def test_mse(): + n = 10 + coef = np.random.randn(3, 1) + coef_hat = np.random.randn(3, 1) + x = np.random.randn(n, 3) + ε = 0.0 * np.random.randn(n, 1) + y = x @ coef + ε + + coef_hat_ = mpp.Parameter(coef_hat) + x_ = mpp.Constant(x) + y_ = mpp.Constant(y) + + opt = mpp.SGD(lr=0.1) + for _ in range(150): + y_pred_ = x_ @ coef_hat_ + mse = ((y_pred_ - y_) ** 2).sum() / n + mse.backward(opt=opt) + opt.step() + + np.testing.assert_allclose(coef, coef_hat)