Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IOLocal propagation for unsafe access #3636

Merged
merged 39 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0b88c01
POC thread-local iolocals
armanbilge May 16, 2023
db743e2
Simplify and optimize
armanbilge May 16, 2023
716ef32
Special-case for `WorkerThread`
armanbilge May 17, 2023
0a69caf
Load locals in `unsafeRunFiber`
armanbilge May 17, 2023
2775064
Dump locals in more places
armanbilge May 18, 2023
270764f
Refactor `IOLocal`
armanbilge May 21, 2023
d55489d
Use new `IOLocal` APIs in `IOLocals`
armanbilge May 21, 2023
2cf72a5
Mark `IOLocal` methods as `final`
armanbilge May 21, 2023
cb3859d
Add `IOLocalsSpec`
armanbilge Jun 10, 2023
7dce01c
Rename property to `ioLocalPropagation` and fixes
armanbilge Jun 28, 2023
5e171ac
Bump base version
armanbilge Jun 28, 2023
c2f312d
Add files I forgot tocommit :)
armanbilge Jun 28, 2023
638930d
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Jun 28, 2023
9174c6a
Add MiMa filters
armanbilge Jun 28, 2023
1987e3a
Fix scaladoc links
armanbilge Jun 28, 2023
02a43a6
Alias the disambiguations
armanbilge Jun 28, 2023
a7bf748
Copy locals back out after blocking unsafe run
armanbilge Sep 5, 2023
145fc0e
Merge remote-tracking branch 'upstream/series/3.x' into topic/thread-…
armanbilge Sep 5, 2023
fa99a5c
Expose status of `IOLocal` propagation
armanbilge Sep 25, 2023
6cad03c
`propagating` -> `arePropagating`
armanbilge Sep 29, 2023
bb5d4b1
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Sep 30, 2023
7517755
Use `real` instead of `ticked`
armanbilge Sep 30, 2023
8d8e004
Formatting
armanbilge Sep 30, 2023
3589db4
Try keeping the current fiber as a thread-local instead
armanbilge Sep 30, 2023
522677e
Revert spurious whitespace changes
armanbilge Sep 30, 2023
6cc4d38
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge May 27, 2024
ac88480
Update headers
armanbilge May 27, 2024
49e5c30
Update platform headers
armanbilge May 27, 2024
925f504
Remove unused class
armanbilge May 27, 2024
d63a6ff
Expose `IOLocal` propagation as a `ThreadLocal`
armanbilge Jun 4, 2024
d4549fb
`unsafeToThreadLocal()` throws if propagation disabled
armanbilge Jun 4, 2024
2502045
Add scaladoc
armanbilge Jun 5, 2024
535fc8a
Factor out to JVM-only
armanbilge Jun 5, 2024
d854799
Bikeshed API and docs
armanbilge Jun 5, 2024
f070552
Formatting
armanbilge Jun 5, 2024
2cf1d8a
Delete dead code
armanbilge Jun 5, 2024
0eec9dd
Document `ThreadLocal` propagation
armanbilge Aug 5, 2024
af84973
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Nov 14, 2024
1adf368
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ private object IOFiberConstants {
final val CedeR = 6
final val AutoCedeR = 7
final val DoneR = 8

final val dumpLocals = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ private[unsafe] sealed abstract class WorkerThread private () extends Thread {
private[unsafe] def isOwnedBy(threadPool: WorkStealingThreadPool): Boolean
private[unsafe] def monitor(fiber: Runnable): WeakBag.Handle
private[unsafe] def index: Int
private[unsafe] var ioLocalState: IOLocalState
}
4 changes: 3 additions & 1 deletion core/jvm/src/main/java/cats/effect/IOFiberConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package cats.effect;

// defined in Java since Scala doesn't let us define static fields
final class IOFiberConstants {
public final class IOFiberConstants {
Copy link
Member Author

@armanbilge armanbilge May 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not good. To avoid this we'll either have to replicate it at both the cats.effect and cats.effect.unsafe levels, or move the thread-local IOLocals accessors into cats.effect.


static final int MaxStackDepth = 512;

Expand All @@ -43,4 +43,6 @@ final class IOFiberConstants {
static final byte CedeR = 6;
static final byte AutoCedeR = 7;
static final byte DoneR = 8;

public static final boolean dumpLocals = Boolean.getBoolean("cats.effect.tracing.dumpLocals");
}
2 changes: 2 additions & 0 deletions core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ private final class WorkerThread(
private val indexTransfer: LinkedTransferQueue[Integer] = new LinkedTransferQueue()
private[this] val runtimeBlockingExpiration: Duration = pool.runtimeBlockingExpiration

private[unsafe] var ioLocalState: IOLocalState = IOLocalState.empty

val nameIndex: Int = pool.blockedWorkerThreadNamingIndex.getAndIncrement()

// Constructor code.
Expand Down
2 changes: 1 addition & 1 deletion core/shared/src/main/scala/cats/effect/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
implicit runtime: unsafe.IORuntime): IOFiber[A @uncheckedVariance] = {

val fiber = new IOFiber[A](
Map.empty,
if (IOFiberConstants.dumpLocals) unsafe.IOLocals.getState else Map.empty,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can even go in the opposite direction for IO#unsafeRun* 😁

It's less clear if/how to do this for fibers started in a Dispatcher, since they should be inheriting locals from the fiber backing the Dispatcher.

oc =>
oc.fold(
{
Expand Down
49 changes: 49 additions & 0 deletions core/shared/src/main/scala/cats/effect/IOFiber.scala
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ private final class IOFiber[A](
pushTracingEvent(cur.event)
}

if (dumpLocals) {
IOLocals.setState(localState)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question: can't we simply do this when we get scheduled on a thread? We know when we're on a thread and we know when we get off of it, so can't we simply set and clear the state respectively at those points?

Copy link
Member Author

@armanbilge armanbilge Sep 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we can't, unless we unify how the state is represented. Currently it's a var to an immutable map in the fiber and also in the thread. While the fiber is running its copy of the var may be updated effectually in the runloop so the thread-local copy would need to be kept in sync with that. Or we could drive all updates through the thread-local copy of the var, but then there would be a penalty for accessing it esp. if we are not running on a worker thread.

Putting aside technical issues, nobody should be unsafely messing about with IOLocals outside of a properly suspended side-effect block and this strategy enforces that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What we can do is set the current fiber in a thread local every time we get scheduled on a thread. Then the unsafe IOLocals manipulations can operate on the state via the current fiber and we don't need to pay the penalty for every delay block. Based on the benchmarks this strategy is seeming more attractive 😅

Note this would leave the fiber's IOLocal state exposed to unsafe manipulations outside of delay blocks.

}
armanbilge marked this conversation as resolved.
Show resolved Hide resolved

var error: Throwable = null
val r =
try cur.thunk()
Expand All @@ -260,6 +264,10 @@ private final class IOFiber[A](
onFatalFailure(t)
}

if (dumpLocals) {
localState = IOLocals.getAndClearState()
}

val next =
if (error == null) succeeded(r, 0)
else failed(error, 0)
Expand Down Expand Up @@ -324,6 +332,10 @@ private final class IOFiber[A](
pushTracingEvent(delay.event)
}

if (dumpLocals) {
IOLocals.setState(localState)
}

// this code is inlined in order to avoid two `try` blocks
var error: Throwable = null
val result =
Expand All @@ -335,6 +347,10 @@ private final class IOFiber[A](
onFatalFailure(t)
}

if (dumpLocals) {
localState = IOLocals.getAndClearState()
}

val nextIO = if (error == null) succeeded(result, 0) else failed(error, 0)
runLoop(nextIO, nextCancelation - 1, nextAutoCede)

Expand Down Expand Up @@ -391,6 +407,10 @@ private final class IOFiber[A](
pushTracingEvent(delay.event)
}

if (dumpLocals) {
IOLocals.setState(localState)
}

// this code is inlined in order to avoid two `try` blocks
val result =
try f(delay.thunk())
Expand All @@ -401,6 +421,10 @@ private final class IOFiber[A](
onFatalFailure(t)
}

if (dumpLocals) {
localState = IOLocals.getAndClearState()
}

runLoop(result, nextCancelation - 1, nextAutoCede)

case 3 =>
Expand Down Expand Up @@ -446,6 +470,10 @@ private final class IOFiber[A](
pushTracingEvent(delay.event)
}

if (dumpLocals) {
IOLocals.setState(localState)
}

// this code is inlined in order to avoid two `try` blocks
var error: Throwable = null
val result =
Expand All @@ -460,6 +488,10 @@ private final class IOFiber[A](
onFatalFailure(t)
}

if (dumpLocals) {
localState = IOLocals.getAndClearState()
}

val next =
if (error == null) succeeded(Right(result), 0) else succeeded(Left(error), 0)
runLoop(next, nextCancelation - 1, nextAutoCede)
Expand Down Expand Up @@ -965,6 +997,10 @@ private final class IOFiber[A](
if (ec.isInstanceOf[WorkStealingThreadPool]) {
val wstp = ec.asInstanceOf[WorkStealingThreadPool]
if (wstp.canExecuteBlockingCode()) {
if (dumpLocals) {
IOLocals.setState(localState)
}

var error: Throwable = null
val r =
try {
Expand All @@ -976,6 +1012,10 @@ private final class IOFiber[A](
onFatalFailure(t)
}

if (dumpLocals) {
localState = IOLocals.getAndClearState()
}

val next = if (error eq null) succeeded(r, 0) else failed(error, 0)
runLoop(next, nextCancelation, nextAutoCede)
} else {
Expand Down Expand Up @@ -1378,6 +1418,11 @@ private final class IOFiber[A](
var error: Throwable = null
val cur = resumeIO.asInstanceOf[Blocking[Any]]
resumeIO = null

if (dumpLocals) {
IOLocals.setState(localState)
}

val r =
try cur.thunk()
catch {
Expand All @@ -1387,6 +1432,10 @@ private final class IOFiber[A](
onFatalFailure(t)
}

if (dumpLocals) {
localState = IOLocals.getAndClearState()
}

if (isStackTracing) {
// Remove the reference to the fiber monitor handle
objectState.pop().asInstanceOf[WeakBag.Handle].deregister()
Expand Down
129 changes: 55 additions & 74 deletions core/shared/src/main/scala/cats/effect/IOLocal.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,35 +136,49 @@ import cats.data.AndThen
* @tparam A
* the type of the local value
*/
sealed trait IOLocal[A] { self =>
sealed trait IOLocal[A] {

protected[effect] def getOrDefault(state: IOLocalState): A

protected[effect] def set(state: IOLocalState, a: A): IOLocalState

protected[effect] def reset(state: IOLocalState): IOLocalState

/**
* Returns the current value.
*/
def get: IO[A]
final def get: IO[A] =
IO.Local(state => (state, getOrDefault(state)))

/**
* Sets the current value to `value`.
*/
def set(value: A): IO[Unit]
final def set(value: A): IO[Unit] =
IO.Local(state => (set(state, value), ()))

/**
* Replaces the current value with the initial value.
*/
def reset: IO[Unit]
final def reset: IO[Unit] =
IO.Local(state => (reset(state), ()))

/**
* Modifies the current value using the given update function.
*/
def update(f: A => A): IO[Unit]
final def update(f: A => A): IO[Unit] =
IO.Local(state => (set(state, f(getOrDefault(state))), ()))

/**
* Like [[update]] but allows the update function to return an output value of type `B`.
*
* @see
* [[update]]
*/
def modify[B](f: A => (A, B)): IO[B]
final def modify[B](f: A => (A, B)): IO[B] =
IO.Local { state =>
val (a2, b) = f(getOrDefault(state))
(set(state, a2), b)
}

/**
* Replaces the current value with `value`, returning the previous value.
Expand All @@ -176,7 +190,8 @@ sealed trait IOLocal[A] { self =>
* @see
* [[set]]
*/
def getAndSet(value: A): IO[A]
final def getAndSet(value: A): IO[A] =
IO.Local(state => (set(state, value), getOrDefault(state)))

/**
* Replaces the current value with the initial value, returning the previous value.
Expand All @@ -188,7 +203,8 @@ sealed trait IOLocal[A] { self =>
* @see
* [[reset]]
*/
def getAndReset: IO[A]
final def getAndReset: IO[A] =
IO.Local(state => (reset(state), getOrDefault(state)))

/**
* Creates a lens to a value of some type `B` from current value and two functions: getter and
Expand Down Expand Up @@ -220,28 +236,7 @@ sealed trait IOLocal[A] { self =>
* } yield ()
* }}}
*/
final def lens[B](get: A => B)(set: A => B => A): IOLocal[B] = {
import IOLocal.IOLocalLens

self match {
case lens: IOLocalLens[aa, A] =>
// We process already created lens separately so
// we wont pay additional `.get.flatMap` price for every call of
// `set`, `update` or `modify` of resulting lens.
// After all, our getters and setters are pure,
// so `AndThen` allows us to safely compose them and
// proxy calls to the 'original' `IOLocal` independent of
// current nesting level.

val getter = lens.getter.andThen(get)
val setter = lens.setter.compose((p: (aa, B)) => (p._1, set(lens.getter(p._1))(p._2)))
new IOLocalLens(lens.underlying, getter, setter)
case _ =>
val getter = AndThen(get)
val setter = AndThen((p: (A, B)) => set(p._1)(p._2))
new IOLocalLens(self, getter, setter)
}
}
def lens[B](get: A => B)(set: A => B => A): IOLocal[B]

}

Expand All @@ -260,63 +255,49 @@ object IOLocal {
*/
def apply[A](default: A): IO[IOLocal[A]] = IO(new IOLocalImpl(default))

private[IOLocal] final class IOLocalImpl[A](default: A) extends IOLocal[A] { self =>
private[this] def getOrDefault(state: IOLocalState): A =
state.getOrElse(self, default).asInstanceOf[A]

def get: IO[A] =
IO.Local(state => (state, getOrDefault(state)))
private[effect] abstract class AbstractIOLocal[A] extends IOLocal[A] {}
armanbilge marked this conversation as resolved.
Show resolved Hide resolved

def set(value: A): IO[Unit] =
IO.Local(state => (state.updated(self, value), ()))
private final class IOLocalImpl[A](default: A) extends IOLocal[A] {

def reset: IO[Unit] =
IO.Local(state => (state - self, ()))
def getOrDefault(state: IOLocalState): A =
state.getOrElse(this, default).asInstanceOf[A]

def update(f: A => A): IO[Unit] =
IO.Local(state => (state.updated(self, f(getOrDefault(state))), ()))
def set(state: IOLocalState, a: A): IOLocalState = state.updated(this, a)

def modify[B](f: A => (A, B)): IO[B] =
IO.Local { state =>
val (a2, b) = f(getOrDefault(state))
(state.updated(self, a2), b)
}
def reset(state: IOLocalState): IOLocalState = state - this

def getAndSet(value: A): IO[A] =
IO.Local(state => (state.updated(self, value), getOrDefault(state)))

def getAndReset: IO[A] =
IO.Local(state => (state - self, getOrDefault(state)))
def lens[B](get: A => B)(set: A => B => A): IOLocal[B] =
new IOLocal.IOLocalLens(this, get, (ab: (A, B)) => set(ab._1)(ab._2))
}

private[IOLocal] final class IOLocalLens[S, A](
val underlying: IOLocal[S],
val getter: AndThen[S, A],
val setter: AndThen[(S, A), S])
private final class IOLocalLens[S, A](
underlying: IOLocal[S],
getter: S => A,
setter: ((S, A)) => S)
extends IOLocal[A] {
def get: IO[A] =
underlying.get.map(getter(_))

def set(value: A): IO[Unit] =
underlying.get.flatMap(s => underlying.set(setter(s -> value)))

def reset: IO[Unit] =
underlying.reset
def getOrDefault(state: IOLocalState): A =
getter(underlying.getOrDefault(state))

def update(f: A => A): IO[Unit] =
underlying.get.flatMap(s => underlying.set(setter(s -> f(getter(s)))))
def set(state: IOLocalState, a: A): IOLocalState =
underlying.set(state, setter((underlying.getOrDefault(state), a)))

def modify[B](f: A => (A, B)): IO[B] =
underlying.get.flatMap { s =>
val (a2, b) = f(getter(s))
underlying.set(setter(s -> a2)).as(b)
}
def reset(state: IOLocalState): IOLocalState = underlying.reset(state)

def getAndSet(value: A): IO[A] =
underlying.get.flatMap(s => underlying.set(setter(s -> value)).as(getter(s)))
def lens[B](get: A => B)(set: A => B => A): IOLocal[B] = {
// We process already created lens separately so
// we wont pay additional `.get.flatMap` price for every call of
// `set`, `update` or `modify` of resulting lens.
// After all, our getters and setters are pure,
// so `AndThen` allows us to safely compose them and
// proxy calls to the 'original' `IOLocal` independent of
// current nesting level.

def getAndReset: IO[A] =
underlying.get.flatMap(s => underlying.reset.as(getter(s)))
val getter = AndThen(this.getter).andThen(get)
val setter =
AndThen(this.setter).compose((sb: (S, B)) => (sb._1, set(this.getter(sb._1))(sb._2)))
new IOLocalLens(underlying, getter, setter)
}
}

}
3 changes: 3 additions & 0 deletions core/shared/src/main/scala/cats/effect/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ package object effect {
val Ref = cekernel.Ref

private[effect] type IOLocalState = scala.collection.immutable.Map[IOLocal[_], Any]
private[effect] object IOLocalState {
val empty: IOLocalState = scala.collection.immutable.Map.empty
}

private[effect] type ByteStack = ByteStack.T
}
Loading