Skip to content

Commit

Permalink
Merge pull request #8 from shy1st/master
Browse files Browse the repository at this point in the history
fix: Add constructor to datasource.
  • Loading branch information
hsluoyz authored Mar 13, 2021
2 parents 33d35c2 + 25e7345 commit dae97fc
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 26 deletions.
75 changes: 53 additions & 22 deletions src/main/java/org/casbin/adapter/HibernateAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import org.hibernate.Transaction;
import org.hibernate.cfg.Configuration;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.*;

public class HibernateAdapter implements Adapter {
Expand All @@ -20,22 +23,31 @@ public class HibernateAdapter implements Adapter {
private int size = 0;
private boolean dbSpecified;
private SessionFactory factory;
private DataSource dataSource;
private String databaseProductName;

public HibernateAdapter(String driver, String url, String username, String password) {
public HibernateAdapter(String driver, String url, String username, String password) throws SQLException {
this(driver, url, username, password, false);
}

public HibernateAdapter(String driver, String url, String username, String password, boolean dbSpecified) {
public HibernateAdapter(DataSource dataSource) throws SQLException {
this.dataSource = dataSource;

open();
}

public HibernateAdapter(String driver, String url, String username, String password, boolean dbSpecified) throws SQLException {
this.driver = driver;
this.url = url;
this.username = username;
this.password = password;
this.dbSpecified = dbSpecified;

setDatabaseProductName();
open();
}

private void open() {
private void open() throws SQLException {
this.factory = initSessionFactory();
if (!this.dbSpecified) {
createDatabase();
Expand All @@ -47,10 +59,10 @@ private void open() {
private void createDatabase() {
Session session = factory.openSession();
Transaction tx = session.beginTransaction();
if (this.driver.contains("mysql")) {
if (this.databaseProductName.contains("MySQL")) {
session.createSQLQuery("CREATE DATABASE IF NOT EXISTS casbin").executeUpdate();
session.createSQLQuery("USE casbin").executeUpdate();
} else if (this.driver.contains("sqlserver")) {
} else if (this.databaseProductName.contains("SQLServer")) {
session.createSQLQuery("IF NOT EXISTS (" +
"SELECT * FROM sysdatabases WHERE name = 'casbin') CREATE DATABASE casbin ON PRIMARY " +
"( NAME = N'casbin', FILENAME = N'C:\\Program Files\\Microsoft SQL Server\\MSSQL.1\\MSSQL\\DATA\\casbinDB.mdf' , SIZE = 3072KB , MAXSIZE = UNLIMITED, FILEGROWTH = 1024KB ) " +
Expand All @@ -66,7 +78,7 @@ private void createDatabase() {
private void createTable() {
Session session = factory.openSession();
Transaction tx = session.beginTransaction();
if (this.driver.contains("mysql")) {
if (this.databaseProductName.contains("MySQL")) {
session.createSQLQuery("CREATE TABLE IF NOT EXISTS casbin_rule (" +
"id INT not NULL primary key," +
"ptype VARCHAR(100) not NULL," +
Expand All @@ -76,7 +88,7 @@ private void createTable() {
"v3 VARCHAR(100)," +
"v4 VARCHAR(100)," +
"v5 VARCHAR(100))").executeUpdate();
} else if (this.driver.contains("oracle")) {
} else if (this.databaseProductName.contains("Oracle")) {
session.createSQLQuery("declare " +
"nCount NUMBER;" +
"v_sql LONG;" +
Expand All @@ -97,7 +109,7 @@ private void createTable() {
"execute immediate v_sql;" +
"END IF;" +
"end;").executeUpdate();
} else if (this.driver.contains("sqlserver")) {
} else if (this.databaseProductName.contains("SQLServer")) {
session.createSQLQuery("if not exists (select * from sysobjects where id = object_id('casbin_rule')) " +
"create table casbin_rule (" +
" id int, " +
Expand All @@ -118,9 +130,9 @@ private void createTable() {
private void dropTable() {
Session session = factory.openSession();
Transaction tx = session.beginTransaction();
if (this.driver.contains("mysql")) {
if (this.databaseProductName.contains("MySQL")) {
session.createSQLQuery("DROP TABLE IF EXISTS casbin_rule").executeUpdate();
} else if (this.driver.contains("oracle")) {
} else if (this.databaseProductName.contains("Oracle")) {
session.createSQLQuery("declare " +
"nCount NUMBER;" +
"v_sql LONG;" +
Expand All @@ -132,26 +144,33 @@ private void dropTable() {
"execute immediate v_sql;" +
"END IF;" +
"end;").executeUpdate();
} else if (this.driver.contains("sqlserver")) {
} else if (this.databaseProductName.contains("SQLServer")) {
session.createSQLQuery("if exists (select * from sysobjects where id = object_id('casbin_rule') drop table casbin_rule").executeUpdate();
}
tx.commit();
session.close();
}

private SessionFactory initSessionFactory() {
private SessionFactory initSessionFactory() throws SQLException {
Configuration configuration = new Configuration();
Properties properties = new Properties();
properties.setProperty("hibernate.connection.driver_class", this.driver);
properties.setProperty("hibernate.connection.url", this.url);
properties.setProperty("hibernate.connection.username", this.username);
properties.setProperty("hibernate.connection.password", this.password);
if (this.driver.contains("mysql")) {
properties.setProperty("hibernate.dialect", "org.hibernate.dialect.MySQL57Dialect");
} else if (this.driver.contains("oracle")) {
properties.setProperty("hibernate.dialect", "org.hibernate.dialect.Oracle9iDialect");
} else if (this.driver.contains("sqlserver")) {
properties.setProperty("hibernate.dialect", "org.hibernate.dialect.SQLServer2012Dialect");
if (dataSource != null) {
properties.put("hibernate.connection.datasource", this.dataSource);
Connection conn = this.dataSource.getConnection();
this.databaseProductName = conn.getMetaData().getDatabaseProductName();
conn.close();
} else {
properties.setProperty("hibernate.connection.driver_class", this.driver);
properties.setProperty("hibernate.connection.url", this.url);
properties.setProperty("hibernate.connection.username", this.username);
properties.setProperty("hibernate.connection.password", this.password);
if (this.driver.contains("mysql")) {
properties.setProperty("hibernate.dialect", "org.hibernate.dialect.MySQL57Dialect");
} else if (this.driver.contains("oracle")) {
properties.setProperty("hibernate.dialect", "org.hibernate.dialect.Oracle9iDialect");
} else if (this.driver.contains("sqlserver")) {
properties.setProperty("hibernate.dialect", "org.hibernate.dialect.SQLServer2012Dialect");
}
}
configuration.setProperties(properties);

Expand Down Expand Up @@ -329,6 +348,18 @@ private void reset() {
session.close();
}

private void setDatabaseProductName() {
if (this.driver != null) {
if (this.driver.contains("mysql")) {
this.databaseProductName = "MySQL";
} else if (this.driver.contains("oracle")) {
this.databaseProductName = "Oracle";
} else if (this.driver.contains("sqlserver")) {
this.databaseProductName = "SQLServer";
}
}
}

public int getPolicySize() {
return size;
}
Expand Down
35 changes: 31 additions & 4 deletions src/test/java/org/casbin/test/HibernateAdapterTest.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package org.casbin.test;

import com.mysql.jdbc.jdbc2.optional.MysqlDataSource;
import org.casbin.adapter.HibernateAdapter;
import org.casbin.jcasbin.main.Enforcer;
import org.casbin.jcasbin.persist.Adapter;
import org.junit.Before;
import org.junit.Test;

import java.sql.SQLException;

import static org.junit.Assert.assertEquals;

public class HibernateAdapterTest {
Expand All @@ -15,7 +18,7 @@ public class HibernateAdapterTest {
private static final String PASSWORD = "casbin_test";

@Before
public void initDataBase() {
public void initDataBase() throws SQLException {
Enforcer e = new Enforcer("examples/rbac_with_domains_model.conf");

Adapter adapter = new HibernateAdapter(DRIVER, URL, USERNAME, PASSWORD, true);
Expand All @@ -33,7 +36,7 @@ public void initDataBase() {
}

@Test
public void testLoadPolicy() {
public void testLoadPolicy() throws SQLException {
Enforcer e = new Enforcer("examples/rbac_with_domains_model.conf");

Adapter adapter = new HibernateAdapter(DRIVER, URL, USERNAME, PASSWORD, true);
Expand All @@ -52,7 +55,31 @@ public void testLoadPolicy() {
}

@Test
public void testSavePolicy() {
public void testLoadPolicyWithDataSource() throws SQLException {
Enforcer e = new Enforcer("examples/rbac_with_domains_model.conf");

MysqlDataSource dataSource = new MysqlDataSource();
dataSource.setURL(URL);
dataSource.setUser(USERNAME);
dataSource.setPassword(PASSWORD);

Adapter adapter = new HibernateAdapter(dataSource);

e.setAdapter(adapter);
e.loadPolicy();

testDomainEnforce(e, "alice", "domain1", "data1", "read", true);
testDomainEnforce(e, "alice", "domain1", "data1", "write", true);
testDomainEnforce(e, "alice", "domain1", "data2", "read", false);
testDomainEnforce(e, "alice", "domain1", "data2", "write", false);
testDomainEnforce(e, "bob", "domain2", "data1", "read", false);
testDomainEnforce(e, "bob", "domain2", "data1", "write", false);
testDomainEnforce(e, "bob", "domain2", "data2", "read", true);
testDomainEnforce(e, "bob", "domain2", "data2", "write", true);
}

@Test
public void testSavePolicy() throws SQLException {
Enforcer e = new Enforcer("examples/rbac_with_domains_model.conf");

Adapter adapter = new HibernateAdapter(DRIVER, URL, USERNAME, PASSWORD, true);
Expand Down Expand Up @@ -80,7 +107,7 @@ public void testSavePolicy() {
}

@Test
public void testRemovePolicy() {
public void testRemovePolicy() throws SQLException {
Enforcer e = new Enforcer("examples/rbac_with_domains_model.conf");

Adapter adapter = new HibernateAdapter(DRIVER, URL, USERNAME, PASSWORD, true);
Expand Down

0 comments on commit dae97fc

Please sign in to comment.