Skip to content

Commit

Permalink
make provider reconnectable
Browse files Browse the repository at this point in the history
  • Loading branch information
jaym committed Aug 26, 2024
1 parent ff88599 commit aa0e94b
Show file tree
Hide file tree
Showing 6 changed files with 436 additions and 49 deletions.
79 changes: 39 additions & 40 deletions providers/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"os/exec"
"strconv"
"sync"
"time"

"github.com/cockroachdb/errors"
"github.com/hashicorp/go-hclog"
Expand Down Expand Up @@ -310,51 +309,51 @@ func (c *coordinator) unsafeStartProvider(id string, update UpdateProvidersConfi
}
}

pluginCmd := exec.Command(provider.binPath(), []string{"run_as_plugin", "--log-level", zerolog.GlobalLevel().String()}...)

addColorConfig(pluginCmd)

pluginLogger := &hclogger{Logger: log.Logger}
pluginLogger.SetLevel(hclog.Warn)
client := plugin.NewClient(&plugin.ClientConfig{
HandshakeConfig: pp.Handshake,
Plugins: pp.PluginMap,
Cmd: pluginCmd,
AllowedProtocols: []plugin.Protocol{
plugin.ProtocolNetRPC, plugin.ProtocolGRPC,
},
Logger: pluginLogger,
Stderr: os.Stderr,
})

// Connect via RPC
rpcClient, err := client.Client()
if err != nil {
client.Kill()
return nil, errors.Wrap(err, "failed to initialize plugin client")
}
connectFunc := func() (pp.ProviderPlugin, *plugin.Client, error) {
pluginCmd := exec.Command(provider.binPath(), []string{"run_as_plugin", "--log-level", zerolog.GlobalLevel().String()}...)

addColorConfig(pluginCmd)

pluginLogger := &hclogger{Logger: log.Logger}
pluginLogger.SetLevel(hclog.Warn)
client := plugin.NewClient(&plugin.ClientConfig{
HandshakeConfig: pp.Handshake,
Plugins: pp.PluginMap,
Cmd: pluginCmd,
AllowedProtocols: []plugin.Protocol{
plugin.ProtocolNetRPC, plugin.ProtocolGRPC,
},
Logger: pluginLogger,
Stderr: os.Stderr,
})

// Connect via RPC
rpcClient, err := client.Client()
if err != nil {
client.Kill()
return nil, nil, errors.Wrap(err, "failed to initialize plugin client")
}

// Request the plugin
pluginName := "provider"
raw, err := rpcClient.Dispense(pluginName)
if err != nil {
client.Kill()
return nil, errors.Wrap(err, "failed to call "+pluginName+" plugin")
// Request the plugin
pluginName := "provider"
raw, err := rpcClient.Dispense(pluginName)
if err != nil {
client.Kill()
return nil, nil, errors.Wrap(err, "failed to call "+pluginName+" plugin")
}

return raw.(pp.ProviderPlugin), client, nil
}

res := &RunningProvider{
Name: provider.Name,
ID: provider.ID,
Plugin: raw.(pp.ProviderPlugin),
Client: client,
Schema: provider.Schema,
interval: 2 * time.Second,
gracePeriod: 3 * time.Second,
plug, client, err := connectFunc()
if err != nil {
return nil, err
}

c.schema.Add(provider.ID, provider.Schema)

if err := res.heartbeat(); err != nil {
res, err := SupervisedRunningProivder(provider.Name, provider.ID, plug, client, provider.Schema, connectFunc)
if err != nil {
return nil, err
}
c.runningByID[res.ID] = res
Expand Down Expand Up @@ -388,7 +387,7 @@ func (c *coordinator) Shutdown() {
if err := provider.Shutdown(); err != nil {
log.Warn().Err(err).Str("provider", provider.Name).Msg("failed to shut down provider")
}
provider.Client.Kill()
provider.KillClient()
}
c.runningByID = map[string]*RunningProvider{}
c.runtimes = map[string]*Runtime{}
Expand Down
2 changes: 0 additions & 2 deletions providers/coordinator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"testing"

"github.com/hashicorp/go-plugin"
"github.com/stretchr/testify/assert"
"go.mondoo.com/cnquery/v11/providers-sdk/v1/inventory"
pp "go.mondoo.com/cnquery/v11/providers-sdk/v1/plugin"
Expand All @@ -32,7 +31,6 @@ func TestShutdown(t *testing.T) {
c.runningByID[id] = &RunningProvider{
ID: id,
Plugin: mockPlugin,
Client: &plugin.Client{},
}
}

Expand Down
5 changes: 4 additions & 1 deletion providers/providers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package providers

import (
"context"
"syscall"
"testing"
"time"
Expand Down Expand Up @@ -50,7 +51,9 @@ func TestProviderShutdown(t *testing.T) {
interval: 500 * time.Millisecond,
gracePeriod: 500 * time.Millisecond,
}
err := s.heartbeat()
hbtCtx, hbtCancel := context.WithCancel(context.Background())
s.hbCancelFunc = hbtCancel
err := s.heartbeat(hbtCtx, hbtCancel)
require.NoError(t, err)
require.False(t, s.isCloseOrShutdown())
// the shutdown here takes 10 seconds, whereas the heartbeat interval is every second.
Expand Down
Loading

0 comments on commit aa0e94b

Please sign in to comment.