Skip to content

Commit

Permalink
Updated failure estimator (#3427)
Browse files Browse the repository at this point in the history
* Updated failure estimator

* Added simulator script

* Cleanup simulator script

* Update SortFunc

* Update SortFunc

* Update SortFunc

* Update SortFunc

* Update SortFunc

* Update SortFunc

* Update SortFunc

* Update SortFunc

* Update SortFunc

* Comments and cleanup
  • Loading branch information
severinson authored Feb 28, 2024
1 parent 211878b commit 8e52353
Show file tree
Hide file tree
Showing 31 changed files with 1,069 additions and 420 deletions.
12 changes: 5 additions & 7 deletions config/scheduler/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,8 @@ scheduling:
gangIdAnnotation: armadaproject.io/gangId
gangCardinalityAnnotation: armadaproject.io/gangCardinality
failureEstimatorConfig:
nodeSuccessProbabilityCordonThreshold: 0.1
queueSuccessProbabilityCordonThreshold: 0.05
nodeCordonTimeout: "10m"
queueCordonTimeout: "1m"
nodeEquilibriumFailureRate: 0.0167 # 1 per minute.
queueEquilibriumFailureRate: 0.0167 # 1 per minute.

# Optimised default parameters.
numInnerIterations: 10
innerOptimiserStepSize: 0.05
outerOptimiserStepSize: 0.05
outerOptimiserNesterovAcceleration: 0.2
17 changes: 9 additions & 8 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ require (
github.com/spf13/viper v1.15.0
github.com/stretchr/testify v1.8.4
github.com/weaveworks/promrus v1.2.0
golang.org/x/exp v0.0.0-20221031165847-c99f073a8326
golang.org/x/net v0.20.0
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225
golang.org/x/net v0.21.0
golang.org/x/oauth2 v0.16.0
golang.org/x/sync v0.5.0
golang.org/x/tools v0.6.0 // indirect
golang.org/x/sync v0.6.0
golang.org/x/tools v0.18.0 // indirect
google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 // indirect
google.golang.org/grpc v1.57.1
gopkg.in/yaml.v2 v2.4.0
Expand Down Expand Up @@ -89,6 +89,7 @@ require (
github.com/segmentio/fasthash v1.0.3
github.com/xitongsys/parquet-go v1.6.2
golang.org/x/time v0.3.0
gonum.org/v1/gonum v0.14.0
google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9
gopkg.in/yaml.v3 v3.0.1
)
Expand Down Expand Up @@ -184,10 +185,10 @@ require (
github.com/xitongsys/parquet-go-source v0.0.0-20200817004010-026bad9b25d0 // indirect
go.mongodb.org/mongo-driver v1.13.1 // indirect
go.uber.org/atomic v1.9.0 // indirect
golang.org/x/crypto v0.18.0 // indirect
golang.org/x/mod v0.9.0 // indirect
golang.org/x/sys v0.16.0 // indirect
golang.org/x/term v0.16.0 // indirect
golang.org/x/crypto v0.19.0 // indirect
golang.org/x/mod v0.15.0 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/term v0.17.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
google.golang.org/appengine v1.6.7 // indirect
Expand Down
34 changes: 18 additions & 16 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -857,8 +857,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
Expand All @@ -869,8 +869,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 h1:QfTh0HpN6hlw6D3vu8DAwC8pBIwikq0AI1evdm+FksE=
golang.org/x/exp v0.0.0-20221031165847-c99f073a8326/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
Expand All @@ -897,8 +897,8 @@ golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs=
golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0 h1:SernR4v+D55NyBH2QiEQrlBAnj1ECL6AGrA5+dPaMY8=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
Expand Down Expand Up @@ -947,8 +947,8 @@ golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
Expand All @@ -974,8 +974,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180821044426-4ea2f632f6e9/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
Expand Down Expand Up @@ -1049,14 +1049,14 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.0.0-20180810153555-6e3c4e7365dd/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down Expand Up @@ -1138,8 +1138,8 @@ golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ=
golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
Expand All @@ -1148,6 +1148,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk=
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
Expand Down
12 changes: 5 additions & 7 deletions internal/armada/configuration/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,11 @@ type PreemptionConfig struct {
// FailureEstimatorConfig contains config controlling node and queue success probability estimation.
// See the internal/scheduler/failureestimator package for details.
type FailureEstimatorConfig struct {
Disabled bool
NodeSuccessProbabilityCordonThreshold float64
QueueSuccessProbabilityCordonThreshold float64
NodeCordonTimeout time.Duration
QueueCordonTimeout time.Duration
NodeEquilibriumFailureRate float64
QueueEquilibriumFailureRate float64
Disabled bool
NumInnerIterations int `validate:"gt=0"`
InnerOptimiserStepSize float64 `validate:"gt=0"`
OuterOptimiserStepSize float64 `validate:"gt=0"`
OuterOptimiserNesterovAcceleration float64 `validate:"gte=0"`
}

// TODO: we can probably just typedef this to map[string]string
Expand Down
11 changes: 9 additions & 2 deletions internal/common/interfaces/interfaces.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
package interfaces

// DeepCopier expresses that the object can be deep-copied.
import "golang.org/x/exp/constraints"

// DeepCopier represents object that can be deep-copied.
type DeepCopier[T any] interface {
// DeepCopy returns a deep copy of the object.
DeepCopy() T
}

// Equaler expresses that objects can be compared for equality via the Equals method.
// Equaler represents objects can be compared for equality via the Equals method.
type Equaler[T any] interface {
// Returns true if both objects are equal.
Equal(T) bool
}

// Number represents any integer or floating-point number.
type Number interface {
constraints.Integer | constraints.Float
}
19 changes: 19 additions & 0 deletions internal/common/linalg/linalg.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package linalg

import "gonum.org/v1/gonum/mat"

// ExtendVecDense extends the length of vec in-place to be at least n.
func ExtendVecDense(vec *mat.VecDense, n int) *mat.VecDense {
if vec == nil {
return mat.NewVecDense(n, make([]float64, n))
}
rawVec := vec.RawVector()
d := n - rawVec.N
if d <= 0 {
return vec
}
rawVec.Data = append(rawVec.Data, make([]float64, d)...)
rawVec.N = n
vec.SetRawVector(rawVec)
return vec
}
45 changes: 45 additions & 0 deletions internal/common/linalg/lingalg_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package linalg

import (
"testing"

"github.com/stretchr/testify/assert"
"gonum.org/v1/gonum/mat"

armadaslices "github.com/armadaproject/armada/internal/common/slices"
)

func TestExtendVecDense(t *testing.T) {
tests := map[string]struct {
vec *mat.VecDense
n int
expected *mat.VecDense
}{
"nil vec": {
vec: nil,
n: 3,
expected: mat.NewVecDense(3, armadaslices.Zeros[float64](3)),
},
"extend": {
vec: mat.NewVecDense(1, armadaslices.Zeros[float64](1)),
n: 3,
expected: mat.NewVecDense(3, armadaslices.Zeros[float64](3)),
},
"extend unnecessary due to greater length": {
vec: mat.NewVecDense(3, armadaslices.Zeros[float64](3)),
n: 1,
expected: mat.NewVecDense(3, armadaslices.Zeros[float64](3)),
},
"extend unnecessary due to equal length": {
vec: mat.NewVecDense(3, armadaslices.Zeros[float64](3)),
n: 3,
expected: mat.NewVecDense(3, armadaslices.Zeros[float64](3)),
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
actual := ExtendVecDense(tc.vec, tc.n)
assert.Equal(t, tc.expected, actual)
})
}
}
44 changes: 44 additions & 0 deletions internal/common/optimisation/descent/descent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package descent

import (
"fmt"

"github.com/pkg/errors"
"gonum.org/v1/gonum/mat"

"github.com/armadaproject/armada/internal/common/armadaerrors"
)

// Gradient descent optimiser; see the following link for details:
// https://fluxml.ai/Flux.jl/stable/training/optimisers/
type Descent struct {
eta float64
}

func New(eta float64) (*Descent, error) {
if eta < 0 {
return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{
Name: "eta",
Value: eta,
Message: fmt.Sprintf("outside allowed range [0, Inf)"),
})
}
return &Descent{eta: eta}, nil
}

func MustNew(eta float64) *Descent {
opt, err := New(eta)
if err != nil {
panic(err)
}
return opt
}

func (o *Descent) Update(out, p *mat.VecDense, g mat.Vector) *mat.VecDense {
out.AddScaledVec(p, -o.eta, g)
return p
}

func (o *Descent) Extend(_ int) {
return
}
44 changes: 44 additions & 0 deletions internal/common/optimisation/descent/descent_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package descent

import (
"testing"

"github.com/stretchr/testify/assert"
"gonum.org/v1/gonum/mat"

armadaslices "github.com/armadaproject/armada/internal/common/slices"
)

func TestDescent(t *testing.T) {
tests := map[string]struct {
eta float64
p *mat.VecDense
g *mat.VecDense
expected *mat.VecDense
}{
"eta is zero": {
eta: 0.0,
p: mat.NewVecDense(2, armadaslices.Ones[float64](2)),
g: mat.NewVecDense(2, armadaslices.Ones[float64](2)),
expected: mat.NewVecDense(2, armadaslices.Ones[float64](2)),
},
"eta is non-zero": {
eta: 2.0,
p: mat.NewVecDense(2, armadaslices.Zeros[float64](2)),
g: mat.NewVecDense(2, armadaslices.Ones[float64](2)),
expected: func() *mat.VecDense {
rv := mat.NewVecDense(2, armadaslices.Ones[float64](2))
rv.ScaleVec(-2, rv)
return rv
}(),
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
opt := MustNew(tc.eta)
rv := opt.Update(tc.p, tc.p, tc.g)
assert.Equal(t, tc.p, rv)
assert.Equal(t, tc.expected, tc.p)
})
}
}
Loading

0 comments on commit 8e52353

Please sign in to comment.