Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #164 #168

Merged
merged 3 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import com.softwaremill.SbtSoftwareMillCommon.commonSmlBuildSettings
import com.softwaremill.Publish.{ossPublishSettings, updateDocs}
import com.softwaremill.UpdateVersionInDocs

Global / cancelable := true

lazy val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq(
organization := "com.softwaremill.ox",
scalaVersion := "3.3.3",
Expand Down Expand Up @@ -50,7 +52,8 @@ lazy val core: Project = (project in file("core"))
scalaTest
),
// Check IO usage in core
useRequireIOPlugin
useRequireIOPlugin,
Test / fork := true
)

lazy val plugin: Project = (project in file("plugin"))
Expand Down
21 changes: 17 additions & 4 deletions core/src/main/scala/ox/fork.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def forkAll[T](fs: Seq[() => T])(using Ox): Fork[Seq[T]] =
val forks = fs.map(f => fork(f()))
new Fork[Seq[T]]:
override def join(): Seq[T] = forks.map(_.join())
override def wasInterruptedWith(ie: InterruptedException): Boolean = forks.exists(_.wasInterruptedWith(ie))

/** Starts a fork (logical thread of execution), which is guaranteed to complete before the enclosing [[supervised]], [[supervisedError]] or
* [[unsupervised]] block completes, and which can be cancelled on-demand.
Expand Down Expand Up @@ -177,8 +178,13 @@ def forkCancellable[T](f: => T)(using OxUnsupervised): CancellableFork[T] =
if !started.getAndSet(true)
then result.completeExceptionally(new InterruptedException("fork was cancelled before it started")).discard

override def wasInterruptedWith(ie: InterruptedException): Boolean =
result.isCompletedExceptionally && (result.exceptionNow() eq ie)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be public? it seems like it has the potential to be misused given the eq check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we discussed that in the issue, it's package private now.


private def newForkUsingResult[T](result: CompletableFuture[T]): Fork[T] = new Fork[T]:
override def join(): T = unwrapExecutionException(result.get())
override def wasInterruptedWith(ie: InterruptedException): Boolean =
result.isCompletedExceptionally && (result.exceptionNow() eq ie)

private[ox] inline def unwrapExecutionException[T](f: => T): T =
try f
Expand Down Expand Up @@ -208,16 +214,23 @@ trait Fork[T]:
def joinEither(): Either[Throwable, T] =
try Right(join())
catch
// normally IE is fatal, but here it was meant to cancel the fork, not the joining parent, hence we catch it
case e: InterruptedException => Left(e)
// normally IE is fatal, but here it could have meant that the fork was cancelled, hence we catch it
// we do discern between the fork and the current thread being cancelled and rethrow if it's us who's getting the axe
case e: InterruptedException => if wasInterruptedWith(e) then Left(e) else throw e
case NonFatal(e) => Left(e)

def wasInterruptedWith(ie: InterruptedException): Boolean

object Fork:
/** A dummy pretending to represent a fork which successfully completed with the given value. */
def successful[T](value: T): Fork[T] = () => value
def successful[T](value: T): Fork[T] = new Fork[T]:
override def join(): T = value
override def wasInterruptedWith(ie: InterruptedException): Boolean = false

/** A dummy pretending to represent a fork which failed with the given exception. */
def failed[T](e: Throwable): Fork[T] = () => throw e
def failed[T](e: Throwable): Fork[T] = new Fork[T]:
override def join(): T = throw e
override def wasInterruptedWith(ie: InterruptedException): Boolean = e eq ie

/** A fork started using [[forkCancellable]], backed by a (virtual) thread. */
trait CancellableFork[T] extends Fork[T]:
Expand Down
27 changes: 27 additions & 0 deletions core/src/test/scala/ox/SupervisedTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,31 @@ class SupervisedTest extends AnyFlatSpec with Matchers {
trail.add("done")
trail.get shouldBe Vector("b", "a", "done")
}

it should "handle interruption of multiple forks with `joinEither` correctly" in {
val e = intercept[Exception] {
supervised {
def computation(withException: Option[String]): Int = {
withException match
case None => 1
case Some(value) =>
throw new Exception(value)
}

val fork1 = fork:
computation(withException = None)
val fork2 = fork:
computation(withException = Some("Oh no!"))
val fork3 = fork:
computation(withException = Some("Oh well.."))

fork1.joinEither() // 1
fork2.joinEither() // 2
fork3.joinEither() // 3
}
}

e.getMessage should startWith("Oh")
}

}
Loading