diff --git a/config/scheduler/config.yaml b/config/scheduler/config.yaml index 7f29baa6db9..739127ce978 100644 --- a/config/scheduler/config.yaml +++ b/config/scheduler/config.yaml @@ -140,10 +140,8 @@ scheduling: gangIdAnnotation: armadaproject.io/gangId gangCardinalityAnnotation: armadaproject.io/gangCardinality failureEstimatorConfig: - nodeSuccessProbabilityCordonThreshold: 0.1 - queueSuccessProbabilityCordonThreshold: 0.05 - nodeCordonTimeout: "10m" - queueCordonTimeout: "1m" - nodeEquilibriumFailureRate: 0.0167 # 1 per minute. - queueEquilibriumFailureRate: 0.0167 # 1 per minute. - + # Optimised default parameters. + numInnerIterations: 10 + innerOptimiserStepSize: 0.05 + outerOptimiserStepSize: 0.05 + outerOptimiserNesterovAcceleration: 0.2 diff --git a/go.mod b/go.mod index 02768131905..a896a0a5309 100644 --- a/go.mod +++ b/go.mod @@ -49,11 +49,11 @@ require ( github.com/spf13/viper v1.15.0 github.com/stretchr/testify v1.8.4 github.com/weaveworks/promrus v1.2.0 - golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 - golang.org/x/net v0.20.0 + golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 + golang.org/x/net v0.21.0 golang.org/x/oauth2 v0.16.0 - golang.org/x/sync v0.5.0 - golang.org/x/tools v0.6.0 // indirect + golang.org/x/sync v0.6.0 + golang.org/x/tools v0.18.0 // indirect google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 // indirect google.golang.org/grpc v1.57.1 gopkg.in/yaml.v2 v2.4.0 @@ -89,6 +89,7 @@ require ( github.com/segmentio/fasthash v1.0.3 github.com/xitongsys/parquet-go v1.6.2 golang.org/x/time v0.3.0 + gonum.org/v1/gonum v0.14.0 google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9 gopkg.in/yaml.v3 v3.0.1 ) @@ -184,10 +185,10 @@ require ( github.com/xitongsys/parquet-go-source v0.0.0-20200817004010-026bad9b25d0 // indirect go.mongodb.org/mongo-driver v1.13.1 // indirect go.uber.org/atomic v1.9.0 // indirect - golang.org/x/crypto v0.18.0 // indirect - golang.org/x/mod v0.9.0 // indirect - golang.org/x/sys v0.16.0 // indirect - golang.org/x/term v0.16.0 // indirect + golang.org/x/crypto v0.19.0 // indirect + golang.org/x/mod v0.15.0 // indirect + golang.org/x/sys v0.17.0 // indirect + golang.org/x/term v0.17.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect google.golang.org/appengine v1.6.7 // indirect diff --git a/go.sum b/go.sum index 564da9d52ab..6b0692e99e5 100644 --- a/go.sum +++ b/go.sum @@ -857,8 +857,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= -golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -869,8 +869,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 h1:QfTh0HpN6hlw6D3vu8DAwC8pBIwikq0AI1evdm+FksE= -golang.org/x/exp v0.0.0-20221031165847-c99f073a8326/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ= +golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -897,8 +897,8 @@ golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= -golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.15.0 h1:SernR4v+D55NyBH2QiEQrlBAnj1ECL6AGrA5+dPaMY8= +golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -947,8 +947,8 @@ golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= -golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -974,8 +974,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= -golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180821044426-4ea2f632f6e9/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1049,14 +1049,14 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE= -golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= +golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20180810153555-6e3c4e7365dd/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1138,8 +1138,8 @@ golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ= +golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1148,6 +1148,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk= golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= +gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= +gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= diff --git a/internal/armada/configuration/types.go b/internal/armada/configuration/types.go index 3ac46696dd7..4cbcdb5f257 100644 --- a/internal/armada/configuration/types.go +++ b/internal/armada/configuration/types.go @@ -295,13 +295,11 @@ type PreemptionConfig struct { // FailureEstimatorConfig contains config controlling node and queue success probability estimation. // See the internal/scheduler/failureestimator package for details. type FailureEstimatorConfig struct { - Disabled bool - NodeSuccessProbabilityCordonThreshold float64 - QueueSuccessProbabilityCordonThreshold float64 - NodeCordonTimeout time.Duration - QueueCordonTimeout time.Duration - NodeEquilibriumFailureRate float64 - QueueEquilibriumFailureRate float64 + Disabled bool + NumInnerIterations int `validate:"gt=0"` + InnerOptimiserStepSize float64 `validate:"gt=0"` + OuterOptimiserStepSize float64 `validate:"gt=0"` + OuterOptimiserNesterovAcceleration float64 `validate:"gte=0"` } // TODO: we can probably just typedef this to map[string]string diff --git a/internal/common/interfaces/interfaces.go b/internal/common/interfaces/interfaces.go index 91abe495611..3275e3622bd 100644 --- a/internal/common/interfaces/interfaces.go +++ b/internal/common/interfaces/interfaces.go @@ -1,13 +1,20 @@ package interfaces -// DeepCopier expresses that the object can be deep-copied. +import "golang.org/x/exp/constraints" + +// DeepCopier represents object that can be deep-copied. type DeepCopier[T any] interface { // DeepCopy returns a deep copy of the object. DeepCopy() T } -// Equaler expresses that objects can be compared for equality via the Equals method. +// Equaler represents objects can be compared for equality via the Equals method. type Equaler[T any] interface { // Returns true if both objects are equal. Equal(T) bool } + +// Number represents any integer or floating-point number. +type Number interface { + constraints.Integer | constraints.Float +} diff --git a/internal/common/linalg/linalg.go b/internal/common/linalg/linalg.go new file mode 100644 index 00000000000..c807ca651d0 --- /dev/null +++ b/internal/common/linalg/linalg.go @@ -0,0 +1,19 @@ +package linalg + +import "gonum.org/v1/gonum/mat" + +// ExtendVecDense extends the length of vec in-place to be at least n. +func ExtendVecDense(vec *mat.VecDense, n int) *mat.VecDense { + if vec == nil { + return mat.NewVecDense(n, make([]float64, n)) + } + rawVec := vec.RawVector() + d := n - rawVec.N + if d <= 0 { + return vec + } + rawVec.Data = append(rawVec.Data, make([]float64, d)...) + rawVec.N = n + vec.SetRawVector(rawVec) + return vec +} diff --git a/internal/common/linalg/lingalg_test.go b/internal/common/linalg/lingalg_test.go new file mode 100644 index 00000000000..200d9c0aacf --- /dev/null +++ b/internal/common/linalg/lingalg_test.go @@ -0,0 +1,45 @@ +package linalg + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "gonum.org/v1/gonum/mat" + + armadaslices "github.com/armadaproject/armada/internal/common/slices" +) + +func TestExtendVecDense(t *testing.T) { + tests := map[string]struct { + vec *mat.VecDense + n int + expected *mat.VecDense + }{ + "nil vec": { + vec: nil, + n: 3, + expected: mat.NewVecDense(3, armadaslices.Zeros[float64](3)), + }, + "extend": { + vec: mat.NewVecDense(1, armadaslices.Zeros[float64](1)), + n: 3, + expected: mat.NewVecDense(3, armadaslices.Zeros[float64](3)), + }, + "extend unnecessary due to greater length": { + vec: mat.NewVecDense(3, armadaslices.Zeros[float64](3)), + n: 1, + expected: mat.NewVecDense(3, armadaslices.Zeros[float64](3)), + }, + "extend unnecessary due to equal length": { + vec: mat.NewVecDense(3, armadaslices.Zeros[float64](3)), + n: 3, + expected: mat.NewVecDense(3, armadaslices.Zeros[float64](3)), + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + actual := ExtendVecDense(tc.vec, tc.n) + assert.Equal(t, tc.expected, actual) + }) + } +} diff --git a/internal/common/optimisation/descent/descent.go b/internal/common/optimisation/descent/descent.go new file mode 100644 index 00000000000..c57051084ac --- /dev/null +++ b/internal/common/optimisation/descent/descent.go @@ -0,0 +1,44 @@ +package descent + +import ( + "fmt" + + "github.com/pkg/errors" + "gonum.org/v1/gonum/mat" + + "github.com/armadaproject/armada/internal/common/armadaerrors" +) + +// Gradient descent optimiser; see the following link for details: +// https://fluxml.ai/Flux.jl/stable/training/optimisers/ +type Descent struct { + eta float64 +} + +func New(eta float64) (*Descent, error) { + if eta < 0 { + return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ + Name: "eta", + Value: eta, + Message: fmt.Sprintf("outside allowed range [0, Inf)"), + }) + } + return &Descent{eta: eta}, nil +} + +func MustNew(eta float64) *Descent { + opt, err := New(eta) + if err != nil { + panic(err) + } + return opt +} + +func (o *Descent) Update(out, p *mat.VecDense, g mat.Vector) *mat.VecDense { + out.AddScaledVec(p, -o.eta, g) + return p +} + +func (o *Descent) Extend(_ int) { + return +} diff --git a/internal/common/optimisation/descent/descent_test.go b/internal/common/optimisation/descent/descent_test.go new file mode 100644 index 00000000000..83571b7ea46 --- /dev/null +++ b/internal/common/optimisation/descent/descent_test.go @@ -0,0 +1,44 @@ +package descent + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "gonum.org/v1/gonum/mat" + + armadaslices "github.com/armadaproject/armada/internal/common/slices" +) + +func TestDescent(t *testing.T) { + tests := map[string]struct { + eta float64 + p *mat.VecDense + g *mat.VecDense + expected *mat.VecDense + }{ + "eta is zero": { + eta: 0.0, + p: mat.NewVecDense(2, armadaslices.Ones[float64](2)), + g: mat.NewVecDense(2, armadaslices.Ones[float64](2)), + expected: mat.NewVecDense(2, armadaslices.Ones[float64](2)), + }, + "eta is non-zero": { + eta: 2.0, + p: mat.NewVecDense(2, armadaslices.Zeros[float64](2)), + g: mat.NewVecDense(2, armadaslices.Ones[float64](2)), + expected: func() *mat.VecDense { + rv := mat.NewVecDense(2, armadaslices.Ones[float64](2)) + rv.ScaleVec(-2, rv) + return rv + }(), + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + opt := MustNew(tc.eta) + rv := opt.Update(tc.p, tc.p, tc.g) + assert.Equal(t, tc.p, rv) + assert.Equal(t, tc.expected, tc.p) + }) + } +} diff --git a/internal/common/optimisation/nesterov/nesterov.go b/internal/common/optimisation/nesterov/nesterov.go new file mode 100644 index 00000000000..050f91f5e7f --- /dev/null +++ b/internal/common/optimisation/nesterov/nesterov.go @@ -0,0 +1,60 @@ +package nesterov + +import ( + "fmt" + "math" + + "github.com/pkg/errors" + "gonum.org/v1/gonum/mat" + + "github.com/armadaproject/armada/internal/common/armadaerrors" + "github.com/armadaproject/armada/internal/common/linalg" +) + +// Nesterov accelerated gradient descent optimiser; see the following link for details: +// https://fluxml.ai/Flux.jl/stable/training/optimisers/ +type Nesterov struct { + eta float64 + rho float64 + vel *mat.VecDense +} + +func New(eta, rho float64) (*Nesterov, error) { + if eta < 0 { + return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ + Name: "eta", + Value: eta, + Message: fmt.Sprintf("outside allowed range [0, Inf)"), + }) + } + if rho < 0 || rho >= 1 { + return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ + Name: "rho", + Value: rho, + Message: fmt.Sprintf("outside allowed range [0, 1)"), + }) + } + return &Nesterov{eta: eta, rho: rho}, nil +} + +func MustNew(eta, rho float64) *Nesterov { + opt, err := New(eta, rho) + if err != nil { + panic(err) + } + return opt +} + +func (o *Nesterov) Update(out, p *mat.VecDense, g mat.Vector) *mat.VecDense { + out.CopyVec(p) + out.AddScaledVec(out, math.Pow(o.rho, 2), o.vel) + out.AddScaledVec(out, -(1+o.rho)*o.eta, g) + + o.vel.ScaleVec(o.rho, o.vel) + o.vel.AddScaledVec(o.vel, -o.eta, g) + return p +} + +func (o *Nesterov) Extend(n int) { + o.vel = linalg.ExtendVecDense(o.vel, n) +} diff --git a/internal/common/optimisation/nesterov/nesterov_test.go b/internal/common/optimisation/nesterov/nesterov_test.go new file mode 100644 index 00000000000..24df99f2120 --- /dev/null +++ b/internal/common/optimisation/nesterov/nesterov_test.go @@ -0,0 +1,83 @@ +package nesterov + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "gonum.org/v1/gonum/mat" + + armadaslices "github.com/armadaproject/armada/internal/common/slices" +) + +func TestNesterov(t *testing.T) { + tests := map[string]struct { + eta float64 + rho float64 + p0 *mat.VecDense + gs []*mat.VecDense + expecteds []*mat.VecDense + }{ + "eta is zero": { + eta: 0.0, + rho: 0.9, + p0: mat.NewVecDense(2, armadaslices.Ones[float64](2)), + gs: []*mat.VecDense{ + mat.NewVecDense(2, armadaslices.Ones[float64](2)), + mat.NewVecDense(2, armadaslices.Ones[float64](2)), + }, + expecteds: []*mat.VecDense{ + mat.NewVecDense(2, armadaslices.Ones[float64](2)), + mat.NewVecDense(2, armadaslices.Ones[float64](2)), + }, + }, + "rho is zero": { + eta: 2.0, + rho: 0.0, + p0: mat.NewVecDense(2, armadaslices.Zeros[float64](2)), + gs: []*mat.VecDense{ + mat.NewVecDense(2, armadaslices.Ones[float64](2)), + mat.NewVecDense(2, armadaslices.Ones[float64](2)), + }, + expecteds: func() []*mat.VecDense { + rv := make([]*mat.VecDense, 2) + rv[0] = mat.NewVecDense(2, armadaslices.Ones[float64](2)) + rv[0].ScaleVec(-2, rv[0]) + rv[1] = mat.NewVecDense(2, armadaslices.Ones[float64](2)) + rv[1].ScaleVec(-4, rv[1]) + return rv + }(), + }, + "eta and rho non-zero": { + eta: 2.0, + rho: 0.5, + p0: mat.NewVecDense(2, armadaslices.Zeros[float64](2)), + gs: []*mat.VecDense{ + mat.NewVecDense(2, armadaslices.Ones[float64](2)), + mat.NewVecDense(2, armadaslices.Ones[float64](2)), + mat.NewVecDense(2, armadaslices.Ones[float64](2)), + }, + expecteds: func() []*mat.VecDense { + rv := make([]*mat.VecDense, 3) + rv[0] = mat.NewVecDense(2, armadaslices.Ones[float64](2)) + rv[0].ScaleVec(-3, rv[0]) + rv[1] = mat.NewVecDense(2, armadaslices.Ones[float64](2)) + rv[1].ScaleVec(-6.5, rv[1]) + rv[2] = mat.NewVecDense(2, armadaslices.Ones[float64](2)) + rv[2].ScaleVec(-10.25, rv[2]) + return rv + }(), + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + opt := MustNew(tc.eta, tc.rho) + p := tc.p0 + for i, g := range tc.gs { + opt.Extend(g.Len()) + rv := opt.Update(p, p, g) + assert.Equal(t, p, rv) + assert.Equal(t, tc.expecteds[i], p) + } + }) + } +} diff --git a/internal/common/optimisation/optimisation.go b/internal/common/optimisation/optimisation.go new file mode 100644 index 00000000000..4bad4826f72 --- /dev/null +++ b/internal/common/optimisation/optimisation.go @@ -0,0 +1,11 @@ +package optimisation + +import "gonum.org/v1/gonum/mat" + +// Optimiser represents a first-order optimisation algorithm. +type Optimiser interface { + // Update the parameters using gradient and store the result in out. + Update(out, parameters *mat.VecDense, gradient mat.Vector) *mat.VecDense + // Extend the internal state of the optimiser to accommodate at least n parameters. + Extend(n int) +} diff --git a/internal/common/slices/slices.go b/internal/common/slices/slices.go index 2cfaf7c5710..9063bf85053 100644 --- a/internal/common/slices/slices.go +++ b/internal/common/slices/slices.go @@ -6,6 +6,8 @@ import ( "math/rand" goslices "golang.org/x/exp/slices" + + "github.com/armadaproject/armada/internal/common/interfaces" ) // PartitionToLen partitions the elements of s into non-overlapping slices, @@ -178,3 +180,22 @@ func AnyFunc[S ~[]T, T any](s S, predicate func(val T) bool) bool { } return false } + +// Zeros returns a slice T[] of length n with all elements equal to zero. +func Zeros[T any](n int) []T { + return make([]T, n) +} + +// Fill returns a slice T[] of length n with all elements equal to v. +func Fill[T any](v T, n int) []T { + rv := make([]T, n) + for i := range rv { + rv[i] = v + } + return rv +} + +// Ones returns a slice T[] of length n with all elements equal to 1. +func Ones[T interfaces.Number](n int) []T { + return Fill[T](1, n) +} diff --git a/internal/common/slices/slices_test.go b/internal/common/slices/slices_test.go index 4d57dbf7fac..c7bf81cc00c 100644 --- a/internal/common/slices/slices_test.go +++ b/internal/common/slices/slices_test.go @@ -326,3 +326,18 @@ func TestAny(t *testing.T) { AnyFunc([]int{1, 2, 3}, func(v int) bool { return v > 3 }), ) } + +func TestZeros(t *testing.T) { + assert.Equal(t, make([]int, 3), Zeros[int](3)) + assert.Equal(t, make([]string, 3), Zeros[string](3)) +} + +func TestOnes(t *testing.T) { + assert.Equal(t, []int{1, 1, 1}, Ones[int](3)) + assert.Equal(t, []float64{1, 1, 1}, Ones[float64](3)) +} + +func TestFill(t *testing.T) { + assert.Equal(t, []int{2, 2, 2}, Fill[int](2, 3)) + assert.Equal(t, []float64{0.5, 0.5, 0.5}, Fill[float64](0.5, 3)) +} diff --git a/internal/scheduler/context/context.go b/internal/scheduler/context/context.go index 571338f8963..1db201cb4e6 100644 --- a/internal/scheduler/context/context.go +++ b/internal/scheduler/context/context.go @@ -440,7 +440,15 @@ func (qctx *QueueSchedulingContext) ReportString(verbosity int32) string { }, ) reasons := maps.Keys(jobIdsByReason) - slices.SortFunc(reasons, func(a, b string) bool { return len(jobIdsByReason[a]) < len(jobIdsByReason[b]) }) + slices.SortFunc(reasons, func(a, b string) int { + if len(jobIdsByReason[a]) < len(jobIdsByReason[b]) { + return -1 + } else if len(jobIdsByReason[a]) > len(jobIdsByReason[b]) { + return 1 + } else { + return 0 + } + }) for i := len(reasons) - 1; i >= 0; i-- { reason := reasons[i] jobIds := jobIdsByReason[reason] diff --git a/internal/scheduler/database/executor_repository_test.go b/internal/scheduler/database/executor_repository_test.go index 76a0e14c9f9..b31bd67caef 100644 --- a/internal/scheduler/database/executor_repository_test.go +++ b/internal/scheduler/database/executor_repository_test.go @@ -61,8 +61,14 @@ func TestExecutorRepository_LoadAndSave(t *testing.T) { } retrievedExecutors, err := repo.GetExecutors(ctx) require.NoError(t, err) - executorSort := func(a *schedulerobjects.Executor, b *schedulerobjects.Executor) bool { - return a.Id > b.Id + executorSort := func(a *schedulerobjects.Executor, b *schedulerobjects.Executor) int { + if a.Id > b.Id { + return -1 + } else if a.Id < b.Id { + return 1 + } else { + return 0 + } } slices.SortFunc(retrievedExecutors, executorSort) slices.SortFunc(tc.executors, executorSort) diff --git a/internal/scheduler/database/job_repository_test.go b/internal/scheduler/database/job_repository_test.go index 9be871caa67..63d8b7ce335 100644 --- a/internal/scheduler/database/job_repository_test.go +++ b/internal/scheduler/database/job_repository_test.go @@ -367,7 +367,15 @@ func TestFindInactiveRuns(t *testing.T) { inactive, err := repo.FindInactiveRuns(ctx, tc.runsToCheck) require.NoError(t, err) - uuidSort := func(a uuid.UUID, b uuid.UUID) bool { return a.String() > b.String() } + uuidSort := func(a uuid.UUID, b uuid.UUID) int { + if a.String() > b.String() { + return -1 + } else if a.String() < b.String() { + return 1 + } else { + return 0 + } + } slices.SortFunc(inactive, uuidSort) slices.SortFunc(tc.expectedInactive, uuidSort) assert.Equal(t, tc.expectedInactive, inactive) @@ -518,7 +526,15 @@ func TestFetchJobRunLeases(t *testing.T) { leases, err := repo.FetchJobRunLeases(ctx, tc.executor, tc.maxRowsToFetch, tc.excludedRuns) require.NoError(t, err) - leaseSort := func(a *JobRunLease, b *JobRunLease) bool { return a.RunID.String() > b.RunID.String() } + leaseSort := func(a *JobRunLease, b *JobRunLease) int { + if a.RunID.String() > b.RunID.String() { + return -1 + } else if a.RunID.String() < b.RunID.String() { + return 1 + } else { + return 0 + } + } slices.SortFunc(leases, leaseSort) slices.SortFunc(tc.expectedLeases, leaseSort) assert.Equal(t, tc.expectedLeases, leases) diff --git a/internal/scheduler/database/queue_repository_test.go b/internal/scheduler/database/queue_repository_test.go index edd33be3d41..718d0904290 100644 --- a/internal/scheduler/database/queue_repository_test.go +++ b/internal/scheduler/database/queue_repository_test.go @@ -58,7 +58,15 @@ func TestLegacyQueueRepository_GetAllQueues(t *testing.T) { } retrievedQueues, err := repo.GetAllQueues() require.NoError(t, err) - sortFunc := func(a, b *Queue) bool { return a.Name > b.Name } + sortFunc := func(a, b *Queue) int { + if a.Name > b.Name { + return -1 + } else if a.Name > b.Name { + return 1 + } else { + return 0 + } + } slices.SortFunc(tc.expectedQueues, sortFunc) slices.SortFunc(retrievedQueues, sortFunc) assert.Equal(t, tc.expectedQueues, retrievedQueues) diff --git a/internal/scheduler/database/redis_executor_repository_test.go b/internal/scheduler/database/redis_executor_repository_test.go index bf5b0ea9629..be1eb42d9b6 100644 --- a/internal/scheduler/database/redis_executor_repository_test.go +++ b/internal/scheduler/database/redis_executor_repository_test.go @@ -61,8 +61,14 @@ func TestRedisExecutorRepository_LoadAndSave(t *testing.T) { } retrievedExecutors, err := repo.GetExecutors(ctx) require.NoError(t, err) - executorSort := func(a *schedulerobjects.Executor, b *schedulerobjects.Executor) bool { - return a.Id > b.Id + executorSort := func(a *schedulerobjects.Executor, b *schedulerobjects.Executor) int { + if a.Id > b.Id { + return -1 + } else if a.Id < b.Id { + return 1 + } else { + return 0 + } } slices.SortFunc(retrievedExecutors, executorSort) slices.SortFunc(tc.executors, executorSort) diff --git a/internal/scheduler/failureestimator/failureestimator.go b/internal/scheduler/failureestimator/failureestimator.go index 9e35f5158e0..c131ded8da8 100644 --- a/internal/scheduler/failureestimator/failureestimator.go +++ b/internal/scheduler/failureestimator/failureestimator.go @@ -4,12 +4,17 @@ import ( "fmt" "math" "sync" - "time" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" + "gonum.org/v1/gonum/mat" "github.com/armadaproject/armada/internal/common/armadaerrors" + "github.com/armadaproject/armada/internal/common/linalg" + armadamath "github.com/armadaproject/armada/internal/common/math" + "github.com/armadaproject/armada/internal/common/optimisation" + "github.com/armadaproject/armada/internal/common/slices" + armadaslices "github.com/armadaproject/armada/internal/common/slices" ) const ( @@ -18,8 +23,6 @@ const ( // Floating point tolerance. Also used when applying limits to avoid divide-by-zero issues. eps = 1e-15 - // Assumed success probability of "good" nodes (queues) used when calculating step size. - healthySuccessProbability = 0.95 ) // FailureEstimator is a system for answering the following question: @@ -28,53 +31,41 @@ const ( // Denote by // - P_q the probability of a job from queue q succeeding when running on a perfect node and // - P_n is the probability of a perfect job succeeding on node n. -// The success probability of a job from queue q on node n is then Pr(p_q*p_n=1), +// The success probability of a job from queue q on node n is then Pr(p_q*p_n = 1), // where p_q and p_n are drawn from Bernoulli distributions with parameter P_q and P_n, respectively. // -// Now, the goal is to jointly estimate P_q and P_n for each queue and node using observed successes and failures. -// The method used is statistical and only relies on knowing which queue a job belongs to and on which node it ran. -// The intuition of the method is that: -// - A job from a queue with many failures doesn't say much about the node; likely it's the job that's the problem. -// - A job failing on a node with many failures doesn't say much about the job; likely it's the node that's the problem. +// Now, the goal is to jointly estimate P_q and P_n for each queue and node from observed successes and failures. +// We do so here with a statistical method. The intuition of the method is that: +// - A job from a queue with many failures failing doesn't say much about the node; likely the problem is with the job. +// - A job failing on a node with many failures doesn't say much about the job; likely the problem is with the node. // And vice versa. // -// Specifically, we maximise the log-likelihood function of P_q and P_n using observed successes and failures. +// Specifically, we maximise the log-likelihood function of P_q and P_n over observed successes and failures. // This maximisation is performed using online gradient descent, where for each success or failure, -// we update the corresponding P_q and P_n by taking a gradient step. -// See New(...) for more details regarding step size. -// -// Finally, we exponentially decay P_q and P_N towards 1 over time, -// such that nodes and queues for which we observe no failures appear to become healthier over time. -// See New(...) function for details regarding decay. +// we update the corresponding P_q and P_n by taking a gradient step. See the Update() function for details. // // This module internally only maintains success probability estimates, as this makes the maths cleaner. -// When exporting these via API calls we convert to failure probabilities as these are more intuitive to reason about. +// We convert these to failure probabilities when exporting these via API calls. type FailureEstimator struct { - // Map from node (queue) name to the estimated success probability of that node (queue). For example: - // - successProbabilityByNode["myNode"] = 0.85]: estimated failure probability of a perfect job run on "myNode" is 15%. - // - successProbabilityByQueue["myQueue"] = 0.60]: estimated failure probability of a job from "myQueue" run on a perfect node is 40%. - successProbabilityByNode map[string]float64 - successProbabilityByQueue map[string]float64 - - // Success probability below which to consider nodes (jobs) unhealthy. - nodeSuccessProbabilityCordonThreshold float64 - queueSuccessProbabilityCordonThreshold float64 - - // Exponential decay factor controlling how quickly estimated node (queue) success probability decays towards 1. - // Computed from: - // - {node, queue}SuccessProbabilityCordonThreshold - // - {node, queue}CordonTimeout - nodeFailureProbabilityDecayRate float64 - queueFailureProbabilityDecayRate float64 - timeOfLastDecay time.Time - - // Gradient descent step size. Controls the extent to which new data affects successProbabilityBy{Node, Queue}. - // Computed from: - // - {node, queue}SuccessProbabilityCordonThreshold - // - {node, queue}FailureProbabilityDecayRate - // - {node, queue}EquilibriumFailureRate - nodeStepSize float64 - queueStepSize float64 + // Success probability estimates for all nodes and queues. + parameters *mat.VecDense + intermediateParameters *mat.VecDense + + // Gradient buffer. + gradient *mat.VecDense + + // Maps node (queue) names to the corresponding index of parameters. + // E.g., if parameterIndexByNode["myNode"] = 10, then parameters[10] is the estimated success probability of myNode. + parameterIndexByNode map[string]int + parameterIndexByQueue map[string]int + + // Samples that have not been processed yet. + samples []Sample + + // Optimisation settings. + numInnerIterations int + innerOptimiser optimisation.Optimiser + outerOptimiser optimisation.Optimiser // Prometheus metrics. failureProbabilityByNodeDesc *prometheus.Desc @@ -84,88 +75,41 @@ type FailureEstimator struct { disabled bool // Mutex protecting the above fields. - // Prevents concurrent map modification issues when scraping metrics. + // Prevents concurrent map modification when scraping metrics. mu sync.Mutex } -// New returns a new FailureEstimator. Parameters have the following meaning: -// - {node, queue}SuccessProbabilityCordonThreshold: Success probability below which nodes (queues) are considered unhealthy. -// - {node, queue}CordonTimeout: Amount of time for which nodes (queues) remain unhealthy in the absence of any job successes or failures for that node (queue). -// - {node, queue}EquilibriumFailureRate: Job failures per second necessary for a node (queue) to remain unhealthy in the absence of any successes for that node (queue). +type Sample struct { + i int + j int + c bool +} + +// New returns a new FailureEstimator. func New( - nodeSuccessProbabilityCordonThreshold float64, - queueSuccessProbabilityCordonThreshold float64, - nodeCordonTimeout time.Duration, - queueCordonTimeout time.Duration, - nodeEquilibriumFailureRate float64, - queueEquilibriumFailureRate float64, + numInnerIterations int, + innerOptimiser optimisation.Optimiser, + outerOptimiser optimisation.Optimiser, ) (*FailureEstimator, error) { - if nodeSuccessProbabilityCordonThreshold < 0 || nodeSuccessProbabilityCordonThreshold > 1 { - return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ - Name: "nodeSuccessProbabilityCordonThreshold", - Value: nodeSuccessProbabilityCordonThreshold, - Message: fmt.Sprintf("outside allowed range [0, 1]"), - }) - } - if queueSuccessProbabilityCordonThreshold < 0 || queueSuccessProbabilityCordonThreshold > 1 { + if numInnerIterations < 1 { return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ - Name: "queueSuccessProbabilityCordonThreshold", - Value: queueSuccessProbabilityCordonThreshold, - Message: fmt.Sprintf("outside allowed range [0, 1]"), - }) - } - if nodeCordonTimeout < 0 { - return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ - Name: "nodeCordonTimeout", - Value: nodeCordonTimeout, - Message: fmt.Sprintf("outside allowed range [0, Inf)"), - }) - } - if queueCordonTimeout < 0 { - return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ - Name: "queueCordonTimeout", - Value: queueCordonTimeout, - Message: fmt.Sprintf("outside allowed range [0, Inf)"), - }) - } - if nodeEquilibriumFailureRate <= 0 { - return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ - Name: "nodeEquilibriumFailureRate", - Value: nodeEquilibriumFailureRate, - Message: fmt.Sprintf("outside allowed range (0, Inf)"), - }) - } - if queueEquilibriumFailureRate <= 0 { - return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ - Name: "queueEquilibriumFailureRate", - Value: queueEquilibriumFailureRate, - Message: fmt.Sprintf("outside allowed range (0, Inf)"), + Name: "numInnerIterations", + Value: numInnerIterations, + Message: fmt.Sprintf("outside allowed range [1, Inf)"), }) } + return &FailureEstimator{ + parameters: mat.NewVecDense(32, armadaslices.Fill[float64](0.5, 32)), + intermediateParameters: mat.NewVecDense(32, armadaslices.Zeros[float64](32)), + gradient: mat.NewVecDense(32, armadaslices.Zeros[float64](32)), - // Compute decay rates such that a node (queue) with success probability 0 will over {node, queue}CordonTimeout time - // decay to a success probability of {node, queue}SuccessProbabilityCordonThreshold. - nodeFailureProbabilityDecayRate := math.Exp(math.Log(1-nodeSuccessProbabilityCordonThreshold) / nodeCordonTimeout.Seconds()) - queueFailureProbabilityDecayRate := math.Exp(math.Log(1-queueSuccessProbabilityCordonThreshold) / queueCordonTimeout.Seconds()) + parameterIndexByNode: make(map[string]int, 16), + parameterIndexByQueue: make(map[string]int, 16), - // Compute step size such that a node (queue) with success probability {node, queue}SuccessProbabilityCordonThreshold - // for which we observe 0 successes and {node, queue}EquilibriumFailureRate failures per second from "good" nodes (queues) - // will remain at exactly {node, queue}SuccessProbabilityCordonThreshold success probability. - dNodeSuccessProbability := healthySuccessProbability / (1 - nodeSuccessProbabilityCordonThreshold*healthySuccessProbability) - dQueueSuccessProbability := healthySuccessProbability / (1 - queueSuccessProbabilityCordonThreshold*healthySuccessProbability) - nodeStepSize := (1 - nodeSuccessProbabilityCordonThreshold - (1-nodeSuccessProbabilityCordonThreshold)*nodeFailureProbabilityDecayRate) / dNodeSuccessProbability / nodeEquilibriumFailureRate - queueStepSize := (1 - queueSuccessProbabilityCordonThreshold - (1-queueSuccessProbabilityCordonThreshold)*queueFailureProbabilityDecayRate) / dQueueSuccessProbability / queueEquilibriumFailureRate + numInnerIterations: numInnerIterations, + innerOptimiser: innerOptimiser, + outerOptimiser: outerOptimiser, - return &FailureEstimator{ - successProbabilityByNode: make(map[string]float64, 1024), - successProbabilityByQueue: make(map[string]float64, 128), - nodeSuccessProbabilityCordonThreshold: nodeSuccessProbabilityCordonThreshold, - queueSuccessProbabilityCordonThreshold: queueSuccessProbabilityCordonThreshold, - nodeFailureProbabilityDecayRate: nodeFailureProbabilityDecayRate, - queueFailureProbabilityDecayRate: queueFailureProbabilityDecayRate, - timeOfLastDecay: time.Now(), - nodeStepSize: nodeStepSize, - queueStepSize: queueStepSize, failureProbabilityByNodeDesc: prometheus.NewDesc( fmt.Sprintf("%s_%s_node_failure_probability", namespace, subsystem), "Estimated per-node failure probability.", @@ -195,61 +139,93 @@ func (fe *FailureEstimator) IsDisabled() bool { return fe.disabled } -// Decay moves the success probabilities of nodes (queues) closer to 1, depending on the configured cordon timeout. -// Periodically calling Decay() ensures nodes (queues) considered unhealthy are eventually considered healthy again, -// even if we observe no successes for those nodes (queues). -func (fe *FailureEstimator) Decay() { +// Push adds a sample to the internal buffer of the failure estimator. +// Samples added via Push are processed on the next call to Update. +func (fe *FailureEstimator) Push(node, queue string, success bool) { fe.mu.Lock() defer fe.mu.Unlock() - t := time.Now() - fe.decay(t.Sub(fe.timeOfLastDecay).Seconds()) - fe.timeOfLastDecay = t - return -} -func (fe *FailureEstimator) decay(secondsSinceLastDecay float64) { - nodeFailureProbabilityDecay := math.Pow(fe.nodeFailureProbabilityDecayRate, secondsSinceLastDecay) - for k, v := range fe.successProbabilityByNode { - failureProbability := 1 - v - failureProbability *= nodeFailureProbabilityDecay - successProbability := 1 - failureProbability - fe.successProbabilityByNode[k] = applyBounds(successProbability) + i, ok := fe.parameterIndexByNode[node] + if !ok { + i = len(fe.parameterIndexByNode) + len(fe.parameterIndexByQueue) + fe.parameterIndexByNode[node] = i + } + j, ok := fe.parameterIndexByQueue[queue] + if !ok { + j = len(fe.parameterIndexByNode) + len(fe.parameterIndexByQueue) + fe.parameterIndexByQueue[queue] = j } + fe.extendParameters(armadamath.Max(i, j) + 1) + fe.samples = append(fe.samples, Sample{ + i: i, + j: j, + c: success, + }) +} - queueFailureProbabilityDecay := math.Pow(fe.queueFailureProbabilityDecayRate, secondsSinceLastDecay) - for k, v := range fe.successProbabilityByQueue { - failureProbability := 1 - v - failureProbability *= queueFailureProbabilityDecay - successProbability := 1 - failureProbability - fe.successProbabilityByQueue[k] = applyBounds(successProbability) +func (fe *FailureEstimator) extendParameters(n int) { + oldN := fe.parameters.Len() + fe.parameters = linalg.ExtendVecDense(fe.parameters, n) + if oldN < n { + for i := oldN; i < n; i++ { + // Initialise new parameters with 50% success probability. + fe.parameters.SetVec(i, 0.5) + } } - return + fe.intermediateParameters = linalg.ExtendVecDense(fe.intermediateParameters, n) + fe.gradient = linalg.ExtendVecDense(fe.gradient, n) } -// Update with success=false decreases the estimated success probability of the provided node and queue. -// Calling with success=true increases the estimated success probability of the provided node and queue. -// This update is performed by taking one gradient descent step. -func (fe *FailureEstimator) Update(node, queue string, success bool) { +// Update success probability estimates based on pushed samples. +func (fe *FailureEstimator) Update() { fe.mu.Lock() defer fe.mu.Unlock() - - // Assume that nodes (queues) we haven't seen before have a 50% success probability. - // Avoiding extreme values for new nodes (queues) helps avoid drastic changes to existing estimates. - nodeSuccessProbability, ok := fe.successProbabilityByNode[node] - if !ok { - nodeSuccessProbability = 0.5 + if len(fe.samples) == 0 { + // Nothing to do. + return } - queueSuccessProbability, ok := fe.successProbabilityByQueue[queue] - if !ok { - queueSuccessProbability = 0.5 + + // Inner loop to compute intermediateParameters from samples. + // Passing over samples multiple times in random order helps improve convergence. + fe.intermediateParameters.CopyVec(fe.parameters) + for k := 0; k < fe.numInnerIterations; k++ { + + // Compute gradient with respect to updates. + fe.gradient.Zero() + slices.Shuffle(fe.samples) + for _, sample := range fe.samples { + gi, gj := fe.negLogLikelihoodGradient( + fe.intermediateParameters.AtVec(sample.i), + fe.intermediateParameters.AtVec(sample.j), + sample.c, + ) + fe.gradient.SetVec(sample.i, fe.gradient.AtVec(sample.i)+gi) + fe.gradient.SetVec(sample.j, fe.gradient.AtVec(sample.j)+gj) + } + + // Update intermediateParameters using this gradient. + fe.innerOptimiser.Extend(fe.intermediateParameters.Len()) + fe.intermediateParameters = fe.innerOptimiser.Update(fe.intermediateParameters, fe.intermediateParameters, fe.gradient) + applyBoundsVec(fe.intermediateParameters) } - dNodeSuccessProbability, dQueueSuccessProbability := fe.negLogLikelihoodGradient(nodeSuccessProbability, queueSuccessProbability, success) - nodeSuccessProbability -= fe.nodeStepSize * dNodeSuccessProbability - queueSuccessProbability -= fe.queueStepSize * dQueueSuccessProbability + // Let the gradient be the difference between parameters and intermediateParameters, + // i.e., we use the inner loop as a method to estimate the gradient, + // and then update parameters using this gradient and the outer optimiser. + fe.gradient.CopyVec(fe.parameters) + fe.gradient.SubVec(fe.gradient, fe.intermediateParameters) + fe.outerOptimiser.Extend(fe.parameters.Len()) + fe.parameters = fe.outerOptimiser.Update(fe.parameters, fe.parameters, fe.gradient) + applyBoundsVec(fe.parameters) + + // Empty the buffer. + fe.samples = fe.samples[0:0] +} - fe.successProbabilityByNode[node] = applyBounds(nodeSuccessProbability) - fe.successProbabilityByQueue[queue] = applyBounds(queueSuccessProbability) +func applyBoundsVec(vec *mat.VecDense) { + for i := 0; i < vec.Len(); i++ { + vec.SetVec(i, applyBounds(vec.AtVec(i))) + } } // applyBounds ensures values stay in the range [eps, 1-eps]. @@ -264,7 +240,7 @@ func applyBounds(v float64) float64 { } } -// negLogLikelihoodGradient returns the gradient of the negated log-likelihood function with respect to P_q and P_n. +// negLogLikelihoodGradient returns the gradient of the negated log-likelihood function. func (fe *FailureEstimator) negLogLikelihoodGradient(nodeSuccessProbability, queueSuccessProbability float64, success bool) (float64, float64) { if success { dNodeSuccessProbability := -1 / nodeSuccessProbability @@ -294,13 +270,13 @@ func (fe *FailureEstimator) Collect(ch chan<- prometheus.Metric) { // Report failure probability rounded to nearest multiple of 0.01. // (As it's unlikely the estimate is accurate to within less than this.) - for k, v := range fe.successProbabilityByNode { - failureProbability := 1 - v + for k, i := range fe.parameterIndexByNode { + failureProbability := 1 - fe.parameters.AtVec(i) failureProbability = math.Round(failureProbability*100) / 100 ch <- prometheus.MustNewConstMetric(fe.failureProbabilityByNodeDesc, prometheus.GaugeValue, failureProbability, k) } - for k, v := range fe.successProbabilityByQueue { - failureProbability := 1 - v + for k, j := range fe.parameterIndexByQueue { + failureProbability := 1 - fe.parameters.AtVec(j) failureProbability = math.Round(failureProbability*100) / 100 ch <- prometheus.MustNewConstMetric(fe.failureProbabilityByQueueDesc, prometheus.GaugeValue, failureProbability, k) } diff --git a/internal/scheduler/failureestimator/failureestimator_test.go b/internal/scheduler/failureestimator/failureestimator_test.go index 3b02f2a096a..5a4bbb3c8bd 100644 --- a/internal/scheduler/failureestimator/failureestimator_test.go +++ b/internal/scheduler/failureestimator/failureestimator_test.go @@ -1,190 +1,81 @@ package failureestimator import ( - "math" + "fmt" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" -) - -func TestNew(t *testing.T) { - successProbabilityCordonThreshold := 0.05 - cordonTimeout := 10 * time.Minute - equilibriumFailureRate := 0.1 - - // Node decay rate and step size tests. - fe, err := New( - successProbabilityCordonThreshold, - 0, - cordonTimeout, - 0, - equilibriumFailureRate, - eps, - ) - require.NoError(t, err) - assert.LessOrEqual(t, fe.nodeFailureProbabilityDecayRate, 0.9999145148300861+eps) - assert.GreaterOrEqual(t, fe.nodeFailureProbabilityDecayRate, 0.9999145148300861-eps) - assert.LessOrEqual(t, fe.nodeStepSize, 0.0008142462434300169+eps) - assert.GreaterOrEqual(t, fe.nodeStepSize, 0.0008142462434300169-eps) - - // Queue decay rate and step size tests. - fe, err = New( - 0, - successProbabilityCordonThreshold, - 0, - cordonTimeout, - eps, - equilibriumFailureRate, - ) - require.NoError(t, err) - assert.LessOrEqual(t, fe.queueFailureProbabilityDecayRate, 0.9999145148300861+eps) - assert.GreaterOrEqual(t, fe.queueFailureProbabilityDecayRate, 0.9999145148300861-eps) - assert.LessOrEqual(t, fe.queueStepSize, 0.0008142462434300169+eps) - assert.GreaterOrEqual(t, fe.queueStepSize, 0.0008142462434300169-eps) -} - -func TestDecay(t *testing.T) { - successProbabilityCordonThreshold := 0.05 - cordonTimeout := 10 * time.Minute - equilibriumFailureRate := 0.1 - // Node decay tests. - fe, err := New( - successProbabilityCordonThreshold, - 0, - cordonTimeout, - 0, - equilibriumFailureRate, - eps, - ) - require.NoError(t, err) - - fe.successProbabilityByNode["foo"] = 0.0 - fe.successProbabilityByNode["bar"] = 0.4 - fe.successProbabilityByNode["baz"] = 1.0 - - fe.decay(0) - assert.Equal( - t, - fe.successProbabilityByNode, - map[string]float64{ - "foo": eps, - "bar": 0.4, - "baz": 1.0 - eps, - }, - ) - - fe.decay(1) - assert.Equal( - t, - fe.successProbabilityByNode, - map[string]float64{ - "foo": 1 - (1-eps)*math.Pow(fe.nodeFailureProbabilityDecayRate, 1), - "bar": 1 - (1-0.4)*math.Pow(fe.nodeFailureProbabilityDecayRate, 1), - "baz": 1.0 - eps, - }, - ) - - fe.decay(1e6) - assert.Equal( - t, - fe.successProbabilityByNode, - map[string]float64{ - "foo": 1.0 - eps, - "bar": 1.0 - eps, - "baz": 1.0 - eps, - }, - ) - - // Queue decay tests. - fe, err = New( - 0, - successProbabilityCordonThreshold, - 0, - cordonTimeout, - eps, - equilibriumFailureRate, - ) - require.NoError(t, err) - - fe.successProbabilityByQueue["foo"] = 0.0 - fe.successProbabilityByQueue["bar"] = 0.4 - fe.successProbabilityByQueue["baz"] = 1.0 - - fe.decay(0) - assert.Equal( - t, - fe.successProbabilityByQueue, - map[string]float64{ - "foo": eps, - "bar": 0.4, - "baz": 1.0 - eps, - }, - ) - - fe.decay(1) - assert.Equal( - t, - fe.successProbabilityByQueue, - map[string]float64{ - "foo": 1 - (1-eps)*math.Pow(fe.queueFailureProbabilityDecayRate, 1), - "bar": 1 - (1-0.4)*math.Pow(fe.queueFailureProbabilityDecayRate, 1), - "baz": 1.0 - eps, - }, - ) - - fe.decay(1e6) - assert.Equal( - t, - fe.successProbabilityByQueue, - map[string]float64{ - "foo": 1.0 - eps, - "bar": 1.0 - eps, - "baz": 1.0 - eps, - }, - ) -} + "github.com/armadaproject/armada/internal/common/optimisation/descent" + "github.com/armadaproject/armada/internal/common/optimisation/nesterov" +) func TestUpdate(t *testing.T) { - successProbabilityCordonThreshold := 0.05 - cordonTimeout := 10 * time.Minute - equilibriumFailureRate := 0.1 - fe, err := New( - successProbabilityCordonThreshold, - successProbabilityCordonThreshold, - cordonTimeout, - cordonTimeout, - equilibriumFailureRate, - equilibriumFailureRate, + 10, + descent.MustNew(0.05), + nesterov.MustNew(0.05, 0.2), ) require.NoError(t, err) - fe.Update("node", "queue", false) - nodeSuccessProbability, ok := fe.successProbabilityByNode["node"] + // Test initialisation. + fe.Push("node", "queue", false) + nodeParameterIndex, ok := fe.parameterIndexByNode["node"] + require.True(t, ok) + queueParameterIndex, ok := fe.parameterIndexByQueue["queue"] + require.True(t, ok) + require.Equal(t, 0, nodeParameterIndex) + require.Equal(t, 1, queueParameterIndex) + require.Equal(t, 0.5, fe.parameters.AtVec(0)) + require.Equal(t, 0.5, fe.parameters.AtVec(1)) + + for i := 0; i < 100; i++ { + fe.Push(fmt.Sprintf("node-%d", i), "queue-0", false) + } + nodeParameterIndex, ok = fe.parameterIndexByNode["node-99"] require.True(t, ok) - queueSuccessProbability, ok := fe.successProbabilityByQueue["queue"] + queueParameterIndex, ok = fe.parameterIndexByQueue["queue-0"] require.True(t, ok) + require.Equal(t, 2+100, nodeParameterIndex) + require.Equal(t, 3, queueParameterIndex) + require.Equal(t, 0.5, fe.parameters.AtVec(102)) + require.Equal(t, 0.5, fe.parameters.AtVec(3)) + + // Test that the estimates move in the expected direction on failure. + fe.Update() + nodeSuccessProbability := fe.parameters.AtVec(0) + queueSuccessProbability := fe.parameters.AtVec(1) assert.Greater(t, nodeSuccessProbability, eps) assert.Greater(t, queueSuccessProbability, eps) - assert.Less(t, nodeSuccessProbability, healthySuccessProbability-eps) - assert.Less(t, queueSuccessProbability, healthySuccessProbability-eps) - - fe.Update("node", "queue", true) - assert.Greater(t, fe.successProbabilityByNode["node"], nodeSuccessProbability) - assert.Greater(t, fe.successProbabilityByQueue["queue"], queueSuccessProbability) - - for i := 0; i < 100000; i++ { - fe.Update("node", "queue", false) + assert.Less(t, nodeSuccessProbability, 0.5-eps) + assert.Less(t, queueSuccessProbability, 0.5-eps) + + // Test that the estimates move in the expected direction on success. + fe.Push("node", "queue", true) + fe.Update() + assert.Greater(t, fe.parameters.AtVec(0), nodeSuccessProbability) + assert.Greater(t, fe.parameters.AtVec(1), queueSuccessProbability) + + for i := 0; i < 1000; i++ { + for i := 0; i < 10; i++ { + fe.Push("node", "queue", false) + } + fe.Update() } - assert.Equal(t, fe.successProbabilityByNode["node"], eps) - assert.Equal(t, fe.successProbabilityByQueue["queue"], eps) - - for i := 0; i < 100000; i++ { - fe.Update("node", "queue", true) + assert.Greater(t, fe.parameters.AtVec(0), 0.0) + assert.Greater(t, fe.parameters.AtVec(1), 0.0) + assert.Less(t, fe.parameters.AtVec(0), 2*eps) + assert.Less(t, fe.parameters.AtVec(1), 2*eps) + + for i := 0; i < 1000; i++ { + for i := 0; i < 10; i++ { + fe.Push("node", "queue", true) + } + fe.Update() } - assert.Equal(t, fe.successProbabilityByNode["node"], 1-eps) - assert.Equal(t, fe.successProbabilityByQueue["queue"], 1-eps) + assert.Greater(t, fe.parameters.AtVec(0), 1-2*eps) + assert.Greater(t, fe.parameters.AtVec(1), 1-2*eps) + assert.Less(t, fe.parameters.AtVec(0), 1.0) + assert.Less(t, fe.parameters.AtVec(1), 1.0) } diff --git a/internal/scheduler/failureestimator/simulator.jl b/internal/scheduler/failureestimator/simulator.jl new file mode 100644 index 00000000000..726a0ad8991 --- /dev/null +++ b/internal/scheduler/failureestimator/simulator.jl @@ -0,0 +1,352 @@ +using Random +using StatsBase +using Optimization +using OptimizationOptimJL +using OptimizationBBO +using ForwardDiff +using Plots + +# Julia script to simulate failure estimation and optimise parameters. + +function neg_log_likelihood(As, Bs, Is, cs) + llsum = 0 + for k = 1:length(cs) + i = Is[k, 1] + j = Is[k, 2] + if cs[k] + llsum += log(As[i]*Bs[j]) + else + llsum += log(1 - As[i]*Bs[j]) + end + end + return -llsum +end + +function neg_log_likelihood_gradient!(dAs, dBs, As, Bs, Is, cs) + dAs .= 0 + dBs .= 0 + for k = 1:length(cs) + i = Is[k, 1] + j = Is[k, 2] + dA, dB = neg_log_likelihood_gradient_inner(As[i], Bs[j], cs[k]) + dAs[i] += dA + dBs[j] += dB + end + return dAs, dBs +end + +function neg_log_likelihood_gradient_inner(A, B, c) + if c + dA = -1 / A + dB = -1 / B + return dA, dB + else + dA = B / (1 - A*B) + dB = A / (1 - A*B) + return dA, dB + end +end + +function f(u, p) + Is, cs, n = p + As = view(u, 1:n) + Bs = view(u, n+1:length(u)) + return neg_log_likelihood(As, Bs, Is, cs) +end + +function g(G, u, p) + Is, cs, n = p + dAs = view(G, 1:n) + dBs = view(G, n+1:length(u)) + neg_log_likelihood_gradient!(dAs, dBs, As, Bs, Is, cs) + return G +end + +function global_optimization_solution(n, k, Is, cs) + prob = OptimizationProblem( + OptimizationFunction(f, grad=g), + ones(n+k)./2, + (Is, cs, n), + lb=zeros(n+k), + ub=ones(n+k), + ) + return solve(prob, NelderMead()) +end + +function gd(Is, Js, cs, ts, n, k; inner_opt, outer_opt, update_interval::Float64=60.0, num_sub_iterations::Integer=1) + nsamples = length(cs) + x = fill(0.5, n+k) + xk = zeros(n+k) + g = zeros(n+k) + As = zeros(n, nsamples) + Bs = zeros(k, nsamples) + Ak = view(x, 1:n) + Bk = view(x, (n+1):(n+k)) + Aki = view(xk, 1:n) + Bki = view(xk, (n+1):(n+k)) + Agki = view(g, 1:n) + Bgki = view(g, (n+1):(n+k)) + + index_of_last_update = 0 + time_of_last_update = 0.0 + for sample_index = 1:nsamples + As[:, sample_index] .= Ak + Bs[:, sample_index] .= Bk + + if !(ts[sample_index] - time_of_last_update >= update_interval) + # Not yet time to update. + continue + end + + # Perform several gradient descent steps over the collected data. + xk .= x + batch_indices = collect((index_of_last_update+1):sample_index) + for _ = 1:num_sub_iterations + g .= 0 + shuffle!(batch_indices) + for batch_index = batch_indices + i = Is[batch_index] + j = Js[batch_index] + dA, dB = neg_log_likelihood_gradient_inner(Aki[i], Bki[j], cs[batch_index]) + Agki[i] += dA + Bgki[j] += dB + end + + Flux.Optimise.update!(inner_opt, xk, g) + xk .= min.(max.(xk, 0), 1) + end + + g .= x .- xk + Flux.Optimise.update!(outer_opt, x, g) + x .= min.(max.(x, 0), 1) + + index_of_last_update = sample_index + time_of_last_update = ts[sample_index] + + As[:, sample_index] .= Ak + Bs[:, sample_index] .= Bk + end + return As, Bs +end + +struct NodeTemplate + n::Int + p::Float64 + df::Float64 +end + +struct QueueTemplate + k::Int + p::Float64 + df::Float64 + ds::Float64 + ws::Vector{Int} +end + +struct Parameters + # Minimum simulation time. + t::Float64 + node_templates::Vector{NodeTemplate} + queue_templates::Vector{QueueTemplate} +end + +function scenario1() + return Parameters( + 36000, + [ + NodeTemplate(1, 0.9, 60), + NodeTemplate(1, 0.01, 60), + ], + [ + QueueTemplate(1, 0.9, 60, 60, [1]), + QueueTemplate(1, 0.01, 60, 60, [1]), + ], + ) +end + +function scenario2() + return Parameters( + 36000, + [ + NodeTemplate(18, 0.9, 60), + NodeTemplate(2, 0.1, 60), + ], + [ + QueueTemplate(1, 0.9, 60, 60, [1]), + QueueTemplate(1, 0.1, 60, 60, [1]), + ], + ) +end + +function scenario3() + return Parameters( + 36000, + [ + NodeTemplate(90, 0.9, 60), + NodeTemplate(10, 0.1, 60), + ], + [ + QueueTemplate(8, 0.9, 60, 600, ones(Int, 8)), + QueueTemplate(2, 0.1, 60, 600, ones(Int, 2)), + ], + ) +end + +function simulate(p::Parameters) + for tmpl in p.queue_templates + if length(tmpl.ws) != tmpl.k + throw(ArgumentError("ws must be of length k")) + end + end + n = sum(tmpl.n for tmpl in p.node_templates) + k = sum(tmpl.k for tmpl in p.queue_templates) + ws = [sum(tmpl.ws) for tmpl in p.queue_templates] + ncs = cumsum(tmpl.n for tmpl in p.node_templates) + kcs = cumsum(tmpl.k for tmpl in p.queue_templates) + + Is = Vector{Int}() + Js = Vector{Int}() + as = Vector{Bool}() + bs = Vector{Bool}() + cs = Vector{Bool}() + ts = Vector{Float64}() + for (node_template_index, node_template) in enumerate(p.node_templates) + for node_template_i = 1:node_template.n + i = node_template_i + if node_template_index > 1 + i += ncs[node_template_index-1] + end + + t = 0.0 + while t < p.t + queue_template_index = sample(1:length(p.queue_templates), Weights(ws)) + queue_template = p.queue_templates[queue_template_index] + j = sample(1:queue_template.k, Weights(queue_template.ws)) + if queue_template_index > 1 + j += kcs[queue_template_index-1] + end + + # Sample node and queue failure. + a = rand() < node_template.p + b = rand() < queue_template.p + c = a*b + + # Compute duration until job termination. + d = 0.0 + if !a + # Node failure. + d = node_template.df + elseif !b + # Queue failure. + d = queue_template.df + else + # Job success. + d = queue_template.ds + end + + t += d + push!(Is, i) + push!(Js, j) + push!(as, a) + push!(bs, b) + push!(cs, c) + push!(ts, t) + end + end + end + + # Sort by time and return. + p = sortperm(ts) + return Is[p], Js[p], as[p], bs[p], cs[p], ts[p] +end + +function rate_by_index(Is, as) + as_by_i = Dict{Int, Vector{Bool}}() + for k = 1:length(Is) + i = Is[k] + a = as[k] + vs = get(as_by_i, i, Vector{Bool}()) + as_by_i[i] = push!(vs, a) + end + rate_by_index = Dict{Int, Float64}() + for (i, vs) in as_by_i + rate_by_index[i] = sum(vs)/length(vs) + end + return rate_by_index +end + +function squared_error(p::Parameters, As, Bs) + Ase = copy(As) + Bse = copy(Bs) + i = 1 + for node_template = p.node_templates + for _ = 1:node_template.n + Ase[i, :] .-= node_template.p + Ase[i, :] .^= 2 + i += 1 + end + end + j = 1 + for queue_template = p.queue_templates + for _ = 1:queue_template.k + Bse[j, :] .-= queue_template.p + Bse[j, :] .^= 2 + j += 1 + end + end + return Ase, Bse +end + +function grid_search(p::Parameters; num_simulations=10) + n = sum(tmpl.n for tmpl in p.node_templates) + k = sum(tmpl.k for tmpl in p.queue_templates) + + it = Base.Iterators.product( + # Inner step-size. + range(1e-5, 1e-2, length=10), + # Number of sub-iterations. + range(1, 10), + # Outer step-size. + range(1e-2, 0.25, length=10), + # Outer Nesterov acceleration. + range(0.0, 0.99, length=10) + ) + mses = zeros(length(it)) + for _ = num_simulations + Is, Js, as, bs, cs, ts = simulate(p) + for (i, (inner_step_size, num_sub_iterations, outer_step_size, outer_nesterov_acceleration)) = enumerate(it) + As, Bs = gd(Is, Js, cs, ts, n, k, inner_opt=Descent(inner_step_size), outer_opt=Nesterov(outer_step_size, outer_nesterov_acceleration), update_interval=60.0, num_sub_iterations=num_sub_iterations) + Ase, Bse = squared_error(p, As, Bs) + mses[i] += (sum(Ase) + sum(Bse)) / ((n+k)*length(cs)) + end + end + mses ./= num_simulations + return it, mses +end + +function main() + p = scenario3() + n = sum(tmpl.n for tmpl in p.node_templates) + k = sum(tmpl.k for tmpl in p.queue_templates) + + Is, Js, as, bs, cs, ts = simulate(p) + + plots = Vector{Plots.Plot}() + + As, Bs = gd(Is, Js, cs, ts, n, k, inner_opt=Descent(0.05), outer_opt=Nesterov(0.05, 0.2), update_interval=60.0, num_sub_iterations=10) + Ase, Bse = squared_error(p, As, Bs) + push!(plots, plot!(plot(ts, As', legend=false), ts, mean(Ase, dims=1)', color="black", legend=false)) + push!(plots, plot!(plot(ts, Bs', legend=false), ts, mean(Bse, dims=1)', color="black", legend=false)) + + As, Bs = gd(Is, Js, cs, ts, n, k, inner_opt=Descent(0.05), outer_opt=Nesterov(0.01, 0.9), update_interval=60.0, num_sub_iterations=10) + Ase, Bse = squared_error(p, As, Bs) + push!(plots, plot!(plot(ts, As', legend=false), ts, mean(Ase, dims=1)', color="black", legend=false)) + push!(plots, plot!(plot(ts, Bs', legend=false), ts, mean(Bse, dims=1)', color="black", legend=false)) + + As, Bs = gd(Is, Js, cs, ts, n, k, inner_opt=Descent(0.05), outer_opt=Nesterov(0.005, 0.99), update_interval=60.0, num_sub_iterations=10) + Ase, Bse = squared_error(p, As, Bs) + push!(plots, plot!(plot(ts, As', legend=false), ts, mean(Ase, dims=1)', color="black", legend=false)) + push!(plots, plot!(plot(ts, Bs', legend=false), ts, mean(Bse, dims=1)', color="black", legend=false)) + + plot(plots..., layout=(3, 2)) +end diff --git a/internal/scheduler/jobdb/jobdb_test.go b/internal/scheduler/jobdb/jobdb_test.go index abd6497dc2a..eda76e4e0b2 100644 --- a/internal/scheduler/jobdb/jobdb_test.go +++ b/internal/scheduler/jobdb/jobdb_test.go @@ -170,11 +170,23 @@ func TestJobDb_TestGetAll(t *testing.T) { require.NoError(t, err) actual := txn.GetAll() expected := []*Job{job1, job2} - slices.SortFunc(expected, func(a, b *Job) bool { - return a.id > b.id + slices.SortFunc(expected, func(a, b *Job) int { + if a.id > b.id { + return -1 + } else if a.id < b.id { + return 1 + } else { + return 0 + } }) - slices.SortFunc(actual, func(a, b *Job) bool { - return a.id > b.id + slices.SortFunc(actual, func(a, b *Job) int { + if a.id > b.id { + return -1 + } else if a.id < b.id { + return 1 + } else { + return 0 + } }) assert.Equal(t, expected, actual) } diff --git a/internal/scheduler/jobiteration.go b/internal/scheduler/jobiteration.go index e914988e4a3..f6733c26e39 100644 --- a/internal/scheduler/jobiteration.go +++ b/internal/scheduler/jobiteration.go @@ -72,8 +72,8 @@ func (repo *InMemoryJobRepository) EnqueueMany(jctxs []*schedulercontext.JobSche // sortQueue sorts jobs in a specified queue by the order in which they should be scheduled. func (repo *InMemoryJobRepository) sortQueue(queue string) { - slices.SortFunc(repo.jctxsByQueue[queue], func(a, b *schedulercontext.JobSchedulingContext) bool { - return a.Job.SchedulingOrderCompare(b.Job) == -1 + slices.SortFunc(repo.jctxsByQueue[queue], func(a, b *schedulercontext.JobSchedulingContext) int { + return a.Job.SchedulingOrderCompare(b.Job) }) } diff --git a/internal/scheduler/nodedb/nodeiteration_test.go b/internal/scheduler/nodedb/nodeiteration_test.go index 54447741808..5c23f158a88 100644 --- a/internal/scheduler/nodedb/nodeiteration_test.go +++ b/internal/scheduler/nodedb/nodeiteration_test.go @@ -45,7 +45,15 @@ func TestNodesIterator(t *testing.T) { } sortedNodes := slices.Clone(tc.Nodes) - slices.SortFunc(sortedNodes, func(a, b *schedulerobjects.Node) bool { return a.Id < b.Id }) + slices.SortFunc(sortedNodes, func(a, b *schedulerobjects.Node) int { + if a.Id < b.Id { + return -1 + } else if a.Id > b.Id { + return 1 + } else { + return 0 + } + }) expected := make([]int, len(sortedNodes)) for i, node := range sortedNodes { expected[i] = indexById[node.Id] diff --git a/internal/scheduler/preempting_queue_scheduler_test.go b/internal/scheduler/preempting_queue_scheduler_test.go index ea9650c66f9..9b4d622d7e2 100644 --- a/internal/scheduler/preempting_queue_scheduler_test.go +++ b/internal/scheduler/preempting_queue_scheduler_test.go @@ -2022,8 +2022,14 @@ func TestPreemptingQueueScheduler(t *testing.T) { // which jobs are preempted). slices.SortFunc( result.ScheduledJobs, - func(a, b *schedulercontext.JobSchedulingContext) bool { - return a.Job.GetSubmitTime().Before(b.Job.GetSubmitTime()) + func(a, b *schedulercontext.JobSchedulingContext) int { + if a.Job.GetSubmitTime().Before(b.Job.GetSubmitTime()) { + return -1 + } else if b.Job.GetSubmitTime().Before(a.Job.GetSubmitTime()) { + return 1 + } else { + return 0 + } }, ) var scheduledJobs []*jobdb.Job diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index 556ce651b61..8598c1a8d60 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -280,7 +280,6 @@ func (s *Scheduler) cycle(ctx *armadacontext.Context, updateAll bool, leaderToke // Update success probability estimates. if !s.failureEstimator.IsDisabled() { - s.failureEstimator.Decay() for _, jst := range jsts { if jst.Job == nil { continue @@ -290,12 +289,13 @@ func (s *Scheduler) cycle(ctx *armadacontext.Context, updateAll bool, leaderToke continue } if jst.Failed { - s.failureEstimator.Update(run.NodeName(), jst.Job.GetQueue(), false) + s.failureEstimator.Push(run.NodeName(), jst.Job.GetQueue(), false) } if jst.Succeeded { - s.failureEstimator.Update(run.NodeName(), jst.Job.GetQueue(), true) + s.failureEstimator.Push(run.NodeName(), jst.Job.GetQueue(), true) } } + s.failureEstimator.Update() } // Generate any eventSequences that came out of synchronising the db state. diff --git a/internal/scheduler/schedulerapp.go b/internal/scheduler/schedulerapp.go index 057da09276c..bc408ef156e 100644 --- a/internal/scheduler/schedulerapp.go +++ b/internal/scheduler/schedulerapp.go @@ -25,6 +25,8 @@ import ( grpcCommon "github.com/armadaproject/armada/internal/common/grpc" "github.com/armadaproject/armada/internal/common/health" "github.com/armadaproject/armada/internal/common/logging" + "github.com/armadaproject/armada/internal/common/optimisation/descent" + "github.com/armadaproject/armada/internal/common/optimisation/nesterov" "github.com/armadaproject/armada/internal/common/profiling" "github.com/armadaproject/armada/internal/common/pulsarutils" "github.com/armadaproject/armada/internal/common/serve" @@ -220,12 +222,14 @@ func Run(config schedulerconfig.Configuration) error { } failureEstimator, err := failureestimator.New( - config.Scheduling.FailureEstimatorConfig.NodeSuccessProbabilityCordonThreshold, - config.Scheduling.FailureEstimatorConfig.QueueSuccessProbabilityCordonThreshold, - config.Scheduling.FailureEstimatorConfig.NodeCordonTimeout, - config.Scheduling.FailureEstimatorConfig.QueueCordonTimeout, - config.Scheduling.FailureEstimatorConfig.NodeEquilibriumFailureRate, - config.Scheduling.FailureEstimatorConfig.QueueEquilibriumFailureRate, + config.Scheduling.FailureEstimatorConfig.NumInnerIterations, + // Invalid config will have failed validation. + descent.MustNew(config.Scheduling.FailureEstimatorConfig.InnerOptimiserStepSize), + // Invalid config will have failed validation. + nesterov.MustNew( + config.Scheduling.FailureEstimatorConfig.OuterOptimiserStepSize, + config.Scheduling.FailureEstimatorConfig.OuterOptimiserNesterovAcceleration, + ), ) if err != nil { return err diff --git a/internal/scheduler/schedulerobjects/nodetype.go b/internal/scheduler/schedulerobjects/nodetype.go index d52d35e36fb..020c57e675d 100644 --- a/internal/scheduler/schedulerobjects/nodetype.go +++ b/internal/scheduler/schedulerobjects/nodetype.go @@ -33,7 +33,15 @@ func NewNodeType(taints []v1.Taint, labels map[string]string, indexedTaints map[ // Sort taints to ensure node type id is consistent regardless of // the order in which taints are set on the node. - slices.SortFunc(taints, func(a, b v1.Taint) bool { return a.Key < b.Key }) // TODO: Use less ambiguous sorting. + slices.SortFunc(taints, func(a, b v1.Taint) int { + if a.Key < b.Key { + return -1 + } else if a.Key > b.Key { + return 1 + } else { + return 0 + } + }) // TODO: Use less ambiguous sorting. // Filter out any labels that should not be indexed. if indexedLabels != nil { diff --git a/internal/scheduler/schedulerobjects/podutils.go b/internal/scheduler/schedulerobjects/podutils.go index bbd087702db..1892c959b5b 100644 --- a/internal/scheduler/schedulerobjects/podutils.go +++ b/internal/scheduler/schedulerobjects/podutils.go @@ -227,45 +227,45 @@ func (skg *PodRequirementsSerialiser) AppendResourceList(out []byte, resourceLis return out } -func lessToleration(a, b v1.Toleration) bool { +func lessToleration(a, b v1.Toleration) int { if a.Key < b.Key { - return true + return -1 } else if a.Key > b.Key { - return false + return 1 } if a.Value < b.Value { - return true + return -1 } else if a.Value > b.Value { - return false + return 1 } if string(a.Operator) < string(b.Operator) { - return true + return -1 } else if string(a.Operator) > string(b.Operator) { - return false + return 1 } if string(a.Effect) < string(b.Effect) { - return true + return -1 } else if string(a.Effect) > string(b.Effect) { - return false + return 1 } - return true + return 0 } -func lessNodeSelectorRequirement(a, b v1.NodeSelectorRequirement) bool { +func lessNodeSelectorRequirement(a, b v1.NodeSelectorRequirement) int { if a.Key < b.Key { - return true + return -1 } else if a.Key > b.Key { - return false + return 1 } if string(a.Operator) < string(b.Operator) { - return true + return -1 } else if string(a.Operator) > string(b.Operator) { - return false + return 1 } if len(a.Values) < len(b.Values) { - return true + return -1 } else if len(a.Values) > len(b.Values) { - return false + return 1 } - return true + return 0 } diff --git a/internal/scheduler/simulator/simulator.go b/internal/scheduler/simulator/simulator.go index bf364bb75f1..b270b54bfd6 100644 --- a/internal/scheduler/simulator/simulator.go +++ b/internal/scheduler/simulator/simulator.go @@ -499,21 +499,21 @@ func (s *Simulator) handleScheduleEvent(ctx *armadacontext.Context) error { preemptedJobs := scheduler.PreemptedJobsFromSchedulerResult[*jobdb.Job](result) scheduledJobs := slices.Clone(result.ScheduledJobs) failedJobs := scheduler.FailedJobsFromSchedulerResult[*jobdb.Job](result) - lessJob := func(a, b *jobdb.Job) bool { + lessJob := func(a, b *jobdb.Job) int { if a.Queue() < b.Queue() { - return true + return -1 } else if a.Queue() > b.Queue() { - return false + return 1 } if a.Id() < b.Id() { - return true + return -1 } else if a.Id() > b.Id() { - return false + return 1 } - return false + return 0 } slices.SortFunc(preemptedJobs, lessJob) - slices.SortFunc(scheduledJobs, func(a, b *schedulercontext.JobSchedulingContext) bool { + slices.SortFunc(scheduledJobs, func(a, b *schedulercontext.JobSchedulingContext) int { return lessJob(a.Job.(*jobdb.Job), b.Job.(*jobdb.Job)) }) slices.SortFunc(failedJobs, lessJob)