From 617a495ee0ca56d7e03f85c14bbc9fcf52d175df Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Fri, 30 Aug 2024 13:01:04 -0700 Subject: [PATCH 01/45] implemented toml connection config --- connection_configuration.go | 499 +++++++++++++++++++++++++++++++ connection_configuration_test.go | 177 +++++++++++ errors.go | 6 + go.mod | 1 + go.sum | 2 + 5 files changed, 685 insertions(+) create mode 100644 connection_configuration.go create mode 100644 connection_configuration_test.go diff --git a/connection_configuration.go b/connection_configuration.go new file mode 100644 index 000000000..5ce71d914 --- /dev/null +++ b/connection_configuration.go @@ -0,0 +1,499 @@ +// Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + +package gosnowflake + +import ( + "encoding/base64" + "os" + "strconv" + "strings" + "time" + + path "path/filepath" + + toml "github.com/BurntSushi/toml" +) + +func LoadConnectionConfig() (*Config, error) { + cfg := &Config{ + Params: make(map[string]*string), + Authenticator: AuthTypeSnowflake, // Default to snowflake + } + var dsn string = getConnectionDSN(os.Getenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME")) + snowflakeConfigDir, err := getTomlFilePath(os.Getenv("SNOWFLAKE_HOME")) + if err != nil { + return nil, err + } + tomlFilePath := path.Join(snowflakeConfigDir, "connections.toml") + err = validateFilePermission(tomlFilePath) + if err != nil { + return nil, err + } + var tomlInfo = make(map[string]interface{}) + + _, err = toml.DecodeFile(tomlFilePath, &tomlInfo) + if err != nil { + return nil, err + } + connectionName, exist := tomlInfo[dsn] + if !exist { + err = &SnowflakeError{ + Number: ErrCodeFailedToFindDSNInToml, + Message: errMsgFailedToFindDSNInTomlFile, + } + return nil, err + } + + connectionConfig, ok := connectionName.(map[string]interface{}) + if !ok { + return nil, err + } + + err = parseToml(cfg, connectionConfig) + if err != nil { + return nil, err + } + + return cfg, err +} + +func parseToml(cfg *Config, connection map[string]interface{}) error { + var ok, vv bool + var err error = &SnowflakeError{ + Number: ErrCodeTomlFileParsingFailed, + Message: errMsgFailedToParseTomlFile, + MessageArgs: []interface{}{cfg.Host}, + } + var v, tokenPath string + for key, value := range connection { + switch strings.ToLower(key) { + case "user", "username": + cfg.User, ok = value.(string) + if !ok { + // //errorinterface + return err + } + case "password": + cfg.Password, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "host": + cfg.Host, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "account": + cfg.Account, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "warehouse": + cfg.Warehouse, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "database": + cfg.Database, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "schema": + cfg.Schema, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "role": + cfg.Role, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "region": + cfg.Region, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "protocol": + cfg.Protocol, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "passcode": + cfg.Passcode, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "port": + cfg.Port, err = parseInt(value) + if err != nil { + //errorinterface + return err + } + case "passcodeInPassword": + cfg.PasscodeInPassword, err = parseBool(value) + if err != nil { + //errorinterface + return err + } + case "clientTimeout": + cfg.ClientTimeout, err = parseDuration(value) + if err != nil { + //errorinterface + return err + } + case "jwtClientTimeout": + cfg.JWTClientTimeout, err = parseDuration(value) + if err != nil { + //errorinterface + return err + } + case "loginTimeout": + cfg.LoginTimeout, err = parseDuration(value) + if err != nil { + //errorinterface + return err + } + case "requestTimeout": + cfg.RequestTimeout, err = parseDuration(value) + if err != nil { + //errorinterface + return err + } + case "jwtTimeout": + cfg.JWTExpireTimeout, err = parseDuration(value) + if err != nil { + //errorinterface + return err + } + case "externalBrowserTimeout": + cfg.ExternalBrowserTimeout, err = parseDuration(value) + if err != nil { + //errorinterface + return err + } + case "maxRetryCount": + cfg.MaxRetryCount, err = parseInt(value) + if err != nil { + //errorinterface + return err + } + case "application": + cfg.Application, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "authenticator": + v, ok = value.(string) + err = determineAuthenticatorType(cfg, v) + if err != nil { + //errorinterface + return err + } + case "insecureMode": + cfg.InsecureMode, err = parseBool(value) + if err != nil { + //errorinterface + return err + } + case "ocspFailOpen": + vv, err = parseBool(value) + if err != nil { + //errorinterface + return err + } + if vv { + cfg.OCSPFailOpen = OCSPFailOpenTrue + } else { + cfg.OCSPFailOpen = OCSPFailOpenFalse + } + + case "token": + cfg.Token, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "privateKey": + v, ok = value.(string) + if !ok { + //errorinterface + return err + } + var decodeErr error + block, decodeErr := base64.URLEncoding.DecodeString(v) + if decodeErr != nil { + err = &SnowflakeError{ + Number: ErrCodePrivateKeyParseError, + Message: "Base64 decode failed", + } + return err + } + cfg.PrivateKey, err = parsePKCS8PrivateKey(block) + if err != nil { + //errorinterface + return err + } + case "validateDefaultParameters": + vv, err = parseBool(value) + if err != nil { + //errorinterface + return err + } + if vv { + cfg.ValidateDefaultParameters = ConfigBoolTrue + } else { + cfg.ValidateDefaultParameters = ConfigBoolFalse + } + case "clientRequestMfaToken": + vv, err = parseBool(value) + if err != nil { + //errorinterface + return err + } + if vv { + cfg.ClientRequestMfaToken = ConfigBoolTrue + } else { + cfg.ClientRequestMfaToken = ConfigBoolFalse + } + case "clientStoreTemporaryCredential": + vv, err = parseBool(value) + if err != nil { + //errorinterface + return err + } + if vv { + cfg.ClientStoreTemporaryCredential = ConfigBoolTrue + } else { + cfg.ClientStoreTemporaryCredential = ConfigBoolFalse + } + case "tracing": + cfg.Tracing, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "tmpDirPath": + cfg.TmpDirPath, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "disableQueryContextCache": + vv, err = parseBool(value) + if err != nil { + //errorinterface + return err + } + cfg.DisableQueryContextCache = vv + case "includeRetryReason": + vv, err = parseBool(value) + if err != nil { + //errorinterface + return err + } + if vv { + cfg.IncludeRetryReason = ConfigBoolTrue + } else { + cfg.IncludeRetryReason = ConfigBoolFalse + } + case "clientConfigFile": + cfg.ClientConfigFile, ok = value.(string) + if !ok { + //errorinterface + return err + } + case "disableConsoleLogin": + vv, err = parseBool(value) + if err != nil { + //errorinterface + return err + } + if vv { + cfg.DisableConsoleLogin = ConfigBoolTrue + } else { + cfg.DisableConsoleLogin = ConfigBoolFalse + } + case "disableSamlURLCheck": + vv, err = parseBool(value) + if err != nil { + //errorinterface + return err + } + if vv { + cfg.DisableSamlURLCheck = ConfigBoolTrue + } else { + cfg.DisableSamlURLCheck = ConfigBoolFalse + } + case "token_file_path": + tokenPath, ok = value.(string) + if !ok { + //errorinterface + return err + } + default: + var param string + param, ok = value.(string) + if !ok { + //errorinterface + return err + } + cfg.Params[urlDecodeIfNeeded(key)] = ¶m + } + } + if shouldReadTokenFromFile(cfg) { + v, err := readToken(tokenPath) + if err != nil { + return err + } + cfg.Token = v + } + return nil +} + +func parseInt(i interface{}) (int, error) { + var v string + var ok bool + var num int + var err, parseErr error + parseErr = &SnowflakeError{ + Number: ErrCodeTomlFileParsingFailed, + Message: errMsgFailedToParseTomlFile, + MessageArgs: []interface{}{i}, + } + if v, ok = i.(string); !ok { + if num, ok = i.(int); !ok { + return 0, parseErr + } else { + return num, nil + } + } else { + num, err = strconv.Atoi(v) + if err != nil { + return 0, parseErr + } + return num, nil + } +} + +func parseBool(i interface{}) (bool, error) { + var v string + var ok, vv bool + var err, parseErr error + parseErr = &SnowflakeError{ + Number: ErrCodeTomlFileParsingFailed, + Message: errMsgFailedToParseTomlFile, + MessageArgs: []interface{}{i}, + } + if v, ok = i.(string); !ok { + if vv, ok = i.(bool); !ok { + return false, parseErr + } else { + return vv, nil + } + } else { + vv, err = strconv.ParseBool(v) + if err != nil { + return false, parseErr + } + return vv, nil + } +} + +func parseDuration(i interface{}) (time.Duration, error) { + var v string + var ok bool + var num int + var t int64 + var err, parseErr error + parseErr = &SnowflakeError{ + Number: ErrCodeTomlFileParsingFailed, + Message: errMsgFailedToParseTomlFile, + MessageArgs: []interface{}{i}, + } + if v, ok = i.(string); !ok { + if num, err = parseInt(i); err != nil { + return time.Duration(0), parseErr + } else { + t = int64(num) + return time.Duration(t * int64(time.Second)), nil + } + } else { + t, err = strconv.ParseInt(v, 10, 64) + if err != nil { + return time.Duration(0), parseErr + } + return time.Duration(t * int64(time.Second)), nil + } +} + +func readToken(tokenPath string) (string, error) { + if !path.IsAbs(tokenPath) { + snowflakeConfigDir, err := getTomlFilePath(os.Getenv("SNOWFLAKE_HOME")) + if err != nil { + return "", err + } + tokenPath = path.Join(snowflakeConfigDir, tokenPath) + } + err := validateFilePermission(tokenPath) + if err != nil { + return "", err + } + token, err := os.ReadFile(tokenPath) + if err != nil { + return "", err + } + return string(token), nil +} + +func getTomlFilePath(filePath string) (string, error) { + var dir string + if len(filePath) != 0 { + dir = filePath + } else { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", err + } + dir = path.Join(homeDir, "snowflake") + } + absDir, err := path.Abs(dir) + if err != nil { + return "", err + } + return absDir, nil +} + +func getConnectionDSN(dsn string) string { + if len(dsn) != 0 { + return dsn + } else { + return "default" + } +} + +func validateFilePermission(filePath string) error { + fileInfo, err := os.Stat(filePath) + if err != nil { + return err + } + permission := fileInfo.Mode().Perm() + if permission != 0o600 { + return err + } + return nil +} + +func shouldReadTokenFromFile(cfg *Config) bool { + return cfg != nil && cfg.Authenticator == AuthTypeOAuth && len(cfg.Token) == 0 +} diff --git a/connection_configuration_test.go b/connection_configuration_test.go new file mode 100644 index 000000000..1ab0ed1a2 --- /dev/null +++ b/connection_configuration_test.go @@ -0,0 +1,177 @@ +package gosnowflake + +import ( + "io/fs" + "os" + "testing" + "time" +) + +func TestLoadConnectionConfig_Default(t *testing.T) { + os.Setenv("SNOWFLAKE_HOME", "./") + + cfg, err := LoadConnectionConfig() + + if err != nil { + t.Fatalf("err: %v", err) + } + + assertEqualF(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") + assertEqualF(t, cfg.User, "test_user") + assertEqualF(t, cfg.Password, "test_pass") + assertEqualF(t, cfg.Warehouse, "testw") + assertEqualF(t, cfg.Database, "test_db") + assertEqualF(t, cfg.Schema, "test_go") + assertEqualF(t, cfg.Protocol, "https") + assertEqualF(t, cfg.Port, 443) +} + +func TestLoadConnectionConfig_OAuth(t *testing.T) { + os.Setenv("SNOWFLAKE_HOME", "./") + os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth") + cfg, err := LoadConnectionConfig() + + if err != nil { + t.Fatalf("err: %v", err) + } + + assertEqualF(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") + assertEqualF(t, cfg.User, "test_user") + assertEqualF(t, cfg.Password, "test_pass") + assertEqualF(t, cfg.Warehouse, "testw") + assertEqualF(t, cfg.Database, "test_db") + assertEqualF(t, cfg.Schema, "test_go") + assertEqualF(t, cfg.Protocol, "https") + assertEqualF(t, cfg.Authenticator, AuthTypeOAuth) + assertEqualF(t, cfg.Token, "token_value") + assertEqualF(t, cfg.Port, 443) +} + +func TestLoadConnectionConfigWitNonExisitngDSN(t *testing.T) { + os.Setenv("SNOWFLAKE_HOME", "./") + os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "unavailableDSN") + + _, err := LoadConnectionConfig() + + if err == nil { + t.Fatal("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrCodeFailedToFindDSNInToml { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeFailedToFindDSNInToml, driverErr.Number) + } +} + +func TestLoadConnectionConfigWithTokenFileNotExist(t *testing.T) { + os.Setenv("SNOWFLAKE_HOME", "./") + os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth-file") + + _, err := LoadConnectionConfig() + + _, ok := err.(*(fs.PathError)) + if !ok { + t.Fatalf("should be io/fs error. err: %v", err) + } +} + +func TestParseInt(t *testing.T) { + var i interface{} + var num int + var err error + + i = 20 + num, err = parseInt(i) + if err != nil { + t.Fatalf("should be parsed: %v", err) + } + assertEqualF(t, num, 20) + + i = "40" + num, err = parseInt(i) + if err != nil { + t.Fatalf("should be parsed: %v", err) + } + assertEqualF(t, num, 40) + + i = "wrong_num" + _, err = parseInt(i) + if err == nil { + t.Fatal("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrCodeTomlFileParsingFailed { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeTomlFileParsingFailed, driverErr.Number) + } +} + +func TestParseBool(t *testing.T) { + var i interface{} + var b bool + var err error + + i = true + b, err = parseBool(i) + if err != nil { + t.Fatalf("should be parsed: %v", err) + } + assertEqualF(t, b, true) + + i = "false" + b, err = parseBool(i) + if err != nil { + t.Fatalf("should be parsed: %v", err) + } + assertEqualF(t, b, false) + + i = "wrong_bool" + _, err = parseInt(i) + if err == nil { + t.Fatal("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrCodeTomlFileParsingFailed { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeTomlFileParsingFailed, driverErr.Number) + } +} + +func TestParseDuration(t *testing.T) { + var i interface{} + var dur time.Duration + var err error + + i = 300 + dur, err = parseDuration(i) + if err != nil { + t.Fatalf("should be parsed: %v", err) + } + assertEqualF(t, dur, time.Duration(5*int64(time.Minute))) + + i = "30" + dur, err = parseDuration(i) + if err != nil { + t.Fatalf("should be parsed: %v", err) + } + assertEqualF(t, dur, time.Duration(int64(time.Minute)/2)) + + i = false + _, err = parseDuration(i) + if err == nil { + t.Fatal("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrCodeTomlFileParsingFailed { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeTomlFileParsingFailed, driverErr.Number) + } +} diff --git a/errors.go b/errors.go index 250af2e4f..3c2bf10ad 100644 --- a/errors.go +++ b/errors.go @@ -127,6 +127,10 @@ const ( ErrCodeFailedToParseAuthenticator = 260011 // ErrCodeClientConfigFailed is an error code for the case where clientConfigFile is invalid or applying client configuration fails ErrCodeClientConfigFailed = 260012 + // ErrCodeTomlFileParsing is an error code for the case when parsing the toml file is failed because of invalid value. + ErrCodeTomlFileParsingFailed = 260013 + // ErrCodeTomlFileParsing is an error code for the case when parsing the toml file is failed because of invalid value. + ErrCodeFailedToFindDSNInToml = 260013 /* network */ @@ -299,6 +303,8 @@ const ( errMsgClientConfigFailed = "client configuration failed: %v" errMsgNullValueInArray = "for handling null values in arrays use WithArrayValuesNullable(ctx)" errMsgNullValueInMap = "for handling null values in maps use WithMapValuesNullable(ctx)" + errMsgFailedToParseTomlFile = "failed to parse toml file. the params occurred error: %v" + errMsgFailedToFindDSNInTomlFile = "failed to find DSN in toml file." ) // Returned if a DNS doesn't include account parameter. diff --git a/go.mod b/go.mod index 7749971f2..ad2738313 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( require ( github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect + github.com/BurntSushi/toml v1.4.0 // indirect github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect diff --git a/go.sum b/go.sum index e5b3ef444..6e99023fa 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLC github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0 h1:u/LLAOFgsMv7HmNL4Qufg58y+qElGOt5qv0z1mURkRY= github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0/go.mod h1:2e8rMJtl2+2j+HXbTBwnyGpm5Nou7KhvSfxOq8JpTag= github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1 h1:BWe8a+f/t+7KY7zH2mqygeUD0t8hNFXe08p1Pb3/jKE= +github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= +github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= github.com/apache/arrow/go/v15 v15.0.0 h1:1zZACWf85oEZY5/kd9dsQS7i+2G5zVQcbKTHgslqHNA= From 93fed7d19233d2d126d3c8445a8d9856e77b8109 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Fri, 30 Aug 2024 13:01:59 -0700 Subject: [PATCH 02/45] added testing toml file --- connections.toml | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 connections.toml diff --git a/connections.toml b/connections.toml new file mode 100644 index 000000000..05e71c547 --- /dev/null +++ b/connections.toml @@ -0,0 +1,35 @@ +[default] +account = 'snowdriverswarsaw.us-west-2.aws' +user = 'test_user' +password = 'test_pass' +warehouse = 'testw' +database = 'test_db' +schema = 'test_go' +protocol = 'https' +port = '443' + +[aws-oauth] +account = 'snowdriverswarsaw.us-west-2.aws' +user = 'test_user' +password = 'test_pass' +warehouse = 'testw' +database = 'test_db' +schema = 'test_go' +protocol = 'https' +port = '443' +authenticator = 'oauth' +testNot = 'problematicParameter' +token = 'token_value' + +[aws-oauth-file] +account = 'snowdriverswarsaw.us-west-2.aws' +user = 'test_user' +password = 'test_pass' +warehouse = 'testw' +database = 'test_db' +schema = 'test_go' +protocol = 'https' +port = '443' +authenticator = 'oauth' +testNot = 'problematicParameter' +token_file_path = '/Users/test/.snowflake/token' \ No newline at end of file From cc025d3195fa080ef12de339e64a41aa4ed5a455 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Fri, 30 Aug 2024 13:13:20 -0700 Subject: [PATCH 03/45] fix lint --- connection_configuration.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/connection_configuration.go b/connection_configuration.go index 5ce71d914..e9adfba03 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -195,6 +195,10 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } case "authenticator": v, ok = value.(string) + if !ok { + //errorinterface + return err + } err = determineAuthenticatorType(cfg, v) if err != nil { //errorinterface From 873f7b7722c4a9c104551aa16e7d6716bfd0d5a9 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Fri, 30 Aug 2024 19:35:20 -0700 Subject: [PATCH 04/45] add comments --- connection_configuration.go | 17 ++++++++--------- errors.go | 4 ++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index e9adfba03..014c4b652 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -14,6 +14,9 @@ import ( toml "github.com/BurntSushi/toml" ) +// LoadConnectionConfig returns connection configs loaded from the toml file. +// By default, SNOWFLAKE_HOME(toml file path) is os.home/snowflake +// and SNOWFLAKE_DEFAULT_CONNECTION_NAME(DSN) is 'default' func LoadConnectionConfig() (*Config, error) { cfg := &Config{ Params: make(map[string]*string), @@ -378,9 +381,8 @@ func parseInt(i interface{}) (int, error) { if v, ok = i.(string); !ok { if num, ok = i.(int); !ok { return 0, parseErr - } else { - return num, nil } + return num, nil } else { num, err = strconv.Atoi(v) if err != nil { @@ -402,9 +404,8 @@ func parseBool(i interface{}) (bool, error) { if v, ok = i.(string); !ok { if vv, ok = i.(bool); !ok { return false, parseErr - } else { - return vv, nil } + return vv, nil } else { vv, err = strconv.ParseBool(v) if err != nil { @@ -428,10 +429,9 @@ func parseDuration(i interface{}) (time.Duration, error) { if v, ok = i.(string); !ok { if num, err = parseInt(i); err != nil { return time.Duration(0), parseErr - } else { - t = int64(num) - return time.Duration(t * int64(time.Second)), nil } + t = int64(num) + return time.Duration(t * int64(time.Second)), nil } else { t, err = strconv.ParseInt(v, 10, 64) if err != nil { @@ -481,9 +481,8 @@ func getTomlFilePath(filePath string) (string, error) { func getConnectionDSN(dsn string) string { if len(dsn) != 0 { return dsn - } else { - return "default" } + return "default" } func validateFilePermission(filePath string) error { diff --git a/errors.go b/errors.go index 3c2bf10ad..8d5daba49 100644 --- a/errors.go +++ b/errors.go @@ -127,9 +127,9 @@ const ( ErrCodeFailedToParseAuthenticator = 260011 // ErrCodeClientConfigFailed is an error code for the case where clientConfigFile is invalid or applying client configuration fails ErrCodeClientConfigFailed = 260012 - // ErrCodeTomlFileParsing is an error code for the case when parsing the toml file is failed because of invalid value. + // ErrCodeTomlFileParsingFailed is an error code for the case where parsing the toml file is failed because of invalid value. ErrCodeTomlFileParsingFailed = 260013 - // ErrCodeTomlFileParsing is an error code for the case when parsing the toml file is failed because of invalid value. + // ErrCodeFailedToFindDSNInToml is an error code for the case where the DSN does not exist in the toml file. ErrCodeFailedToFindDSNInToml = 260013 /* network */ From 0575096f88ecc7b0e5777cb8db3345fe76a4c60b Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Fri, 30 Aug 2024 19:41:56 -0700 Subject: [PATCH 05/45] fix lint --- connection_configuration.go | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 014c4b652..96bbc8b07 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -383,13 +383,12 @@ func parseInt(i interface{}) (int, error) { return 0, parseErr } return num, nil - } else { - num, err = strconv.Atoi(v) - if err != nil { - return 0, parseErr - } - return num, nil } + num, err = strconv.Atoi(v) + if err != nil { + return 0, parseErr + } + return num, nil } func parseBool(i interface{}) (bool, error) { @@ -406,13 +405,12 @@ func parseBool(i interface{}) (bool, error) { return false, parseErr } return vv, nil - } else { - vv, err = strconv.ParseBool(v) - if err != nil { - return false, parseErr - } - return vv, nil } + vv, err = strconv.ParseBool(v) + if err != nil { + return false, parseErr + } + return vv, nil } func parseDuration(i interface{}) (time.Duration, error) { @@ -432,13 +430,12 @@ func parseDuration(i interface{}) (time.Duration, error) { } t = int64(num) return time.Duration(t * int64(time.Second)), nil - } else { - t, err = strconv.ParseInt(v, 10, 64) - if err != nil { - return time.Duration(0), parseErr - } - return time.Duration(t * int64(time.Second)), nil } + t, err = strconv.ParseInt(v, 10, 64) + if err != nil { + return time.Duration(0), parseErr + } + return time.Duration(t * int64(time.Second)), nil } func readToken(tokenPath string) (string, error) { From be15d2eb91416afb3ca1cb0541a01e5973311e66 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Wed, 4 Sep 2024 19:04:08 -0700 Subject: [PATCH 06/45] add more testing cases --- connection_configuration.go | 304 +++++++++++++++---------------- connection_configuration_test.go | 87 +++++++++ errors.go | 2 +- 3 files changed, 232 insertions(+), 161 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 96bbc8b07..7558753de 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -4,13 +4,13 @@ package gosnowflake import ( "encoding/base64" + "errors" "os" + path "path/filepath" "strconv" "strings" "time" - path "path/filepath" - toml "github.com/BurntSushi/toml" ) @@ -46,7 +46,6 @@ func LoadConnectionConfig() (*Config, error) { } return nil, err } - connectionConfig, ok := connectionName.(map[string]interface{}) if !ok { return nil, err @@ -56,167 +55,159 @@ func LoadConnectionConfig() (*Config, error) { if err != nil { return nil, err } - return cfg, err } func parseToml(cfg *Config, connection map[string]interface{}) error { var ok, vv bool - var err error = &SnowflakeError{ - Number: ErrCodeTomlFileParsingFailed, - Message: errMsgFailedToParseTomlFile, - MessageArgs: []interface{}{cfg.Host}, + var err, parsingErr error + err = &SnowflakeError{ + Number: ErrCodeTomlFileParsingFailed, + Message: errMsgFailedToParseTomlFile, } var v, tokenPath string for key, value := range connection { switch strings.ToLower(key) { case "user", "username": - cfg.User, ok = value.(string) - if !ok { - // //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + cfg.User = value.(string) case "password": - cfg.Password, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + cfg.Password = value.(string) case "host": - cfg.Host, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + cfg.Host = value.(string) case "account": - cfg.Account, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + cfg.Account = value.(string) case "warehouse": - cfg.Warehouse, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + cfg.Warehouse = value.(string) case "database": - cfg.Database, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + cfg.Database = value.(string) case "schema": - cfg.Schema, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + cfg.Schema = value.(string) case "role": - cfg.Role, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + cfg.Role = value.(string) case "region": - cfg.Region, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + cfg.Region = value.(string) case "protocol": - cfg.Protocol, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + cfg.Protocol = value.(string) case "passcode": - cfg.Passcode, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + cfg.Passcode = value.(string) case "port": - cfg.Port, err = parseInt(value) - if err != nil { - //errorinterface + + if cfg.Port, parsingErr = parseInt(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "passcodeInPassword": - cfg.PasscodeInPassword, err = parseBool(value) - if err != nil { - //errorinterface + case "passcodeinpassword": + + if cfg.PasscodeInPassword, parsingErr = parseBool(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "clientTimeout": - cfg.ClientTimeout, err = parseDuration(value) - if err != nil { - //errorinterface + case "clienttimeout": + if cfg.ClientTimeout, parsingErr = parseDuration(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "jwtClientTimeout": - cfg.JWTClientTimeout, err = parseDuration(value) - if err != nil { - //errorinterface + case "jwtclienttimeout": + if cfg.JWTClientTimeout, parsingErr = parseDuration(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "loginTimeout": - cfg.LoginTimeout, err = parseDuration(value) - if err != nil { - //errorinterface + case "logintimeout": + + if cfg.LoginTimeout, parsingErr = parseDuration(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "requestTimeout": - cfg.RequestTimeout, err = parseDuration(value) - if err != nil { - //errorinterface + case "requesttimeout": + if cfg.RequestTimeout, parsingErr = parseDuration(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "jwtTimeout": - cfg.JWTExpireTimeout, err = parseDuration(value) - if err != nil { - //errorinterface + case "jwttimeout": + if cfg.JWTExpireTimeout, parsingErr = parseDuration(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "externalBrowserTimeout": - cfg.ExternalBrowserTimeout, err = parseDuration(value) - if err != nil { - //errorinterface + case "externalbrowsertimeout": + if cfg.ExternalBrowserTimeout, parsingErr = parseDuration(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "maxRetryCount": - cfg.MaxRetryCount, err = parseInt(value) - if err != nil { - //errorinterface + case "maxretrycount": + + if cfg.MaxRetryCount, parsingErr = parseInt(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } case "application": - cfg.Application, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + cfg.Application = value.(string) case "authenticator": - v, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + v = value.(string) err = determineAuthenticatorType(cfg, v) if err != nil { - //errorinterface + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "insecureMode": - cfg.InsecureMode, err = parseBool(value) - if err != nil { - //errorinterface + case "insecuremode": + if cfg.InsecureMode, parsingErr = parseBool(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "ocspFailOpen": - vv, err = parseBool(value) - if err != nil { - //errorinterface + case "ocspfailopen": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } if vv { @@ -226,17 +217,17 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } case "token": - cfg.Token, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "privateKey": - v, ok = value.(string) - if !ok { - //errorinterface + cfg.Token = value.(string) + case "privatekey": + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + v = value.(string) var decodeErr error block, decodeErr := base64.URLEncoding.DecodeString(v) if decodeErr != nil { @@ -248,13 +239,12 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } cfg.PrivateKey, err = parsePKCS8PrivateKey(block) if err != nil { - //errorinterface + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "validateDefaultParameters": - vv, err = parseBool(value) - if err != nil { - //errorinterface + case "validatedefaultparameters": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } if vv { @@ -262,10 +252,9 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } else { cfg.ValidateDefaultParameters = ConfigBoolFalse } - case "clientRequestMfaToken": - vv, err = parseBool(value) - if err != nil { - //errorinterface + case "clientrequestmfatoken": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } if vv { @@ -273,10 +262,9 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } else { cfg.ClientRequestMfaToken = ConfigBoolFalse } - case "clientStoreTemporaryCredential": - vv, err = parseBool(value) - if err != nil { - //errorinterface + case "clientstoretemporarycredential": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } if vv { @@ -285,28 +273,26 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { cfg.ClientStoreTemporaryCredential = ConfigBoolFalse } case "tracing": - cfg.Tracing, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "tmpDirPath": - cfg.TmpDirPath, ok = value.(string) - if !ok { - //errorinterface + cfg.Tracing = value.(string) + case "tmpdirpath": + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "disableQueryContextCache": - vv, err = parseBool(value) - if err != nil { - //errorinterface + cfg.TmpDirPath = value.(string) + case "disablequerycontextcache": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } cfg.DisableQueryContextCache = vv - case "includeRetryReason": - vv, err = parseBool(value) - if err != nil { - //errorinterface + case "includeretryreason": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } if vv { @@ -314,16 +300,15 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } else { cfg.IncludeRetryReason = ConfigBoolFalse } - case "clientConfigFile": - cfg.ClientConfigFile, ok = value.(string) - if !ok { - //errorinterface + case "clientconfigfile": + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } - case "disableConsoleLogin": - vv, err = parseBool(value) - if err != nil { - //errorinterface + cfg.ClientConfigFile = value.(string) + case "disableconsolelogin": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } if vv { @@ -331,10 +316,9 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } else { cfg.DisableConsoleLogin = ConfigBoolFalse } - case "disableSamlURLCheck": - vv, err = parseBool(value) - if err != nil { - //errorinterface + case "disablesamlurlcheck": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } if vv { @@ -343,18 +327,18 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { cfg.DisableSamlURLCheck = ConfigBoolFalse } case "token_file_path": - tokenPath, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + tokenPath = value.(string) default: var param string - param, ok = value.(string) - if !ok { - //errorinterface + if _, ok = value.(string); !ok { + err.(*SnowflakeError).MessageArgs = []interface{}{key, value} return err } + param = value.(string) cfg.Params[urlDecodeIfNeeded(key)] = ¶m } } @@ -371,22 +355,19 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { func parseInt(i interface{}) (int, error) { var v string var ok bool - var num int - var err, parseErr error - parseErr = &SnowflakeError{ - Number: ErrCodeTomlFileParsingFailed, - Message: errMsgFailedToParseTomlFile, - MessageArgs: []interface{}{i}, - } - if v, ok = i.(string); !ok { - if num, ok = i.(int); !ok { - return 0, parseErr + var num int = 0 + var err error = errors.New("parse Error") + if _, ok = i.(string); !ok { + if _, ok = i.(int); !ok { + return num, err } + num = i.(int) return num, nil } - num, err = strconv.Atoi(v) - if err != nil { - return 0, parseErr + v = i.(string) + + if num, err = strconv.Atoi(v); err != nil { + return num, err } return num, nil } @@ -400,12 +381,14 @@ func parseBool(i interface{}) (bool, error) { Message: errMsgFailedToParseTomlFile, MessageArgs: []interface{}{i}, } - if v, ok = i.(string); !ok { - if vv, ok = i.(bool); !ok { + if _, ok = i.(string); !ok { + if _, ok = i.(bool); !ok { return false, parseErr } + vv = i.(bool) return vv, nil } + v = i.(string) vv, err = strconv.ParseBool(v) if err != nil { return false, parseErr @@ -424,13 +407,14 @@ func parseDuration(i interface{}) (time.Duration, error) { Message: errMsgFailedToParseTomlFile, MessageArgs: []interface{}{i}, } - if v, ok = i.(string); !ok { + if _, ok = i.(string); !ok { if num, err = parseInt(i); err != nil { return time.Duration(0), parseErr } t = int64(num) return time.Duration(t * int64(time.Second)), nil } + v = i.(string) t, err = strconv.ParseInt(v, 10, 64) if err != nil { return time.Duration(0), parseErr diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 1ab0ed1a2..f22ea444c 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -175,3 +175,90 @@ func TestParseDuration(t *testing.T) { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeTomlFileParsingFailed, driverErr.Number) } } + +type paramList struct { + testParams []string + values []interface{} +} + +func TestParseToml(t *testing.T) { + testCases := []paramList{ + { + testParams: []string{"user", "password", "host", "account", "warehouse", "database", + "schema", "role", "region", "protocol", "passcode", "application", "token", + "tracing", "tmpDirPath", "clientConfigFile"}, + values: []interface{}{"value"}, + }, + { + testParams: []string{"port", "maxRetryCount", "clientTimeout", "jwtClientTimeout", "loginTimeout", + "requestTimeout", "jwtTimeout", "externalBrowserTimeout"}, + values: []interface{}{"300", 500}, + }, + { + testParams: []string{"ocspFailOpen", "insecureMode", "PasscodeInPassword", "validateDEFAULTParameters", "clientRequestMFAtoken", + "clientStoreTemporaryCredential", "disableQueryContextCache", "includeRetryReason", "disableConsoleLogin", "disableSamlUrlCheck"}, + values: []interface{}{true, "true", false, "false"}, + }, + } + + for _, testCase := range testCases { + for _, param := range testCase.testParams { + for _, value := range testCase.values { + t.Run(param, func(t *testing.T) { + cfg := &Config{} + var connectionMap = make(map[string]interface{}) + connectionMap[param] = value + err := parseToml(cfg, connectionMap) + if err != nil { + t.Fatal("should not have failed") + } + }) + } + } + } +} + +func TestParseTomlWithWrongValue(t *testing.T) { + testCases := []paramList{ + { + testParams: []string{"user", "password", "host", "account", "warehouse", "database", + "schema", "role", "region", "protocol", "passcode", "application", "token", "privateKey", + "tracing", "tmpDirPath", "clientConfigFile", "wrongParams"}, + values: []interface{}{1}, + }, + { + testParams: []string{"port", "maxRetryCount", "clientTimeout", "jwtClientTimeout", "loginTimeout", + "requestTimeout", "jwtTimeout", "externalBrowserTimeout"}, + values: []interface{}{"wrong_value", false}, + }, + { + testParams: []string{"ocspFailOpen", "insecureMode", "PasscodeInPassword", "validateDEFAULTParameters", "clientRequestMFAtoken", + "clientStoreTemporaryCredential", "disableQueryContextCache", "includeRetryReason", "disableConsoleLogin", "disableSamlUrlCheck"}, + values: []interface{}{"wrong_value", 1}, + }, + } + + for _, testCase := range testCases { + for _, param := range testCase.testParams { + for _, value := range testCase.values { + t.Run(param, func(t *testing.T) { + cfg := &Config{} + var connectionMap = make(map[string]interface{}) + connectionMap[param] = value + err := parseToml(cfg, connectionMap) + if err == nil { + t.Fatal("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrCodeTomlFileParsingFailed { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeTomlFileParsingFailed, driverErr.Number) + } + }) + + } + } + } +} diff --git a/errors.go b/errors.go index 8d5daba49..77d6d214e 100644 --- a/errors.go +++ b/errors.go @@ -303,7 +303,7 @@ const ( errMsgClientConfigFailed = "client configuration failed: %v" errMsgNullValueInArray = "for handling null values in arrays use WithArrayValuesNullable(ctx)" errMsgNullValueInMap = "for handling null values in maps use WithMapValuesNullable(ctx)" - errMsgFailedToParseTomlFile = "failed to parse toml file. the params occurred error: %v" + errMsgFailedToParseTomlFile = "failed to parse toml file. the params %v occurred error with value %v" errMsgFailedToFindDSNInTomlFile = "failed to find DSN in toml file." ) From f36fc279a7f4542e192f6f94430ffa8572992d6d Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Thu, 5 Sep 2024 12:55:45 -0700 Subject: [PATCH 07/45] fix --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index ad2738313..5fb788136 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/99designs/keyring v1.2.2 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0 github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0 + github.com/BurntSushi/toml v1.4.0 github.com/apache/arrow/go/v15 v15.0.0 github.com/aws/aws-sdk-go-v2 v1.26.1 github.com/aws/aws-sdk-go-v2/credentials v1.17.11 @@ -23,7 +24,6 @@ require ( require ( github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect - github.com/BurntSushi/toml v1.4.0 // indirect github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect From 1a79d15080593bfb3409d47c248bba18415d2342 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Thu, 5 Sep 2024 16:43:22 -0700 Subject: [PATCH 08/45] add more testing --- connection_configuration.go | 3 +++ connection_configuration_test.go | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/connection_configuration.go b/connection_configuration.go index 7558753de..596c3d139 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -445,6 +445,9 @@ func getTomlFilePath(filePath string) (string, error) { var dir string if len(filePath) != 0 { dir = filePath + if path.IsAbs(dir) { + return dir, nil + } } else { homeDir, err := os.UserHomeDir() if err != nil { diff --git a/connection_configuration_test.go b/connection_configuration_test.go index f22ea444c..780886782 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -3,6 +3,7 @@ package gosnowflake import ( "io/fs" "os" + path "path/filepath" "testing" "time" ) @@ -262,3 +263,25 @@ func TestParseTomlWithWrongValue(t *testing.T) { } } } + +func TestGetTomlFilePath(t *testing.T) { + dir, err := getTomlFilePath("") + if err != nil { + t.Fatal("should not have failed") + } + homeDir, err := os.UserHomeDir() + if err != nil { + t.Fatal("The connection cannot find the user home directory") + } + + assertEqualF(t, dir, path.Join(homeDir, "snowflake")) + + var location string = "../user//somelocation///b" + dir, err = getTomlFilePath(location) + if err != nil { + t.Fatal("should not have failed") + } + result, err := path.Abs(location) + assertEqualF(t, dir, result) + +} From a9825c59e7ebac9bcbc63c6e239cb97aea842b0f Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Thu, 5 Sep 2024 17:36:46 -0700 Subject: [PATCH 09/45] fix lint --- connection_configuration_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 780886782..0ba5a6bde 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -282,6 +282,9 @@ func TestGetTomlFilePath(t *testing.T) { t.Fatal("should not have failed") } result, err := path.Abs(location) + if err != nil { + t.Fatal("should not have failed") + } assertEqualF(t, dir, result) } From 7e67b834753cf54d5be10a96d698b3d7d1d8af9d Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Fri, 6 Sep 2024 10:32:51 -0700 Subject: [PATCH 10/45] fix error --- connection_configuration_test.go | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 0ba5a6bde..6c0b68413 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -102,13 +102,6 @@ func TestParseInt(t *testing.T) { if err == nil { t.Fatal("should have failed") } - driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrCodeTomlFileParsingFailed { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeTomlFileParsingFailed, driverErr.Number) - } } func TestParseBool(t *testing.T) { @@ -135,13 +128,6 @@ func TestParseBool(t *testing.T) { if err == nil { t.Fatal("should have failed") } - driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrCodeTomlFileParsingFailed { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeTomlFileParsingFailed, driverErr.Number) - } } func TestParseDuration(t *testing.T) { @@ -273,7 +259,6 @@ func TestGetTomlFilePath(t *testing.T) { if err != nil { t.Fatal("The connection cannot find the user home directory") } - assertEqualF(t, dir, path.Join(homeDir, "snowflake")) var location string = "../user//somelocation///b" @@ -286,5 +271,4 @@ func TestGetTomlFilePath(t *testing.T) { t.Fatal("should not have failed") } assertEqualF(t, dir, result) - } From f68f8c921235fcc93f03f78c336a245b69413ffc Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Sat, 7 Sep 2024 01:49:56 -0700 Subject: [PATCH 11/45] replaced all if to assertX in the testing codes --- connection_configuration_test.go | 102 ++++++++----------------------- 1 file changed, 24 insertions(+), 78 deletions(-) diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 6c0b68413..3d65a1a10 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -13,10 +13,7 @@ func TestLoadConnectionConfig_Default(t *testing.T) { cfg, err := LoadConnectionConfig() - if err != nil { - t.Fatalf("err: %v", err) - } - + assertNilF(t, err, "The error should not occured") assertEqualF(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") assertEqualF(t, cfg.User, "test_user") assertEqualF(t, cfg.Password, "test_pass") @@ -32,10 +29,7 @@ func TestLoadConnectionConfig_OAuth(t *testing.T) { os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth") cfg, err := LoadConnectionConfig() - if err != nil { - t.Fatalf("err: %v", err) - } - + assertNilF(t, err, "The error should not occurred") assertEqualF(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") assertEqualF(t, cfg.User, "test_user") assertEqualF(t, cfg.Password, "test_pass") @@ -53,17 +47,11 @@ func TestLoadConnectionConfigWitNonExisitngDSN(t *testing.T) { os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "unavailableDSN") _, err := LoadConnectionConfig() + assertNotNilF(t, err, "The error should be occurred") - if err == nil { - t.Fatal("should have failed") - } driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrCodeFailedToFindDSNInToml { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeFailedToFindDSNInToml, driverErr.Number) - } + assertTrueF(t, ok, "This should be a Snowflake Error") + assertEqualF(t, driverErr.Number, ErrCodeFailedToFindDSNInToml) } func TestLoadConnectionConfigWithTokenFileNotExist(t *testing.T) { @@ -71,11 +59,10 @@ func TestLoadConnectionConfigWithTokenFileNotExist(t *testing.T) { os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth-file") _, err := LoadConnectionConfig() + assertNotNilF(t, err, "The error should be occurred") _, ok := err.(*(fs.PathError)) - if !ok { - t.Fatalf("should be io/fs error. err: %v", err) - } + assertTrueF(t, ok, "This error should be a path error") } func TestParseInt(t *testing.T) { @@ -85,23 +72,17 @@ func TestParseInt(t *testing.T) { i = 20 num, err = parseInt(i) - if err != nil { - t.Fatalf("should be parsed: %v", err) - } + assertNilF(t, err, "This value should be parsed") assertEqualF(t, num, 20) i = "40" num, err = parseInt(i) - if err != nil { - t.Fatalf("should be parsed: %v", err) - } + assertNilF(t, err, "This value should be parsed") assertEqualF(t, num, 40) i = "wrong_num" _, err = parseInt(i) - if err == nil { - t.Fatal("should have failed") - } + assertNotNilF(t, err, "should have failed") } func TestParseBool(t *testing.T) { @@ -111,23 +92,17 @@ func TestParseBool(t *testing.T) { i = true b, err = parseBool(i) - if err != nil { - t.Fatalf("should be parsed: %v", err) - } + assertNilF(t, err, "This value should be parsed") assertEqualF(t, b, true) i = "false" b, err = parseBool(i) - if err != nil { - t.Fatalf("should be parsed: %v", err) - } + assertNilF(t, err, "This value should be parsed") assertEqualF(t, b, false) i = "wrong_bool" _, err = parseInt(i) - if err == nil { - t.Fatal("should have failed") - } + assertNotNilF(t, err, "should have failed") } func TestParseDuration(t *testing.T) { @@ -137,30 +112,17 @@ func TestParseDuration(t *testing.T) { i = 300 dur, err = parseDuration(i) - if err != nil { - t.Fatalf("should be parsed: %v", err) - } + assertNilF(t, err, "This value should be parsed") assertEqualF(t, dur, time.Duration(5*int64(time.Minute))) i = "30" dur, err = parseDuration(i) - if err != nil { - t.Fatalf("should be parsed: %v", err) - } + assertNilF(t, err, "This value should be parsed") assertEqualF(t, dur, time.Duration(int64(time.Minute)/2)) i = false _, err = parseDuration(i) - if err == nil { - t.Fatal("should have failed") - } - driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrCodeTomlFileParsingFailed { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeTomlFileParsingFailed, driverErr.Number) - } + assertNotNilF(t, err, "should have failed") } type paramList struct { @@ -196,9 +158,7 @@ func TestParseToml(t *testing.T) { var connectionMap = make(map[string]interface{}) connectionMap[param] = value err := parseToml(cfg, connectionMap) - if err != nil { - t.Fatal("should not have failed") - } + assertNilF(t, err, "The value should be parsed") }) } } @@ -233,16 +193,10 @@ func TestParseTomlWithWrongValue(t *testing.T) { var connectionMap = make(map[string]interface{}) connectionMap[param] = value err := parseToml(cfg, connectionMap) - if err == nil { - t.Fatal("should have failed") - } + assertNotNilF(t, err, "should have failed") driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrCodeTomlFileParsingFailed { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeTomlFileParsingFailed, driverErr.Number) - } + assertTrueF(t, ok, "This should be a Snowflake Error") + assertEqualF(t, driverErr.Number, ErrCodeTomlFileParsingFailed) }) } @@ -252,23 +206,15 @@ func TestParseTomlWithWrongValue(t *testing.T) { func TestGetTomlFilePath(t *testing.T) { dir, err := getTomlFilePath("") - if err != nil { - t.Fatal("should not have failed") - } + assertNilF(t, err, "should not have failed") homeDir, err := os.UserHomeDir() - if err != nil { - t.Fatal("The connection cannot find the user home directory") - } + assertNilF(t, err, "The connection cannot find the user home directory") assertEqualF(t, dir, path.Join(homeDir, "snowflake")) var location string = "../user//somelocation///b" dir, err = getTomlFilePath(location) - if err != nil { - t.Fatal("should not have failed") - } + assertNilF(t, err, "should not have failed") result, err := path.Abs(location) - if err != nil { - t.Fatal("should not have failed") - } + assertNilF(t, err, "should not have failed") assertEqualF(t, dir, result) } From 41b35ae8c63ec2f5da132f259acce12acc29fd45 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Sat, 7 Sep 2024 01:51:51 -0700 Subject: [PATCH 12/45] fix typos --- connection_configuration_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 3d65a1a10..881640f2e 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -13,7 +13,7 @@ func TestLoadConnectionConfig_Default(t *testing.T) { cfg, err := LoadConnectionConfig() - assertNilF(t, err, "The error should not occured") + assertNilF(t, err, "The error should not occur") assertEqualF(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") assertEqualF(t, cfg.User, "test_user") assertEqualF(t, cfg.Password, "test_pass") @@ -29,7 +29,7 @@ func TestLoadConnectionConfig_OAuth(t *testing.T) { os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth") cfg, err := LoadConnectionConfig() - assertNilF(t, err, "The error should not occurred") + assertNilF(t, err, "The error should not occur") assertEqualF(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") assertEqualF(t, cfg.User, "test_user") assertEqualF(t, cfg.Password, "test_pass") @@ -47,7 +47,7 @@ func TestLoadConnectionConfigWitNonExisitngDSN(t *testing.T) { os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "unavailableDSN") _, err := LoadConnectionConfig() - assertNotNilF(t, err, "The error should be occurred") + assertNotNilF(t, err, "The error should occur") driverErr, ok := err.(*SnowflakeError) assertTrueF(t, ok, "This should be a Snowflake Error") @@ -59,7 +59,7 @@ func TestLoadConnectionConfigWithTokenFileNotExist(t *testing.T) { os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth-file") _, err := LoadConnectionConfig() - assertNotNilF(t, err, "The error should be occurred") + assertNotNilF(t, err, "The error should occur") _, ok := err.(*(fs.PathError)) assertTrueF(t, ok, "This error should be a path error") From 3d6151fdd2dc26dbe80669054489ae0b58668bd3 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Mon, 9 Sep 2024 13:40:51 -0700 Subject: [PATCH 13/45] updated doc, add sample application --- cmd/tomlfileconnection/.gitignore | 1 + cmd/tomlfileconnection/Makefile | 16 ++++++++++++++++ connection_configuration.go | 4 ++++ connection_configuration_test.go | 8 ++++---- doc.go | 5 +++++ connections.toml => test_data/connections.toml | 0 6 files changed, 30 insertions(+), 4 deletions(-) create mode 100644 cmd/tomlfileconnection/.gitignore create mode 100644 cmd/tomlfileconnection/Makefile rename connections.toml => test_data/connections.toml (100%) diff --git a/cmd/tomlfileconnection/.gitignore b/cmd/tomlfileconnection/.gitignore new file mode 100644 index 000000000..3b1a12e76 --- /dev/null +++ b/cmd/tomlfileconnection/.gitignore @@ -0,0 +1 @@ +tomlfileconnection.go \ No newline at end of file diff --git a/cmd/tomlfileconnection/Makefile b/cmd/tomlfileconnection/Makefile new file mode 100644 index 000000000..7813cb2e1 --- /dev/null +++ b/cmd/tomlfileconnection/Makefile @@ -0,0 +1,16 @@ +include ../../gosnowflake.mak +CMD_TARGET=tomlfileconnection + +## Install +install: cinstall + +## Run +run: crun + +## Lint +lint: clint + +## Format source codes +fmt: cfmt + +.PHONY: install run lint fmt diff --git a/connection_configuration.go b/connection_configuration.go index 596c3d139..c06903c53 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -423,6 +423,10 @@ func parseDuration(i interface{}) (time.Duration, error) { } func readToken(tokenPath string) (string, error) { + if tokenPath == "" { + tokenPath = "./snowflake/session/token" + } + if !path.IsAbs(tokenPath) { snowflakeConfigDir, err := getTomlFilePath(os.Getenv("SNOWFLAKE_HOME")) if err != nil { diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 881640f2e..37ed3eb19 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -9,7 +9,7 @@ import ( ) func TestLoadConnectionConfig_Default(t *testing.T) { - os.Setenv("SNOWFLAKE_HOME", "./") + os.Setenv("SNOWFLAKE_HOME", "./test_data") cfg, err := LoadConnectionConfig() @@ -25,7 +25,7 @@ func TestLoadConnectionConfig_Default(t *testing.T) { } func TestLoadConnectionConfig_OAuth(t *testing.T) { - os.Setenv("SNOWFLAKE_HOME", "./") + os.Setenv("SNOWFLAKE_HOME", "./test_data") os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth") cfg, err := LoadConnectionConfig() @@ -43,7 +43,7 @@ func TestLoadConnectionConfig_OAuth(t *testing.T) { } func TestLoadConnectionConfigWitNonExisitngDSN(t *testing.T) { - os.Setenv("SNOWFLAKE_HOME", "./") + os.Setenv("SNOWFLAKE_HOME", "./test_data") os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "unavailableDSN") _, err := LoadConnectionConfig() @@ -55,7 +55,7 @@ func TestLoadConnectionConfigWitNonExisitngDSN(t *testing.T) { } func TestLoadConnectionConfigWithTokenFileNotExist(t *testing.T) { - os.Setenv("SNOWFLAKE_HOME", "./") + os.Setenv("SNOWFLAKE_HOME", "./test_data") os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth-file") _, err := LoadConnectionConfig() diff --git a/doc.go b/doc.go index ec1865151..db700d32b 100644 --- a/doc.go +++ b/doc.go @@ -169,6 +169,11 @@ Note: GOSNOWFLAKE_SKIP_REGISTERATION should not be used if sql.Open() is used as to connect to the server, as sql.Open will require registration so it can map the driver name to the driver type, which in this case is "snowflake" and SnowflakeDriver{}. +After Version 1.11.1 and later, you can load the connnection configuration with .toml file format. +With two environment variables SNOWFLAKE_HOME(connections.toml file directory) SNOWFLAKE_DEFAULT_CONNECTION_NAME(DSN name), +the driver will search the config file and load the connection. You can find how to use this connection way at ./cmd/tomlfileconnection +or Snowflake doc: https://docs.snowflake.com/en/developer-guide/snowflake-cli-v2/connecting/specify-credentials + # Proxy The Go Snowflake Driver honors the environment variables HTTP_PROXY, HTTPS_PROXY and NO_PROXY for the forward proxy setting. diff --git a/connections.toml b/test_data/connections.toml similarity index 100% rename from connections.toml rename to test_data/connections.toml From a7e234fbc403e2242428407c58ebbad72f9a5b73 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Mon, 9 Sep 2024 13:47:51 -0700 Subject: [PATCH 14/45] add sample file --- cmd/tomlfileconnection/tomlfileconnection.go | 54 ++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 cmd/tomlfileconnection/tomlfileconnection.go diff --git a/cmd/tomlfileconnection/tomlfileconnection.go b/cmd/tomlfileconnection/tomlfileconnection.go new file mode 100644 index 000000000..e332ba038 --- /dev/null +++ b/cmd/tomlfileconnection/tomlfileconnection.go @@ -0,0 +1,54 @@ +// Example: How to connect to the server with the toml file configuration +// Prerequiste: following the Snowflake doc: https://docs.snowflake.com/en/developer-guide/snowflake-cli-v2/connecting/specify-credentials +package main + +import ( + "database/sql" + "flag" + "fmt" + "log" + + sf "github.com/snowflakedb/gosnowflake" +) + +func main() { + if !flag.Parsed() { + flag.Parse() + } + + cfg, err := sf.LoadConnectionConfig() + if err != nil { + log.Fatalf("failed to create Config, err: %v", err) + } + dsn, err := sf.DSN(cfg) + if err != nil { + log.Fatalf("failed to create DSN from Config: %v, err: %v", cfg, err) + } + + db, err := sql.Open("snowflake", dsn) + if err != nil { + log.Fatalf("failed to connect. %v, err: %v", dsn, err) + } + defer db.Close() + query := "SELECT 1" + rows, err := db.Query(query) // no cancel is allowed + if err != nil { + log.Fatalf("failed to run a query. %v, err: %v", query, err) + } + defer rows.Close() + var v int + for rows.Next() { + err := rows.Scan(&v) + if err != nil { + log.Fatalf("failed to get result. err: %v", err) + } + if v != 1 { + log.Fatalf("failed to get 1. got: %v", v) + } + } + if rows.Err() != nil { + fmt.Printf("ERROR: %v\n", rows.Err()) + return + } + fmt.Printf("Congrats! You have successfully run %v with Snowflake DB!\n", query) +} From 234532af022da759e05cd1bbc1dfd91f7285f74d Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Mon, 9 Sep 2024 16:43:59 -0700 Subject: [PATCH 15/45] added details and modified the file permission to unify --- cmd/tomlfileconnection/tomlfileconnection.go | 4 ++++ connection_configuration.go | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/cmd/tomlfileconnection/tomlfileconnection.go b/cmd/tomlfileconnection/tomlfileconnection.go index e332ba038..f82b0b629 100644 --- a/cmd/tomlfileconnection/tomlfileconnection.go +++ b/cmd/tomlfileconnection/tomlfileconnection.go @@ -7,6 +7,7 @@ import ( "flag" "fmt" "log" + "os" sf "github.com/snowflakedb/gosnowflake" ) @@ -16,6 +17,9 @@ func main() { flag.Parse() } + os.Setenv("SNOWFLAKE_HOME", "") + os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "") + cfg, err := sf.LoadConnectionConfig() if err != nil { log.Fatalf("failed to create Config, err: %v", err) diff --git a/connection_configuration.go b/connection_configuration.go index c06903c53..9c9eae16a 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -479,7 +479,7 @@ func validateFilePermission(filePath string) error { return err } permission := fileInfo.Mode().Perm() - if permission != 0o600 { + if permission != 0600 { return err } return nil From 84ba09dd29eb7de191d365f552a77d3887fe26fa Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Mon, 9 Sep 2024 18:28:23 -0700 Subject: [PATCH 16/45] remove var keywords --- connection_configuration.go | 213 ++++++++++++++----------------- connection_configuration_test.go | 19 +-- 2 files changed, 100 insertions(+), 132 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 9c9eae16a..940a1d3cf 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -22,7 +22,7 @@ func LoadConnectionConfig() (*Config, error) { Params: make(map[string]*string), Authenticator: AuthTypeSnowflake, // Default to snowflake } - var dsn string = getConnectionDSN(os.Getenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME")) + dsn := getConnectionDSN(os.Getenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME")) snowflakeConfigDir, err := getTomlFilePath(os.Getenv("SNOWFLAKE_HOME")) if err != nil { return nil, err @@ -59,155 +59,151 @@ func LoadConnectionConfig() (*Config, error) { } func parseToml(cfg *Config, connection map[string]interface{}) error { - var ok, vv bool - var err, parsingErr error - err = &SnowflakeError{ + var parsingErr error + var vv bool + var tokenPath string + err := &SnowflakeError{ Number: ErrCodeTomlFileParsingFailed, Message: errMsgFailedToParseTomlFile, } - var v, tokenPath string for key, value := range connection { switch strings.ToLower(key) { case "user", "username": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.User = value.(string) case "password": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.Password = value.(string) case "host": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.Host = value.(string) case "account": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.Account = value.(string) case "warehouse": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.Warehouse = value.(string) case "database": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.Database = value.(string) case "schema": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.Schema = value.(string) case "role": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.Role = value.(string) case "region": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.Region = value.(string) case "protocol": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.Protocol = value.(string) case "passcode": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.Passcode = value.(string) case "port": - if cfg.Port, parsingErr = parseInt(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } case "passcodeinpassword": - if cfg.PasscodeInPassword, parsingErr = parseBool(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } case "clienttimeout": if cfg.ClientTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } case "jwtclienttimeout": if cfg.JWTClientTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } case "logintimeout": - if cfg.LoginTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } case "requesttimeout": if cfg.RequestTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } case "jwttimeout": if cfg.JWTExpireTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } case "externalbrowsertimeout": if cfg.ExternalBrowserTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } case "maxretrycount": - if cfg.MaxRetryCount, parsingErr = parseInt(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } case "application": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.Application = value.(string) case "authenticator": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } - v = value.(string) - err = determineAuthenticatorType(cfg, v) - if err != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + v := value.(string) + parsingErr = determineAuthenticatorType(cfg, v) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} return err } case "insecuremode": if cfg.InsecureMode, parsingErr = parseBool(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } case "ocspfailopen": if vv, parsingErr = parseBool(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } if vv { @@ -217,17 +213,17 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } case "token": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.Token = value.(string) case "privatekey": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } - v = value.(string) + v := value.(string) var decodeErr error block, decodeErr := base64.URLEncoding.DecodeString(v) if decodeErr != nil { @@ -237,14 +233,14 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } return err } - cfg.PrivateKey, err = parsePKCS8PrivateKey(block) - if err != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + cfg.PrivateKey, parsingErr = parsePKCS8PrivateKey(block) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} return err } case "validatedefaultparameters": if vv, parsingErr = parseBool(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } if vv { @@ -254,7 +250,7 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } case "clientrequestmfatoken": if vv, parsingErr = parseBool(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } if vv { @@ -264,7 +260,7 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } case "clientstoretemporarycredential": if vv, parsingErr = parseBool(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } if vv { @@ -273,26 +269,26 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { cfg.ClientStoreTemporaryCredential = ConfigBoolFalse } case "tracing": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.Tracing = value.(string) case "tmpdirpath": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.TmpDirPath = value.(string) case "disablequerycontextcache": if vv, parsingErr = parseBool(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } cfg.DisableQueryContextCache = vv case "includeretryreason": if vv, parsingErr = parseBool(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } if vv { @@ -301,14 +297,14 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { cfg.IncludeRetryReason = ConfigBoolFalse } case "clientconfigfile": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } cfg.ClientConfigFile = value.(string) case "disableconsolelogin": if vv, parsingErr = parseBool(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } if vv { @@ -318,7 +314,7 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } case "disablesamlurlcheck": if vv, parsingErr = parseBool(value); parsingErr != nil { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + err.MessageArgs = []interface{}{key, value} return err } if vv { @@ -327,18 +323,17 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { cfg.DisableSamlURLCheck = ConfigBoolFalse } case "token_file_path": - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } tokenPath = value.(string) default: - var param string - if _, ok = value.(string); !ok { - err.(*SnowflakeError).MessageArgs = []interface{}{key, value} + if _, ok := value.(string); !ok { + err.MessageArgs = []interface{}{key, value} return err } - param = value.(string) + param := value.(string) cfg.Params[urlDecodeIfNeeded(key)] = ¶m } } @@ -353,71 +348,51 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } func parseInt(i interface{}) (int, error) { - var v string - var ok bool - var num int = 0 - var err error = errors.New("parse Error") - if _, ok = i.(string); !ok { - if _, ok = i.(int); !ok { - return num, err + if _, ok := i.(string); !ok { + if _, ok := i.(int); !ok { + return 0, errors.New("parse Error") } - num = i.(int) + num := i.(int) return num, nil } - v = i.(string) + v := i.(string) + num, err := strconv.Atoi(v) - if num, err = strconv.Atoi(v); err != nil { + if err != nil { return num, err } return num, nil } func parseBool(i interface{}) (bool, error) { - var v string - var ok, vv bool - var err, parseErr error - parseErr = &SnowflakeError{ - Number: ErrCodeTomlFileParsingFailed, - Message: errMsgFailedToParseTomlFile, - MessageArgs: []interface{}{i}, - } - if _, ok = i.(string); !ok { - if _, ok = i.(bool); !ok { - return false, parseErr + if _, ok := i.(string); !ok { + if _, ok := i.(bool); !ok { + return false, errors.New("parse Error") } - vv = i.(bool) + vv := i.(bool) return vv, nil } - v = i.(string) - vv, err = strconv.ParseBool(v) + v := i.(string) + vv, err := strconv.ParseBool(v) if err != nil { - return false, parseErr + return false, errors.New("parse Error") } return vv, nil } func parseDuration(i interface{}) (time.Duration, error) { - var v string - var ok bool - var num int - var t int64 - var err, parseErr error - parseErr = &SnowflakeError{ - Number: ErrCodeTomlFileParsingFailed, - Message: errMsgFailedToParseTomlFile, - MessageArgs: []interface{}{i}, - } - if _, ok = i.(string); !ok { - if num, err = parseInt(i); err != nil { - return time.Duration(0), parseErr + if _, ok := i.(string); !ok { + num, err := parseInt(i) + if err != nil { + return time.Duration(0), err } - t = int64(num) + t := int64(num) return time.Duration(t * int64(time.Second)), nil } - v = i.(string) - t, err = strconv.ParseInt(v, 10, 64) + v := i.(string) + t, err := strconv.ParseInt(v, 10, 64) if err != nil { - return time.Duration(0), parseErr + return time.Duration(0), err } return time.Duration(t * int64(time.Second)), nil } diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 37ed3eb19..694018d38 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -67,11 +67,8 @@ func TestLoadConnectionConfigWithTokenFileNotExist(t *testing.T) { func TestParseInt(t *testing.T) { var i interface{} - var num int - var err error - i = 20 - num, err = parseInt(i) + num, err := parseInt(i) assertNilF(t, err, "This value should be parsed") assertEqualF(t, num, 20) @@ -87,11 +84,9 @@ func TestParseInt(t *testing.T) { func TestParseBool(t *testing.T) { var i interface{} - var b bool - var err error i = true - b, err = parseBool(i) + b, err := parseBool(i) assertNilF(t, err, "This value should be parsed") assertEqualF(t, b, true) @@ -107,11 +102,9 @@ func TestParseBool(t *testing.T) { func TestParseDuration(t *testing.T) { var i interface{} - var dur time.Duration - var err error i = 300 - dur, err = parseDuration(i) + dur, err := parseDuration(i) assertNilF(t, err, "This value should be parsed") assertEqualF(t, dur, time.Duration(5*int64(time.Minute))) @@ -155,7 +148,7 @@ func TestParseToml(t *testing.T) { for _, value := range testCase.values { t.Run(param, func(t *testing.T) { cfg := &Config{} - var connectionMap = make(map[string]interface{}) + connectionMap := make(map[string]interface{}) connectionMap[param] = value err := parseToml(cfg, connectionMap) assertNilF(t, err, "The value should be parsed") @@ -190,7 +183,7 @@ func TestParseTomlWithWrongValue(t *testing.T) { for _, value := range testCase.values { t.Run(param, func(t *testing.T) { cfg := &Config{} - var connectionMap = make(map[string]interface{}) + connectionMap := make(map[string]interface{}) connectionMap[param] = value err := parseToml(cfg, connectionMap) assertNotNilF(t, err, "should have failed") @@ -211,7 +204,7 @@ func TestGetTomlFilePath(t *testing.T) { assertNilF(t, err, "The connection cannot find the user home directory") assertEqualF(t, dir, path.Join(homeDir, "snowflake")) - var location string = "../user//somelocation///b" + location := "../user//somelocation///b" dir, err = getTomlFilePath(location) assertNilF(t, err, "should not have failed") result, err := path.Abs(location) From 9eeaaf3c250e7d0dfa0305cd7291f59cd1e9bd13 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Mon, 9 Sep 2024 19:28:35 -0700 Subject: [PATCH 17/45] fix error and refactored code --- connection_configuration.go | 121 +++++++++++++++++++----------------- errors.go | 2 +- 2 files changed, 66 insertions(+), 57 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 940a1d3cf..affa19969 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -32,7 +32,7 @@ func LoadConnectionConfig() (*Config, error) { if err != nil { return nil, err } - var tomlInfo = make(map[string]interface{}) + tomlInfo := make(map[string]interface{}) _, err = toml.DecodeFile(tomlFilePath, &tomlInfo) if err != nil { @@ -59,9 +59,9 @@ func LoadConnectionConfig() (*Config, error) { } func parseToml(cfg *Config, connection map[string]interface{}) error { + var v, tokenPath string var parsingErr error var vv bool - var tokenPath string err := &SnowflakeError{ Number: ErrCodeTomlFileParsingFailed, Message: errMsgFailedToParseTomlFile, @@ -69,71 +69,71 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { for key, value := range connection { switch strings.ToLower(key) { case "user", "username": - if _, ok := value.(string); !ok { + cfg.User, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.User = value.(string) case "password": - if _, ok := value.(string); !ok { + cfg.Password, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.Password = value.(string) case "host": - if _, ok := value.(string); !ok { + cfg.Host, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.Host = value.(string) case "account": - if _, ok := value.(string); !ok { + cfg.Account, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.Account = value.(string) case "warehouse": - if _, ok := value.(string); !ok { + cfg.Warehouse, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.Warehouse = value.(string) case "database": - if _, ok := value.(string); !ok { + cfg.Database, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.Database = value.(string) case "schema": - if _, ok := value.(string); !ok { + cfg.Schema, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.Schema = value.(string) case "role": - if _, ok := value.(string); !ok { + cfg.Role, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.Role = value.(string) case "region": - if _, ok := value.(string); !ok { + cfg.Region, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.Region = value.(string) case "protocol": - if _, ok := value.(string); !ok { + cfg.Protocol, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.Protocol = value.(string) case "passcode": - if _, ok := value.(string); !ok { + cfg.Passcode, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.Passcode = value.(string) case "port": if cfg.Port, parsingErr = parseInt(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} @@ -180,17 +180,17 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { return err } case "application": - if _, ok := value.(string); !ok { + cfg.Application, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.Application = value.(string) case "authenticator": - if _, ok := value.(string); !ok { + v, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - v := value.(string) parsingErr = determineAuthenticatorType(cfg, v) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} @@ -213,17 +213,17 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } case "token": - if _, ok := value.(string); !ok { + cfg.Token, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.Token = value.(string) case "privatekey": - if _, ok := value.(string); !ok { + v, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - v := value.(string) var decodeErr error block, decodeErr := base64.URLEncoding.DecodeString(v) if decodeErr != nil { @@ -269,17 +269,17 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { cfg.ClientStoreTemporaryCredential = ConfigBoolFalse } case "tracing": - if _, ok := value.(string); !ok { + cfg.Tracing, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.Tracing = value.(string) case "tmpdirpath": - if _, ok := value.(string); !ok { + cfg.TmpDirPath, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.TmpDirPath = value.(string) case "disablequerycontextcache": if vv, parsingErr = parseBool(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} @@ -297,11 +297,11 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { cfg.IncludeRetryReason = ConfigBoolFalse } case "clientconfigfile": - if _, ok := value.(string); !ok { + cfg.ClientConfigFile, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.ClientConfigFile = value.(string) case "disableconsolelogin": if vv, parsingErr = parseBool(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} @@ -323,17 +323,18 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { cfg.DisableSamlURLCheck = ConfigBoolFalse } case "token_file_path": - if _, ok := value.(string); !ok { + tokenPath, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - tokenPath = value.(string) default: - if _, ok := value.(string); !ok { + var param string + param, parsingErr = populateSessionParams(value) + if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - param := value.(string) cfg.Params[urlDecodeIfNeeded(key)] = ¶m } } @@ -365,36 +366,36 @@ func parseInt(i interface{}) (int, error) { } func parseBool(i interface{}) (bool, error) { - if _, ok := i.(string); !ok { + if v, ok := i.(string); !ok { if _, ok := i.(bool); !ok { return false, errors.New("parse Error") } vv := i.(bool) return vv, nil + } else { + vv, err := strconv.ParseBool(v) + if err != nil { + return false, errors.New("parse Error") + } + return vv, nil } - v := i.(string) - vv, err := strconv.ParseBool(v) - if err != nil { - return false, errors.New("parse Error") - } - return vv, nil } func parseDuration(i interface{}) (time.Duration, error) { - if _, ok := i.(string); !ok { + if v, ok := i.(string); !ok { num, err := parseInt(i) if err != nil { return time.Duration(0), err } t := int64(num) return time.Duration(t * int64(time.Second)), nil + } else { + t, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return time.Duration(0), err + } + return time.Duration(t * int64(time.Second)), nil } - v := i.(string) - t, err := strconv.ParseInt(v, 10, 64) - if err != nil { - return time.Duration(0), err - } - return time.Duration(t * int64(time.Second)), nil } func readToken(tokenPath string) (string, error) { @@ -420,6 +421,14 @@ func readToken(tokenPath string) (string, error) { return string(token), nil } +func populateSessionParams(i interface{}) (string, error) { + if v, ok := i.(string); !ok { + return "", errors.New("Error") + } else { + return v, nil + } +} + func getTomlFilePath(filePath string) (string, error) { var dir string if len(filePath) != 0 { diff --git a/errors.go b/errors.go index 77d6d214e..7e56040fe 100644 --- a/errors.go +++ b/errors.go @@ -130,7 +130,7 @@ const ( // ErrCodeTomlFileParsingFailed is an error code for the case where parsing the toml file is failed because of invalid value. ErrCodeTomlFileParsingFailed = 260013 // ErrCodeFailedToFindDSNInToml is an error code for the case where the DSN does not exist in the toml file. - ErrCodeFailedToFindDSNInToml = 260013 + ErrCodeFailedToFindDSNInToml = 260014 /* network */ From 4eaa365e185b79b5eda207e00d34b6106aad3d03 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Mon, 9 Sep 2024 19:52:22 -0700 Subject: [PATCH 18/45] add error --- connection_configuration.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connection_configuration.go b/connection_configuration.go index affa19969..3e961dbc3 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -464,7 +464,7 @@ func validateFilePermission(filePath string) error { } permission := fileInfo.Mode().Perm() if permission != 0600 { - return err + return errors.New("Your access to the file was denied. Please check the permission of your toml file") } return nil } From 54bf52030677debe1f833074d180ccfcddf19092 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Tue, 10 Sep 2024 11:02:31 -0700 Subject: [PATCH 19/45] fix the file permission issue --- connection_configuration.go | 50 +++++++++++++++++--------------- connection_configuration_test.go | 16 ++++++++-- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 3e961dbc3..ee5881dc7 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -5,6 +5,7 @@ package gosnowflake import ( "encoding/base64" "errors" + "fmt" "os" path "path/filepath" "strconv" @@ -69,67 +70,67 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { for key, value := range connection { switch strings.ToLower(key) { case "user", "username": - cfg.User, parsingErr = populateSessionParams(value) + cfg.User, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "password": - cfg.Password, parsingErr = populateSessionParams(value) + cfg.Password, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "host": - cfg.Host, parsingErr = populateSessionParams(value) + cfg.Host, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "account": - cfg.Account, parsingErr = populateSessionParams(value) + cfg.Account, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "warehouse": - cfg.Warehouse, parsingErr = populateSessionParams(value) + cfg.Warehouse, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "database": - cfg.Database, parsingErr = populateSessionParams(value) + cfg.Database, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "schema": - cfg.Schema, parsingErr = populateSessionParams(value) + cfg.Schema, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "role": - cfg.Role, parsingErr = populateSessionParams(value) + cfg.Role, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "region": - cfg.Region, parsingErr = populateSessionParams(value) + cfg.Region, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "protocol": - cfg.Protocol, parsingErr = populateSessionParams(value) + cfg.Protocol, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "passcode": - cfg.Passcode, parsingErr = populateSessionParams(value) + cfg.Passcode, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err @@ -180,13 +181,13 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { return err } case "application": - cfg.Application, parsingErr = populateSessionParams(value) + cfg.Application, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "authenticator": - v, parsingErr = populateSessionParams(value) + v, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err @@ -213,13 +214,13 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } case "token": - cfg.Token, parsingErr = populateSessionParams(value) + cfg.Token, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "privatekey": - v, parsingErr = populateSessionParams(value) + v, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err @@ -269,13 +270,13 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { cfg.ClientStoreTemporaryCredential = ConfigBoolFalse } case "tracing": - cfg.Tracing, parsingErr = populateSessionParams(value) + cfg.Tracing, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "tmpdirpath": - cfg.TmpDirPath, parsingErr = populateSessionParams(value) + cfg.TmpDirPath, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err @@ -297,7 +298,7 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { cfg.IncludeRetryReason = ConfigBoolFalse } case "clientconfigfile": - cfg.ClientConfigFile, parsingErr = populateSessionParams(value) + cfg.ClientConfigFile, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err @@ -323,14 +324,14 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { cfg.DisableSamlURLCheck = ConfigBoolFalse } case "token_file_path": - tokenPath, parsingErr = populateSessionParams(value) + tokenPath, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } default: var param string - param, parsingErr = populateSessionParams(value) + param, parsingErr = parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err @@ -421,7 +422,7 @@ func readToken(tokenPath string) (string, error) { return string(token), nil } -func populateSessionParams(i interface{}) (string, error) { +func parseString(i interface{}) (string, error) { if v, ok := i.(string); !ok { return "", errors.New("Error") } else { @@ -458,12 +459,15 @@ func getConnectionDSN(dsn string) string { } func validateFilePermission(filePath string) error { + if isWindows { + return nil + } fileInfo, err := os.Stat(filePath) if err != nil { return err } - permission := fileInfo.Mode().Perm() - if permission != 0600 { + permission := fmt.Sprintf("%04o", fileInfo.Mode().Perm()) + if permission != "0600" { return errors.New("Your access to the file was denied. Please check the permission of your toml file") } return nil diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 694018d38..64057304b 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -9,6 +9,9 @@ import ( ) func TestLoadConnectionConfig_Default(t *testing.T) { + err := os.Chmod("./test_data/connections.toml", 0600) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + os.Setenv("SNOWFLAKE_HOME", "./test_data") cfg, err := LoadConnectionConfig() @@ -25,6 +28,9 @@ func TestLoadConnectionConfig_Default(t *testing.T) { } func TestLoadConnectionConfig_OAuth(t *testing.T) { + err := os.Chmod("./test_data/connections.toml", 0600) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + os.Setenv("SNOWFLAKE_HOME", "./test_data") os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth") cfg, err := LoadConnectionConfig() @@ -43,10 +49,13 @@ func TestLoadConnectionConfig_OAuth(t *testing.T) { } func TestLoadConnectionConfigWitNonExisitngDSN(t *testing.T) { + err := os.Chmod("./test_data/connections.toml", 0600) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + os.Setenv("SNOWFLAKE_HOME", "./test_data") os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "unavailableDSN") - _, err := LoadConnectionConfig() + _, err = LoadConnectionConfig() assertNotNilF(t, err, "The error should occur") driverErr, ok := err.(*SnowflakeError) @@ -55,10 +64,13 @@ func TestLoadConnectionConfigWitNonExisitngDSN(t *testing.T) { } func TestLoadConnectionConfigWithTokenFileNotExist(t *testing.T) { + err := os.Chmod("./test_data/connections.toml", 0600) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + os.Setenv("SNOWFLAKE_HOME", "./test_data") os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth-file") - _, err := LoadConnectionConfig() + _, err = LoadConnectionConfig() assertNotNilF(t, err, "The error should occur") _, ok := err.(*(fs.PathError)) From f7e1932e8b179511d92b07b603c49fcf5fbb2c1a Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Tue, 10 Sep 2024 11:19:39 -0700 Subject: [PATCH 20/45] fix lint --- connection_configuration.go | 38 +++++++++++++++++--------------- connection_configuration_test.go | 2 +- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index ee5881dc7..38815e22e 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -367,36 +367,38 @@ func parseInt(i interface{}) (int, error) { } func parseBool(i interface{}) (bool, error) { - if v, ok := i.(string); !ok { + v, ok := i.(string) + if !ok { if _, ok := i.(bool); !ok { return false, errors.New("parse Error") } vv := i.(bool) return vv, nil - } else { - vv, err := strconv.ParseBool(v) - if err != nil { - return false, errors.New("parse Error") - } - return vv, nil } + vv, err := strconv.ParseBool(v) + if err != nil { + return false, errors.New("parse Error") + } + return vv, nil + } func parseDuration(i interface{}) (time.Duration, error) { - if v, ok := i.(string); !ok { + v, ok := i.(string) + if !ok { num, err := parseInt(i) if err != nil { return time.Duration(0), err } t := int64(num) return time.Duration(t * int64(time.Second)), nil - } else { - t, err := strconv.ParseInt(v, 10, 64) - if err != nil { - return time.Duration(0), err - } - return time.Duration(t * int64(time.Second)), nil } + t, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return time.Duration(0), err + } + return time.Duration(t * int64(time.Second)), nil + } func readToken(tokenPath string) (string, error) { @@ -423,11 +425,11 @@ func readToken(tokenPath string) (string, error) { } func parseString(i interface{}) (string, error) { - if v, ok := i.(string); !ok { + v, ok := i.(string) + if !ok { return "", errors.New("Error") - } else { - return v, nil } + return v, nil } func getTomlFilePath(filePath string) (string, error) { @@ -468,7 +470,7 @@ func validateFilePermission(filePath string) error { } permission := fmt.Sprintf("%04o", fileInfo.Mode().Perm()) if permission != "0600" { - return errors.New("Your access to the file was denied. Please check the permission of your toml file") + return errors.New("your access to the file was denied") } return nil } diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 64057304b..429967151 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -176,7 +176,7 @@ func TestParseTomlWithWrongValue(t *testing.T) { testParams: []string{"user", "password", "host", "account", "warehouse", "database", "schema", "role", "region", "protocol", "passcode", "application", "token", "privateKey", "tracing", "tmpDirPath", "clientConfigFile", "wrongParams"}, - values: []interface{}{1}, + values: []interface{}{1, false}, }, { testParams: []string{"port", "maxRetryCount", "clientTimeout", "jwtClientTimeout", "loginTimeout", From 061445ff5368bf9cd162ed17b254ca09d58a403c Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Tue, 10 Sep 2024 12:32:07 -0700 Subject: [PATCH 21/45] refactored --- connection_configuration.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 38815e22e..29713f160 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -5,7 +5,6 @@ package gosnowflake import ( "encoding/base64" "errors" - "fmt" "os" path "path/filepath" "strconv" @@ -468,8 +467,7 @@ func validateFilePermission(filePath string) error { if err != nil { return err } - permission := fmt.Sprintf("%04o", fileInfo.Mode().Perm()) - if permission != "0600" { + if permission := fileInfo.Mode().Perm(); permission != os.FileMode(0600) { return errors.New("your access to the file was denied") } return nil From 54123ebfdbdb4082c0e39e81c611ab64dd30a45c Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Tue, 10 Sep 2024 13:43:06 -0700 Subject: [PATCH 22/45] remove spaces --- connection_configuration.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 29713f160..bb5e0da97 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -379,7 +379,6 @@ func parseBool(i interface{}) (bool, error) { return false, errors.New("parse Error") } return vv, nil - } func parseDuration(i interface{}) (time.Duration, error) { @@ -397,14 +396,12 @@ func parseDuration(i interface{}) (time.Duration, error) { return time.Duration(0), err } return time.Duration(t * int64(time.Second)), nil - } func readToken(tokenPath string) (string, error) { if tokenPath == "" { tokenPath = "./snowflake/session/token" } - if !path.IsAbs(tokenPath) { snowflakeConfigDir, err := getTomlFilePath(os.Getenv("SNOWFLAKE_HOME")) if err != nil { From 9fcaaff26c8336d6e631527b47c30bb232327b1c Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Wed, 11 Sep 2024 12:06:46 -0700 Subject: [PATCH 23/45] fix --- connection_configuration.go | 18 ++++++------------ connection_configuration_test.go | 1 + 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index bb5e0da97..a03d5d366 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -33,7 +33,6 @@ func LoadConnectionConfig() (*Config, error) { return nil, err } tomlInfo := make(map[string]interface{}) - _, err = toml.DecodeFile(tomlFilePath, &tomlInfo) if err != nil { return nil, err @@ -50,7 +49,6 @@ func LoadConnectionConfig() (*Config, error) { if !ok { return nil, err } - err = parseToml(cfg, connectionConfig) if err != nil { return nil, err @@ -224,7 +222,6 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { err.MessageArgs = []interface{}{key, value} return err } - var decodeErr error block, decodeErr := base64.URLEncoding.DecodeString(v) if decodeErr != nil { err = &SnowflakeError{ @@ -329,8 +326,7 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { return err } default: - var param string - param, parsingErr = parseString(value) + param, parsingErr := parseString(value) if parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err @@ -423,26 +419,24 @@ func readToken(tokenPath string) (string, error) { func parseString(i interface{}) (string, error) { v, ok := i.(string) if !ok { - return "", errors.New("Error") + return "", errors.New("failed to convert the value to string") } return v, nil } func getTomlFilePath(filePath string) (string, error) { - var dir string if len(filePath) != 0 { - dir = filePath - if path.IsAbs(dir) { - return dir, nil + if path.IsAbs(filePath) { + return filePath, nil } } else { homeDir, err := os.UserHomeDir() if err != nil { return "", err } - dir = path.Join(homeDir, "snowflake") + filePath = path.Join(homeDir, "snowflake") } - absDir, err := path.Abs(dir) + absDir, err := path.Abs(filePath) if err != nil { return "", err } diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 429967151..666c89eed 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -79,6 +79,7 @@ func TestLoadConnectionConfigWithTokenFileNotExist(t *testing.T) { func TestParseInt(t *testing.T) { var i interface{} + i = 20 num, err := parseInt(i) assertNilF(t, err, "This value should be parsed") From 363fb30e628136eb1eacfcccc367c6bd664a6f7e Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Wed, 11 Sep 2024 17:21:23 -0700 Subject: [PATCH 24/45] fix error message --- connection_configuration.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index a03d5d366..551fca3e8 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -347,7 +347,7 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { func parseInt(i interface{}) (int, error) { if _, ok := i.(string); !ok { if _, ok := i.(int); !ok { - return 0, errors.New("parse Error") + return 0, errors.New("Failed to parse the value to integer") } num := i.(int) return num, nil @@ -365,14 +365,14 @@ func parseBool(i interface{}) (bool, error) { v, ok := i.(string) if !ok { if _, ok := i.(bool); !ok { - return false, errors.New("parse Error") + return false, errors.New("Failed to parse the value to boolean") } vv := i.(bool) return vv, nil } vv, err := strconv.ParseBool(v) if err != nil { - return false, errors.New("parse Error") + return false, errors.New("Failed to parse the value to boolean") } return vv, nil } From 3c880d04b90192c9f2223ffbd2f62dc1283b64d8 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Wed, 11 Sep 2024 17:22:38 -0700 Subject: [PATCH 25/45] lint fix --- connection_configuration.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 551fca3e8..6d148dbbc 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -347,7 +347,7 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { func parseInt(i interface{}) (int, error) { if _, ok := i.(string); !ok { if _, ok := i.(int); !ok { - return 0, errors.New("Failed to parse the value to integer") + return 0, errors.New("failed to parse the value to integer") } num := i.(int) return num, nil @@ -365,14 +365,14 @@ func parseBool(i interface{}) (bool, error) { v, ok := i.(string) if !ok { if _, ok := i.(bool); !ok { - return false, errors.New("Failed to parse the value to boolean") + return false, errors.New("failed to parse the value to boolean") } vv := i.(bool) return vv, nil } vv, err := strconv.ParseBool(v) if err != nil { - return false, errors.New("Failed to parse the value to boolean") + return false, errors.New("failed to parse the value to boolean") } return vv, nil } From d34546352c582f58a0e7f67e45185aacdb28e5ab Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Wed, 11 Sep 2024 18:03:18 -0700 Subject: [PATCH 26/45] add one more testing to read token from a file --- connection_configuration.go | 3 +-- connection_configuration_test.go | 19 +++++++++++++++++-- doc.go | 2 +- test_data/connections.toml | 6 +++++- 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 6d148dbbc..25128cd94 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -39,11 +39,10 @@ func LoadConnectionConfig() (*Config, error) { } connectionName, exist := tomlInfo[dsn] if !exist { - err = &SnowflakeError{ + return nil, &SnowflakeError{ Number: ErrCodeFailedToFindDSNInToml, Message: errMsgFailedToFindDSNInTomlFile, } - return nil, err } connectionConfig, ok := connectionName.(map[string]interface{}) if !ok { diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 666c89eed..5969829f9 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -15,7 +15,6 @@ func TestLoadConnectionConfig_Default(t *testing.T) { os.Setenv("SNOWFLAKE_HOME", "./test_data") cfg, err := LoadConnectionConfig() - assertNilF(t, err, "The error should not occur") assertEqualF(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") assertEqualF(t, cfg.User, "test_user") @@ -33,8 +32,8 @@ func TestLoadConnectionConfig_OAuth(t *testing.T) { os.Setenv("SNOWFLAKE_HOME", "./test_data") os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth") - cfg, err := LoadConnectionConfig() + cfg, err := LoadConnectionConfig() assertNilF(t, err, "The error should not occur") assertEqualF(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") assertEqualF(t, cfg.User, "test_user") @@ -48,6 +47,22 @@ func TestLoadConnectionConfig_OAuth(t *testing.T) { assertEqualF(t, cfg.Port, 443) } +func TestReadTokenValueWithTokenFilePath(t *testing.T) { + err := os.Chmod("./test_data/connections.toml", 0600) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + + err = os.Chmod("./test_data/token_file/token", 0600) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + + os.Setenv("SNOWFLAKE_HOME", "./test_data") + os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "read-token") + + cfg, err := LoadConnectionConfig() + assertNilF(t, err, "The error should not occur") + assertEqualF(t, cfg.Authenticator, AuthTypeOAuth) + assertEqualF(t, cfg.Token, "mock_token123456") +} + func TestLoadConnectionConfigWitNonExisitngDSN(t *testing.T) { err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") diff --git a/doc.go b/doc.go index db700d32b..1b923fee3 100644 --- a/doc.go +++ b/doc.go @@ -169,7 +169,7 @@ Note: GOSNOWFLAKE_SKIP_REGISTERATION should not be used if sql.Open() is used as to connect to the server, as sql.Open will require registration so it can map the driver name to the driver type, which in this case is "snowflake" and SnowflakeDriver{}. -After Version 1.11.1 and later, you can load the connnection configuration with .toml file format. +After Version 1.11.2 and later, you can load the connnection configuration with .toml file format. With two environment variables SNOWFLAKE_HOME(connections.toml file directory) SNOWFLAKE_DEFAULT_CONNECTION_NAME(DSN name), the driver will search the config file and load the connection. You can find how to use this connection way at ./cmd/tomlfileconnection or Snowflake doc: https://docs.snowflake.com/en/developer-guide/snowflake-cli-v2/connecting/specify-credentials diff --git a/test_data/connections.toml b/test_data/connections.toml index 05e71c547..b3c52ed84 100644 --- a/test_data/connections.toml +++ b/test_data/connections.toml @@ -32,4 +32,8 @@ protocol = 'https' port = '443' authenticator = 'oauth' testNot = 'problematicParameter' -token_file_path = '/Users/test/.snowflake/token' \ No newline at end of file +token_file_path = '/Users/test/.snowflake/token' + +[read-token] +authenticator = 'oauth' +token_file_path = './token_file/token' \ No newline at end of file From bc2f5157009149da39be3bb6f3f4f40dc8006dd9 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Wed, 11 Sep 2024 18:04:14 -0700 Subject: [PATCH 27/45] add mock token file --- test_data/token_file/token | 1 + 1 file changed, 1 insertion(+) create mode 100644 test_data/token_file/token diff --git a/test_data/token_file/token b/test_data/token_file/token new file mode 100644 index 000000000..6db96c0a3 --- /dev/null +++ b/test_data/token_file/token @@ -0,0 +1 @@ +mock_token123456 \ No newline at end of file From 2099982c009bb2528b2a39bded06a34091b0c054 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Thu, 12 Sep 2024 12:54:02 -0700 Subject: [PATCH 28/45] add more testing cases --- connection_configuration_test.go | 27 ++++++++++++++++++++++----- test_data/connections.toml | 5 ++++- test_data/token_file/token | 1 - 3 files changed, 26 insertions(+), 7 deletions(-) delete mode 100644 test_data/token_file/token diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 5969829f9..dd22a348e 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -9,6 +9,11 @@ import ( ) func TestLoadConnectionConfig_Default(t *testing.T) { + if !isWindows { + _, err := LoadConnectionConfig() + assertNotNilF(t, err, "The error should occur because you cannot change the file permission") + } + err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") @@ -51,19 +56,27 @@ func TestReadTokenValueWithTokenFilePath(t *testing.T) { err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") - err = os.Chmod("./test_data/token_file/token", 0600) + err = os.Chmod("./test_data/snowflake/session/token", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") os.Setenv("SNOWFLAKE_HOME", "./test_data") - os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "read-token") + os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "no-token-path") cfg, err := LoadConnectionConfig() assertNilF(t, err, "The error should not occur") assertEqualF(t, cfg.Authenticator, AuthTypeOAuth) assertEqualF(t, cfg.Token, "mock_token123456") + + os.Setenv("SNOWFLAKE_HOME", "./test_data") + os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "read-token") + + cfg, err = LoadConnectionConfig() + assertNilF(t, err, "The error should not occur") + assertEqualF(t, cfg.Authenticator, AuthTypeOAuth) + assertEqualF(t, cfg.Token, "mock_token123456") } -func TestLoadConnectionConfigWitNonExisitngDSN(t *testing.T) { +func TestLoadConnectionConfigWitNonExistingDSN(t *testing.T) { err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") @@ -159,6 +172,10 @@ func TestParseToml(t *testing.T) { "tracing", "tmpDirPath", "clientConfigFile"}, values: []interface{}{"value"}, }, + { + testParams: []string{"privatekey"}, + values: []interface{}{generatePKCS8StringSupress(testPrivKey)}, + }, { testParams: []string{"port", "maxRetryCount", "clientTimeout", "jwtClientTimeout", "loginTimeout", "requestTimeout", "jwtTimeout", "externalBrowserTimeout"}, @@ -191,12 +208,12 @@ func TestParseTomlWithWrongValue(t *testing.T) { { testParams: []string{"user", "password", "host", "account", "warehouse", "database", "schema", "role", "region", "protocol", "passcode", "application", "token", "privateKey", - "tracing", "tmpDirPath", "clientConfigFile", "wrongParams"}, + "tracing", "tmpDirPath", "clientConfigFile", "wrongParams", "token_file_path"}, values: []interface{}{1, false}, }, { testParams: []string{"port", "maxRetryCount", "clientTimeout", "jwtClientTimeout", "loginTimeout", - "requestTimeout", "jwtTimeout", "externalBrowserTimeout"}, + "requestTimeout", "jwtTimeout", "externalBrowserTimeout", "authenticator"}, values: []interface{}{"wrong_value", false}, }, { diff --git a/test_data/connections.toml b/test_data/connections.toml index b3c52ed84..878e4fc23 100644 --- a/test_data/connections.toml +++ b/test_data/connections.toml @@ -36,4 +36,7 @@ token_file_path = '/Users/test/.snowflake/token' [read-token] authenticator = 'oauth' -token_file_path = './token_file/token' \ No newline at end of file +token_file_path = './snowflake/session/token' + +[no-token-path] +authenticator = 'oauth' \ No newline at end of file diff --git a/test_data/token_file/token b/test_data/token_file/token deleted file mode 100644 index 6db96c0a3..000000000 --- a/test_data/token_file/token +++ /dev/null @@ -1 +0,0 @@ -mock_token123456 \ No newline at end of file From 3b2e0864e337313edbc5c9f668ae1d744cc944f1 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Thu, 12 Sep 2024 12:54:31 -0700 Subject: [PATCH 29/45] moved the testing token file --- test_data/snowflake/session/token | 1 + 1 file changed, 1 insertion(+) create mode 100644 test_data/snowflake/session/token diff --git a/test_data/snowflake/session/token b/test_data/snowflake/session/token new file mode 100644 index 000000000..6db96c0a3 --- /dev/null +++ b/test_data/snowflake/session/token @@ -0,0 +1 @@ +mock_token123456 \ No newline at end of file From 9dbe2ab5ea62e9daa4e5cb7a0d8af800ffab85d7 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Thu, 12 Sep 2024 13:55:20 -0700 Subject: [PATCH 30/45] add copyright --- connection_configuration_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/connection_configuration_test.go b/connection_configuration_test.go index dd22a348e..7ee2df0dd 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -1,3 +1,5 @@ +// Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + package gosnowflake import ( From 9c8c15013a02b4b477aa653a7415f47a5c881853 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:37:29 -0700 Subject: [PATCH 31/45] fix --- connection_configuration.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 25128cd94..c3660015e 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -223,11 +223,10 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } block, decodeErr := base64.URLEncoding.DecodeString(v) if decodeErr != nil { - err = &SnowflakeError{ + return &SnowflakeError{ Number: ErrCodePrivateKeyParseError, Message: "Base64 decode failed", } - return err } cfg.PrivateKey, parsingErr = parsePKCS8PrivateKey(block) if parsingErr != nil { From c57441129340606f8d4b07a1c38db1742162baf2 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:56:33 -0700 Subject: [PATCH 32/45] updated --- connection_configuration.go | 52 ++++++++++++-------------------- connection_configuration_test.go | 35 +++++++++++++-------- test_data/connections.toml | 22 +++++++------- 3 files changed, 54 insertions(+), 55 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index c3660015e..491e2f9fd 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -37,14 +37,14 @@ func LoadConnectionConfig() (*Config, error) { if err != nil { return nil, err } - connectionName, exist := tomlInfo[dsn] + dsnMap, exist := tomlInfo[dsn] if !exist { return nil, &SnowflakeError{ Number: ErrCodeFailedToFindDSNInToml, Message: errMsgFailedToFindDSNInTomlFile, } } - connectionConfig, ok := connectionName.(map[string]interface{}) + connectionConfig, ok := dsnMap.(map[string]interface{}) if !ok { return nil, err } @@ -66,68 +66,57 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { for key, value := range connection { switch strings.ToLower(key) { case "user", "username": - cfg.User, parsingErr = parseString(value) - if parsingErr != nil { + if cfg.User, parsingErr = parseString(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "password": - cfg.Password, parsingErr = parseString(value) - if parsingErr != nil { + if cfg.Password, parsingErr = parseString(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "host": - cfg.Host, parsingErr = parseString(value) - if parsingErr != nil { + if cfg.Host, parsingErr = parseString(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "account": - cfg.Account, parsingErr = parseString(value) - if parsingErr != nil { + if cfg.Account, parsingErr = parseString(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "warehouse": - cfg.Warehouse, parsingErr = parseString(value) - if parsingErr != nil { + if cfg.Warehouse, parsingErr = parseString(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "database": - cfg.Database, parsingErr = parseString(value) - if parsingErr != nil { + if cfg.Database, parsingErr = parseString(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "schema": - cfg.Schema, parsingErr = parseString(value) - if parsingErr != nil { + if cfg.Schema, parsingErr = parseString(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "role": - cfg.Role, parsingErr = parseString(value) - if parsingErr != nil { + if cfg.Role, parsingErr = parseString(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "region": - cfg.Region, parsingErr = parseString(value) - if parsingErr != nil { + if cfg.Region, parsingErr = parseString(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "protocol": - cfg.Protocol, parsingErr = parseString(value) - if parsingErr != nil { + if cfg.Protocol, parsingErr = parseString(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } case "passcode": - cfg.Passcode, parsingErr = parseString(value) - if parsingErr != nil { + if cfg.Passcode, parsingErr = parseString(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } @@ -276,11 +265,10 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { return err } case "disablequerycontextcache": - if vv, parsingErr = parseBool(value); parsingErr != nil { + if cfg.DisableQueryContextCache, parsingErr = parseBool(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} return err } - cfg.DisableQueryContextCache = vv case "includeretryreason": if vv, parsingErr = parseBool(value); parsingErr != nil { err.MessageArgs = []interface{}{key, value} @@ -343,14 +331,14 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } func parseInt(i interface{}) (int, error) { - if _, ok := i.(string); !ok { - if _, ok := i.(int); !ok { + v, ok := i.(string) + if !ok { + num, ok := i.(int) + if !ok { return 0, errors.New("failed to parse the value to integer") } - num := i.(int) return num, nil } - v := i.(string) num, err := strconv.Atoi(v) if err != nil { @@ -362,10 +350,10 @@ func parseInt(i interface{}) (int, error) { func parseBool(i interface{}) (bool, error) { v, ok := i.(string) if !ok { - if _, ok := i.(bool); !ok { + vv, ok := i.(bool) + if !ok { return false, errors.New("failed to parse the value to boolean") } - vv := i.(bool) return vv, nil } vv, err := strconv.ParseBool(v) diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 7ee2df0dd..fc1acfad4 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -24,13 +24,13 @@ func TestLoadConnectionConfig_Default(t *testing.T) { cfg, err := LoadConnectionConfig() assertNilF(t, err, "The error should not occur") assertEqualF(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") - assertEqualF(t, cfg.User, "test_user") - assertEqualF(t, cfg.Password, "test_pass") - assertEqualF(t, cfg.Warehouse, "testw") - assertEqualF(t, cfg.Database, "test_db") - assertEqualF(t, cfg.Schema, "test_go") + assertEqualF(t, cfg.User, "test_default_user") + assertEqualF(t, cfg.Password, "test_default_pass") + assertEqualF(t, cfg.Warehouse, "testw_default") + assertEqualF(t, cfg.Database, "test_default_db") + assertEqualF(t, cfg.Schema, "test_default_go") assertEqualF(t, cfg.Protocol, "https") - assertEqualF(t, cfg.Port, 443) + assertEqualF(t, cfg.Port, 300) } func TestLoadConnectionConfig_OAuth(t *testing.T) { @@ -43,11 +43,11 @@ func TestLoadConnectionConfig_OAuth(t *testing.T) { cfg, err := LoadConnectionConfig() assertNilF(t, err, "The error should not occur") assertEqualF(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") - assertEqualF(t, cfg.User, "test_user") - assertEqualF(t, cfg.Password, "test_pass") - assertEqualF(t, cfg.Warehouse, "testw") - assertEqualF(t, cfg.Database, "test_db") - assertEqualF(t, cfg.Schema, "test_go") + assertEqualF(t, cfg.User, "test_oauth_user") + assertEqualF(t, cfg.Password, "test_oauth_pass") + assertEqualF(t, cfg.Warehouse, "testw_oauth") + assertEqualF(t, cfg.Database, "test_oauth_db") + assertEqualF(t, cfg.Schema, "test_oauth_go") assertEqualF(t, cfg.Protocol, "https") assertEqualF(t, cfg.Authenticator, AuthTypeOAuth) assertEqualF(t, cfg.Token, "token_value") @@ -149,7 +149,7 @@ func TestParseDuration(t *testing.T) { i = 300 dur, err := parseDuration(i) assertNilF(t, err, "This value should be parsed") - assertEqualF(t, dur, time.Duration(5*int64(time.Minute))) + assertEqualF(t, dur, time.Duration(300*int64(time.Second))) i = "30" dur, err = parseDuration(i) @@ -257,4 +257,15 @@ func TestGetTomlFilePath(t *testing.T) { result, err := path.Abs(location) assertNilF(t, err, "should not have failed") assertEqualF(t, dir, result) + + result = "/user/somelocation/b" + if isWindows { + result = "c:\\user\\somelocation\\b" + } + location = "/user//somelocation///b" + dir, err = getTomlFilePath(location) + assertNilF(t, err, "should not have failed") + // result, err = path.Abs(location) + assertNilF(t, err, "should not have failed") + assertEqualF(t, dir, result) } diff --git a/test_data/connections.toml b/test_data/connections.toml index 878e4fc23..d6475b1ea 100644 --- a/test_data/connections.toml +++ b/test_data/connections.toml @@ -1,20 +1,20 @@ [default] account = 'snowdriverswarsaw.us-west-2.aws' -user = 'test_user' -password = 'test_pass' -warehouse = 'testw' -database = 'test_db' -schema = 'test_go' +user = 'test_default_user' +password = 'test_default_pass' +warehouse = 'testw_default' +database = 'test_default_db' +schema = 'test_default_go' protocol = 'https' -port = '443' +port = '300' [aws-oauth] account = 'snowdriverswarsaw.us-west-2.aws' -user = 'test_user' -password = 'test_pass' -warehouse = 'testw' -database = 'test_db' -schema = 'test_go' +user = 'test_oauth_user' +password = 'test_oauth_pass' +warehouse = 'testw_oauth' +database = 'test_oauth_db' +schema = 'test_oauth_go' protocol = 'https' port = '443' authenticator = 'oauth' From 13cbdbb12c41e9d040fc366818504d2ad0196774 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Wed, 25 Sep 2024 19:46:02 -0700 Subject: [PATCH 33/45] fixed --- connection_configuration.go | 425 +++++++++++-------------------- connection_configuration_test.go | 18 +- 2 files changed, 160 insertions(+), 283 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 491e2f9fd..86474adda 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -14,6 +14,11 @@ import ( toml "github.com/BurntSushi/toml" ) +const ( + connectionName = "SNOWFLAKE_DEFAULT_CONNECTION_NAME" + home = "SNOWFLAKE_HOME" +) + // LoadConnectionConfig returns connection configs loaded from the toml file. // By default, SNOWFLAKE_HOME(toml file path) is os.home/snowflake // and SNOWFLAKE_DEFAULT_CONNECTION_NAME(DSN) is 'default' @@ -22,8 +27,8 @@ func LoadConnectionConfig() (*Config, error) { Params: make(map[string]*string), Authenticator: AuthTypeSnowflake, // Default to snowflake } - dsn := getConnectionDSN(os.Getenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME")) - snowflakeConfigDir, err := getTomlFilePath(os.Getenv("SNOWFLAKE_HOME")) + dsn := getConnectionDSN(os.Getenv(connectionName)) + snowflakeConfigDir, err := getTomlFilePath(os.Getenv(home)) if err != nil { return nil, err } @@ -56,276 +61,145 @@ func LoadConnectionConfig() (*Config, error) { } func parseToml(cfg *Config, connection map[string]interface{}) error { - var v, tokenPath string - var parsingErr error - var vv bool - err := &SnowflakeError{ - Number: ErrCodeTomlFileParsingFailed, - Message: errMsgFailedToParseTomlFile, - } for key, value := range connection { - switch strings.ToLower(key) { - case "user", "username": - if cfg.User, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "password": - if cfg.Password, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "host": - if cfg.Host, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "account": - if cfg.Account, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "warehouse": - if cfg.Warehouse, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "database": - if cfg.Database, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "schema": - if cfg.Schema, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "role": - if cfg.Role, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "region": - if cfg.Region, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "protocol": - if cfg.Protocol, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "passcode": - if cfg.Passcode, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "port": - if cfg.Port, parsingErr = parseInt(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "passcodeinpassword": - if cfg.PasscodeInPassword, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "clienttimeout": - if cfg.ClientTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "jwtclienttimeout": - if cfg.JWTClientTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "logintimeout": - if cfg.LoginTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "requesttimeout": - if cfg.RequestTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "jwttimeout": - if cfg.JWTExpireTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "externalbrowsertimeout": - if cfg.ExternalBrowserTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "maxretrycount": - if cfg.MaxRetryCount, parsingErr = parseInt(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "application": - cfg.Application, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "authenticator": - v, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - parsingErr = determineAuthenticatorType(cfg, v) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "insecuremode": - if cfg.InsecureMode, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "ocspfailopen": - if vv, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - if vv { - cfg.OCSPFailOpen = OCSPFailOpenTrue - } else { - cfg.OCSPFailOpen = OCSPFailOpenFalse - } - - case "token": - cfg.Token, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "privatekey": - v, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - block, decodeErr := base64.URLEncoding.DecodeString(v) - if decodeErr != nil { - return &SnowflakeError{ - Number: ErrCodePrivateKeyParseError, - Message: "Base64 decode failed", - } - } - cfg.PrivateKey, parsingErr = parsePKCS8PrivateKey(block) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "validatedefaultparameters": - if vv, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - if vv { - cfg.ValidateDefaultParameters = ConfigBoolTrue - } else { - cfg.ValidateDefaultParameters = ConfigBoolFalse - } - case "clientrequestmfatoken": - if vv, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - if vv { - cfg.ClientRequestMfaToken = ConfigBoolTrue - } else { - cfg.ClientRequestMfaToken = ConfigBoolFalse - } - case "clientstoretemporarycredential": - if vv, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - if vv { - cfg.ClientStoreTemporaryCredential = ConfigBoolTrue - } else { - cfg.ClientStoreTemporaryCredential = ConfigBoolFalse - } - case "tracing": - cfg.Tracing, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "tmpdirpath": - cfg.TmpDirPath, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "disablequerycontextcache": - if cfg.DisableQueryContextCache, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "includeretryreason": - if vv, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - if vv { - cfg.IncludeRetryReason = ConfigBoolTrue - } else { - cfg.IncludeRetryReason = ConfigBoolFalse - } - case "clientconfigfile": - cfg.ClientConfigFile, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "disableconsolelogin": - if vv, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - if vv { - cfg.DisableConsoleLogin = ConfigBoolTrue - } else { - cfg.DisableConsoleLogin = ConfigBoolFalse - } - case "disablesamlurlcheck": - if vv, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - if vv { - cfg.DisableSamlURLCheck = ConfigBoolTrue - } else { - cfg.DisableSamlURLCheck = ConfigBoolFalse - } - case "token_file_path": - tokenPath, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - default: - param, parsingErr := parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - cfg.Params[urlDecodeIfNeeded(key)] = ¶m + if err := handleSingleParam(cfg, key, value); err != nil { + return err } } if shouldReadTokenFromFile(cfg) { + v, err := readToken("") + if err != nil { + return err + } + cfg.Token = v + } + return nil +} + +func handleSingleParam(cfg *Config, key string, value interface{}) error { + var parsingErr error + var v, tokenPath string + switch strings.ToLower(key) { + case "user", "username": + cfg.User, parsingErr = parseString(value) + case "password": + cfg.Password, parsingErr = parseString(value) + case "host": + cfg.Host, parsingErr = parseString(value) + case "account": + cfg.Account, parsingErr = parseString(value) + case "warehouse": + cfg.Warehouse, parsingErr = parseString(value) + case "database": + cfg.Database, parsingErr = parseString(value) + case "schema": + cfg.Schema, parsingErr = parseString(value) + case "role": + cfg.Role, parsingErr = parseString(value) + case "region": + cfg.Region, parsingErr = parseString(value) + case "protocol": + cfg.Protocol, parsingErr = parseString(value) + case "passcode": + cfg.Passcode, parsingErr = parseString(value) + case "port": + cfg.Port, parsingErr = parseInt(value) + case "passcodeinpassword": + cfg.PasscodeInPassword, parsingErr = parseBool(value) + case "clienttimeout": + cfg.ClientTimeout, parsingErr = parseDuration(value) + case "jwtclienttimeout": + cfg.JWTClientTimeout, parsingErr = parseDuration(value) + case "logintimeout": + cfg.LoginTimeout, parsingErr = parseDuration(value) + case "requesttimeout": + cfg.RequestTimeout, parsingErr = parseDuration(value) + case "jwttimeout": + cfg.JWTExpireTimeout, parsingErr = parseDuration(value) + case "externalbrowsertimeout": + cfg.ExternalBrowserTimeout, parsingErr = parseDuration(value) + case "maxretrycount": + cfg.MaxRetryCount, parsingErr = parseInt(value) + case "application": + cfg.Application, parsingErr = parseString(value) + case "authenticator": + v, parsingErr = parseString(value) + if err := checkParsingError(parsingErr, key, value); err != nil { + return err + } + parsingErr = determineAuthenticatorType(cfg, v) + case "insecuremode": + cfg.InsecureMode, parsingErr = parseBool(value) + case "ocspfailopen": + var vv ConfigBool + vv, parsingErr = parseConfigBool(value) + if err := checkParsingError(parsingErr, key, value); err != nil { + return err + } + cfg.OCSPFailOpen = OCSPFailOpenMode(vv) + case "token": + cfg.Token, parsingErr = parseString(value) + case "privatekey": + v, parsingErr = parseString(value) + if err := checkParsingError(parsingErr, key, value); err != nil { + return err + } + block, decodeErr := base64.URLEncoding.DecodeString(v) + if decodeErr != nil { + return &SnowflakeError{ + Number: ErrCodePrivateKeyParseError, + Message: "Base64 decode failed", + } + } + cfg.PrivateKey, parsingErr = parsePKCS8PrivateKey(block) + case "validatedefaultparameters": + cfg.ValidateDefaultParameters, parsingErr = parseConfigBool(value) + case "clientrequestmfatoken": + cfg.ClientRequestMfaToken, parsingErr = parseConfigBool(value) + case "clientstoretemporarycredential": + cfg.ClientStoreTemporaryCredential, parsingErr = parseConfigBool(value) + case "tracing": + cfg.Tracing, parsingErr = parseString(value) + case "tmpdirpath": + cfg.TmpDirPath, parsingErr = parseString(value) + case "disablequerycontextcache": + cfg.DisableQueryContextCache, parsingErr = parseBool(value) + case "includeretryreason": + cfg.IncludeRetryReason, parsingErr = parseConfigBool(value) + case "clientconfigfile": + cfg.ClientConfigFile, parsingErr = parseString(value) + case "disableconsolelogin": + cfg.DisableConsoleLogin, parsingErr = parseConfigBool(value) + case "disablesamlurlcheck": + cfg.DisableSamlURLCheck, parsingErr = parseConfigBool(value) + case "token_file_path": + tokenPath, parsingErr = parseString(value) + if err := checkParsingError(parsingErr, key, value); err != nil { + return err + } v, err := readToken(tokenPath) if err != nil { return err } cfg.Token = v + default: + param, parsingErr := parseString(value) + if err := checkParsingError(parsingErr, key, value); err != nil { + return err + } + cfg.Params[urlDecodeIfNeeded(key)] = ¶m + } + return checkParsingError(parsingErr, key, value) +} + +func checkParsingError(parsingErr error, key string, value interface{}) error { + if parsingErr != nil { + err := &SnowflakeError{ + Number: ErrCodeTomlFileParsingFailed, + Message: errMsgFailedToParseTomlFile, + MessageArgs: []interface{}{key, value}, + } + return err } return nil } @@ -363,6 +237,17 @@ func parseBool(i interface{}) (bool, error) { return vv, nil } +func parseConfigBool(i interface{}) (ConfigBool, error) { + vv, err := parseBool(i) + if err != nil { + return ConfigBoolFalse, err + } + if vv { + return ConfigBoolTrue, nil + } + return ConfigBoolFalse, nil +} + func parseDuration(i interface{}) (time.Duration, error) { v, ok := i.(string) if !ok { @@ -373,11 +258,7 @@ func parseDuration(i interface{}) (time.Duration, error) { t := int64(num) return time.Duration(t * int64(time.Second)), nil } - t, err := strconv.ParseInt(v, 10, 64) - if err != nil { - return time.Duration(0), err - } - return time.Duration(t * int64(time.Second)), nil + return parseTimeout(v) } func readToken(tokenPath string) (string, error) { @@ -385,7 +266,7 @@ func readToken(tokenPath string) (string, error) { tokenPath = "./snowflake/session/token" } if !path.IsAbs(tokenPath) { - snowflakeConfigDir, err := getTomlFilePath(os.Getenv("SNOWFLAKE_HOME")) + snowflakeConfigDir, err := getTomlFilePath(os.Getenv(home)) if err != nil { return "", err } @@ -411,11 +292,7 @@ func parseString(i interface{}) (string, error) { } func getTomlFilePath(filePath string) (string, error) { - if len(filePath) != 0 { - if path.IsAbs(filePath) { - return filePath, nil - } - } else { + if len(filePath) == 0 { homeDir, err := os.UserHomeDir() if err != nil { return "", err diff --git a/connection_configuration_test.go b/connection_configuration_test.go index fc1acfad4..f68f42b08 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -258,14 +258,14 @@ func TestGetTomlFilePath(t *testing.T) { assertNilF(t, err, "should not have failed") assertEqualF(t, dir, result) - result = "/user/somelocation/b" - if isWindows { - result = "c:\\user\\somelocation\\b" + //Absolute path for windows can be varied depend on which disk the driver is located. + // As a result, this test is available on non-Window machines. + if !isWindows { + result = "/user/somelocation/b" + location = "/user//somelocation///b" + dir, err = getTomlFilePath(location) + assertNilF(t, err, "should not have failed") + assertEqualF(t, dir, result) } - location = "/user//somelocation///b" - dir, err = getTomlFilePath(location) - assertNilF(t, err, "should not have failed") - // result, err = path.Abs(location) - assertNilF(t, err, "should not have failed") - assertEqualF(t, dir, result) + } From cf175f11bcf0e67347821b316da1747f45cc1037 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Thu, 26 Sep 2024 17:25:46 -0700 Subject: [PATCH 34/45] fix --- connection_configuration.go | 29 ++++++++++------------------- connection_configuration_test.go | 26 +++++++++++++------------- 2 files changed, 23 insertions(+), 32 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 86474adda..fbcc4a0ed 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -15,20 +15,20 @@ import ( ) const ( - connectionName = "SNOWFLAKE_DEFAULT_CONNECTION_NAME" - home = "SNOWFLAKE_HOME" + snowflake_connectionName = "SNOWFLAKE_DEFAULT_CONNECTION_NAME" + snowflake_home = "SNOWFLAKE_HOME" ) // LoadConnectionConfig returns connection configs loaded from the toml file. -// By default, SNOWFLAKE_HOME(toml file path) is os.home/snowflake +// By default, SNOWFLAKE_HOME(toml file path) is os.snowflake_home/.snowflake // and SNOWFLAKE_DEFAULT_CONNECTION_NAME(DSN) is 'default' func LoadConnectionConfig() (*Config, error) { cfg := &Config{ Params: make(map[string]*string), Authenticator: AuthTypeSnowflake, // Default to snowflake } - dsn := getConnectionDSN(os.Getenv(connectionName)) - snowflakeConfigDir, err := getTomlFilePath(os.Getenv(home)) + dsn := getConnectionDSN(os.Getenv(snowflake_connectionName)) + snowflakeConfigDir, err := getTomlFilePath(os.Getenv(snowflake_home)) if err != nil { return nil, err } @@ -213,12 +213,7 @@ func parseInt(i interface{}) (int, error) { } return num, nil } - num, err := strconv.Atoi(v) - - if err != nil { - return num, err - } - return num, nil + return strconv.Atoi(v) } func parseBool(i interface{}) (bool, error) { @@ -230,11 +225,7 @@ func parseBool(i interface{}) (bool, error) { } return vv, nil } - vv, err := strconv.ParseBool(v) - if err != nil { - return false, errors.New("failed to parse the value to boolean") - } - return vv, nil + return strconv.ParseBool(v) } func parseConfigBool(i interface{}) (ConfigBool, error) { @@ -266,7 +257,7 @@ func readToken(tokenPath string) (string, error) { tokenPath = "./snowflake/session/token" } if !path.IsAbs(tokenPath) { - snowflakeConfigDir, err := getTomlFilePath(os.Getenv(home)) + snowflakeConfigDir, err := getTomlFilePath(os.Getenv(snowflake_home)) if err != nil { return "", err } @@ -297,7 +288,7 @@ func getTomlFilePath(filePath string) (string, error) { if err != nil { return "", err } - filePath = path.Join(homeDir, "snowflake") + filePath = path.Join(homeDir, ".snowflake") } absDir, err := path.Abs(filePath) if err != nil { @@ -322,7 +313,7 @@ func validateFilePermission(filePath string) error { return err } if permission := fileInfo.Mode().Perm(); permission != os.FileMode(0600) { - return errors.New("your access to the file was denied") + return errors.New("file permissions different than read/write for user") } return nil } diff --git a/connection_configuration_test.go b/connection_configuration_test.go index f68f42b08..1dc130ca3 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -19,7 +19,7 @@ func TestLoadConnectionConfig_Default(t *testing.T) { err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") - os.Setenv("SNOWFLAKE_HOME", "./test_data") + os.Setenv(snowflake_home, "./test_data") cfg, err := LoadConnectionConfig() assertNilF(t, err, "The error should not occur") @@ -37,8 +37,8 @@ func TestLoadConnectionConfig_OAuth(t *testing.T) { err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") - os.Setenv("SNOWFLAKE_HOME", "./test_data") - os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth") + os.Setenv(snowflake_home, "./test_data") + os.Setenv(snowflake_connectionName, "aws-oauth") cfg, err := LoadConnectionConfig() assertNilF(t, err, "The error should not occur") @@ -61,16 +61,16 @@ func TestReadTokenValueWithTokenFilePath(t *testing.T) { err = os.Chmod("./test_data/snowflake/session/token", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") - os.Setenv("SNOWFLAKE_HOME", "./test_data") - os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "no-token-path") + os.Setenv(snowflake_home, "./test_data") + os.Setenv(snowflake_connectionName, "no-token-path") cfg, err := LoadConnectionConfig() assertNilF(t, err, "The error should not occur") assertEqualF(t, cfg.Authenticator, AuthTypeOAuth) assertEqualF(t, cfg.Token, "mock_token123456") - os.Setenv("SNOWFLAKE_HOME", "./test_data") - os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "read-token") + os.Setenv(snowflake_home, "./test_data") + os.Setenv(snowflake_connectionName, "read-token") cfg, err = LoadConnectionConfig() assertNilF(t, err, "The error should not occur") @@ -82,8 +82,8 @@ func TestLoadConnectionConfigWitNonExistingDSN(t *testing.T) { err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") - os.Setenv("SNOWFLAKE_HOME", "./test_data") - os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "unavailableDSN") + os.Setenv(snowflake_home, "./test_data") + os.Setenv(snowflake_connectionName, "unavailableDSN") _, err = LoadConnectionConfig() assertNotNilF(t, err, "The error should occur") @@ -97,8 +97,8 @@ func TestLoadConnectionConfigWithTokenFileNotExist(t *testing.T) { err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") - os.Setenv("SNOWFLAKE_HOME", "./test_data") - os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth-file") + os.Setenv(snowflake_home, "./test_data") + os.Setenv(snowflake_connectionName, "aws-oauth-file") _, err = LoadConnectionConfig() assertNotNilF(t, err, "The error should occur") @@ -248,8 +248,8 @@ func TestGetTomlFilePath(t *testing.T) { dir, err := getTomlFilePath("") assertNilF(t, err, "should not have failed") homeDir, err := os.UserHomeDir() - assertNilF(t, err, "The connection cannot find the user home directory") - assertEqualF(t, dir, path.Join(homeDir, "snowflake")) + assertNilF(t, err, "The connection cannot find the user snowflake_home directory") + assertEqualF(t, dir, path.Join(homeDir, ".snowflake")) location := "../user//somelocation///b" dir, err = getTomlFilePath(location) From cdc9031fff06b3548403bf6ec465292afc62ecae Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Thu, 26 Sep 2024 17:45:42 -0700 Subject: [PATCH 35/45] lint fix --- connection_configuration.go | 12 ++++++------ connection_configuration_test.go | 24 ++++++++++++------------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index fbcc4a0ed..0b624aad4 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -15,20 +15,20 @@ import ( ) const ( - snowflake_connectionName = "SNOWFLAKE_DEFAULT_CONNECTION_NAME" - snowflake_home = "SNOWFLAKE_HOME" + snowflakeConnectionName = "SNOWFLAKE_DEFAULT_CONNECTION_NAME" + snowflakeHome = "SNOWFLAKE_HOME" ) // LoadConnectionConfig returns connection configs loaded from the toml file. -// By default, SNOWFLAKE_HOME(toml file path) is os.snowflake_home/.snowflake +// By default, SNOWFLAKE_HOME(toml file path) is os.snowflakeHome/.snowflake // and SNOWFLAKE_DEFAULT_CONNECTION_NAME(DSN) is 'default' func LoadConnectionConfig() (*Config, error) { cfg := &Config{ Params: make(map[string]*string), Authenticator: AuthTypeSnowflake, // Default to snowflake } - dsn := getConnectionDSN(os.Getenv(snowflake_connectionName)) - snowflakeConfigDir, err := getTomlFilePath(os.Getenv(snowflake_home)) + dsn := getConnectionDSN(os.Getenv(snowflakeConnectionName)) + snowflakeConfigDir, err := getTomlFilePath(os.Getenv(snowflakeHome)) if err != nil { return nil, err } @@ -257,7 +257,7 @@ func readToken(tokenPath string) (string, error) { tokenPath = "./snowflake/session/token" } if !path.IsAbs(tokenPath) { - snowflakeConfigDir, err := getTomlFilePath(os.Getenv(snowflake_home)) + snowflakeConfigDir, err := getTomlFilePath(os.Getenv(snowflakeHome)) if err != nil { return "", err } diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 1dc130ca3..abd6661b4 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -19,7 +19,7 @@ func TestLoadConnectionConfig_Default(t *testing.T) { err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") - os.Setenv(snowflake_home, "./test_data") + os.Setenv(snowflakeHome, "./test_data") cfg, err := LoadConnectionConfig() assertNilF(t, err, "The error should not occur") @@ -37,8 +37,8 @@ func TestLoadConnectionConfig_OAuth(t *testing.T) { err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") - os.Setenv(snowflake_home, "./test_data") - os.Setenv(snowflake_connectionName, "aws-oauth") + os.Setenv(snowflakeHome, "./test_data") + os.Setenv(snowflakeConnectionName, "aws-oauth") cfg, err := LoadConnectionConfig() assertNilF(t, err, "The error should not occur") @@ -61,16 +61,16 @@ func TestReadTokenValueWithTokenFilePath(t *testing.T) { err = os.Chmod("./test_data/snowflake/session/token", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") - os.Setenv(snowflake_home, "./test_data") - os.Setenv(snowflake_connectionName, "no-token-path") + os.Setenv(snowflakeHome, "./test_data") + os.Setenv(snowflakeConnectionName, "no-token-path") cfg, err := LoadConnectionConfig() assertNilF(t, err, "The error should not occur") assertEqualF(t, cfg.Authenticator, AuthTypeOAuth) assertEqualF(t, cfg.Token, "mock_token123456") - os.Setenv(snowflake_home, "./test_data") - os.Setenv(snowflake_connectionName, "read-token") + os.Setenv(snowflakeHome, "./test_data") + os.Setenv(snowflakeConnectionName, "read-token") cfg, err = LoadConnectionConfig() assertNilF(t, err, "The error should not occur") @@ -82,8 +82,8 @@ func TestLoadConnectionConfigWitNonExistingDSN(t *testing.T) { err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") - os.Setenv(snowflake_home, "./test_data") - os.Setenv(snowflake_connectionName, "unavailableDSN") + os.Setenv(snowflakeHome, "./test_data") + os.Setenv(snowflakeConnectionName, "unavailableDSN") _, err = LoadConnectionConfig() assertNotNilF(t, err, "The error should occur") @@ -97,8 +97,8 @@ func TestLoadConnectionConfigWithTokenFileNotExist(t *testing.T) { err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") - os.Setenv(snowflake_home, "./test_data") - os.Setenv(snowflake_connectionName, "aws-oauth-file") + os.Setenv(snowflakeHome, "./test_data") + os.Setenv(snowflakeConnectionName, "aws-oauth-file") _, err = LoadConnectionConfig() assertNotNilF(t, err, "The error should occur") @@ -248,7 +248,7 @@ func TestGetTomlFilePath(t *testing.T) { dir, err := getTomlFilePath("") assertNilF(t, err, "should not have failed") homeDir, err := os.UserHomeDir() - assertNilF(t, err, "The connection cannot find the user snowflake_home directory") + assertNilF(t, err, "The connection cannot find the user home directory") assertEqualF(t, dir, path.Join(homeDir, ".snowflake")) location := "../user//somelocation///b" From 23a254d8066c01c477ddbffd7c7c3652dddb536f Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Tue, 1 Oct 2024 15:07:07 -0700 Subject: [PATCH 36/45] fix --- connection_configuration.go | 7 ++++--- connection_configuration_test.go | 10 +--------- test_data/connections.toml | 2 +- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 0b624aad4..96029fa29 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -17,6 +17,7 @@ import ( const ( snowflakeConnectionName = "SNOWFLAKE_DEFAULT_CONNECTION_NAME" snowflakeHome = "SNOWFLAKE_HOME" + defaultTokenPath = "/snowflake/session/token" ) // LoadConnectionConfig returns connection configs loaded from the toml file. @@ -254,14 +255,14 @@ func parseDuration(i interface{}) (time.Duration, error) { func readToken(tokenPath string) (string, error) { if tokenPath == "" { - tokenPath = "./snowflake/session/token" + tokenPath = defaultTokenPath } if !path.IsAbs(tokenPath) { - snowflakeConfigDir, err := getTomlFilePath(os.Getenv(snowflakeHome)) + var err error + tokenPath, err = path.Abs(tokenPath) if err != nil { return "", err } - tokenPath = path.Join(snowflakeConfigDir, tokenPath) } err := validateFilePermission(tokenPath) if err != nil { diff --git a/connection_configuration_test.go b/connection_configuration_test.go index abd6661b4..dafb595e6 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -61,18 +61,10 @@ func TestReadTokenValueWithTokenFilePath(t *testing.T) { err = os.Chmod("./test_data/snowflake/session/token", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") - os.Setenv(snowflakeHome, "./test_data") - os.Setenv(snowflakeConnectionName, "no-token-path") - - cfg, err := LoadConnectionConfig() - assertNilF(t, err, "The error should not occur") - assertEqualF(t, cfg.Authenticator, AuthTypeOAuth) - assertEqualF(t, cfg.Token, "mock_token123456") - os.Setenv(snowflakeHome, "./test_data") os.Setenv(snowflakeConnectionName, "read-token") - cfg, err = LoadConnectionConfig() + cfg, err := LoadConnectionConfig() assertNilF(t, err, "The error should not occur") assertEqualF(t, cfg.Authenticator, AuthTypeOAuth) assertEqualF(t, cfg.Token, "mock_token123456") diff --git a/test_data/connections.toml b/test_data/connections.toml index d6475b1ea..95a9fcba5 100644 --- a/test_data/connections.toml +++ b/test_data/connections.toml @@ -36,7 +36,7 @@ token_file_path = '/Users/test/.snowflake/token' [read-token] authenticator = 'oauth' -token_file_path = './snowflake/session/token' +token_file_path = './test_data/snowflake/session/token' [no-token-path] authenticator = 'oauth' \ No newline at end of file From 8f4033851b391fb98552721e4a98f4f68835e824 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Tue, 1 Oct 2024 17:01:03 -0700 Subject: [PATCH 37/45] add permission testing --- connection_configuration_test.go | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/connection_configuration_test.go b/connection_configuration_test.go index dafb595e6..a104c2e31 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -10,12 +10,41 @@ import ( "time" ) -func TestLoadConnectionConfig_Default(t *testing.T) { +func TestTokenFilePermission(t *testing.T) { if !isWindows { _, err := LoadConnectionConfig() assertNotNilF(t, err, "The error should occur because you cannot change the file permission") + + _, err = readToken("./test_data/snowflake/session") + assertNotNilF(t, err, "The error should occur because you cannot change the file permission") + + err = os.Chmod("./test_data/connections.toml", 0666) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + + err = os.Chmod("./test_data/snowflake/session/token", 0666) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + + _, err = LoadConnectionConfig() + assertNotNilF(t, err, "The error should occur because you cannot change the file permission") + + _, err = readToken("./test_data/snowflake/session") + assertNotNilF(t, err, "The error should occur because you cannot change the file permission") + + err = os.Chmod("./test_data/connections.toml", 0600) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + + err = os.Chmod("./test_data/snowflake/session/token", 0600) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + + _, err = LoadConnectionConfig() + assertNilF(t, err, "The error should occur because you cannot change the file permission") + + _, err = readToken("./test_data/snowflake/session") + assertNilF(t, err, "The error should occur because you cannot change the file permission") } +} +func TestLoadConnectionConfig_Default(t *testing.T) { err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") From 608fb0cdfe41d241af210698dc008ceaeb02bd30 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Tue, 1 Oct 2024 17:03:27 -0700 Subject: [PATCH 38/45] fix testing error message --- connection_configuration_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/connection_configuration_test.go b/connection_configuration_test.go index a104c2e31..a1dfdae92 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -13,22 +13,22 @@ import ( func TestTokenFilePermission(t *testing.T) { if !isWindows { _, err := LoadConnectionConfig() - assertNotNilF(t, err, "The error should occur because you cannot change the file permission") + assertNotNilF(t, err, "The error should occur because you the permission is not 0600") _, err = readToken("./test_data/snowflake/session") - assertNotNilF(t, err, "The error should occur because you cannot change the file permission") + assertNotNilF(t, err, "The error should occur because you the permission is not 0600") err = os.Chmod("./test_data/connections.toml", 0666) - assertNilF(t, err, "The error occurred because you cannot change the file permission") + assertNotNilF(t, err, "The error occurred because you cannot change the file permission") err = os.Chmod("./test_data/snowflake/session/token", 0666) - assertNilF(t, err, "The error occurred because you cannot change the file permission") + assertNotNilF(t, err, "TThe error occurred because you cannot change the file permission") _, err = LoadConnectionConfig() - assertNotNilF(t, err, "The error should occur because you cannot change the file permission") + assertNotNilF(t, err, "The error should occur because you the permission is not 0600") _, err = readToken("./test_data/snowflake/session") - assertNotNilF(t, err, "The error should occur because you cannot change the file permission") + assertNotNilF(t, err, "The error should occur because you the permission is not 0600") err = os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") @@ -37,10 +37,10 @@ func TestTokenFilePermission(t *testing.T) { assertNilF(t, err, "The error occurred because you cannot change the file permission") _, err = LoadConnectionConfig() - assertNilF(t, err, "The error should occur because you cannot change the file permission") + assertNilF(t, err, "The error should occur because you the permission is not 0600") _, err = readToken("./test_data/snowflake/session") - assertNilF(t, err, "The error should occur because you cannot change the file permission") + assertNilF(t, err, "The error should occur because you the permission is not 0600") } } From 418705dd7833d334d004b465b68c5dc7fff6ddb7 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Tue, 1 Oct 2024 17:15:28 -0700 Subject: [PATCH 39/45] removed unnecessary DSN --- test_data/connections.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test_data/connections.toml b/test_data/connections.toml index 95a9fcba5..fb4d8e4ad 100644 --- a/test_data/connections.toml +++ b/test_data/connections.toml @@ -36,7 +36,4 @@ token_file_path = '/Users/test/.snowflake/token' [read-token] authenticator = 'oauth' -token_file_path = './test_data/snowflake/session/token' - -[no-token-path] -authenticator = 'oauth' \ No newline at end of file +token_file_path = './test_data/snowflake/session/token' \ No newline at end of file From 6042d914ac9376ae693b4827fb560bb668d41c10 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Tue, 1 Oct 2024 18:00:45 -0700 Subject: [PATCH 40/45] fixed testing --- connection_configuration_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/connection_configuration_test.go b/connection_configuration_test.go index a1dfdae92..5b59e7bcb 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -12,6 +12,7 @@ import ( func TestTokenFilePermission(t *testing.T) { if !isWindows { + os.Setenv(snowflakeHome, "./test_data") _, err := LoadConnectionConfig() assertNotNilF(t, err, "The error should occur because you the permission is not 0600") From a4d7d517802d67f964eb16846a8510270e4e1f6f Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Wed, 2 Oct 2024 10:28:12 -0700 Subject: [PATCH 41/45] fix error --- connection_configuration_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 5b59e7bcb..f95f08823 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -14,22 +14,22 @@ func TestTokenFilePermission(t *testing.T) { if !isWindows { os.Setenv(snowflakeHome, "./test_data") _, err := LoadConnectionConfig() - assertNotNilF(t, err, "The error should occur because you the permission is not 0600") + assertNotNilF(t, err, "The error should occur because the permission is not 0600") _, err = readToken("./test_data/snowflake/session") - assertNotNilF(t, err, "The error should occur because you the permission is not 0600") + assertNotNilF(t, err, "The error should occur because the permission is not 0600") err = os.Chmod("./test_data/connections.toml", 0666) - assertNotNilF(t, err, "The error occurred because you cannot change the file permission") + assertNilF(t, err, "The error occurred because you cannot change the file permission") err = os.Chmod("./test_data/snowflake/session/token", 0666) - assertNotNilF(t, err, "TThe error occurred because you cannot change the file permission") + assertNilF(t, err, "TThe error occurred because you cannot change the file permission") _, err = LoadConnectionConfig() - assertNotNilF(t, err, "The error should occur because you the permission is not 0600") + assertNotNilF(t, err, "The error should occur because the permission is not 0600") _, err = readToken("./test_data/snowflake/session") - assertNotNilF(t, err, "The error should occur because you the permission is not 0600") + assertNotNilF(t, err, "The error should occur because the permission is not 0600") err = os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") @@ -38,10 +38,10 @@ func TestTokenFilePermission(t *testing.T) { assertNilF(t, err, "The error occurred because you cannot change the file permission") _, err = LoadConnectionConfig() - assertNilF(t, err, "The error should occur because you the permission is not 0600") + assertNilF(t, err, "The error occurred because the permission is not 0600") _, err = readToken("./test_data/snowflake/session") - assertNilF(t, err, "The error should occur because you the permission is not 0600") + assertNilF(t, err, "The error occurred because the permission is not 0600") } } From 107d01885af7034594950eaba514e6a568f26738 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Wed, 2 Oct 2024 20:34:45 -0700 Subject: [PATCH 42/45] fix --- cmd/tomlfileconnection/tomlfileconnection.go | 20 +++++--------------- connection_configuration_test.go | 2 +- driver.go | 8 +++++++- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/cmd/tomlfileconnection/tomlfileconnection.go b/cmd/tomlfileconnection/tomlfileconnection.go index f82b0b629..91b91d9dc 100644 --- a/cmd/tomlfileconnection/tomlfileconnection.go +++ b/cmd/tomlfileconnection/tomlfileconnection.go @@ -7,9 +7,8 @@ import ( "flag" "fmt" "log" - "os" - sf "github.com/snowflakedb/gosnowflake" + _ "github.com/snowflakedb/gosnowflake" ) func main() { @@ -17,21 +16,12 @@ func main() { flag.Parse() } - os.Setenv("SNOWFLAKE_HOME", "") - os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "") + // os.Setenv("SNOWFLAKE_HOME", "") + // os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "") - cfg, err := sf.LoadConnectionConfig() + db, err := sql.Open("snowflake", "autoConfig") if err != nil { - log.Fatalf("failed to create Config, err: %v", err) - } - dsn, err := sf.DSN(cfg) - if err != nil { - log.Fatalf("failed to create DSN from Config: %v, err: %v", cfg, err) - } - - db, err := sql.Open("snowflake", dsn) - if err != nil { - log.Fatalf("failed to connect. %v, err: %v", dsn, err) + log.Fatalf("failed to connect. %v,", err) } defer db.Close() query := "SELECT 1" diff --git a/connection_configuration_test.go b/connection_configuration_test.go index f95f08823..90a645735 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -40,7 +40,7 @@ func TestTokenFilePermission(t *testing.T) { _, err = LoadConnectionConfig() assertNilF(t, err, "The error occurred because the permission is not 0600") - _, err = readToken("./test_data/snowflake/session") + _, err = readToken("./test_data/snowflake/session/token") assertNilF(t, err, "The error occurred because the permission is not 0600") } } diff --git a/driver.go b/driver.go index 9679118bb..8d89a1b3b 100644 --- a/driver.go +++ b/driver.go @@ -18,9 +18,15 @@ type SnowflakeDriver struct{} // Open creates a new connection. func (d SnowflakeDriver) Open(dsn string) (driver.Conn, error) { + var cfg *Config + var err error logger.Info("Open") ctx := context.Background() - cfg, err := ParseDSN(dsn) + if dsn == "autoConfig" { + cfg, err = LoadConnectionConfig() + } else { + cfg, err = ParseDSN(dsn) + } if err != nil { return nil, err } From 4a2d8191038b4b563f941fbc100acf81104d6b81 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Wed, 2 Oct 2024 21:50:01 -0700 Subject: [PATCH 43/45] added error code for the validation and fixed the bug for cmd tomlconnection.go --- connection_configuration.go | 13 +++++++++++-- connection_configuration_test.go | 32 ++++++++++++++++++++++---------- driver.go | 2 +- errors.go | 3 +++ 4 files changed, 37 insertions(+), 13 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 96029fa29..534b567ba 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -23,7 +23,7 @@ const ( // LoadConnectionConfig returns connection configs loaded from the toml file. // By default, SNOWFLAKE_HOME(toml file path) is os.snowflakeHome/.snowflake // and SNOWFLAKE_DEFAULT_CONNECTION_NAME(DSN) is 'default' -func LoadConnectionConfig() (*Config, error) { +func loadConnectionConfig() (*Config, error) { cfg := &Config{ Params: make(map[string]*string), Authenticator: AuthTypeSnowflake, // Default to snowflake @@ -74,6 +74,11 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } cfg.Token = v } + + err := fillMissingConfigParameters(cfg) + if err != nil { + return err + } return nil } @@ -314,7 +319,11 @@ func validateFilePermission(filePath string) error { return err } if permission := fileInfo.Mode().Perm(); permission != os.FileMode(0600) { - return errors.New("file permissions different than read/write for user") + return err := &SnowflakeError{ + Number: ErrCodeInvalidFilePermission, + Message: errMsgInvalidPermissionToTomlFile, + MessageArgs: []interface{}{permission}, + } } return nil } diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 90a645735..6a21c850f 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -13,11 +13,17 @@ import ( func TestTokenFilePermission(t *testing.T) { if !isWindows { os.Setenv(snowflakeHome, "./test_data") - _, err := LoadConnectionConfig() + _, err := loadConnectionConfig() assertNotNilF(t, err, "The error should occur because the permission is not 0600") + driverErr, ok := err.(*SnowflakeError) + assertTrueF(t, ok, "This should be a Snowflake Error") + assertEqualF(t, driverErr.Number, ErrCodeInvalidFilePermission) _, err = readToken("./test_data/snowflake/session") assertNotNilF(t, err, "The error should occur because the permission is not 0600") + driverErr, ok = err.(*SnowflakeError) + assertTrueF(t, ok, "This should be a Snowflake Error") + assertEqualF(t, driverErr.Number, ErrCodeInvalidFilePermission) err = os.Chmod("./test_data/connections.toml", 0666) assertNilF(t, err, "The error occurred because you cannot change the file permission") @@ -25,11 +31,17 @@ func TestTokenFilePermission(t *testing.T) { err = os.Chmod("./test_data/snowflake/session/token", 0666) assertNilF(t, err, "TThe error occurred because you cannot change the file permission") - _, err = LoadConnectionConfig() + _, err = loadConnectionConfig() assertNotNilF(t, err, "The error should occur because the permission is not 0600") + driverErr, ok = err.(*SnowflakeError) + assertTrueF(t, ok, "This should be a Snowflake Error") + assertEqualF(t, driverErr.Number, ErrCodeInvalidFilePermission) _, err = readToken("./test_data/snowflake/session") assertNotNilF(t, err, "The error should occur because the permission is not 0600") + driverErr, ok = err.(*SnowflakeError) + assertTrueF(t, ok, "This should be a Snowflake Error") + assertEqualF(t, driverErr.Number, ErrCodeInvalidFilePermission) err = os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") @@ -37,7 +49,7 @@ func TestTokenFilePermission(t *testing.T) { err = os.Chmod("./test_data/snowflake/session/token", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") - _, err = LoadConnectionConfig() + _, err = loadConnectionConfig() assertNilF(t, err, "The error occurred because the permission is not 0600") _, err = readToken("./test_data/snowflake/session/token") @@ -45,13 +57,13 @@ func TestTokenFilePermission(t *testing.T) { } } -func TestLoadConnectionConfig_Default(t *testing.T) { +func TestloadConnectionConfig_Default(t *testing.T) { err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") os.Setenv(snowflakeHome, "./test_data") - cfg, err := LoadConnectionConfig() + cfg, err := loadConnectionConfig() assertNilF(t, err, "The error should not occur") assertEqualF(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") assertEqualF(t, cfg.User, "test_default_user") @@ -63,14 +75,14 @@ func TestLoadConnectionConfig_Default(t *testing.T) { assertEqualF(t, cfg.Port, 300) } -func TestLoadConnectionConfig_OAuth(t *testing.T) { +func TestloadConnectionConfig_OAuth(t *testing.T) { err := os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") os.Setenv(snowflakeHome, "./test_data") os.Setenv(snowflakeConnectionName, "aws-oauth") - cfg, err := LoadConnectionConfig() + cfg, err := loadConnectionConfig() assertNilF(t, err, "The error should not occur") assertEqualF(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") assertEqualF(t, cfg.User, "test_oauth_user") @@ -94,7 +106,7 @@ func TestReadTokenValueWithTokenFilePath(t *testing.T) { os.Setenv(snowflakeHome, "./test_data") os.Setenv(snowflakeConnectionName, "read-token") - cfg, err := LoadConnectionConfig() + cfg, err := loadConnectionConfig() assertNilF(t, err, "The error should not occur") assertEqualF(t, cfg.Authenticator, AuthTypeOAuth) assertEqualF(t, cfg.Token, "mock_token123456") @@ -107,7 +119,7 @@ func TestLoadConnectionConfigWitNonExistingDSN(t *testing.T) { os.Setenv(snowflakeHome, "./test_data") os.Setenv(snowflakeConnectionName, "unavailableDSN") - _, err = LoadConnectionConfig() + _, err = loadConnectionConfig() assertNotNilF(t, err, "The error should occur") driverErr, ok := err.(*SnowflakeError) @@ -122,7 +134,7 @@ func TestLoadConnectionConfigWithTokenFileNotExist(t *testing.T) { os.Setenv(snowflakeHome, "./test_data") os.Setenv(snowflakeConnectionName, "aws-oauth-file") - _, err = LoadConnectionConfig() + _, err = loadConnectionConfig() assertNotNilF(t, err, "The error should occur") _, ok := err.(*(fs.PathError)) diff --git a/driver.go b/driver.go index 8d89a1b3b..3c8c7b291 100644 --- a/driver.go +++ b/driver.go @@ -23,7 +23,7 @@ func (d SnowflakeDriver) Open(dsn string) (driver.Conn, error) { logger.Info("Open") ctx := context.Background() if dsn == "autoConfig" { - cfg, err = LoadConnectionConfig() + cfg, err = loadConnectionConfig() } else { cfg, err = ParseDSN(dsn) } diff --git a/errors.go b/errors.go index 7e56040fe..2e5d902f0 100644 --- a/errors.go +++ b/errors.go @@ -131,6 +131,8 @@ const ( ErrCodeTomlFileParsingFailed = 260013 // ErrCodeFailedToFindDSNInToml is an error code for the case where the DSN does not exist in the toml file. ErrCodeFailedToFindDSNInToml = 260014 + // ErrCodeInvalidFilePermission is an error code for the case where the user does not have 0600 permission to the toml file . + ErrCodeInvalidFilePermission = 260015 /* network */ @@ -305,6 +307,7 @@ const ( errMsgNullValueInMap = "for handling null values in maps use WithMapValuesNullable(ctx)" errMsgFailedToParseTomlFile = "failed to parse toml file. the params %v occurred error with value %v" errMsgFailedToFindDSNInTomlFile = "failed to find DSN in toml file." + errMsgInvalidPermissionToTomlFile = "file permissions different than read/write for user. Your Permission: %v" ) // Returned if a DNS doesn't include account parameter. From cd0beebd81bbb672ce987a88eeebc832bcd1ba97 Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Wed, 2 Oct 2024 22:21:42 -0700 Subject: [PATCH 44/45] fix --- connection_configuration.go | 11 +++++------ test_data/connections.toml | 7 +++++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/connection_configuration.go b/connection_configuration.go index 534b567ba..33b3bcb00 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -58,6 +58,10 @@ func loadConnectionConfig() (*Config, error) { if err != nil { return nil, err } + err = fillMissingConfigParameters(cfg) + if err != nil { + return nil, err + } return cfg, err } @@ -74,11 +78,6 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } cfg.Token = v } - - err := fillMissingConfigParameters(cfg) - if err != nil { - return err - } return nil } @@ -319,7 +318,7 @@ func validateFilePermission(filePath string) error { return err } if permission := fileInfo.Mode().Perm(); permission != os.FileMode(0600) { - return err := &SnowflakeError{ + return &SnowflakeError{ Number: ErrCodeInvalidFilePermission, Message: errMsgInvalidPermissionToTomlFile, MessageArgs: []interface{}{permission}, diff --git a/test_data/connections.toml b/test_data/connections.toml index fb4d8e4ad..7fc07201b 100644 --- a/test_data/connections.toml +++ b/test_data/connections.toml @@ -35,5 +35,12 @@ testNot = 'problematicParameter' token_file_path = '/Users/test/.snowflake/token' [read-token] +account = 'snowdriverswarsaw.us-west-2.aws' +user = 'test_default_user' +password = 'test_default_pass' +warehouse = 'testw_default' +database = 'test_default_db' +schema = 'test_default_go' +protocol = 'https' authenticator = 'oauth' token_file_path = './test_data/snowflake/session/token' \ No newline at end of file From c995093e69fca88210ebb8c2ea35930f6f8d4a2a Mon Sep 17 00:00:00 2001 From: John Yun <140559986+sfc-gh-ext-simba-jy@users.noreply.github.com> Date: Tue, 8 Oct 2024 17:33:25 -0700 Subject: [PATCH 45/45] updated --- cmd/tomlfileconnection/tomlfileconnection.go | 20 ++-- connection_configuration.go | 104 +++++++++---------- connection_configuration_test.go | 1 - doc.go | 2 +- 4 files changed, 64 insertions(+), 63 deletions(-) diff --git a/cmd/tomlfileconnection/tomlfileconnection.go b/cmd/tomlfileconnection/tomlfileconnection.go index 91b91d9dc..8d34d9a3f 100644 --- a/cmd/tomlfileconnection/tomlfileconnection.go +++ b/cmd/tomlfileconnection/tomlfileconnection.go @@ -25,21 +25,23 @@ func main() { } defer db.Close() query := "SELECT 1" - rows, err := db.Query(query) // no cancel is allowed + rows, err := db.Query(query) if err != nil { log.Fatalf("failed to run a query. %v, err: %v", query, err) } defer rows.Close() var v int - for rows.Next() { - err := rows.Scan(&v) - if err != nil { - log.Fatalf("failed to get result. err: %v", err) - } - if v != 1 { - log.Fatalf("failed to get 1. got: %v", v) - } + if !rows.Next() { + log.Fatalf("no rows returned, expected 1") } + err = rows.Scan(&v) + if err != nil { + log.Fatalf("failed to get result. err: %v", err) + } + if v != 1 { + log.Fatalf("failed to get 1. got: %v", v) + } + if rows.Err() != nil { fmt.Printf("ERROR: %v\n", rows.Err()) return diff --git a/connection_configuration.go b/connection_configuration.go index 33b3bcb00..d2e191aef 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -65,8 +65,8 @@ func loadConnectionConfig() (*Config, error) { return cfg, err } -func parseToml(cfg *Config, connection map[string]interface{}) error { - for key, value := range connection { +func parseToml(cfg *Config, connectionMap map[string]interface{}) error { + for key, value := range connectionMap { if err := handleSingleParam(cfg, key, value); err != nil { return err } @@ -82,71 +82,71 @@ func parseToml(cfg *Config, connection map[string]interface{}) error { } func handleSingleParam(cfg *Config, key string, value interface{}) error { - var parsingErr error + var err error var v, tokenPath string switch strings.ToLower(key) { case "user", "username": - cfg.User, parsingErr = parseString(value) + cfg.User, err = parseString(value) case "password": - cfg.Password, parsingErr = parseString(value) + cfg.Password, err = parseString(value) case "host": - cfg.Host, parsingErr = parseString(value) + cfg.Host, err = parseString(value) case "account": - cfg.Account, parsingErr = parseString(value) + cfg.Account, err = parseString(value) case "warehouse": - cfg.Warehouse, parsingErr = parseString(value) + cfg.Warehouse, err = parseString(value) case "database": - cfg.Database, parsingErr = parseString(value) + cfg.Database, err = parseString(value) case "schema": - cfg.Schema, parsingErr = parseString(value) + cfg.Schema, err = parseString(value) case "role": - cfg.Role, parsingErr = parseString(value) + cfg.Role, err = parseString(value) case "region": - cfg.Region, parsingErr = parseString(value) + cfg.Region, err = parseString(value) case "protocol": - cfg.Protocol, parsingErr = parseString(value) + cfg.Protocol, err = parseString(value) case "passcode": - cfg.Passcode, parsingErr = parseString(value) + cfg.Passcode, err = parseString(value) case "port": - cfg.Port, parsingErr = parseInt(value) + cfg.Port, err = parseInt(value) case "passcodeinpassword": - cfg.PasscodeInPassword, parsingErr = parseBool(value) + cfg.PasscodeInPassword, err = parseBool(value) case "clienttimeout": - cfg.ClientTimeout, parsingErr = parseDuration(value) + cfg.ClientTimeout, err = parseDuration(value) case "jwtclienttimeout": - cfg.JWTClientTimeout, parsingErr = parseDuration(value) + cfg.JWTClientTimeout, err = parseDuration(value) case "logintimeout": - cfg.LoginTimeout, parsingErr = parseDuration(value) + cfg.LoginTimeout, err = parseDuration(value) case "requesttimeout": - cfg.RequestTimeout, parsingErr = parseDuration(value) + cfg.RequestTimeout, err = parseDuration(value) case "jwttimeout": - cfg.JWTExpireTimeout, parsingErr = parseDuration(value) + cfg.JWTExpireTimeout, err = parseDuration(value) case "externalbrowsertimeout": - cfg.ExternalBrowserTimeout, parsingErr = parseDuration(value) + cfg.ExternalBrowserTimeout, err = parseDuration(value) case "maxretrycount": - cfg.MaxRetryCount, parsingErr = parseInt(value) + cfg.MaxRetryCount, err = parseInt(value) case "application": - cfg.Application, parsingErr = parseString(value) + cfg.Application, err = parseString(value) case "authenticator": - v, parsingErr = parseString(value) - if err := checkParsingError(parsingErr, key, value); err != nil { + v, err = parseString(value) + if err = checkParsingError(err, key, value); err != nil { return err } - parsingErr = determineAuthenticatorType(cfg, v) + err = determineAuthenticatorType(cfg, v) case "insecuremode": - cfg.InsecureMode, parsingErr = parseBool(value) + cfg.InsecureMode, err = parseBool(value) case "ocspfailopen": var vv ConfigBool - vv, parsingErr = parseConfigBool(value) - if err := checkParsingError(parsingErr, key, value); err != nil { + vv, err = parseConfigBool(value) + if err := checkParsingError(err, key, value); err != nil { return err } cfg.OCSPFailOpen = OCSPFailOpenMode(vv) case "token": - cfg.Token, parsingErr = parseString(value) + cfg.Token, err = parseString(value) case "privatekey": - v, parsingErr = parseString(value) - if err := checkParsingError(parsingErr, key, value); err != nil { + v, err = parseString(value) + if err = checkParsingError(err, key, value); err != nil { return err } block, decodeErr := base64.URLEncoding.DecodeString(v) @@ -156,30 +156,30 @@ func handleSingleParam(cfg *Config, key string, value interface{}) error { Message: "Base64 decode failed", } } - cfg.PrivateKey, parsingErr = parsePKCS8PrivateKey(block) + cfg.PrivateKey, err = parsePKCS8PrivateKey(block) case "validatedefaultparameters": - cfg.ValidateDefaultParameters, parsingErr = parseConfigBool(value) + cfg.ValidateDefaultParameters, err = parseConfigBool(value) case "clientrequestmfatoken": - cfg.ClientRequestMfaToken, parsingErr = parseConfigBool(value) + cfg.ClientRequestMfaToken, err = parseConfigBool(value) case "clientstoretemporarycredential": - cfg.ClientStoreTemporaryCredential, parsingErr = parseConfigBool(value) + cfg.ClientStoreTemporaryCredential, err = parseConfigBool(value) case "tracing": - cfg.Tracing, parsingErr = parseString(value) + cfg.Tracing, err = parseString(value) case "tmpdirpath": - cfg.TmpDirPath, parsingErr = parseString(value) + cfg.TmpDirPath, err = parseString(value) case "disablequerycontextcache": - cfg.DisableQueryContextCache, parsingErr = parseBool(value) + cfg.DisableQueryContextCache, err = parseBool(value) case "includeretryreason": - cfg.IncludeRetryReason, parsingErr = parseConfigBool(value) + cfg.IncludeRetryReason, err = parseConfigBool(value) case "clientconfigfile": - cfg.ClientConfigFile, parsingErr = parseString(value) + cfg.ClientConfigFile, err = parseString(value) case "disableconsolelogin": - cfg.DisableConsoleLogin, parsingErr = parseConfigBool(value) + cfg.DisableConsoleLogin, err = parseConfigBool(value) case "disablesamlurlcheck": - cfg.DisableSamlURLCheck, parsingErr = parseConfigBool(value) + cfg.DisableSamlURLCheck, err = parseConfigBool(value) case "token_file_path": - tokenPath, parsingErr = parseString(value) - if err := checkParsingError(parsingErr, key, value); err != nil { + tokenPath, err = parseString(value) + if err = checkParsingError(err, key, value); err != nil { return err } v, err := readToken(tokenPath) @@ -188,18 +188,18 @@ func handleSingleParam(cfg *Config, key string, value interface{}) error { } cfg.Token = v default: - param, parsingErr := parseString(value) - if err := checkParsingError(parsingErr, key, value); err != nil { + param, err := parseString(value) + if err = checkParsingError(err, key, value); err != nil { return err } cfg.Params[urlDecodeIfNeeded(key)] = ¶m } - return checkParsingError(parsingErr, key, value) + return checkParsingError(err, key, value) } -func checkParsingError(parsingErr error, key string, value interface{}) error { - if parsingErr != nil { - err := &SnowflakeError{ +func checkParsingError(err error, key string, value interface{}) error { + if err != nil { + err = &SnowflakeError{ Number: ErrCodeTomlFileParsingFailed, Message: errMsgFailedToParseTomlFile, MessageArgs: []interface{}{key, value}, diff --git a/connection_configuration_test.go b/connection_configuration_test.go index 6a21c850f..18612ddbf 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -301,5 +301,4 @@ func TestGetTomlFilePath(t *testing.T) { assertNilF(t, err, "should not have failed") assertEqualF(t, dir, result) } - } diff --git a/doc.go b/doc.go index 7113c2d96..cf095d51d 100644 --- a/doc.go +++ b/doc.go @@ -169,7 +169,7 @@ Note: GOSNOWFLAKE_SKIP_REGISTERATION should not be used if sql.Open() is used as to connect to the server, as sql.Open will require registration so it can map the driver name to the driver type, which in this case is "snowflake" and SnowflakeDriver{}. -After Version 1.11.2 and later, you can load the connnection configuration with .toml file format. +You can load the connnection configuration with .toml file format. With two environment variables SNOWFLAKE_HOME(connections.toml file directory) SNOWFLAKE_DEFAULT_CONNECTION_NAME(DSN name), the driver will search the config file and load the connection. You can find how to use this connection way at ./cmd/tomlfileconnection or Snowflake doc: https://docs.snowflake.com/en/developer-guide/snowflake-cli-v2/connecting/specify-credentials