Skip to content

Commit

Permalink
feat: Implement UpdatableAdapter interface
Browse files Browse the repository at this point in the history
Signed-off-by: closetool <4closetool3@gmail.com>
  • Loading branch information
kilosonc committed May 19, 2021
1 parent 34b16df commit 6c3f36b
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 9 deletions.
151 changes: 144 additions & 7 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,7 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int,
// This is part of the Auto-Save feature.
func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error {
return a.WithTx(func(tx *ent.Tx) error {
lines := make([]*ent.CasbinRuleCreate, 0)
for _, rule := range rules {
lines = append(lines, a.savePolicyLine(tx, ptype, rule))
}
_, err := tx.CasbinRule.CreateBulk(lines...).Save(a.ctx)
return err
return a.createPolicies(tx, ptype, rules)
})
}

Expand All @@ -215,7 +210,15 @@ func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) err
return a.WithTx(func(tx *ent.Tx) error {
for _, rule := range rules {
instance := a.toInstance(ptype, rule)
if err := tx.CasbinRule.DeleteOne(instance).Exec(a.ctx); err != nil {
if _, err := tx.CasbinRule.Delete().Where(
casbinrule.PtypeEQ(instance.Ptype),
casbinrule.V0EQ(instance.V0),
casbinrule.V1EQ(instance.V1),
casbinrule.V2EQ(instance.V2),
casbinrule.V3EQ(instance.V3),
casbinrule.V4EQ(instance.V4),
casbinrule.V5EQ(instance.V5),
).Exec(a.ctx); err != nil {
return err
}
}
Expand Down Expand Up @@ -319,3 +322,137 @@ func (a *Adapter) savePolicyLine(tx *ent.Tx, ptype string, rule []string) *ent.C

return line
}

// UpdatePolicy updates a policy rule from storage.
// This is part of the Auto-Save feature.
func (a *Adapter) UpdatePolicy(sec string, ptype string, oldRule, newPolicy []string) error {
return a.WithTx(func(tx *ent.Tx) error {
rule := a.toInstance(ptype, oldRule)
line := tx.CasbinRule.Update().Where(
casbinrule.PtypeEQ(rule.Ptype),
casbinrule.V0EQ(rule.V0),
casbinrule.V1EQ(rule.V1),
casbinrule.V2EQ(rule.V2),
casbinrule.V3EQ(rule.V3),
casbinrule.V4EQ(rule.V4),
casbinrule.V5EQ(rule.V5),
)
rule = a.toInstance(ptype, newPolicy)
line.SetV0(rule.V0)
line.SetV1(rule.V1)
line.SetV2(rule.V2)
line.SetV3(rule.V3)
line.SetV4(rule.V4)
line.SetV5(rule.V5)
_, err := line.Save(a.ctx)
return err
})
}

// UpdatePolicies updates some policy rules to storage, like db, redis.
func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) error {
return a.WithTx(func(tx *ent.Tx) error {
for _, policy := range oldRules {
rule := a.toInstance(ptype, policy)
if _, err := tx.CasbinRule.Delete().Where(
casbinrule.PtypeEQ(rule.Ptype),
casbinrule.V0EQ(rule.V0),
casbinrule.V1EQ(rule.V1),
casbinrule.V2EQ(rule.V2),
casbinrule.V3EQ(rule.V3),
casbinrule.V4EQ(rule.V4),
casbinrule.V5EQ(rule.V5),
).Exec(a.ctx); err != nil {
return err
}
}
lines := make([]*ent.CasbinRuleCreate, 0)
for _, policy := range newRules {
lines = append(lines, a.savePolicyLine(tx, ptype, policy))
}
if _, err := tx.CasbinRule.CreateBulk(lines...).Save(a.ctx); err != nil {
return err
}
return nil
})
}

// UpdateFilteredPolicies deletes old rules and adds new rules.
func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) ([][]string, error) {
oldPolicies := make([][]string, 0)
err := a.WithTx(func(tx *ent.Tx) error {
line := tx.CasbinRule.Query()
if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
line = line.Where(casbinrule.V0EQ(fieldValues[0-fieldIndex]))
}
if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
line = line.Where(casbinrule.V1EQ(fieldValues[1-fieldIndex]))
}
if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
line = line.Where(casbinrule.V2EQ(fieldValues[2-fieldIndex]))
}
if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
line = line.Where(casbinrule.V3EQ(fieldValues[3-fieldIndex]))
}
if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
line = line.Where(casbinrule.V4EQ(fieldValues[4-fieldIndex]))
}
if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
line = line.Where(casbinrule.V5EQ(fieldValues[5-fieldIndex]))
}
rules, err := line.All(a.ctx)
if err != nil {
return err
}
for _, rule := range rules {
if _, err := tx.CasbinRule.Delete().Where(
casbinrule.IDEQ(rule.ID),
).Exec(a.ctx); err != nil {
return err
}
}
a.createPolicies(tx, ptype, newPolicies)
for _, rule := range rules {
oldPolicies = append(oldPolicies, CasbinRuleToStringArray(rule))
}
return nil
})
if err != nil {
return nil, err
}
return oldPolicies, nil
}

func (a *Adapter) createPolicies(tx *ent.Tx, ptype string, policies [][]string) error {
lines := make([]*ent.CasbinRuleCreate, 0)
for _, policy := range policies {
lines = append(lines, a.savePolicyLine(tx, ptype, policy))
}
if _, err := tx.CasbinRule.CreateBulk(lines...).Save(a.ctx); err != nil {
return err
}
return nil
}

func CasbinRuleToStringArray(rule *ent.CasbinRule) []string {
arr := make([]string, 0)
if rule.V0 != "" {
arr = append(arr, rule.V0)
}
if rule.V1 != "" {
arr = append(arr, rule.V1)
}
if rule.V2 != "" {
arr = append(arr, rule.V2)
}
if rule.V3 != "" {
arr = append(arr, rule.V3)
}
if rule.V4 != "" {
arr = append(arr, rule.V4)
}
if rule.V5 != "" {
arr = append(arr, rule.V5)
}
return arr
}
18 changes: 16 additions & 2 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ func testAutoSave(t *testing.T, a *Adapter) {
e.RemoveFilteredPolicy(0, "data2_admin")
e.LoadPolicy()
testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})

e.RemovePolicies([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
e.LoadPolicy()
testGetPolicy(t, e, [][]string{})
}

//func testFilteredPolicy(t *testing.T, a *Adapter) {
Expand Down Expand Up @@ -225,7 +229,7 @@ func testUpdatePolicies(t *testing.T, a *Adapter) {
e.EnableAutoSave(true)
e.UpdatePolicies([][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}}, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}})
e.LoadPolicy()
testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
testGetPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}

func testUpdateFilteredPolicies(t *testing.T, a *Adapter) {
Expand All @@ -234,7 +238,7 @@ func testUpdateFilteredPolicies(t *testing.T, a *Adapter) {

e.EnableAutoSave(true)
e.UpdateFilteredPolicies([][]string{{"alice", "data1", "write"}}, 0, "alice", "data1", "read")
e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2", "write")
e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2")
e.LoadPolicy()
testGetPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}})
}
Expand Down Expand Up @@ -263,4 +267,14 @@ func TestAdapters(t *testing.T) {
a = initAdapterWithClientInstance(t, db)
testAutoSave(t, a)
testSaveLoad(t, a)

a = initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/casbin")
testUpdatePolicy(t, a)
testUpdatePolicies(t, a)
testUpdateFilteredPolicies(t, a)

a = initAdapter(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable dbname=casbin")
testUpdatePolicy(t, a)
testUpdatePolicies(t, a)
testUpdateFilteredPolicies(t, a)
}

0 comments on commit 6c3f36b

Please sign in to comment.