Skip to content

Commit

Permalink
parallelize beam search inner loop
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Nov 30, 2024
1 parent 54c233b commit a189f6f
Showing 1 changed file with 12 additions and 68 deletions.
80 changes: 12 additions & 68 deletions src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import ai.hypergraph.markovian.mcmc.MarkovChain
import dk.brics.automaton.Automaton.*
import dk.brics.automaton.Transition
import java.util.PriorityQueue
import java.util.concurrent.PriorityBlockingQueue
import kotlin.random.Random
import kotlin.time.*

Expand Down Expand Up @@ -98,99 +99,42 @@ fun BAutomaton.decodeDFA(
mc: MarkovChain<Σᐩ>,
// BAutomata uses a Unicode alphabet, and the Markov Chain recognizes a
// string-based alphabet, so we need a way to translate between the two
dec: Map<Char, Σᐩ>, // Maps unicode characters back to strings because BAutomata uses Unicode
callback: (Σᐩ) -> Unit = {},
topK: Int = 10_000_000, // Total number of top-K results to return
timeout: Duration = Duration.INFINITE,
): List<Σᐩ> {
val startTime = TimeSource.Monotonic.markNow()
val load = 100_000
val fullTrajectories = PriorityQueue<FSATrajectory>(load, compareBy { it.score / it.traj.size })
val partTrajectories =
PriorityQueue<FSATrajectory>(load, compareBy { it.score / it.traj.size })
.apply { add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) }

while (
fullTrajectories.size < topK &&
partTrajectories.size > 0 &&
startTime.elapsedNow() < timeout
) {
val partTraj = partTrajectories.poll()
val lastToks = partTraj.traj.take(mc.memory - 1).reversed()
partTraj.lastState.transitions.flatMap { next ->
(next.min..next.max).map { tok ->
val decTok = dec[tok]
val nextScore = partTraj.score + mc.scoreChunk(lastToks + decTok)

Triple(next, decTok, nextScore)
}
}
// .sortedBy { (_, _, nextScore) -> -nextScore }.take(100)
.forEach { (next: Transition, decTok: String?, nextScore: Double) ->
val traj = partTraj.append(decTok, next.dest, nextScore)
if (!traj.isComplete) { partTrajectories.add(traj) }
else {
fullTrajectories.add(traj.also { callback(it.toString()) })
if (traj.lastState.transitions.isNotEmpty()) partTrajectories.add(traj)
}
}
}

val deduped = fullTrajectories.map { it.toString() }.distinct().toList()
// .map { it.toString() to mc.score(it.tokens) }
// .distinct().toList().sortedBy { it.second }.map { it.first }

// println("Top 10 trajectories:")
// fullTrajectories.take(10).forEach { println(it.score.toString().take(5) + ": $it") }
println("Took ${startTime.elapsedNow()} to decode ${deduped.size} trajectories, with ${partTrajectories.size} in queue")

return deduped
}

fun BAutomaton.decodeDFAWithBeamSearch(
mc: MarkovChain<Σᐩ>,
dec: Map<Char, Σᐩ>, // Maps unicode characters back to strings
callback: (Σᐩ) -> Unit = {},
topK: Int = 10_000_000, // Total number of top-K results to return
timeout: Duration = Duration.INFINITE,
beamWidth: Int = 100_000, // Maximum number of trajectories to keep at each step
beamWidth: Long = 1_000_000L, // Maximum number of trajectories to keep at each step
): List<Σᐩ> {
val startTime = TimeSource.Monotonic.markNow()
val fullTrajectories = PriorityQueue<FSATrajectory>(compareBy { it.score / it.traj.size }) // Max-heap for full trajectories
val fullTrajectories = PriorityBlockingQueue<FSATrajectory>(10000, compareBy { it.score / it.traj.size }) // Max-heap for full trajectories
val beam = PriorityQueue<FSATrajectory>(compareBy { it.score / it.traj.size }) // Beam for partial trajectories

beam.add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0))

while (
fullTrajectories.size < topK &&
fullTrajectories.size < beamWidth &&
beam.isNotEmpty() &&
startTime.elapsedNow() < timeout
) {
val nextBeam = PriorityQueue<FSATrajectory>(compareBy { it.score / it.traj.size })

while (beam.isNotEmpty() && startTime.elapsedNow() < timeout) {
val partTraj = beam.poll()
val nextBeam = beam.parallelStream().flatMap { partTraj ->
val lastToks = partTraj.traj.take(mc.memory - 1).reversed()

partTraj.lastState.transitions.flatMap { next ->
(next.min..next.max).map { tok ->
val decTok = dec[tok]
val nextScore = partTraj.score + mc.scoreChunk(lastToks + decTok)
partTraj.append(decTok, next.dest, nextScore)
}
}.forEach { traj ->
}.flatMap { traj ->
if (traj.isComplete) {
if (traj.lastState.transitions.isNotEmpty()) nextBeam.add(traj)
fullTrajectories.add(traj)
callback(traj.toString())
} else {
nextBeam.add(traj)
}
}
}
if (traj.lastState.transitions.isNotEmpty()) listOf(traj to traj.score) else emptyList()
} else { listOf(traj to traj.score) }
}.stream()
}.sorted(compareBy { it.second / it.first.traj.size })
.limit(beamWidth).map { it.first }.toList()

beam.clear()
beam.addAll(nextBeam.take(beamWidth))
beam.addAll(nextBeam)
}

val deduped = fullTrajectories.map { it.toString() }.distinct().toList()
Expand Down

0 comments on commit a189f6f

Please sign in to comment.