diff --git a/build.sbt b/build.sbt index 7ed28d9d62..21f88cf130 100644 --- a/build.sbt +++ b/build.sbt @@ -660,7 +660,15 @@ lazy val core = crossProject(JSPlatform, JVMPlatform, NativePlatform) "cats.effect.unsafe.IORuntimeBuilder.this"), // introduced by #3695, which enabled fiber dumps on native ProblemFilters.exclude[MissingClassProblem]( - "cats.effect.unsafe.FiberMonitorCompanionPlatform") + "cats.effect.unsafe.FiberMonitorCompanionPlatform"), + // introduced by #3636, IOLocal propagation + // IOLocal is a sealed trait + ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.getOrDefault"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.set"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.reset"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.lens"), + // this filter is particulary terrible, because it can also mask real issues :( + ProblemFilters.exclude[DirectMissingMethodProblem]("cats.effect.IOLocal.lens") ) ++ { if (tlIsScala3.value) { // Scala 3 specific exclusions @@ -905,7 +913,8 @@ lazy val tests: CrossProject = crossProject(JSPlatform, JVMPlatform, NativePlatf scalacOptions ~= { _.filterNot(_.startsWith("-P:scalajs:mapSourceURI")) } ) .jvmSettings( - fork := true + fork := true, + Test / javaOptions += "-Dcats.effect.ioLocalPropagation=true" ) .nativeSettings( Compile / mainClass := Some("catseffect.examples.NativeRunner") diff --git a/core/js-native/src/main/scala/cats/effect/IOFiberConstants.scala b/core/js-native/src/main/scala/cats/effect/IOFiberConstants.scala index c12106b1dd..efd9594bb0 100644 --- a/core/js-native/src/main/scala/cats/effect/IOFiberConstants.scala +++ b/core/js-native/src/main/scala/cats/effect/IOFiberConstants.scala @@ -46,6 +46,8 @@ private object IOFiberConstants { final val AutoCedeR = 7 final val DoneR = 8 + final val ioLocalPropagation = false + @nowarn212 @inline def isVirtualThread(t: Thread): Boolean = false } diff --git a/core/js-native/src/main/scala/cats/effect/IOLocalPlatform.scala b/core/js-native/src/main/scala/cats/effect/IOLocalPlatform.scala new file mode 100644 index 0000000000..ea7d3e8972 --- /dev/null +++ b/core/js-native/src/main/scala/cats/effect/IOLocalPlatform.scala @@ -0,0 +1,19 @@ +/* + * Copyright 2020-2024 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect + +private[effect] trait IOLocalPlatform[A] diff --git a/core/js-native/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala b/core/js-native/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala index acbe94c077..4d13ee8b0a 100644 --- a/core/js-native/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala +++ b/core/js-native/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala @@ -43,8 +43,9 @@ private[effect] sealed abstract class WorkStealingThreadPool[P] private () Map[Runnable, Trace]) } -private[unsafe] sealed abstract class WorkerThread[P] private () extends Thread { +private[effect] sealed abstract class WorkerThread[P] private () extends Thread { private[unsafe] def isOwnedBy(threadPool: WorkStealingThreadPool[_]): Boolean private[unsafe] def monitor(fiber: Runnable): WeakBag.Handle private[unsafe] def index: Int + private[effect] var currentIOFiber: IOFiber[_] } diff --git a/core/jvm/src/main/java/cats/effect/IOFiberConstants.java b/core/jvm/src/main/java/cats/effect/IOFiberConstants.java index 92a7c861a5..7cb6585b3d 100644 --- a/core/jvm/src/main/java/cats/effect/IOFiberConstants.java +++ b/core/jvm/src/main/java/cats/effect/IOFiberConstants.java @@ -48,6 +48,8 @@ final class IOFiberConstants { static final byte AutoCedeR = 7; static final byte DoneR = 8; + static final boolean ioLocalPropagation = Boolean.getBoolean("cats.effect.ioLocalPropagation"); + static boolean isVirtualThread(final Thread thread) { try { return (boolean) THREAD_IS_VIRTUAL_HANDLE.invokeExact(thread); diff --git a/core/jvm/src/main/scala/cats/effect/IOLocalPlatform.scala b/core/jvm/src/main/scala/cats/effect/IOLocalPlatform.scala new file mode 100644 index 0000000000..4a66d31b23 --- /dev/null +++ b/core/jvm/src/main/scala/cats/effect/IOLocalPlatform.scala @@ -0,0 +1,57 @@ +/* + * Copyright 2020-2024 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect + +import IOFiberConstants.ioLocalPropagation + +private[effect] trait IOLocalPlatform[A] { self: IOLocal[A] => + + /** + * Returns a [[java.lang.ThreadLocal]] view of this [[IOLocal]] that allows to unsafely get, + * set, and remove (aka reset) the value in the currently running fiber. The system property + * `cats.effect.ioLocalPropagation` must be `true`, otherwise throws an + * [[java.lang.UnsupportedOperationException]]. + */ + def unsafeThreadLocal(): ThreadLocal[A] = if (ioLocalPropagation) + new ThreadLocal[A] { + override def get(): A = { + val fiber = IOFiber.currentIOFiber() + val state = if (fiber ne null) fiber.getLocalState() else IOLocalState.empty + self.getOrDefault(state) + } + + override def set(value: A): Unit = { + val fiber = IOFiber.currentIOFiber() + if (fiber ne null) { + fiber.setLocalState(self.set(fiber.getLocalState(), value)) + } + } + + override def remove(): Unit = { + val fiber = IOFiber.currentIOFiber() + if (fiber ne null) { + fiber.setLocalState(self.reset(fiber.getLocalState())) + } + } + } + else + throw new UnsupportedOperationException( + "IOLocal-ThreadLocal propagation is disabled.\n" + + "Enable by setting cats.effect.ioLocalPropagation=true." + ) + +} diff --git a/core/jvm/src/main/scala/cats/effect/IOPlatform.scala b/core/jvm/src/main/scala/cats/effect/IOPlatform.scala index c4a68d5a76..c53654eafc 100644 --- a/core/jvm/src/main/scala/cats/effect/IOPlatform.scala +++ b/core/jvm/src/main/scala/cats/effect/IOPlatform.scala @@ -71,7 +71,7 @@ abstract private[effect] class IOPlatform[+A] extends Serializable { self: IO[A] implicit runtime: unsafe.IORuntime): Option[A] = { val queue = new ArrayBlockingQueue[Either[Throwable, A]](1) - unsafeRunAsync { r => + val fiber = unsafeRunAsyncImpl { r => queue.offer(r) () } @@ -82,6 +82,9 @@ abstract private[effect] class IOPlatform[+A] extends Serializable { self: IO[A] } catch { case _: InterruptedException => None + } finally { + if (IOFiberConstants.ioLocalPropagation) + IOLocal.setThreadLocalState(fiber.getLocalState()) } } diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala index a45c1babcc..ce9364d07d 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala @@ -41,7 +41,7 @@ import java.util.concurrent.atomic.AtomicBoolean * system when compared to a fixed size thread pool whose worker threads all draw tasks from a * single global work queue. */ -private final class WorkerThread[P <: AnyRef]( +private[effect] final class WorkerThread[P <: AnyRef]( idx: Int, // Local queue instance with exclusive write access. private[this] var queue: LocalQueue, @@ -107,6 +107,8 @@ private final class WorkerThread[P <: AnyRef]( private val indexTransfer: LinkedTransferQueue[Integer] = new LinkedTransferQueue() private[this] val runtimeBlockingExpiration: Duration = pool.runtimeBlockingExpiration + private[effect] var currentIOFiber: IOFiber[_] = _ + private[this] val RightUnit = Right(()) private[this] val noop = new Function0[Unit] with Runnable { def apply() = () diff --git a/core/shared/src/main/scala/cats/effect/IO.scala b/core/shared/src/main/scala/cats/effect/IO.scala index f2d2c0681c..fa00d3a3d9 100644 --- a/core/shared/src/main/scala/cats/effect/IO.scala +++ b/core/shared/src/main/scala/cats/effect/IO.scala @@ -997,6 +997,12 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] { */ def unsafeRunAsync(cb: Either[Throwable, A] => Unit)( implicit runtime: unsafe.IORuntime): Unit = { + unsafeRunAsyncImpl(cb) + () + } + + private[effect] def unsafeRunAsyncImpl(cb: Either[Throwable, A] => Unit)( + implicit runtime: unsafe.IORuntime): IOFiber[A @uncheckedVariance] = unsafeRunFiber( cb(Left(new CancellationException("The fiber was canceled"))), t => { @@ -1007,8 +1013,6 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] { }, a => cb(Right(a)) ) - () - } def unsafeRunAsyncOutcome(cb: Outcome[Id, Throwable, A @uncheckedVariance] => Unit)( implicit runtime: unsafe.IORuntime): Unit = { @@ -1111,7 +1115,8 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] { implicit runtime: unsafe.IORuntime): IOFiber[A @uncheckedVariance] = { val fiber = new IOFiber[A]( - Map.empty, + if (IOFiberConstants.ioLocalPropagation) IOLocal.getThreadLocalState() + else IOLocalState.empty, { oc => if (registerCallback) { runtime.fiberErrorCbs.remove(failure) diff --git a/core/shared/src/main/scala/cats/effect/IOFiber.scala b/core/shared/src/main/scala/cats/effect/IOFiber.scala index 68b905635b..8d095587fe 100644 --- a/core/shared/src/main/scala/cats/effect/IOFiber.scala +++ b/core/shared/src/main/scala/cats/effect/IOFiber.scala @@ -107,9 +107,18 @@ private final class IOFiber[A]( @volatile private[this] var outcome: OutcomeIO[A] = _ + def getLocalState(): IOLocalState = localState + + def setLocalState(s: IOLocalState): Unit = localState = s + override def run(): Unit = { // insert a read barrier after every async boundary readBarrier() + + if (ioLocalPropagation) { + IOFiber.setCurrentIOFiber(this) + } + (resumeTag: @switch) match { case 0 => execR() case 1 => asyncContinueSuccessfulR() @@ -121,6 +130,10 @@ private final class IOFiber[A]( case 7 => autoCedeR() case 8 => () // DoneR } + + if (ioLocalPropagation) { + IOFiber.setCurrentIOFiber(null) + } } /* backing fields for `cancel` and `join` */ @@ -1559,6 +1572,23 @@ private object IOFiber { @static private[IOFiber] val OutcomeCanceled = Outcome.Canceled() @static private[effect] val RightUnit = Right(()) + @static private[this] val threadLocal = new ThreadLocal[IOFiber[_]] + @static def currentIOFiber(): IOFiber[_] = { + val thread = Thread.currentThread() + if (thread.isInstanceOf[WorkerThread[_]]) + thread.asInstanceOf[WorkerThread[_]].currentIOFiber + else + threadLocal.get() + } + + @static private def setCurrentIOFiber(f: IOFiber[_]): Unit = { + val thread = Thread.currentThread() + if (thread.isInstanceOf[WorkerThread[_]]) + thread.asInstanceOf[WorkerThread[_]].currentIOFiber = f + else + threadLocal.set(f) + } + @static def onFatalFailure(t: Throwable): Nothing = { val interrupted = Thread.interrupted() diff --git a/core/shared/src/main/scala/cats/effect/IOLocal.scala b/core/shared/src/main/scala/cats/effect/IOLocal.scala index 64fdb26d44..c0bcb8132f 100644 --- a/core/shared/src/main/scala/cats/effect/IOLocal.scala +++ b/core/shared/src/main/scala/cats/effect/IOLocal.scala @@ -136,27 +136,37 @@ import cats.data.AndThen * @tparam A * the type of the local value */ -sealed trait IOLocal[A] { self => +sealed trait IOLocal[A] extends IOLocalPlatform[A] { self => + + protected[effect] def getOrDefault(state: IOLocalState): A + + protected[effect] def set(state: IOLocalState, a: A): IOLocalState + + protected[effect] def reset(state: IOLocalState): IOLocalState /** * Returns the current value. */ - def get: IO[A] + final def get: IO[A] = + IO.Local(state => (state, getOrDefault(state))) /** * Sets the current value to `value`. */ - def set(value: A): IO[Unit] + final def set(value: A): IO[Unit] = + IO.Local(state => (set(state, value), ())) /** * Replaces the current value with the initial value. */ - def reset: IO[Unit] + final def reset: IO[Unit] = + IO.Local(state => (reset(state), ())) /** * Modifies the current value using the given update function. */ - def update(f: A => A): IO[Unit] + final def update(f: A => A): IO[Unit] = + IO.Local(state => (set(state, f(getOrDefault(state))), ())) /** * Like [[update]] but allows the update function to return an output value of type `B`. @@ -164,31 +174,37 @@ sealed trait IOLocal[A] { self => * @see * [[update]] */ - def modify[B](f: A => (A, B)): IO[B] + final def modify[B](f: A => (A, B)): IO[B] = + IO.Local { state => + val (a2, b) = f(getOrDefault(state)) + (set(state, a2), b) + } /** * Replaces the current value with `value`, returning the previous value. * - * The combination of [[get]] and [[set]]. + * The combination of [[get]] and [[set(value:* set]]. * * @see * [[get]] * @see - * [[set]] + * [[set(value:* set]] */ - def getAndSet(value: A): IO[A] + final def getAndSet(value: A): IO[A] = + IO.Local(state => (set(state, value), getOrDefault(state))) /** * Replaces the current value with the initial value, returning the previous value. * - * The combination of [[get]] and [[reset]]. + * The combination of [[get]] and [[reset:* reset]]. * * @see * [[get]] * @see - * [[reset]] + * [[reset:* reset]] */ - def getAndReset: IO[A] + final def getAndReset: IO[A] = + IO.Local(state => (reset(state), getOrDefault(state))) /** * Creates a lens to a value of some type `B` from current value and two functions: getter and @@ -197,9 +213,9 @@ sealed trait IOLocal[A] { self => * All changes to the original value will be visible via lens getter and all changes applied * to 'refracted' value will be forwarded to the original via setter. * - * Note that [[.set]] method requires special mention: while from the `IOLocal[B]` point of - * view old value will be replaced with a new one, from `IOLocal[A]` POV old value will be - * updated via setter. This means that for 'refracted' `IOLocal[B]` use of `set(b)` is + * Note that [[.set(value* set]] method requires special mention: while from the `IOLocal[B]` + * point of view old value will be replaced with a new one, from `IOLocal[A]` POV old value + * will be updated via setter. This means that for 'refracted' `IOLocal[B]` use of `set(b)` is * equivalent to `reset *> set(b)`, but it does not hold for original `IOLocal[A]`: * * {{{ @@ -220,28 +236,7 @@ sealed trait IOLocal[A] { self => * } yield () * }}} */ - final def lens[B](get: A => B)(set: A => B => A): IOLocal[B] = { - import IOLocal.IOLocalLens - - self match { - case lens: IOLocalLens[aa, A] => - // We process already created lens separately so - // we wont pay additional `.get.flatMap` price for every call of - // `set`, `update` or `modify` of resulting lens. - // After all, our getters and setters are pure, - // so `AndThen` allows us to safely compose them and - // proxy calls to the 'original' `IOLocal` independent of - // current nesting level. - - val getter = lens.getter.andThen(get) - val setter = lens.setter.compose((p: (aa, B)) => (p._1, set(lens.getter(p._1))(p._2))) - new IOLocalLens(lens.underlying, getter, setter) - case _ => - val getter = AndThen(get) - val setter = AndThen((p: (A, B)) => set(p._1)(p._2)) - new IOLocalLens(self, getter, setter) - } - } + def lens[B](get: A => B)(set: A => B => A): IOLocal[B] } @@ -260,63 +255,62 @@ object IOLocal { */ def apply[A](default: A): IO[IOLocal[A]] = IO(new IOLocalImpl(default)) - private[IOLocal] final class IOLocalImpl[A](default: A) extends IOLocal[A] { self => - private[this] def getOrDefault(state: IOLocalState): A = - state.getOrElse(self, default).asInstanceOf[A] + /** + * `true` if IOLocal-Threadlocal propagation is enabled + */ + def isPropagating: Boolean = IOFiberConstants.ioLocalPropagation - def get: IO[A] = - IO.Local(state => (state, getOrDefault(state))) + private[effect] def getThreadLocalState() = { + val fiber = IOFiber.currentIOFiber() + if (fiber ne null) fiber.getLocalState() else IOLocalState.empty + } - def set(value: A): IO[Unit] = - IO.Local(state => (state.updated(self, value), ())) + private[effect] def setThreadLocalState(state: IOLocalState) = { + val fiber = IOFiber.currentIOFiber() + if (fiber ne null) fiber.setLocalState(state) + } - def reset: IO[Unit] = - IO.Local(state => (state - self, ())) + private final class IOLocalImpl[A](default: A) extends IOLocal[A] { - def update(f: A => A): IO[Unit] = - IO.Local(state => (state.updated(self, f(getOrDefault(state))), ())) + def getOrDefault(state: IOLocalState): A = + state.getOrElse(this, default).asInstanceOf[A] - def modify[B](f: A => (A, B)): IO[B] = - IO.Local { state => - val (a2, b) = f(getOrDefault(state)) - (state.updated(self, a2), b) - } + def set(state: IOLocalState, a: A): IOLocalState = state.updated(this, a) - def getAndSet(value: A): IO[A] = - IO.Local(state => (state.updated(self, value), getOrDefault(state))) + def reset(state: IOLocalState): IOLocalState = state - this - def getAndReset: IO[A] = - IO.Local(state => (state - self, getOrDefault(state))) + def lens[B](get: A => B)(set: A => B => A): IOLocal[B] = + new IOLocal.IOLocalLens(this, get, (ab: (A, B)) => set(ab._1)(ab._2)) } - private[IOLocal] final class IOLocalLens[S, A]( - val underlying: IOLocal[S], - val getter: AndThen[S, A], - val setter: AndThen[(S, A), S]) + private final class IOLocalLens[S, A]( + underlying: IOLocal[S], + getter: S => A, + setter: ((S, A)) => S) extends IOLocal[A] { - def get: IO[A] = - underlying.get.map(getter(_)) - - def set(value: A): IO[Unit] = - underlying.get.flatMap(s => underlying.set(setter(s -> value))) - def reset: IO[Unit] = - underlying.reset + def getOrDefault(state: IOLocalState): A = + getter(underlying.getOrDefault(state)) - def update(f: A => A): IO[Unit] = - underlying.get.flatMap(s => underlying.set(setter(s -> f(getter(s))))) + def set(state: IOLocalState, a: A): IOLocalState = + underlying.set(state, setter((underlying.getOrDefault(state), a))) - def modify[B](f: A => (A, B)): IO[B] = - underlying.get.flatMap { s => - val (a2, b) = f(getter(s)) - underlying.set(setter(s -> a2)).as(b) - } + def reset(state: IOLocalState): IOLocalState = underlying.reset(state) - def getAndSet(value: A): IO[A] = - underlying.get.flatMap(s => underlying.set(setter(s -> value)).as(getter(s))) + def lens[B](get: A => B)(set: A => B => A): IOLocal[B] = { + // We process already created lens separately so + // we wont pay additional `.get.flatMap` price for every call of + // `set`, `update` or `modify` of resulting lens. + // After all, our getters and setters are pure, + // so `AndThen` allows us to safely compose them and + // proxy calls to the 'original' `IOLocal` independent of + // current nesting level. - def getAndReset: IO[A] = - underlying.get.flatMap(s => underlying.reset.as(getter(s))) + val getter = AndThen(this.getter).andThen(get) + val setter = + AndThen(this.setter).compose((sb: (S, B)) => (sb._1, set(this.getter(sb._1))(sb._2))) + new IOLocalLens(underlying, getter, setter) + } } } diff --git a/core/shared/src/main/scala/cats/effect/package.scala b/core/shared/src/main/scala/cats/effect/package.scala index 0270eae439..55a52e7cab 100644 --- a/core/shared/src/main/scala/cats/effect/package.scala +++ b/core/shared/src/main/scala/cats/effect/package.scala @@ -80,4 +80,7 @@ package object effect { val Ref = cekernel.Ref private[effect] type IOLocalState = scala.collection.immutable.Map[IOLocal[_], Any] + private[effect] object IOLocalState { + val empty: IOLocalState = scala.collection.immutable.Map.empty + } } diff --git a/docs/core/io-local.md b/docs/core/io-local.md index 345932c378..af3af30284 100644 --- a/docs/core/io-local.md +++ b/docs/core/io-local.md @@ -180,3 +180,9 @@ TraceIdScope.fromIOLocal.flatMap { implicit traceIdScope: TraceIdScope[IO] => service[IO] } ``` + +## Propagating `IOLocal`s as `ThreadLocal`s + +To support integration with Java libraries, `IOLocal` interoperates with the JDK `ThreadLocal` API via `IOLocal#unsafeThreadLocal`. This makes it possible to unsafely read and write the value of an `IOLocal` on the currently running fiber within a suspended side-effect (e.g. `IO.delay` or `IO.blocking`). + +To use this feature you must set the property `cats.effect.ioLocalPropagation=true`. Note that enabling propagation causes a performance hit of up to 25% in some of our microbenchmarks. However, it is not clear that this performance impact matters in practice. diff --git a/docs/core/io-runtime-config.md b/docs/core/io-runtime-config.md index 39828ad429..3810cab0c0 100644 --- a/docs/core/io-runtime-config.md +++ b/docs/core/io-runtime-config.md @@ -37,3 +37,4 @@ This can be done for example with the [EnvironmentPlugin for Webpack](https://we | `cats.effect.cpu.starvation.check.interval`
`CATS_EFFECT_CPU_STARVATION_CHECK_INTERVAL` | `FiniteDuration` (`1.second`) | The starvation checker repeatedly sleeps for this interval and then checks `monotonic` time when it awakens. It will then print a warning to stderr if it finds that the current time is greater than expected (see `threshold` below). | | `cats.effect.cpu.starvation.check.initialDelay`
`CATS_EFFECT_CPU_STARVATION_CHECK_INITIAL_DELAY` | `Duration` (`10.seconds`) | The initial delay before the CPU starvation checker starts running. Avoids spurious warnings due to the JVM not being warmed up yet. Set to `Duration.Inf` to disable CPU starvation checking. | | `cats.effect.cpu.starvation.check.threshold`
`CATS_EFFECT_CPU_STARVATION_CHECK_THRESHOLD` | `Double` (`0.1`) | The starvation checker will print a warning if it finds that it has been asleep for at least `interval * (1 + threshold)` (where `interval` from above is the expected time to be asleep for). Sleeping for too long is indicative of fibers hogging a worker thread either by performing blocking operations on it or by `cede`ing insufficiently frequently. | +| `cats.effect.ioLocalPropagation`
N/A | `Boolean` (`false`) | Enables `IOLocal`s to be propagated as `ThreadLocal`s. | diff --git a/tests/jvm/src/test/scala/cats/effect/unsafe/IOLocalsSpec.scala b/tests/jvm/src/test/scala/cats/effect/unsafe/IOLocalsSpec.scala new file mode 100644 index 0000000000..f425e67fa2 --- /dev/null +++ b/tests/jvm/src/test/scala/cats/effect/unsafe/IOLocalsSpec.scala @@ -0,0 +1,55 @@ +/* + * Copyright 2020-2023 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect +package unsafe + +class IOLocalsSpec extends BaseSpec { + + "IOLocals" should { + "return a default value" in real { + IOLocal(42) + .flatMap(local => IO(local.unsafeThreadLocal().get())) + .map(_ must beEqualTo(42)) + } + + "return a set value" in real { + for { + local <- IOLocal(42) + threadLocal <- IO(local.unsafeThreadLocal()) + _ <- local.set(24) + got <- IO(threadLocal.get()) + } yield got must beEqualTo(24) + } + + "unsafely set" in real { + IOLocal(42).flatMap(local => + IO(local.unsafeThreadLocal().set(24)) *> local.get.map(_ must beEqualTo(24))) + } + + "unsafely reset" in real { + for { + local <- IOLocal(42) + threadLocal <- IO(local.unsafeThreadLocal()) + _ <- local.set(24) + _ <- IO(threadLocal.remove()) + got <- local.get + } yield got must beEqualTo(42) + } + + } + +}