diff --git a/README.md b/README.md index 72add83..d7ecde3 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,49 @@ func main() { e.SavePolicy() } ``` +## Advanced Example + +```go +package main + +import ( + "github.com/casbin/casbin/v2" + "github.com/casbin/mongodb-adapter/v3" + mongooptions "go.mongodb.org/mongo-driver/mongo/options" +) + +func main() { + // Initialize a MongoDB adapter with NewAdapterWithClientOption: + // The adapter will use custom mongo client options. + // custom database name. + // default collection name 'casbin_rule'. + mongoClientOption := mongooptions.Client().ApplyURI("mongodb://127.0.0.1:27017") + databaseName := "casbin" + a,err := mongodbadapter.NewAdapterWithClientOption(mongoClientOption, databaseName) + // Or you can use NewAdapterWithCollectionName for custom collection name. + if err != nil { + panic(err) + } + + e, err := casbin.NewEnforcer("examples/rbac_model.conf", a) + if err != nil { + panic(err) + } + + // Load the policy from DB. + e.LoadPolicy() + + // Check the permission. + e.Enforce("alice", "data1", "read") + + // Modify the policy. + // e.AddPolicy(...) + // e.RemovePolicy(...) + + // Save the policy back to DB. + e.SavePolicy() +} +``` ## Filtered Policies diff --git a/adapter.go b/adapter.go index 510d0eb..f2e7e31 100644 --- a/adapter.go +++ b/adapter.go @@ -32,6 +32,8 @@ import ( ) const defaultTimeout time.Duration = 30 * time.Second +const defaultDatabaseName string = "casbin" +const defaultCollectionName string = "casbin_rule" // CasbinRule represents a rule in Casbin. type CasbinRule struct { @@ -60,6 +62,7 @@ func finalizer(a *adapter) { // NewAdapter is the constructor for Adapter. If database name is not provided // in the Mongo URL, 'casbin' will be used as database name. +// 'casbin_rule' will be used as a collection name. func NewAdapter(url string, timeout ...interface{}) (persist.BatchAdapter, error) { if !strings.HasPrefix(url, "mongodb+srv://") && !strings.HasPrefix(url, "mongodb://") { url = fmt.Sprint("mongodb://" + url) @@ -79,15 +82,26 @@ func NewAdapter(url string, timeout ...interface{}) (persist.BatchAdapter, error if connString.Database != "" { databaseName = connString.Database } else { - databaseName = "casbin_rule" + databaseName = defaultDatabaseName } - return NewAdapterWithClientOption(clientOption, databaseName, timeout...) + return baseNewAdapter(clientOption, databaseName, defaultCollectionName, timeout...) } // NewAdapterWithClientOption is an alternative constructor for Adapter -// that does the same as NewAdapter, but uses mongo.ClientOption instead of a Mongo URL +// that does the same as NewAdapter, but uses mongo.ClientOption instead of a Mongo URL + a databaseName option func NewAdapterWithClientOption(clientOption *options.ClientOptions, databaseName string, timeout ...interface{}) (persist.BatchAdapter, error) { + return baseNewAdapter(clientOption, databaseName, defaultCollectionName, timeout...) +} + +// NewAdapterWithCollectionName is an alternative constructor for Adapter +// that does the same as NewAdapterWithClientOption, but with an extra collectionName option +func NewAdapterWithCollectionName(clientOption *options.ClientOptions, databaseName string, collectionName string, timeout ...interface{}) (persist.BatchAdapter, error) { + return baseNewAdapter(clientOption, databaseName, collectionName, timeout...) +} + +// baseNewAdapter is a base constructor for Adapter +func baseNewAdapter(clientOption *options.ClientOptions, databaseName string, collectionName string, timeout ...interface{}) (persist.BatchAdapter, error) { a := &adapter{ clientOption: clientOption, } @@ -102,7 +116,7 @@ func NewAdapterWithClientOption(clientOption *options.ClientOptions, databaseNam } // Open the DB, create it if not existed. - err := a.open(databaseName) + err := a.open(databaseName, collectionName) if err != nil { return nil, err } @@ -125,7 +139,7 @@ func NewFilteredAdapter(url string) (persist.FilteredAdapter, error) { return a.(*adapter), nil } -func (a *adapter) open(databaseName string) error { +func (a *adapter) open(databaseName string, collectionName string) error { ctx, cancel := context.WithTimeout(context.TODO(), a.timeout) defer cancel() @@ -135,7 +149,7 @@ func (a *adapter) open(databaseName string) error { } db := client.Database(databaseName) - collection := db.Collection("casbin_rule") + collection := db.Collection(collectionName) a.client = client a.collection = collection diff --git a/adapter_test.go b/adapter_test.go index 065a6a6..7fad8c6 100644 --- a/adapter_test.go +++ b/adapter_test.go @@ -17,11 +17,13 @@ package mongodbadapter import ( "fmt" "os" + "strings" "testing" "github.com/casbin/casbin/v2" "github.com/casbin/casbin/v2/util" "go.mongodb.org/mongo-driver/bson" + mongooptions "go.mongodb.org/mongo-driver/mongo/options" ) var testDbURL = os.Getenv("TEST_MONGODB_URL") @@ -104,7 +106,7 @@ func TestAdapter(t *testing.T) { {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, - }, + }, ) // AutoSave is enabled by default. // Now we disable it. @@ -122,7 +124,7 @@ func TestAdapter(t *testing.T) { {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, - }, + }, ) // Now we enable the AutoSave. @@ -142,8 +144,8 @@ func TestAdapter(t *testing.T) { {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}, - }, - ) + }, + ) // Remove the added rule. e.RemovePolicy("alice", "data1", "write") @@ -158,7 +160,7 @@ func TestAdapter(t *testing.T) { {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, - }, + }, ) // Remove "data2_admin" related policy rules via a filter. @@ -170,7 +172,7 @@ func TestAdapter(t *testing.T) { testGetPolicy(t, e, [][]string{ {"alice", "data1", "read"}, {"bob", "data2", "write"}, - }, + }, ) e.RemoveFilteredPolicy(1, "data1") @@ -204,15 +206,15 @@ func TestAddPolicies(t *testing.T) { {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, - }, + }, ) - a.AddPolicies("p","p",[][]string{ + a.AddPolicies("p", "p", [][]string{ {"bob", "data2", "read"}, {"alice", "data2", "write"}, {"alice", "data2", "read"}, {"bob", "data1", "write"}, {"bob", "data1", "read"}, - }, + }, ) if err := e.LoadPolicy(); err != nil { @@ -222,14 +224,14 @@ func TestAddPolicies(t *testing.T) { testGetPolicy(t, e, [][]string{ {"alice", "data1", "read"}, {"bob", "data2", "write"}, - {"data2_admin", "data2", "read"}, + {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}, {"alice", "data2", "write"}, {"alice", "data2", "read"}, {"bob", "data1", "write"}, {"bob", "data1", "read"}, - }, + }, ) // Remove the added rule. @@ -248,10 +250,10 @@ func TestAddPolicies(t *testing.T) { testGetPolicy(t, e, [][]string{ {"alice", "data1", "read"}, {"bob", "data2", "write"}, - {"data2_admin", "data2", "read"}, + {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, - }, - ) + }, + ) } func TestDeleteFilteredAdapter(t *testing.T) { @@ -322,7 +324,7 @@ func TestFilteredAdapter(t *testing.T) { if err != nil { panic(err) } - + // Load filtered policies from the database. e.AddPolicy("alice", "data1", "write") e.AddPolicy("bob", "data2", "write") @@ -343,7 +345,7 @@ func TestFilteredAdapter(t *testing.T) { testGetPolicy(t, e, [][]string{ {"alice", "data1", "read"}, {"alice", "data1", "write"}, - }, + }, ) // Test safe handling of SavePolicy when using filtered policies. @@ -400,3 +402,30 @@ func TestNewAdapterWithDatabase(t *testing.T) { panic(err) } } + +func TestNewAdapterWithClientOption(t *testing.T) { + uri := getDbURL() + if !strings.HasPrefix(uri, "mongodb+srv://") && !strings.HasPrefix(uri, "mongodb://") { + uri = fmt.Sprint("mongodb://" + uri) + } + mongoClientOption := mongooptions.Client().ApplyURI(uri) + databaseName := "casbin_custom" + _, err := NewAdapterWithClientOption(mongoClientOption, databaseName) + if err != nil { + panic(err) + } +} + +func TestNewAdapterWithCollectionName(t *testing.T) { + uri := getDbURL() + if !strings.HasPrefix(uri, "mongodb+srv://") && !strings.HasPrefix(uri, "mongodb://") { + uri = fmt.Sprint("mongodb://" + uri) + } + mongoClientOption := mongooptions.Client().ApplyURI(uri) + databaseName := "casbin_custom" + collectionName := "casbin_rule_custom" + _, err := NewAdapterWithCollectionName(mongoClientOption, databaseName, collectionName) + if err != nil { + panic(err) + } +}