From dea27c85cd6659106ee0366ca2a150c1c387a286 Mon Sep 17 00:00:00 2001 From: matteo-grella Date: Sun, 29 Oct 2023 23:28:29 +0100 Subject: [PATCH] Refactor `Matrix` interface to extend from `Tensor` --- ag/operator.go | 12 ++++++++++++ mat/dense.go | 4 ++-- mat/matrix.go | 32 ++------------------------------ mat/tensor.go | 16 +++++++++++++++- 4 files changed, 31 insertions(+), 33 deletions(-) diff --git a/ag/operator.go b/ag/operator.go index 165e0860..49e2ce47 100644 --- a/ag/operator.go +++ b/ag/operator.go @@ -121,6 +121,18 @@ func NewOperator(f AutoGradFunction) *Operator { return &Operator{fn: f} } +// SetAt sets the value at the given indices. +// It panics if the given indices are out of range. +func (o *Operator) SetAt(m mat.Tensor, indices ...int) { + o.Value().SetAt(m, indices...) +} + +// At returns the value at the given indices. +// It panics if the given indices are out of range. +func (o *Operator) At(indices ...int) mat.Tensor { + return o.Value().At(indices...) +} + // Run starts the execution of the operator, performing the forward pass. // If the optional async argument is set to true, the forward pass will be executed in a separate goroutine. // The function returns a pointer to the Operator, allowing for method chaining. diff --git a/mat/dense.go b/mat/dense.go index 6c22d79b..a6f9e721 100644 --- a/mat/dense.go +++ b/mat/dense.go @@ -110,13 +110,13 @@ func (d *Dense[T]) Zeros() { // SetAt sets the value m at the given indices. // It panics if the given indices are out of range. -func (d *Dense[T]) SetAt(m Matrix, indices ...int) { +func (d *Dense[T]) SetAt(m Tensor, indices ...int) { d.set(float.ValueOf[T](m.Item()), indices...) } // At returns the value at the given indices. // It panics if the given indices are out of range. -func (d *Dense[T]) At(i ...int) Matrix { +func (d *Dense[T]) At(i ...int) Tensor { return Scalar[T](d.at(i...)) } diff --git a/mat/matrix.go b/mat/matrix.go index 6243535d..d2dff0cd 100644 --- a/mat/matrix.go +++ b/mat/matrix.go @@ -15,15 +15,8 @@ import ( // such as element-wise addition, subtraction, product and matrix-matrix // multiplication. type Matrix interface { - // Shape returns the size in each dimension. - Shape() []int - // Dims returns the number of dimensions. - Dims() int - // Size returns the total number of elements. - Size() int - // Data returns the underlying data of the matrix, as a raw one-dimensional - // slice of values in row-major order. - Data() float.Slice + Tensor + // SetData sets the content of the matrix, copying the given raw // data representation as one-dimensional slice. SetData(data float.Slice) @@ -38,12 +31,6 @@ type Matrix interface { Item() float.Float // Zeros sets all the values of the matrix to zero. Zeros() - // SetAt sets the value at the given indices. - // It panics if the given indices are out of range. - SetAt(m Matrix, indices ...int) - // At returns the value at the given indices. - // It panics if the given indices are out of range. - At(indices ...int) Matrix // SetScalar sets the value at the given indices. // It panics if the given indices are out of range. SetScalar(v float.Float, indices ...int) @@ -235,22 +222,7 @@ type Matrix interface { // them as row vectors. NewStack(vs ...Matrix) Matrix - // Value returns the Matrix itself. - Value() Tensor - // Grad returns the accumulated gradients with the AccGrad method. - // A matrix full of zeros and the nil value are considered equivalent. - Grad() Tensor - // HasGrad reports whether there are accumulated gradients. - HasGrad() bool - // RequiresGrad reports whether the Matrix requires gradients. - // It is set by the SetRequiresGrad method or the functional options WithGrad. - RequiresGrad() bool - // SetRequiresGrad sets whether the Matrix requires gradients. SetRequiresGrad(bool) - // AccGrad accumulates the gradients. - AccGrad(gx Tensor) - // ZeroGrad zeroes the gradients, setting the value of Grad to nil. - ZeroGrad() } func init() { diff --git a/mat/tensor.go b/mat/tensor.go index c96f4dbc..c9d7ecbf 100644 --- a/mat/tensor.go +++ b/mat/tensor.go @@ -4,7 +4,11 @@ package mat -import "github.com/nlpodyssey/spago/mat/float" +import ( + "encoding/gob" + + "github.com/nlpodyssey/spago/mat/float" +) // Tensor represents an interface for a generic tensor. type Tensor interface { @@ -19,6 +23,12 @@ type Tensor interface { // Item returns the scalar value. // It panics if the matrix does not contain exactly one element. Item() float.Float + // SetAt sets the value at the given indices. + // It panics if the given indices are out of range. + SetAt(m Tensor, indices ...int) + // At returns the value at the given indices. + // It panics if the given indices are out of range. + At(indices ...int) Tensor // Value returns the value of the node. // In case of a leaf node, it returns the value of the underlying matrix. // In case of a non-leaf node, it returns the value of the operation performed during the forward pass. @@ -35,3 +45,7 @@ type Tensor interface { // ZeroGrad zeroes the gradients, setting the value of Grad to nil. ZeroGrad() } + +func init() { + gob.Register([]Tensor{}) +}