Skip to content

Commit

Permalink
Refactoring tests to better exercise AsyncRecursiveLock
Browse files Browse the repository at this point in the history
  • Loading branch information
mattmassicotte committed Sep 30, 2024
1 parent 1778458 commit ffb00a2
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 58 deletions.
52 changes: 38 additions & 14 deletions Sources/Lock/AsyncRecursiveLock.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
public final class AsyncRecursiveLock {
// This thing doesn't quite work yet...
final class AsyncRecursiveLock {
@TaskLocal private static var locked = false
@TaskLocal private static var lockedSet = Set<ObjectIdentifier>()

private let internalLock = AsyncLock()

Expand All @@ -10,24 +12,46 @@ public final class AsyncRecursiveLock {
isolation: isolated (any Actor)? = #isolation,
_ block: () async throws -> T
) async rethrows -> T {
if Self.locked {
let id = ObjectIdentifier(self)
var set = Self.lockedSet

let (needsLock, _) = set.insert(id)

print("state:", id, needsLock)

if needsLock == false {
return try await block()
}

await internalLock.lock()

do {
let value = try await Self.$locked.withValue(true) {
return try await internalLock.withLock {
try await Self.$lockedSet.withValue(set) {
try await block()
}

internalLock.unlock()

return value
} catch {
internalLock.unlock()

throw error
}
}

// public func withLock<T: Sendable>(
// isolation: isolated (any Actor)? = #isolation,
// _ block: () async throws -> T
// ) async rethrows -> T {
// if Self.locked {
// return try await block()
// }
//
// await internalLock.lock()
//
// do {
// let value = try await Self.$locked.withValue(true) {
// try await block()
// }
//
// internalLock.unlock()
//
// return value
// } catch {
// internalLock.unlock()
//
// throw error
// }
// }
}
36 changes: 14 additions & 22 deletions Tests/LockTests/LockTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,20 @@ import Testing
import Lock

actor ReentrantActor {
var value = 42
let state = ReentrantSensitiveState()
let lock = AsyncLock()

func doThingWithLock() async {
await lock.withLock {
try! #require(self.value == 42)
self.value = 0
try! await Task.sleep(nanoseconds: 1_000_000)
try! #require(self.value == 0)
self.value = 42
func doThingUsingWithLock() async throws {
try await lock.withLock {
try await state.doThing()
}
}

func doThing() async {
func doThingUsingLockUnlock() async throws {
await lock.lock()
defer { lock.unlock() }

try! #require(self.value == 42)
self.value = 0
try! await Task.sleep(nanoseconds: 1_000_000)
try! #require(self.value == 0)
self.value = 42
try await state.doThing()
}
}

Expand All @@ -37,38 +29,38 @@ struct LockTests {
}

@Test
func serializes() async {
func serializes() async throws {
let actor = ReentrantActor()
var tasks = [Task<Void, Never>]()
var tasks = [Task<Void, any Error>]()

for _ in 0..<1000 {
let task = Task {
await actor.doThing()
try await actor.doThingUsingLockUnlock()
}

tasks.append(task)
}

for task in tasks {
await task.value
try await task.value
}
}

@Test
func serializesWithLock() async {
func serializesWithLock() async throws {
let actor = ReentrantActor()
var tasks = [Task<Void, Never>]()
var tasks = [Task<Void, any Error>]()

for _ in 0..<1000 {
let task = Task {
await actor.doThingWithLock()
try await actor.doThingUsingWithLock()
}

tasks.append(task)
}

for task in tasks {
await task.value
try await task.value
}
}
}
86 changes: 64 additions & 22 deletions Tests/LockTests/RecursiveLockTests.swift
Original file line number Diff line number Diff line change
@@ -1,17 +1,46 @@
import Testing
import Lock
@testable import Lock

actor RecursiveReentrantActor {
actor ReentrantSensitiveState {
var value = 42

func doThing() async throws {
try #require(self.value == 42)
self.value = 0
try await Task.sleep(nanoseconds: 1_000_000)
try #require(self.value == 0)
self.value = 42

}
}

actor RecursiveReentrantActor {
let state = ReentrantSensitiveState()
let lock = AsyncRecursiveLock()

func doThing() async {
await lock.withLock {
try! #require(self.value == 42)
self.value = 0
try! await Task.sleep(nanoseconds: 1_000_000)
try! #require(self.value == 0)
self.value = 42
func doThing() async throws {
try await lock.withLock {
try await state.doThing()
}
}
}

actor TwoLockRecursiveReentrantActor {
let state = ReentrantSensitiveState()
let lock1 = AsyncRecursiveLock()
let lock2 = AsyncRecursiveLock()

func holdBothLocks(with block: () async throws -> Void) async rethrows {
try await lock1.withLock {
try await lock2.withLock {
try await block()
}
}
}

func doThing() async throws {
try await lock2.withLock {
try await state.doThing()
}
}
}
Expand All @@ -27,29 +56,42 @@ struct RecursiveLockTests {
}
}

@Test
func lockAndUnlock() async {
let lock = AsyncLock()

await lock.lock()
lock.unlock()
}

@Test
func serializesWithRecursiveLock() async {
// @Test
func serializesWithRecursiveLock() async throws {
let actor = RecursiveReentrantActor()
var tasks = [Task<Void, Never>]()
var tasks = [Task<Void, any Error>]()

for _ in 0..<1000 {
let task = Task {
await actor.doThing()
try await actor.doThing()
}

tasks.append(task)
}

for task in tasks {
await task.value
try await task.value
}
}

// @Test
func serializesWithTwoLocks() async throws {
let actor = TwoLockRecursiveReentrantActor()

try await actor.holdBothLocks {
var tasks = [Task<Void, any Error>]()

for _ in 0..<1000 {
let task = Task {
try await actor.doThing()
}

tasks.append(task)
}

for task in tasks {
try await task.value
}
}
}
}

0 comments on commit ffb00a2

Please sign in to comment.