From 89591fa3a4762bfa0916c067841dad727c34b322 Mon Sep 17 00:00:00 2001 From: breandan Date: Mon, 29 Apr 2024 19:52:28 -0400 Subject: [PATCH] speed up slow tests --- .../kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt | 5 +++-- .../ai/hypergraph/kaliningraph/parsing/BarHillelTest.kt | 3 ++- .../kotlin/ai/hypergraph/kaliningraph/parsing/Grammars.kt | 1 + .../kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt | 3 ++- .../ai/hypergraph/kaliningraph/repair/ProbabilisticLBH.kt | 3 ++- .../kotlin/ai/hypergraph/kaliningraph/sat/SATValiantTest.kt | 4 ++-- 6 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt index 8ee3a03c..2f7cf034 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt @@ -14,7 +14,8 @@ import kotlin.time.TimeSource infix fun FSA.intersectLevFSA(cfg: CFG) = cfg.intersectLevFSA(this) -infix fun CFG.intersectLevFSA(fsa: FSA): CFG = intersectLevFSAP(fsa) +fun CFG.intersectLevFSA(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CFG = + intersectLevFSAP(fsa, parikhMap) // subgrammar(fsa.alphabet) // .also { it.forEach { println("${it.LHS} -> ${it.RHS.joinToString(" ")}") } } @@ -26,7 +27,7 @@ fun CFG.barHillelRepair(prompt: List<Σᐩ>, distance: Int) = // http://www.cs.umd.edu/~gasarch/BLOGPAPERS/cfg.pdf#page=2 // https://browse.arxiv.org/pdf/2209.06809.pdf#page=5 -private infix fun CFG.intersectLevFSAP(fsa: FSA): CFG { +private fun CFG.intersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CFG { var clock = TimeSource.Monotonic.markNow() val nts = mutableSetOf("START") fun Σᐩ.isSyntheticNT() = diff --git a/src/commonTest/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillelTest.kt b/src/commonTest/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillelTest.kt index 7565f3d7..37d8c9d1 100644 --- a/src/commonTest/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillelTest.kt +++ b/src/commonTest/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillelTest.kt @@ -1,6 +1,7 @@ package ai.hypergraph.kaliningraph.parsing import Grammars +import Grammars.shortS2PParikhMap import ai.hypergraph.kaliningraph.* import ai.hypergraph.kaliningraph.automata.* import kotlin.test.* @@ -319,7 +320,7 @@ class BarHillelTest { println(levBall.states.size) // println(levBall.toDot()) // throw Exception("") - val intGram = gram.intersectLevFSA(levBall) + val intGram = gram.intersectLevFSA(levBall, shortS2PParikhMap) val clock = TimeSource.Monotonic.markNow() diff --git a/src/commonTest/kotlin/ai/hypergraph/kaliningraph/parsing/Grammars.kt b/src/commonTest/kotlin/ai/hypergraph/kaliningraph/parsing/Grammars.kt index f5f8e769..eeb00031 100644 --- a/src/commonTest/kotlin/ai/hypergraph/kaliningraph/parsing/Grammars.kt +++ b/src/commonTest/kotlin/ai/hypergraph/kaliningraph/parsing/Grammars.kt @@ -114,6 +114,7 @@ object Grammars { T -> [ Q ] """.trimIndent().parseCFG().noNonterminalStubs + val shortS2PParikhMap by lazy { ParikhMap(seq2parsePythonCFG, 20) } val seq2parsePythonCFG: CFG = """ START -> Stmts_Or_Newlines Stmts_Or_Newlines -> Stmt_Or_Newline | Stmt_Or_Newline Stmts_Or_Newlines diff --git a/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt b/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt index 4244408d..05c74bc4 100644 --- a/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt +++ b/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt @@ -1,6 +1,7 @@ package ai.hypergraph.kaliningraph.automata import Grammars +import Grammars.shortS2PParikhMap import ai.hypergraph.kaliningraph.parsing.* import net.jhoogland.jautomata.* import net.jhoogland.jautomata.Automaton @@ -38,7 +39,7 @@ class WFSATest { fun testLBHRepair() { val toRepair = "NAME : NEWLINE NAME = STRING NEWLINE NAME = NAME . NAME ( STRING ) NEWLINE" val radius = 1 - val pt = Grammars.seq2parsePythonCFG.makeLevPTree(toRepair, radius) + val pt = Grammars.seq2parsePythonCFG.makeLevPTree(toRepair, radius, shortS2PParikhMap) val repairs = pt.sampleStrWithoutReplacement().distinct().take(100).toSet() println("Found ${repairs.size} repairs by enumerating PTree") measureTimedValue { diff --git a/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/repair/ProbabilisticLBH.kt b/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/repair/ProbabilisticLBH.kt index 68893113..03236f86 100644 --- a/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/repair/ProbabilisticLBH.kt +++ b/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/repair/ProbabilisticLBH.kt @@ -1,6 +1,7 @@ package ai.hypergraph.kaliningraph.repair import Grammars +import Grammars.shortS2PParikhMap import ai.hypergraph.kaliningraph.parsing.* import ai.hypergraph.kaliningraph.tokenizeByWhitespace import ai.hypergraph.markovian.* @@ -115,7 +116,7 @@ class ProbabilisticLBH { val clock = TimeSource.Monotonic.markNow() val levBall = makeLevFSA(source.tokenizeByWhitespace(), levDist) - val intGram = gram.jvmIntersectLevFSA(levBall) + val intGram = gram.jvmIntersectLevFSA(levBall, shortS2PParikhMap) println("Finished ${intGram.size}-prod ∩-grammar in ${clock.elapsedNow()}") val lbhSet = intGram.toPTree().sampleDirectlyWOR() .takeWhile { clock.elapsedNow().inWholeSeconds < 30 }.collect(Collectors.toSet()) diff --git a/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/sat/SATValiantTest.kt b/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/sat/SATValiantTest.kt index cf2a1e74..e29e148e 100644 --- a/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/sat/SATValiantTest.kt +++ b/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/sat/SATValiantTest.kt @@ -598,10 +598,10 @@ class SATValiantTest { """.trimIndent().parseCFG() /* - ./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.sat.SATValiantTest.testLevensheteinIntersection" + ./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.sat.SATValiantTest.testLevenshteinIntersection" */ @Test - fun testLevensheteinIntersection() { + fun testLevenshteinIntersection() { val cfg = sumCFG.noNonterminalStubs val strWithParseErr = "1 + 2 + + +".tokenizeByWhitespace() val dist = 2