diff --git a/cmd/connect.go b/cmd/connect.go index 4d58ff62..1d9a49c1 100644 --- a/cmd/connect.go +++ b/cmd/connect.go @@ -2,39 +2,25 @@ package main import ( "context" - "crypto/sha256" "fmt" "io" "net" "net/http" "os" - "os/exec" "os/signal" - "sync" "syscall" "time" "github.com/loopholelabs/logging" "github.com/loopholelabs/logging/types" - "github.com/loopholelabs/silo/pkg/storage" - "github.com/loopholelabs/silo/pkg/storage/config" - "github.com/loopholelabs/silo/pkg/storage/expose" - "github.com/loopholelabs/silo/pkg/storage/integrity" + "github.com/loopholelabs/silo/pkg/storage/devicegroup" "github.com/loopholelabs/silo/pkg/storage/metrics" siloprom "github.com/loopholelabs/silo/pkg/storage/metrics/prometheus" - "github.com/loopholelabs/silo/pkg/storage/modules" "github.com/loopholelabs/silo/pkg/storage/protocol" - "github.com/loopholelabs/silo/pkg/storage/protocol/packets" - "github.com/loopholelabs/silo/pkg/storage/sources" - "github.com/loopholelabs/silo/pkg/storage/waitingcache" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/spf13/cobra" - - "github.com/fatih/color" - "github.com/vbauerster/mpb/v8" - "github.com/vbauerster/mpb/v8/decor" ) var ( @@ -49,32 +35,12 @@ var ( // Address to connect to var connectAddr string -// Should we expose each device as an nbd device? -var connectExposeDev bool - -// Should we also mount the devices -var connectMountDev bool - -var connectProgress bool - var connectDebug bool - var connectMetrics string -// List of ExposedStorage so they can be cleaned up on exit. -var dstExposed []storage.ExposedStorage - -var dstProgress *mpb.Progress -var dstBars []*mpb.Bar -var dstWG sync.WaitGroup -var dstWGFirst bool - func init() { rootCmd.AddCommand(cmdConnect) cmdConnect.Flags().StringVarP(&connectAddr, "addr", "a", "localhost:5170", "Address to serve from") - cmdConnect.Flags().BoolVarP(&connectExposeDev, "expose", "e", false, "Expose as an nbd devices") - cmdConnect.Flags().BoolVarP(&connectMountDev, "mount", "m", false, "Mount the nbd devices") - cmdConnect.Flags().BoolVarP(&connectProgress, "progress", "p", false, "Show progress") cmdConnect.Flags().BoolVarP(&connectDebug, "debug", "d", false, "Debug logging (trace)") cmdConnect.Flags().StringVarP(&connectMetrics, "metrics", "M", "", "Prom metrics address") } @@ -116,26 +82,17 @@ func runConnect(_ *cobra.Command, _ []string) { go http.ListenAndServe(connectMetrics, nil) } - if connectProgress { - dstProgress = mpb.New( - mpb.WithOutput(color.Output), - mpb.WithAutoRefresh(), - ) - - dstBars = make([]*mpb.Bar, 0) - } - fmt.Printf("Starting silo connect from source %s\n", connectAddr) - dstExposed = make([]storage.ExposedStorage, 0) + var dg *devicegroup.DeviceGroup // Handle shutdown gracefully to disconnect any exposed devices correctly. c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) go func() { <-c - for _, e := range dstExposed { - _ = dstDeviceShutdown(e) + if dg != nil { + dg.CloseAll() } os.Exit(1) }() @@ -148,16 +105,10 @@ func runConnect(_ *cobra.Command, _ []string) { } // Wrap the connection in a protocol, and handle incoming devices - dstWGFirst = true - dstWG.Add(1) // We need to at least wait for one to complete. protoCtx, protoCancelfn := context.WithCancel(context.TODO()) - handleIncomingDevice := func(ctx context.Context, pro protocol.Protocol, dev uint32) { - handleIncomingDeviceWithLogging(ctx, pro, dev, log, siloMetrics) - } - - pro := protocol.NewRW(protoCtx, []io.Reader{con}, []io.Writer{con}, handleIncomingDevice) + pro := protocol.NewRW(protoCtx, []io.Reader{con}, []io.Writer{con}, nil) // Let the protocol do its thing. go func() { @@ -175,347 +126,28 @@ func runConnect(_ *cobra.Command, _ []string) { siloMetrics.AddProtocol("protocol", pro) } - dstWG.Wait() // Wait until the migrations have completed... - - if connectProgress { - dstProgress.Wait() - } - - if log != nil { - metrics := pro.GetMetrics() - log.Debug(). - Uint64("PacketsSent", metrics.PacketsSent). - Uint64("DataSent", metrics.DataSent). - Uint64("PacketsRecv", metrics.PacketsRecv). - Uint64("DataRecv", metrics.DataRecv). - Msg("protocol metrics") - } - - fmt.Printf("\nMigrations completed. Please ctrl-c if you want to shut down, or wait an hour :)\n") - - // We should pause here, to allow the user to do things with the devices - time.Sleep(10 * time.Hour) - - // Shutdown any storage exposed as devices - for _, e := range dstExposed { - _ = dstDeviceShutdown(e) - } -} - -// Handle a new incoming device. This is called when a packet is received for a device we haven't heard about before. -func handleIncomingDeviceWithLogging(ctx context.Context, pro protocol.Protocol, dev uint32, log types.RootLogger, met metrics.SiloMetrics) { - var destStorage storage.Provider - var destWaitingLocal *waitingcache.Local - var destWaitingRemote *waitingcache.Remote - var destMonitorStorage *modules.Hooks - var dest *protocol.FromProtocol - - var devSchema *config.DeviceSchema - - var bar *mpb.Bar - - var blockSize uint - var deviceName string - - var statusString = " " - var statusVerify = " " - var statusExposed = " " - - if !dstWGFirst { - // We have a new migration to deal with - dstWG.Add(1) + // TODO: Modify schemas a bit here... + tweak := func(_ int, _ string, schema string) string { + return schema } - dstWGFirst = false - - // This is a storage factory which will be called when we recive DevInfo. - storageFactory := func(di *packets.DevInfo) storage.Provider { - // fmt.Printf("= %d = Received DevInfo name=%s size=%d blocksize=%d schema=%s\n", dev, di.Name, di.Size, di.Block_size, di.Schema) - - // Decode the schema - devSchema = &config.DeviceSchema{} - err := devSchema.Decode(di.Schema) - if err != nil { - panic(err) - } - - blockSize = uint(di.BlockSize) - deviceName = di.Name - - statusFn := func(_ decor.Statistics) string { - return statusString + statusVerify - } - - if connectProgress { - bar = dstProgress.AddBar(int64(di.Size), - mpb.PrependDecorators( - decor.Name(di.Name, decor.WCSyncSpaceR), - decor.Name(" "), - decor.Any(func(_ decor.Statistics) string { return statusExposed }, decor.WC{W: 4}), - decor.Name(" "), - decor.CountersKiloByte("%d/%d", decor.WCSyncWidth), - ), - mpb.AppendDecorators( - decor.EwmaETA(decor.ET_STYLE_GO, 30), - decor.Name(" "), - decor.EwmaSpeed(decor.SizeB1024(0), "% .2f", 60, decor.WCSyncWidth), - decor.OnComplete(decor.Percentage(decor.WC{W: 5}), "done"), - decor.Name(" "), - decor.Any(statusFn, decor.WC{W: 2}), - ), - ) - - dstBars = append(dstBars, bar) - } - - // You can change this to use sources.NewFileStorage etc etc - cr := func(_ int, s int) (storage.Provider, error) { - return sources.NewMemoryStorage(s), nil - } - // Setup some sharded memory storage (for concurrent write speed) - shardSize := di.Size - if di.Size > 64*1024 { - shardSize = di.Size / 1024 - } - - destStorage, err = modules.NewShardedStorage(int(di.Size), int(shardSize), cr) - if err != nil { - panic(err) // FIXME - } - - destMonitorStorage = modules.NewHooks(destStorage) - - if connectProgress { - lastValue := uint64(0) - lastTime := time.Now() - - destMonitorStorage.PostWrite = func(_ []byte, _ int64, n int, err error) (int, error) { - // Update the progress bar - available, total := destWaitingLocal.Availability() - v := uint64(available) * di.Size / uint64(total) - bar.SetCurrent(int64(v)) - bar.EwmaIncrInt64(int64(v-lastValue), time.Since(lastTime)) - lastTime = time.Now() - lastValue = v - - return n, err - } - } - // Use a WaitingCache which will wait for migration blocks, send priorities etc - // A WaitingCache has two ends - local and remote. - destWaitingLocal, destWaitingRemote = waitingcache.NewWaitingCache(destMonitorStorage, int(di.BlockSize)) + dg, err = devicegroup.NewFromProtocol(protoCtx, pro, tweak, nil, nil, log, siloMetrics) - // Connect the waitingCache to the FromProtocol. - // Note that since these are hints, errors don't matter too much. - destWaitingLocal.NeedAt = func(offset int64, length int32) { - _ = dest.NeedAt(offset, length) + for _, d := range dg.GetDeviceSchema() { + expName := dg.GetExposedDeviceByName(d.Name) + if expName != nil { + fmt.Printf("Device %s exposed at %s\n", d.Name, expName.Device()) } - - destWaitingLocal.DontNeedAt = func(offset int64, length int32) { - _ = dest.DontNeedAt(offset, length) - } - - conf := &config.DeviceSchema{} - _ = conf.Decode(di.Schema) - - // Expose this storage as a device if requested - if connectExposeDev { - p, err := dstDeviceSetup(destWaitingLocal) - if err != nil { - fmt.Printf("= %d = Error during setup (expose nbd) %v\n", dev, err) - } else { - statusExposed = p.Device() - dstExposed = append(dstExposed, p) - } - } - return destWaitingRemote } - dest = protocol.NewFromProtocol(ctx, dev, storageFactory, pro) - - if met != nil { - met.AddFromProtocol(deviceName, dest) - } + // Wait for completion events. + dg.WaitForCompletion() - var handlerWG sync.WaitGroup - - handlerWG.Add(1) - go func() { - _ = dest.HandleReadAt() - handlerWG.Done() - }() - handlerWG.Add(1) - go func() { - _ = dest.HandleWriteAt() - handlerWG.Done() - }() - handlerWG.Add(1) - go func() { - _ = dest.HandleDevInfo() - handlerWG.Done() - }() - - handlerWG.Add(1) - // Handle events from the source - go func() { - _ = dest.HandleEvent(func(e *packets.Event) { - switch e.Type { - - case packets.EventPostLock: - statusString = "L" // red.Sprintf("L") - case packets.EventPreLock: - statusString = "l" // red.Sprintf("l") - case packets.EventPostUnlock: - statusString = "U" // green.Sprintf("U") - case packets.EventPreUnlock: - statusString = "u" // green.Sprintf("u") - - // fmt.Printf("= %d = Event %s\n", dev, protocol.EventsByType[e.Type]) - // Check we have all data... - case packets.EventCompleted: - - if log != nil { - m := destWaitingLocal.GetMetrics() - log.Debug(). - Uint64("WaitForBlock", m.WaitForBlock). - Uint64("WaitForBlockHadRemote", m.WaitForBlockHadRemote). - Uint64("WaitForBlockHadLocal", m.WaitForBlockHadLocal). - Uint64("WaitForBlockTimeMS", uint64(m.WaitForBlockTime.Milliseconds())). - Uint64("WaitForBlockLock", m.WaitForBlockLock). - Uint64("WaitForBlockLockDone", m.WaitForBlockLockDone). - Uint64("MarkAvailableLocalBlock", m.MarkAvailableLocalBlock). - Uint64("MarkAvailableRemoteBlock", m.MarkAvailableRemoteBlock). - Uint64("AvailableLocal", m.AvailableLocal). - Uint64("AvailableRemote", m.AvailableRemote). - Str("name", deviceName). - Msg("waitingCacheMetrics") - - fromMetrics := dest.GetMetrics() - log.Debug(). - Uint64("RecvEvents", fromMetrics.RecvEvents). - Uint64("RecvHashes", fromMetrics.RecvHashes). - Uint64("RecvDevInfo", fromMetrics.RecvDevInfo). - Uint64("RecvAltSources", fromMetrics.RecvAltSources). - Uint64("RecvReadAt", fromMetrics.RecvReadAt). - Uint64("RecvWriteAtHash", fromMetrics.RecvWriteAtHash). - Uint64("RecvWriteAtComp", fromMetrics.RecvWriteAtComp). - Uint64("RecvWriteAt", fromMetrics.RecvWriteAt). - Uint64("RecvWriteAtWithMap", fromMetrics.RecvWriteAtWithMap). - Uint64("RecvRemoveFromMap", fromMetrics.RecvRemoveFromMap). - Uint64("RecvRemoveDev", fromMetrics.RecvRemoveDev). - Uint64("RecvDirtyList", fromMetrics.RecvDirtyList). - Uint64("SentNeedAt", fromMetrics.SentNeedAt). - Uint64("SentDontNeedAt", fromMetrics.SentDontNeedAt). - Str("name", deviceName). - Msg("fromProtocolMetrics") - } - - // We completed the migration, but we should wait for handlers to finish before we ok things... - // fmt.Printf("Completed, now wait for handlers...\n") - go func() { - handlerWG.Wait() - dstWG.Done() - }() - // available, total := destWaitingLocal.Availability() - // fmt.Printf("= %d = Availability (%d/%d)\n", dev, available, total) - // Set bar to completed - if connectProgress { - bar.SetCurrent(int64(destWaitingLocal.Size())) - } - } - }) - handlerWG.Done() - }() - - handlerWG.Add(1) - go func() { - _ = dest.HandleHashes(func(hashes map[uint][sha256.Size]byte) { - // fmt.Printf("[%d] Got %d hashes...\n", dev, len(hashes)) - if len(hashes) > 0 { - in := integrity.NewChecker(int64(destStorage.Size()), int(blockSize)) - in.SetHashes(hashes) - correct, err := in.Check(destStorage) - if err != nil { - panic(err) - } - // fmt.Printf("[%d] Verification result %t %v\n", dev, correct, err) - if correct { - statusVerify = "\u2611" - } else { - statusVerify = "\u2612" - } - } - }) - handlerWG.Done() - }() - - // Handle dirty list by invalidating local waiting cache - handlerWG.Add(1) - go func() { - _ = dest.HandleDirtyList(func(dirty []uint) { - // fmt.Printf("= %d = LIST OF DIRTY BLOCKS %v\n", dev, dirty) - destWaitingLocal.DirtyBlocks(dirty) - }) - handlerWG.Done() - }() -} - -// Called to setup an exposed storage device -func dstDeviceSetup(prov storage.Provider) (storage.ExposedStorage, error) { - p := expose.NewExposedStorageNBDNL(prov, expose.DefaultConfig) - var err error - - err = p.Init() - if err != nil { - // fmt.Printf("\n\n\np.Init returned %v\n\n\n", err) - return nil, err - } - - device := p.Device() - // fmt.Printf("* Device ready on /dev/%s\n", device) - - // We could also mount the device, but we should do so inside a goroutine, so that it doesn't block things... - if connectMountDev { - err = os.Mkdir(fmt.Sprintf("/mnt/mount%s", device), 0600) - if err != nil { - return nil, fmt.Errorf("error mkdir %v", err) - } - - go func() { - // fmt.Printf("Mounting device...") - cmd := exec.Command("mount", "-r", fmt.Sprintf("/dev/%s", device), fmt.Sprintf("/mnt/mount%s", device)) - err = cmd.Run() - if err != nil { - fmt.Printf("Could not mount device %v\n", err) - return - } - // fmt.Printf("* Device is mounted at /mnt/mount%s\n", device) - }() - } - - return p, nil -} - -// Called to shutdown an exposed storage device -func dstDeviceShutdown(p storage.ExposedStorage) error { - device := p.Device() + fmt.Printf("\nMigrations completed. Please ctrl-c if you want to shut down, or wait an hour :)\n") - fmt.Printf("Shutdown %s\n", device) - if connectMountDev { - cmd := exec.Command("umount", fmt.Sprintf("/dev/%s", device)) - err := cmd.Run() - if err != nil { - return err - } - err = os.Remove(fmt.Sprintf("/mnt/mount%s", device)) - if err != nil { - return err - } - } + // We should pause here, to allow the user to do things with the devices + time.Sleep(1 * time.Hour) - err := p.Shutdown() - if err != nil { - return err - } - return nil + // Shutdown any storage exposed as devices + dg.CloseAll() } diff --git a/cmd/serve.go b/cmd/serve.go index caf53b91..c114cf2f 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -8,32 +8,21 @@ import ( "net/http" "os" "os/signal" - "sync" "syscall" "time" - "github.com/fatih/color" "github.com/loopholelabs/logging" "github.com/loopholelabs/logging/types" - "github.com/loopholelabs/silo/pkg/storage" - "github.com/loopholelabs/silo/pkg/storage/blocks" "github.com/loopholelabs/silo/pkg/storage/config" - "github.com/loopholelabs/silo/pkg/storage/device" - "github.com/loopholelabs/silo/pkg/storage/dirtytracker" - "github.com/loopholelabs/silo/pkg/storage/expose" + "github.com/loopholelabs/silo/pkg/storage/devicegroup" "github.com/loopholelabs/silo/pkg/storage/metrics" siloprom "github.com/loopholelabs/silo/pkg/storage/metrics/prometheus" "github.com/loopholelabs/silo/pkg/storage/migrator" - "github.com/loopholelabs/silo/pkg/storage/modules" "github.com/loopholelabs/silo/pkg/storage/protocol" - "github.com/loopholelabs/silo/pkg/storage/protocol/packets" - "github.com/loopholelabs/silo/pkg/storage/volatilitymonitor" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/spf13/cobra" - "github.com/vbauerster/mpb/v8" - "github.com/vbauerster/mpb/v8/decor" ) var ( @@ -47,42 +36,18 @@ var ( var serveAddr string var serveConf string -var serveProgress bool var serveContinuous bool -var serveAnyOrder bool -var serveCompress bool var serveMetrics string - -var srcExposed []storage.ExposedStorage -var srcStorage []*storageInfo - -var serveProgressBar *mpb.Progress -var serveBars []*mpb.Bar - var serveDebug bool func init() { rootCmd.AddCommand(cmdServe) cmdServe.Flags().StringVarP(&serveAddr, "addr", "a", ":5170", "Address to serve from") cmdServe.Flags().StringVarP(&serveConf, "conf", "c", "silo.conf", "Configuration file") - cmdServe.Flags().BoolVarP(&serveProgress, "progress", "p", false, "Show progress") - cmdServe.Flags().BoolVarP(&serveContinuous, "continuous", "C", false, "Continuous sync") - cmdServe.Flags().BoolVarP(&serveAnyOrder, "order", "o", false, "Any order (faster)") cmdServe.Flags().BoolVarP(&serveDebug, "debug", "d", false, "Debug logging (trace)") cmdServe.Flags().StringVarP(&serveMetrics, "metrics", "m", "", "Prom metrics address") - cmdServe.Flags().BoolVarP(&serveCompress, "compress", "x", false, "Compress") -} - -type storageInfo struct { - // tracker storage.TrackingStorageProvider - tracker *dirtytracker.Remote - lockable storage.LockableProvider - orderer *blocks.PriorityBlockOrder - numBlocks int - blockSize int - name string - schema string + cmdServe.Flags().BoolVarP(&serveContinuous, "continuous", "C", false, "Continuous sync") } func runServe(_ *cobra.Command, _ []string) { @@ -119,46 +84,37 @@ func runServe(_ *cobra.Command, _ []string) { go http.ListenAndServe(serveMetrics, nil) } - if serveProgress { - serveProgressBar = mpb.New( - mpb.WithOutput(color.Output), - mpb.WithAutoRefresh(), - ) - serveBars = make([]*mpb.Bar, 0) - } - - srcExposed = make([]storage.ExposedStorage, 0) - srcStorage = make([]*storageInfo, 0) fmt.Printf("Starting silo serve %s\n", serveAddr) - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - go func() { - <-c - shutdownEverything(log) - os.Exit(1) - }() - siloConf, err := config.ReadSchema(serveConf) if err != nil { panic(err) } - for i, s := range siloConf.Device { - fmt.Printf("Setup storage %d [%s] size %s - %d\n", i, s.Name, s.Size, s.ByteSize()) - sinfo, err := setupStorageDevice(s, log, siloMetrics) - if err != nil { - panic(fmt.Sprintf("Could not setup storage. %v", err)) - } + dg, err := devicegroup.NewFromSchema(siloConf.Device, log, siloMetrics) + if err != nil { + panic(err) + } - srcStorage = append(srcStorage, sinfo) + for _, d := range siloConf.Device { + expName := dg.GetExposedDeviceByName(d.Name) + if expName != nil { + fmt.Printf("Device %s exposed at %s\n", d.Name, expName.Device()) + } } - // Setup listener here. When client connects, migrate data to it. + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + <-c + dg.CloseAll() + os.Exit(1) + }() + // Setup listener here. When client connects, migrate data to it. l, err := net.Listen("tcp", serveAddr) if err != nil { - shutdownEverything(log) + dg.CloseAll() panic("Listener issue...") } @@ -179,27 +135,65 @@ func runServe(_ *cobra.Command, _ []string) { siloMetrics.AddProtocol("serve", pro) } - // Lets go through each of the things we want to migrate... ctime := time.Now() - var wg sync.WaitGroup + // Migrate everything to the destination... + err = dg.StartMigrationTo(pro) + if err != nil { + dg.CloseAll() + panic(err) + } + + err = dg.MigrateAll(1000, func(ps map[string]*migrator.MigrationProgress) { + for name, p := range ps { + fmt.Printf("[%s] Progress Moved: %d/%d %.2f%% Clean: %d/%d %.2f%% InProgress: %d\n", + name, p.MigratedBlocks, p.TotalBlocks, p.MigratedBlocksPerc, + p.ReadyBlocks, p.TotalBlocks, p.ReadyBlocksPerc, + p.ActiveBlocks) + } + }) + if err != nil { + dg.CloseAll() + panic(err) + } + + fmt.Printf("All devices migrated in %dms.\n", time.Since(ctime).Milliseconds()) - for i, s := range srcStorage { - wg.Add(1) - go func(index int, src *storageInfo) { - err := migrateDevice(log, siloMetrics, uint32(index), src.name, pro, src) - if err != nil { - fmt.Printf("There was an issue migrating the storage %d %v\n", index, err) + // Now do a dirty block phase... + hooks := &devicegroup.MigrateDirtyHooks{ + PreGetDirty: func(name string) error { + fmt.Printf("# [%s]PreGetDirty\n", name) + return nil + }, + PostGetDirty: func(name string, blocks []uint) (bool, error) { + fmt.Printf("# [%s]PostGetDirty %d\n", name, len(blocks)) + if serveContinuous { + return true, nil } - wg.Done() - }(i, s) + return len(blocks) > 0, nil + }, + PostMigrateDirty: func(name string, blocks []uint) (bool, error) { + fmt.Printf("# [%s]PostMigrateDirty %d\n", name, len(blocks)) + time.Sleep(1 * time.Second) // Wait a bit for next dirty loop + return true, nil + }, + Completed: func(name string) { + fmt.Printf("# [%s]Completed\n", name) + }, + } + err = dg.MigrateDirty(hooks) + if err != nil { + dg.CloseAll() + panic(err) } - wg.Wait() - if serveProgressBar != nil { - serveProgressBar.Wait() + fmt.Printf("All devices migrated(including dirty) in %dms.\n", time.Since(ctime).Milliseconds()) + + err = dg.Completed() // Send completion events for the devices. + if err != nil { + dg.CloseAll() + panic(err) } - fmt.Printf("\n\nMigration completed in %dms\n", time.Since(ctime).Milliseconds()) if log != nil { metrics := pro.GetMetrics() @@ -213,319 +207,5 @@ func runServe(_ *cobra.Command, _ []string) { con.Close() } - shutdownEverything(log) -} - -func shutdownEverything(log types.Logger) { - // first unlock everything - fmt.Printf("Unlocking devices...\n") - for _, i := range srcStorage { - i.lockable.Unlock() - i.tracker.Close() - } - - fmt.Printf("Shutting down devices cleanly...\n") - for _, p := range srcExposed { - device := p.Device() - - fmt.Printf("Shutdown nbd device %s\n", device) - _ = p.Shutdown() - - // Show some metrics... - if log != nil { - nbdDevice, ok := p.(*expose.ExposedStorageNBDNL) - if ok { - m := nbdDevice.GetMetrics() - log.Debug(). - Uint64("PacketsIn", m.PacketsIn). - Uint64("PacketsOut", m.PacketsOut). - Uint64("ReadAt", m.ReadAt). - Uint64("ReadAtBytes", m.ReadAtBytes). - Uint64("ReadAtTimeMS", uint64(m.ReadAtTime.Milliseconds())). - Uint64("WriteAt", m.WriteAt). - Uint64("WriteAtBytes", m.WriteAtBytes). - Uint64("WriteAtTimeMS", uint64(m.WriteAtTime.Milliseconds())). - Str("device", p.Device()). - Msg("NBD metrics") - } - } - } -} - -func setupStorageDevice(conf *config.DeviceSchema, log types.Logger, met metrics.SiloMetrics) (*storageInfo, error) { - source, ex, err := device.NewDeviceWithLoggingMetrics(conf, log, met) - if err != nil { - return nil, err - } - if ex != nil { - fmt.Printf("Device %s exposed as %s\n", conf.Name, ex.Device()) - srcExposed = append(srcExposed, ex) - } - - blockSize := 1024 * 128 - - if conf.BlockSize != "" { - blockSize = int(conf.ByteBlockSize()) - } - - numBlocks := (int(conf.ByteSize()) + blockSize - 1) / blockSize - - sourceMetrics := modules.NewMetrics(source) - sourceDirtyLocal, sourceDirtyRemote := dirtytracker.NewDirtyTracker(sourceMetrics, blockSize) - sourceMonitor := volatilitymonitor.NewVolatilityMonitor(sourceDirtyLocal, blockSize, 10*time.Second) - sourceStorage := modules.NewLockable(sourceMonitor) - - if met != nil { - met.AddDirtyTracker(conf.Name, sourceDirtyRemote) - met.AddVolatilityMonitor(conf.Name, sourceMonitor) - met.AddMetrics(conf.Name, sourceMetrics) - } - - if ex != nil { - ex.SetProvider(sourceStorage) - } - - // Start monitoring blocks. - - var primaryOrderer storage.BlockOrder - primaryOrderer = sourceMonitor - - if serveAnyOrder { - primaryOrderer = blocks.NewAnyBlockOrder(numBlocks, nil) - } - orderer := blocks.NewPriorityBlockOrder(numBlocks, primaryOrderer) - orderer.AddAll() - - schema := string(conf.Encode()) - - sinfo := &storageInfo{ - tracker: sourceDirtyRemote, - lockable: sourceStorage, - orderer: orderer, - blockSize: blockSize, - numBlocks: numBlocks, - name: conf.Name, - schema: schema, - } - - return sinfo, nil -} - -// Migrate a device -func migrateDevice(log types.Logger, met metrics.SiloMetrics, devID uint32, name string, - pro protocol.Protocol, - sinfo *storageInfo) error { - size := sinfo.lockable.Size() - dest := protocol.NewToProtocol(size, devID, pro) - - // Maybe compress writes - dest.SetCompression(serveCompress) - - err := dest.SendDevInfo(name, uint32(sinfo.blockSize), sinfo.schema) - if err != nil { - return err - } - - statusString := " " - - statusFn := func(_ decor.Statistics) string { - return statusString - } - - var bar *mpb.Bar - if serveProgress { - bar = serveProgressBar.AddBar(int64(size), - mpb.PrependDecorators( - decor.Name(name, decor.WCSyncSpaceR), - decor.CountersKiloByte("%d/%d", decor.WCSyncWidth), - ), - mpb.AppendDecorators( - decor.EwmaETA(decor.ET_STYLE_GO, 30), - decor.Name(" "), - decor.EwmaSpeed(decor.SizeB1024(0), "% .2f", 60, decor.WCSyncWidth), - decor.OnComplete(decor.Percentage(decor.WC{W: 5}), "done"), - decor.Name(" "), - decor.Any(statusFn, decor.WC{W: 2}), - ), - ) - - serveBars = append(serveBars, bar) - } - - go func() { - _ = dest.HandleNeedAt(func(offset int64, length int32) { - // Prioritize blocks... - end := uint64(offset + int64(length)) - if end > size { - end = size - } - - bStart := int(offset / int64(sinfo.blockSize)) - bEnd := int((end-1)/uint64(sinfo.blockSize)) + 1 - for b := bStart; b < bEnd; b++ { - // Ask the orderer to prioritize these blocks... - sinfo.orderer.PrioritiseBlock(b) - } - }) - }() - - go func() { - _ = dest.HandleDontNeedAt(func(offset int64, length int32) { - end := uint64(offset + int64(length)) - if end > size { - end = size - } - - bStart := int(offset / int64(sinfo.blockSize)) - bEnd := int((end-1)/uint64(sinfo.blockSize)) + 1 - for b := bStart; b < bEnd; b++ { - sinfo.orderer.Remove(b) - } - }) - }() - - conf := migrator.NewConfig().WithBlockSize(sinfo.blockSize) - conf.Logger = log - conf.LockerHandler = func() { - _ = dest.SendEvent(&packets.Event{Type: packets.EventPreLock}) - sinfo.lockable.Lock() - _ = dest.SendEvent(&packets.Event{Type: packets.EventPostLock}) - } - conf.UnlockerHandler = func() { - _ = dest.SendEvent(&packets.Event{Type: packets.EventPreUnlock}) - sinfo.lockable.Unlock() - _ = dest.SendEvent(&packets.Event{Type: packets.EventPostUnlock}) - } - conf.Concurrency = map[int]int{ - storage.BlockTypeAny: 1000, - } - conf.ErrorHandler = func(_ *storage.BlockInfo, err error) { - // For now... - panic(err) - } - conf.Integrity = true - - lastValue := uint64(0) - lastTime := time.Now() - - if serveProgress { - - conf.ProgressHandler = func(p *migrator.MigrationProgress) { - v := uint64(p.ReadyBlocks) * uint64(sinfo.blockSize) - if v > size { - v = size - } - bar.SetCurrent(int64(v)) - bar.EwmaIncrInt64(int64(v-lastValue), time.Since(lastTime)) - lastTime = time.Now() - lastValue = v - } - } else { - conf.ProgressHandler = func(p *migrator.MigrationProgress) { - fmt.Printf("[%s] Progress Moved: %d/%d %.2f%% Clean: %d/%d %.2f%% InProgress: %d\n", - name, p.MigratedBlocks, p.TotalBlocks, p.MigratedBlocksPerc, - p.ReadyBlocks, p.TotalBlocks, p.ReadyBlocksPerc, - p.ActiveBlocks) - } - conf.ErrorHandler = func(b *storage.BlockInfo, err error) { - fmt.Printf("[%s] Error for block %d error %v\n", name, b.Block, err) - } - } - - mig, err := migrator.NewMigrator(sinfo.tracker, dest, sinfo.orderer, conf) - - if err != nil { - return err - } - - if met != nil { - met.AddToProtocol(name, dest) - met.AddMigrator(name, mig) - } - - migrateBlocks := sinfo.numBlocks - - // Now do the migration... - err = mig.Migrate(migrateBlocks) - if err != nil { - return err - } - - // Wait for completion. - err = mig.WaitForCompletion() - if err != nil { - return err - } - - hashes := mig.GetHashes() // Get the initial hashes and send them over for verification... - err = dest.SendHashes(hashes) - if err != nil { - return err - } - - // Optional: Enter a loop looking for more dirty blocks to migrate... - for { - blocks := mig.GetLatestDirty() // - if !serveContinuous && blocks == nil { - break - } - - if blocks != nil { - // Optional: Send the list of dirty blocks over... - err := dest.DirtyList(conf.BlockSize, blocks) - if err != nil { - return err - } - - // fmt.Printf("[%s] Migrating dirty blocks %d\n", name, len(blocks)) - err = mig.MigrateDirty(blocks) - if err != nil { - return err - } - } else { - mig.Unlock() - } - time.Sleep(100 * time.Millisecond) - } - - err = mig.WaitForCompletion() - if err != nil { - return err - } - - // fmt.Printf("[%s] Migration completed\n", name) - - err = dest.SendEvent(&packets.Event{Type: packets.EventCompleted}) - if err != nil { - return err - } - /* - // Completed. - if serve_progress { - // bar.SetCurrent(int64(size)) - // bar.EwmaIncrInt64(int64(size-last_value), time.Since(last_time)) - } - */ - - if log != nil { - toMetrics := dest.GetMetrics() - log.Debug(). - Str("name", name). - Uint64("SentEvents", toMetrics.SentEvents). - Uint64("SentHashes", toMetrics.SentHashes). - Uint64("SentDevInfo", toMetrics.SentDevInfo). - Uint64("SentRemoveDev", toMetrics.SentRemoveDev). - Uint64("SentDirtyList", toMetrics.SentDirtyList). - Uint64("SentReadAt", toMetrics.SentReadAt). - Uint64("SentWriteAtHash", toMetrics.SentWriteAtHash). - Uint64("SentWriteAtComp", toMetrics.SentWriteAtComp). - Uint64("SentWriteAt", toMetrics.SentWriteAt). - Uint64("SentWriteAtWithMap", toMetrics.SentWriteAtWithMap). - Uint64("SentRemoveFromMap", toMetrics.SentRemoveFromMap). - Uint64("SentNeedAt", toMetrics.RecvNeedAt). - Uint64("SentDontNeedAt", toMetrics.RecvDontNeedAt). - Msg("ToProtocol metrics") - } - - return nil + dg.CloseAll() } diff --git a/pkg/storage/config/silo.go b/pkg/storage/config/silo.go index 1b837fe2..8b2e9681 100644 --- a/pkg/storage/config/silo.go +++ b/pkg/storage/config/silo.go @@ -1,6 +1,7 @@ package config import ( + "errors" "fmt" "os" "strconv" @@ -123,6 +124,25 @@ func (ds *DeviceSchema) Encode() []byte { return f.Bytes() } +func (ds *DeviceSchema) EncodeAsBlock() []byte { + f := hclwrite.NewEmptyFile() + block := gohcl.EncodeAsBlock(ds, "device") + f.Body().AppendBlock(block) + return f.Bytes() +} + +func DecodeDeviceFromBlock(schema string) (*DeviceSchema, error) { + sf := &SiloSchema{} + err := sf.Decode([]byte(schema)) + if err != nil { + return nil, err + } + if len(sf.Device) != 1 { + return nil, errors.New("more than one device in schema") + } + return sf.Device[0], nil +} + func (ds *DeviceSchema) Decode(schema string) error { file, diag := hclsyntax.ParseConfig([]byte(schema), "", hcl.Pos{Line: 1, Column: 1}) if diag.HasErrors() { diff --git a/pkg/storage/config/silo_test.go b/pkg/storage/config/silo_test.go index dc874831..2f0d597a 100644 --- a/pkg/storage/config/silo_test.go +++ b/pkg/storage/config/silo_test.go @@ -66,3 +66,31 @@ func TestSiloConfig(t *testing.T) { assert.NoError(t, err) // TODO: Check data is as expected } + +func TestSiloConfigBlock(t *testing.T) { + + schema := `device Disk0 { + size = "1G" + expose = true + system = "memory" + } + + device Disk1 { + size = "2M" + system = "memory" + } + ` + + s := new(SiloSchema) + err := s.Decode([]byte(schema)) + assert.NoError(t, err) + + block0 := s.Device[0].EncodeAsBlock() + + ds := &SiloSchema{} + err = ds.Decode(block0) + assert.NoError(t, err) + + // Make sure it used the label + assert.Equal(t, ds.Device[0].Name, s.Device[0].Name) +} diff --git a/pkg/storage/devicegroup/README.md b/pkg/storage/devicegroup/README.md new file mode 100644 index 00000000..8227fd50 --- /dev/null +++ b/pkg/storage/devicegroup/README.md @@ -0,0 +1,75 @@ +# Device Group + +The `DeviceGroup` combines some number of Silo devices into a single unit, which can then be migrated to another Silo instance. +All internal concerns such as volatilityMonitor, waitingCache, as well as the new S3 assist, are now hidden from the consumer. + +## Creation + +There are two methods to create a `DeviceGroup`. + +### NewFromSchema + +This takes in an array of Silo device configs, and creates the devices. If `expose==true` then a corresponding NBD device will be created and attached. + +### NewFromProtocol + +This takes in a `protocol` and creates the devices as they are received from a sender. + +## Usage (Sending devices) + +Devices in a `DeviceGroup` are sent together, which allows Silo to optimize all aspects of the transfer. + + // Create a device group from schema + dg, err := devicegroup.NewFromSchema(devices, log, siloMetrics) + + // Start a migration + err := dg.StartMigrationTo(protocol) + + // Migrate the data with max total concurrency 100 + err = dg.MigrateAll(100, pHandler) + + // Migrate any dirty blocks + // hooks gives some control over the dirty loop + err = dg.MigrateDirty(hooks) + + // Send completion events for all devices + err = dg.Completed() + + // Close everything + dg.CloseAll() + +Within the `MigrateDirty` there are a number of hooks we can use to control things. MigrateDirty will return once all devices have no more dirty data. You can of course then call MigrateDirty again eg for continuous sync. + + type MigrateDirtyHooks struct { + PreGetDirty func(name string) error + PostGetDirty func(name string, blocks []uint) (bool, error) + PostMigrateDirty func(name string, blocks []uint) (bool, error) + Completed func(name string) + } + + +There is also support for sending global custom data. This would typically be done either in `pHandler` (The progress handler), or in one of the `MigrateDirty` hooks. + + pHandler := func(ps []*migrator.MigrationProgress) { + // Do some test here to see if enough data migrated + + // If so, send a custom Authority Transfer event. + dg.SendCustomData(authorityTransferPacket) + } + +## Usage (Receiving devices) + + // Create a DeviceGroup from protocol + // tweak func allows us to modify the schema, eg pathnames + dg, err = NewFromProtocol(ctx, protocol, tweak, nil, nil) + + // Handle any custom data events + // For example resume the VM here. + go dg.HandleCustomData(func(data []byte) { + // We got sent some custom data! + }) + + // Wait for migration completion + dg.WaitForCompletion() + +Once a `DeviceGroup` is has been created and migration is completed, you can then send the devices somewhere else with `StartMigration(protocol)`. \ No newline at end of file diff --git a/pkg/storage/devicegroup/device_group.go b/pkg/storage/devicegroup/device_group.go new file mode 100644 index 00000000..918e75a5 --- /dev/null +++ b/pkg/storage/devicegroup/device_group.go @@ -0,0 +1,139 @@ +package devicegroup + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/loopholelabs/logging/types" + "github.com/loopholelabs/silo/pkg/storage" + "github.com/loopholelabs/silo/pkg/storage/blocks" + "github.com/loopholelabs/silo/pkg/storage/config" + "github.com/loopholelabs/silo/pkg/storage/dirtytracker" + "github.com/loopholelabs/silo/pkg/storage/metrics" + "github.com/loopholelabs/silo/pkg/storage/migrator" + "github.com/loopholelabs/silo/pkg/storage/protocol" + "github.com/loopholelabs/silo/pkg/storage/protocol/packets" + "github.com/loopholelabs/silo/pkg/storage/volatilitymonitor" + "github.com/loopholelabs/silo/pkg/storage/waitingcache" +) + +const volatilityExpiry = 30 * time.Minute +const defaultBlockSize = 1024 * 1024 + +var errNotSetup = errors.New("toProtocol not setup") + +type DeviceGroup struct { + log types.Logger + met metrics.SiloMetrics + ctx context.Context + devices []*DeviceInformation + controlProtocol protocol.Protocol + incomingDevicesCh chan bool + progressLock sync.Mutex + progress map[string]*migrator.MigrationProgress +} + +type DeviceInformation struct { + Size uint64 + BlockSize uint64 + NumBlocks int + Schema *config.DeviceSchema + Prov storage.Provider + Storage storage.LockableProvider + Exp storage.ExposedStorage + Volatility *volatilitymonitor.VolatilityMonitor + DirtyLocal *dirtytracker.Local + DirtyRemote *dirtytracker.Remote + To *protocol.ToProtocol + Orderer *blocks.PriorityBlockOrder + Migrator *migrator.Migrator + migrationError chan error + WaitingCacheLocal *waitingcache.Local + WaitingCacheRemote *waitingcache.Remote + EventHandler func(e *packets.Event) +} + +func (dg *DeviceGroup) GetDeviceSchema() []*config.DeviceSchema { + s := make([]*config.DeviceSchema, 0) + for _, di := range dg.devices { + s = append(s, di.Schema) + } + return s +} + +func (dg *DeviceGroup) GetAllNames() []string { + names := make([]string, 0) + for _, di := range dg.devices { + names = append(names, di.Schema.Name) + } + return names +} + +func (dg *DeviceGroup) GetDeviceInformationByName(name string) *DeviceInformation { + for _, di := range dg.devices { + if di.Schema.Name == name { + return di + } + } + return nil +} + +func (dg *DeviceGroup) GetExposedDeviceByName(name string) storage.ExposedStorage { + for _, di := range dg.devices { + if di.Schema.Name == name && di.Exp != nil { + return di.Exp + } + } + return nil +} + +func (dg *DeviceGroup) GetProviderByName(name string) storage.Provider { + for _, di := range dg.devices { + if di.Schema.Name == name { + return di.Prov + } + } + return nil +} + +func (dg *DeviceGroup) GetBlockSizeByName(name string) int { + for _, di := range dg.devices { + if di.Schema.Name == name { + return int(di.BlockSize) + } + } + return -1 +} + +func (dg *DeviceGroup) CloseAll() error { + if dg.log != nil { + dg.log.Debug().Int("devices", len(dg.devices)).Msg("close device group") + } + + var e error + for _, d := range dg.devices { + // Unlock the storage so nothing blocks here... + // If we don't unlock there may be pending nbd writes that can't be completed. + d.Storage.Unlock() + + err := d.Prov.Close() + if err != nil { + if dg.log != nil { + dg.log.Error().Err(err).Msg("error closing device group storage provider") + } + e = errors.Join(e, err) + } + if d.Exp != nil { + err = d.Exp.Shutdown() + if err != nil { + if dg.log != nil { + dg.log.Error().Err(err).Msg("error closing device group exposed storage") + } + e = errors.Join(e, err) + } + } + } + return e +} diff --git a/pkg/storage/devicegroup/device_group_from.go b/pkg/storage/devicegroup/device_group_from.go new file mode 100644 index 00000000..5fb0bf21 --- /dev/null +++ b/pkg/storage/devicegroup/device_group_from.go @@ -0,0 +1,169 @@ +package devicegroup + +import ( + "context" + "errors" + + "github.com/loopholelabs/logging/types" + "github.com/loopholelabs/silo/pkg/storage" + "github.com/loopholelabs/silo/pkg/storage/config" + "github.com/loopholelabs/silo/pkg/storage/metrics" + "github.com/loopholelabs/silo/pkg/storage/protocol" + "github.com/loopholelabs/silo/pkg/storage/protocol/packets" + "github.com/loopholelabs/silo/pkg/storage/waitingcache" +) + +func NewFromProtocol(ctx context.Context, + pro protocol.Protocol, + tweakDeviceSchema func(index int, name string, schema string) string, + eventHandler func(e *packets.Event), + customDataHandler func(data []byte), + log types.Logger, + met metrics.SiloMetrics) (*DeviceGroup, error) { + + // This is our control channel, and we're expecting a DeviceGroupInfo + _, dgData, err := pro.WaitForCommand(0, packets.CommandDeviceGroupInfo) + if err != nil { + return nil, err + } + dgi, err := packets.DecodeDeviceGroupInfo(dgData) + if err != nil { + return nil, err + } + + devices := make([]*config.DeviceSchema, len(dgi.Devices)) + + // Setup something to listen for custom data... + handleCustomDataEvent := func() error { + // This is our control channel, and we're expecting a DeviceGroupInfo + id, evData, err := pro.WaitForCommand(0, packets.CommandEvent) + if err != nil { + return err + } + ev, err := packets.DecodeEvent(evData) + if err != nil { + return err + } + if ev.Type != packets.EventCustom || ev.CustomType != 0 { + return err + } + + if customDataHandler != nil { + customDataHandler(ev.CustomPayload) + } + + // Reply with ack + eack := packets.EncodeEventResponse() + _, err = pro.SendPacket(0, id, eack, protocol.UrgencyUrgent) + if err != nil { + return err + } + return nil + } + + // Listen for custom data events + go func() { + for { + err := handleCustomDataEvent() + if err != nil && !errors.Is(err, context.Canceled) { + log.Debug().Err(err).Msg("handleCustomDataEvenet returned") + return + } + } + }() + + // First create the devices we need using the schemas sent... + for index, di := range dgi.Devices { + // We may want to tweak schemas here eg autoStart = false on sync. Or modify pathnames. + schema := di.Schema + if tweakDeviceSchema != nil { + schema = tweakDeviceSchema(index-1, di.Name, schema) + } + ds, err := config.DecodeDeviceFromBlock(schema) + if err != nil { + return nil, err + } + devices[index-1] = ds + } + + dg, err := NewFromSchema(devices, log, met) + if err != nil { + return nil, err + } + + dg.controlProtocol = pro + dg.ctx = ctx + + dg.incomingDevicesCh = make(chan bool, len(dg.devices)) + + // We need to create the FromProtocol for each device, and associated goroutines here. + for index, di := range dgi.Devices { + dev := index - 1 + d := dg.devices[dev] + d.EventHandler = eventHandler + + destStorageFactory := func(di *packets.DevInfo) storage.Provider { + d.WaitingCacheLocal, d.WaitingCacheRemote = waitingcache.NewWaitingCacheWithLogger(d.Prov, int(di.BlockSize), dg.log) + + if d.Exp != nil { + d.Exp.SetProvider(d.WaitingCacheLocal) + } + + return d.WaitingCacheRemote + } + + from := protocol.NewFromProtocol(ctx, uint32(index), destStorageFactory, pro) + err = from.SetDevInfo(di) + if err != nil { + return nil, err + } + go func() { + err := from.HandleReadAt() + if err != nil && !errors.Is(err, context.Canceled) { + log.Debug().Err(err).Msg("HandleReadAt returned") + } + }() + go func() { + err := from.HandleWriteAt() + if err != nil && !errors.Is(err, context.Canceled) { + log.Debug().Err(err).Msg("HandleWriteAt returned") + } + }() + go func() { + err := from.HandleDirtyList(func(dirtyBlocks []uint) { + // Tell the waitingCache about it + d.WaitingCacheLocal.DirtyBlocks(dirtyBlocks) + }) + if err != nil && !errors.Is(err, context.Canceled) { + log.Debug().Err(err).Msg("HandleDirtyList returned") + } + }() + go func() { + err := from.HandleEvent(func(p *packets.Event) { + if p.Type == packets.EventCompleted { + dg.incomingDevicesCh <- true + } + if d.EventHandler != nil { + d.EventHandler(p) + } + }) + if err != nil && !errors.Is(err, context.Canceled) { + log.Debug().Err(err).Msg("HandleEvent returned") + } + }() + } + + return dg, nil +} + +// Wait for completion events from all devices here. +func (dg *DeviceGroup) WaitForCompletion() error { + for range dg.devices { + select { + case <-dg.incomingDevicesCh: + case <-dg.ctx.Done(): + return dg.ctx.Err() + } + } + return nil +} diff --git a/pkg/storage/devicegroup/device_group_test.go b/pkg/storage/devicegroup/device_group_test.go new file mode 100644 index 00000000..32609b7f --- /dev/null +++ b/pkg/storage/devicegroup/device_group_test.go @@ -0,0 +1,322 @@ +package devicegroup + +import ( + "context" + "crypto/rand" + "io" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/loopholelabs/logging" + "github.com/loopholelabs/logging/types" + "github.com/loopholelabs/silo/pkg/storage" + "github.com/loopholelabs/silo/pkg/storage/config" + "github.com/loopholelabs/silo/pkg/storage/migrator" + "github.com/loopholelabs/silo/pkg/storage/protocol" + "github.com/loopholelabs/silo/pkg/storage/protocol/packets" + "github.com/loopholelabs/silo/pkg/storage/sources" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var testDeviceSchema = []*config.DeviceSchema{ + { + Name: "test1", + Size: "8m", + System: "file", + BlockSize: "1m", + // Expose: true, + Location: "testdev_test1", + }, + + { + Name: "test2", + Size: "16m", + System: "file", + BlockSize: "1m", + // Expose: true, + Location: "testdev_test2", + }, +} + +func setupDeviceGroup(t *testing.T) *DeviceGroup { + /* + currentUser, err := user.Current() + if err != nil { + panic(err) + } + if currentUser.Username != "root" { + fmt.Printf("Cannot run test unless we are root.\n") + return nil + } + */ + dg, err := NewFromSchema(testDeviceSchema, nil, nil) + assert.NoError(t, err) + + t.Cleanup(func() { + err = dg.CloseAll() + assert.NoError(t, err) + + os.Remove("testdev_test1") + os.Remove("testdev_test2") + }) + + return dg +} + +func TestDeviceGroupBasic(t *testing.T) { + dg := setupDeviceGroup(t) + if dg == nil { + return + } +} + +func TestDeviceGroupSendDevInfo(t *testing.T) { + dg := setupDeviceGroup(t) + if dg == nil { + return + } + + pro := protocol.NewMockProtocol(context.TODO()) + + err := dg.StartMigrationTo(pro) + assert.NoError(t, err) + + // Make sure they all got sent correctly... + _, data, err := pro.WaitForCommand(0, packets.CommandDeviceGroupInfo) + assert.NoError(t, err) + + dgi, err := packets.DecodeDeviceGroupInfo(data) + assert.NoError(t, err) + + for index, r := range testDeviceSchema { + di := dgi.Devices[index+1] + + assert.Equal(t, r.Name, di.Name) + assert.Equal(t, uint64(r.ByteSize()), di.Size) + assert.Equal(t, uint32(r.ByteBlockSize()), di.BlockSize) + assert.Equal(t, string(r.EncodeAsBlock()), di.Schema) + } +} + +func TestDeviceGroupMigrateTo(t *testing.T) { + dg := setupDeviceGroup(t) + if dg == nil { + return + } + + log := logging.New(logging.Zerolog, "silo", os.Stdout) + log.SetLevel(types.TraceLevel) + + // Create a simple pipe + r1, w1 := io.Pipe() + r2, w2 := io.Pipe() + + ctx, cancelfn := context.WithCancel(context.TODO()) + + var incomingLock sync.Mutex + incomingProviders := make(map[string]storage.Provider) + + prSource := protocol.NewRW(ctx, []io.Reader{r1}, []io.Writer{w2}, nil) + prDest := protocol.NewRW(ctx, []io.Reader{r2}, []io.Writer{w1}, nil) + + go func() { + // This is our control channel, and we're expecting a DeviceGroupInfo + _, dgData, err := prDest.WaitForCommand(0, packets.CommandDeviceGroupInfo) + assert.NoError(t, err) + dgi, err := packets.DecodeDeviceGroupInfo(dgData) + assert.NoError(t, err) + + for index, di := range dgi.Devices { + destStorageFactory := func(di *packets.DevInfo) storage.Provider { + store := sources.NewMemoryStorage(int(di.Size)) + incomingLock.Lock() + incomingProviders[di.Name] = store + incomingLock.Unlock() + return store + } + + from := protocol.NewFromProtocol(ctx, uint32(index), destStorageFactory, prDest) + err = from.SetDevInfo(di) + assert.NoError(t, err) + go func() { + err := from.HandleReadAt() + assert.ErrorIs(t, err, context.Canceled) + }() + go func() { + err := from.HandleWriteAt() + assert.ErrorIs(t, err, context.Canceled) + }() + go func() { + err := from.HandleDirtyList(func(_ []uint) { + }) + assert.ErrorIs(t, err, context.Canceled) + }() + } + }() + + go func() { + _ = prSource.Handle() + }() + go func() { + _ = prDest.Handle() + }() + + // Lets write some data... + for _, s := range testDeviceSchema { + prov := dg.GetProviderByName(s.Name) + assert.NotNil(t, prov) + buff := make([]byte, prov.Size()) + _, err := rand.Read(buff) + assert.NoError(t, err) + _, err = prov.WriteAt(buff, 0) + assert.NoError(t, err) + } + + // Send all the dev info... + err := dg.StartMigrationTo(prSource) + assert.NoError(t, err) + + pHandler := func(_ map[string]*migrator.MigrationProgress) {} + + err = dg.MigrateAll(100, pHandler) + assert.NoError(t, err) + + // Check the data all got migrated correctly + for _, s := range testDeviceSchema { + prov := dg.GetProviderByName(s.Name) + // Find the correct destProvider... + destProvider := incomingProviders[s.Name] + assert.NotNil(t, destProvider) + eq, err := storage.Equals(prov, destProvider, 1024*1024) + assert.NoError(t, err) + assert.True(t, eq) + } + + cancelfn() +} + +func TestDeviceGroupMigrate(t *testing.T) { + dg := setupDeviceGroup(t) + if dg == nil { + return + } + + // Remove the receiving files + t.Cleanup(func() { + os.Remove("testrecv_test1") + os.Remove("testrecv_test2") + }) + + log := logging.New(logging.Zerolog, "silo", os.Stdout) + log.SetLevel(types.TraceLevel) + + // Create a simple pipe + r1, w1 := io.Pipe() + r2, w2 := io.Pipe() + + ctx, cancelfn := context.WithCancel(context.TODO()) + + prSource := protocol.NewRW(ctx, []io.Reader{r1}, []io.Writer{w2}, nil) + prDest := protocol.NewRW(ctx, []io.Reader{r2}, []io.Writer{w1}, nil) + + var prDone sync.WaitGroup + + prDone.Add(2) + go func() { + _ = prSource.Handle() + prDone.Done() + }() + go func() { + _ = prDest.Handle() + prDone.Done() + }() + + // Lets write some data... + for _, s := range testDeviceSchema { + prov := dg.GetProviderByName(s.Name) + buff := make([]byte, prov.Size()) + _, err := rand.Read(buff) + assert.NoError(t, err) + _, err = prov.WriteAt(buff, 0) + assert.NoError(t, err) + } + + var dg2 *DeviceGroup + var wg sync.WaitGroup + + // We will tweak schema in recv here so we have separate paths. + tweak := func(_ int, _ string, schema string) string { + s := strings.ReplaceAll(schema, "testdev_test1", "testrecv_test1") + s = strings.ReplaceAll(s, "testdev_test2", "testrecv_test2") + return s + } + + // TransferAuthority + var tawg sync.WaitGroup + tawg.Add(1) + cdh := func(data []byte) { + assert.Equal(t, []byte("Hello"), data) + tawg.Done() + } + + wg.Add(1) + go func() { + var err error + dg2, err = NewFromProtocol(ctx, prDest, tweak, nil, cdh, nil, nil) + assert.NoError(t, err) + wg.Done() + }() + + // Send all the dev info... + err := dg.StartMigrationTo(prSource) + assert.NoError(t, err) + + // Make sure the incoming devices were setup completely + wg.Wait() + + // TransferAuthority + tawg.Add(1) + time.AfterFunc(100*time.Millisecond, func() { + dg.SendCustomData([]byte("Hello")) + tawg.Done() + }) + + pHandler := func(_ map[string]*migrator.MigrationProgress) {} + + err = dg.MigrateAll(100, pHandler) + assert.NoError(t, err) + + // Make sure authority has been transferred as expected. + tawg.Wait() + + err = dg.Completed() + assert.NoError(t, err) + + // Make sure all incoming devices are complete + dg2.WaitForCompletion() + + // Check the data all got migrated correctly from dg to dg2. + for _, s := range testDeviceSchema { + prov := dg.GetProviderByName(s.Name) + require.NotNil(t, prov) + destProvider := dg2.GetProviderByName(s.Name) + require.NotNil(t, destProvider) + eq, err := storage.Equals(prov, destProvider, 1024*1024) + assert.NoError(t, err) + assert.True(t, eq) + } + + // Cancel context + cancelfn() + + // Close protocol bits + prDone.Wait() + r1.Close() + w1.Close() + r2.Close() + w2.Close() +} diff --git a/pkg/storage/devicegroup/device_group_to.go b/pkg/storage/devicegroup/device_group_to.go new file mode 100644 index 00000000..3f834ac2 --- /dev/null +++ b/pkg/storage/devicegroup/device_group_to.go @@ -0,0 +1,448 @@ +package devicegroup + +import ( + "context" + "time" + + "github.com/loopholelabs/logging/types" + "github.com/loopholelabs/silo/pkg/storage" + "github.com/loopholelabs/silo/pkg/storage/blocks" + "github.com/loopholelabs/silo/pkg/storage/config" + "github.com/loopholelabs/silo/pkg/storage/device" + "github.com/loopholelabs/silo/pkg/storage/dirtytracker" + "github.com/loopholelabs/silo/pkg/storage/expose" + "github.com/loopholelabs/silo/pkg/storage/metrics" + "github.com/loopholelabs/silo/pkg/storage/migrator" + "github.com/loopholelabs/silo/pkg/storage/modules" + "github.com/loopholelabs/silo/pkg/storage/protocol" + "github.com/loopholelabs/silo/pkg/storage/protocol/packets" + "github.com/loopholelabs/silo/pkg/storage/volatilitymonitor" +) + +func NewFromSchema(ds []*config.DeviceSchema, log types.Logger, met metrics.SiloMetrics) (*DeviceGroup, error) { + dg := &DeviceGroup{ + log: log, + met: met, + devices: make([]*DeviceInformation, 0), + progress: make(map[string]*migrator.MigrationProgress), + } + + for _, s := range ds { + prov, exp, err := device.NewDeviceWithLoggingMetrics(s, log, met) + if err != nil { + if log != nil { + log.Error().Err(err).Str("schema", string(s.Encode())).Msg("could not create device") + } + // We should try to close / shutdown any successful devices we created here... + // But it's likely to be critical. + dg.CloseAll() + return nil, err + } + + blockSize := int(s.ByteBlockSize()) + if blockSize == 0 { + blockSize = defaultBlockSize + } + + local := modules.NewLockable(prov) + mlocal := modules.NewMetrics(local) + dirtyLocal, dirtyRemote := dirtytracker.NewDirtyTracker(mlocal, blockSize) + vmonitor := volatilitymonitor.NewVolatilityMonitor(dirtyLocal, blockSize, volatilityExpiry) + + totalBlocks := (int(local.Size()) + blockSize - 1) / blockSize + orderer := blocks.NewPriorityBlockOrder(totalBlocks, vmonitor) + orderer.AddAll() + + if exp != nil { + exp.SetProvider(vmonitor) + } + + // Add to metrics if given. + if met != nil { + met.AddMetrics(s.Name, mlocal) + if exp != nil { + met.AddNBD(s.Name, exp.(*expose.ExposedStorageNBDNL)) + } + met.AddDirtyTracker(s.Name, dirtyRemote) + met.AddVolatilityMonitor(s.Name, vmonitor) + } + + dg.devices = append(dg.devices, &DeviceInformation{ + Size: local.Size(), + BlockSize: uint64(blockSize), + NumBlocks: totalBlocks, + Schema: s, + Prov: prov, + Storage: local, + Exp: exp, + Volatility: vmonitor, + DirtyLocal: dirtyLocal, + DirtyRemote: dirtyRemote, + Orderer: orderer, + }) + + // Set these two at least, so we know *something* about every device in progress handler. + dg.progress[s.Name] = &migrator.MigrationProgress{ + BlockSize: blockSize, + TotalBlocks: totalBlocks, + } + } + + if log != nil { + log.Debug().Int("devices", len(dg.devices)).Msg("created device group") + } + return dg, nil +} + +func (dg *DeviceGroup) StartMigrationTo(pro protocol.Protocol) error { + // We will use dev 0 to communicate + dg.controlProtocol = pro + + // First lets setup the ToProtocol + for index, d := range dg.devices { + d.To = protocol.NewToProtocol(d.Prov.Size(), uint32(index+1), pro) + d.To.SetCompression(true) + + if dg.met != nil { + dg.met.AddToProtocol(d.Schema.Name, d.To) + } + } + + // Now package devices up into a single DeviceGroupInfo + dgi := &packets.DeviceGroupInfo{ + Devices: make(map[int]*packets.DevInfo), + } + + for index, d := range dg.devices { + di := &packets.DevInfo{ + Size: d.Prov.Size(), + BlockSize: uint32(d.BlockSize), + Name: d.Schema.Name, + Schema: string(d.Schema.EncodeAsBlock()), + } + dgi.Devices[index+1] = di + } + + // Send the single DeviceGroupInfo packet down our control channel 0 + dgiData := packets.EncodeDeviceGroupInfo(dgi) + _, err := dg.controlProtocol.SendPacket(0, protocol.IDPickAny, dgiData, protocol.UrgencyUrgent) + + return err +} + +// This will Migrate all devices to the 'to' setup in SendDevInfo stage. +func (dg *DeviceGroup) MigrateAll(maxConcurrency int, progressHandler func(p map[string]*migrator.MigrationProgress)) error { + for _, d := range dg.devices { + if d.To == nil { + return errNotSetup + } + } + + ctime := time.Now() + + if dg.log != nil { + dg.log.Debug().Int("devices", len(dg.devices)).Msg("migrating device group") + } + + // Add up device sizes, so we can allocate the concurrency proportionally + totalSize := uint64(0) + for _, d := range dg.devices { + totalSize += d.Size + } + + // We need at least this much... + if maxConcurrency < len(dg.devices) { + maxConcurrency = len(dg.devices) + } + // We will allocate each device at least ONE... + maxConcurrency -= len(dg.devices) + + for index, d := range dg.devices { + concurrency := 1 + (uint64(maxConcurrency) * d.Size / totalSize) + d.migrationError = make(chan error, 1) // We will just hold onto the first error for now. + + setMigrationError := func(err error) { + if err != nil && err != context.Canceled { + select { + case d.migrationError <- err: + default: + } + } + } + + // Setup d.to + go func() { + err := d.To.HandleNeedAt(func(offset int64, length int32) { + if dg.log != nil { + dg.log.Debug(). + Int64("offset", offset). + Int32("length", length). + Int("dev", index). + Str("name", d.Schema.Name). + Msg("NeedAt for device") + } + // Prioritize blocks + endOffset := uint64(offset + int64(length)) + if endOffset > d.Size { + endOffset = d.Size + } + + startBlock := int(offset / int64(d.BlockSize)) + endBlock := int((endOffset-1)/d.BlockSize) + 1 + for b := startBlock; b < endBlock; b++ { + d.Orderer.PrioritiseBlock(b) + } + }) + setMigrationError(err) + }() + + go func() { + err := d.To.HandleDontNeedAt(func(offset int64, length int32) { + if dg.log != nil { + dg.log.Debug(). + Int64("offset", offset). + Int32("length", length). + Int("dev", index). + Str("name", d.Schema.Name). + Msg("DontNeedAt for device") + } + // Deprioritize blocks + endOffset := uint64(offset + int64(length)) + if endOffset > d.Size { + endOffset = d.Size + } + + startBlock := int(offset / int64(d.BlockSize)) + endBlock := int((endOffset-1)/d.BlockSize) + 1 + for b := startBlock; b < endBlock; b++ { + d.Orderer.Remove(b) + } + }) + setMigrationError(err) + }() + + cfg := migrator.NewConfig() + cfg.Logger = dg.log + cfg.BlockSize = int(d.BlockSize) + cfg.Concurrency = map[int]int{ + storage.BlockTypeAny: int(concurrency), + } + cfg.LockerHandler = func() { + // setMigrationError(d.to.SendEvent(&packets.Event{Type: packets.EventPreLock})) + // d.Storage.Lock() + // setMigrationError(d.to.SendEvent(&packets.Event{Type: packets.EventPostLock})) + } + cfg.UnlockerHandler = func() { + // setMigrationError(d.to.SendEvent(&packets.Event{Type: packets.EventPreUnlock})) + // d.Storage.Unlock() + // setMigrationError(d.to.SendEvent(&packets.Event{Type: packets.EventPostUnlock})) + } + cfg.ErrorHandler = func(_ *storage.BlockInfo, err error) { + setMigrationError(err) + } + cfg.ProgressHandler = func(p *migrator.MigrationProgress) { + dg.progressLock.Lock() + dg.progress[d.Schema.Name] = p + if progressHandler != nil { + progressHandler(dg.progress) + } + dg.progressLock.Unlock() + } + mig, err := migrator.NewMigrator(d.DirtyRemote, d.To, d.Orderer, cfg) + if err != nil { + return err + } + d.Migrator = mig + if dg.met != nil { + dg.met.AddMigrator(d.Schema.Name, mig) + } + if dg.log != nil { + dg.log.Debug(). + Uint64("concurrency", concurrency). + Int("index", index). + Str("name", d.Schema.Name). + Msg("Setup migrator") + } + } + + errs := make(chan error, len(dg.devices)) + + // Now start them all migrating, and collect err + for _, d := range dg.devices { + go func() { + err := d.Migrator.Migrate(d.NumBlocks) + errs <- err + }() + } + + // Check for error from Migrate, and then Wait for completion of all devices... + for index := range dg.devices { + migErr := <-errs + if migErr != nil { + if dg.log != nil { + dg.log.Error().Err(migErr).Int("index", index).Msg("error migrating device group") + } + return migErr + } + } + + for index, d := range dg.devices { + err := d.Migrator.WaitForCompletion() + if err != nil { + if dg.log != nil { + dg.log.Error().Err(err).Int("index", index).Msg("error migrating device group waiting for completion") + } + return err + } + + // Check for any migration error + select { + case err := <-d.migrationError: + if dg.log != nil { + dg.log.Error().Err(err).Int("index", index).Msg("error migrating device group from goroutines") + } + return err + default: + } + } + + if dg.log != nil { + dg.log.Debug().Int64("duration", time.Since(ctime).Milliseconds()).Int("devices", len(dg.devices)).Msg("migration of device group completed") + } + + return nil +} + +type MigrateDirtyHooks struct { + PreGetDirty func(name string) error + PostGetDirty func(name string, blocks []uint) (bool, error) + PostMigrateDirty func(name string, blocks []uint) (bool, error) + Completed func(name string) +} + +func (dg *DeviceGroup) MigrateDirty(hooks *MigrateDirtyHooks) error { + // If StartMigrationTo or MigrateAll have not been called, return error. + for _, d := range dg.devices { + if d.To == nil || d.Migrator == nil { + return errNotSetup + } + } + + errs := make(chan error, len(dg.devices)) + + for index, d := range dg.devices { + // First unlock the storage if it is locked due to a previous MigrateDirty call + d.Storage.Unlock() + + go func() { + for { + if hooks != nil && hooks.PreGetDirty != nil { + hooks.PreGetDirty(d.Schema.Name) + } + + blocks := d.Migrator.GetLatestDirty() + if dg.log != nil { + dg.log.Debug(). + Int("blocks", len(blocks)). + Int("index", index). + Str("name", d.Schema.Name). + Msg("migrating dirty blocks") + } + + if hooks != nil && hooks.PostGetDirty != nil { + cont, err := hooks.PostGetDirty(d.Schema.Name, blocks) + if err != nil { + errs <- err + } + if !cont { + break + } + } + + err := d.To.DirtyList(int(d.BlockSize), blocks) + if err != nil { + errs <- err + return + } + + err = d.Migrator.MigrateDirty(blocks) + if err != nil { + errs <- err + return + } + + if hooks != nil && hooks.PostMigrateDirty != nil { + cont, err := hooks.PostMigrateDirty(d.Schema.Name, blocks) + if err != nil { + errs <- err + } + if !cont { + break + } + } + } + + err := d.Migrator.WaitForCompletion() + if err != nil { + errs <- err + return + } + + if hooks != nil && hooks.Completed != nil { + hooks.Completed(d.Schema.Name) + } + + errs <- nil + }() + } + + // Wait for all dirty migrations to complete + // Check for any error and return it + for range dg.devices { + err := <-errs + if err != nil { + return err + } + } + + return nil +} + +func (dg *DeviceGroup) Completed() error { + for index, d := range dg.devices { + err := d.To.SendEvent(&packets.Event{Type: packets.EventCompleted}) + if err != nil { + return err + } + + if dg.log != nil { + dg.log.Debug(). + Int("index", index). + Str("name", d.Schema.Name). + Msg("migration completed") + } + } + return nil +} + +func (dg *DeviceGroup) SendCustomData(customData []byte) error { + + // Send the single TransferAuthority packet down our control channel 0 + taData := packets.EncodeEvent(&packets.Event{ + Type: packets.EventCustom, + CustomType: 0, + CustomPayload: customData, + }) + id, err := dg.controlProtocol.SendPacket(0, protocol.IDPickAny, taData, protocol.UrgencyUrgent) + if err != nil { + return err + } + + // Wait for ack + ackData, err := dg.controlProtocol.WaitForPacket(0, id) + if err != nil { + return err + } + + return packets.DecodeEventResponse(ackData) +} diff --git a/pkg/storage/expose/nbd.go b/pkg/storage/expose/nbd.go index a95f627c..5ecfabde 100644 --- a/pkg/storage/expose/nbd.go +++ b/pkg/storage/expose/nbd.go @@ -108,6 +108,7 @@ type ExposedStorageNBDNL struct { provLock sync.RWMutex deviceIndex int dispatchers []*Dispatch + handlersWg sync.WaitGroup } func NewExposedStorageNBDNL(prov storage.Provider, conf *Config) *ExposedStorageNBDNL { @@ -210,6 +211,7 @@ func (n *ExposedStorageNBDNL) Init() error { d.asyncReads = n.config.AsyncReads d.asyncWrites = n.config.AsyncWrites // Start reading commands on the socket and dispatching them to our provider + n.handlersWg.Add(1) go func() { err := d.Handle() if n.config.Logger != nil { @@ -218,6 +220,7 @@ func (n *ExposedStorageNBDNL) Init() error { Err(err). Msg("nbd dispatch completed") } + n.handlersWg.Done() }() n.socks = append(n.socks, serverc) socks = append(socks, client) @@ -293,7 +296,19 @@ func (n *ExposedStorageNBDNL) Shutdown() error { // First cancel the context, which will stop waiting on pending readAt/writeAt... n.cancelfn() - // Now wait for any pending responses to be sent + // Close all the socket pairs... + for _, v := range n.socks { + err := v.Close() + if err != nil { + return err + } + } + + // Now wait until the handlers return + n.handlersWg.Wait() + + // Now wait for any pending responses to be sent. There should not be any new + // Requests received since we have Disconnected. for _, d := range n.dispatchers { d.Wait() } @@ -304,14 +319,6 @@ func (n *ExposedStorageNBDNL) Shutdown() error { return err } - // Close all the socket pairs... - for _, v := range n.socks { - err = v.Close() - if err != nil { - return err - } - } - // Wait until it's completely disconnected... for { s, err := nbdnl.Status(uint32(n.deviceIndex)) diff --git a/pkg/storage/expose/nbd_dispatch.go b/pkg/storage/expose/nbd_dispatch.go index dc5b59cd..8bb50c41 100644 --- a/pkg/storage/expose/nbd_dispatch.go +++ b/pkg/storage/expose/nbd_dispatch.go @@ -3,6 +3,7 @@ package expose import ( "context" "encoding/binary" + "errors" "fmt" "io" "sync" @@ -13,6 +14,8 @@ import ( "github.com/loopholelabs/silo/pkg/storage" ) +var ErrShuttingDown = errors.New("shutting down. Cannot serve any new requests") + const dispatchBufferSize = 4 * 1024 * 1024 /** @@ -76,6 +79,8 @@ type Dispatch struct { prov storage.Provider fatal chan error pendingResponses sync.WaitGroup + shuttingDown bool + shuttingDownLock sync.Mutex metricPacketsIn uint64 metricPacketsOut uint64 metricReadAt uint64 @@ -140,6 +145,11 @@ func (d *Dispatch) GetMetrics() *DispatchMetrics { } func (d *Dispatch) Wait() { + d.shuttingDownLock.Lock() + d.shuttingDown = true + defer d.shuttingDownLock.Unlock() + // Stop accepting any new requests... + if d.logger != nil { d.logger.Trace().Str("device", d.dev).Msg("nbd waiting for pending responses") } @@ -342,16 +352,22 @@ func (d *Dispatch) cmdRead(cmdHandle uint64, cmdFrom uint64, cmdLength uint32) e case e = <-errchan: } - errorValue := uint32(0) if e != nil { - errorValue = 1 - data = make([]byte, 0) // If there was an error, don't send data + return d.writeResponse(1, handle, []byte{}) } - return d.writeResponse(errorValue, handle, data) + return d.writeResponse(0, handle, data) } - if d.asyncReads { + d.shuttingDownLock.Lock() + if !d.shuttingDown { d.pendingResponses.Add(1) + } else { + d.shuttingDownLock.Unlock() + return ErrShuttingDown + } + d.shuttingDownLock.Unlock() + + if d.asyncReads { go func() { ctime := time.Now() err := performRead(cmdHandle, cmdFrom, cmdLength) @@ -368,7 +384,6 @@ func (d *Dispatch) cmdRead(cmdHandle uint64, cmdFrom uint64, cmdLength uint32) e d.pendingResponses.Done() }() } else { - d.pendingResponses.Add(1) ctime := time.Now() err := performRead(cmdHandle, cmdFrom, cmdLength) if err == nil { @@ -418,8 +433,16 @@ func (d *Dispatch) cmdWrite(cmdHandle uint64, cmdFrom uint64, cmdLength uint32, return d.writeResponse(errorValue, handle, []byte{}) } - if d.asyncWrites { + d.shuttingDownLock.Lock() + if !d.shuttingDown { d.pendingResponses.Add(1) + } else { + d.shuttingDownLock.Unlock() + return ErrShuttingDown + } + d.shuttingDownLock.Unlock() + + if d.asyncWrites { go func() { ctime := time.Now() err := performWrite(cmdHandle, cmdFrom, cmdLength, cmdData) @@ -436,7 +459,6 @@ func (d *Dispatch) cmdWrite(cmdHandle uint64, cmdFrom uint64, cmdLength uint32, d.pendingResponses.Done() }() } else { - d.pendingResponses.Add(1) ctime := time.Now() err := performWrite(cmdHandle, cmdFrom, cmdLength, cmdData) if err == nil { diff --git a/pkg/storage/protocol/from_protocol.go b/pkg/storage/protocol/from_protocol.go index ee9988ad..0d1c2c2b 100644 --- a/pkg/storage/protocol/from_protocol.go +++ b/pkg/storage/protocol/from_protocol.go @@ -251,9 +251,13 @@ func (fp *FromProtocol) HandleDevInfo() error { if err != nil { return err } - atomic.AddUint64(&fp.metricRecvDevInfo, 1) + return fp.SetDevInfo(di) +} + +// Alternatively, you can call SetDevInfo to setup the DevInfo. +func (fp *FromProtocol) SetDevInfo(di *packets.DevInfo) error { // Create storage, and setup a writeCombinator with two inputs fp.initLock.Lock() fp.prov = fp.providerFactory(di) diff --git a/pkg/storage/protocol/packets/device_group_info.go b/pkg/storage/protocol/packets/device_group_info.go new file mode 100644 index 00000000..3d43aa89 --- /dev/null +++ b/pkg/storage/protocol/packets/device_group_info.go @@ -0,0 +1,58 @@ +package packets + +import ( + "bytes" + "encoding/binary" +) + +type DeviceGroupInfo struct { + Devices map[int]*DevInfo +} + +func EncodeDeviceGroupInfo(dgi *DeviceGroupInfo) []byte { + var buffer bytes.Buffer + buffer.WriteByte(CommandDeviceGroupInfo) + diHeader := make([]byte, 8) + + for index, di := range dgi.Devices { + diBytes := EncodeDevInfo(di) + binary.LittleEndian.PutUint32(diHeader, uint32(index)) + binary.LittleEndian.PutUint32(diHeader[4:], uint32(len(diBytes))) + buffer.Write(diHeader) + buffer.Write(diBytes) + } + return buffer.Bytes() +} + +func DecodeDeviceGroupInfo(buff []byte) (*DeviceGroupInfo, error) { + dgi := &DeviceGroupInfo{ + Devices: make(map[int]*DevInfo), + } + + if len(buff) < 1 || buff[0] != CommandDeviceGroupInfo { + return nil, ErrInvalidPacket + } + + ptr := 1 + for { + if ptr == len(buff) { + break + } + if len(buff)-ptr < 8 { + return nil, ErrInvalidPacket + } + index := binary.LittleEndian.Uint32(buff[ptr:]) + length := binary.LittleEndian.Uint32(buff[ptr+4:]) + ptr += 8 + if len(buff)-ptr < int(length) { + return nil, ErrInvalidPacket + } + di, err := DecodeDevInfo(buff[ptr : ptr+int(length)]) + if err != nil { + return nil, err + } + dgi.Devices[int(index)] = di + ptr += int(length) + } + return dgi, nil +} diff --git a/pkg/storage/protocol/packets/packet.go b/pkg/storage/protocol/packets/packet.go index 4d40fa1e..0b61fa64 100644 --- a/pkg/storage/protocol/packets/packet.go +++ b/pkg/storage/protocol/packets/packet.go @@ -22,6 +22,7 @@ const ( CommandRemoveDev = CommandRequest | byte(10) CommandRemoveFromMap = CommandRequest | byte(11) CommandAlternateSources = CommandRequest | byte(12) + CommandDeviceGroupInfo = CommandRequest | byte(13) ) const ( diff --git a/pkg/storage/protocol/packets/packet_test.go b/pkg/storage/protocol/packets/packet_test.go index f4332c80..9c420fb7 100644 --- a/pkg/storage/protocol/packets/packet_test.go +++ b/pkg/storage/protocol/packets/packet_test.go @@ -371,3 +371,29 @@ func TestAlternateSources(t *testing.T) { assert.Equal(t, sources[0].Location, sources2[0].Location) } + +func TestDeviceGroupInfo(t *testing.T) { + dgi := &DeviceGroupInfo{ + Devices: map[int]*DevInfo{ + 0: {Size: 100, BlockSize: 1, Name: "a-hello", Schema: "a-1234"}, + 1: {Size: 200, BlockSize: 2, Name: "b-hello", Schema: "b-1234"}, + 3: {Size: 300, BlockSize: 3, Name: "c-hello", Schema: "c-1234"}, + }, + } + b := EncodeDeviceGroupInfo(dgi) + + dgi2, err := DecodeDeviceGroupInfo(b) + assert.NoError(t, err) + + // Check that dgi and dgi2 are the same... + assert.Equal(t, len(dgi.Devices), len(dgi2.Devices)) + + for index, di := range dgi.Devices { + di2, ok := dgi2.Devices[index] + assert.True(t, ok) + assert.Equal(t, di.Size, di2.Size) + assert.Equal(t, di.BlockSize, di2.BlockSize) + assert.Equal(t, di.Name, di2.Name) + assert.Equal(t, di.Schema, di2.Schema) + } +} diff --git a/pkg/storage/protocol/protocol_rw.go b/pkg/storage/protocol/protocol_rw.go index 65ac40e0..d6db8772 100644 --- a/pkg/storage/protocol/protocol_rw.go +++ b/pkg/storage/protocol/protocol_rw.go @@ -323,14 +323,24 @@ func (p *RW) WaitForPacket(dev uint32, id uint32) ([]byte, error) { } func (p *RW) WaitForCommand(dev uint32, cmd byte) (uint32, []byte, error) { + p.activeDevsLock.Lock() p.waitersLock.Lock() - w := p.waiters[dev] + w, ok := p.waiters[dev] + if !ok { + p.activeDevs[dev] = true + w = Waiters{ + byCmd: make(map[byte]chan packetinfo), + byID: make(map[uint32]chan packetinfo), + } + p.waiters[dev] = w + } wq, okk := w.byCmd[cmd] if !okk { wq = make(chan packetinfo, packetBufferSize) w.byCmd[cmd] = wq } p.waitersLock.Unlock() + p.activeDevsLock.Unlock() select { case p := <-wq: diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index 3a3673d4..942d88f1 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -65,6 +65,7 @@ type SyncStartConfig struct { */ func Equals(sp1 Provider, sp2 Provider, blockSize int) (bool, error) { if sp1.Size() != sp2.Size() { + fmt.Printf("Equals: Size differs (%d %d)\n", sp1.Size(), sp2.Size()) return false, nil } @@ -78,20 +79,23 @@ func Equals(sp1 Provider, sp2 Provider, blockSize int) (bool, error) { n, err := sp1.ReadAt(sourceBuff, int64(i)) if err != nil { + fmt.Printf("Equals: sp1.ReadAt %v\n", err) return false, err } sourceBuff = sourceBuff[:n] n, err = sp2.ReadAt(destBuff, int64(i)) if err != nil { + fmt.Printf("Equals: sp2.ReadAt %v\n", err) return false, err } destBuff = destBuff[:n] if len(sourceBuff) != len(destBuff) { + fmt.Printf("Equals: data len sp1 sp2 %d %d\n", len(sourceBuff), len(destBuff)) return false, nil } for j := 0; j < n; j++ { if sourceBuff[j] != destBuff[j] { - fmt.Printf("Equals: Block %d differs\n", i/blockSize) + fmt.Printf("Equals: Block %d differs [sp1 %d, sp2 %d]\n", i/blockSize, sourceBuff[j], destBuff[j]) return false, nil } }