Skip to content

Commit

Permalink
Add a SQL files generator
Browse files Browse the repository at this point in the history
  • Loading branch information
loicknuchel committed Oct 31, 2020
1 parent 5e8a117 commit 169ee50
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 42 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# SafeQL [![travis-badge][]][travis] [![codecov-badge][]][codecov] [![release-badge][]][release] [![maven-badge][]][maven] [![license-badge][]][license]
# SafeQL

[![travis-badge][]][travis] [![codecov-badge][]][codecov] [![release-badge][]][release] [![maven-badge][]][maven] [![license-badge][]][license]

[travis]: https://travis-ci.com/loicknuchel/SafeQL
[travis-badge]: https://travis-ci.com/loicknuchel/SafeQL.svg?branch=master
Expand Down Expand Up @@ -54,4 +56,4 @@ val postsWithUsers: List[(Post, User)] = POSTS.joinOn(_.AUTHOR).select.all[(Post
## Releasing

Every commit on master is [released as SNAPSHOT](https://oss.sonatype.org/#nexus-search;quick~fr.loicknuchel) so you can use it immediately thanks to [sbt-ci-release](https://github.com/olafurpg/sbt-ci-release) plugin.
To push a stable release, push a tag version starting with 'v' (ex: `v0.1.0`).
To push a stable release, push a tag version starting with 'v' (ex: `v0.1.0`) or create a release from github interface with a correct tage name.
3 changes: 3 additions & 0 deletions src/main/scala/fr/loicknuchel/safeql/Exceptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,8 @@ case class InvalidNumberOfFields[A](read: Read[A], fields: List[Field[_]])
case class FailedQuery(fr: Fragment, cause: Throwable)
extends Exception(s"Fail on ${fr.query.sql}: ${cause.getMessage}", cause)

case class FailedScript(script: String, cause: Throwable)
extends Exception(s"Script has an error: ${cause.getMessage}", cause)

case class NotImplementedJoin[T <: Table, T2 <: Table](t: T, t2: T2)
extends Exception(s"Join between ${t.sql} and ${t2.sql} is not implemented")
64 changes: 54 additions & 10 deletions src/main/scala/fr/loicknuchel/safeql/gen/Generator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@ package fr.loicknuchel.safeql.gen
import java.util.UUID

import cats.effect.IO
import doobie.Update0
import doobie.syntax.connectionio._
import fr.loicknuchel.safeql.FailedScript
import fr.loicknuchel.safeql.gen.reader.{H2Reader, Reader}
import fr.loicknuchel.safeql.gen.writer.Writer
import fr.loicknuchel.safeql.utils.Extensions._
import fr.loicknuchel.safeql.utils.FileUtils
import org.flywaydb.core.Flyway
import org.flywaydb.core.internal.jdbc.DriverDataSource

import scala.util.control.NonFatal

object Generator {
def generate(reader: Reader, writer: Writer): IO[Unit] = for {
database <- reader.read()
_ <- writer.write(database).toIO
} yield ()
/**
* Flyway Generator
*/

def flyway(flywayLocations: String*): FlywayGeneratorBuilder = {
val reader = H2Reader(
Expand All @@ -24,21 +29,56 @@ object Generator {
.dataSource(new DriverDataSource(this.getClass.getClassLoader, reader.driver, reader.url, reader.user, reader.pass))
.locations(flywayLocations: _*)
.load()
new FlywayGeneratorBuilder(reader, flyway)
new FlywayGeneratorBuilder(flyway, reader)
}

class FlywayGeneratorBuilder(reader: H2Reader, flyway: Flyway) {
def writer(writer: Writer): FlywayGenerator = new FlywayGenerator(reader, flyway, writer)
class FlywayGeneratorBuilder(flyway: Flyway, reader: H2Reader) {
def writer(writer: Writer): FlywayGenerator = new FlywayGenerator(flyway, reader, writer)

def excludes(regex: String): FlywayGeneratorBuilder = new FlywayGeneratorBuilder(reader.excludes(regex), flyway)
def excludes(regex: String): FlywayGeneratorBuilder = new FlywayGeneratorBuilder(flyway, reader.excludes(regex))
}

class FlywayGenerator(reader: H2Reader, flyway: Flyway, writer: Writer) {
class FlywayGenerator(flyway: Flyway, reader: H2Reader, writer: Writer) {
def generate(): IO[Unit] = IO(flyway.migrate()).flatMap(_ => Generator.generate(reader, writer))

def excludes(regex: String): FlywayGenerator = new FlywayGenerator(reader.excludes(regex), flyway, writer)
def excludes(regex: String): FlywayGenerator = new FlywayGenerator(flyway, reader.excludes(regex), writer)
}

/**
* SQL files Generator
*/

def fromFiles(paths: List[String]): SQLFilesGeneratorBuilder = {
val reader = H2Reader(
url = s"jdbc:h2:mem:${UUID.randomUUID()};MODE=PostgreSQL;DATABASE_TO_UPPER=false;DB_CLOSE_DELAY=-1",
schema = Some("PUBLIC"),
excludes = None)
new SQLFilesGeneratorBuilder(paths, reader)
}

class SQLFilesGeneratorBuilder(paths: List[String], reader: H2Reader) {
def writer(writer: Writer): SQLFilesGenerator = new SQLFilesGenerator(paths, reader, writer)

def excludes(regex: String): SQLFilesGeneratorBuilder = new SQLFilesGeneratorBuilder(paths, reader.excludes(regex))
}

class SQLFilesGenerator(paths: List[String], reader: H2Reader, writer: Writer) {
def generate(): IO[Unit] = for {
files <- paths.map(FileUtils.read).sequence.toIO
_ <- files.map(exec(_, reader.xa)).sequence
_ <- Generator.generate(reader, writer)
} yield ()

def excludes(regex: String): SQLFilesGenerator = new SQLFilesGenerator(paths, reader.excludes(regex), writer)

private def exec(script: String, xa: doobie.Transactor[IO]): IO[Int] =
Update0(script, None).run.transact(xa).recoverWith { case NonFatal(e) => IO.raiseError(FailedScript(script, e)) }
}

/**
* Reader Generator
*/

def reader(reader: Reader) = new ReaderGeneratorBuilder(reader)

class ReaderGeneratorBuilder(reader: Reader) {
Expand All @@ -49,4 +89,8 @@ object Generator {
def generate(): IO[Unit] = Generator.generate(reader, writer)
}

private def generate(reader: Reader, writer: Writer): IO[Unit] = for {
database <- reader.read()
_ <- writer.write(database).toIO
} yield ()
}
10 changes: 4 additions & 6 deletions src/main/scala/fr/loicknuchel/safeql/gen/reader/H2Reader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ class H2Reader(val url: String,
schema: Option[String],
excludes: Option[String]) extends Reader {
val driver: String = "org.h2.Driver"
protected[gen] lazy val xa: doobie.Transactor[IO] = {
implicit val cs: ContextShift[IO] = IO.contextShift(ExecutionContext.global)
Transactor.fromDriverManager[IO](driver, url, user, pass)
}

override def read(): IO[Database] = for {
xa <- IO.pure(getTransactor)
columns <- readColumns(xa)
crossReferences <- readCrossReferences(xa)
} yield buildDatabase(columns, crossReferences)
Expand All @@ -36,11 +39,6 @@ class H2Reader(val url: String,
}.filterNot(s => excludes.exists(e => s.name.matches(e))))
}

protected[reader] def getTransactor: doobie.Transactor[IO] = {
implicit val cs: ContextShift[IO] = IO.contextShift(ExecutionContext.global)
Transactor.fromDriverManager[IO](driver, url, user, pass)
}

protected[reader] def readColumns(xa: doobie.Transactor[IO]): IO[List[Column]] =
(fr0"SELECT TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, ORDINAL_POSITION, DOMAIN_CATALOG, DOMAIN_SCHEMA, DOMAIN_NAME, COLUMN_DEFAULT, IS_NULLABLE, DATA_TYPE, CHARACTER_MAXIMUM_LENGTH, CHARACTER_OCTET_LENGTH, NUMERIC_PRECISION, NUMERIC_PRECISION_RADIX, NUMERIC_SCALE, DATETIME_PRECISION, INTERVAL_TYPE, INTERVAL_PRECISION, CHARACTER_SET_NAME, COLLATION_NAME, TYPE_NAME, NULLABLE, IS_COMPUTED, SELECTIVITY, CHECK_CONSTRAINT, SEQUENCE_NAME, REMARKS, SOURCE_DATA_TYPE, COLUMN_TYPE, COLUMN_ON_UPDATE, IS_VISIBLE FROM INFORMATION_SCHEMA.COLUMNS" ++
schema.map(s => fr0" WHERE TABLE_SCHEMA=$s").getOrElse(fr0"") ++ fr0" ORDER BY TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, ORDINAL_POSITION").query[Column].to[List].transact(xa)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ class ScalaWriter(directory: String,
require(config.getConfigErrors.isEmpty, s"DatabaseConfig has some errors :${config.getConfigErrors.map("\n - " + _).mkString}")
require(StringUtils.isScalaPackage(packageName), s"'$packageName' is an invalid scala package name")

def directory(dir: String): ScalaWriter = new ScalaWriter(dir, packageName, identifierStrategy, config)

override protected def getDatabaseErrors(db: Database): List[String] = config.getDatabaseErrors(db)

override protected[writer] def rootFolderPath: String = directory + "/" + packageName.replaceAll("\\.", "/")
override protected[gen] def rootFolderPath: String = directory + "/" + packageName.replaceAll("\\.", "/")

override protected[writer] def tableFilePath(t: Table): String = tablesFolderPath + "/" + idf(t.name) + ".scala"

Expand Down
14 changes: 14 additions & 0 deletions src/main/scala/fr/loicknuchel/safeql/utils/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,20 @@ object Extensions {
}
}

implicit class RichTraversableOnceIO[A, M[X] <: TraversableOnce[X]](val in: M[IO[A]]) extends AnyVal {
def sequence(implicit cbf: CanBuildFrom[M[IO[A]], A, M[A]]): IO[M[A]] = IO {
val init = IO.pure(cbf(in) -> List.empty[Throwable])
in.foldLeft(init) { (acc, cur) =>
acc.flatMap { case (results, errors) =>
Try(cur.unsafeRunSync())
.map { result => (results += result, errors) }
.recover { case NonFatal(error) => (results, error +: errors) }
.toIO
}
}.flatMap(sequenceResult[A, M](_).toIO).unsafeRunSync()
}
}

implicit class RichTraversableOnceFragment[M[X] <: TraversableOnce[X]](val in: M[Fragment]) extends AnyVal {
def mkFragment(sep: Fragment): Fragment = in.toList match {
case Nil => fr0""
Expand Down
67 changes: 52 additions & 15 deletions src/test/scala/fr/loicknuchel/safeql/gen/GeneratorSpec.scala
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
package fr.loicknuchel.safeql.gen

import java.util.UUID

import cats.data.NonEmptyList
import fr.loicknuchel.safeql.gen.reader.H2Reader
import fr.loicknuchel.safeql.gen.writer.ScalaWriter.{DatabaseConfig, FieldConfig, SchemaConfig, TableConfig}
import fr.loicknuchel.safeql.gen.writer.{ScalaWriter, Writer}
import fr.loicknuchel.safeql.testingutils.SqlSpec
import fr.loicknuchel.safeql.testingutils.BaseSpec
import fr.loicknuchel.safeql.utils.Extensions._
import fr.loicknuchel.safeql.utils.FileUtils
import org.flywaydb.core.Flyway
import org.flywaydb.core.internal.jdbc.DriverDataSource
import org.scalatest.BeforeAndAfterEach

import scala.util.Try

class GeneratorSpec extends SqlSpec {
class GeneratorSpec extends BaseSpec with BeforeAndAfterEach {
private val root = "target/tmp-generator-tests"
private val reader = H2Reader(
url = dbUrl,
user = dbUser,
pass = dbPass,
url = s"jdbc:h2:mem:${UUID.randomUUID()};MODE=PostgreSQL;DATABASE_TO_UPPER=false;DB_CLOSE_DELAY=-1",
schema = Some("PUBLIC"),
excludes = Some(".*flyway.*"))
private val writer = ScalaWriter(
Expand All @@ -29,18 +37,47 @@ class GeneratorSpec extends SqlSpec {
"id" -> FieldConfig(customType = Some("Post.Id"))))
)))))

override protected def afterEach(): Unit = FileUtils.delete(root).get

describe("Generator") {
ignore("should generate database tables") {
Generator.generate(reader, writer).unsafeRunSync()
it("should generate the same files with all the generators") {
// Basic generation
Flyway.configure()
.dataSource(new DriverDataSource(this.getClass.getClassLoader, reader.driver, reader.url, reader.user, reader.pass))
.locations("classpath:sql_migrations")
.load().migrate()
val basicPath = s"$root/basic-gen"
Generator.reader(reader).writer(writer.directory(basicPath)).generate().unsafeRunSync()
val basicDb = getFolderContent(basicPath).get

// Flyway generator
val flywapPath = s"$root/flyway-gen"
Generator.flyway("classpath:sql_migrations").writer(writer.directory(flywapPath)).generate().unsafeRunSync()
val flywayDb = getFolderContent(flywapPath).get
flywayDb shouldBe basicDb

// SQL files generator
val sqlFilesPath = s"$root/sql-gen"
Generator.fromFiles(List("src/test/resources/sql_migrations/V1__test_schema.sql")).writer(writer.directory(sqlFilesPath)).generate().unsafeRunSync()
val sqlFilesDb = getFolderContent(sqlFilesPath).get
sqlFilesDb shouldBe basicDb
}
it("should generate same files as before") {
val existingFiles = writer.readFiles().get
val database = reader.read().unsafeRunSync()
val newFiles = writer.generateFiles(database)
newFiles.size shouldBe existingFiles.size
newFiles.map { case (path, content) =>
content.trim shouldBe existingFiles.getOrElse(path, "").trim
}
it("should keep the generated database up to date") {
val flywayWriter = writer.directory(s"$root/flyway-gen")
Generator.flyway("classpath:sql_migrations").writer(flywayWriter).generate().unsafeRunSync()

val flywayDb = getFolderContent(flywayWriter.rootFolderPath).get
val currentDb = getFolderContent(writer.rootFolderPath).get
currentDb shouldBe flywayDb
}
ignore("should generate the database tables") { // run this test to generate the test database tables
Generator.reader(reader).writer(writer).generate().unsafeRunSync()
}
}

private def getFolderContent(path: String): Try[Map[String, String]] = {
FileUtils.listFiles(path)
.flatMap(_.map(p => FileUtils.read(p).map(c => (p.stripPrefix(path), c))).sequence)
.map(_.toMap)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ import org.scalatest.BeforeAndAfterAll

class H2ReaderSpec extends BaseSpec with BeforeAndAfterAll {
private val reader = H2Reader("jdbc:h2:mem:reader_db;MODE=PostgreSQL;DATABASE_TO_UPPER=false;DB_CLOSE_DELAY=-1", schema = Some("PUBLIC"))
private val xa = reader.getTransactor

override def beforeAll(): Unit = {
sql"CREATE TABLE users (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))".update.run.transact(xa).unsafeRunSync()
sql"CREATE TABLE posts (id INT NOT NULL PRIMARY KEY, title VARCHAR(50), author INT NOT NULL REFERENCES users (id))".update.run.transact(xa).unsafeRunSync()
sql"CREATE TABLE users (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))".update.run.transact(reader.xa).unsafeRunSync()
sql"CREATE TABLE posts (id INT NOT NULL PRIMARY KEY, title VARCHAR(50), author INT NOT NULL REFERENCES users (id))".update.run.transact(reader.xa).unsafeRunSync()
()
}

Expand All @@ -39,25 +38,25 @@ class H2ReaderSpec extends BaseSpec with BeforeAndAfterAll {
))))))
}
it("should read columns") {
reader.readColumns(xa).unsafeRunSync() shouldBe List(
reader.readColumns(reader.xa).unsafeRunSync() shouldBe List(
Column("reader_db", "PUBLIC", "posts", "id", 1, None, None, None, None, false, 4, 10, 10, 10, 10, 0, None, None, None, "Unicode", "OFF", "INTEGER", false, false, 50, "", None, "", None, "INT NOT NULL", None, true),
Column("reader_db", "PUBLIC", "posts", "title", 2, None, None, None, None, true, 12, 50, 50, 50, 10, 0, None, None, None, "Unicode", "OFF", "VARCHAR", true, false, 50, "", None, "", None, "VARCHAR(50)", None, true),
Column("reader_db", "PUBLIC", "posts", "author", 3, None, None, None, None, false, 4, 10, 10, 10, 10, 0, None, None, None, "Unicode", "OFF", "INTEGER", false, false, 50, "", None, "", None, "INT NOT NULL", None, true),
Column("reader_db", "PUBLIC", "users", "id", 1, None, None, None, None, false, 4, 10, 10, 10, 10, 0, None, None, None, "Unicode", "OFF", "INTEGER", false, false, 50, "", None, "", None, "INT NOT NULL", None, true),
Column("reader_db", "PUBLIC", "users", "name", 2, None, None, None, None, true, 12, 50, 50, 50, 10, 0, None, None, None, "Unicode", "OFF", "VARCHAR", true, false, 50, "", None, "", None, "VARCHAR(50)", None, true))
}
it("should read constraints") {
reader.readConstraints(xa).unsafeRunSync() shouldBe List(
reader.readConstraints(reader.xa).unsafeRunSync() shouldBe List(
Constraint("reader_db", "PUBLIC", "CONSTRAINT_65E", "REFERENTIAL", "reader_db", "PUBLIC", "posts", "PRIMARY_KEY_6", None, "author", "", """ALTER TABLE "PUBLIC"."posts" ADD CONSTRAINT "PUBLIC"."CONSTRAINT_65E" FOREIGN KEY("author") INDEX "PUBLIC"."CONSTRAINT_INDEX_6" REFERENCES "PUBLIC"."users"("id") NOCHECK""", 11),
Constraint("reader_db", "PUBLIC", "CONSTRAINT_65", "PRIMARY KEY", "reader_db", "PUBLIC", "posts", "PRIMARY_KEY_65", None, "id", "", """ALTER TABLE "PUBLIC"."posts" ADD CONSTRAINT "PUBLIC"."CONSTRAINT_65" PRIMARY KEY("id") INDEX "PUBLIC"."PRIMARY_KEY_65"""", 9),
Constraint("reader_db", "PUBLIC", "CONSTRAINT_6", "PRIMARY KEY", "reader_db", "PUBLIC", "users", "PRIMARY_KEY_6", None, "id", "", """ALTER TABLE "PUBLIC"."users" ADD CONSTRAINT "PUBLIC"."CONSTRAINT_6" PRIMARY KEY("id") INDEX "PUBLIC"."PRIMARY_KEY_6"""", 6))
}
it("should read cross references") {
reader.readCrossReferences(xa).unsafeRunSync() shouldBe List(
reader.readCrossReferences(reader.xa).unsafeRunSync() shouldBe List(
CrossReference("reader_db", "PUBLIC", "users", "id", "reader_db", "PUBLIC", "posts", "author", 1, 1, 1, "CONSTRAINT_65E", "PRIMARY_KEY_6", 7))
}
it("should read tables") {
reader.readTables(xa).unsafeRunSync() shouldBe List(
reader.readTables(reader.xa).unsafeRunSync() shouldBe List(
Table("reader_db", "PUBLIC", "posts", "TABLE", "MEMORY", Some(
"""CREATE MEMORY TABLE "PUBLIC"."posts"(
| "id" INT NOT NULL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ object GeneratorSamples {
}

def generateFromSQLFiles(): Unit = {
// Generator.fromFiles(List()).writer(ScalaWriter()).generate().unsafeRunSync()
Generator
.fromFiles(List("src/test/resources/sql_migrations/V1__test_schema.sql"))
.writer(ScalaWriter(packageName = "com.company.db"))
.generate().unsafeRunSync()
}

def generateFromDatabase(): Unit = {
Expand Down

0 comments on commit 169ee50

Please sign in to comment.