Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #164 #168

Merged
merged 3 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import com.softwaremill.SbtSoftwareMillCommon.commonSmlBuildSettings
import com.softwaremill.Publish.{ossPublishSettings, updateDocs}
import com.softwaremill.UpdateVersionInDocs

Global / cancelable := true

lazy val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq(
organization := "com.softwaremill.ox",
scalaVersion := "3.3.3",
Expand Down Expand Up @@ -50,7 +52,8 @@ lazy val core: Project = (project in file("core"))
scalaTest
),
// Check IO usage in core
useRequireIOPlugin
useRequireIOPlugin,
Test / fork := true
)

lazy val plugin: Project = (project in file("plugin"))
Expand Down
21 changes: 17 additions & 4 deletions core/src/main/scala/ox/fork.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def forkAll[T](fs: Seq[() => T])(using Ox): Fork[Seq[T]] =
val forks = fs.map(f => fork(f()))
new Fork[Seq[T]]:
override def join(): Seq[T] = forks.map(_.join())
override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean = forks.exists(_.wasInterruptedWith(ie))

/** Starts a fork (logical thread of execution), which is guaranteed to complete before the enclosing [[supervised]], [[supervisedError]] or
* [[unsupervised]] block completes, and which can be cancelled on-demand.
Expand Down Expand Up @@ -177,8 +178,13 @@ def forkCancellable[T](f: => T)(using OxUnsupervised): CancellableFork[T] =
if !started.getAndSet(true)
then result.completeExceptionally(new InterruptedException("fork was cancelled before it started")).discard

override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean =
result.isCompletedExceptionally && (result.exceptionNow() eq ie)

private def newForkUsingResult[T](result: CompletableFuture[T]): Fork[T] = new Fork[T]:
override def join(): T = unwrapExecutionException(result.get())
override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean =
result.isCompletedExceptionally && (result.exceptionNow() eq ie)

private[ox] inline def unwrapExecutionException[T](f: => T): T =
try f
Expand Down Expand Up @@ -208,16 +214,23 @@ trait Fork[T]:
def joinEither(): Either[Throwable, T] =
try Right(join())
catch
// normally IE is fatal, but here it was meant to cancel the fork, not the joining parent, hence we catch it
case e: InterruptedException => Left(e)
// normally IE is fatal, but here it could have meant that the fork was cancelled, hence we catch it
// we do discern between the fork and the current thread being cancelled and rethrow if it's us who's getting the axe
case e: InterruptedException => if wasInterruptedWith(e) then Left(e) else throw e
case NonFatal(e) => Left(e)

private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean

object Fork:
/** A dummy pretending to represent a fork which successfully completed with the given value. */
def successful[T](value: T): Fork[T] = () => value
def successful[T](value: T): Fork[T] = new Fork[T]:
override def join(): T = value
override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean = false

/** A dummy pretending to represent a fork which failed with the given exception. */
def failed[T](e: Throwable): Fork[T] = () => throw e
def failed[T](e: Throwable): Fork[T] = new Fork[T]:
override def join(): T = throw e
override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean = e eq ie

/** A fork started using [[forkCancellable]], backed by a (virtual) thread. */
trait CancellableFork[T] extends Fork[T]:
Expand Down
27 changes: 27 additions & 0 deletions core/src/test/scala/ox/SupervisedTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,31 @@ class SupervisedTest extends AnyFlatSpec with Matchers {
trail.add("done")
trail.get shouldBe Vector("b", "a", "done")
}

it should "handle interruption of multiple forks with `joinEither` correctly" in {
val e = intercept[Exception] {
supervised {
def computation(withException: Option[String]): Int = {
withException match
case None => 1
case Some(value) =>
throw new Exception(value)
}

val fork1 = fork:
computation(withException = None)
val fork2 = fork:
computation(withException = Some("Oh no!"))
val fork3 = fork:
computation(withException = Some("Oh well.."))

fork1.joinEither() // 1
fork2.joinEither() // 2
fork3.joinEither() // 3
}
}

e.getMessage should startWith("Oh")
}

}
Loading