Skip to content

Commit

Permalink
fix(pool): connection pool exhausted (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
cococolanosugar authored Mar 8, 2024
1 parent 9cd9abd commit 2f9f7cf
Showing 1 changed file with 75 additions and 16 deletions.
91 changes: 75 additions & 16 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ func (a *Adapter) getConn() redis.Conn {
return a._conn
}

func (a *Adapter) release(conn redis.Conn) {
if a._pool != nil {
if conn != nil {
conn.Close()
}
}
}

// finalizer is the destructor for Adapter.
func finalizer(a *Adapter) {
if a._conn != nil {
Expand Down Expand Up @@ -111,7 +119,11 @@ func NewAdapterWithKey(network string, address string, key string) (*Adapter, er
func NewAdapterWithPool(pool *redis.Pool) (*Adapter, error) {
a := &Adapter{}
a.key = "casbin_rules"
a._conn = pool.Get()

conn := pool.Get()
defer a.release(conn)

a._conn = conn
a._pool = pool

// Call the destructor when the object is released.
Expand All @@ -127,7 +139,11 @@ func NewAdapterWithPoolAndOptions(pool *redis.Pool, options ...Option) (*Adapter
for _, option := range options {
option(a)
}
a._conn = pool.Get()

conn := pool.Get()
defer a.release(conn)

a._conn = conn
a._pool = pool

// Call the destructor when the object is released.
Expand Down Expand Up @@ -228,7 +244,10 @@ func (a *Adapter) createTable() {
}

func (a *Adapter) dropTable() {
_, _ = a.getConn().Do("DEL", a.key)
conn := a.getConn()
defer a.release(conn)

_, _ = conn.Do("DEL", a.key)
}

func (c *CasbinRule) toStringPolicy() []string {
Expand Down Expand Up @@ -265,14 +284,17 @@ func loadPolicyLine(line CasbinRule, model model.Model) {

// LoadPolicy loads policy from database.
func (a *Adapter) LoadPolicy(model model.Model) error {
num, err := redis.Int(a.getConn().Do("LLEN", a.key))
conn := a.getConn()
defer a.release(conn)

num, err := redis.Int(conn.Do("LLEN", a.key))
if err == redis.ErrNil {
return nil
}
if err != nil {
return err
}
values, err := redis.Values(a.getConn().Do("LRANGE", a.key, 0, num))
values, err := redis.Values(conn.Do("LRANGE", a.key, 0, num))
if err != nil {
return err
}
Expand Down Expand Up @@ -349,7 +371,10 @@ func (a *Adapter) SavePolicy(model model.Model) error {
}
}

_, err := a.getConn().Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...)
conn := a.getConn()
defer a.release(conn)

_, err := conn.Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...)
return err
}

Expand All @@ -360,7 +385,11 @@ func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
if err != nil {
return err
}
_, err = a.getConn().Do("RPUSH", a.key, text)

conn := a.getConn()
defer a.release(conn)

_, err = conn.Do("RPUSH", a.key, text)
return err
}

Expand All @@ -371,7 +400,11 @@ func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
if err != nil {
return err
}
_, err = a.getConn().Do("LREM", a.key, 1, text)

conn := a.getConn()
defer a.release(conn)

_, err = conn.Do("LREM", a.key, 1, text)
return err
}

Expand All @@ -386,19 +419,26 @@ func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error
}
texts = append(texts, text)
}
_, err := a.getConn().Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...)

conn := a.getConn()
defer a.release(conn)

_, err := conn.Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...)
return err
}

// RemovePolicies removes policy rules from the storage.
func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) error {
conn := a.getConn()
defer a.release(conn)

for _, rule := range rules {
line := savePolicyLine(ptype, rule)
text, err := json.Marshal(line)
if err != nil {
return err
}
_, err = a.getConn().Do("LREM", a.key, 1, text)
_, err = conn.Do("LREM", a.key, 1, text)
if err != nil {
return err
}
Expand Down Expand Up @@ -484,14 +524,17 @@ func filterFieldToLuaPattern(sec string, ptype string, fieldIndex int, fieldValu
}

func (a *Adapter) loadFilteredPolicy(model model.Model, filter *Filter) error {
num, err := redis.Int(a.getConn().Do("LLEN", a.key))
conn := a.getConn()
defer a.release(conn)

num, err := redis.Int(conn.Do("LLEN", a.key))
if err == redis.ErrNil {
return nil
}
if err != nil {
return err
}
values, err := redis.Values(a.getConn().Do("LRANGE", a.key, 0, num))
values, err := redis.Values(conn.Do("LRANGE", a.key, 0, num))
if err != nil {
return err
}
Expand Down Expand Up @@ -559,7 +602,11 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int,
redis.call('lrem', key, 0, '__CASBIN_DELETED__')
return
`)
_, err := getScript.Do(a.getConn(), a.key, pattern)

conn := a.getConn()
defer a.release(conn)

_, err := getScript.Do(conn, a.key, pattern)
return err
}

Expand Down Expand Up @@ -592,7 +639,11 @@ func (a *Adapter) UpdatePolicy(sec string, ptype string, oldRule, newPolicy []st
end
return false
`)
_, err = getScript.Do(a.getConn(), a.key, textOld, textNew)

conn := a.getConn()
defer a.release(conn)

_, err = getScript.Do(conn, a.key, textOld, textNew)
return err
}

Expand Down Expand Up @@ -640,7 +691,11 @@ func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules []
return false
`)
args := redis.Args{}.Add(a.key).AddFlat(oldPolicies).AddFlat(newPolicies)
_, err := getScript.Do(a.getConn(), args...)

conn := a.getConn()
defer a.release(conn)

_, err := getScript.Do(conn, args...)
return err
}

Expand Down Expand Up @@ -684,7 +739,11 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [
args := redis.Args{}.Add(a.key).Add(pattern).AddFlat(newP)
//r, err := getScript.Do(a.conn, args...)
//reply, err := redis.Values(r, err)
reply, err := redis.Values(getScript.Do(a.getConn(), args...))

conn := a.getConn()
defer a.release(conn)

reply, err := redis.Values(getScript.Do(conn, args...))
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 2f9f7cf

Please sign in to comment.