Skip to content

Commit

Permalink
Even better nbd dispatch shutdown
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Moore <jamesmoore@loopholelabs.io>
  • Loading branch information
jimmyaxod committed Dec 17, 2024
1 parent addbbe4 commit 16c9a91
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions pkg/storage/expose/nbd_dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package expose
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
Expand All @@ -13,6 +14,8 @@ import (
"github.com/loopholelabs/silo/pkg/storage"
)

var ErrShuttingDown = errors.New("shutting down. Cannot serve any new requests.")

Check failure on line 17 in pkg/storage/expose/nbd_dispatch.go

View workflow job for this annotation

GitHub Actions / golang

error-strings: error strings should not be capitalized or end with punctuation or a newline (revive)

const dispatchBufferSize = 4 * 1024 * 1024

/**
Expand Down Expand Up @@ -76,6 +79,8 @@ type Dispatch struct {
prov storage.Provider
fatal chan error
pendingResponses sync.WaitGroup
shuttingDown bool
shuttingDownLock sync.Mutex
metricPacketsIn uint64
metricPacketsOut uint64
metricReadAt uint64
Expand Down Expand Up @@ -140,6 +145,11 @@ func (d *Dispatch) GetMetrics() *DispatchMetrics {
}

func (d *Dispatch) Wait() {
d.shuttingDownLock.Lock()
d.shuttingDown = true
defer d.shuttingDownLock.Unlock()
// Stop accepting any new requests...

if d.logger != nil {
d.logger.Trace().Str("device", d.dev).Msg("nbd waiting for pending responses")
}
Expand Down Expand Up @@ -342,16 +352,22 @@ func (d *Dispatch) cmdRead(cmdHandle uint64, cmdFrom uint64, cmdLength uint32) e
case e = <-errchan:
}

errorValue := uint32(0)
if e != nil {
errorValue = 1
data = make([]byte, 0) // If there was an error, don't send data
return d.writeResponse(1, handle, []byte{})
}
return d.writeResponse(errorValue, handle, data)
return d.writeResponse(0, handle, data)
}

if d.asyncReads {
d.shuttingDownLock.Lock()
if !d.shuttingDown {
d.pendingResponses.Add(1)
} else {
d.shuttingDownLock.Unlock()
return ErrShuttingDown
}
d.shuttingDownLock.Unlock()

if d.asyncReads {
go func() {
ctime := time.Now()
err := performRead(cmdHandle, cmdFrom, cmdLength)
Expand All @@ -368,7 +384,6 @@ func (d *Dispatch) cmdRead(cmdHandle uint64, cmdFrom uint64, cmdLength uint32) e
d.pendingResponses.Done()
}()
} else {
d.pendingResponses.Add(1)
ctime := time.Now()
err := performRead(cmdHandle, cmdFrom, cmdLength)
if err == nil {
Expand Down Expand Up @@ -418,8 +433,16 @@ func (d *Dispatch) cmdWrite(cmdHandle uint64, cmdFrom uint64, cmdLength uint32,
return d.writeResponse(errorValue, handle, []byte{})
}

if d.asyncWrites {
d.shuttingDownLock.Lock()
if !d.shuttingDown {
d.pendingResponses.Add(1)
} else {
d.shuttingDownLock.Unlock()
return ErrShuttingDown
}
d.shuttingDownLock.Unlock()

if d.asyncWrites {
go func() {
ctime := time.Now()
err := performWrite(cmdHandle, cmdFrom, cmdLength, cmdData)
Expand All @@ -436,7 +459,6 @@ func (d *Dispatch) cmdWrite(cmdHandle uint64, cmdFrom uint64, cmdLength uint32,
d.pendingResponses.Done()
}()
} else {
d.pendingResponses.Add(1)
ctime := time.Now()
err := performWrite(cmdHandle, cmdFrom, cmdLength, cmdData)
if err == nil {
Expand Down

0 comments on commit 16c9a91

Please sign in to comment.