diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/BusyWait.scala b/spark/src/main/scala/org/apache/spark/sql/delta/BusyWait.scala new file mode 100644 index 00000000000..840d3ddedec --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/delta/BusyWait.scala @@ -0,0 +1,46 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta + +import scala.concurrent.duration._ + +object BusyWait { + /** + * Keep checking if `check` returns `true` until it's the case or `waitTime` expires. + * + * Return `true` when the `check` returned `true`, and `false` if `waitTime` expired. + * + * Note: This function is used as a helper function for the Concurrency Testing framework, + * and should not be used in production code. Production code should not use polling + * and should instead use signalling to coordinate. + */ + def until( + check: => Boolean, + waitTime: FiniteDuration): Boolean = { + val DEFAULT_SLEEP_TIME: Duration = 10.millis + val deadline = waitTime.fromNow + + do { + if (check) { + return true + } + val sleepTimeMs = DEFAULT_SLEEP_TIME.min(deadline.timeLeft).toMillis + Thread.sleep(sleepTimeMs) + } while (deadline.hasTimeLeft()) + false + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/fuzzer/AtomicBarrier.scala b/spark/src/main/scala/org/apache/spark/sql/delta/fuzzer/AtomicBarrier.scala new file mode 100644 index 00000000000..9a5d2b5a4e0 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/delta/fuzzer/AtomicBarrier.scala @@ -0,0 +1,135 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.fuzzer + +import java.util.concurrent.atomic.AtomicInteger + +/** + * An atomic barrier is similar to a countdown latch, + * except that the content is a state transition system with semantic meaning + * instead of a simple counter. + * + * It is designed with a single writer ("unblocker") thread and a single reader ("waiter") thread + * in mind. It is concurrency safe with more writers and readers, but using more is likely to cause + * race conditions for legal transitions. That is to say, trying to perform an otherwise + * legal transition twice is illegal and may occur if there is more than one unblocker or + * waiter thread. + * Having additional passive state observers that only call [[load()]] is never an issue. + * + * Legal transitions are: + * - BLOCKED -> UNBLOCKED + * - BLOCKED -> REQUESTED + * - REQUESTED -> UNBLOCKED + * - UNBLOCKED -> PASSED + */ +class AtomicBarrier { + + import AtomicBarrier._ + + private final val state: AtomicInteger = new AtomicInteger(State.Blocked.ordinal) + + /** Get the current state. */ + def load(): State = { + val ordinal = state.get() + // We should never be putting illegal state ordinals into `state`, + // so this should always succeed. + stateIndex(ordinal) + } + + /** Transition to the Unblocked state. */ + def unblock(): Unit = { + // Just hot-retry this, since it never needs to wait to make progress. + var successful = false + while(!successful) { + val currentValue = state.get() + if (currentValue == State.Blocked.ordinal || currentValue == State.Requested.ordinal) { + this.synchronized { + successful = state.compareAndSet(currentValue, State.Unblocked.ordinal) + if (successful) { + this.notifyAll() + } + } + } else { + // if it's in any other state we will never make progress + throw new IllegalStateTransitionException(stateIndex(currentValue), State.Unblocked) + } + } + } + + /** Wait until this barrier can be passed and then mark it as Passed. */ + def waitToPass(): Unit = { + while (true) { + val currentState = load() + currentState match { + case State.Unblocked => + val updated = state.compareAndSet(currentState.ordinal, State.Passed.ordinal) + if (updated) { + return + } + case State.Passed => + throw new IllegalStateTransitionException(State.Passed, State.Passed) + case State.Requested => + this.synchronized { + if (load().ordinal == State.Requested.ordinal) { + this.wait() + } + } + case State.Blocked => + this.synchronized { + val updated = state.compareAndSet(currentState.ordinal, State.Requested.ordinal) + if (updated) { + this.wait() + } + } // else (if we didn't succeed) just hot-retry until we do + // (or more likely pass, since unblocking is the only legal concurrent + // update with a single concurrent "waiter") + } + } + } + + override def toString: String = s"AtomicBarrier(state=${load()})" +} + +object AtomicBarrier { + + sealed trait State { + def ordinal: Int + } + + object State { + case object Blocked extends State { + override final val ordinal = 0 + } + case object Unblocked extends State { + override final val ordinal = 1 + } + case object Requested extends State { + override final val ordinal = 2 + } + case object Passed extends State { + override final val ordinal = 3 + } + } + + final val stateIndex: Map[Int, State] = + List(State.Blocked, State.Unblocked, State.Requested, State.Passed) + .map(state => state.ordinal -> state) + .toMap +} + +class IllegalStateTransitionException(fromState: AtomicBarrier.State, toState: AtomicBarrier.State) + extends RuntimeException(s"State transition from $fromState to $toState is illegal.") diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/concurrency/PhaseLockingTestMixin.scala b/spark/src/test/scala/org/apache/spark/sql/delta/concurrency/PhaseLockingTestMixin.scala new file mode 100644 index 00000000000..c17867c38c3 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/concurrency/PhaseLockingTestMixin.scala @@ -0,0 +1,51 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.concurrency + +import scala.concurrent.duration._ + +import org.apache.spark.sql.delta.BusyWait +import org.apache.spark.sql.delta.fuzzer.AtomicBarrier + +import org.apache.spark.SparkFunSuite + +trait PhaseLockingTestMixin { self: SparkFunSuite => + /** Keep checking if `barrier` in `state` until it's the case or `waitTime` expires. */ + def busyWaitForState( + barrier: AtomicBarrier, + state: AtomicBarrier.State, + waitTime: FiniteDuration): Unit = + busyWaitFor( + barrier.load() == state, + waitTime, + s"Exceeded deadline waiting for $barrier to transition to state $state") + + /** + * Keep checking if `check` return `true` until it's the case or `waitTime` expires. + * + * Optionally provide a custom error `message`. + */ + def busyWaitFor( + check: => Boolean, + timeout: FiniteDuration, + // lazy evaluate so closed over states are evaluated at time of failure not invocation + message: => String = "Exceeded deadline waiting for check to become true."): Unit = { + if (!BusyWait.until(check, timeout)) { + fail(message) + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/fuzzer/AtomicBarrierSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/fuzzer/AtomicBarrierSuite.scala new file mode 100644 index 00000000000..ae06ebcbef2 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/fuzzer/AtomicBarrierSuite.scala @@ -0,0 +1,60 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.fuzzer + +import scala.concurrent.duration._ + +import org.apache.spark.sql.delta.concurrency.PhaseLockingTestMixin + +import org.apache.spark.SparkFunSuite + +class AtomicBarrierSuite extends SparkFunSuite + with PhaseLockingTestMixin { + + val timeout: FiniteDuration = 5000.millis + + test("Atomic Barrier - wait before unblock") { + val barrier = new AtomicBarrier + assert(AtomicBarrier.State.Blocked === barrier.load()) + val thread = new Thread(() => { + barrier.waitToPass() + }) + assert(AtomicBarrier.State.Blocked === barrier.load()) + thread.start() + busyWaitForState(barrier, AtomicBarrier.State.Requested, timeout) + assert(thread.isAlive) // should be stuck waiting for unblock + barrier.unblock() + busyWaitForState(barrier, AtomicBarrier.State.Passed, timeout) + thread.join(timeout.toMillis) // shouldn't take long + assert(!thread.isAlive) // should have passed the barrier and completed + } + + test("Atomic Barrier - unblock before wait") { + val barrier = new AtomicBarrier + assert(AtomicBarrier.State.Blocked === barrier.load()) + val thread = new Thread(() => { + barrier.waitToPass() + }) + assert(AtomicBarrier.State.Blocked === barrier.load()) + barrier.unblock() + assert(AtomicBarrier.State.Unblocked === barrier.load()) + thread.start() + busyWaitForState(barrier, AtomicBarrier.State.Passed, timeout) + thread.join(timeout.toMillis) // shouldn't take long + assert(!thread.isAlive) // should have passed the barrier and completed + } +}