Skip to content

Commit

Permalink
implement GRE propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Dec 3, 2024
1 parent 03f4b8f commit 5e63a0d
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 4 deletions.
4 changes: 2 additions & 2 deletions latex/tacm2024/tacm_poster.tex
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@

\mysection{Motivation}
\null\hspace*{3cm}\begin{minipage}[c]{0.85\columnwidth}
Suppose we want to force an autoregressive LLM to generate syntactically valid next tokens $P(x_n \mid x_1, \ldots, x_{n-1})$, under certain resource constraints. Here is a concrete example: ``Generate an arithmetic expression with two or more variables in ten or fewer tokens.''. If we sample the partial trajectory,
Suppose we want to force an autoregressive LLM to generate syntactically valid next tokens $P(x_n \mid x_1, \ldots, x_{n-1})$, under certain resource constraints. Here is a concrete example: ``Generate an arithmetic expression with two or more variables in ten or fewer tokens.'' If we sample the partial trajectory,
\begin{center}\texttt{( x + ( y * }\underline{\texttt{(}}\end{center}\\
then we will spend quite a long time rejecting invalid completions, because this trajectory has passed the point of no return. Even though \texttt{(} is a locally valid continuation, we need to avoid this scenario, because we would like a linear sampling delay and to guarantee this, we must avoid backtracking.
\end{minipage}
Expand Down Expand Up @@ -398,7 +398,7 @@
Consider a time series, $A$, whose points which are not too close nor far apart, and $n \leq \sum_{i=1}^{|A|} \mathbf{1}[A_i = \bs]$. We want to sample the typical set using an LLM.\vspace{0.5cm}
\begin{itemize}[leftmargin=2cm]
\item The words are bitvectors of some length, $T$, i.e., $A = \{\ws, \bs\}^T$
\item Consecutive $\bs$ separated by $\ws^{[a,b]}$, i.e., $B = \ws^*(\bs\ws^{[a, b]})^{[n,\infty)}\{\bs,\epsilon\}\ws^*$
\item Consecutive $\bs$ separated by $\ws^{[a,b]}$, i.e., $B = \ws^*(\bs\ws^{[a, b]})^{[n,\infty)}\{\bs,\varepsilon\}\ws^*$
\end{itemize}\vspace{0.5cm}

The DPP language is regular. Let $C$ be an FSA such that $\mathcal{L}(C) = \mathcal{L}(A) \cap \mathcal{L}(B)$. For example, here is the minimal automaton for $T=13, a=3, b=5, n=2$.
Expand Down
Binary file modified latex/thesis/Thesis.pdf
Binary file not shown.
2 changes: 1 addition & 1 deletion latex/thesis/content/Ch2_Formal_Language_Theory.tex
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ \chapter{\rm\bfseries Formal Language Theory}
\caption{TODO: depict product construction for finite automata here.}
\end{figure}

The goal of this thesis is to speed up the product construction by leveraging (1) parameterized complexity (2) pruning and (3) parallelization to speed up the wallclock runtime of the product construction and generalize it to CFG-REG intersections. We show it is possible to decide intersection non-emptiness in realtime for Levenshtein automata and build a tool to demonstrate it on real-world programming languages and grammars.
The goal of this thesis is to speed up the product construction by leveraging (1) parameterized complexity (2) pruning and (3) parallelization to speed up the wallclock runtime of the product construction and generalize it to CFG-REG intersections. We show it is possible to decide INE in realtime for Levenshtein automata and build a tool to demonstrate it on real-world programming languages and grammars.

Finally, we show a probabilistic extension of the REG-CFL product construction, which can be used to decode the top-K most probable words in the intersection of two languages. This is useful for applications in natural language processing, where we might want to find the most natural word that satisfies multiple constraints, such as being a valid repair with fewer than $k$ edits whose probability is maximized.

Expand Down
2 changes: 1 addition & 1 deletion latex/thesis/content/Terminology.tex
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ \chapter*{\rm\bfseries Terminology}
\item \textbf{Deterministic}: A property of a system that, given the same input, will always produce the same output.
\item \textbf{Grammar}: A set of rules that define the syntax of a language.
\item \textbf{Language}: A set of words generated by a grammar. For the purposes of this thesis, the language can be finite or infinite.
\item \textbf{Word}: A member of a language, consisting of a sequence of terminals. For the purposes of this thesis, words are always finite.
\item \textbf{Word}: A member of a language, consisting of a sequence of terminals. For the purposes of this thesis, a word is always finite.
\item \textbf{Terminal}: A single token from an alphabet. For the purposes of this thesis, the alphabet is always finite.
\item \textbf{Intersection}: The set of elements common to two or more sets.
\item \textbf{Probabilistic}: A property of a system that, given the same input, may produce different outputs.
Expand Down
70 changes: 70 additions & 0 deletions src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/GRE.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package ai.hypergraph.kaliningraph.automata

import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.kaliningraph.tensor.UTMatrix
import ai.hypergraph.kaliningraph.types.*

// Generalized regular expression: https://planetmath.org/generalizedregularexpression
sealed class GRE(vararg val args: GRE) {
companion object { operator fun invoke(s: Σᐩ) = ONE(s) }

class EPS: GRE()
class ONE(val s: Σᐩ): GRE()
class SET(val s: Set<Σᐩ>): GRE()
class NEG(val g: GRE): GRE(g)
class UNI(val l: GRE, val r: GRE): GRE(l, r)
class CAT(val l: GRE, val r: GRE): GRE(l, r)
class INT(val l: GRE, val r: GRE): GRE(l, r)

infix fun and(a: GRE): GRE = INT(this, a)
operator fun plus(g: GRE): GRE = UNI(this, g)
operator fun times(g: GRE): GRE = CAT(this, g)
operator fun not(): GRE = NEG(this)

override fun toString(): String = when (this) {
is ONE -> s
is SET -> "( ${s.joinToString(" ")} )"
is NEG -> "! ( $g )"
is UNI -> "( $l$r )"
is CAT -> "$l $r"
is INT -> "$l$r"
is EPS -> "ε"
}
}


fun CFG.initGREListMat(tokens: List<String>): UTMatrix<List<GRE?>> =
UTMatrix(
ts = tokens.map { token ->
val ptreeList = MutableList<GRE?>(nonterminals.size) { null }
(if (token != HOLE_MARKER) bimap[listOf(token)] else unitNonterminals)
.associateWith { nt ->
if (token != HOLE_MARKER) GRE.ONE(token)
else bimap.UNITS[nt]?.let { GRE.SET(it) }
}.forEach { (k, v) -> ptreeList[bindex[k]] = v }
ptreeList
}.toTypedArray(),
algebra = greAlgebra
)

val CFG.greAlgebra: Ring<List<GRE?>> by cache {
vindex.let {
Ring.of(
nil = List(nonterminals.size) { null },
plus = { x, y -> greUnion(x, y) },
times = { x, y -> greJoin(x, y) }
)
}
}

fun greUnion(l: List<GRE?>, r: List<GRE?>) =
l.zip(r) { l, r -> if (l == null) r else if (r == null) l else l + r }

fun CFG.greJoin(left: List<GRE?>, right: List<GRE?>): List<GRE?> = vindex2.map {
val t = it.map { (B, C) -> if (left[B] != null && right[C] != null) left[B]!! * right[C]!! else null }
if (t.isEmpty()) null else t.reduce { acc, int -> if (acc == null) int else if (int == null) acc else acc + int }
}

fun CFG.startGRE(tokens: List<String>): GRE? =
initGREListMat(tokens).seekFixpoint().diagonals.last()[0][bindex[START_SYMBOL]]

Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ val CFG.vindex: Array<IntArray> by cache {
}
}

val CFG.vindex2: Array<List<List<Int>>> by cache {
Array(bindex.indexedNTs.size) { i ->
bimap[bindex[i]].filter { it.size > 1 }
.map { listOf(bindex[it[0]], bindex[it[1]]) }
}
}

val CFG.bindex: Bindex<Σᐩ> by cache { Bindex(nonterminals) }
val CFG.normalForm: CFG by cache { normalize() }
val CFG.depGraph: LabeledGraph by cache { dependencyGraph() }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.automata.GRE
import ai.hypergraph.kaliningraph.sampling.*
import ai.hypergraph.kaliningraph.tensor.*
import ai.hypergraph.kaliningraph.types.*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.automata.*
import ai.hypergraph.kaliningraph.repair.vanillaS2PCFG
import ai.hypergraph.kaliningraph.tokenizeByWhitespace
import ai.hypergraph.kaliningraph.types.*
import ai.hypergraph.kaliningraph.types.powerset
import kotlin.test.*
Expand All @@ -8,6 +11,21 @@ import kotlin.test.*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.BrzozowskiTest"
*/
class BrzozowskiTest {
/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.BrzozowskiTest.testGRE"
*/
@Test
fun testGRE() {
val ab = GRE("A") + GRE("B")
val nabab = !(ab * ab)

println(nabab.toString())

val t = Grammars.ifThen.startGRE(List(5) { "_" })

println(t?.toString()?.length)
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.BrzozowskiTest.testLeftQuotient"
*/
Expand Down

0 comments on commit 5e63a0d

Please sign in to comment.