-
Notifications
You must be signed in to change notification settings - Fork 136
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Updated failure estimator * Added simulator script * Cleanup simulator script * Update SortFunc * Update SortFunc * Update SortFunc * Update SortFunc * Update SortFunc * Update SortFunc * Update SortFunc * Update SortFunc * Update SortFunc * Comments and cleanup
- Loading branch information
1 parent
211878b
commit 8e52353
Showing
31 changed files
with
1,069 additions
and
420 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,20 @@ | ||
package interfaces | ||
|
||
// DeepCopier expresses that the object can be deep-copied. | ||
import "golang.org/x/exp/constraints" | ||
|
||
// DeepCopier represents object that can be deep-copied. | ||
type DeepCopier[T any] interface { | ||
// DeepCopy returns a deep copy of the object. | ||
DeepCopy() T | ||
} | ||
|
||
// Equaler expresses that objects can be compared for equality via the Equals method. | ||
// Equaler represents objects can be compared for equality via the Equals method. | ||
type Equaler[T any] interface { | ||
// Returns true if both objects are equal. | ||
Equal(T) bool | ||
} | ||
|
||
// Number represents any integer or floating-point number. | ||
type Number interface { | ||
constraints.Integer | constraints.Float | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package linalg | ||
|
||
import "gonum.org/v1/gonum/mat" | ||
|
||
// ExtendVecDense extends the length of vec in-place to be at least n. | ||
func ExtendVecDense(vec *mat.VecDense, n int) *mat.VecDense { | ||
if vec == nil { | ||
return mat.NewVecDense(n, make([]float64, n)) | ||
} | ||
rawVec := vec.RawVector() | ||
d := n - rawVec.N | ||
if d <= 0 { | ||
return vec | ||
} | ||
rawVec.Data = append(rawVec.Data, make([]float64, d)...) | ||
rawVec.N = n | ||
vec.SetRawVector(rawVec) | ||
return vec | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
package linalg | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"gonum.org/v1/gonum/mat" | ||
|
||
armadaslices "github.com/armadaproject/armada/internal/common/slices" | ||
) | ||
|
||
func TestExtendVecDense(t *testing.T) { | ||
tests := map[string]struct { | ||
vec *mat.VecDense | ||
n int | ||
expected *mat.VecDense | ||
}{ | ||
"nil vec": { | ||
vec: nil, | ||
n: 3, | ||
expected: mat.NewVecDense(3, armadaslices.Zeros[float64](3)), | ||
}, | ||
"extend": { | ||
vec: mat.NewVecDense(1, armadaslices.Zeros[float64](1)), | ||
n: 3, | ||
expected: mat.NewVecDense(3, armadaslices.Zeros[float64](3)), | ||
}, | ||
"extend unnecessary due to greater length": { | ||
vec: mat.NewVecDense(3, armadaslices.Zeros[float64](3)), | ||
n: 1, | ||
expected: mat.NewVecDense(3, armadaslices.Zeros[float64](3)), | ||
}, | ||
"extend unnecessary due to equal length": { | ||
vec: mat.NewVecDense(3, armadaslices.Zeros[float64](3)), | ||
n: 3, | ||
expected: mat.NewVecDense(3, armadaslices.Zeros[float64](3)), | ||
}, | ||
} | ||
for name, tc := range tests { | ||
t.Run(name, func(t *testing.T) { | ||
actual := ExtendVecDense(tc.vec, tc.n) | ||
assert.Equal(t, tc.expected, actual) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package descent | ||
|
||
import ( | ||
"fmt" | ||
|
||
"github.com/pkg/errors" | ||
"gonum.org/v1/gonum/mat" | ||
|
||
"github.com/armadaproject/armada/internal/common/armadaerrors" | ||
) | ||
|
||
// Gradient descent optimiser; see the following link for details: | ||
// https://fluxml.ai/Flux.jl/stable/training/optimisers/ | ||
type Descent struct { | ||
eta float64 | ||
} | ||
|
||
func New(eta float64) (*Descent, error) { | ||
if eta < 0 { | ||
return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ | ||
Name: "eta", | ||
Value: eta, | ||
Message: fmt.Sprintf("outside allowed range [0, Inf)"), | ||
}) | ||
} | ||
return &Descent{eta: eta}, nil | ||
} | ||
|
||
func MustNew(eta float64) *Descent { | ||
opt, err := New(eta) | ||
if err != nil { | ||
panic(err) | ||
} | ||
return opt | ||
} | ||
|
||
func (o *Descent) Update(out, p *mat.VecDense, g mat.Vector) *mat.VecDense { | ||
out.AddScaledVec(p, -o.eta, g) | ||
return p | ||
} | ||
|
||
func (o *Descent) Extend(_ int) { | ||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package descent | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"gonum.org/v1/gonum/mat" | ||
|
||
armadaslices "github.com/armadaproject/armada/internal/common/slices" | ||
) | ||
|
||
func TestDescent(t *testing.T) { | ||
tests := map[string]struct { | ||
eta float64 | ||
p *mat.VecDense | ||
g *mat.VecDense | ||
expected *mat.VecDense | ||
}{ | ||
"eta is zero": { | ||
eta: 0.0, | ||
p: mat.NewVecDense(2, armadaslices.Ones[float64](2)), | ||
g: mat.NewVecDense(2, armadaslices.Ones[float64](2)), | ||
expected: mat.NewVecDense(2, armadaslices.Ones[float64](2)), | ||
}, | ||
"eta is non-zero": { | ||
eta: 2.0, | ||
p: mat.NewVecDense(2, armadaslices.Zeros[float64](2)), | ||
g: mat.NewVecDense(2, armadaslices.Ones[float64](2)), | ||
expected: func() *mat.VecDense { | ||
rv := mat.NewVecDense(2, armadaslices.Ones[float64](2)) | ||
rv.ScaleVec(-2, rv) | ||
return rv | ||
}(), | ||
}, | ||
} | ||
for name, tc := range tests { | ||
t.Run(name, func(t *testing.T) { | ||
opt := MustNew(tc.eta) | ||
rv := opt.Update(tc.p, tc.p, tc.g) | ||
assert.Equal(t, tc.p, rv) | ||
assert.Equal(t, tc.expected, tc.p) | ||
}) | ||
} | ||
} |
Oops, something went wrong.