Skip to content

Commit

Permalink
refactor: use wait.Poll function waiting for task state
Browse files Browse the repository at this point in the history
  • Loading branch information
dkoshkin committed Jan 3, 2024
1 parent cabb21b commit 5a13aa2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 76 deletions.
86 changes: 11 additions & 75 deletions pkg/client/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,45 +19,26 @@ package client
import (
"context"
"fmt"
"math"
"time"

ctrl "sigs.k8s.io/controller-runtime"

"github.com/nutanix-cloud-native/prism-go-client/utils"
nutanixClientV3 "github.com/nutanix-cloud-native/prism-go-client/v3"
"k8s.io/apimachinery/pkg/util/wait"
ctrl "sigs.k8s.io/controller-runtime"
)

type stateRefreshFunc func() (string, error)
const (
pollingInterval = time.Second * 2
stateSucceeded = "SUCCEEDED"
)

// WaitForTaskCompletion will poll indefinitely every 2 seconds for the task with uuid to have status of "SUCCEEDED".
// Returns an error from GetTaskState or a timeout error if the context is cancelled.
func WaitForTaskCompletion(ctx context.Context, conn *nutanixClientV3.Client, uuid string) error {
errCh := make(chan error, 1)
go waitForState(
errCh,
"SUCCEEDED",
waitUntilTaskStateFunc(ctx, conn, uuid))

err := <-errCh
return err
}

func waitForState(errCh chan<- error, target string, refresh stateRefreshFunc) {
err := Retry(2, 2, 0, func(_ uint) (bool, error) {
state, err := refresh()
if err != nil {
return false, err
} else if state == target {
return true, nil
}
return false, nil
return wait.PollImmediateInfiniteWithContext(ctx, pollingInterval, func(ctx context.Context) (done bool, err error) {
state, getErr := GetTaskState(ctx, conn, uuid)
return state == stateSucceeded, getErr
})
errCh <- err
}

func waitUntilTaskStateFunc(ctx context.Context, conn *nutanixClientV3.Client, uuid string) stateRefreshFunc {
return func() (string, error) {
return GetTaskState(ctx, conn, uuid)
}
}

func GetTaskState(ctx context.Context, client *nutanixClientV3.Client, taskUUID string) (string, error) {
Expand All @@ -77,48 +58,3 @@ func GetTaskState(ctx context.Context, client *nutanixClientV3.Client, taskUUID
log.V(1).Info(fmt.Sprintf("Status for task with UUID %s: %s", taskUUID, taskStatus))
return taskStatus, nil
}

// RetryableFunc performs an action and returns a bool indicating whether the
// function is done, or if it should keep retrying, and an error which will
// abort the retry and be returned by the Retry function. The 0-indexed attempt
// is passed with each call.
type RetryableFunc func(uint) (bool, error)

/*
Retry retries a function up to numTries times with exponential backoff.
If numTries == 0, retry indefinitely.
If interval == 0, Retry will not delay retrying and there will be no
exponential backoff.
If maxInterval == 0, maxInterval is set to +Infinity.
Intervals are in seconds.
Returns an error if initial > max intervals, if retries are exhausted, or if the passed function returns
an error.
*/
func Retry(initialInterval float64, maxInterval float64, numTries uint, function RetryableFunc) error {
if maxInterval == 0 {
maxInterval = math.Inf(1)
} else if initialInterval < 0 || initialInterval > maxInterval {
return fmt.Errorf("invalid retry intervals (negative or initial < max). Initial: %f, Max: %f", initialInterval, maxInterval)
}

var err error
done := false
interval := initialInterval
for i := uint(0); !done && (numTries == 0 || i < numTries); i++ {
done, err = function(i)
if err != nil {
return err
}

if !done {
// Retry after delay. Calculate next delay.
time.Sleep(time.Duration(interval) * time.Second)
interval = math.Min(interval*2, maxInterval)
}
}

if !done {
return fmt.Errorf("function never succeeded in Retry")
}
return nil
}
3 changes: 2 additions & 1 deletion pkg/client/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/util/wait"

nutanixTestClient "github.com/nutanix-cloud-native/cluster-api-provider-nutanix/test/helpers/prism-go-client/v3"
)
Expand Down Expand Up @@ -126,7 +127,7 @@ func Test_WaitForTaskCompletion(t *testing.T) {
fmt.Fprint(w, `{"status": "PENDING"}`)
},
ctx: ctx,
expectedErr: context.DeadlineExceeded,
expectedErr: wait.ErrWaitTimeout,
},
}
for _, tt := range tests {
Expand Down

0 comments on commit 5a13aa2

Please sign in to comment.