From 5863d0d7ecb391547ebb9e2603586305301ad4c3 Mon Sep 17 00:00:00 2001 From: breandan Date: Sun, 25 Feb 2024 16:23:48 -0500 Subject: [PATCH] speed up filtering --- .../ai/hypergraph/kaliningraph/CommonUtils.kt | 2 + .../hypergraph/kaliningraph/automata/FSA.kt | 13 +-- .../kaliningraph/parsing/BarHillel.kt | 79 ++++++++----------- .../ai/hypergraph/kaliningraph/parsing/CFG.kt | 1 + .../hypergraph/kaliningraph/parsing/Parikh.kt | 2 + .../kaliningraph/parsing/JVMBarHillel.kt | 22 +++--- 6 files changed, 59 insertions(+), 60 deletions(-) diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/CommonUtils.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/CommonUtils.kt index ecbbe962..a47e83be 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/CommonUtils.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/CommonUtils.kt @@ -93,6 +93,8 @@ fun Array.toDoubleMatrix() = DoubleMatrix(size, this[0].size) { i, fun kroneckerDelta(i: Int, j: Int) = if (i == j) 1.0 else 0.0 +fun hashPair(i1: Int, i2: Int): Int = i1 * 31 + i2 + const val DEFAULT_FEATURE_LEN = 20 fun String.vectorize(len: Int = DEFAULT_FEATURE_LEN) = Random(hashCode()).let { randomVector(len) { it.nextDouble() } } diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/FSA.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/FSA.kt index 9802d9a4..a52a7e8f 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/FSA.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/FSA.kt @@ -1,17 +1,16 @@ package ai.hypergraph.kaliningraph.automata +import ai.hypergraph.kaliningraph.* import ai.hypergraph.kaliningraph.graphs.* import ai.hypergraph.kaliningraph.parsing.* -import ai.hypergraph.kaliningraph.tokenizeByWhitespace import ai.hypergraph.kaliningraph.types.* -import kotlin.math.* typealias Arc = Π3A<Σᐩ> typealias TSA = Set fun Arc.pretty() = "$π1 -<$π2>-> $π3" fun Σᐩ.coords(): Pair = (length / 2 - 1).let { substring(2, it + 2).toInt() to substring(it + 3).toInt() } -typealias STC = Triple<Σᐩ, Int, Int> +typealias STC = Triple fun STC.coords() = π2 to π3 open class FSA(open val Q: TSA, open val init: Set<Σᐩ>, open val final: Set<Σᐩ>) { @@ -19,9 +18,11 @@ open class FSA(open val Q: TSA, open val init: Set<Σᐩ>, open val final: Set< val isNominalizable by lazy { alphabet.any { it.startsWith("[!=]") } } val nominalForm: NOM by lazy { nominalize() } val states by lazy { Q.states } - val APSP: Map, Int> by lazy { + val stateLst by lazy { states.toList() } + val stateMap by lazy { states.toList().withIndex().associate { it.value to it.index } } + val APSP: Map by lazy { graph.APSP.map { (k, v) -> - Pair(Pair(k.first.label, k.second.label), v) + Pair(hashPair(stateMap[k.first.label]!!, stateMap[k.second.label]!!), v) }.toMap() } @@ -29,7 +30,7 @@ open class FSA(open val Q: TSA, open val init: Set<Σᐩ>, open val final: Set< Q.groupBy { it.π1 }.mapValues { (_, v) -> v.map { it.π2 to it.π3 } } } - val stateCoords: Sequence by lazy { states.map { it.coords().let { (i, j) -> Triple(it, i, j) } }.asSequence() } + val stateCoords: Sequence by lazy { states.map { it.coords().let { (i, j) -> Triple(stateMap[it]!!, i, j) } }.asSequence() } val validTriples by lazy { stateCoords.let { it * it * it }.filter { it.isValidStateTriple() }.toList() } diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt index a729b938..9a8f826d 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt @@ -1,10 +1,11 @@ package ai.hypergraph.kaliningraph.parsing import ai.hypergraph.kaliningraph.automata.* +import ai.hypergraph.kaliningraph.hashPair import ai.hypergraph.kaliningraph.repair.MAX_TOKENS import ai.hypergraph.kaliningraph.types.* import ai.hypergraph.kaliningraph.types.times -import kotlin.math.absoluteValue +import kotlin.math.* import kotlin.time.TimeSource /** @@ -28,7 +29,6 @@ fun CFG.barHillelRepair(prompt: List<Σᐩ>, distance: Int) = // https://browse.arxiv.org/pdf/2209.06809.pdf#page=5 private infix fun CFG.intersectLevFSAP(fsa: FSA): CFG { var clock = TimeSource.Monotonic.markNow() - val lengthBoundsCache = lengthBounds val nts = mutableSetOf("START") fun Σᐩ.isSyntheticNT() = first() == '[' && last() == ']' && count { it == '~' } == 2 @@ -50,11 +50,15 @@ private infix fun CFG.intersectLevFSAP(fsa: FSA): CFG { // For each production A → BC in P, for every p, q, r ∈ Q, // we have the production [p,A,r] → [p,B,q] [q,C,r] in P′. - val validTriples = - fsa.stateCoords.let { it * it * it }.filter { it.isValidStateTriple() }.toList() + val ntLst = nonterminals.toList() + val ntMap = ntLst.withIndex().associate { (i, s) -> s to i } + val prods: Set = nonterminalProductions + .map { (a, b) -> ntMap[a]!! to b.map { ntMap[it]!! } }.toSet() + val lengthBoundsCache = lengthBounds.let { lb -> nonterminals.map { lb[it]!! } } + val validTriples: List> = fsa.validTriples val binaryProds = - nonterminalProductions.map { + prods.map { // if (i % 100 == 0) println("Finished ${i}/${nonterminalProductions.size} productions") val (A, B, C) = it.π1 to it.π2[0] to it.π2[1] validTriples @@ -64,7 +68,7 @@ private infix fun CFG.intersectLevFSAP(fsa: FSA): CFG { .filter { it.obeysLevenshteinParikhBounds(A to B to C, fsa, parikhMap) } .map { (a, b, c) -> val (p, q, r) = a.π1 to b.π1 to c.π1 - "[$p~$A~$r]".also { nts.add(it) } to listOf("[$p~$B~$q]", "[$q~$C~$r]") + "[$p~${ntLst[A]}~$r]".also { nts.add(it) } to listOf("[$p~${ntLst[B]}~$q]", "[$q~${ntLst[C]}~$r]") }.toList() }.flatten().filterRHSInNTS() @@ -244,48 +248,35 @@ fun Π3A.isValidStateTriple(): Boolean { // && obeys(second, third, nts.third) //} -fun Π3A.obeysLevenshteinParikhBounds(nts: Triple<Σᐩ, Σᐩ, Σᐩ>, fsa: FSA, parikhMap: ParikhMap): Boolean { - fun obeys(a: STC, b: STC, nt: Σᐩ): Bln { - val sl = - fsa.levString.size <= a.second || // Part of the LA that handles extra - fsa.levString.size <= b.second // terminals at the end of the string - - if (sl) return true - val margin = (b.third - a.third).absoluteValue - val length = (b.second - a.second) - val range = (length - margin).coerceAtLeast(1)..(length + margin) - val pb = parikhMap.parikhBounds(nt, range) - val pv = fsa.parikhVector(a.second, b.second) - return pb.admits(pv, margin) - } +private fun FSA.obeys(a: STC, b: STC, nt: Int, parikhMap: ParikhMap): Bln { + val sl = levString.size <= max(a.second, b.second) // Part of the LA that handles extra - return obeys(first, third, nts.first) - && obeys(first, second, nts.second) - && obeys(second, third, nts.third) + if (sl) return true + val margin = (b.third - a.third).absoluteValue + val length = (b.second - a.second) + val range = (length - margin).coerceAtLeast(1)..(length + margin) + val pb = parikhMap.parikhBounds(nt, range) + val pv = parikhVector(a.second, b.second) + return pb.admits(pv, margin) } -fun Π3A.isCompatibleWith(nts: Triple<Σᐩ, Σᐩ, Σᐩ>, fsa: FSA, lengthBounds: Map<Σᐩ, IntRange>): Boolean { - fun lengthBounds(nt: Σᐩ): IntRange = - (lengthBounds[nt] ?: -9999..-9990) - // Okay if we overapproximate the length bounds a bit -// .let { (it.first - fudge)..(it.last + fudge) } - - fun manhattanDistance(first: Pair, second: Pair): Int = - (second.second - first.second).absoluteValue + (second.first - first.first).absoluteValue +fun Π3A.obeysLevenshteinParikhBounds(nts: Triple, fsa: FSA, parikhMap: ParikhMap): Boolean = + fsa.obeys(first, third, nts.first, parikhMap) + && fsa.obeys(first, second, nts.second, parikhMap) + && fsa.obeys(second, third, nts.third, parikhMap) - // Range of the shortest path to the longest path, i.e., Manhattan distance - fun SPLP(a: STC, b: STC) = - (fsa.APSP[a.π1 to b.π1] ?: Int.MAX_VALUE).. - manhattanDistance(a.coords(), b.coords()) +private fun manhattanDistance(first: Pair, second: Pair): Int = + (second.second - first.second).absoluteValue + (second.first - first.first).absoluteValue - fun IntRange.overlaps(other: IntRange) = - (other.first in first..last) || (other.last in first..last) +// Range of the shortest path to the longest path, i.e., Manhattan distance +private fun FSA.SPLP(a: STC, b: STC) = + (APSP[hashPair(a.π1, b.π1)] ?: Int.MAX_VALUE).. + manhattanDistance(a.coords(), b.coords()) - // "[$p,$A,$r] -> [$p,$B,$q] [$q,$C,$r]" - fun isCompatible() = - lengthBounds(nts.first).overlaps(SPLP(first, third)) - && lengthBounds(nts.second).overlaps(SPLP(first, second)) - && lengthBounds(nts.third).overlaps(SPLP(second, third)) +private fun IntRange.overlaps(other: IntRange) = + (other.first in first..last) || (other.last in first..last) - return isCompatible() -} \ No newline at end of file +fun Π3A.isCompatibleWith(nts: Triple, fsa: FSA, lengthBounds: List): Boolean = + lengthBounds[nts.first].overlaps(fsa.SPLP(first, third)) + && lengthBounds[nts.second].overlaps(fsa.SPLP(first, second)) + && lengthBounds[nts.third].overlaps(fsa.SPLP(second, third)) \ No newline at end of file diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt index 21e998c5..7a27a279 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt @@ -10,6 +10,7 @@ import kotlin.time.* typealias Σᐩ = String typealias Production = Π2<Σᐩ, List<Σᐩ>> +typealias IProduction = Π2> // TODO: make this immutable typealias CFG = Set diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Parikh.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Parikh.kt index 084b15c8..739ea32c 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Parikh.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Parikh.kt @@ -54,6 +54,7 @@ class ParikhMap(val cfg: CFG, val size: Int) { private val lengthBounds: MutableMap> = mutableMapOf() private val parikhMap: MutableMap = mutableMapOf() val parikhRangeMap: MutableMap = mutableMapOf() + val ntIdx = cfg.nonterminals.toList() companion object { fun genRanges(delta: Int = 2 * MAX_RADIUS + 1, n: Int = MAX_TOKENS) = @@ -83,6 +84,7 @@ class ParikhMap(val cfg: CFG, val size: Int) { } } + fun parikhBounds(nt: Int, range: IntRange): ParikhBounds = parikhBounds(ntIdx[nt], range) fun parikhBounds(nt: Σᐩ, range: IntRange): ParikhBounds = parikhRangeMap[range]?.get(nt) ?: emptyMap() fun parikhBounds(nt: Σᐩ, size: Int): ParikhBounds? = parikhMap[size]?.get(nt) // parikhMap.also { println("Keys (${nt}): " + it.keys.size + ", ${it[size]?.get(nt)}") }[size]?.get(nt) diff --git a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt index 3add28e1..68e13ba7 100644 --- a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt +++ b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt @@ -154,7 +154,6 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG { // if (fsa.Q.size < 650) throw Exception("FSA size was out of bounds") var clock = TimeSource.Monotonic.markNow() - val lengthBoundsCache = lengthBounds val nts = ConcurrentSkipListSet(setOf("START")) val initFinal = @@ -170,31 +169,34 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG { // For each production A → BC in P, for every p, q, r ∈ Q, // we have the production [p,A,r] → [p,B,q] [q,C,r] in P′. - val prods: Set = nonterminalProductions - var i = 0 + val ntLst = nonterminals.toList() + val ntMap = ntLst.mapIndexed { i, s -> s to i }.toMap() + val prods: Set = nonterminalProductions + .map { (a, b) -> ntMap[a]!! to b.map { ntMap[it]!! } }.toSet() + val lengthBoundsCache = lengthBounds.let { lb -> nonterminals.map { lb[it]!! } } val validTriples: List> = fsa.validTriples val elimCounter = AtomicInteger(0) val counter = AtomicInteger(0) + val lpClock = TimeSource.Monotonic.markNow() val binaryProds = prods.parallelStream().flatMap { -// if (i++ % 100 == 0) println("Finished $i/${nonterminalProductions.size} productions") if (BH_TIMEOUT < clock.elapsedNow()) throw Exception("Timeout: ${nts.size} nts") val (A, B, C) = it.π1 to it.π2[0] to it.π2[1] validTriples.stream() // CFG ∩ FSA - in general we are not allowed to do this, but it works // because we assume a Levenshtein FSA, which is monotone and acyclic. - .filter { it.isCompatibleWith(A to B to C, fsa, lengthBoundsCache).also { if (it) elimCounter.incrementAndGet() } } - .filter { it.obeysLevenshteinParikhBounds(A to B to C, fsa, parikhMap).also { if (it) elimCounter.incrementAndGet() } } + .filter { it.isCompatibleWith(A to B to C, fsa, lengthBoundsCache).also { if (!it) elimCounter.incrementAndGet() } } + .filter { it.obeysLevenshteinParikhBounds(A to B to C, fsa, parikhMap).also { if (!it) elimCounter.incrementAndGet() } } .map { (a, b, c) -> if (MAX_PRODS < counter.incrementAndGet()) throw Exception("∩-grammar has too many productions! (>$MAX_PRODS)") - val (p, q, r) = a.π1 to b.π1 to c.π1 - "[$p~$A~$r]".also { nts.add(it) } to listOf("[$p~$B~$q]", "[$q~$C~$r]") + val (p, q, r) = fsa.stateLst[a.π1] to fsa.stateLst[b.π1] to fsa.stateLst[c.π1] + "[$p~${ntLst[A]}~$r]".also { nts.add(it) } to listOf("[$p~${ntLst[B]}~$q]", "[$q~${ntLst[C]}~$r]") } }.toList() - println("LP constraints eliminated $elimCounter productions...") + println("Levenshtein-Parikh constraints eliminated $elimCounter productions in ${lpClock.elapsedNow()}") fun Σᐩ.isSyntheticNT() = first() == '[' && length > 1 // && last() == ']' && count { it == '~' } == 2 @@ -214,7 +216,7 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG { fun CFG.jvmPostProcess(clock: TimeSource.Monotonic.ValueTimeMark) = jvmDropVestigialProductions(clock) .jvmElimVarUnitProds() - .also { println("Reduced ∩-grammar from $size to ${it.size} useful productions in ${clock.elapsedNow()}") } + .also { println("Normalization eliminated ${size - it.size} productions in ${clock.elapsedNow()}") } .freeze() tailrec fun CFG.jvmElimVarUnitProds(