From 216f9f87bbd48640558a72c1f8b12fec1bbdf3bc Mon Sep 17 00:00:00 2001 From: Dominik Richter Date: Fri, 20 Oct 2023 02:30:30 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20expand=20EnsureProvider=20to=20c?= =?UTF-8?q?over=20IDs=20(#2306)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactor the whole EnsureProvider flow so users can specify providers via ID, connector name and connector type. All of them are useful: 1. Install provider from Defaults => just specify the ID 2. Install provider from CLI args (like "local") => just specify name 3. Install provider from connection type => doh Signed-off-by: Dominik Richter --- cli/providers/providers.go | 2 +- cli/sysinfo/sysinfo.go | 2 +- providers/providers.go | 115 +++++++++++++++++++++++-------------- providers/runtime.go | 2 +- 4 files changed, 75 insertions(+), 46 deletions(-) diff --git a/cli/providers/providers.go b/cli/providers/providers.go index 77cdab2c36..bb77905beb 100644 --- a/cli/providers/providers.go +++ b/cli/providers/providers.go @@ -36,7 +36,7 @@ func AttachCLIs(rootCmd *cobra.Command, commands ...*Command) error { connectorName, autoUpdate := detectConnectorName(os.Args, rootCmd, commands, existing) if connectorName != "" { - if _, err := providers.EnsureProvider(connectorName, "", autoUpdate, existing); err != nil { + if _, err := providers.EnsureProvider(providers.ProviderLookup{ConnName: connectorName}, autoUpdate, existing); err != nil { return err } } diff --git a/cli/sysinfo/sysinfo.go b/cli/sysinfo/sysinfo.go index 966008fe02..9c7933ead0 100644 --- a/cli/sysinfo/sysinfo.go +++ b/cli/sysinfo/sysinfo.go @@ -52,7 +52,7 @@ func GatherSystemInfo(opts ...SystemInfoOption) (*SystemInfo, error) { cfg.runtime = providers.Coordinator.NewRuntime() // init runtime - if _, err := providers.EnsureProvider("local", "", true, nil); err != nil { + if _, err := providers.EnsureProvider(providers.ProviderLookup{ConnName: "local"}, true, nil); err != nil { return nil, err } if err := cfg.runtime.UseProvider(providers.DefaultOsID); err != nil { diff --git a/providers/providers.go b/providers/providers.go index 0a39d9419d..8596573eaa 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -51,8 +51,71 @@ func init() { LastProviderInstall = time.Now().Unix() } +type ProviderLookup struct { + ID string + ConnName string + ConnType string +} + +func (s ProviderLookup) String() string { + res := []string{} + if s.ID != "" { + res = append(res, "id="+s.ID) + } + if s.ConnName != "" { + res = append(res, "name="+s.ConnName) + } + if s.ConnType != "" { + res = append(res, "name="+s.ConnType) + } + return strings.Join(res, " ") +} + type Providers map[string]*Provider +// Lookup a provider in this list. If you search via ProviderID we will +// try to find the exact provider. Otherwise we will try to find a matching +// connector type first and name second. +func (p Providers) Lookup(search ProviderLookup) *Provider { + if search.ID != "" { + return p[search.ID] + } + + if search.ConnType != "" { + for _, provider := range p { + if slices.Contains(provider.ConnectionTypes, search.ConnType) { + return provider + } + for i := range provider.Connectors { + if slices.Contains(provider.Connectors[i].Aliases, search.ConnType) { + return provider + } + } + } + } + + if search.ConnName != "" { + for _, provider := range p { + for i := range provider.Connectors { + if provider.Connectors[i].Name == search.ConnName { + return provider + } + if slices.Contains(provider.Connectors[i].Aliases, search.ConnName) { + return provider + } + } + } + } + + return nil +} + +func (p Providers) Add(nu *Provider) { + if nu != nil { + p[nu.ID] = nu + } +} + type Provider struct { *plugin.Provider Schema *resources.Schema @@ -169,15 +232,16 @@ func ListAll() ([]*Provider, error) { // EnsureProvider makes sure that a given provider exists and returns it. // You can supply providers either via: -// 1. connectorName, which is what you see in the CLI e.g. "local", "ssh", ... -// 2. connectorType, which is how assets define the connector type when +// 1. providerID, which universally identifies it, e.g. "go.mondoo.com/cnquery/v9/providers/os" +// 2. connectorName, which is what you see in the CLI e.g. "local", "ssh", ... +// 3. connectorType, which is how assets define the connector type when // they are moved between discovery and execution, e.g. "registry-image". // // If you disable autoUpdate, it will neither update NOR install missing providers. // // If you don't supply existing providers, it will look for alist of all // active providers first. -func EnsureProvider(connectorName string, connectorType string, autoUpdate bool, existing Providers) (*Provider, error) { +func EnsureProvider(search ProviderLookup, autoUpdate bool, existing Providers) (*Provider, error) { if existing == nil { var err error existing, err = ListActive() @@ -186,17 +250,17 @@ func EnsureProvider(connectorName string, connectorType string, autoUpdate bool, } } - provider := existing.ForConnection(connectorName, connectorType) + provider := existing.Lookup(search) if provider != nil { return provider, nil } - if connectorName == "mock" || connectorType == "mock" { + if search.ID == mockProvider.ID || search.ConnName == "mock" || search.ConnType == "mock" { existing.Add(&mockProvider) return &mockProvider, nil } - upstream := DefaultProviders.ForConnection(connectorName, connectorType) + upstream := DefaultProviders.Lookup(search) if upstream == nil { // we can't find any provider for this connector in our default set // FIXME: This causes a panic in the CLI, we should handle this better @@ -204,13 +268,14 @@ func EnsureProvider(connectorName string, connectorType string, autoUpdate bool, } if !autoUpdate { - return nil, errors.New("cannot find installed provider for connection " + connectorName) + return nil, errors.New("cannot find installed provider for " + search.String()) } nu, err := Install(upstream.Name, "") if err != nil { return nil, err } + existing.Add(nu) PrintInstallResults([]*Provider{nu}) return nu, nil @@ -626,42 +691,6 @@ func (p *Provider) binPath() string { return filepath.Join(p.Path, name) } -func (p Providers) ForConnection(name string, typ string) *Provider { - if name != "" { - for _, provider := range p { - for i := range provider.Connectors { - if provider.Connectors[i].Name == name { - return provider - } - if slices.Contains(provider.Connectors[i].Aliases, name) { - return provider - } - } - } - } - - if typ != "" { - for _, provider := range p { - if slices.Contains(provider.ConnectionTypes, typ) { - return provider - } - for i := range provider.Connectors { - if slices.Contains(provider.Connectors[i].Aliases, typ) { - return provider - } - } - } - } - - return nil -} - -func (p Providers) Add(nu *Provider) { - if nu != nil { - p[nu.ID] = nu - } -} - func MustLoadSchema(name string, data []byte) *resources.Schema { var res resources.Schema if err := json.Unmarshal(data, &res); err != nil { diff --git a/providers/runtime.go b/providers/runtime.go index fa481784af..6cc70d9a44 100644 --- a/providers/runtime.go +++ b/providers/runtime.go @@ -179,7 +179,7 @@ func (r *Runtime) DetectProvider(asset *inventory.Asset) error { conn.Type = inventory.ConnBackendToType(conn.Backend) } - provider, err := EnsureProvider("", conn.Type, true, r.coordinator.Providers) + provider, err := EnsureProvider(ProviderLookup{ConnType: conn.Type}, true, r.coordinator.Providers) if err != nil { errs.Add(err) continue