Skip to content

Commit

Permalink
fix shared variable (solver state) access bug in multithreaded runner
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeyBond committed Dec 16, 2023
1 parent 2d35395 commit 79595c2
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 17 deletions.
12 changes: 12 additions & 0 deletions addons/wfc/examples/demo_wfc_2d_gridmap_dungeon_class_map.tscn
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,21 @@
[sub_resource type="GDScript" id="GDScript_f3n0e"]
script/source = "extends Node3D

@export
var stress_test: bool = false

func _ready():
$sample.hide()
$class_map.hide()
$target.show()
$generator.start()


func _on_generator_done():
if stress_test:
$target.clear()
$generator.reset()
$generator.start()
"

[sub_resource type="Resource" id="Resource_3csuy"]
Expand Down Expand Up @@ -95,3 +105,5 @@ render_intermediate_results = true

[node name="progressIndicator" parent="." node_paths=PackedStringArray("generator") instance=ExtResource("8_rvqav")]
generator = NodePath("../generator")

[connection signal="done" from="generator" to="." method="_on_generator_done"]
16 changes: 16 additions & 0 deletions addons/wfc/nodes/generator_2d.gd
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,19 @@ func get_progress() -> float:
return 0

return _runner.get_progress()

## Returns [code]true[/code] iff any solver is currently running.
func is_running() -> bool:
if _runner == null:
return false

return _runner.is_running()

## Resets this generator to it's initial state.
## [br]
## Stops any running solver(s), if any.
func reset():
if _runner != null:
if _runner.is_running():
_runner.interrupt()
_runner = null
53 changes: 40 additions & 13 deletions addons/wfc/runners/runner_multithreaded.gd
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,36 @@ const _STUB_NO_MULTITHREADING = false
## Settings for this runner.
var runner_settings: WFCMultithreadedRunnerSettings = WFCMultithreadedRunnerSettings.new()

class _TaskStatusContainer extends RefCounted:
var unsolved_cells: int

var state_snapshot_required: bool
var state_snapshot_mutex: Mutex
var state_snapshot: WFCSolverState

func _init(total_cells: int):
unsolved_cells = total_cells
state_snapshot_mutex = Mutex.new()

func take_and_request_snapshot() -> WFCSolverState:
state_snapshot_mutex.lock()
var snapshot := state_snapshot
state_snapshot_mutex.unlock()
state_snapshot_required = true
return snapshot

class _Task extends RefCounted:
var problem: WFCProblem
var dependencies: PackedInt64Array
var solver: WFCSolver = null
var thread: Thread = null
var is_completed: bool = false
var status_container: _TaskStatusContainer

func _init(problem_: WFCProblem, dependencies_: PackedInt64Array):
problem = problem_
dependencies = dependencies_
status_container = _TaskStatusContainer.new(problem.get_cell_count())

func is_started() -> bool:
return thread != null
Expand Down Expand Up @@ -55,22 +75,24 @@ class _Task extends RefCounted:
return problem.get_cell_count()

func get_unsolved_cells() -> int:
if solver == null:
return get_total_cells()

var state: WFCSolverState = solver.current_state

if state == null:
return 0

return state.unsolved_cells
return status_container.unsolved_cells

var tasks: Array[_Task] = []
var interrupted: bool = false

func _thread_main(solver: WFCSolver):
func _thread_main(solver: WFCSolver, status_container: _TaskStatusContainer):
while (not interrupted) and (not solver.solve_step()):
pass
status_container.unsolved_cells = solver.current_state.unsolved_cells

if status_container.state_snapshot_required:
status_container.state_snapshot_required = false

var state_snapshot: WFCSolverState = solver.current_state.make_snapshot()

var mx := status_container.state_snapshot_mutex
mx.lock()
status_container.state_snapshot = state_snapshot
mx.unlock()

func _noop():
pass
Expand All @@ -85,7 +107,7 @@ func _start_tasks(max_start: int) -> int:
task.thread.start(_noop)
task.solver.solve()
else:
task.thread.start(_thread_main.bind(task.solver))
task.thread.start(_thread_main.bind(task.solver, task.status_container))

started += 1

Expand Down Expand Up @@ -113,6 +135,8 @@ func update():
var running: int = 0
var completed: int = 0

var emit_partial_solution := partial_solution.get_connections().size() > 0

for task in tasks:
if task.is_completed:
completed += 1
Expand All @@ -123,7 +147,10 @@ func update():
sub_problem_solved.emit(task.problem, task.solver.current_state)
completed += 1
else:
partial_solution.emit(task.problem, task.solver.current_state)
if emit_partial_solution:
var snapshot := task.status_container.take_and_request_snapshot()
if snapshot != null:
partial_solution.emit(task.problem, snapshot)
running += 1

if unstarted == 0 and running == 0:
Expand Down
6 changes: 2 additions & 4 deletions addons/wfc/solver/solver.gd
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ func _propagate_constraints() -> bool:
Returns:
true iff solution has failed and backtracking should be performed
"""
assert(current_state != null)

while true:
var changed: PackedInt64Array = current_state.extract_changed_cells()

Expand Down Expand Up @@ -161,10 +163,6 @@ func _try_backtrack() -> bool:
## [br]
## Returns [code]true[/code] iff solution is completed.
func solve_step() -> bool:
"""
Returns:
true iff process has termitated (either successfully or with failure)
"""
assert(current_state != null)

if current_state.is_all_solved():
Expand Down
13 changes: 13 additions & 0 deletions addons/wfc/solver/solver_state.gd
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,19 @@ func make_next() -> WFCSolverState:

return new

## Makes a copy of this state.
## [br]
## The copy is unlinked from this state's previous state.
## The copy is safe to access from thread different from one the solver runs on.
func make_snapshot() -> WFCSolverState:
var new: WFCSolverState = WFCSolverState.new()

new.cell_domains = cell_domains.duplicate()
new.cell_solution_or_entropy = cell_solution_or_entropy.duplicate()
new.unsolved_cells = unsolved_cells

return new

func pick_divergence_cell() -> int:
assert(unsolved_cells > 0)

Expand Down

0 comments on commit 79595c2

Please sign in to comment.