Skip to content

Commit

Permalink
speed up PCFG sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Feb 26, 2024
1 parent 5863d0d commit f784a49
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ fun Array<DoubleArray>.toDoubleMatrix() = DoubleMatrix(size, this[0].size) { i,

fun kroneckerDelta(i: Int, j: Int) = if (i == j) 1.0 else 0.0

fun hashPair(i1: Int, i2: Int): Int = i1 * 31 + i2
fun hash(vararg ints: Any): Int = ints.fold(0) { acc, i -> 31 * acc + i.hashCode() }
fun hash(vararg ints: Int): Int = ints.fold(0) { acc, i -> 31 * acc + i }

const val DEFAULT_FEATURE_LEN = 20
fun String.vectorize(len: Int = DEFAULT_FEATURE_LEN) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ open class FSA(open val Q: TSA, open val init: Set<Σᐩ>, open val final: Set<
val stateMap by lazy { states.toList().withIndex().associate { it.value to it.index } }
val APSP: Map<Int, Int> by lazy {
graph.APSP.map { (k, v) ->
Pair(hashPair(stateMap[k.first.label]!!, stateMap[k.second.label]!!), v)
Pair(hash(stateMap[k.first.label]!!, stateMap[k.second.label]!!), v)
}.toMap()
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.automata.*
import ai.hypergraph.kaliningraph.hashPair
import ai.hypergraph.kaliningraph.hash
import ai.hypergraph.kaliningraph.repair.MAX_TOKENS
import ai.hypergraph.kaliningraph.types.*
import ai.hypergraph.kaliningraph.types.times
Expand Down Expand Up @@ -270,7 +270,7 @@ private fun manhattanDistance(first: Pair<Int, Int>, second: Pair<Int, Int>): In

// Range of the shortest path to the longest path, i.e., Manhattan distance
private fun FSA.SPLP(a: STC, b: STC) =
(APSP[hashPair(a.π1, b.π1)] ?: Int.MAX_VALUE)..
(APSP[hash(a.π1, b.π1)] ?: Int.MAX_VALUE)..
manhattanDistance(a.coords(), b.coords())

private fun IntRange.overlaps(other: IntRange) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ val CFG.tmap: Map<Set<Σᐩ>, Set<Σᐩ>> by cache {
.mapValues { it.value.map { it.second }.toSet() }
}

val CFG.ntLst by cache { (symbols + "ε").toList() }
val CFG.ntMap by cache { ntLst.mapIndexed { i, s -> s to i }.toMap() }

// Maps each nonterminal to the set of nonterminals that can generate it
val CFG.vindex: Array<IntArray> by cache {
Array(bindex.indexedNTs.size) { i ->
Expand Down Expand Up @@ -256,8 +259,9 @@ class BiMap(cfg: CFG) {
}

// n.b., this only works if the CFG is acyclic, i.e., finite otherwise it will loop forever
fun CFG.toPTree(from: Σᐩ = START_SYMBOL): PTree =
PTree(from, bimap[from].map { toPTree(it[0]) to if (it.size == 1) PTree() else toPTree(it[1]) })
fun CFG.toPTree(from: Σᐩ = START_SYMBOL, origCFG: CFG = this): PTree =
PTree(from, bimap[from].map { toPTree(it[0], origCFG) to if (it.size == 1) PTree() else toPTree(it[1], origCFG) })
.also { it.ntIdx = (origCFG.ntMap[(if('~' in from) from.split('~')[1] else from)] ?: Int.MAX_VALUE) }

/*
Γ ⊢ ∀ v.[α→*]∈G ⇒ α→[β] "If all productions rooted at α
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ typealias PForest = Map<String, PTree> // ℙ₃
class PTree(val root: String = "", val branches: List<Π2A<PTree>> = listOf()) {
// val hash by lazy { root.hashCode() + if (branches.isEmpty()) 0 else branches.hashCode() }
// override fun hashCode(): Int = hash
var ntIdx = -1

val branchRatio: Pair<Double, Double> by lazy { if (branches.isEmpty()) 0.0 to 0.0 else
(branches.size.toDouble() + branches.sumOf { (l, r) -> l.branchRatio.first + r.branchRatio.first }) to
Expand Down Expand Up @@ -120,10 +121,10 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()
while (i < 9 * totalTrees) yield(decodeString(i++ * stride + offset).first)
}

fun sampleStrWithPCFG5(pcfgTable: Map<StrQuintuple, Int>): Sequence<String> =
fun sampleStrWithPCFG5(pcfgTable: Map<Int, Int>): Sequence<String> =
sequence { while (true) yield(samplePCFG5(pcfgTable)) }

fun sampleStrWithPCFG3(pcfgTable: Map<Π3A<Σᐩ>, Int>): Sequence<String> =
fun sampleStrWithPCFG3(pcfgTable: Map<Int, Int>): Sequence<String> =
sequence { while (true) yield(samplePCFG3(pcfgTable)) }

// Samples instantaneously from the parse forest, but may return duplicates
Expand Down Expand Up @@ -152,28 +153,32 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()
if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b"
}

fun Σᐩ.name() = if ("~" in this) split("~")[1] else this
val triples : List<Π3A<Σᐩ>> by lazy { branches.map { root.name() to it.first.root.name() to it.second.root.name() } }
fun Σᐩ.name() = if ('~' in this) split('~')[1] else this
val triples : List<Π2A<Int>> by lazy { branches.map { it.first.ntIdx to it.second.ntIdx } }
val rootName by lazy { root.name() }
val isLeaf by lazy { branches.isEmpty() }

fun samplePCFG5(pcfgTable: Map<StrQuintuple, Int>, upUp: Σᐩ = "NIL", upLeft: Σᐩ = "NIL", upRight: Σᐩ = "NIL"): Σᐩ {
fun samplePCFG5(pcfgTable: Map<Int, Int>, upUp: Int = 0, upLeft: Int = 0, upRight: Int = 0): Σᐩ {
if (isLeaf) return epsStr
val probs = triples.map { (pcfgTable[StrQuintuple(upUp, upLeft, upRight, it.second, it.third)] ?: 1) + 1 }
val probs = triples.map {
val hash = hash(upUp, upLeft, upRight, it.first, it.second)
(pcfgTable[hash] ?: 1)
// .also { if(Random.nextInt(10000) == 3) if (it == 1) println("$hash Miss"); else println("$hash Hit") }
+ 1 }
val cdf = probs.runningReduce { acc, i -> acc + i }
val rnd = Random.nextInt(probs.sum())
val childIdx = cdf.binarySearch { it.compareTo(rnd) }.let { if (it < 0) -it - 1 else it }
val (l, r) = branches[childIdx]
val (lr, rr) = l.rootName to r.rootName
val (a, b) = l.samplePCFG5(pcfgTable, rootName, "$lr*", rr) to
r.samplePCFG5(pcfgTable, rootName, lr, "$rr*")
val (lr, rr) = l.ntIdx to r.ntIdx
val (a, b) = l.samplePCFG5(pcfgTable, ntIdx, 31 * lr, rr) to
r.samplePCFG5(pcfgTable, ntIdx, lr, 31 * rr)
return if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b"
}

fun samplePCFG3(pcfgTable: Map<Π3A<Σᐩ>, Int>): Σᐩ {
fun samplePCFG3(pcfgTable: Map<Int, Int>): Σᐩ {
if (branches.isEmpty()) return epsStr

val probs = triples.map { (pcfgTable[it] ?: 1) + 1 }
val probs = triples.map { (pcfgTable[hash(ntIdx, it.first, it.second)] ?: 1) + 1 }
val cdf = probs.runningReduce { acc, i -> acc + i }
val rnd = Random.nextInt(probs.sum())
val childIdx = cdf.binarySearch { it.compareTo(rnd) }.let { if (it < 0) -it - 1 else it }
Expand Down Expand Up @@ -206,12 +211,6 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()
// }
}

data class StrQuintuple(val a: String, val b: String, val c: String, val d: String, val e: String) {
val hash = a.hashCode() + b.hashCode() + c.hashCode() + d.hashCode() + e.hashCode()
override fun hashCode(): Int = hash
override fun equals(other: Any?): Boolean = other is StrQuintuple && other.hash == hash
}

fun CFG.startPTree(tokens: List<String>) = //measureTimedValue {
initPForestMat(tokens).seekFixpoint().diagonals.last()[0][START_SYMBOL]
//}.also { println("Took ${it.duration} to compute parse forest") }.value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import ai.hypergraph.kaliningraph.types.Π2A
import kotlin.math.pow
import kotlin.time.*

var CFG_THRESH = 20_000
var MAX_UNIQUE = 20_000 // Maximum number of unique samples to generate
var MAX_SAMPLE = 20 // Maximum number of repairs to sample
var MAX_TOKENS = 80 // Maximum number of tokens per repair
var MAX_RADIUS = 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,31 +75,27 @@ fun CFG.sampleDirectlyWR(
}
}

fun CFG.sampleWithPCFG(
pcfgTable: Map<StrQuintuple, Int>,
fun PTree.sampleWithPCFG(
pcfgTable: Map<Int, Int>,
cores: Int = NUM_CORES,
stoppingCriterion: () -> Boolean = { true }
): Stream<String> =
toPTree().let {
(0..<cores).toList().parallelStream().flatMap { i ->
it.sampleStrWithPCFG5(pcfgTable)
.takeWhile { stoppingCriterion() }
.distinct()
.asStream()
}
(0..<cores).toList().parallelStream().flatMap { i ->
sampleStrWithPCFG5(pcfgTable)
.takeWhile { stoppingCriterion() }
.distinct()
.asStream()
}

fun CFG.sampleDirectlyWOR(
fun PTree.sampleDirectlyWOR(
cores: Int = NUM_CORES,
stoppingCriterion: () -> Boolean = { true }
): Stream<String> =
toPTree().let {
(0..<cores).toList().parallelStream().flatMap { i ->
it.sampleStrWithoutReplacement(cores, i)
.takeWhile { stoppingCriterion() }
.distinct()
.asStream()
}
(0..<cores).toList().parallelStream().flatMap { i ->
sampleStrWithoutReplacement(cores, i)
.takeWhile { stoppingCriterion() }
.distinct()
.asStream()
}

fun CFG.parallelEnumListWR(
Expand Down Expand Up @@ -169,8 +165,6 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG {

// For each production A → BC in P, for every p, q, r ∈ Q,
// we have the production [p,A,r] → [p,B,q] [q,C,r] in P′.
val ntLst = nonterminals.toList()
val ntMap = ntLst.mapIndexed { i, s -> s to i }.toMap()
val prods: Set<IProduction> = nonterminalProductions
.map { (a, b) -> ntMap[a]!! to b.map { ntMap[it]!! } }.toSet()
val lengthBoundsCache = lengthBounds.let { lb -> nonterminals.map { lb[it]!! } }
Expand Down

0 comments on commit f784a49

Please sign in to comment.