From f784a495c22f0564bc35cb3f14e4ab83a1b8406d Mon Sep 17 00:00:00 2001 From: breandan Date: Sun, 25 Feb 2024 23:34:03 -0500 Subject: [PATCH] speed up PCFG sampler --- .../ai/hypergraph/kaliningraph/CommonUtils.kt | 3 +- .../hypergraph/kaliningraph/automata/FSA.kt | 2 +- .../kaliningraph/parsing/BarHillel.kt | 4 +-- .../ai/hypergraph/kaliningraph/parsing/CFG.kt | 8 +++-- .../kaliningraph/parsing/SeqValiant.kt | 33 +++++++++---------- .../kaliningraph/repair/SyntaxRepair.kt | 2 ++ .../kaliningraph/parsing/JVMBarHillel.kt | 32 ++++++++---------- 7 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/CommonUtils.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/CommonUtils.kt index a47e83be..e5323d74 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/CommonUtils.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/CommonUtils.kt @@ -93,7 +93,8 @@ fun Array.toDoubleMatrix() = DoubleMatrix(size, this[0].size) { i, fun kroneckerDelta(i: Int, j: Int) = if (i == j) 1.0 else 0.0 -fun hashPair(i1: Int, i2: Int): Int = i1 * 31 + i2 +fun hash(vararg ints: Any): Int = ints.fold(0) { acc, i -> 31 * acc + i.hashCode() } +fun hash(vararg ints: Int): Int = ints.fold(0) { acc, i -> 31 * acc + i } const val DEFAULT_FEATURE_LEN = 20 fun String.vectorize(len: Int = DEFAULT_FEATURE_LEN) = diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/FSA.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/FSA.kt index a52a7e8f..238533e6 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/FSA.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/FSA.kt @@ -22,7 +22,7 @@ open class FSA(open val Q: TSA, open val init: Set<Σᐩ>, open val final: Set< val stateMap by lazy { states.toList().withIndex().associate { it.value to it.index } } val APSP: Map by lazy { graph.APSP.map { (k, v) -> - Pair(hashPair(stateMap[k.first.label]!!, stateMap[k.second.label]!!), v) + Pair(hash(stateMap[k.first.label]!!, stateMap[k.second.label]!!), v) }.toMap() } diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt index 9a8f826d..95d33cfe 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt @@ -1,7 +1,7 @@ package ai.hypergraph.kaliningraph.parsing import ai.hypergraph.kaliningraph.automata.* -import ai.hypergraph.kaliningraph.hashPair +import ai.hypergraph.kaliningraph.hash import ai.hypergraph.kaliningraph.repair.MAX_TOKENS import ai.hypergraph.kaliningraph.types.* import ai.hypergraph.kaliningraph.types.times @@ -270,7 +270,7 @@ private fun manhattanDistance(first: Pair, second: Pair): In // Range of the shortest path to the longest path, i.e., Manhattan distance private fun FSA.SPLP(a: STC, b: STC) = - (APSP[hashPair(a.π1, b.π1)] ?: Int.MAX_VALUE).. + (APSP[hash(a.π1, b.π1)] ?: Int.MAX_VALUE).. manhattanDistance(a.coords(), b.coords()) private fun IntRange.overlaps(other: IntRange) = diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt index 7a27a279..844533a2 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt @@ -52,6 +52,9 @@ val CFG.tmap: Map, Set<Σᐩ>> by cache { .mapValues { it.value.map { it.second }.toSet() } } +val CFG.ntLst by cache { (symbols + "ε").toList() } +val CFG.ntMap by cache { ntLst.mapIndexed { i, s -> s to i }.toMap() } + // Maps each nonterminal to the set of nonterminals that can generate it val CFG.vindex: Array by cache { Array(bindex.indexedNTs.size) { i -> @@ -256,8 +259,9 @@ class BiMap(cfg: CFG) { } // n.b., this only works if the CFG is acyclic, i.e., finite otherwise it will loop forever -fun CFG.toPTree(from: Σᐩ = START_SYMBOL): PTree = - PTree(from, bimap[from].map { toPTree(it[0]) to if (it.size == 1) PTree() else toPTree(it[1]) }) +fun CFG.toPTree(from: Σᐩ = START_SYMBOL, origCFG: CFG = this): PTree = + PTree(from, bimap[from].map { toPTree(it[0], origCFG) to if (it.size == 1) PTree() else toPTree(it[1], origCFG) }) + .also { it.ntIdx = (origCFG.ntMap[(if('~' in from) from.split('~')[1] else from)] ?: Int.MAX_VALUE) } /* Γ ⊢ ∀ v.[α→*]∈G ⇒ α→[β] "If all productions rooted at α diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt index ec31ed8c..d9b6f0a9 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt @@ -17,6 +17,7 @@ typealias PForest = Map // ℙ₃ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf()) { // val hash by lazy { root.hashCode() + if (branches.isEmpty()) 0 else branches.hashCode() } // override fun hashCode(): Int = hash + var ntIdx = -1 val branchRatio: Pair by lazy { if (branches.isEmpty()) 0.0 to 0.0 else (branches.size.toDouble() + branches.sumOf { (l, r) -> l.branchRatio.first + r.branchRatio.first }) to @@ -120,10 +121,10 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() while (i < 9 * totalTrees) yield(decodeString(i++ * stride + offset).first) } - fun sampleStrWithPCFG5(pcfgTable: Map): Sequence = + fun sampleStrWithPCFG5(pcfgTable: Map): Sequence = sequence { while (true) yield(samplePCFG5(pcfgTable)) } - fun sampleStrWithPCFG3(pcfgTable: Map<Π3A<Σᐩ>, Int>): Sequence = + fun sampleStrWithPCFG3(pcfgTable: Map): Sequence = sequence { while (true) yield(samplePCFG3(pcfgTable)) } // Samples instantaneously from the parse forest, but may return duplicates @@ -152,28 +153,32 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b" } - fun Σᐩ.name() = if ("~" in this) split("~")[1] else this - val triples : List<Π3A<Σᐩ>> by lazy { branches.map { root.name() to it.first.root.name() to it.second.root.name() } } + fun Σᐩ.name() = if ('~' in this) split('~')[1] else this + val triples : List<Π2A> by lazy { branches.map { it.first.ntIdx to it.second.ntIdx } } val rootName by lazy { root.name() } val isLeaf by lazy { branches.isEmpty() } - fun samplePCFG5(pcfgTable: Map, upUp: Σᐩ = "NIL", upLeft: Σᐩ = "NIL", upRight: Σᐩ = "NIL"): Σᐩ { + fun samplePCFG5(pcfgTable: Map, upUp: Int = 0, upLeft: Int = 0, upRight: Int = 0): Σᐩ { if (isLeaf) return epsStr - val probs = triples.map { (pcfgTable[StrQuintuple(upUp, upLeft, upRight, it.second, it.third)] ?: 1) + 1 } + val probs = triples.map { + val hash = hash(upUp, upLeft, upRight, it.first, it.second) + (pcfgTable[hash] ?: 1) +// .also { if(Random.nextInt(10000) == 3) if (it == 1) println("$hash Miss"); else println("$hash Hit") } + + 1 } val cdf = probs.runningReduce { acc, i -> acc + i } val rnd = Random.nextInt(probs.sum()) val childIdx = cdf.binarySearch { it.compareTo(rnd) }.let { if (it < 0) -it - 1 else it } val (l, r) = branches[childIdx] - val (lr, rr) = l.rootName to r.rootName - val (a, b) = l.samplePCFG5(pcfgTable, rootName, "$lr*", rr) to - r.samplePCFG5(pcfgTable, rootName, lr, "$rr*") + val (lr, rr) = l.ntIdx to r.ntIdx + val (a, b) = l.samplePCFG5(pcfgTable, ntIdx, 31 * lr, rr) to + r.samplePCFG5(pcfgTable, ntIdx, lr, 31 * rr) return if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b" } - fun samplePCFG3(pcfgTable: Map<Π3A<Σᐩ>, Int>): Σᐩ { + fun samplePCFG3(pcfgTable: Map): Σᐩ { if (branches.isEmpty()) return epsStr - val probs = triples.map { (pcfgTable[it] ?: 1) + 1 } + val probs = triples.map { (pcfgTable[hash(ntIdx, it.first, it.second)] ?: 1) + 1 } val cdf = probs.runningReduce { acc, i -> acc + i } val rnd = Random.nextInt(probs.sum()) val childIdx = cdf.binarySearch { it.compareTo(rnd) }.let { if (it < 0) -it - 1 else it } @@ -206,12 +211,6 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() // } } -data class StrQuintuple(val a: String, val b: String, val c: String, val d: String, val e: String) { - val hash = a.hashCode() + b.hashCode() + c.hashCode() + d.hashCode() + e.hashCode() - override fun hashCode(): Int = hash - override fun equals(other: Any?): Boolean = other is StrQuintuple && other.hash == hash -} - fun CFG.startPTree(tokens: List) = //measureTimedValue { initPForestMat(tokens).seekFixpoint().diagonals.last()[0][START_SYMBOL] //}.also { println("Took ${it.duration} to compute parse forest") }.value diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/repair/SyntaxRepair.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/repair/SyntaxRepair.kt index a232f549..41c89e45 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/repair/SyntaxRepair.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/repair/SyntaxRepair.kt @@ -10,6 +10,8 @@ import ai.hypergraph.kaliningraph.types.Π2A import kotlin.math.pow import kotlin.time.* +var CFG_THRESH = 20_000 +var MAX_UNIQUE = 20_000 // Maximum number of unique samples to generate var MAX_SAMPLE = 20 // Maximum number of repairs to sample var MAX_TOKENS = 80 // Maximum number of tokens per repair var MAX_RADIUS = 4 diff --git a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt index 68e13ba7..1125d0ee 100644 --- a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt +++ b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt @@ -75,31 +75,27 @@ fun CFG.sampleDirectlyWR( } } -fun CFG.sampleWithPCFG( - pcfgTable: Map, +fun PTree.sampleWithPCFG( + pcfgTable: Map, cores: Int = NUM_CORES, stoppingCriterion: () -> Boolean = { true } ): Stream = - toPTree().let { - (0.. - it.sampleStrWithPCFG5(pcfgTable) - .takeWhile { stoppingCriterion() } - .distinct() - .asStream() - } + (0.. + sampleStrWithPCFG5(pcfgTable) + .takeWhile { stoppingCriterion() } + .distinct() + .asStream() } -fun CFG.sampleDirectlyWOR( +fun PTree.sampleDirectlyWOR( cores: Int = NUM_CORES, stoppingCriterion: () -> Boolean = { true } ): Stream = - toPTree().let { - (0.. - it.sampleStrWithoutReplacement(cores, i) - .takeWhile { stoppingCriterion() } - .distinct() - .asStream() - } + (0.. + sampleStrWithoutReplacement(cores, i) + .takeWhile { stoppingCriterion() } + .distinct() + .asStream() } fun CFG.parallelEnumListWR( @@ -169,8 +165,6 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG { // For each production A → BC in P, for every p, q, r ∈ Q, // we have the production [p,A,r] → [p,B,q] [q,C,r] in P′. - val ntLst = nonterminals.toList() - val ntMap = ntLst.mapIndexed { i, s -> s to i }.toMap() val prods: Set = nonterminalProductions .map { (a, b) -> ntMap[a]!! to b.map { ntMap[it]!! } }.toSet() val lengthBoundsCache = lengthBounds.let { lb -> nonterminals.map { lb[it]!! } }