diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt index 1ffe9f8a..8ee3a03c 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt @@ -62,6 +62,8 @@ private infix fun CFG.intersectLevFSAP(fsa: FSA): CFG { ct.filter { fsa.obeys(it.π1, it.π2, it.π3, parikhMap) } .forEach { ct2[it.π1.π1][it.π3][it.π2.π1] = true } + val states = fsa.stateLst + val allsym = ntLst val binaryProds = prods.map { // if (i % 100 == 0) println("Finished ${i}/${nonterminalProductions.size} productions") @@ -74,8 +76,8 @@ private infix fun CFG.intersectLevFSAP(fsa: FSA): CFG { .filter { it.checkCompatibility(trip, ct2) } // .filter { it.obeysLevenshteinParikhBounds(A to B to C, fsa, parikhMap) } .map { (a, b, c) -> - val (p, q, r) = fsa.stateLst[a] to fsa.stateLst[b] to fsa.stateLst[c] - "[$p~${ntLst[A]}~$r]".also { nts.add(it) } to listOf("[$p~${ntLst[B]}~$q]", "[$q~${ntLst[C]}~$r]") + val (p, q, r) = states[a] to states[b] to states[c] + "[$p~${allsym[A]}~$r]".also { nts.add(it) } to listOf("[$p~${allsym[B]}~$q]", "[$q~${allsym[C]}~$r]") }.toList() }.flatten().filterRHSInNTS() diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt index 96bef62d..70fdaba5 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt @@ -7,6 +7,7 @@ import ai.hypergraph.kaliningraph.tokenizeByWhitespace import ai.hypergraph.kaliningraph.types.* import kotlin.jvm.JvmName import kotlin.time.* +import kotlin.time.Duration.Companion.seconds typealias Σᐩ = String typealias Production = Π2<Σᐩ, List<Σᐩ>> diff --git a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt index b72dcebc..f1b0cc45 100644 --- a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt +++ b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt @@ -167,19 +167,24 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG { // we have the production [p,A,r] → [p,B,q] [q,C,r] in P′. val prods = nonterminalProductions .map { (a, b) -> ntMap[a]!! to b.map { ntMap[it]!! } }.toSet() -// val lengthBoundsCache = lengthBounds.let { lb -> nonterminals.map { lb[it] ?: 0..0 } } + val lengthBoundsCache = lengthBounds.let { lb -> nonterminals.map { lb[it] ?: 0..0 } } val validTriples = fsa.validTriples.map { arrayOf(it.π1.π1, it.π2.π1, it.π3.π1) } + val ctClock = TimeSource.Monotonic.markNow() val ct = (fsa.validPairs * nonterminals.indices.toSet()).toList() val ct2 = Array(fsa.states.size) { Array(nonterminals.size) { Array(fsa.states.size) { false } } } ct.parallelStream() - .filter { fsa.obeys(it.π1, it.π2, it.π3, parikhMap) } - .toList().also { + .filter { + lengthBoundsCache[it.π3].overlaps(fsa.SPLP(it.π1, it.π2)) && + fsa.obeys(it.π1, it.π2, it.π3, parikhMap) + }.toList().also { val fraction = it.size.toDouble() / (fsa.states.size * nonterminals.size * fsa.states.size) println("Fraction of valid triples: $fraction") }.forEach { ct2[it.π1.π1][it.π3][it.π2.π1] = true } + println("Precomputed LP constraints in ${ctClock.elapsedNow()}") - val elimCounter = AtomicInteger(0) + val states = fsa.stateLst + val allsym = ntLst val counter = AtomicInteger(0) val lpClock = TimeSource.Monotonic.markNow() val binaryProds = @@ -190,17 +195,18 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG { validTriples.stream() // CFG ∩ FSA - in general we are not allowed to do this, but it works // because we assume a Levenshtein FSA, which is monotone and acyclic. -// .filter { it.isCompatibleWith(A to B to C, fsa, lengthBoundsCache).also { if (!it) elimCounter.incrementAndGet() } } +// .filter { it.isCompatibleWith(A to B to C, fsa, lengthBoundsCache) } // .filter { it.checkCT(trip, ct1).also { if (!it) elimCounter.incrementAndGet() } } -// .filter { it.obeysLevenshteinParikhBounds(A to B to C, fsa, parikhMap).also { if (!it) elimCounter.incrementAndGet() } } - .filter { it.checkCompatibility(trip, ct2).also { if (!it) elimCounter.incrementAndGet() } } +// .filter { it.obeysLevenshteinParikhBounds(A to B to C, fsa, parikhMap) } + .filter { it.checkCompatibility(trip, ct2) } .map { (a, b, c) -> if (MAX_PRODS < counter.incrementAndGet()) throw Exception("∩-grammar has too many productions! (>$MAX_PRODS)") - val (p, q, r) = fsa.stateLst[a] to fsa.stateLst[b] to fsa.stateLst[c] - "[$p~${ntLst[A]}~$r]".also { nts.add(it) } to listOf("[$p~${ntLst[B]}~$q]", "[$q~${ntLst[C]}~$r]") + val (p, q, r) = states[a] to states[b] to states[c] + "[$p~${allsym[A]}~$r]".also { nts.add(it) } to listOf("[$p~${allsym[B]}~$q]", "[$q~${allsym[C]}~$r]") } }.toList() + val elimCounter = (validTriples.size * prods.size) - binaryProds.size println("Levenshtein-Parikh constraints eliminated $elimCounter productions in ${lpClock.elapsedNow()}") fun Σᐩ.isSyntheticNT() = @@ -213,6 +219,7 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG { return Stream.concat(binaryProds.stream(), (initFinal + transits + unitProds).stream()).parallel() .filter { (_, rhs) -> rhs.all { !it.isSyntheticNT() || it in nts } } .collect(Collectors.toSet()) + .also { println("Eliminated ${totalProds - it.size} extra productions before normalization") } .jvmPostProcess(clock) } @@ -249,15 +256,15 @@ fun CFG.jvmDropVestigialProductions(clock: TimeSource.Monotonic.ValueTimeMark): val nts: Set<Σᐩ> = asSequence().asStream().parallel().map { it.first }.collect(Collectors.toSet()) val rw = asSequence().asStream().parallel() .filter { prod -> - if (counter.incrementAndGet() % 10 == 0 && BH_TIMEOUT < clock.elapsedNow()) throw Exception("Timeout!") + if (counter.incrementAndGet() % 10 == 0 && BH_TIMEOUT < clock.elapsedNow()) throw Exception("Timeout! ${clock.elapsedNow()}") // Only keep productions whose RHS symbols are not synthetic or are in the set of NTs prod.RHS.all { !(it.first() == '[' && 1 < it.length) || it in nts } } .collect(Collectors.toSet()) -// .also { println("Removed ${size - it.size} invalid productions in ${clock.elapsedNow() - start}") } + .also { println("Removed ${size - it.size} invalid productions in ${clock.elapsedNow() - start}") } .freeze().jvmRemoveUselessSymbols() -// println("Removed ${size - rw.size} vestigial productions, resulting in ${rw.size} productions.") + println("Removed ${size - rw.size} vestigial productions, resulting in ${rw.size} productions.") return if (rw.size == size) rw else rw.jvmDropVestigialProductions(clock) } @@ -293,6 +300,10 @@ private fun CFG.jvmReachSym(from: Σᐩ = START_SYMBOL): Set<Σᐩ> { this@jvmReachSym.asSequence().asStream().parallel() .forEach { (l, r) -> getOrPut(l) { ConcurrentSkipListSet() }.addAll(r) } } +// this@jvmReachSym.asSequence().asStream().parallel() +// .flatMap { (l, r) -> r.stream().map { l to it } } +// // List of second elements grouped by first element +// .collect(Collectors.groupingByConcurrent({ it.first }, Collectors.mapping({ it.second }, Collectors.toSet()))) do { val t = nextReachable.first() @@ -320,6 +331,10 @@ private fun CFG.jvmGenSym( this@jvmGenSym.asSequence().asStream().parallel() .forEach { (l, r) -> r.forEach { getOrPut(it) { ConcurrentSkipListSet() }.add(l) } } } +// this@jvmGenSym.asSequence().asStream().parallel() +// .flatMap { (l, r) -> r.asSequence().asStream().map { it to l } } +// // List of second elements grouped by first element +// .collect(Collectors.groupingByConcurrent({ it.first }, Collectors.mapping({ it.second }, Collectors.toList()))) do { val t = nextGenerating.first()