From 57f46bba8e1e30640821c53d28b4fd7178d07185 Mon Sep 17 00:00:00 2001 From: Jay Mundrawala Date: Thu, 22 Aug 2024 16:57:52 -0500 Subject: [PATCH 1/4] make provider reconnectable --- providers/coordinator.go | 79 ++++---- providers/coordinator_test.go | 2 - providers/providers_test.go | 5 +- providers/running_provider.go | 297 ++++++++++++++++++++++++++++- providers/running_provider_test.go | 79 ++++++++ providers/runtime.go | 23 +++ 6 files changed, 436 insertions(+), 49 deletions(-) create mode 100644 providers/running_provider_test.go diff --git a/providers/coordinator.go b/providers/coordinator.go index 03aef69fa2..7080120c67 100644 --- a/providers/coordinator.go +++ b/providers/coordinator.go @@ -9,7 +9,6 @@ import ( "os/exec" "strconv" "sync" - "time" "github.com/cockroachdb/errors" "github.com/hashicorp/go-hclog" @@ -310,51 +309,51 @@ func (c *coordinator) unsafeStartProvider(id string, update UpdateProvidersConfi } } - pluginCmd := exec.Command(provider.binPath(), []string{"run_as_plugin", "--log-level", zerolog.GlobalLevel().String()}...) - - addColorConfig(pluginCmd) - - pluginLogger := &hclogger{Logger: log.Logger} - pluginLogger.SetLevel(hclog.Warn) - client := plugin.NewClient(&plugin.ClientConfig{ - HandshakeConfig: pp.Handshake, - Plugins: pp.PluginMap, - Cmd: pluginCmd, - AllowedProtocols: []plugin.Protocol{ - plugin.ProtocolNetRPC, plugin.ProtocolGRPC, - }, - Logger: pluginLogger, - Stderr: os.Stderr, - }) - - // Connect via RPC - rpcClient, err := client.Client() - if err != nil { - client.Kill() - return nil, errors.Wrap(err, "failed to initialize plugin client") - } + connectFunc := func() (pp.ProviderPlugin, *plugin.Client, error) { + pluginCmd := exec.Command(provider.binPath(), []string{"run_as_plugin", "--log-level", zerolog.GlobalLevel().String()}...) + + addColorConfig(pluginCmd) + + pluginLogger := &hclogger{Logger: log.Logger} + pluginLogger.SetLevel(hclog.Warn) + client := plugin.NewClient(&plugin.ClientConfig{ + HandshakeConfig: pp.Handshake, + Plugins: pp.PluginMap, + Cmd: pluginCmd, + AllowedProtocols: []plugin.Protocol{ + plugin.ProtocolNetRPC, plugin.ProtocolGRPC, + }, + Logger: pluginLogger, + Stderr: os.Stderr, + }) + + // Connect via RPC + rpcClient, err := client.Client() + if err != nil { + client.Kill() + return nil, nil, errors.Wrap(err, "failed to initialize plugin client") + } - // Request the plugin - pluginName := "provider" - raw, err := rpcClient.Dispense(pluginName) - if err != nil { - client.Kill() - return nil, errors.Wrap(err, "failed to call "+pluginName+" plugin") + // Request the plugin + pluginName := "provider" + raw, err := rpcClient.Dispense(pluginName) + if err != nil { + client.Kill() + return nil, nil, errors.Wrap(err, "failed to call "+pluginName+" plugin") + } + + return raw.(pp.ProviderPlugin), client, nil } - res := &RunningProvider{ - Name: provider.Name, - ID: provider.ID, - Plugin: raw.(pp.ProviderPlugin), - Client: client, - Schema: provider.Schema, - interval: 2 * time.Second, - gracePeriod: 3 * time.Second, + plug, client, err := connectFunc() + if err != nil { + return nil, err } c.schema.Add(provider.ID, provider.Schema) - if err := res.heartbeat(); err != nil { + res, err := SupervisedRunningProivder(provider.Name, provider.ID, plug, client, provider.Schema, connectFunc) + if err != nil { return nil, err } c.runningByID[res.ID] = res @@ -388,7 +387,7 @@ func (c *coordinator) Shutdown() { if err := provider.Shutdown(); err != nil { log.Warn().Err(err).Str("provider", provider.Name).Msg("failed to shut down provider") } - provider.Client.Kill() + provider.KillClient() } c.runningByID = map[string]*RunningProvider{} c.runtimes = map[string]*Runtime{} diff --git a/providers/coordinator_test.go b/providers/coordinator_test.go index 8229997eab..1cc83150f2 100644 --- a/providers/coordinator_test.go +++ b/providers/coordinator_test.go @@ -7,7 +7,6 @@ import ( "fmt" "testing" - "github.com/hashicorp/go-plugin" "github.com/stretchr/testify/assert" "go.mondoo.com/cnquery/v11/providers-sdk/v1/inventory" pp "go.mondoo.com/cnquery/v11/providers-sdk/v1/plugin" @@ -32,7 +31,6 @@ func TestShutdown(t *testing.T) { c.runningByID[id] = &RunningProvider{ ID: id, Plugin: mockPlugin, - Client: &plugin.Client{}, } } diff --git a/providers/providers_test.go b/providers/providers_test.go index 00da53ad8d..08c3db68c7 100644 --- a/providers/providers_test.go +++ b/providers/providers_test.go @@ -4,6 +4,7 @@ package providers import ( + "context" "syscall" "testing" "time" @@ -50,7 +51,9 @@ func TestProviderShutdown(t *testing.T) { interval: 500 * time.Millisecond, gracePeriod: 500 * time.Millisecond, } - err := s.heartbeat() + hbtCtx, hbtCancel := context.WithCancel(context.Background()) + s.hbCancelFunc = hbtCancel + err := s.heartbeat(hbtCtx, hbtCancel) require.NoError(t, err) require.False(t, s.isCloseOrShutdown()) // the shutdown here takes 10 seconds, whereas the heartbeat interval is every second. diff --git a/providers/running_provider.go b/providers/running_provider.go index 94c26577ed..c52258ffa5 100644 --- a/providers/running_provider.go +++ b/providers/running_provider.go @@ -4,22 +4,231 @@ package providers import ( + "context" + "errors" + "fmt" "sync" "time" - "github.com/cockroachdb/errors" "github.com/hashicorp/go-plugin" "github.com/rs/zerolog/log" pp "go.mondoo.com/cnquery/v11/providers-sdk/v1/plugin" "go.mondoo.com/cnquery/v11/providers-sdk/v1/resources" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" ) +type connectionGraphNode struct { + connected bool + data connectReq +} + +type connectionGraph struct { + nodes map[uint32]connectionGraphNode + edges map[uint32]uint32 +} + +func newConnectionGraph() *connectionGraph { + return &connectionGraph{ + nodes: map[uint32]connectionGraphNode{}, + edges: map[uint32]uint32{}, + } +} + +func (c *connectionGraph) addNode(node uint32, data connectReq) { + c.nodes[node] = connectionGraphNode{ + connected: true, + data: data, + } +} + +func (c *connectionGraph) getNode(node uint32) (connectReq, bool) { + n, ok := c.nodes[node] + if !ok { + return connectReq{}, false + } + return n.data, ok +} + +func (c *connectionGraph) setEdge(from, to uint32) { + c.edges[from] = to +} + +func (c *connectionGraph) markDisconnected(id uint32) { + if node, ok := c.nodes[id]; ok { + node.connected = false + c.nodes[id] = node + } +} + +// topoSort returns a topological sorted list of the nodes in the graph. +func (c *connectionGraph) topoSort() []uint32 { + var sorted []uint32 + var visit func(node uint32, visited map[uint32]bool, sorted *[]uint32) + visit = func(node uint32, visited map[uint32]bool, sorted *[]uint32) { + if visited[node] { + return + } + visited[node] = true + if connected, ok := c.edges[node]; ok { + if connected != 0 { + visit(connected, visited, sorted) + } + } + *sorted = append(*sorted, node) + } + visited := map[uint32]bool{} + for nodeId, node := range c.nodes { + if !node.connected { + continue + } + visit(nodeId, visited, &sorted) + } + return sorted +} + +func (c *connectionGraph) garbageCollect() { + sorted := c.topoSort() + + keep := map[uint32]struct{}{} + for _, node := range sorted { + keep[node] = struct{}{} + } + + for node := range c.nodes { + if _, ok := keep[node]; !ok { + delete(c.nodes, node) + delete(c.edges, node) + } + } +} + +type ReconnectFunc func() (pp.ProviderPlugin, *plugin.Client, error) +type connectReq struct { + req *pp.ConnectReq + cb pp.ProviderCallback +} + +const maxRestartCount = 3 + +type RestartableProvider struct { + plugin pp.ProviderPlugin + client *plugin.Client + connectionGraph *connectionGraph + reconnectFunc ReconnectFunc + restartCount int + lock sync.Mutex +} + +func (r *RestartableProvider) Client() *plugin.Client { + r.lock.Lock() + defer r.lock.Unlock() + return r.client +} + +// Connect implements plugin.ProviderPlugin. +func (r *RestartableProvider) Connect(req *pp.ConnectReq, cb pp.ProviderCallback) (*pp.ConnectRes, error) { + if len(req.Asset.GetConnections()) > 0 { + reqClone := proto.Clone(req).(*pp.ConnectReq) + r.lock.Lock() + connectionId := req.Asset.Connections[0].Id + if _, ok := r.connectionGraph.getNode(connectionId); !ok { + r.connectionGraph.addNode(connectionId, connectReq{ + req: reqClone, + cb: cb, + }) + r.connectionGraph.setEdge(connectionId, req.Asset.Connections[0].ParentConnectionId) + } + + r.lock.Unlock() + } + + resp, err := r.plugin.Connect(req, cb) + if err != nil { + return nil, err + } + + return resp, nil +} + +func (r *RestartableProvider) Reconnect() error { + r.lock.Lock() + defer r.lock.Unlock() + + if r.restartCount >= maxRestartCount { + return errors.New("reached maximum provider restart count") + } + r.restartCount++ + + p, c, err := r.reconnectFunc() + if err != nil { + return fmt.Errorf("failed to reconnect: %w", err) + } + r.plugin = p + r.client = c + + connectRequestOrder := r.connectionGraph.topoSort() + + for _, connect := range connectRequestOrder { + cr, ok := r.connectionGraph.getNode(connect) + if !ok { + continue + } + + if _, err := r.plugin.Connect(cr.req, cr.cb); err != nil { + return fmt.Errorf("failed to reconnect connection %d: %w", connect, err) + } + } + + return nil +} + +// Disconnect implements plugin.ProviderPlugin. +func (r *RestartableProvider) Disconnect(req *pp.DisconnectReq) (*pp.DisconnectRes, error) { + r.lock.Lock() + r.connectionGraph.markDisconnected(req.Connection) + r.connectionGraph.garbageCollect() + r.lock.Unlock() + + return r.plugin.Disconnect(req) +} + +// GetData implements plugin.ProviderPlugin. +func (r *RestartableProvider) GetData(req *pp.DataReq) (*pp.DataRes, error) { + return r.plugin.GetData(req) +} + +// Heartbeat implements plugin.ProviderPlugin. +func (r *RestartableProvider) Heartbeat(req *pp.HeartbeatReq) (*pp.HeartbeatRes, error) { + return r.plugin.Heartbeat(req) +} + +// MockConnect implements plugin.ProviderPlugin. +func (r *RestartableProvider) MockConnect(req *pp.ConnectReq, callback pp.ProviderCallback) (*pp.ConnectRes, error) { + return r.plugin.MockConnect(req, callback) +} + +// ParseCLI implements plugin.ProviderPlugin. +func (r *RestartableProvider) ParseCLI(req *pp.ParseCLIReq) (*pp.ParseCLIRes, error) { + return r.plugin.ParseCLI(req) +} + +// Shutdown implements plugin.ProviderPlugin. +func (r *RestartableProvider) Shutdown(req *pp.ShutdownReq) (*pp.ShutdownRes, error) { + return r.plugin.Shutdown(req) +} + +// StoreData implements plugin.ProviderPlugin. +func (r *RestartableProvider) StoreData(req *pp.StoreReq) (*pp.StoreRes, error) { + return r.plugin.StoreData(req) +} + +var _ pp.ProviderPlugin = &RestartableProvider{} + type RunningProvider struct { Name string ID string Plugin pp.ProviderPlugin - Client *plugin.Client Schema resources.ResourcesSchema // isClosed is true for any provider that is not running anymore, @@ -33,10 +242,37 @@ type RunningProvider struct { shutdownLock sync.Mutex interval time.Duration gracePeriod time.Duration + hbCancelFunc context.CancelFunc +} + +func SupervisedRunningProivder(name string, id string, plugin pp.ProviderPlugin, client *plugin.Client, schema resources.ResourcesSchema, reconnectFunc ReconnectFunc) (*RunningProvider, error) { + hbCtx, hbCancelFunc := context.WithCancel(context.Background()) + + rp := &RunningProvider{ + Name: name, + ID: id, + Schema: schema, + isClosed: false, + Plugin: &RestartableProvider{ + plugin: plugin, + client: client, + connectionGraph: newConnectionGraph(), + reconnectFunc: reconnectFunc, + }, + hbCancelFunc: hbCancelFunc, + interval: 2 * time.Second, + gracePeriod: 3 * time.Second, + } + + if err := rp.heartbeat(hbCtx, hbCancelFunc); err != nil { + return nil, err + } + + return rp, nil } // initialize the heartbeat with the provider -func (p *RunningProvider) heartbeat() error { +func (p *RunningProvider) heartbeat(ctx context.Context, cancelFunc context.CancelFunc) error { if err := p.doOneHeartbeat(p.interval + p.gracePeriod); err != nil { log.Error().Err(err).Str("plugin", p.Name).Msg("error in plugin heartbeat") if err := p.Shutdown(); err != nil { @@ -46,6 +282,8 @@ func (p *RunningProvider) heartbeat() error { } go func() { + ticker := time.NewTicker(p.interval) + defer ticker.Stop() for !p.isCloseOrShutdown() { if err := p.doOneHeartbeat(p.interval + p.gracePeriod); err != nil { log.Error().Err(err).Str("plugin", p.Name).Msg("error in plugin heartbeat") @@ -55,7 +293,13 @@ func (p *RunningProvider) heartbeat() error { break } - time.Sleep(p.interval) + select { + case <-ctx.Done(): + cancelFunc() + return + case <-ticker.C: + + } } }() @@ -83,6 +327,35 @@ func (p *RunningProvider) isCloseOrShutdown() bool { return p.isClosed || p.isShutdown } +func (p *RunningProvider) Reconnect() error { + p.lock.Lock() + defer p.lock.Unlock() + p.shutdownLock.Lock() + defer p.shutdownLock.Unlock() + if !(p.isClosed || p.isShutdown) { + return nil + } + + // we can only restart if it is a restartable provider + if rp, ok := p.Plugin.(*RestartableProvider); ok { + log.Warn().Str("plugin", p.Name).Msg("reconnecting provider") + if err := rp.Reconnect(); err != nil { + log.Error().Err(err).Str("plugin", p.Name).Msg("error in plugin reconnect") + return err + } + p.isClosed = false + p.isShutdown = false + hbCtx, hbCancelFunc := context.WithCancel(context.Background()) + if p.hbCancelFunc != nil { + p.hbCancelFunc() + } + p.hbCancelFunc = hbCancelFunc + return p.heartbeat(hbCtx, hbCancelFunc) + } + + return errors.New("provider is not restartable") +} + func (p *RunningProvider) Shutdown() error { p.lock.Lock() defer p.lock.Unlock() @@ -107,8 +380,11 @@ func (p *RunningProvider) Shutdown() error { // If the plugin was not in active use, we may not have a client at this // point. Since all of this is run within a sync-lock, we can check the // client and if it exists use it to send the kill signal. - if p.Client != nil { - p.Client.Kill() + if rp, ok := p.Plugin.(*RestartableProvider); ok { + c := rp.Client() + if c != nil { + c.Kill() + } } p.shutdownLock.Lock() p.isClosed = true @@ -122,3 +398,12 @@ func (p *RunningProvider) Shutdown() error { return err } + +func (p *RunningProvider) KillClient() { + if rp, ok := p.Plugin.(*RestartableProvider); ok { + c := rp.Client() + if c != nil { + c.Kill() + } + } +} diff --git a/providers/running_provider_test.go b/providers/running_provider_test.go new file mode 100644 index 0000000000..0043e0fe24 --- /dev/null +++ b/providers/running_provider_test.go @@ -0,0 +1,79 @@ +package providers + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConnectionGraph(t *testing.T) { + + g := newConnectionGraph() + g.addNode(1, connectReq{}) + g.addNode(2, connectReq{}) + g.addNode(3, connectReq{}) + g.addNode(4, connectReq{}) + + g.setEdge(4, 2) + g.setEdge(2, 1) + g.setEdge(3, 1) + + sorted := g.topoSort() + + require.Len(t, sorted, 4) + + requireComesBefore := func(t *testing.T, sorted []uint32, before, after uint32) { + beforeIdx := -1 + afterIdx := -1 + for i, n := range sorted { + if n == before { + beforeIdx = i + } + if n == after { + afterIdx = i + } + } + require.True(t, beforeIdx >= 0, "before node not found") + require.True(t, afterIdx >= 0, "after node not found") + require.True(t, beforeIdx < afterIdx, "before node does not come before after node") + } + + requireComesBefore(t, sorted, 2, 4) + requireComesBefore(t, sorted, 1, 2) + requireComesBefore(t, sorted, 1, 3) + + g.markDisconnected(1) + g.garbageCollect() + + sorted = g.topoSort() + require.Len(t, sorted, 4) + requireComesBefore(t, sorted, 2, 4) + requireComesBefore(t, sorted, 1, 2) + requireComesBefore(t, sorted, 1, 3) + + g.markDisconnected(2) + g.garbageCollect() + + sorted = g.topoSort() + require.Len(t, sorted, 4) + requireComesBefore(t, sorted, 2, 4) + requireComesBefore(t, sorted, 1, 2) + requireComesBefore(t, sorted, 1, 3) + + g.markDisconnected(4) + g.garbageCollect() + + sorted = g.topoSort() + require.Len(t, sorted, 2) + requireComesBefore(t, sorted, 1, 3) + require.NotContains(t, g.nodes, uint32(2)) + require.NotContains(t, g.nodes, uint32(4)) + + g.markDisconnected(3) + g.garbageCollect() + + sorted = g.topoSort() + require.Len(t, sorted, 0) + require.Empty(t, g.nodes) + require.Empty(t, g.edges) +} diff --git a/providers/runtime.go b/providers/runtime.go index 3e4c7ca3d7..840f09fa25 100644 --- a/providers/runtime.go +++ b/providers/runtime.go @@ -179,6 +179,29 @@ func (r *Runtime) providerForAsset(asset *inventory.Asset) (*Provider, error) { return nil, multierr.Wrap(errs.Deduplicate(), "cannot find provider for this asset") } +func (r *Runtime) EnsureProvidersConnected() error { + if r.Provider == nil { + return errors.New("cannot reconnect, no provider set") + } + + if r.Provider.Connection == nil { + return errors.New("cannot reconnect, no connection set") + } + + err := r.Provider.Instance.Reconnect() + if err != nil { + return err + } + + for _, p := range r.providers { + if err := p.Instance.Reconnect(); err != nil { + return err + } + } + + return nil +} + // Connect to an asset using the main provider func (r *Runtime) Connect(req *plugin.ConnectReq) error { if r.Provider == nil { From 457eb4a1932ddeee483400bf7ded62c436986dcf Mon Sep 17 00:00:00 2001 From: Jay Mundrawala Date: Mon, 26 Aug 2024 12:31:55 -0500 Subject: [PATCH 2/4] disconnect garbage collected connections --- providers/running_provider.go | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/providers/running_provider.go b/providers/running_provider.go index c52258ffa5..0b4d8033bd 100644 --- a/providers/running_provider.go +++ b/providers/running_provider.go @@ -87,7 +87,7 @@ func (c *connectionGraph) topoSort() []uint32 { return sorted } -func (c *connectionGraph) garbageCollect() { +func (c *connectionGraph) garbageCollect() []uint32 { sorted := c.topoSort() keep := map[uint32]struct{}{} @@ -95,12 +95,16 @@ func (c *connectionGraph) garbageCollect() { keep[node] = struct{}{} } + collected := []uint32{} for node := range c.nodes { if _, ok := keep[node]; !ok { + collected = append(collected, node) delete(c.nodes, node) delete(c.edges, node) } } + + return collected } type ReconnectFunc func() (pp.ProviderPlugin, *plugin.Client, error) @@ -187,10 +191,24 @@ func (r *RestartableProvider) Reconnect() error { func (r *RestartableProvider) Disconnect(req *pp.DisconnectReq) (*pp.DisconnectRes, error) { r.lock.Lock() r.connectionGraph.markDisconnected(req.Connection) - r.connectionGraph.garbageCollect() + collected := r.connectionGraph.garbageCollect() r.lock.Unlock() - return r.plugin.Disconnect(req) + resp, err := r.plugin.Disconnect(req) + + for _, c := range collected { + if c == req.Connection { + continue + } + _, err := r.plugin.Disconnect(&pp.DisconnectReq{ + Connection: c, + }) + if err != nil { + log.Warn().Err(err).Uint32("connection", c).Msg("failed to disconnect garbage collected connection") + } + } + + return resp, err } // GetData implements plugin.ProviderPlugin. From 991a3ffd010a94b919adf261751cd9fc72ca691d Mon Sep 17 00:00:00 2001 From: Jay Mundrawala Date: Mon, 26 Aug 2024 12:36:10 -0500 Subject: [PATCH 3/4] add missing header --- providers/running_provider_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/providers/running_provider_test.go b/providers/running_provider_test.go index 0043e0fe24..4461ba0787 100644 --- a/providers/running_provider_test.go +++ b/providers/running_provider_test.go @@ -1,3 +1,6 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + package providers import ( From 29fa5a861a784dc8f503b3f6caf87d5592484918 Mon Sep 17 00:00:00 2001 From: Jay Mundrawala Date: Tue, 27 Aug 2024 09:48:10 -0500 Subject: [PATCH 4/4] add comments for connectionGraph --- providers/running_provider.go | 44 +++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/providers/running_provider.go b/providers/running_provider.go index 0b4d8033bd..7c0453fb25 100644 --- a/providers/running_provider.go +++ b/providers/running_provider.go @@ -18,13 +18,36 @@ import ( "google.golang.org/protobuf/proto" ) +// connectionGraphNode is a node in the connection graph. It represents a connection. type connectionGraphNode struct { - connected bool - data connectReq + // explicitlyConnected is true if the connection was explicitly connected + // it is set to false when explicitly disconnected. + // When reconnecting, disconnected connections are not set to explicitly connected, + // even if we require the connection to connect another connection. + explicitlyConnected bool + // data is the connect request data for the connection + data connectReq } +// connectionGraph is a directed graph of connections. +// Each node represents a connection. It can have one edge to its parent connection. +// +// When a connection is first connected, addNode is called to add the connection to the graph +// and keep track of the connect request data. This is also when setEdge is called to set the +// edge to the parent connection. +// +// When a connection is disconnected, markDisconnected is called to mark the connection as disconnected. +// When a connection is marked as disconnected, it indicates that the connection is not explicitly required. +// It is still possible that the connection needs to be reconnected if another connection has it set as its +// parent. +// This is also when garbageCollect is called to remove connections from the graph is they are not explicitly +// connected and are not required by any other connection. type connectionGraph struct { + // nodes is a map of connection id to connectionGraphNode. We store data to + // reestablish the connection when reconnecting. We also store if the connection + // has been disconnected. nodes map[uint32]connectionGraphNode + // edges is a map of connection id to parent connection id edges map[uint32]uint32 } @@ -35,13 +58,15 @@ func newConnectionGraph() *connectionGraph { } } +// addNode adds a node to the graph with the given data. func (c *connectionGraph) addNode(node uint32, data connectReq) { c.nodes[node] = connectionGraphNode{ - connected: true, - data: data, + explicitlyConnected: true, + data: data, } } +// getNode returns the connect request data for the given node. func (c *connectionGraph) getNode(node uint32) (connectReq, bool) { n, ok := c.nodes[node] if !ok { @@ -50,18 +75,22 @@ func (c *connectionGraph) getNode(node uint32) (connectReq, bool) { return n.data, ok } +// setEdge sets the edge from the from node to the to node. +// from is the child node and to is the parent node. func (c *connectionGraph) setEdge(from, to uint32) { c.edges[from] = to } +// markDisconnected marks the connection as disconnected. It may still be needed by other connections. func (c *connectionGraph) markDisconnected(id uint32) { if node, ok := c.nodes[id]; ok { - node.connected = false + node.explicitlyConnected = false c.nodes[id] = node } } -// topoSort returns a topological sorted list of the nodes in the graph. +// topoSort returns a topological sorted list of the nodes in the graph. Connecting in this order +// will ensure that all connections are connected in the correct order. func (c *connectionGraph) topoSort() []uint32 { var sorted []uint32 var visit func(node uint32, visited map[uint32]bool, sorted *[]uint32) @@ -79,7 +108,7 @@ func (c *connectionGraph) topoSort() []uint32 { } visited := map[uint32]bool{} for nodeId, node := range c.nodes { - if !node.connected { + if !node.explicitlyConnected { continue } visit(nodeId, visited, &sorted) @@ -87,6 +116,7 @@ func (c *connectionGraph) topoSort() []uint32 { return sorted } +// garbageCollect removes nodes from the graph that are not explicitly connected and are not required by any other connection. func (c *connectionGraph) garbageCollect() []uint32 { sorted := c.topoSort()