Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IOLocal propagation for unsafe access #3636

Merged
merged 39 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0b88c01
POC thread-local iolocals
armanbilge May 16, 2023
db743e2
Simplify and optimize
armanbilge May 16, 2023
716ef32
Special-case for `WorkerThread`
armanbilge May 17, 2023
0a69caf
Load locals in `unsafeRunFiber`
armanbilge May 17, 2023
2775064
Dump locals in more places
armanbilge May 18, 2023
270764f
Refactor `IOLocal`
armanbilge May 21, 2023
d55489d
Use new `IOLocal` APIs in `IOLocals`
armanbilge May 21, 2023
2cf72a5
Mark `IOLocal` methods as `final`
armanbilge May 21, 2023
cb3859d
Add `IOLocalsSpec`
armanbilge Jun 10, 2023
7dce01c
Rename property to `ioLocalPropagation` and fixes
armanbilge Jun 28, 2023
5e171ac
Bump base version
armanbilge Jun 28, 2023
c2f312d
Add files I forgot tocommit :)
armanbilge Jun 28, 2023
638930d
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Jun 28, 2023
9174c6a
Add MiMa filters
armanbilge Jun 28, 2023
1987e3a
Fix scaladoc links
armanbilge Jun 28, 2023
02a43a6
Alias the disambiguations
armanbilge Jun 28, 2023
a7bf748
Copy locals back out after blocking unsafe run
armanbilge Sep 5, 2023
145fc0e
Merge remote-tracking branch 'upstream/series/3.x' into topic/thread-…
armanbilge Sep 5, 2023
fa99a5c
Expose status of `IOLocal` propagation
armanbilge Sep 25, 2023
6cad03c
`propagating` -> `arePropagating`
armanbilge Sep 29, 2023
bb5d4b1
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Sep 30, 2023
7517755
Use `real` instead of `ticked`
armanbilge Sep 30, 2023
8d8e004
Formatting
armanbilge Sep 30, 2023
3589db4
Try keeping the current fiber as a thread-local instead
armanbilge Sep 30, 2023
522677e
Revert spurious whitespace changes
armanbilge Sep 30, 2023
6cc4d38
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge May 27, 2024
ac88480
Update headers
armanbilge May 27, 2024
49e5c30
Update platform headers
armanbilge May 27, 2024
925f504
Remove unused class
armanbilge May 27, 2024
d63a6ff
Expose `IOLocal` propagation as a `ThreadLocal`
armanbilge Jun 4, 2024
d4549fb
`unsafeToThreadLocal()` throws if propagation disabled
armanbilge Jun 4, 2024
2502045
Add scaladoc
armanbilge Jun 5, 2024
535fc8a
Factor out to JVM-only
armanbilge Jun 5, 2024
d854799
Bikeshed API and docs
armanbilge Jun 5, 2024
f070552
Formatting
armanbilge Jun 5, 2024
2cf1d8a
Delete dead code
armanbilge Jun 5, 2024
0eec9dd
Document `ThreadLocal` propagation
armanbilge Aug 5, 2024
af84973
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Nov 14, 2024
1adf368
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,15 @@ lazy val core = crossProject(JSPlatform, JVMPlatform, NativePlatform)
"cats.effect.unsafe.IORuntimeBuilder.this"),
// introduced by #3695, which enabled fiber dumps on native
ProblemFilters.exclude[MissingClassProblem](
"cats.effect.unsafe.FiberMonitorCompanionPlatform")
"cats.effect.unsafe.FiberMonitorCompanionPlatform"),
// introduced by #3636, IOLocal propagation
// IOLocal is a sealed trait
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.getOrDefault"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.set"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.reset"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.lens"),
// this filter is particulary terrible, because it can also mask real issues :(
ProblemFilters.exclude[DirectMissingMethodProblem]("cats.effect.IOLocal.lens")
) ++ {
if (tlIsScala3.value) {
// Scala 3 specific exclusions
Expand Down Expand Up @@ -905,7 +913,8 @@ lazy val tests: CrossProject = crossProject(JSPlatform, JVMPlatform, NativePlatf
scalacOptions ~= { _.filterNot(_.startsWith("-P:scalajs:mapSourceURI")) }
)
.jvmSettings(
fork := true
fork := true,
Test / javaOptions += "-Dcats.effect.ioLocalPropagation=true"
)
.nativeSettings(
Compile / mainClass := Some("catseffect.examples.NativeRunner")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ private object IOFiberConstants {
final val AutoCedeR = 7
final val DoneR = 8

final val ioLocalPropagation = false
djspiewak marked this conversation as resolved.
Show resolved Hide resolved

@nowarn212
@inline def isVirtualThread(t: Thread): Boolean = false
}
19 changes: 19 additions & 0 deletions core/js-native/src/main/scala/cats/effect/IOLocalPlatform.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright 2020-2024 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect

private[effect] trait IOLocalPlatform[A]
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ private[effect] sealed abstract class WorkStealingThreadPool[P] private ()
Map[Runnable, Trace])
}

private[unsafe] sealed abstract class WorkerThread[P] private () extends Thread {
private[effect] sealed abstract class WorkerThread[P] private () extends Thread {
private[unsafe] def isOwnedBy(threadPool: WorkStealingThreadPool[_]): Boolean
private[unsafe] def monitor(fiber: Runnable): WeakBag.Handle
private[unsafe] def index: Int
private[effect] var currentIOFiber: IOFiber[_]
}
2 changes: 2 additions & 0 deletions core/jvm/src/main/java/cats/effect/IOFiberConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ final class IOFiberConstants {
static final byte AutoCedeR = 7;
static final byte DoneR = 8;

static final boolean ioLocalPropagation = Boolean.getBoolean("cats.effect.ioLocalPropagation");

static boolean isVirtualThread(final Thread thread) {
try {
return (boolean) THREAD_IS_VIRTUAL_HANDLE.invokeExact(thread);
Expand Down
57 changes: 57 additions & 0 deletions core/jvm/src/main/scala/cats/effect/IOLocalPlatform.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2020-2024 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect

import IOFiberConstants.ioLocalPropagation

private[effect] trait IOLocalPlatform[A] { self: IOLocal[A] =>

/**
* Returns a [[java.lang.ThreadLocal]] view of this [[IOLocal]] that allows to unsafely get,
* set, and remove (aka reset) the value in the currently running fiber. The system property
* `cats.effect.ioLocalPropagation` must be `true`, otherwise throws an
* [[java.lang.UnsupportedOperationException]].
*/
def unsafeThreadLocal(): ThreadLocal[A] = if (ioLocalPropagation)
djspiewak marked this conversation as resolved.
Show resolved Hide resolved
new ThreadLocal[A] {
override def get(): A = {
val fiber = IOFiber.currentIOFiber()
val state = if (fiber ne null) fiber.getLocalState() else IOLocalState.empty
self.getOrDefault(state)
}

override def set(value: A): Unit = {
val fiber = IOFiber.currentIOFiber()
if (fiber ne null) {
fiber.setLocalState(self.set(fiber.getLocalState(), value))
}
}

override def remove(): Unit = {
val fiber = IOFiber.currentIOFiber()
if (fiber ne null) {
fiber.setLocalState(self.reset(fiber.getLocalState()))
}
}
}
else
throw new UnsupportedOperationException(
"IOLocal-ThreadLocal propagation is disabled.\n" +
"Enable by setting cats.effect.ioLocalPropagation=true."
)

}
5 changes: 4 additions & 1 deletion core/jvm/src/main/scala/cats/effect/IOPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ abstract private[effect] class IOPlatform[+A] extends Serializable { self: IO[A]
implicit runtime: unsafe.IORuntime): Option[A] = {
val queue = new ArrayBlockingQueue[Either[Throwable, A]](1)

unsafeRunAsync { r =>
val fiber = unsafeRunAsyncImpl { r =>
queue.offer(r)
()
}
Expand All @@ -82,6 +82,9 @@ abstract private[effect] class IOPlatform[+A] extends Serializable { self: IO[A]
} catch {
case _: InterruptedException =>
None
} finally {
if (IOFiberConstants.ioLocalPropagation)
IOLocal.setThreadLocalState(fiber.getLocalState())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import java.util.concurrent.atomic.AtomicBoolean
* system when compared to a fixed size thread pool whose worker threads all draw tasks from a
* single global work queue.
*/
private final class WorkerThread[P <: AnyRef](
private[effect] final class WorkerThread[P <: AnyRef](
idx: Int,
// Local queue instance with exclusive write access.
private[this] var queue: LocalQueue,
Expand Down Expand Up @@ -107,6 +107,8 @@ private final class WorkerThread[P <: AnyRef](
private val indexTransfer: LinkedTransferQueue[Integer] = new LinkedTransferQueue()
private[this] val runtimeBlockingExpiration: Duration = pool.runtimeBlockingExpiration

private[effect] var currentIOFiber: IOFiber[_] = _
djspiewak marked this conversation as resolved.
Show resolved Hide resolved

private[this] val RightUnit = Right(())
private[this] val noop = new Function0[Unit] with Runnable {
def apply() = ()
Expand Down
11 changes: 8 additions & 3 deletions core/shared/src/main/scala/cats/effect/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,12 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
*/
def unsafeRunAsync(cb: Either[Throwable, A] => Unit)(
implicit runtime: unsafe.IORuntime): Unit = {
unsafeRunAsyncImpl(cb)
()
}

private[effect] def unsafeRunAsyncImpl(cb: Either[Throwable, A] => Unit)(
implicit runtime: unsafe.IORuntime): IOFiber[A @uncheckedVariance] =
djspiewak marked this conversation as resolved.
Show resolved Hide resolved
unsafeRunFiber(
cb(Left(new CancellationException("The fiber was canceled"))),
t => {
Expand All @@ -988,8 +994,6 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
},
a => cb(Right(a))
)
()
}

def unsafeRunAsyncOutcome(cb: Outcome[Id, Throwable, A @uncheckedVariance] => Unit)(
implicit runtime: unsafe.IORuntime): Unit = {
Expand Down Expand Up @@ -1092,7 +1096,8 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
implicit runtime: unsafe.IORuntime): IOFiber[A @uncheckedVariance] = {

val fiber = new IOFiber[A](
Map.empty,
if (IOFiberConstants.ioLocalPropagation) IOLocal.getThreadLocalState()
else IOLocalState.empty,
oc =>
oc.fold(
{
Expand Down
30 changes: 30 additions & 0 deletions core/shared/src/main/scala/cats/effect/IOFiber.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,18 @@ private final class IOFiber[A](
@volatile
private[this] var outcome: OutcomeIO[A] = _

def getLocalState(): IOLocalState = localState

def setLocalState(s: IOLocalState): Unit = localState = s

override def run(): Unit = {
// insert a read barrier after every async boundary
readBarrier()

if (ioLocalPropagation) {
IOFiber.setCurrentIOFiber(this)
}

(resumeTag: @switch) match {
case 0 => execR()
case 1 => asyncContinueSuccessfulR()
Expand All @@ -121,6 +130,10 @@ private final class IOFiber[A](
case 7 => autoCedeR()
case 8 => () // DoneR
}

if (ioLocalPropagation) {
IOFiber.setCurrentIOFiber(null)
}
djspiewak marked this conversation as resolved.
Show resolved Hide resolved
}

/* backing fields for `cancel` and `join` */
Expand Down Expand Up @@ -1559,6 +1572,23 @@ private object IOFiber {
@static private[IOFiber] val OutcomeCanceled = Outcome.Canceled()
@static private[effect] val RightUnit = Right(())

@static private[this] val threadLocal = new ThreadLocal[IOFiber[_]]
@static def currentIOFiber(): IOFiber[_] = {
val thread = Thread.currentThread()
if (thread.isInstanceOf[WorkerThread[_]])
thread.asInstanceOf[WorkerThread[_]].currentIOFiber
else
threadLocal.get()
}

@static private def setCurrentIOFiber(f: IOFiber[_]): Unit = {
val thread = Thread.currentThread()
if (thread.isInstanceOf[WorkerThread[_]])
thread.asInstanceOf[WorkerThread[_]].currentIOFiber = f
else
threadLocal.set(f)
}

@static def onFatalFailure(t: Throwable): Nothing = {
val interrupted = Thread.interrupted()

Expand Down
Loading
Loading