diff --git a/src/main/java/org/casbin/adapter/HibernateAdapter.java b/src/main/java/org/casbin/adapter/HibernateAdapter.java index 8a53d2b..e7e45eb 100644 --- a/src/main/java/org/casbin/adapter/HibernateAdapter.java +++ b/src/main/java/org/casbin/adapter/HibernateAdapter.java @@ -10,6 +10,9 @@ import org.hibernate.Transaction; import org.hibernate.cfg.Configuration; +import javax.sql.DataSource; +import java.sql.Connection; +import java.sql.SQLException; import java.util.*; public class HibernateAdapter implements Adapter { @@ -20,22 +23,31 @@ public class HibernateAdapter implements Adapter { private int size = 0; private boolean dbSpecified; private SessionFactory factory; + private DataSource dataSource; + private String databaseProductName; - public HibernateAdapter(String driver, String url, String username, String password) { + public HibernateAdapter(String driver, String url, String username, String password) throws SQLException { this(driver, url, username, password, false); } - public HibernateAdapter(String driver, String url, String username, String password, boolean dbSpecified) { + public HibernateAdapter(DataSource dataSource) throws SQLException { + this.dataSource = dataSource; + + open(); + } + + public HibernateAdapter(String driver, String url, String username, String password, boolean dbSpecified) throws SQLException { this.driver = driver; this.url = url; this.username = username; this.password = password; this.dbSpecified = dbSpecified; + setDatabaseProductName(); open(); } - private void open() { + private void open() throws SQLException { this.factory = initSessionFactory(); if (!this.dbSpecified) { createDatabase(); @@ -47,10 +59,10 @@ private void open() { private void createDatabase() { Session session = factory.openSession(); Transaction tx = session.beginTransaction(); - if (this.driver.contains("mysql")) { + if (this.databaseProductName.contains("MySQL")) { session.createSQLQuery("CREATE DATABASE IF NOT EXISTS casbin").executeUpdate(); session.createSQLQuery("USE casbin").executeUpdate(); - } else if (this.driver.contains("sqlserver")) { + } else if (this.databaseProductName.contains("SQLServer")) { session.createSQLQuery("IF NOT EXISTS (" + "SELECT * FROM sysdatabases WHERE name = 'casbin') CREATE DATABASE casbin ON PRIMARY " + "( NAME = N'casbin', FILENAME = N'C:\\Program Files\\Microsoft SQL Server\\MSSQL.1\\MSSQL\\DATA\\casbinDB.mdf' , SIZE = 3072KB , MAXSIZE = UNLIMITED, FILEGROWTH = 1024KB ) " + @@ -66,7 +78,7 @@ private void createDatabase() { private void createTable() { Session session = factory.openSession(); Transaction tx = session.beginTransaction(); - if (this.driver.contains("mysql")) { + if (this.databaseProductName.contains("MySQL")) { session.createSQLQuery("CREATE TABLE IF NOT EXISTS casbin_rule (" + "id INT not NULL primary key," + "ptype VARCHAR(100) not NULL," + @@ -76,7 +88,7 @@ private void createTable() { "v3 VARCHAR(100)," + "v4 VARCHAR(100)," + "v5 VARCHAR(100))").executeUpdate(); - } else if (this.driver.contains("oracle")) { + } else if (this.databaseProductName.contains("Oracle")) { session.createSQLQuery("declare " + "nCount NUMBER;" + "v_sql LONG;" + @@ -97,7 +109,7 @@ private void createTable() { "execute immediate v_sql;" + "END IF;" + "end;").executeUpdate(); - } else if (this.driver.contains("sqlserver")) { + } else if (this.databaseProductName.contains("SQLServer")) { session.createSQLQuery("if not exists (select * from sysobjects where id = object_id('casbin_rule')) " + "create table casbin_rule (" + " id int, " + @@ -118,9 +130,9 @@ private void createTable() { private void dropTable() { Session session = factory.openSession(); Transaction tx = session.beginTransaction(); - if (this.driver.contains("mysql")) { + if (this.databaseProductName.contains("MySQL")) { session.createSQLQuery("DROP TABLE IF EXISTS casbin_rule").executeUpdate(); - } else if (this.driver.contains("oracle")) { + } else if (this.databaseProductName.contains("Oracle")) { session.createSQLQuery("declare " + "nCount NUMBER;" + "v_sql LONG;" + @@ -132,26 +144,33 @@ private void dropTable() { "execute immediate v_sql;" + "END IF;" + "end;").executeUpdate(); - } else if (this.driver.contains("sqlserver")) { + } else if (this.databaseProductName.contains("SQLServer")) { session.createSQLQuery("if exists (select * from sysobjects where id = object_id('casbin_rule') drop table casbin_rule").executeUpdate(); } tx.commit(); session.close(); } - private SessionFactory initSessionFactory() { + private SessionFactory initSessionFactory() throws SQLException { Configuration configuration = new Configuration(); Properties properties = new Properties(); - properties.setProperty("hibernate.connection.driver_class", this.driver); - properties.setProperty("hibernate.connection.url", this.url); - properties.setProperty("hibernate.connection.username", this.username); - properties.setProperty("hibernate.connection.password", this.password); - if (this.driver.contains("mysql")) { - properties.setProperty("hibernate.dialect", "org.hibernate.dialect.MySQL57Dialect"); - } else if (this.driver.contains("oracle")) { - properties.setProperty("hibernate.dialect", "org.hibernate.dialect.Oracle9iDialect"); - } else if (this.driver.contains("sqlserver")) { - properties.setProperty("hibernate.dialect", "org.hibernate.dialect.SQLServer2012Dialect"); + if (dataSource != null) { + properties.put("hibernate.connection.datasource", this.dataSource); + Connection conn = this.dataSource.getConnection(); + this.databaseProductName = conn.getMetaData().getDatabaseProductName(); + conn.close(); + } else { + properties.setProperty("hibernate.connection.driver_class", this.driver); + properties.setProperty("hibernate.connection.url", this.url); + properties.setProperty("hibernate.connection.username", this.username); + properties.setProperty("hibernate.connection.password", this.password); + if (this.driver.contains("mysql")) { + properties.setProperty("hibernate.dialect", "org.hibernate.dialect.MySQL57Dialect"); + } else if (this.driver.contains("oracle")) { + properties.setProperty("hibernate.dialect", "org.hibernate.dialect.Oracle9iDialect"); + } else if (this.driver.contains("sqlserver")) { + properties.setProperty("hibernate.dialect", "org.hibernate.dialect.SQLServer2012Dialect"); + } } configuration.setProperties(properties); @@ -329,6 +348,18 @@ private void reset() { session.close(); } + private void setDatabaseProductName() { + if (this.driver != null) { + if (this.driver.contains("mysql")) { + this.databaseProductName = "MySQL"; + } else if (this.driver.contains("oracle")) { + this.databaseProductName = "Oracle"; + } else if (this.driver.contains("sqlserver")) { + this.databaseProductName = "SQLServer"; + } + } + } + public int getPolicySize() { return size; } diff --git a/src/test/java/org/casbin/test/HibernateAdapterTest.java b/src/test/java/org/casbin/test/HibernateAdapterTest.java index 059f1cb..1e47f66 100644 --- a/src/test/java/org/casbin/test/HibernateAdapterTest.java +++ b/src/test/java/org/casbin/test/HibernateAdapterTest.java @@ -1,11 +1,14 @@ package org.casbin.test; +import com.mysql.jdbc.jdbc2.optional.MysqlDataSource; import org.casbin.adapter.HibernateAdapter; import org.casbin.jcasbin.main.Enforcer; import org.casbin.jcasbin.persist.Adapter; import org.junit.Before; import org.junit.Test; +import java.sql.SQLException; + import static org.junit.Assert.assertEquals; public class HibernateAdapterTest { @@ -15,7 +18,7 @@ public class HibernateAdapterTest { private static final String PASSWORD = "casbin_test"; @Before - public void initDataBase() { + public void initDataBase() throws SQLException { Enforcer e = new Enforcer("examples/rbac_with_domains_model.conf"); Adapter adapter = new HibernateAdapter(DRIVER, URL, USERNAME, PASSWORD, true); @@ -33,7 +36,7 @@ public void initDataBase() { } @Test - public void testLoadPolicy() { + public void testLoadPolicy() throws SQLException { Enforcer e = new Enforcer("examples/rbac_with_domains_model.conf"); Adapter adapter = new HibernateAdapter(DRIVER, URL, USERNAME, PASSWORD, true); @@ -52,7 +55,31 @@ public void testLoadPolicy() { } @Test - public void testSavePolicy() { + public void testLoadPolicyWithDataSource() throws SQLException { + Enforcer e = new Enforcer("examples/rbac_with_domains_model.conf"); + + MysqlDataSource dataSource = new MysqlDataSource(); + dataSource.setURL(URL); + dataSource.setUser(USERNAME); + dataSource.setPassword(PASSWORD); + + Adapter adapter = new HibernateAdapter(dataSource); + + e.setAdapter(adapter); + e.loadPolicy(); + + testDomainEnforce(e, "alice", "domain1", "data1", "read", true); + testDomainEnforce(e, "alice", "domain1", "data1", "write", true); + testDomainEnforce(e, "alice", "domain1", "data2", "read", false); + testDomainEnforce(e, "alice", "domain1", "data2", "write", false); + testDomainEnforce(e, "bob", "domain2", "data1", "read", false); + testDomainEnforce(e, "bob", "domain2", "data1", "write", false); + testDomainEnforce(e, "bob", "domain2", "data2", "read", true); + testDomainEnforce(e, "bob", "domain2", "data2", "write", true); + } + + @Test + public void testSavePolicy() throws SQLException { Enforcer e = new Enforcer("examples/rbac_with_domains_model.conf"); Adapter adapter = new HibernateAdapter(DRIVER, URL, USERNAME, PASSWORD, true); @@ -80,7 +107,7 @@ public void testSavePolicy() { } @Test - public void testRemovePolicy() { + public void testRemovePolicy() throws SQLException { Enforcer e = new Enforcer("examples/rbac_with_domains_model.conf"); Adapter adapter = new HibernateAdapter(DRIVER, URL, USERNAME, PASSWORD, true);