diff --git a/adapter.go b/adapter.go index aa36e00..8d7d359 100644 --- a/adapter.go +++ b/adapter.go @@ -48,13 +48,26 @@ type Adapter struct { username string password string tlsConfig *tls.Config - conn redis.Conn + _conn redis.Conn + _pool *redis.Pool isFiltered bool } +func (a *Adapter) getConn() redis.Conn { + if a._pool != nil { + return a._pool.Get() + } + return a._conn +} + // finalizer is the destructor for Adapter. func finalizer(a *Adapter) { - a.conn.Close() + if a._conn != nil { + a._conn.Close() + } + if a._pool != nil { + a._pool.Close() + } } func newAdapter(network string, address string, key string, @@ -98,7 +111,24 @@ func NewAdapterWithKey(network string, address string, key string) (*Adapter, er func NewAdapterWithPool(pool *redis.Pool) (*Adapter, error) { a := &Adapter{} a.key = "casbin_rules" - a.conn = pool.Get() + a._conn = pool.Get() + a._pool = pool + + // Call the destructor when the object is released. + runtime.SetFinalizer(a, finalizer) + + return a, nil +} + +// NewAdapterWithPoolAndOptions is the constructor for Adapter. +func NewAdapterWithPoolAndOptions(pool *redis.Pool, options ...Option) (*Adapter, error) { + a := &Adapter{} + a.key = "casbin_rules" + for _, option := range options { + option(a) + } + a._conn = pool.Get() + a._pool = pool // Call the destructor when the object is released. runtime.SetFinalizer(a, finalizer) @@ -166,34 +196,39 @@ func (a *Adapter) open() error { return err } - a.conn = conn + a._conn = conn } else if a.password == "" { conn, err := redis.Dial(a.network, a.address, redis.DialTLSConfig(a.tlsConfig), redis.DialUseTLS(useTls)) if err != nil { return err } - a.conn = conn + a._conn = conn } else { conn, err := redis.Dial(a.network, a.address, redis.DialPassword(a.password), redis.DialTLSConfig(a.tlsConfig), redis.DialUseTLS(useTls)) if err != nil { return err } - a.conn = conn + a._conn = conn } return nil } func (a *Adapter) close() { - a.conn.Close() + if a._conn != nil { + a._conn.Close() + } + if a._pool != nil { + a._pool.Close() + } } func (a *Adapter) createTable() { } func (a *Adapter) dropTable() { - _, _ = a.conn.Do("DEL", a.key) + _, _ = a.getConn().Do("DEL", a.key) } func (c *CasbinRule) toStringPolicy() []string { @@ -230,14 +265,14 @@ func loadPolicyLine(line CasbinRule, model model.Model) { // LoadPolicy loads policy from database. func (a *Adapter) LoadPolicy(model model.Model) error { - num, err := redis.Int(a.conn.Do("LLEN", a.key)) + num, err := redis.Int(a.getConn().Do("LLEN", a.key)) if err == redis.ErrNil { return nil } if err != nil { return err } - values, err := redis.Values(a.conn.Do("LRANGE", a.key, 0, num)) + values, err := redis.Values(a.getConn().Do("LRANGE", a.key, 0, num)) if err != nil { return err } @@ -314,7 +349,7 @@ func (a *Adapter) SavePolicy(model model.Model) error { } } - _, err := a.conn.Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...) + _, err := a.getConn().Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...) return err } @@ -325,7 +360,7 @@ func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error { if err != nil { return err } - _, err = a.conn.Do("RPUSH", a.key, text) + _, err = a.getConn().Do("RPUSH", a.key, text) return err } @@ -336,7 +371,7 @@ func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error { if err != nil { return err } - _, err = a.conn.Do("LREM", a.key, 1, text) + _, err = a.getConn().Do("LREM", a.key, 1, text) return err } @@ -351,7 +386,7 @@ func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error } texts = append(texts, text) } - _, err := a.conn.Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...) + _, err := a.getConn().Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...) return err } @@ -363,7 +398,7 @@ func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) err if err != nil { return err } - _, err = a.conn.Do("LREM", a.key, 1, text) + _, err = a.getConn().Do("LREM", a.key, 1, text) if err != nil { return err } @@ -449,14 +484,14 @@ func filterFieldToLuaPattern(sec string, ptype string, fieldIndex int, fieldValu } func (a *Adapter) loadFilteredPolicy(model model.Model, filter *Filter) error { - num, err := redis.Int(a.conn.Do("LLEN", a.key)) + num, err := redis.Int(a.getConn().Do("LLEN", a.key)) if err == redis.ErrNil { return nil } if err != nil { return err } - values, err := redis.Values(a.conn.Do("LRANGE", a.key, 0, num)) + values, err := redis.Values(a.getConn().Do("LRANGE", a.key, 0, num)) if err != nil { return err } @@ -524,7 +559,7 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, redis.call('lrem', key, 0, '__CASBIN_DELETED__') return `) - _, err := getScript.Do(a.conn, a.key, pattern) + _, err := getScript.Do(a.getConn(), a.key, pattern) return err } @@ -557,7 +592,7 @@ func (a *Adapter) UpdatePolicy(sec string, ptype string, oldRule, newPolicy []st end return false `) - _, err = getScript.Do(a.conn, a.key, textOld, textNew) + _, err = getScript.Do(a.getConn(), a.key, textOld, textNew) return err } @@ -605,7 +640,7 @@ func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules [] return false `) args := redis.Args{}.Add(a.key).AddFlat(oldPolicies).AddFlat(newPolicies) - _, err := getScript.Do(a.conn, args...) + _, err := getScript.Do(a.getConn(), args...) return err } @@ -649,7 +684,7 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [ args := redis.Args{}.Add(a.key).Add(pattern).AddFlat(newP) //r, err := getScript.Do(a.conn, args...) //reply, err := redis.Values(r, err) - reply, err := redis.Values(getScript.Do(a.conn, args...)) + reply, err := redis.Values(getScript.Do(a.getConn(), args...)) if err != nil { return nil, err } diff --git a/adapter_test.go b/adapter_test.go index 0082d11..0c3769e 100644 --- a/adapter_test.go +++ b/adapter_test.go @@ -403,3 +403,22 @@ func TestPoolAdapters(t *testing.T) { testUpdatePolicies(t, a) testUpdateFilteredPolicies(t, a) } + +func TestPoolAndOptionsAdapters(t *testing.T) { + a, err := NewAdapterWithPoolAndOptions(&redis.Pool{ + Dial: func() (redis.Conn, error) { + return redis.Dial("tcp", "127.0.0.1:6379") + }, + }, WithKey("casbin:policy:test")) + if err != nil { + t.Fatal(err) + } + + testSaveLoad(t, a) + testAutoSave(t, a) + testFilteredPolicy(t, a) + testAddPolicies(t, a) + testRemovePolicies(t, a) + testUpdatePolicies(t, a) + testUpdateFilteredPolicies(t, a) +}