Skip to content

Commit

Permalink
Fix getGeneratedKeys (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
koxudaxi authored Mar 20, 2021
1 parent d244bbc commit 2df0bbe
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ fun Application.module(testing: Boolean = false) {
)
val executeStatementResponse = try {
val statement = if (request.parameters == null) {
resource.connection.prepareStatementWithReturnGeneratedKeys(request.sql)
resource.connection.prepareStatement(request.sql)
} else {
val parameters = Parameters.parse(request.sql)
val statement =
resource.connection.prepareStatementWithReturnGeneratedKeys(parameters.sql)
resource.connection.prepareStatement(parameters.sql)
parameters.apply(statement, request.parameters.map { Pair(it.name, it.castValue) }.toMap())
statement
}
Expand All @@ -117,7 +117,7 @@ fun Application.module(testing: Boolean = false) {
} else {
ExecuteStatementResponse(
updatedCount,
statement.updateResults.lastOrNull() ?: emptyList(),
getGeneratedKeys(statement.connection),
)
}
if (resource.transactionId == null) {
Expand Down Expand Up @@ -151,7 +151,7 @@ fun Application.module(testing: Boolean = false) {
val batchExecuteStatementResponse = try {
val parameters = Parameters.parse(request.sql)
val statement =
resource.connection.prepareStatementWithReturnGeneratedKeys(parameters.sql)
resource.connection.prepareStatement(parameters.sql, Statement.RETURN_GENERATED_KEYS)

request.parameterSets.forEach { parameterSet ->
parameters.apply(statement, parameterSet.map { Pair(it.name, it.castValue) }.toMap())
Expand Down
24 changes: 11 additions & 13 deletions kotlin/local-data-api/src/com/koxudaxi/localDataApi/LocalDataApi.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ fun createField(resultSet: ResultSet, index: Int): Field {
}
}

fun isPostgreSQL(connection: Connection): Boolean = "PostgreSQL" in connection.metaData.databaseProductName

fun getGeneratedKeys(connection: Connection): List<Field> {
if (isPostgreSQL(connection)) return emptyList()
val resultSet = connection.createStatement().executeQuery("SELECT LAST_INSERT_ID()")
resultSet.next()
return IntRange(1, resultSet.metaData.columnCount).mapNotNull { index ->
resultSet.getInt(index)
}.filter { it > 0 }.map { Field(longValue = it.toLong()) }.toList()
}

val Statement.updateResults: List<List<Field>>
get() {
return this.generatedKeys.let { resultSet ->
Expand Down Expand Up @@ -56,19 +67,6 @@ val Statement.records: List<List<Field>>
return records.toList()
}

fun isReturnGeneratedKeysType(sql: String): Boolean {
val match = Regex("^[^a-zA-Z]*([a-zA-Z]+)").find(sql) ?: return false
return match.destructured.component1().toUpperCase() in listOf("INSERT", "UPDATE", "DELETE")
}

fun Connection.prepareStatementWithReturnGeneratedKeys(sql: String): PreparedStatement {
return if (isReturnGeneratedKeysType(sql)) {
this.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)
} else {
this.prepareStatement(sql)
}
}

fun createColumnMetadata(resultSet: ResultSet): List<ColumnMetadata> {
return resultSet.metaData.let {
IntRange(1, resultSet.metaData.columnCount).map { index ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,72 @@ class ApplicationTest {
}
}

@Test
fun testExecuteInsertParametersPostgresql() {
ResourceManager.INSTANCE.setResource("h2:./test;MODE=PostgreSQL", dummyResourceArn, null, null, emptyMap())

mockkStatic("com.koxudaxi.localDataApi.LocalDataApiKt")
every { isPostgreSQL(any()) } returns true

withTestApplication({ module(testing = true) }) {
handleRequest(HttpMethod.Post, "/Execute") {
addHeader(HttpHeaders.ContentType, "*/*")
setBody(Json.encodeToString(ExecuteStatementRequest(dummyResourceArn, dummySecretArn,
"CREATE TABLE TEST(id INT PRIMARY KEY AUTO_INCREMENT, name VARCHAR(10), age INT)"
)))
}.apply {
assertEquals(
"{\"numberOfRecordsUpdated\":0,\"generatedFields\":[],\"records\":null,\"columnMetadata\":null}",
response.content)
assertEquals(HttpStatusCode.OK, response.status())
}
handleRequest(HttpMethod.Post, "/Execute") {
addHeader(HttpHeaders.ContentType, "*/*")
setBody(Json.encodeToString(ExecuteStatementRequest(dummyResourceArn, dummySecretArn,
"select * from TEST")))
}.apply {
assertEquals(
"{\"numberOfRecordsUpdated\":0,\"generatedFields\":null,\"records\":[],\"columnMetadata\":null}",
response.content)
assertEquals(HttpStatusCode.OK, response.status())
}
handleRequest(HttpMethod.Post, "/Execute") {
addHeader(HttpHeaders.ContentType, "*/*")
setBody(Json.encodeToString(ExecuteStatementRequest(dummyResourceArn, dummySecretArn,
"INSERT INTO TEST (name, age) VALUES ('cat', 1)")))
}.apply {
assertEquals(
"{\"numberOfRecordsUpdated\":1,\"generatedFields\":[],\"records\":null,\"columnMetadata\":null}",
response.content)
assertEquals(HttpStatusCode.OK, response.status())
}

handleRequest(HttpMethod.Post, "/Execute") {
addHeader(HttpHeaders.ContentType, "*/*")
setBody(Json.encodeToString(ExecuteStatementRequest(dummyResourceArn, dummySecretArn,
"INSERT INTO test (name, age) VALUES (:name, :age)", parameters = listOf(
SqlParameter("name", Field(stringValue = "dog")),
SqlParameter("age", Field(longValue = 3)),
))))
}.apply {
assertEquals(
"{\"numberOfRecordsUpdated\":1,\"generatedFields\":[],\"records\":null,\"columnMetadata\":null}",
response.content)
assertEquals(HttpStatusCode.OK, response.status())
}
handleRequest(HttpMethod.Post, "/Execute") {
addHeader(HttpHeaders.ContentType, "*/*")
setBody(Json.encodeToString(ExecuteStatementRequest(dummyResourceArn, dummySecretArn,
"select * from TEST", includeResultMetadata = true)))
}.apply {
assertEquals(
"{\"numberOfRecordsUpdated\":0,\"generatedFields\":null,\"records\":[[{\"blobValue\":null,\"booleanValue\":null,\"doubleValue\":null,\"isNull\":null,\"longValue\":1,\"stringValue\":null},{\"blobValue\":null,\"booleanValue\":null,\"doubleValue\":null,\"isNull\":null,\"longValue\":null,\"stringValue\":\"cat\"},{\"blobValue\":null,\"booleanValue\":null,\"doubleValue\":null,\"isNull\":null,\"longValue\":1,\"stringValue\":null}],[{\"blobValue\":null,\"booleanValue\":null,\"doubleValue\":null,\"isNull\":null,\"longValue\":2,\"stringValue\":null},{\"blobValue\":null,\"booleanValue\":null,\"doubleValue\":null,\"isNull\":null,\"longValue\":null,\"stringValue\":\"dog\"},{\"blobValue\":null,\"booleanValue\":null,\"doubleValue\":null,\"isNull\":null,\"longValue\":3,\"stringValue\":null}]],\"columnMetadata\":[{\"arrayBaseColumnType\":0,\"isAutoIncrement\":true,\"isCaseSensitive\":true,\"isCurrency\":false,\"isSigned\":true,\"label\":\"ID\",\"name\":\"ID\",\"nullable\":0,\"precision\":10,\"scale\":0,\"schemaName\":\"PUBLIC\",\"tableName\":\"TEST\",\"type\":4,\"typeName\":\"INTEGER\"},{\"arrayBaseColumnType\":0,\"isAutoIncrement\":false,\"isCaseSensitive\":true,\"isCurrency\":false,\"isSigned\":true,\"label\":\"NAME\",\"name\":\"NAME\",\"nullable\":1,\"precision\":10,\"scale\":0,\"schemaName\":\"PUBLIC\",\"tableName\":\"TEST\",\"type\":12,\"typeName\":\"VARCHAR\"},{\"arrayBaseColumnType\":0,\"isAutoIncrement\":false,\"isCaseSensitive\":true,\"isCurrency\":false,\"isSigned\":true,\"label\":\"AGE\",\"name\":\"AGE\",\"nullable\":1,\"precision\":10,\"scale\":0,\"schemaName\":\"PUBLIC\",\"tableName\":\"TEST\",\"type\":4,\"typeName\":\"INTEGER\"}]}",
response.content)
assertEquals(HttpStatusCode.OK, response.status())
}
}
}

@Test
fun testExecuteInsertParametersWithTransaction() {
withTestApplication({ module(testing = true) }) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,6 @@ class LocalDataApiTest {
assertEquals(listOf(Types.TIMESTAMP, Types.TIMESTAMP_WITH_TIMEZONE), DATETIME)
}

@Test
fun testIsReturnGeneratedKeysType() {
assertEquals( false, isReturnGeneratedKeysType("selEct 1"))
assertEquals( false, isReturnGeneratedKeysType(" selEct 1"))
assertEquals( true, isReturnGeneratedKeysType(" insert 1"))
assertEquals( true, isReturnGeneratedKeysType(" Update 1"))
assertEquals( true, isReturnGeneratedKeysType(" DELETE 1"))
assertEquals( false, isReturnGeneratedKeysType(" CREATE 1"))
assertEquals( false, isReturnGeneratedKeysType(" ;"))
assertEquals( false, isReturnGeneratedKeysType(""))
}

@Test
fun testMySQL() {
mockkStatic(System::class)
Expand Down Expand Up @@ -78,7 +66,11 @@ class LocalDataApiTest {
val resource = mockk<Resource>(relaxed = true)
mockkStatic(Resource::class)
every {
Resource(Resource.Config("mysql", "arn:aws:rds:us-east-1:123456789012:cluster:dummy", "127.0.0.1", 3306, emptyMap()),
Resource(Resource.Config("mysql",
"arn:aws:rds:us-east-1:123456789012:cluster:dummy",
"127.0.0.1",
3306,
emptyMap()),
"root",
"example",
null,
Expand Down Expand Up @@ -107,7 +99,12 @@ class LocalDataApiTest {
val resource = mockk<Resource>(relaxed = true)
mockkStatic(Resource::class)
every {
Resource(Resource.Config("postgresql", "abc", "localhost", 1234, mapOf("stringtype" to "unspecified")), "user", "pass", null, null, "xyz")
Resource(Resource.Config("postgresql", "abc", "localhost", 1234, mapOf("stringtype" to "unspecified")),
"user",
"pass",
null,
null,
"xyz")
} returns resource

val env = mapOf(
Expand Down

0 comments on commit 2df0bbe

Please sign in to comment.