Skip to content

Commit

Permalink
use global Python grammar
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Oct 9, 2024
1 parent b1328cf commit 39d3f84
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class ParikhMap(val cfg: CFG, val size: Int, reconstruct: Boolean = true) {
val pm = deserializePM(str.substringBefore("\n\n====\n\n"))
val lb = str.substringAfter("\n\n====\n\n").lines().map { it.split(" ") }
.associate { it.first().toInt() to it.drop(1).toSet() }
println("Deserialized Parikh Map with ${pm.size} lengths and ${lb.size} bounds")
return ParikhMap(cfg, pm.size, false).apply {
parikhMap.putAll(pm)
lengthBounds.putAll(lb)
Expand Down Expand Up @@ -120,7 +121,7 @@ class ParikhMap(val cfg: CFG, val size: Int, reconstruct: Boolean = true) {
val template = List(size) { "_" }
cfg.initPForestMat(template).seekFixpoint().diagonals
.forEachIndexed { i, it ->
println("Computing length $i")
println("Computing PM length $i/$size with ${it.size} keys")
lengthBounds[i + 1] = it.first().keys
parikhMap[i + 1] = it.first().mapValues { it.value.parikhBounds }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Grammars
import Grammars.shortS2PParikhMap
import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.automata.*
import ai.hypergraph.kaliningraph.repair.vanillaS2PCFG
import kotlin.test.*
import kotlin.time.*

Expand Down Expand Up @@ -247,7 +248,7 @@ class BarHillelTest {
*/
@Test
fun testPythonBarHillel() {
val gram = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val gram = vanillaS2PCFG.noEpsilonOrNonterminalStubs
val origStr = "NAME = ( NAME . NAME ( NAME NEWLINE"
val toRepair = origStr.tokenizeByWhitespace()
val maxLevDist = 2
Expand Down Expand Up @@ -282,7 +283,7 @@ class BarHillelTest {
// Found 6987 minimal solutions using Levenshtein/Bar-Hillel
// Enumerative solver took 360184ms

val s2pg = Grammars.seq2parsePythonCFG
val s2pg = vanillaS2PCFG
val prbSet = s2pg.fasterRepairSeq(toRepair, 1, 3)
.takeWhile { clock.elapsedNow().inWholeSeconds < 90 }.distinct()
.mapIndexedNotNull { i, it ->
Expand Down Expand Up @@ -313,7 +314,7 @@ class BarHillelTest {
*/
@Test
fun semiRealisticTest() {
val gram = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val gram = vanillaS2PCFG.noEpsilonOrNonterminalStubs
val origStr = "NAME = NAME . NAME ( [ NUMBER , NUMBER , NUMBER ] NEWLINE"
val toRepair = origStr.tokenizeByWhitespace()
val levDist = 2
Expand All @@ -336,7 +337,7 @@ class BarHillelTest {
.also { println("Found ${it.size} minimal solutions using " +
"Levenshtein/Bar-Hillel in ${clock.elapsedNow()}") }

val s2pg = Grammars.seq2parsePythonCFG
val s2pg = vanillaS2PCFG
val prbSet = s2pg.fasterRepairSeq(toRepair, 1, 2)
.takeWhile { clock.elapsedNow().inWholeSeconds < 90 }.distinct()
.mapIndexedNotNull { i, it ->
Expand Down Expand Up @@ -365,14 +366,14 @@ class BarHillelTest {
*/
@Test
fun levenshteinBlanketTest() {
val gram = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val gram = vanillaS2PCFG.noEpsilonOrNonterminalStubs
val origStr= "NAME = NAME . NAME ( [ NUMBER , NUMBER , NUMBER ] NEWLINE"
val toRepair = origStr.tokenizeByWhitespace()
val levDist = 2
val levBall = makeLevFSA(toRepair, levDist)
val clock = TimeSource.Monotonic.markNow()

val s2pg = Grammars.seq2parsePythonCFG
val s2pg = vanillaS2PCFG
s2pg.fasterRepairSeq(toRepair, 1, 2).distinct()
.mapIndexedNotNull { i, it ->
val levDistance = levenshtein(origStr, it)
Expand All @@ -395,7 +396,7 @@ class BarHillelTest {
@Test
fun testHammingBallRepair() {
val timeout = 30
val gram = Grammars.seq2parsePythonCFG
val gram = vanillaS2PCFG
val prompt= "NAME = ( NAME . NAME ( NAME NEWLINE".tokenizeByWhitespace()
val clock = TimeSource.Monotonic.markNow()
val lbhSet = gram.repairSeq(prompt).onEach { println(it) }
Expand All @@ -408,7 +409,7 @@ class BarHillelTest {
*/
@Test
fun testAllBlankSampler() {
val gram = Grammars.seq2parsePythonCFG
val gram = vanillaS2PCFG
val n = 10
gram.startPTree(List(n) { "_" })?.also {
it.sampleWRGD().map { it.removeEpsilon() }.distinct()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.repair.vanillaS2PCFG
import kotlin.test.*

/*
Expand All @@ -11,7 +12,7 @@ class ParikhTest {
*/
@Test
fun testParikhBounds() {
val cfg = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val cfg = vanillaS2PCFG
val parikhMap = ParikhMap(cfg, 10)
(1..10).forEach { i ->
cfg.enumSeq(List(i) { "_" }).take(10).forEach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Grammars.seq2parsePythonVanillaCFG
import Grammars.tinyC
import Grammars.toyArith
import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.repair.vanillaS2PCFG
import org.kosat.round
import kotlin.random.Random
import kotlin.test.*
Expand All @@ -21,19 +22,19 @@ class SeqValiantTest {
@Test
fun testSeqValiant() {
var clock = TimeSource.Monotonic.markNow()
val detSols = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val detSols = vanillaS2PCFG
.enumSeq(List(20) {"_"})
.take(10_000).sortedBy { it.length }.toList()

detSols.forEach { assertTrue("\"$it\" was invalid!") { it in Grammars.seq2parsePythonCFG.language } }
detSols.forEach { assertTrue("\"$it\" was invalid!") { it in vanillaS2PCFG.language } }

var elapsed = clock.elapsedNow().inWholeMilliseconds
println("Found ${detSols.size} determinstic solutions in ${elapsed}ms or ~${detSols.size / (elapsed/1000.0)}/s, all were valid!")

clock = TimeSource.Monotonic.markNow()
val randSols = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val randSols = vanillaS2PCFG
.sampleSeq(List(20) { "_" }).take(10_000).toList().distinct()
.onEach { assertTrue("\"$it\" was invalid!") { it in Grammars.seq2parsePythonCFG.language } }
.onEach { assertTrue("\"$it\" was invalid!") { it in vanillaS2PCFG.language } }

// 10k in ~22094ms
elapsed = clock.elapsedNow().inWholeMilliseconds
Expand Down Expand Up @@ -140,7 +141,7 @@ class SeqValiantTest {
val template = List(refLst.size + 3) { "_" }
println("Solving: $template")
measureTime {
Grammars.seq2parsePythonCFG.enumSeq(template)
vanillaS2PCFG.enumSeq(template)
.map { it to levenshtein(it, refStr) }
.filter { it.second < 4 }.distinct().take(100)
.sortedWith(compareBy({ it.second }, { it.first.length }))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ai.hypergraph.kaliningraph.parsing
import Grammars
import Grammars.toyArith
import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.repair.vanillaS2PCFG
import ai.hypergraph.kaliningraph.tensor.seekFixpoint
import ai.hypergraph.kaliningraph.types.π2
import kotlinx.datetime.Clock
Expand Down Expand Up @@ -345,7 +346,7 @@ class SetValiantTest {
*/
@Test
fun testUnitParse() {
assertNotNull(Grammars.seq2parsePythonCFG.parse("NEWLINE"))
assertNotNull(vanillaS2PCFG.parse("NEWLINE"))
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ import kotlin.time.Duration.Companion.seconds
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.repair.ProbabilisticLBH"
*/
class ProbabilisticLBH {
init { LangCache.prepopPythonLangCache() }
val pythonTestCases =
invalidPythonStatements.lines().zip(validPythonStatements.lines())
// This ensures the LBH grammar is nonempty, otherwise extragrammatical symbols produce an error
// .map { it.first.tokenizeByWhitespace().map { if (it in Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs.terminals) it else "." }.joinToString(" ") to it.second }
.filter { it.first.tokenizeByWhitespace().all { it in Grammars.seq2parsePythonCFG.terminals } }
// .map { it.first.tokenizeByWhitespace().map { if (it in vanillaS2PCFG.noEpsilonOrNonterminalStubs.terminals) it else "." }.joinToString(" ") to it.second }
.filter { it.first.tokenizeByWhitespace().all { it in vanillaS2PCFG.terminals } }
.shuffled(Random(seed = 1)).filter { (a, b) ->
("$a NEWLINE" !in Grammars.seq2parsePythonCFG.language).also { if (!it) println("Failed invalid") }
&& ("$b NEWLINE" in Grammars.seq2parsePythonCFG.language).also { if (!it) println("Failed valid") }
("$a NEWLINE" !in vanillaS2PCFG.language).also { if (!it) println("Failed invalid") }
&& ("$b NEWLINE" in vanillaS2PCFG.language).also { if (!it) println("Failed valid") }
&& (levenshtein(a, b).also { if (it !in 1..3) println("Failed distance: $it") } in 1..3)
}.distinct().filter { it.first.tokenizeByWhitespace().size < 23 }
/*
Expand All @@ -35,7 +36,7 @@ class ProbabilisticLBH {
@Test
fun testSubgrammarEquivalence() {
val terminalImage = setOf<String>() + "NEWLINE" + validPythonStatements.tokenizeByWhitespace().toSet()
val s2pg = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val s2pg = vanillaS2PCFG
val subgrammar = s2pg.subgrammar(terminalImage)

(validPythonStatements + invalidPythonStatements).lines()
Expand All @@ -48,7 +49,7 @@ class ProbabilisticLBH {
@Test
fun testSubgrammar() {
val terminalImage = setOf<String>() + "NEWLINE" + validPythonStatements.tokenizeByWhitespace().toSet()
val s2pg = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val s2pg = vanillaS2PCFG
val subgrammar = s2pg.subgrammar(terminalImage)
println("Original size: ${s2pg.size}")
println("Subgrammar size: ${subgrammar.size}")
Expand Down Expand Up @@ -79,24 +80,24 @@ class ProbabilisticLBH {
//// subgrammar.parseInvalidWithMaximalFragments(pp).forEach { println(it.prettyPrint() + "\n\n") }
// println(s2pg.parse(pp)!!.prettyPrint())
// println(lastGood.first { it.isNotEmpty() }.first().prettyPrint())
assertTrue(pp in s2pg.language, "$it\nnot in Grammars.seq2parsePythonCFG!")
assertTrue(pp in s2pg.language, "$it\nnot in vanillaS2PCFG!")
assertTrue(pp in subgrammar.language, "$it\nnot in subgrammar!")
}
subgrammar.sampleSeq(List(20) {"_"}).take(100).forEach { pp ->
assertTrue(pp in Grammars.seq2parsePythonCFG.language, "$pp\nnot in Grammars.seq2parsePythonCFG!")
assertTrue(pp in vanillaS2PCFG.language, "$pp\nnot in vanillaS2PCFG!")
assertTrue(pp in subgrammar.language, "$pp\nnot in subgrammar!")
}
}

val topTerms by lazy {
contextCSV.allProbs.entries
.filter { it.key.type != EditType.DEL }
.groupingBy { Grammars.seq2parsePythonCFG.getS2PNT(it.key.newMid) }
.groupingBy { vanillaS2PCFG.getS2PNT(it.key.newMid) }
.aggregate { _, acc: Int?, it, _ -> (acc ?: 0) + it.value }
.map { (k, v) -> k to v }
.sortedBy { -it.second }
// .onEach { println("${it.first}≡${Grammars.seq2parsePythonCFG.bimap[it.first]}: ${it.second}") }
.mapNotNull { Grammars.seq2parsePythonCFG.bimap[it.first].firstOrNull() }
// .onEach { println("${it.first}≡${vanillaS2PCFG.bimap[it.first]}: ${it.second}") }
.mapNotNull { vanillaS2PCFG.bimap[it.first].firstOrNull() }
.take(20)
.toSet()
}
Expand All @@ -109,7 +110,7 @@ class ProbabilisticLBH {
fun threeEditRepair() {
val source = "NAME = { STRING = NUMBER , STRING = NUMBER , STRING = NUMBER } NEWLINE"
val repair = "NAME = { STRING : NUMBER , STRING : NUMBER , STRING : NUMBER } NEWLINE"
val gram = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val gram = vanillaS2PCFG
// MAX_TOKENS = source.tokenizeByWhitespace().size + 5
// MAX_RADIUS = 3
val levDist = 3
Expand All @@ -134,7 +135,7 @@ class ProbabilisticLBH {
invalidPythonStatements.lines().shuffled().take(10).forEach {
val toRepair = "$it NEWLINE".tokenizeByWhitespace()
println("Repairing: ${toRepair.joinToString(" ")}\nRepairs:\n")
Grammars.seq2parsePythonCFG.fasterRepairSeq(toRepair)
vanillaS2PCFG.fasterRepairSeq(toRepair)
.filter { it.isNotEmpty() }.distinct().take(10).forEach {
println(levenshteinAlign(toRepair, it.tokenizeByWhitespace()).paintANSIColors())
}
Expand All @@ -146,7 +147,7 @@ class ProbabilisticLBH {
*/
// @Test
fun testCompleteness() {
val s2pg = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val s2pg = vanillaS2PCFG
val TIMEOUT_MINS = 2
val totalTrials = 10
var currentTrials = 0
Expand Down Expand Up @@ -237,7 +238,7 @@ class ProbabilisticLBH {
val sampleTimeByLevDist = mutableMapOf(1 to 0.0, 2 to 0.0, 3 to 0.0)
val allTimeByLevDist = mutableMapOf(1 to 0.0, 2 to 0.0, 3 to 0.0)
val samplesBeforeMatchByLevDist = mutableMapOf(1 to 0.0, 2 to 0.0, 3 to 0.0)
val s2pg = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val s2pg = vanillaS2PCFG

invalidPythonStatements.lines().zip(validPythonStatements.lines())
// .filter { (invalid, valid) -> 3 == levenshtein(invalid, valid) }.take(50)
Expand Down Expand Up @@ -311,7 +312,7 @@ class ProbabilisticLBH {
val clock = TimeSource.Monotonic.markNow()

val levDist = 2
val s2pg = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val s2pg = vanillaS2PCFG
val levBall = makeLevFSA(toRepair, levDist)
val intGram = s2pg.jvmIntersectLevFSA(levBall)
val template = List(toRepair.size + levDist) { "_" }
Expand All @@ -334,7 +335,7 @@ class ProbabilisticLBH {
validPythonStatements
.lines()
.shuffled()
.flatMap { (0..10).map { _ -> Grammars.seq2parsePythonCFG to it.maskRandomIndices(holes) } }
.flatMap { (0..10).map { _ -> vanillaS2PCFG to it.maskRandomIndices(holes) } }
.filter { (a, b) ->
val clock = TimeSource.Monotonic.markNow()
a.sampleSWOR(b).takeWhile { clock.elapsedNow() < 2.seconds }.distinct().toList().size > 1
Expand Down

0 comments on commit 39d3f84

Please sign in to comment.