From 93ae88e239d89cbe0e4c88b38c2a7c06bec6f643 Mon Sep 17 00:00:00 2001 From: Marco Nicola Date: Sat, 4 Nov 2023 15:42:11 +0100 Subject: [PATCH] Fix gradfn.ReduceMean.Backward not respecting operator shape --- CHANGELOG.md | 5 +++++ mat/gradfn/reducemean.go | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c40775e4..b3c92833 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Fixed + +- `gradfn.ReduceMean.Backward` was not respecting the operand's shape, causing + incompatible-shape error when reduce-mean was applied to non-vector values. + ## [1.1.0] - 2023-10-30 ### Changed diff --git a/mat/gradfn/reducemean.go b/mat/gradfn/reducemean.go index 142a6fc6..dd8c5702 100644 --- a/mat/gradfn/reducemean.go +++ b/mat/gradfn/reducemean.go @@ -42,7 +42,7 @@ func (r *ReduceMean[O]) Backward(gy mat.Tensor) error { x := r.x.Value().(mat.Matrix) size := x.Size() v := gy.Item().F64() / float64(size) - gx := x.NewMatrix(mat.WithShape(size), mat.WithBacking(mat.CreateInitializedSlice(size, v))) + gx := x.NewMatrix(mat.WithShape(x.Shape()...), mat.WithBacking(mat.CreateInitializedSlice(size, v))) r.x.AccGrad(gx) } return nil