Skip to content

Commit

Permalink
Refactor Matrix interface to extend from Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
matteo-grella committed Oct 29, 2023
1 parent 9244ad1 commit dea27c8
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 33 deletions.
12 changes: 12 additions & 0 deletions ag/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ func NewOperator(f AutoGradFunction) *Operator {
return &Operator{fn: f}
}

// SetAt sets the value at the given indices.
// It panics if the given indices are out of range.
func (o *Operator) SetAt(m mat.Tensor, indices ...int) {
o.Value().SetAt(m, indices...)
}

// At returns the value at the given indices.
// It panics if the given indices are out of range.
func (o *Operator) At(indices ...int) mat.Tensor {
return o.Value().At(indices...)
}

// Run starts the execution of the operator, performing the forward pass.
// If the optional async argument is set to true, the forward pass will be executed in a separate goroutine.
// The function returns a pointer to the Operator, allowing for method chaining.
Expand Down
4 changes: 2 additions & 2 deletions mat/dense.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ func (d *Dense[T]) Zeros() {

// SetAt sets the value m at the given indices.
// It panics if the given indices are out of range.
func (d *Dense[T]) SetAt(m Matrix, indices ...int) {
func (d *Dense[T]) SetAt(m Tensor, indices ...int) {
d.set(float.ValueOf[T](m.Item()), indices...)
}

// At returns the value at the given indices.
// It panics if the given indices are out of range.
func (d *Dense[T]) At(i ...int) Matrix {
func (d *Dense[T]) At(i ...int) Tensor {
return Scalar[T](d.at(i...))
}

Expand Down
32 changes: 2 additions & 30 deletions mat/matrix.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,8 @@ import (
// such as element-wise addition, subtraction, product and matrix-matrix
// multiplication.
type Matrix interface {
// Shape returns the size in each dimension.
Shape() []int
// Dims returns the number of dimensions.
Dims() int
// Size returns the total number of elements.
Size() int
// Data returns the underlying data of the matrix, as a raw one-dimensional
// slice of values in row-major order.
Data() float.Slice
Tensor

// SetData sets the content of the matrix, copying the given raw
// data representation as one-dimensional slice.
SetData(data float.Slice)
Expand All @@ -38,12 +31,6 @@ type Matrix interface {
Item() float.Float
// Zeros sets all the values of the matrix to zero.
Zeros()
// SetAt sets the value at the given indices.
// It panics if the given indices are out of range.
SetAt(m Matrix, indices ...int)
// At returns the value at the given indices.
// It panics if the given indices are out of range.
At(indices ...int) Matrix
// SetScalar sets the value at the given indices.
// It panics if the given indices are out of range.
SetScalar(v float.Float, indices ...int)
Expand Down Expand Up @@ -235,22 +222,7 @@ type Matrix interface {
// them as row vectors.
NewStack(vs ...Matrix) Matrix

// Value returns the Matrix itself.
Value() Tensor
// Grad returns the accumulated gradients with the AccGrad method.
// A matrix full of zeros and the nil value are considered equivalent.
Grad() Tensor
// HasGrad reports whether there are accumulated gradients.
HasGrad() bool
// RequiresGrad reports whether the Matrix requires gradients.
// It is set by the SetRequiresGrad method or the functional options WithGrad.
RequiresGrad() bool
// SetRequiresGrad sets whether the Matrix requires gradients.
SetRequiresGrad(bool)
// AccGrad accumulates the gradients.
AccGrad(gx Tensor)
// ZeroGrad zeroes the gradients, setting the value of Grad to nil.
ZeroGrad()
}

func init() {
Expand Down
16 changes: 15 additions & 1 deletion mat/tensor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

package mat

import "github.com/nlpodyssey/spago/mat/float"
import (
"encoding/gob"

"github.com/nlpodyssey/spago/mat/float"
)

// Tensor represents an interface for a generic tensor.
type Tensor interface {
Expand All @@ -19,6 +23,12 @@ type Tensor interface {
// Item returns the scalar value.
// It panics if the matrix does not contain exactly one element.
Item() float.Float
// SetAt sets the value at the given indices.
// It panics if the given indices are out of range.
SetAt(m Tensor, indices ...int)
// At returns the value at the given indices.
// It panics if the given indices are out of range.
At(indices ...int) Tensor
// Value returns the value of the node.
// In case of a leaf node, it returns the value of the underlying matrix.
// In case of a non-leaf node, it returns the value of the operation performed during the forward pass.
Expand All @@ -35,3 +45,7 @@ type Tensor interface {
// ZeroGrad zeroes the gradients, setting the value of Grad to nil.
ZeroGrad()
}

func init() {
gob.Register([]Tensor{})
}

0 comments on commit dea27c8

Please sign in to comment.