From b0f234f11b555031ff7f197ee70e652ada9c8610 Mon Sep 17 00:00:00 2001 From: Vivek R Date: Wed, 29 May 2024 17:44:06 +0530 Subject: [PATCH] refactor: redis package to v3 spec --- stores/redis/go.mod | 2 +- stores/redis/store.go | 410 +++++++++++++++------------ stores/redis/store_test.go | 551 ++++++++++++++++++++++--------------- 3 files changed, 557 insertions(+), 406 deletions(-) diff --git a/stores/redis/go.mod b/stores/redis/go.mod index b268fed..0177f52 100644 --- a/stores/redis/go.mod +++ b/stores/redis/go.mod @@ -1,4 +1,4 @@ -module github.com/vividvilla/simplesessions/stores/goredis/v9 +module github.com/vividvilla/simplesessions/stores/redis/v3 go 1.18 diff --git a/stores/redis/store.go b/stores/redis/store.go index 82c159f..5fa4d1b 100644 --- a/stores/redis/store.go +++ b/stores/redis/store.go @@ -1,22 +1,19 @@ -package goredis +package redis import ( "context" - "crypto/rand" + "strconv" "time" - "unicode" "github.com/redis/go-redis/v9" - "github.com/vividvilla/simplesessions/conv" ) var ( // Error codes for store errors. This should match the codes // defined in the /simplesessions package exactly. ErrInvalidSession = &Err{code: 1, msg: "invalid session"} - ErrFieldNotFound = &Err{code: 2, msg: "field not found"} + ErrNil = &Err{code: 2, msg: "nil returned"} ErrAssertType = &Err{code: 3, msg: "assertion failed"} - ErrNil = &Err{code: 4, msg: "nil returned"} ) type Err struct { @@ -37,6 +34,8 @@ func (e *Err) Code() int { type Store struct { // Maximum lifetime sessions has to be persisted. ttl time.Duration + // extend TTL on update. + extendTTL bool // Prefix for session id. prefix string @@ -49,7 +48,9 @@ type Store struct { const ( // Default prefix used to store session redis defaultPrefix = "session:" - sessionIDLen = 32 + // Default key used when session is created. + // Its not possible to have empty map in Redis. + defaultSessKey = "_ss" ) // New creates a new Redis store instance. @@ -67,86 +68,58 @@ func (s *Store) SetPrefix(val string) { } // SetTTL sets TTL for session in redis. -func (s *Store) SetTTL(d time.Duration) { +// if isExtend is true then ttl is updated on all set/setmulti. +// otherwise its set only on create(). +func (s *Store) SetTTL(d time.Duration, extend bool) { s.ttl = d + s.extendTTL = extend } // Create returns a new session id but doesn't stores it in redis since empty hashmap can't be created. -func (s *Store) Create() (string, error) { - id, err := generateID(sessionIDLen) - if err != nil { - return "", err +func (s *Store) Create(id string) error { + // Create the session in backend with default session key since + // Redis doesn't support empty hashmap and its impossible to + // check if the session exist or not. + p := s.client.TxPipeline() + p.HSet(s.clientCtx, s.prefix+id, defaultSessKey, "1") + if s.ttl > 0 { + p.Expire(s.clientCtx, s.prefix+id, s.ttl) } - - return id, err + _, err := p.Exec(s.clientCtx) + return err } // Get gets a field in hashmap. If field is nill then ErrFieldNotFound is raised func (s *Store) Get(id, key string) (interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - - pipe := s.client.TxPipeline() - exists := pipe.Exists(s.clientCtx, s.prefix+id) - get := pipe.HGet(s.clientCtx, s.prefix+id, key) - _, err := pipe.Exec(s.clientCtx) - // redis.Nil is returned if a field does not exist. - // Ignore the error and check for key existence check. - if err != nil && err != redis.Nil { + vals, err := s.client.HMGet(s.clientCtx, s.prefix+id, defaultSessKey, key).Result() + if err != nil { return nil, err } - // Check if key exists and return ErrInvalidSession if not. - if ex, err := exists.Result(); err != nil { - return nil, err - } else if ex == 0 { + if vals[0] == nil { return nil, ErrInvalidSession } - v, err := get.Result() - if err != nil && err == redis.Nil { - return nil, ErrFieldNotFound - } - - return v, nil + return vals[1], nil } // GetMulti gets a map for values for multiple keys. If key is not found then its set as nil. func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - - pipe := s.client.TxPipeline() - exists := pipe.Exists(s.clientCtx, s.prefix+id) - get := pipe.HMGet(s.clientCtx, s.prefix+id, keys...) - _, err := pipe.Exec(s.clientCtx) - // redis.Nil is returned if a field does not exist. - // Ignore the error and check for key existence check. - if err != nil && err != redis.Nil { + allKeys := append([]string{defaultSessKey}, keys...) + vals, err := s.client.HMGet(s.clientCtx, s.prefix+id, allKeys...).Result() + if err != nil { return nil, err } - // Check if key exists and return ErrInvalidSession if not. - if ex, err := exists.Result(); err != nil { - return nil, err - } else if ex == 0 { + if vals[0] == nil { return nil, ErrInvalidSession } - v, err := get.Result() - if err != nil { - return nil, err - } - // Form a map with returned results res := make(map[string]interface{}) - for i, k := range keys { - if v[i] == nil { - res[k] = ErrFieldNotFound - } else { - res[k] = v[i] + for i, k := range allKeys { + if k != defaultSessKey { + res[k] = vals[i] } } @@ -155,187 +128,272 @@ func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, err // GetAll gets all fields from hashmap. func (s *Store) GetAll(id string) (map[string]interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - - pipe := s.client.TxPipeline() - exists := pipe.Exists(s.clientCtx, s.prefix+id) - get := pipe.HGetAll(s.clientCtx, s.prefix+id) - _, err := pipe.Exec(s.clientCtx) - // redis.Nil is returned if a field does not exist. - // Ignore the error and check for key existence check. - if err != nil && err != redis.Nil { - return nil, err - } - - // Check if key exists and return ErrInvalidSession if not. - if ex, err := exists.Result(); err != nil { - return nil, err - } else if ex == 0 { - return nil, ErrInvalidSession - } - - res, err := get.Result() + vals, err := s.client.HGetAll(s.clientCtx, s.prefix+id).Result() if err != nil { return nil, err } // Convert results to type `map[string]interface{}` - out := make(map[string]interface{}, len(res)) - for k, v := range res { - out[k] = v + out := make(map[string]interface{}) + for k, v := range vals { + if k != defaultSessKey { + out[k] = v + } } return out, nil } // Set sets a value to given session. +// If session is not present in backend then its still written. func (s *Store) Set(id, key string, val interface{}) error { - if !validateID(id) { - return ErrInvalidSession - } - - pipe := s.client.TxPipeline() - pipe.HSet(s.clientCtx, s.prefix+id, key, val) + p := s.client.TxPipeline() + p.HSet(s.clientCtx, s.prefix+id, key, val) + p.HSet(s.clientCtx, s.prefix+id, defaultSessKey, "1") // Set expiry of key only if 'ttl' is set, this is to // ensure that the key remains valid indefinitely like // how redis handles it by default - if s.ttl > 0 { - pipe.Expire(s.clientCtx, s.prefix+id, s.ttl) + if s.ttl > 0 && s.extendTTL { + p.Expire(s.clientCtx, s.prefix+id, s.ttl) } - _, err := pipe.Exec(s.clientCtx) + _, err := p.Exec(s.clientCtx) return err } // Set sets a value to given session. func (s *Store) SetMulti(id string, data map[string]interface{}) error { - if !validateID(id) { - return ErrInvalidSession - } - // Make slice of arguments to be passed in HGETALL command - args := []interface{}{} + args := []interface{}{defaultSessKey, "1"} for k, v := range data { args = append(args, k, v) } - pipe := s.client.TxPipeline() - pipe.HMSet(s.clientCtx, s.prefix+id, args...) + p := s.client.TxPipeline() + p.HMSet(s.clientCtx, s.prefix+id, args...) // Set expiry of key only if 'ttl' is set, this is to // ensure that the key remains valid indefinitely like // how redis handles it by default - if s.ttl > 0 { - pipe.Expire(s.clientCtx, s.prefix+id, s.ttl) + if s.ttl > 0 && s.extendTTL { + p.Expire(s.clientCtx, s.prefix+id, s.ttl) } - _, err := pipe.Exec(s.clientCtx) + _, err := p.Exec(s.clientCtx) return err } // Delete deletes a key from redis session hashmap. func (s *Store) Delete(id string, key string) error { - if !validateID(id) { - return ErrInvalidSession - } - - pipe := s.client.TxPipeline() - exists := pipe.Exists(s.clientCtx, s.prefix+id) - del := pipe.HDel(s.clientCtx, s.prefix+id, key) - _, err := pipe.Exec(s.clientCtx) - // redis.Nil is returned if a field does not exist. - // Ignore the error and check for key existence check. - if err != nil && err != redis.Nil { - return err - } - - // Check if key exists and return ErrInvalidSession if not. - if ex, err := exists.Result(); err != nil { - return err - } else if ex == 0 { - return ErrInvalidSession - } - - if v, err := del.Result(); err != nil { - return err - } else if v == 0 { - return ErrFieldNotFound - } - - return nil + return s.client.HDel(s.clientCtx, s.prefix+id, key).Err() } // Clear clears session in redis. func (s *Store) Clear(id string) error { - if !validateID(id) { - return ErrInvalidSession - } - return s.client.Del(s.clientCtx, s.prefix+id).Err() } -// Int returns redis reply as integer. +// Int converts interface to integer. func (s *Store) Int(r interface{}, err error) (int, error) { - return conv.Int(r, err) + if err != nil { + return 0, err + } + + switch r := r.(type) { + case int: + return r, nil + case int64: + if x := int(r); int64(x) != r { + return 0, ErrAssertType + } else { + return x, nil + } + case []byte: + if n, err := strconv.ParseInt(string(r), 10, 0); err != nil { + return 0, &Err{code: 3, msg: err.Error()} + } else { + return int(n), nil + } + case string: + if n, err := strconv.ParseInt(r, 10, 0); err != nil { + return 0, &Err{code: 3, msg: err.Error()} + } else { + return int(n), nil + } + case nil: + return 0, ErrNil + case error: + return 0, r + } + + return 0, ErrAssertType } -// Int64 returns redis reply as Int64. +// Int64 converts interface to Int64. func (s *Store) Int64(r interface{}, err error) (int64, error) { - return conv.Int64(r, err) -} + if err != nil { + return 0, err + } -// UInt64 returns redis reply as UInt64. -func (s *Store) UInt64(r interface{}, err error) (uint64, error) { - return conv.UInt64(r, err) -} + switch r := r.(type) { + case int: + return int64(r), nil + case int64: + return r, nil + case []byte: + if n, err := strconv.ParseInt(string(r), 10, 64); err != nil { + return 0, &Err{code: 3, msg: err.Error()} + } else { + return n, nil + } + case string: + if n, err := strconv.ParseInt(r, 10, 64); err != nil { + return 0, &Err{code: 3, msg: err.Error()} + } else { + return n, nil + } + case nil: + return 0, ErrNil + case error: + return 0, r + } -// Float64 returns redis reply as Float64. -func (s *Store) Float64(r interface{}, err error) (float64, error) { - return conv.Float64(r, err) + return 0, ErrAssertType } -// String returns redis reply as String. -func (s *Store) String(r interface{}, err error) (string, error) { - return conv.String(r, err) -} +// UInt64 converts interface to UInt64. +func (s *Store) UInt64(r interface{}, err error) (uint64, error) { + if err != nil { + return 0, err + } -// Bytes returns redis reply as Bytes. -func (s *Store) Bytes(r interface{}, err error) ([]byte, error) { - return conv.Bytes(r, err) -} + switch r := r.(type) { + case uint64: + return r, nil + case int: + if r < 0 { + return 0, ErrAssertType + } + return uint64(r), nil + case int64: + if r < 0 { + return 0, ErrAssertType + } + return uint64(r), nil + case []byte: + if n, err := strconv.ParseUint(string(r), 10, 64); err != nil { + return 0, &Err{code: 3, msg: err.Error()} + } else { + return n, nil + } + case string: + if n, err := strconv.ParseUint(r, 10, 64); err != nil { + return 0, &Err{code: 3, msg: err.Error()} + } else { + return n, nil + } + case nil: + return 0, ErrNil + case error: + return 0, r + } -// Bool returns redis reply as Bool. -func (s *Store) Bool(r interface{}, err error) (bool, error) { - return conv.Bool(r, err) + return 0, ErrAssertType } -func validateID(id string) bool { - if len(id) != sessionIDLen { - return false +// Float64 converts interface to Float64. +func (s *Store) Float64(r interface{}, err error) (float64, error) { + if err != nil { + return 0, err } - - for _, r := range id { - if !unicode.IsDigit(r) && !unicode.IsLetter(r) { - return false + switch r := r.(type) { + case float64: + return r, err + case []byte: + if n, err := strconv.ParseFloat(string(r), 64); err != nil { + return 0, &Err{code: 3, msg: err.Error()} + } else { + return n, nil } + case string: + if n, err := strconv.ParseFloat(r, 64); err != nil { + return 0, &Err{code: 3, msg: err.Error()} + } else { + return n, nil + } + case nil: + return 0, ErrNil + case error: + return 0, r } - - return true + return 0, ErrAssertType } -// generateID generates a random alpha-num session ID. -func generateID(n int) (string, error) { - const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" - bytes := make([]byte, n) - if _, err := rand.Read(bytes); err != nil { +// String converts interface to String. +func (s *Store) String(r interface{}, err error) (string, error) { + if err != nil { return "", err } + switch r := r.(type) { + case []byte: + return string(r), nil + case string: + return r, nil + case nil: + return "", ErrNil + case error: + return "", r + } + return "", ErrAssertType +} - for k, v := range bytes { - bytes[k] = dict[v%byte(len(dict))] +// Bytes converts interface to Bytes. +func (s *Store) Bytes(r interface{}, err error) ([]byte, error) { + if err != nil { + return nil, err + } + switch r := r.(type) { + case []byte: + return r, nil + case string: + return []byte(r), nil + case nil: + return nil, ErrNil + case error: + return nil, r } + return nil, ErrAssertType +} - return string(bytes), nil +// Bool converts interface to Bool. +func (s *Store) Bool(r interface{}, err error) (bool, error) { + if err != nil { + return false, err + } + switch r := r.(type) { + case bool: + return r, err + // Very common in redis to reply int64 with 0 for bool flag. + case int: + return r != 0, nil + case int64: + return r != 0, nil + case []byte: + if n, err := strconv.ParseBool(string(r)); err != nil { + return false, &Err{code: 3, msg: err.Error()} + } else { + return n, nil + } + case string: + if n, err := strconv.ParseBool(r); err != nil { + return false, &Err{code: 3, msg: err.Error()} + } else { + return n, nil + } + case nil: + return false, ErrNil + case error: + return false, r + } + return false, ErrAssertType } diff --git a/stores/redis/store_test.go b/stores/redis/store_test.go index 332b54e..e174317 100644 --- a/stores/redis/store_test.go +++ b/stores/redis/store_test.go @@ -1,4 +1,4 @@ -package goredis +package redis import ( "context" @@ -13,6 +13,7 @@ import ( var ( mockRedis *miniredis.Miniredis + errTest = errors.New("test error") ) func init() { @@ -47,85 +48,80 @@ func TestSetPrefix(t *testing.T) { func TestSetTTL(t *testing.T) { testDur := time.Second * 10 str := New(context.TODO(), getRedisClient()) - str.SetTTL(testDur) + str.SetTTL(testDur, true) assert.Equal(t, str.ttl, testDur) + assert.True(t, str.extendTTL) } func TestCreate(t *testing.T) { - str := New(context.TODO(), getRedisClient()) - id, err := str.Create() + var ( + id = "testid_create" + client = getRedisClient() + str = New(context.TODO(), client) + ) + str.SetTTL(time.Second*100, false) + err := str.Create(id) assert.Nil(t, err) - assert.Equal(t, len(id), sessionIDLen) -} - -func TestGet(t *testing.T) { - key := "4dIHy6S2uBuKaNnTUszB218L898ikGY1" - field := "somekey" - value := 100 - client := getRedisClient() - - // Set a key - err := client.HSet(context.TODO(), defaultPrefix+key, field, value).Err() - assert.NoError(t, err) - str := New(context.TODO(), client) - - val, err := str.Int(str.Get(key, field)) + vals, err := client.HGetAll(context.TODO(), str.prefix+id).Result() assert.NoError(t, err) - assert.Equal(t, val, value) + assert.Contains(t, vals, defaultSessKey) - // Check for invalid key. - _, err = str.Int(str.Get(key, "invalidfield")) - assert.ErrorIs(t, ErrFieldNotFound, err) + ttl, _ := client.TTL(context.TODO(), str.prefix+id).Result() + assert.Equal(t, ttl, time.Second*100) } -func TestGetInvalidSession(t *testing.T) { - str := New(context.TODO(), getRedisClient()) +func TestGet(t *testing.T) { + var ( + id = "testid_get" + field = "somekey" + value = 100 + client = getRedisClient() + str = New(context.TODO(), client) + ) + // Invalid session. val, err := str.Get("invalidkey", "invalidkey") assert.Nil(t, val) assert.ErrorIs(t, err, ErrInvalidSession) - id := "10IHy6S2uBuKaNnTUszB218L898ikGY1" - val, err = str.Get(id, "invalidkey") - assert.Nil(t, val) - assert.ErrorIs(t, ErrInvalidSession, err) -} + // Check valid session. + err = client.HMSet(context.TODO(), str.prefix+id, field, value, defaultSessKey, "1").Err() + assert.NoError(t, err) -func TestGetMultiInvalidSession(t *testing.T) { - str := New(context.TODO(), getRedisClient()) - val, err := str.GetMulti("invalidkey", "invalidkey") - assert.Nil(t, val) - assert.ErrorIs(t, ErrInvalidSession, err) + val, err = str.Int(str.Get(id, field)) + assert.NoError(t, err) + assert.Equal(t, val, value) - key := "11IHy6S2uBuKaNnTUszB218L898ikGY1" - field := "somefield" - _, err = str.GetMulti(key, field) - assert.ErrorIs(t, err, ErrInvalidSession) + // Check for invalid key. + _, err = str.Int(str.Get(id, "invalidfield")) + assert.ErrorIs(t, ErrNil, err) } func TestGetMulti(t *testing.T) { var ( - key = "5dIHy6S2uBuKaNnTUszB218L898ikGY1" + id = "testid_getmulti" field1 = "somekey" value1 = 100 field2 = "someotherkey" value2 = "abc123" - field3 = "thishouldntbethere" - value3 = 100.10 invalidField = "foo" client = getRedisClient() + str = New(context.TODO(), client) ) + // Invalid session. + val, err := str.GetMulti("invalidkey", "invalidkey") + assert.Nil(t, val) + assert.ErrorIs(t, err, ErrInvalidSession) // Set a key - err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2, field3, value3).Err() + err = client.HMSet(context.TODO(), str.prefix+id, defaultSessKey, "1", field1, value1, field2, value2).Err() assert.NoError(t, err) - str := New(context.TODO(), client) - vals, err := str.GetMulti(key, field1, field2, invalidField) + vals, err := str.GetMulti(id, field1, field2, invalidField) assert.NoError(t, err) assert.Contains(t, vals, field1) assert.Contains(t, vals, field2) - assert.NotContains(t, vals, field3) + assert.Contains(t, vals, invalidField) val1, err := str.Int(vals[field1], nil) assert.NoError(t, err) @@ -137,42 +133,28 @@ func TestGetMulti(t *testing.T) { // Check for invalid key. _, err = str.String(vals[invalidField], nil) - assert.ErrorIs(t, ErrFieldNotFound, err) -} - -func TestGetAllInvalidSession(t *testing.T) { - str := New(context.TODO(), getRedisClient()) - val, err := str.GetAll("invalidkey") - assert.Nil(t, val) - assert.ErrorIs(t, ErrInvalidSession, err) - - key := "11IHy6S2uBuKaNnTUszB218L898ikGY1" - val, err = str.GetAll(key) - assert.Nil(t, val) - assert.ErrorIs(t, ErrInvalidSession, err) + assert.ErrorIs(t, ErrNil, err) } func TestGetAll(t *testing.T) { - key := "6dIHy6S2uBuKaNnTUszB218L898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - field3 := "thishouldntbethere" - value3 := 100.10 - client := getRedisClient() + var ( + key = "testid_getall" + field1 = "somekey" + value1 = 100 + field2 = "someotherkey" + value2 = "abc123" + client = getRedisClient() + str = New(context.TODO(), client) + ) // Set a key - err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2, field3, value3).Err() + err := client.HMSet(context.TODO(), str.prefix+key, defaultSessKey, "1", field1, value1, field2, value2).Err() assert.NoError(t, err) - str := New(context.TODO(), client) - vals, err := str.GetAll(key) assert.NoError(t, err) assert.Contains(t, vals, field1) assert.Contains(t, vals, field2) - assert.Contains(t, vals, field3) val1, err := str.Int(vals[field1], nil) assert.NoError(t, err) @@ -181,60 +163,51 @@ func TestGetAll(t *testing.T) { val2, err := str.String(vals[field2], nil) assert.NoError(t, err) assert.Equal(t, val2, value2) - - val3, err := str.Float64(vals[field3], nil) - assert.NoError(t, err) - assert.Equal(t, val3, value3) -} - -func TestSetInvalidSessionError(t *testing.T) { - str := New(context.TODO(), getRedisClient()) - err := str.Set("invalidid", "key", "value") - assert.ErrorIs(t, ErrInvalidSession, err) } func TestSet(t *testing.T) { // Test should only set in internal map and not in redis - client := getRedisClient() - str := New(context.TODO(), client) - ttl := time.Second * 10 - str.SetTTL(ttl) - - // this key is unique across all tests - key := "7dIHy6S2uBuKaNnTUszB218L898ikGY9" - field := "somekey" - value := 100 + var ( + client = getRedisClient() + str = New(context.TODO(), client) + ttl = time.Second * 10 + // this key is unique across all tests + key = "testid_set" + field = "somekey" + value = 100 + ) + str.SetTTL(ttl, true) err := str.Set(key, field, value) assert.NoError(t, err) // Check ifs not commited to redis - v1, err := client.Exists(context.TODO(), defaultPrefix+key).Result() + v1, err := client.Exists(context.TODO(), str.prefix+key).Result() assert.NoError(t, err) assert.Equal(t, int64(1), v1) - v2, err := str.Int(client.HGet(context.TODO(), defaultPrefix+key, field).Result()) + v2, err := str.Int(client.HGet(context.TODO(), str.prefix+key, field).Result()) assert.NoError(t, err) assert.Equal(t, value, v2) - dur, err := client.TTL(context.TODO(), defaultPrefix+key).Result() + dur, err := client.TTL(context.TODO(), str.prefix+key).Result() assert.NoError(t, err) assert.Equal(t, dur, ttl) } func TestSetMulti(t *testing.T) { // Test should only set in internal map and not in redis - client := getRedisClient() - str := New(context.TODO(), client) - ttl := time.Second * 10 - str.SetTTL(ttl) - - // this key is unique across all tests - key := "7dIHy6S2uBuKaNnTUszB218L898ikGY9" - field1 := "somekey1" - value1 := 100 - field2 := "somekey2" - value2 := "somevalue" + var ( + client = getRedisClient() + str = New(context.TODO(), client) + ttl = time.Second * 10 + key = "testid_setmulti" + field1 = "somekey1" + value1 = 100 + field2 = "somekey2" + value2 = "somevalue" + ) + str.SetTTL(ttl, true) err := str.SetMulti(key, map[string]interface{}{ field1: value1, @@ -243,235 +216,355 @@ func TestSetMulti(t *testing.T) { assert.NoError(t, err) // Check ifs not commited to redis - v1, err := client.Exists(context.TODO(), defaultPrefix+key).Result() + v1, err := client.Exists(context.TODO(), str.prefix+key).Result() assert.NoError(t, err) assert.Equal(t, int64(1), v1) - v2, err := str.Int(client.HGet(context.TODO(), defaultPrefix+key, field1).Result()) + v2, err := str.Int(client.HGet(context.TODO(), str.prefix+key, field1).Result()) assert.NoError(t, err) assert.Equal(t, value1, v2) - dur, err := client.TTL(context.TODO(), defaultPrefix+key).Result() + dur, err := client.TTL(context.TODO(), str.prefix+key).Result() assert.NoError(t, err) assert.Equal(t, dur, ttl) } -func TestDeleteInvalidSessionError(t *testing.T) { - str := New(context.TODO(), getRedisClient()) - err := str.Delete("invalidkey", "somefield") - assert.ErrorIs(t, ErrInvalidSession, err) - - str = New(context.TODO(), getRedisClient()) - err = str.Delete("8dIHy6S2uBuKaNnTUszB2180898ikGY1", "somefield") - assert.ErrorIs(t, ErrInvalidSession, err) -} - func TestDelete(t *testing.T) { // Test should only set in internal map and not in redis - client := getRedisClient() - str := New(context.TODO(), client) - - // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" + var ( + client = getRedisClient() + str = New(context.TODO(), client) + + // this key is unique across all tests + key = "testid_delete" + field1 = "somekey" + value1 = 100 + field2 = "someotherkey" + value2 = "abc123" + ) - err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2).Err() + err := client.HMSet(context.TODO(), str.prefix+key, defaultSessKey, "1", field1, value1, field2, value2).Err() assert.NoError(t, err) err = str.Delete(key, field1) assert.NoError(t, err) - val, err := client.HExists(context.TODO(), defaultPrefix+key, field1).Result() + val, err := client.HExists(context.TODO(), str.prefix+key, field1).Result() assert.False(t, val) assert.NoError(t, err) - val, err = client.HExists(context.TODO(), defaultPrefix+key, field2).Result() + val, err = client.HExists(context.TODO(), str.prefix+key, field2).Result() assert.True(t, val) assert.NoError(t, err) - - err = str.Delete(key, "xxxxx") - assert.ErrorIs(t, err, ErrFieldNotFound) -} - -func TestClearInvalidSessionError(t *testing.T) { - str := New(context.TODO(), getRedisClient()) - err := str.Clear("invalidkey") - assert.ErrorIs(t, ErrInvalidSession, err) } func TestClear(t *testing.T) { // Test should only set in internal map and not in redis - client := getRedisClient() - str := New(context.TODO(), client) - - // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" + var ( + client = getRedisClient() + str = New(context.TODO(), client) + + // this key is unique across all tests + key = "testid_clear" + field1 = "somekey" + value1 = 100 + field2 = "someotherkey" + value2 = "abc123" + ) - err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2).Err() + err := client.HMSet(context.TODO(), str.prefix+key, defaultSessKey, "1", field1, value1, field2, value2).Err() assert.NoError(t, err) // Check if its set - val, err := client.Exists(context.TODO(), defaultPrefix+key).Result() + val, err := client.Exists(context.TODO(), str.prefix+key).Result() assert.NoError(t, err) assert.NotEqual(t, val, int64(0)) err = str.Clear(key) assert.NoError(t, err) - val, err = client.Exists(context.TODO(), defaultPrefix+key).Result() + val, err = client.Exists(context.TODO(), str.prefix+key).Result() assert.NoError(t, err) assert.Equal(t, val, int64(0)) } func TestInt(t *testing.T) { - client := getRedisClient() - str := New(context.TODO(), client) + str := New(context.TODO(), nil) + + v, err := str.Int(1, nil) + assert.NoError(t, err) + assert.Equal(t, 1, v) - field := "somekey" - value := 100 + v, err = str.Int("1", nil) + assert.NoError(t, err) + assert.Equal(t, 1, v) - err := client.Set(context.TODO(), field, value, 0).Err() + v, err = str.Int([]byte("1"), nil) assert.NoError(t, err) + assert.Equal(t, 1, v) - val, err := str.Int(client.Get(context.TODO(), field).Result()) + var tVal int64 = 1 + v, err = str.Int(tVal, nil) assert.NoError(t, err) - assert.Equal(t, value, val) + assert.Equal(t, 1, v) + + var tVal1 interface{} = 1 + v, err = str.Int(tVal1, nil) + assert.NoError(t, err) + assert.Equal(t, 1, v) + + // Test if ErrNil is returned if value is nil. + v, err = str.Int(nil, nil) + assert.ErrorIs(t, err, ErrNil) + assert.Equal(t, 0, v) - testError := errors.New("test error") - _, err = str.Int(value, testError) - assert.ErrorIs(t, testError, err) + // Test if custom error sent is returned. + v, err = str.Int(nil, errTest) + assert.ErrorIs(t, err, errTest) + assert.Equal(t, 0, v) + + // Test invalid assert error. + v, err = str.Int(10.1112, nil) + assert.ErrorIs(t, err, ErrAssertType) + assert.Equal(t, 0, v) } func TestInt64(t *testing.T) { - client := getRedisClient() - str := New(context.TODO(), client) + str := New(context.TODO(), nil) + + v, err := str.Int64(int64(1), nil) + assert.NoError(t, err) + assert.Equal(t, int64(1), v) - field := "somekey" - var value int64 = 100 + v, err = str.Int64("1", nil) + assert.NoError(t, err) + assert.Equal(t, int64(1), v) - err := client.Set(context.TODO(), field, value, 0).Err() + v, err = str.Int64([]byte("1"), nil) assert.NoError(t, err) + assert.Equal(t, int64(1), v) - val, err := str.Int64(client.Get(context.TODO(), field).Result()) + var tVal interface{} = 1 + v, err = str.Int64(tVal, nil) assert.NoError(t, err) - assert.Equal(t, value, val) + assert.Equal(t, int64(1), v) + + // Test if ErrNil is returned if value is nil. + v, err = str.Int64(nil, nil) + assert.ErrorIs(t, err, ErrNil) + assert.Equal(t, int64(0), v) - testError := errors.New("test error") - _, err = str.Int64(value, testError) - assert.ErrorIs(t, testError, err) + // Test if custom error sent is returned. + v, err = str.Int64(nil, errTest) + assert.ErrorIs(t, err, errTest) + assert.Equal(t, int64(0), v) + + // Test invalid assert error. + v, err = str.Int64(10.1112, nil) + assert.ErrorIs(t, err, ErrAssertType) + assert.Equal(t, int64(0), v) } func TestUInt64(t *testing.T) { - client := getRedisClient() - str := New(context.TODO(), client) + str := New(context.TODO(), nil) - field := "somekey" - var value uint64 = 100 + v, err := str.UInt64(uint64(1), nil) + assert.NoError(t, err) + assert.Equal(t, uint64(1), v) + + v, err = str.UInt64("1", nil) + assert.NoError(t, err) + assert.Equal(t, uint64(1), v) - err := client.Set(context.TODO(), field, value, 0).Err() + v, err = str.UInt64([]byte("1"), nil) assert.NoError(t, err) + assert.Equal(t, uint64(1), v) - val, err := str.UInt64(client.Get(context.TODO(), field).Result()) + var tVal interface{} = 1 + v, err = str.UInt64(tVal, nil) assert.NoError(t, err) - assert.Equal(t, value, val) + assert.Equal(t, uint64(1), v) + + // Test if ErrNil is returned if value is nil. + v, err = str.UInt64(nil, nil) + assert.ErrorIs(t, err, ErrNil) + assert.Equal(t, uint64(0), v) - testError := errors.New("test error") - _, err = str.UInt64(value, testError) - assert.ErrorIs(t, testError, err) + // Test if custom error sent is returned. + v, err = str.UInt64(nil, errTest) + assert.ErrorIs(t, err, errTest) + assert.Equal(t, uint64(0), v) + + // Test invalid assert error. + v, err = str.UInt64(10.1112, nil) + assert.ErrorIs(t, err, ErrAssertType) + assert.Equal(t, uint64(0), v) } func TestFloat64(t *testing.T) { - client := getRedisClient() - str := New(context.TODO(), client) + str := New(context.TODO(), nil) + + v, err := str.Float64(float64(1.11), nil) + assert.NoError(t, err) + assert.Equal(t, float64(1.11), v) - field := "somekey" - var value float64 = 100 + v, err = str.Float64("1.11", nil) + assert.NoError(t, err) + assert.Equal(t, float64(1.11), v) - err := client.Set(context.TODO(), field, value, 0).Err() + v, err = str.Float64([]byte("1.11"), nil) assert.NoError(t, err) + assert.Equal(t, float64(1.11), v) - val, err := str.Float64(client.Get(context.TODO(), field).Result()) + var tVal float64 = 1.11 + v, err = str.Float64(tVal, nil) assert.NoError(t, err) - assert.Equal(t, value, val) + assert.Equal(t, float64(1.11), v) + + // Test if ErrNil is returned if value is nil. + v, err = str.Float64(nil, nil) + assert.ErrorIs(t, err, ErrNil) + assert.Equal(t, float64(0), v) + + // Test if custom error sent is returned. + v, err = str.Float64(nil, errTest) + assert.ErrorIs(t, err, errTest) + assert.Equal(t, float64(0), v) - testError := errors.New("test error") - _, err = str.Float64(value, testError) - assert.ErrorIs(t, testError, err) + // Test invalid assert error. + // v, err = str.Float64("abc", nil) + // assert.ErrorIs(t, err, ErrAssertType) + // assert.Equal(t, float64(0), v) } func TestString(t *testing.T) { - client := getRedisClient() - str := New(context.TODO(), client) + str := New(context.TODO(), nil) - field := "somekey" - value := "abc123" + v, err := str.String("abc", nil) + assert.NoError(t, err) + assert.Equal(t, "abc", v) - err := client.Set(context.TODO(), field, value, 0).Err() + v, err = str.String([]byte("abc"), nil) assert.NoError(t, err) + assert.Equal(t, "abc", v) - val, err := str.String(client.Get(context.TODO(), field).Result()) + var tVal interface{} = "abc" + v, err = str.String(tVal, nil) assert.NoError(t, err) - assert.Equal(t, value, val) + assert.Equal(t, "abc", v) + + // Test if ErrNil is returned if value is nil. + v, err = str.String(nil, nil) + assert.ErrorIs(t, err, ErrNil) + assert.Equal(t, "", v) + + // Test if custom error sent is returned. + v, err = str.String(nil, errTest) + assert.ErrorIs(t, err, errTest) + assert.Equal(t, "", v) - testError := errors.New("test error") - _, err = str.String(value, testError) - assert.ErrorIs(t, testError, err) + // Test invalid assert error. + v, err = str.String(10.1112, nil) + assert.ErrorIs(t, err, ErrAssertType) + assert.Equal(t, "", v) } func TestBytes(t *testing.T) { - client := getRedisClient() - str := New(context.TODO(), client) + str := New(context.TODO(), nil) - field := "somekey" - var value []byte = []byte("abc123") + v, err := str.Bytes("abc", nil) + assert.NoError(t, err) + assert.Equal(t, []byte("abc"), v) - err := client.Set(context.TODO(), field, value, 0).Err() + v, err = str.Bytes([]byte("abc"), nil) assert.NoError(t, err) + assert.Equal(t, []byte("abc"), v) - val, err := str.Bytes(client.Get(context.TODO(), field).Result()) + var tVal interface{} = "abc" + v, err = str.Bytes(tVal, nil) assert.NoError(t, err) - assert.Equal(t, value, val) + assert.Equal(t, []byte("abc"), v) + + // Test if ErrNil is returned if value is nil. + v, err = str.Bytes(nil, nil) + assert.ErrorIs(t, err, ErrNil) + assert.Equal(t, []byte(nil), v) + + // Test if custom error sent is returned. + v, err = str.Bytes(nil, errTest) + assert.ErrorIs(t, err, errTest) + assert.Equal(t, []byte(nil), v) - testError := errors.New("test error") - _, err = str.Bytes(value, testError) - assert.ErrorIs(t, testError, err) + // Test invalid assert error. + v, err = str.Bytes(10.1112, nil) + assert.ErrorIs(t, err, ErrAssertType) + assert.Equal(t, []byte(nil), v) } func TestBool(t *testing.T) { - client := getRedisClient() - str := New(context.TODO(), client) + str := New(context.TODO(), nil) - field := "somekey" - value := true + v, err := str.Bool(true, nil) + assert.NoError(t, err) + assert.Equal(t, true, v) - err := client.Set(context.TODO(), field, value, 0).Err() + v, err = str.Bool(false, nil) assert.NoError(t, err) + assert.Equal(t, false, v) - val, err := str.Bool(client.Get(context.TODO(), field).Result()) + v, err = str.Bool(0, nil) assert.NoError(t, err) - assert.Equal(t, value, val) + assert.Equal(t, false, v) - testError := errors.New("test error") - _, err = str.Bool(value, testError) - assert.ErrorIs(t, testError, err) -} + v, err = str.Bool(1, nil) + assert.NoError(t, err) + assert.Equal(t, true, v) + + v, err = str.Bool(int64(0), nil) + assert.NoError(t, err) + assert.Equal(t, false, v) + + v, err = str.Bool(int64(1), nil) + assert.NoError(t, err) + assert.Equal(t, true, v) -func TestValidateID(t *testing.T) { - ok := validateID("xxxx") - assert.False(t, ok) + v, err = str.Bool([]byte("true"), nil) + assert.NoError(t, err) + assert.Equal(t, true, v) - ok = validateID("8dIHy6S2uBuKaNnTUszB2180898ikGY&") - assert.False(t, ok) + v, err = str.Bool([]byte("false"), nil) + assert.NoError(t, err) + assert.Equal(t, false, v) - id, err := generateID(sessionIDLen) + v, err = str.Bool("true", nil) assert.NoError(t, err) - ok = validateID(id) - assert.True(t, ok) + assert.Equal(t, true, v) + + v, err = str.Bool("false", nil) + assert.NoError(t, err) + assert.Equal(t, false, v) + + // Test if ErrNil is returned if value is nil. + v, err = str.Bool(nil, nil) + assert.ErrorIs(t, err, ErrNil) + assert.Equal(t, false, v) + + // Test if custom error sent is returned. + v, err = str.Bool(nil, errTest) + assert.ErrorIs(t, err, errTest) + assert.Equal(t, false, v) + + // Test invalid assert error. + v, err = str.Bool(10.1112, nil) + assert.ErrorIs(t, err, ErrAssertType) + assert.Equal(t, false, v) +} + +func TestError(t *testing.T) { + err := Err{ + code: 1, + msg: "test", + } + assert.Equal(t, 1, err.Code()) + assert.Equal(t, "test", err.Error()) }