diff --git a/cmd/tomlfileconnection/.gitignore b/cmd/tomlfileconnection/.gitignore new file mode 100644 index 000000000..3b1a12e76 --- /dev/null +++ b/cmd/tomlfileconnection/.gitignore @@ -0,0 +1 @@ +tomlfileconnection.go \ No newline at end of file diff --git a/cmd/tomlfileconnection/Makefile b/cmd/tomlfileconnection/Makefile new file mode 100644 index 000000000..7813cb2e1 --- /dev/null +++ b/cmd/tomlfileconnection/Makefile @@ -0,0 +1,16 @@ +include ../../gosnowflake.mak +CMD_TARGET=tomlfileconnection + +## Install +install: cinstall + +## Run +run: crun + +## Lint +lint: clint + +## Format source codes +fmt: cfmt + +.PHONY: install run lint fmt diff --git a/cmd/tomlfileconnection/tomlfileconnection.go b/cmd/tomlfileconnection/tomlfileconnection.go new file mode 100644 index 000000000..f82b0b629 --- /dev/null +++ b/cmd/tomlfileconnection/tomlfileconnection.go @@ -0,0 +1,58 @@ +// Example: How to connect to the server with the toml file configuration +// Prerequiste: following the Snowflake doc: https://docs.snowflake.com/en/developer-guide/snowflake-cli-v2/connecting/specify-credentials +package main + +import ( + "database/sql" + "flag" + "fmt" + "log" + "os" + + sf "github.com/snowflakedb/gosnowflake" +) + +func main() { + if !flag.Parsed() { + flag.Parse() + } + + os.Setenv("SNOWFLAKE_HOME", "") + os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "") + + cfg, err := sf.LoadConnectionConfig() + if err != nil { + log.Fatalf("failed to create Config, err: %v", err) + } + dsn, err := sf.DSN(cfg) + if err != nil { + log.Fatalf("failed to create DSN from Config: %v, err: %v", cfg, err) + } + + db, err := sql.Open("snowflake", dsn) + if err != nil { + log.Fatalf("failed to connect. %v, err: %v", dsn, err) + } + defer db.Close() + query := "SELECT 1" + rows, err := db.Query(query) // no cancel is allowed + if err != nil { + log.Fatalf("failed to run a query. %v, err: %v", query, err) + } + defer rows.Close() + var v int + for rows.Next() { + err := rows.Scan(&v) + if err != nil { + log.Fatalf("failed to get result. err: %v", err) + } + if v != 1 { + log.Fatalf("failed to get 1. got: %v", v) + } + } + if rows.Err() != nil { + fmt.Printf("ERROR: %v\n", rows.Err()) + return + } + fmt.Printf("Congrats! You have successfully run %v with Snowflake DB!\n", query) +} diff --git a/connection_configuration.go b/connection_configuration.go new file mode 100644 index 000000000..c3660015e --- /dev/null +++ b/connection_configuration.go @@ -0,0 +1,467 @@ +// Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + +package gosnowflake + +import ( + "encoding/base64" + "errors" + "os" + path "path/filepath" + "strconv" + "strings" + "time" + + toml "github.com/BurntSushi/toml" +) + +// LoadConnectionConfig returns connection configs loaded from the toml file. +// By default, SNOWFLAKE_HOME(toml file path) is os.home/snowflake +// and SNOWFLAKE_DEFAULT_CONNECTION_NAME(DSN) is 'default' +func LoadConnectionConfig() (*Config, error) { + cfg := &Config{ + Params: make(map[string]*string), + Authenticator: AuthTypeSnowflake, // Default to snowflake + } + dsn := getConnectionDSN(os.Getenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME")) + snowflakeConfigDir, err := getTomlFilePath(os.Getenv("SNOWFLAKE_HOME")) + if err != nil { + return nil, err + } + tomlFilePath := path.Join(snowflakeConfigDir, "connections.toml") + err = validateFilePermission(tomlFilePath) + if err != nil { + return nil, err + } + tomlInfo := make(map[string]interface{}) + _, err = toml.DecodeFile(tomlFilePath, &tomlInfo) + if err != nil { + return nil, err + } + connectionName, exist := tomlInfo[dsn] + if !exist { + return nil, &SnowflakeError{ + Number: ErrCodeFailedToFindDSNInToml, + Message: errMsgFailedToFindDSNInTomlFile, + } + } + connectionConfig, ok := connectionName.(map[string]interface{}) + if !ok { + return nil, err + } + err = parseToml(cfg, connectionConfig) + if err != nil { + return nil, err + } + return cfg, err +} + +func parseToml(cfg *Config, connection map[string]interface{}) error { + var v, tokenPath string + var parsingErr error + var vv bool + err := &SnowflakeError{ + Number: ErrCodeTomlFileParsingFailed, + Message: errMsgFailedToParseTomlFile, + } + for key, value := range connection { + switch strings.ToLower(key) { + case "user", "username": + cfg.User, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "password": + cfg.Password, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "host": + cfg.Host, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "account": + cfg.Account, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "warehouse": + cfg.Warehouse, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "database": + cfg.Database, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "schema": + cfg.Schema, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "role": + cfg.Role, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "region": + cfg.Region, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "protocol": + cfg.Protocol, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "passcode": + cfg.Passcode, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "port": + if cfg.Port, parsingErr = parseInt(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "passcodeinpassword": + if cfg.PasscodeInPassword, parsingErr = parseBool(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "clienttimeout": + if cfg.ClientTimeout, parsingErr = parseDuration(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "jwtclienttimeout": + if cfg.JWTClientTimeout, parsingErr = parseDuration(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "logintimeout": + if cfg.LoginTimeout, parsingErr = parseDuration(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "requesttimeout": + if cfg.RequestTimeout, parsingErr = parseDuration(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "jwttimeout": + if cfg.JWTExpireTimeout, parsingErr = parseDuration(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "externalbrowsertimeout": + if cfg.ExternalBrowserTimeout, parsingErr = parseDuration(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "maxretrycount": + if cfg.MaxRetryCount, parsingErr = parseInt(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "application": + cfg.Application, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "authenticator": + v, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + parsingErr = determineAuthenticatorType(cfg, v) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "insecuremode": + if cfg.InsecureMode, parsingErr = parseBool(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "ocspfailopen": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + if vv { + cfg.OCSPFailOpen = OCSPFailOpenTrue + } else { + cfg.OCSPFailOpen = OCSPFailOpenFalse + } + + case "token": + cfg.Token, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "privatekey": + v, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + block, decodeErr := base64.URLEncoding.DecodeString(v) + if decodeErr != nil { + return &SnowflakeError{ + Number: ErrCodePrivateKeyParseError, + Message: "Base64 decode failed", + } + } + cfg.PrivateKey, parsingErr = parsePKCS8PrivateKey(block) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "validatedefaultparameters": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + if vv { + cfg.ValidateDefaultParameters = ConfigBoolTrue + } else { + cfg.ValidateDefaultParameters = ConfigBoolFalse + } + case "clientrequestmfatoken": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + if vv { + cfg.ClientRequestMfaToken = ConfigBoolTrue + } else { + cfg.ClientRequestMfaToken = ConfigBoolFalse + } + case "clientstoretemporarycredential": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + if vv { + cfg.ClientStoreTemporaryCredential = ConfigBoolTrue + } else { + cfg.ClientStoreTemporaryCredential = ConfigBoolFalse + } + case "tracing": + cfg.Tracing, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "tmpdirpath": + cfg.TmpDirPath, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "disablequerycontextcache": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + cfg.DisableQueryContextCache = vv + case "includeretryreason": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + if vv { + cfg.IncludeRetryReason = ConfigBoolTrue + } else { + cfg.IncludeRetryReason = ConfigBoolFalse + } + case "clientconfigfile": + cfg.ClientConfigFile, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + case "disableconsolelogin": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + if vv { + cfg.DisableConsoleLogin = ConfigBoolTrue + } else { + cfg.DisableConsoleLogin = ConfigBoolFalse + } + case "disablesamlurlcheck": + if vv, parsingErr = parseBool(value); parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + if vv { + cfg.DisableSamlURLCheck = ConfigBoolTrue + } else { + cfg.DisableSamlURLCheck = ConfigBoolFalse + } + case "token_file_path": + tokenPath, parsingErr = parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + default: + param, parsingErr := parseString(value) + if parsingErr != nil { + err.MessageArgs = []interface{}{key, value} + return err + } + cfg.Params[urlDecodeIfNeeded(key)] = ¶m + } + } + if shouldReadTokenFromFile(cfg) { + v, err := readToken(tokenPath) + if err != nil { + return err + } + cfg.Token = v + } + return nil +} + +func parseInt(i interface{}) (int, error) { + if _, ok := i.(string); !ok { + if _, ok := i.(int); !ok { + return 0, errors.New("failed to parse the value to integer") + } + num := i.(int) + return num, nil + } + v := i.(string) + num, err := strconv.Atoi(v) + + if err != nil { + return num, err + } + return num, nil +} + +func parseBool(i interface{}) (bool, error) { + v, ok := i.(string) + if !ok { + if _, ok := i.(bool); !ok { + return false, errors.New("failed to parse the value to boolean") + } + vv := i.(bool) + return vv, nil + } + vv, err := strconv.ParseBool(v) + if err != nil { + return false, errors.New("failed to parse the value to boolean") + } + return vv, nil +} + +func parseDuration(i interface{}) (time.Duration, error) { + v, ok := i.(string) + if !ok { + num, err := parseInt(i) + if err != nil { + return time.Duration(0), err + } + t := int64(num) + return time.Duration(t * int64(time.Second)), nil + } + t, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return time.Duration(0), err + } + return time.Duration(t * int64(time.Second)), nil +} + +func readToken(tokenPath string) (string, error) { + if tokenPath == "" { + tokenPath = "./snowflake/session/token" + } + if !path.IsAbs(tokenPath) { + snowflakeConfigDir, err := getTomlFilePath(os.Getenv("SNOWFLAKE_HOME")) + if err != nil { + return "", err + } + tokenPath = path.Join(snowflakeConfigDir, tokenPath) + } + err := validateFilePermission(tokenPath) + if err != nil { + return "", err + } + token, err := os.ReadFile(tokenPath) + if err != nil { + return "", err + } + return string(token), nil +} + +func parseString(i interface{}) (string, error) { + v, ok := i.(string) + if !ok { + return "", errors.New("failed to convert the value to string") + } + return v, nil +} + +func getTomlFilePath(filePath string) (string, error) { + if len(filePath) != 0 { + if path.IsAbs(filePath) { + return filePath, nil + } + } else { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", err + } + filePath = path.Join(homeDir, "snowflake") + } + absDir, err := path.Abs(filePath) + if err != nil { + return "", err + } + return absDir, nil +} + +func getConnectionDSN(dsn string) string { + if len(dsn) != 0 { + return dsn + } + return "default" +} + +func validateFilePermission(filePath string) error { + if isWindows { + return nil + } + fileInfo, err := os.Stat(filePath) + if err != nil { + return err + } + if permission := fileInfo.Mode().Perm(); permission != os.FileMode(0600) { + return errors.New("your access to the file was denied") + } + return nil +} + +func shouldReadTokenFromFile(cfg *Config) bool { + return cfg != nil && cfg.Authenticator == AuthTypeOAuth && len(cfg.Token) == 0 +} diff --git a/connection_configuration_test.go b/connection_configuration_test.go new file mode 100644 index 000000000..7ee2df0dd --- /dev/null +++ b/connection_configuration_test.go @@ -0,0 +1,260 @@ +// Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + +package gosnowflake + +import ( + "io/fs" + "os" + path "path/filepath" + "testing" + "time" +) + +func TestLoadConnectionConfig_Default(t *testing.T) { + if !isWindows { + _, err := LoadConnectionConfig() + assertNotNilF(t, err, "The error should occur because you cannot change the file permission") + } + + err := os.Chmod("./test_data/connections.toml", 0600) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + + os.Setenv("SNOWFLAKE_HOME", "./test_data") + + cfg, err := LoadConnectionConfig() + assertNilF(t, err, "The error should not occur") + assertEqualF(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") + assertEqualF(t, cfg.User, "test_user") + assertEqualF(t, cfg.Password, "test_pass") + assertEqualF(t, cfg.Warehouse, "testw") + assertEqualF(t, cfg.Database, "test_db") + assertEqualF(t, cfg.Schema, "test_go") + assertEqualF(t, cfg.Protocol, "https") + assertEqualF(t, cfg.Port, 443) +} + +func TestLoadConnectionConfig_OAuth(t *testing.T) { + err := os.Chmod("./test_data/connections.toml", 0600) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + + os.Setenv("SNOWFLAKE_HOME", "./test_data") + os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth") + + cfg, err := LoadConnectionConfig() + assertNilF(t, err, "The error should not occur") + assertEqualF(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") + assertEqualF(t, cfg.User, "test_user") + assertEqualF(t, cfg.Password, "test_pass") + assertEqualF(t, cfg.Warehouse, "testw") + assertEqualF(t, cfg.Database, "test_db") + assertEqualF(t, cfg.Schema, "test_go") + assertEqualF(t, cfg.Protocol, "https") + assertEqualF(t, cfg.Authenticator, AuthTypeOAuth) + assertEqualF(t, cfg.Token, "token_value") + assertEqualF(t, cfg.Port, 443) +} + +func TestReadTokenValueWithTokenFilePath(t *testing.T) { + err := os.Chmod("./test_data/connections.toml", 0600) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + + err = os.Chmod("./test_data/snowflake/session/token", 0600) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + + os.Setenv("SNOWFLAKE_HOME", "./test_data") + os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "no-token-path") + + cfg, err := LoadConnectionConfig() + assertNilF(t, err, "The error should not occur") + assertEqualF(t, cfg.Authenticator, AuthTypeOAuth) + assertEqualF(t, cfg.Token, "mock_token123456") + + os.Setenv("SNOWFLAKE_HOME", "./test_data") + os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "read-token") + + cfg, err = LoadConnectionConfig() + assertNilF(t, err, "The error should not occur") + assertEqualF(t, cfg.Authenticator, AuthTypeOAuth) + assertEqualF(t, cfg.Token, "mock_token123456") +} + +func TestLoadConnectionConfigWitNonExistingDSN(t *testing.T) { + err := os.Chmod("./test_data/connections.toml", 0600) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + + os.Setenv("SNOWFLAKE_HOME", "./test_data") + os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "unavailableDSN") + + _, err = LoadConnectionConfig() + assertNotNilF(t, err, "The error should occur") + + driverErr, ok := err.(*SnowflakeError) + assertTrueF(t, ok, "This should be a Snowflake Error") + assertEqualF(t, driverErr.Number, ErrCodeFailedToFindDSNInToml) +} + +func TestLoadConnectionConfigWithTokenFileNotExist(t *testing.T) { + err := os.Chmod("./test_data/connections.toml", 0600) + assertNilF(t, err, "The error occurred because you cannot change the file permission") + + os.Setenv("SNOWFLAKE_HOME", "./test_data") + os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "aws-oauth-file") + + _, err = LoadConnectionConfig() + assertNotNilF(t, err, "The error should occur") + + _, ok := err.(*(fs.PathError)) + assertTrueF(t, ok, "This error should be a path error") +} + +func TestParseInt(t *testing.T) { + var i interface{} + + i = 20 + num, err := parseInt(i) + assertNilF(t, err, "This value should be parsed") + assertEqualF(t, num, 20) + + i = "40" + num, err = parseInt(i) + assertNilF(t, err, "This value should be parsed") + assertEqualF(t, num, 40) + + i = "wrong_num" + _, err = parseInt(i) + assertNotNilF(t, err, "should have failed") +} + +func TestParseBool(t *testing.T) { + var i interface{} + + i = true + b, err := parseBool(i) + assertNilF(t, err, "This value should be parsed") + assertEqualF(t, b, true) + + i = "false" + b, err = parseBool(i) + assertNilF(t, err, "This value should be parsed") + assertEqualF(t, b, false) + + i = "wrong_bool" + _, err = parseInt(i) + assertNotNilF(t, err, "should have failed") +} + +func TestParseDuration(t *testing.T) { + var i interface{} + + i = 300 + dur, err := parseDuration(i) + assertNilF(t, err, "This value should be parsed") + assertEqualF(t, dur, time.Duration(5*int64(time.Minute))) + + i = "30" + dur, err = parseDuration(i) + assertNilF(t, err, "This value should be parsed") + assertEqualF(t, dur, time.Duration(int64(time.Minute)/2)) + + i = false + _, err = parseDuration(i) + assertNotNilF(t, err, "should have failed") +} + +type paramList struct { + testParams []string + values []interface{} +} + +func TestParseToml(t *testing.T) { + testCases := []paramList{ + { + testParams: []string{"user", "password", "host", "account", "warehouse", "database", + "schema", "role", "region", "protocol", "passcode", "application", "token", + "tracing", "tmpDirPath", "clientConfigFile"}, + values: []interface{}{"value"}, + }, + { + testParams: []string{"privatekey"}, + values: []interface{}{generatePKCS8StringSupress(testPrivKey)}, + }, + { + testParams: []string{"port", "maxRetryCount", "clientTimeout", "jwtClientTimeout", "loginTimeout", + "requestTimeout", "jwtTimeout", "externalBrowserTimeout"}, + values: []interface{}{"300", 500}, + }, + { + testParams: []string{"ocspFailOpen", "insecureMode", "PasscodeInPassword", "validateDEFAULTParameters", "clientRequestMFAtoken", + "clientStoreTemporaryCredential", "disableQueryContextCache", "includeRetryReason", "disableConsoleLogin", "disableSamlUrlCheck"}, + values: []interface{}{true, "true", false, "false"}, + }, + } + + for _, testCase := range testCases { + for _, param := range testCase.testParams { + for _, value := range testCase.values { + t.Run(param, func(t *testing.T) { + cfg := &Config{} + connectionMap := make(map[string]interface{}) + connectionMap[param] = value + err := parseToml(cfg, connectionMap) + assertNilF(t, err, "The value should be parsed") + }) + } + } + } +} + +func TestParseTomlWithWrongValue(t *testing.T) { + testCases := []paramList{ + { + testParams: []string{"user", "password", "host", "account", "warehouse", "database", + "schema", "role", "region", "protocol", "passcode", "application", "token", "privateKey", + "tracing", "tmpDirPath", "clientConfigFile", "wrongParams", "token_file_path"}, + values: []interface{}{1, false}, + }, + { + testParams: []string{"port", "maxRetryCount", "clientTimeout", "jwtClientTimeout", "loginTimeout", + "requestTimeout", "jwtTimeout", "externalBrowserTimeout", "authenticator"}, + values: []interface{}{"wrong_value", false}, + }, + { + testParams: []string{"ocspFailOpen", "insecureMode", "PasscodeInPassword", "validateDEFAULTParameters", "clientRequestMFAtoken", + "clientStoreTemporaryCredential", "disableQueryContextCache", "includeRetryReason", "disableConsoleLogin", "disableSamlUrlCheck"}, + values: []interface{}{"wrong_value", 1}, + }, + } + + for _, testCase := range testCases { + for _, param := range testCase.testParams { + for _, value := range testCase.values { + t.Run(param, func(t *testing.T) { + cfg := &Config{} + connectionMap := make(map[string]interface{}) + connectionMap[param] = value + err := parseToml(cfg, connectionMap) + assertNotNilF(t, err, "should have failed") + driverErr, ok := err.(*SnowflakeError) + assertTrueF(t, ok, "This should be a Snowflake Error") + assertEqualF(t, driverErr.Number, ErrCodeTomlFileParsingFailed) + }) + + } + } + } +} + +func TestGetTomlFilePath(t *testing.T) { + dir, err := getTomlFilePath("") + assertNilF(t, err, "should not have failed") + homeDir, err := os.UserHomeDir() + assertNilF(t, err, "The connection cannot find the user home directory") + assertEqualF(t, dir, path.Join(homeDir, "snowflake")) + + location := "../user//somelocation///b" + dir, err = getTomlFilePath(location) + assertNilF(t, err, "should not have failed") + result, err := path.Abs(location) + assertNilF(t, err, "should not have failed") + assertEqualF(t, dir, result) +} diff --git a/doc.go b/doc.go index ec1865151..1b923fee3 100644 --- a/doc.go +++ b/doc.go @@ -169,6 +169,11 @@ Note: GOSNOWFLAKE_SKIP_REGISTERATION should not be used if sql.Open() is used as to connect to the server, as sql.Open will require registration so it can map the driver name to the driver type, which in this case is "snowflake" and SnowflakeDriver{}. +After Version 1.11.2 and later, you can load the connnection configuration with .toml file format. +With two environment variables SNOWFLAKE_HOME(connections.toml file directory) SNOWFLAKE_DEFAULT_CONNECTION_NAME(DSN name), +the driver will search the config file and load the connection. You can find how to use this connection way at ./cmd/tomlfileconnection +or Snowflake doc: https://docs.snowflake.com/en/developer-guide/snowflake-cli-v2/connecting/specify-credentials + # Proxy The Go Snowflake Driver honors the environment variables HTTP_PROXY, HTTPS_PROXY and NO_PROXY for the forward proxy setting. diff --git a/errors.go b/errors.go index 250af2e4f..7e56040fe 100644 --- a/errors.go +++ b/errors.go @@ -127,6 +127,10 @@ const ( ErrCodeFailedToParseAuthenticator = 260011 // ErrCodeClientConfigFailed is an error code for the case where clientConfigFile is invalid or applying client configuration fails ErrCodeClientConfigFailed = 260012 + // ErrCodeTomlFileParsingFailed is an error code for the case where parsing the toml file is failed because of invalid value. + ErrCodeTomlFileParsingFailed = 260013 + // ErrCodeFailedToFindDSNInToml is an error code for the case where the DSN does not exist in the toml file. + ErrCodeFailedToFindDSNInToml = 260014 /* network */ @@ -299,6 +303,8 @@ const ( errMsgClientConfigFailed = "client configuration failed: %v" errMsgNullValueInArray = "for handling null values in arrays use WithArrayValuesNullable(ctx)" errMsgNullValueInMap = "for handling null values in maps use WithMapValuesNullable(ctx)" + errMsgFailedToParseTomlFile = "failed to parse toml file. the params %v occurred error with value %v" + errMsgFailedToFindDSNInTomlFile = "failed to find DSN in toml file." ) // Returned if a DNS doesn't include account parameter. diff --git a/go.mod b/go.mod index 7749971f2..5fb788136 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/99designs/keyring v1.2.2 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0 github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0 + github.com/BurntSushi/toml v1.4.0 github.com/apache/arrow/go/v15 v15.0.0 github.com/aws/aws-sdk-go-v2 v1.26.1 github.com/aws/aws-sdk-go-v2/credentials v1.17.11 diff --git a/go.sum b/go.sum index e5b3ef444..6e99023fa 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLC github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0 h1:u/LLAOFgsMv7HmNL4Qufg58y+qElGOt5qv0z1mURkRY= github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0/go.mod h1:2e8rMJtl2+2j+HXbTBwnyGpm5Nou7KhvSfxOq8JpTag= github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1 h1:BWe8a+f/t+7KY7zH2mqygeUD0t8hNFXe08p1Pb3/jKE= +github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= +github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= github.com/apache/arrow/go/v15 v15.0.0 h1:1zZACWf85oEZY5/kd9dsQS7i+2G5zVQcbKTHgslqHNA= diff --git a/test_data/connections.toml b/test_data/connections.toml new file mode 100644 index 000000000..878e4fc23 --- /dev/null +++ b/test_data/connections.toml @@ -0,0 +1,42 @@ +[default] +account = 'snowdriverswarsaw.us-west-2.aws' +user = 'test_user' +password = 'test_pass' +warehouse = 'testw' +database = 'test_db' +schema = 'test_go' +protocol = 'https' +port = '443' + +[aws-oauth] +account = 'snowdriverswarsaw.us-west-2.aws' +user = 'test_user' +password = 'test_pass' +warehouse = 'testw' +database = 'test_db' +schema = 'test_go' +protocol = 'https' +port = '443' +authenticator = 'oauth' +testNot = 'problematicParameter' +token = 'token_value' + +[aws-oauth-file] +account = 'snowdriverswarsaw.us-west-2.aws' +user = 'test_user' +password = 'test_pass' +warehouse = 'testw' +database = 'test_db' +schema = 'test_go' +protocol = 'https' +port = '443' +authenticator = 'oauth' +testNot = 'problematicParameter' +token_file_path = '/Users/test/.snowflake/token' + +[read-token] +authenticator = 'oauth' +token_file_path = './snowflake/session/token' + +[no-token-path] +authenticator = 'oauth' \ No newline at end of file diff --git a/test_data/snowflake/session/token b/test_data/snowflake/session/token new file mode 100644 index 000000000..6db96c0a3 --- /dev/null +++ b/test_data/snowflake/session/token @@ -0,0 +1 @@ +mock_token123456 \ No newline at end of file