diff --git a/adapter.go b/adapter.go index 7681584..65daefd 100644 --- a/adapter.go +++ b/adapter.go @@ -236,22 +236,22 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, return a.WithTx(func(tx *ent.Tx) error { cond := make([]predicate.CasbinRule, 0) cond = append(cond, casbinrule.PtypeEQ(ptype)) - if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) { + if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) && len(fieldValues[0-fieldIndex]) > 0 { cond = append(cond, casbinrule.V0EQ(fieldValues[0-fieldIndex])) } - if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) { + if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) && len(fieldValues[1-fieldIndex]) > 0 { cond = append(cond, casbinrule.V1EQ(fieldValues[1-fieldIndex])) } - if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) { + if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) && len(fieldValues[2-fieldIndex]) > 0 { cond = append(cond, casbinrule.V2EQ(fieldValues[2-fieldIndex])) } - if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) { + if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) && len(fieldValues[3-fieldIndex]) > 0 { cond = append(cond, casbinrule.V3EQ(fieldValues[3-fieldIndex])) } - if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) { + if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) && len(fieldValues[4-fieldIndex]) > 0 { cond = append(cond, casbinrule.V4EQ(fieldValues[4-fieldIndex])) } - if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) { + if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) && len(fieldValues[5-fieldIndex]) > 0 { cond = append(cond, casbinrule.V5EQ(fieldValues[5-fieldIndex])) } _, err := tx.CasbinRule.Delete().Where( @@ -446,37 +446,47 @@ func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules [] func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) ([][]string, error) { oldPolicies := make([][]string, 0) err := a.WithTx(func(tx *ent.Tx) error { - line := tx.CasbinRule.Query() - if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) { - line = line.Where(casbinrule.V0EQ(fieldValues[0-fieldIndex])) + cond := make([]predicate.CasbinRule, 0) + cond = append(cond, casbinrule.PtypeEQ(ptype)) + if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) && len(fieldValues[0-fieldIndex]) > 0 { + cond = append(cond, casbinrule.V0EQ(fieldValues[0-fieldIndex])) } - if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) { - line = line.Where(casbinrule.V1EQ(fieldValues[1-fieldIndex])) + if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) && len(fieldValues[1-fieldIndex]) > 0 { + cond = append(cond, casbinrule.V1EQ(fieldValues[1-fieldIndex])) } - if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) { - line = line.Where(casbinrule.V2EQ(fieldValues[2-fieldIndex])) + if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) && len(fieldValues[2-fieldIndex]) > 0 { + cond = append(cond, casbinrule.V2EQ(fieldValues[2-fieldIndex])) } - if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) { - line = line.Where(casbinrule.V3EQ(fieldValues[3-fieldIndex])) + if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) && len(fieldValues[3-fieldIndex]) > 0 { + cond = append(cond, casbinrule.V3EQ(fieldValues[3-fieldIndex])) } - if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) { - line = line.Where(casbinrule.V4EQ(fieldValues[4-fieldIndex])) + if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) && len(fieldValues[4-fieldIndex]) > 0 { + cond = append(cond, casbinrule.V4EQ(fieldValues[4-fieldIndex])) } - if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) { - line = line.Where(casbinrule.V5EQ(fieldValues[5-fieldIndex])) + if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) && len(fieldValues[5-fieldIndex]) > 0 { + cond = append(cond, casbinrule.V5EQ(fieldValues[5-fieldIndex])) } - rules, err := line.All(a.ctx) + rules, err := tx.CasbinRule.Query(). + Where(cond...). + All(a.ctx) if err != nil { return err } - for _, rule := range rules { - if _, err := tx.CasbinRule.Delete().Where( - casbinrule.IDEQ(rule.ID), - ).Exec(a.ctx); err != nil { - return err - } + ruleIDs := make([]int, 0, len(rules)) + for _, r := range rules { + ruleIDs = append(ruleIDs, r.ID) + } + + _, err = tx.CasbinRule.Delete(). + Where(casbinrule.IDIn(ruleIDs...)). + Exec(a.ctx) + if err != nil { + return err + } + + if err := a.createPolicies(tx, ptype, newPolicies); err != nil { + return err } - a.createPolicies(tx, ptype, newPolicies) for _, rule := range rules { oldPolicies = append(oldPolicies, CasbinRuleToStringArray(rule)) }