diff --git a/modules/core/shared/src/main/scala/Fragment.scala b/modules/core/shared/src/main/scala/Fragment.scala index 61083342..f6366bfc 100644 --- a/modules/core/shared/src/main/scala/Fragment.scala +++ b/modules/core/shared/src/main/scala/Fragment.scala @@ -42,6 +42,18 @@ final case class Fragment[A]( def ~[B](fb: Fragment[B]): Fragment[A ~ B] = product(fb) + def stripMargin: Fragment[A] = stripMargin('|') + + def stripMargin(marginChar: Char): Fragment[A] = { + val ps = parts.map { + _.bimap( + _.stripMargin(marginChar).replaceAll("\n", " "), + _.map(_.stripMargin(marginChar).replaceAll("\n", " ")) + ) + } + Fragment(ps, encoder, origin) + } + def apply(a: A): AppliedFragment = AppliedFragment(this, a) diff --git a/modules/tests/shared/src/test/scala/FragmentTest.scala b/modules/tests/shared/src/test/scala/FragmentTest.scala index a77f496b..fecf90f2 100644 --- a/modules/tests/shared/src/test/scala/FragmentTest.scala +++ b/modules/tests/shared/src/test/scala/FragmentTest.scala @@ -60,5 +60,18 @@ class FragmentTest extends SkunkTest { } yield "ok" } } -} + pureTest("stripMargin") { + val f = sql"""select + |$int4 + |""".stripMargin + f.sql.trim == sql"select $int4".sql + } + + pureTest("stripMargin with char") { + val f = sql"""select + ^$int4 + ^""".stripMargin('^') + f.sql.trim == sql"select $int4".sql + } +}