Skip to content

Commit

Permalink
added test for conflicting names
Browse files Browse the repository at this point in the history
  • Loading branch information
arcuri82 committed Nov 18, 2024
1 parent b140a1c commit 410cf23
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
public class TableDto {

/**
* In Postgres this represents a "schema", whereas it is a "catalog" for MySQL.
* In other words, this is used to group tables that can be indexed by this group name for disambiguation.
* The schema this table belongs to.
* Note that databases like MySQL make no distinction between catalog and schema.
*/
public String openGroupName;
public String schema;


public String catalog;

/**
* The name of the table
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -545,10 +545,10 @@ private static void addConstraints(DbInfoDto schemaDto, List<DbTableConstraint>
}

private static String getId(TableDto dto){
if(dto.openGroupName == null){
if(dto.schema == null){
return dto.name;
}
return dto.openGroupName + "." + dto.name;
return dto.schema + "." + dto.name;
}

private static void handleTableEntry(Connection connection, DbInfoDto schemaDto, DatabaseMetaData md, ResultSet tables, Set<String> tableIds) throws SQLException {
Expand All @@ -571,15 +571,6 @@ private static void handleTableEntry(Connection connection, DbInfoDto schemaDto,
}
}

//no longer done: we extract all schemas for a given catalog
// if (tableSchema!=null && !tableSchema.equalsIgnoreCase(schemaDto.name)) {
// /**
// * If this table does not belong to the current schema under extraction,
// * skip adding the table.
// */
// return;
// }

List<String> toSkip = SchemasToSkip.get(type);
if(toSkip!=null && toSkip.contains(tableSchema)){
return;
Expand All @@ -588,7 +579,8 @@ private static void handleTableEntry(Connection connection, DbInfoDto schemaDto,
TableDto tableDto = new TableDto();
schemaDto.tables.add(tableDto);
tableDto.name = tables.getString("TABLE_NAME");
tableDto.openGroupName = tableSchema;
tableDto.schema = tableSchema;
tableDto.catalog = tableCatalog;

if (tableIds.contains(getId(tableDto))) {
/*
Expand All @@ -601,7 +593,7 @@ private static void handleTableEntry(Connection connection, DbInfoDto schemaDto,

Set<String> pks = new HashSet<>();
SortedMap<Integer, String> primaryKeySequence = new TreeMap<>();
ResultSet rsPK = md.getPrimaryKeys(schemaDto.name, tableDto.openGroupName, tableDto.name);
ResultSet rsPK = md.getPrimaryKeys(tableDto.catalog, tableDto.schema, tableDto.name);

while (rsPK.next()) {
String pkColumnName = rsPK.getString("COLUMN_NAME");
Expand All @@ -614,7 +606,7 @@ private static void handleTableEntry(Connection connection, DbInfoDto schemaDto,

tableDto.primaryKeySequence.addAll(primaryKeySequence.values());

ResultSet columns = md.getColumns(schemaDto.name, tableDto.openGroupName, tableDto.name, null);
ResultSet columns = md.getColumns(tableDto.catalog, tableDto.schema, tableDto.name, null);

Set<String> columnNames = new HashSet<>();
while (columns.next()) {
Expand Down Expand Up @@ -663,7 +655,7 @@ private static void handleTableEntry(Connection connection, DbInfoDto schemaDto,
columns.close();


ResultSet fks = md.getImportedKeys(null, null, tableDto.name);
ResultSet fks = md.getImportedKeys(tableDto.catalog, tableDto.schema, tableDto.name);
while (fks.next()) {
//TODO need to see how to handle case of multi-columns

Expand Down Expand Up @@ -738,7 +730,7 @@ private static void extractMySQLColumn(DbInfoDto schemaDto,
* corresponding [DATA_TYPE] column value.
*/
String sqlQuery = String.format("SELECT DATA_TYPE, table_schema from INFORMATION_SCHEMA.COLUMNS where\n" +
" table_schema = '%s' and table_name = '%s' and column_name= '%s' ", tableDto.openGroupName, tableDto.name, columnDto.name);
" table_schema = '%s' and table_name = '%s' and column_name= '%s' ", tableDto.schema, tableDto.name, columnDto.name);
try (Statement statement = connection.createStatement()) {
ResultSet rs = statement.executeQuery(sqlQuery);
if (rs.next()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private List<DbTableConstraint> extractTableConstraintsVersionTwoOrHigher(Connec
List<DbTableConstraint> tableCheckExpressions = new ArrayList<>();

for (TableDto tableDto : schemaDto.tables) {
String tableSchema = tableDto.openGroupName;
String tableSchema = tableDto.schema;
String tableName = tableDto.name;
try (Statement statement = connectionToH2.createStatement()) {
final String query = String.format("Select CONSTRAINT_CATALOG,CONSTRAINT_SCHEMA,CONSTRAINT_NAME,CONSTRAINT_TYPE From INFORMATION_SCHEMA.TABLE_CONSTRAINTS\n" +
Expand Down Expand Up @@ -192,7 +192,7 @@ private List<DbTableConstraint> extractTableConstraintsVersionOneOrLower(Connect
List<DbTableConstraint> tableCheckExpressions = new ArrayList<>();

for (TableDto tableDto : schemaDto.tables) {
String tableSchema = tableDto.openGroupName;
String tableSchema = tableDto.schema;
String tableName = tableDto.name;
try (Statement statement = connectionToH2.createStatement()) {
final String query = String.format("Select CONSTRAINT_TYPE, CHECK_EXPRESSION, COLUMN_LIST From INFORMATION_SCHEMA.CONSTRAINTS\n" +
Expand Down Expand Up @@ -255,7 +255,7 @@ private List<DbTableConstraint> extractColumnConstraintsVersion1OrLower(Connecti

List<DbTableConstraint> columnConstraints = new ArrayList<>();
for (TableDto tableDto : schemaDto.tables) {
String tableSchema = tableDto.openGroupName;
String tableSchema = tableDto.schema;
String tableName = tableDto.name;

try (Statement statement = connectionToH2.createStatement()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public List<DbTableConstraint> extract(Connection connectionToMySQL, DbInfoDto s
List<DbTableConstraint> constraints = new ArrayList<>();

for (TableDto tableDto : schemaDto.tables){
String tableSchema = tableDto.openGroupName;
String tableSchema = tableDto.schema;
String tableName = tableDto.name;
try (Statement statement = connectionToMySQL.createStatement()) {
String query = String.format("SELECT *\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public List<DbTableConstraint> extract(Connection connectionToPostgres, DbInfoDt

List<DbTableConstraint> constraints = new ArrayList<>();
for (TableDto tableDto : schemaDto.tables) {
String tableSchema = tableDto.openGroupName;
String tableSchema = tableDto.schema;
String tableName = tableDto.name;
try (Statement statement = connectionToPostgres.createStatement()) {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package org.evomaster.core.sql.multidb

import org.evomaster.client.java.controller.api.dto.database.schema.DatabaseType
import org.evomaster.client.java.sql.DbInfoExtractor
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import java.sql.Connection

class ConflictingSchemasTest : MultiDbTestBase() {

override fun verify(databaseType: DatabaseType, connection: Connection, name: String) {

val info = DbInfoExtractor.extract(connection)
assertEquals(name.lowercase(), info.name.lowercase())
assertEquals(2, info.tables.size)
val first = info.tables.find { it.schema.equals("first",true) }!!
val second = info.tables.find { it.schema.equals("other",true) }!!

assertEquals(first.name, second.name)

assertEquals(2, first.columns.size)
assertEquals(2, second.columns.size)

assertTrue(first.columns.any { it.name.equals("x", true) })
assertTrue(second.columns.any { it.name.equals("y",true) })
}

override fun getSchemaLocation()= "/sql_schema/multidb/conflictingschemas.sql"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.evomaster.core.sql.multidb

import org.evomaster.client.java.controller.api.dto.database.schema.DatabaseType
import org.evomaster.client.java.sql.SqlScriptRunner
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.EnumSource
import java.sql.Connection

abstract class MultiDbTestBase {

@ParameterizedTest
@EnumSource(names = ["MYSQL","POSTGRES","H2"])
fun test(databaseType: DatabaseType){
val name = "dbtest"
val sqlSchemaCommand = this::class.java.getResource(getSchemaLocation()).readText()

MultiDbUtils.startDatabase(databaseType)
try {
MultiDbUtils.resetDatabase(name, databaseType)
val connection = MultiDbUtils.createConnection(name, databaseType)
connection.use {
SqlScriptRunner.execCommand(it, sqlSchemaCommand)
verify(databaseType,it, name)
}
}finally {
MultiDbUtils.stopDatabase(databaseType)
}
}

protected abstract fun verify(databaseType: DatabaseType, connection: Connection, name: String)

protected abstract fun getSchemaLocation(): String
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,27 @@ package org.evomaster.core.sql.multidb

import org.evomaster.client.java.controller.api.dto.database.schema.DatabaseType
import org.evomaster.client.java.sql.DbInfoExtractor
import org.evomaster.client.java.sql.SqlScriptRunner
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.EnumSource
import java.sql.Connection

class SecondSchemaTest {
class SecondSchemaTest : MultiDbTestBase(){

@ParameterizedTest
@EnumSource(names = ["MYSQL","POSTGRES","H2"])
fun test(databaseType: DatabaseType){
val name = "dbtest"
val sqlSchemaCommand = this::class.java.getResource(getSchemaLocation()).readText()

MultiDbUtils.startDatabase(databaseType)
try {
MultiDbUtils.resetDatabase(name, databaseType)
val connection = MultiDbUtils.createConnection(name, databaseType)
connection.use {
SqlScriptRunner.execCommand(it, sqlSchemaCommand)
verify(databaseType,it, name)
}
}finally {
MultiDbUtils.stopDatabase(databaseType)
}
}

private fun verify(databaseType: DatabaseType,connection: Connection, name: String){
override fun verify(databaseType: DatabaseType,connection: Connection, name: String){

val info = DbInfoExtractor.extract(connection)
assertEquals(name.lowercase(), info.name.lowercase())
assertEquals(2, info.tables.size)
val foo = info.tables.find { it.name.lowercase() == "foo" }!!
val bar = info.tables.find { it.name.lowercase() == "bar" }!!
if(databaseType == DatabaseType.MYSQL){
assertEquals(name.lowercase(), foo.openGroupName.lowercase())
assertEquals(name.lowercase(), foo.schema.lowercase())
} else {
assertEquals("public", foo.openGroupName.lowercase())
assertEquals("public", foo.schema.lowercase())
}
assertEquals("other", bar.openGroupName.lowercase())
assertEquals("other", bar.schema.lowercase())
}

fun getSchemaLocation() = "/sql_schema/multidb/secondschema.sql"
override fun getSchemaLocation() = "/sql_schema/multidb/secondschema.sql"
}
16 changes: 16 additions & 0 deletions core/src/test/resources/sql_schema/multidb/conflictingschemas.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
CREATE SCHEMA IF NOT EXISTS first;
CREATE SCHEMA IF NOT EXISTS other;


create table first.Foo (
id bigint not null,
x bigint,
primary key (id)
);


create table other.Foo(
id bigint not null,
y bigint,
primary key (id)
);

0 comments on commit 410cf23

Please sign in to comment.