Skip to content

Commit

Permalink
Refactor SqlField and improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
loicknuchel committed Oct 31, 2020
1 parent c5b841b commit 1915388
Show file tree
Hide file tree
Showing 30 changed files with 562 additions and 158 deletions.
18 changes: 7 additions & 11 deletions src/main/scala/fr/loicknuchel/safeql/Cond.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,32 @@ object Cond {
override def fr: Fragment = f.fr ++ fr0"=$value"
}

final case class IsField[A](f1: Field[A], f2: Field[A]) extends Cond(List(f1, f2)) {
override def fr: Fragment = f1.fr ++ fr0"=" ++ f2.fr
final case class IsNotValue[A: Put](f: Field[A], value: A) extends Cond(List(f)) {
override def fr: Fragment = f.fr ++ fr0" != $value"
}

final case class IsFieldLeftOpt[A](f1: Field[Option[A]], f2: Field[A]) extends Cond(List(f1, f2)) {
final case class IsField[A](f1: Field[A], f2: Field[A]) extends Cond(List(f1, f2)) {
override def fr: Fragment = f1.fr ++ fr0"=" ++ f2.fr
}

final case class IsQuery[A](f: Field[A], s: Query.Select[A]) extends Cond(List(f)) {
override def fr: Fragment = f.fr ++ fr0"=(" ++ s.fr ++ fr0")"
}

final case class IsNotValue[A: Put](f: Field[A], value: A) extends Cond(List(f)) {
override def fr: Fragment = f.fr ++ fr0" != $value"
}

final case class Like[A](f: Field[A], value: String) extends Cond(List(f)) {
override def fr: Fragment = f.fr ++ fr0" LIKE $value"
}

final case class ILike[A](f: Field[A], value: String) extends Cond(List(f)) {
override def fr: Fragment = f.fr ++ fr0" ILIKE $value"
final case class NotLike[A](f: Field[A], value: String) extends Cond(List(f)) {
override def fr: Fragment = f.fr ++ fr0" NOT LIKE $value"
}

final case class LikeExpr(e: Expr, value: String) extends Cond(e.getFields) {
override def fr: Fragment = e.fr ++ fr0" LIKE $value"
}

final case class NotLike[A](f: Field[A], value: String) extends Cond(List(f)) {
override def fr: Fragment = f.fr ++ fr0" NOT LIKE $value"
final case class ILike[A](f: Field[A], value: String) extends Cond(List(f)) {
override def fr: Fragment = f.fr ++ fr0" ILIKE $value"
}

final case class GtValue[A: Put](f: Field[A], value: A) extends Cond(List(f)) {
Expand Down
70 changes: 34 additions & 36 deletions src/main/scala/fr/loicknuchel/safeql/Field.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ sealed trait Field[A] {

def is(value: A)(implicit p: Put[A]): Cond = IsValue(this, value)

def isNot(value: A)(implicit p: Put[A]): Cond = IsNotValue(this, value)

def is(field: Field[A]): Cond = IsField(this, field)

def is(select: Query.Select[A]): Cond = IsQuery(this, select)

def isNot(value: A)(implicit p: Put[A]): Cond = IsNotValue(this, value)

// TODO restrict to fields with sql string type
def like(value: String): Cond = Like(this, value)

def ilike(value: String): Cond = ILike(this, value)

def notLike(value: String): Cond = NotLike(this, value)

def ilike(value: String): Cond = ILike(this, value)

def gt(value: A)(implicit p: Put[A]): Cond = GtValue(this, value)

def gte(value: A)(implicit p: Put[A]): Cond = GteValue(this, value)
Expand Down Expand Up @@ -91,45 +91,30 @@ object Field {

}

class SqlField[A, +T <: Table.SqlTable](val table: T,
val name: String,
val info: SqlField.JdbcInfo,
val alias: Option[String]) extends Field[A] {
sealed trait SqlField[A, +T <: Table.SqlTable] extends Field[A] {
val table: T
val name: String
val info: SqlField.JdbcInfo
val alias: Option[String]

override def ref: Fragment = const0(s"${table.getAlias.getOrElse(table.getName)}.$name")

override def value: Fragment = ref

def nullable: Boolean = info.nullable

def as(alias: String): SqlField[A, T] = new SqlField[A, T](table, name, info, Some(alias))
override def as(alias: String): SqlField[A, T]

// create a null TableField based on a sql field, useful on union when a field is available on one side only
def asNull: NullField[A] = NullField[A](alias.getOrElse(name))

def asNull(name: String): NullField[A] = NullField[A](name)

override def toString: String = s"SqlField(${table.getName}.$name)"

def canEqual(other: Any): Boolean = other.isInstanceOf[SqlField[_, _]]

override def equals(other: Any): Boolean = other match {
case that: SqlField[_, _] =>
(that canEqual this) &&
table == that.table &&
name == that.name
case _ => false
}

override def hashCode(): Int = {
val state: List[Object] = List(table, name)
state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
}
def nullable: Boolean = info.nullable
}

object SqlField {

def apply[A, T <: Table.SqlTable](table: T, name: String, jdbcDeclaration: String, jdbcType: JdbcType, nullable: Boolean, index: Int): SqlField[A, T] =
new SqlField(table, name, JdbcInfo(nullable, index, jdbcType, jdbcDeclaration), None)
def apply[A, T <: Table.SqlTable](table: T, name: String, jdbcDeclaration: String, jdbcType: JdbcType, nullable: Boolean, index: Int): SqlFieldRaw[A, T] =
SqlFieldRaw(table, name, JdbcInfo(nullable, index, jdbcType, jdbcDeclaration), None)

def apply[A, T <: Table.SqlTable, T2 <: Table.SqlTable](table: T, name: String, jdbcDeclaration: String, jdbcType: JdbcType, nullable: Boolean, index: Int, references: SqlField[A, T2]): SqlFieldRef[A, T, T2] =
SqlFieldRef(table, name, JdbcInfo(nullable, index, jdbcType, jdbcDeclaration), None, references)
Expand All @@ -138,11 +123,22 @@ object SqlField {

}

case class SqlFieldRef[A, T <: Table.SqlTable, T2 <: Table.SqlTable](override val table: T,
override val name: String,
override val info: SqlField.JdbcInfo,
override val alias: Option[String],
references: SqlField[A, T2]) extends SqlField[A, T](table, name, info, alias) {
case class SqlFieldRaw[A, +T <: Table.SqlTable](table: T,
name: String,
info: SqlField.JdbcInfo,
alias: Option[String]) extends SqlField[A, T] {
override def as(alias: String): SqlFieldRaw[A, T] = copy(alias = Some(alias))

override def toString: String = s"SqlFieldRaw(${table.getName}.$name)"
}

case class SqlFieldRef[A, T <: Table.SqlTable, T2 <: Table.SqlTable](table: T,
name: String,
info: SqlField.JdbcInfo,
alias: Option[String],
references: SqlField[A, T2]) extends SqlField[A, T] {
override def as(alias: String): SqlFieldRef[A, T, T2] = copy(alias = Some(alias))

override def toString: String = s"SqlFieldRef(${table.getName}.$name, ${references.table.getName}.${references.name})"
}

Expand Down Expand Up @@ -183,18 +179,20 @@ object AggField {

def apply[A](name: String, alias: String): SimpleAggField[A] = SimpleAggField(name, Some(alias))

def apply[A](query: Query[A]): QueryAggField[A] = QueryAggField(query, None)

def apply[A](query: Query[A], alias: String): QueryAggField[A] = QueryAggField(query, Some(alias))
}

case class SimpleAggField[A](name: String, alias: Option[String]) extends AggField[A] {
case class SimpleAggField[A](name: String, alias: Option[String] = None) extends AggField[A] {
override def ref: Fragment = const0(alias.getOrElse(name))

override def value: Fragment = const0(name)

override def as(alias: String): AggField[A] = copy(alias = Some(alias))
}

case class QueryAggField[A](query: Query[A], alias: Option[String]) extends AggField[A] {
case class QueryAggField[A](query: Query[A], alias: Option[String] = None) extends AggField[A] {
override val name: String = "(" + query.sql + ")"

override def ref: Fragment = alias.map(const0(_)).getOrElse(query.fr)
Expand Down
9 changes: 4 additions & 5 deletions src/main/scala/fr/loicknuchel/safeql/Query.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ object Query {
case class Builder[T <: Table.SqlTable](private val table: T, private val fields: List[SqlField[_, T]]) {
def fields(fields: List[SqlField[_, T]]): Builder[T] = copy(fields = fields)

def fields(fields: SqlField[_, T]*): Builder[T] = this.fields(fields.toList)

// FIXME 2020-10-15: temporary hack waiting I solve the Put[Option[A]] problem (cf https://gist.github.com/loicknuchel/2297d612b58b399395bdd08d3c6dd217)
def values(fr: Fragment): Insert[T] = Insert(table, fields, fr)

Expand Down Expand Up @@ -135,8 +137,6 @@ object Query {
if (f.nullable) copy(values = values :+ (const0(f.name) ++ fr0"=$value")) else throw new Exception(s"Can't use an Option for non nullable field $f")
}

def all: Update[T] = Update(table, values, WhereClause(None, None, None))

def where(cond: Cond): Update[T] = Exceptions.check(cond, table, Update(table, values, WhereClause(Some(cond), None, None)))

def where(cond: T => Cond): Update[T] = where(cond(table))
Expand All @@ -157,8 +157,6 @@ object Query {
object Delete {

case class Builder[T <: Table.SqlTable](private val table: T) {
def all: Delete[T] = Delete(table, WhereClause(None, None, None))

def where(cond: Cond): Delete[T] = Exceptions.check(cond, table, Delete(table, WhereClause(Some(cond), None, None)))

def where(cond: T => Cond): Delete[T] = where(cond(table))
Expand Down Expand Up @@ -256,6 +254,7 @@ object Query {

def withoutFields(fns: (T => Field[_])*): Builder[T] = dropFields(fns.map(f => f(table)).toList)

// unsafe option is useful when a nested queries use a parent field, there is no way to track this right now as it's built independently
def where(cond: Cond, unsafe: Boolean = false): Builder[T] =
if (unsafe) copy(where = WhereClause(Some(cond), None, None)) else Exceptions.check(cond, table, copy(where = WhereClause(Some(cond), None, None)))

Expand All @@ -277,7 +276,7 @@ object Query {

def union[T2 <: Table](other: Builder[T2], alias: Option[String] = None, sorts: List[(String, String, List[String])] = List(), search: List[String] = List()): Table.UnionTable = {
if (fields.length != other.fields.length) throw new Exception(s"Field number do not match (${fields.length} vs ${other.fields.length})")
val invalidFields = fields.zip(other.fields).filter { case (f1, f2) => f1.alias.getOrElse(f1.name) != f2.alias.getOrElse(f2.name) } // TODO check also match of sql type (should be added)
val invalidFields = fields.zip(other.fields).filter { case (f1, f2) => f1.alias.getOrElse(f1.name) != f2.alias.getOrElse(f2.name) } // FIXME check also match of sql type (should be added)
if (invalidFields.nonEmpty) throw new Exception(s"Some fields do not match: ${invalidFields.map { case (f1, f2) => f1.name + " != " + f2.name }.mkString(", ")}")

val getFields = fields.map(f => TableField(f.alias.getOrElse(f.name), alias))
Expand Down
40 changes: 20 additions & 20 deletions src/main/scala/fr/loicknuchel/safeql/gen/Generator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,47 +29,47 @@ object Generator {
.dataSource(new DriverDataSource(this.getClass.getClassLoader, reader.driver, reader.url, reader.user, reader.pass))
.locations(flywayLocations: _*)
.load()
new FlywayGeneratorBuilder(flyway, reader)
FlywayGeneratorBuilder(flyway, reader)
}

class FlywayGeneratorBuilder(flyway: Flyway, reader: H2Reader) {
def writer(writer: Writer): FlywayGenerator = new FlywayGenerator(flyway, reader, writer)
case class FlywayGeneratorBuilder(flyway: Flyway, reader: H2Reader) {
def writer(writer: Writer): FlywayGenerator = FlywayGenerator(flyway, reader, writer)

def excludes(regex: String): FlywayGeneratorBuilder = new FlywayGeneratorBuilder(flyway, reader.excludes(regex))
def excludes(regex: String): FlywayGeneratorBuilder = FlywayGeneratorBuilder(flyway, reader.excludes(regex))
}

class FlywayGenerator(flyway: Flyway, reader: H2Reader, writer: Writer) {
case class FlywayGenerator(flyway: Flyway, reader: H2Reader, writer: Writer) {
def generate(): IO[Unit] = IO(flyway.migrate()).flatMap(_ => Generator.generate(reader, writer))

def excludes(regex: String): FlywayGenerator = new FlywayGenerator(flyway, reader.excludes(regex), writer)
def excludes(regex: String): FlywayGenerator = FlywayGenerator(flyway, reader.excludes(regex), writer)
}

/**
* SQL files Generator
*/

def fromFiles(paths: List[String]): SQLFilesGeneratorBuilder = {
def sqlFiles(paths: List[String]): SQLFilesGeneratorBuilder = {
val reader = H2Reader(
url = s"jdbc:h2:mem:${UUID.randomUUID()};MODE=PostgreSQL;DATABASE_TO_UPPER=false;DB_CLOSE_DELAY=-1",
schema = Some("PUBLIC"),
excludes = None)
new SQLFilesGeneratorBuilder(paths, reader)
SQLFilesGeneratorBuilder(paths, reader)
}

class SQLFilesGeneratorBuilder(paths: List[String], reader: H2Reader) {
def writer(writer: Writer): SQLFilesGenerator = new SQLFilesGenerator(paths, reader, writer)
case class SQLFilesGeneratorBuilder(paths: List[String], reader: H2Reader) {
def writer(writer: Writer): SQLFilesGenerator = SQLFilesGenerator(paths, reader, writer)

def excludes(regex: String): SQLFilesGeneratorBuilder = new SQLFilesGeneratorBuilder(paths, reader.excludes(regex))
def excludes(regex: String): SQLFilesGeneratorBuilder = SQLFilesGeneratorBuilder(paths, reader.excludes(regex))
}

class SQLFilesGenerator(paths: List[String], reader: H2Reader, writer: Writer) {
case class SQLFilesGenerator(paths: List[String], reader: H2Reader, writer: Writer) {
def generate(): IO[Unit] = for {
files <- paths.map(FileUtils.read).sequence.toIO
_ <- files.map(exec(_, reader.xa)).sequence
_ <- Generator.generate(reader, writer)
} yield ()

def excludes(regex: String): SQLFilesGenerator = new SQLFilesGenerator(paths, reader.excludes(regex), writer)
def excludes(regex: String): SQLFilesGenerator = SQLFilesGenerator(paths, reader.excludes(regex), writer)

private def exec(script: String, xa: doobie.Transactor[IO]): IO[Int] =
Update0(script, None).run.transact(xa).recoverWith { case NonFatal(e) => IO.raiseError(FailedScript(script, e)) }
Expand All @@ -79,13 +79,13 @@ object Generator {
* Reader Generator
*/

def reader(reader: Reader) = new ReaderGeneratorBuilder(reader)
def reader(reader: Reader): ReaderGeneratorBuilder = ReaderGeneratorBuilder(reader)

class ReaderGeneratorBuilder(reader: Reader) {
def writer(writer: Writer): ReaderGenerator = new ReaderGenerator(reader, writer)
case class ReaderGeneratorBuilder(reader: Reader) {
def writer(writer: Writer): ReaderGenerator = ReaderGenerator(reader, writer)
}

class ReaderGenerator(reader: Reader, writer: Writer) {
case class ReaderGenerator(reader: Reader, writer: Writer) {
def generate(): IO[Unit] = Generator.generate(reader, writer)
}

Expand All @@ -98,12 +98,12 @@ object Generator {
* Allow to start with writer
*/

def writer(writer: Writer) = new Builder(writer)
def writer(writer: Writer): Builder = Builder(writer)

class Builder(writer: Writer) {
case class Builder(writer: Writer) {
def flyway(flywayLocations: String*): FlywayGenerator = Generator.flyway(flywayLocations: _*).writer(writer)

def fromFiles(paths: List[String]): SQLFilesGenerator = Generator.fromFiles(paths).writer(writer)
def sqlFiles(paths: List[String]): SQLFilesGenerator = Generator.sqlFiles(paths).writer(writer)

def reader(reader: Reader): ReaderGenerator = Generator.reader(reader).writer(writer)
}
Expand Down
10 changes: 5 additions & 5 deletions src/main/scala/fr/loicknuchel/safeql/gen/reader/H2Reader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import fr.loicknuchel.safeql.gen.reader.H2Reader._

import scala.concurrent.ExecutionContext

class H2Reader(val url: String,
val user: String,
val pass: String,
val schema: Option[String],
val excludes: Option[String]) extends Reader {
case class H2Reader(url: String,
user: String,
pass: String,
schema: Option[String],
excludes: Option[String]) extends Reader {
val driver: String = "org.h2.Driver"
protected[gen] lazy val xa: doobie.Transactor[IO] = {
implicit val cs: ContextShift[IO] = IO.contextShift(ExecutionContext.global)
Expand Down
20 changes: 10 additions & 10 deletions src/main/scala/fr/loicknuchel/safeql/gen/writer/ScalaWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ import fr.loicknuchel.safeql.gen.writer.ScalaWriter.{DatabaseConfig, TableConfig
import fr.loicknuchel.safeql.gen.writer.Writer.IdentifierStrategy
import fr.loicknuchel.safeql.utils.StringUtils

class ScalaWriter(val directory: String,
val packageName: String,
val identifierStrategy: IdentifierStrategy,
val config: DatabaseConfig) extends Writer {
case class ScalaWriter(directory: String,
packageName: String,
identifierStrategy: IdentifierStrategy,
config: DatabaseConfig) extends Writer {
require(config.getConfigErrors.isEmpty, s"DatabaseConfig has some errors :${config.getConfigErrors.map("\n - " + _).mkString}")
require(StringUtils.isScalaPackage(packageName), s"'$packageName' is an invalid scala package name")

Expand Down Expand Up @@ -125,7 +125,7 @@ class ScalaWriter(val directory: String,
val fieldRef = (if (r.schema == f.schema && r.table == f.table) "" else s"${idf(r.table)}.table.") + idf(r.field)
s"val $fieldName: SqlFieldRef[$valueType, $tableName, ${idf(r.table)}] = SqlField(this, ${str(f.name)}, ${str(f.jdbcTypeDeclaration)}, JdbcType.$jdbcType, nullable = ${f.nullable}, ${f.index}, $fieldRef)"
}.getOrElse {
s"val $fieldName: SqlField[$valueType, $tableName] = SqlField(this, ${str(f.name)}, ${str(f.jdbcTypeDeclaration)}, JdbcType.$jdbcType, nullable = ${f.nullable}, ${f.index})"
s"val $fieldName: SqlFieldRaw[$valueType, $tableName] = SqlField(this, ${str(f.name)}, ${str(f.jdbcTypeDeclaration)}, JdbcType.$jdbcType, nullable = ${f.nullable}, ${f.index})"
}
}

Expand Down Expand Up @@ -174,7 +174,7 @@ class ScalaWriter(val directory: String,
object ScalaWriter {
def apply(directory: String = "src/main/scala",
packageName: String = "safeql",
identifierStrategy: IdentifierStrategy = Writer.IdentifierStrategy.upperCase,
identifierStrategy: IdentifierStrategy = Writer.IdentifierStrategy.UpperCase,
config: DatabaseConfig = DatabaseConfig()): ScalaWriter =
new ScalaWriter(directory, packageName, identifierStrategy, config)

Expand Down Expand Up @@ -242,15 +242,15 @@ object ScalaWriter {

def apply(alias: String, sort: TableConfig.Sort, search: List[String]): TableConfig = new TableConfig(Some(alias), List(sort), search)

def apply(alias: String, fields: Map[String, FieldConfig]): TableConfig = new TableConfig(Some(alias), List(), List(), fields)
def apply(alias: String, sort: TableConfig.Sort, search: List[String], fields: Map[String, FieldConfig]): TableConfig = new TableConfig(Some(alias), List(sort), search, fields)

def apply(alias: String, sort: TableConfig.Sort, fields: Map[String, FieldConfig]): TableConfig = new TableConfig(Some(alias), List(sort), List(), fields)

def apply(alias: String, sort: String, fields: Map[String, FieldConfig]): TableConfig = new TableConfig(Some(alias), List(TableConfig.Sort(sort)), List(), fields)

def apply(alias: String, sort: String, search: List[String], fields: Map[String, FieldConfig]): TableConfig = new TableConfig(Some(alias), List(TableConfig.Sort(sort)), search, fields)

def apply(alias: String, sort: TableConfig.Sort, fields: Map[String, FieldConfig]): TableConfig = new TableConfig(Some(alias), List(sort), List(), fields)

def apply(alias: String, sort: TableConfig.Sort, search: List[String], fields: Map[String, FieldConfig]): TableConfig = new TableConfig(Some(alias), List(sort), search, fields)
def apply(alias: String, fields: Map[String, FieldConfig]): TableConfig = new TableConfig(Some(alias), List(), List(), fields)

case class Sort(slug: String, label: String, fields: NonEmptyList[Sort.Field])

Expand Down
Loading

0 comments on commit 1915388

Please sign in to comment.