Skip to content

Commit

Permalink
More sane algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
eddieantonio committed Oct 26, 2017
1 parent 5e49009 commit e5f5f1e
Showing 1 changed file with 24 additions and 33 deletions.
57 changes: 24 additions & 33 deletions sensibility/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Implements the logic to attempt to fix syntax errors.
"""

from typing import Iterable, Iterator, List, NamedTuple, Sequence, SupportsFloat
from typing import Iterable, Iterator, List, NamedTuple, Set, Sequence, SupportsFloat
from typing import cast

import numpy as np
Expand Down Expand Up @@ -64,10 +64,6 @@ def fix(self, source_file: bytes) -> Sequence[Edit]:
# Holds the lowest agreement at each point in the file.
results: List[IndexResult] = []

# These will hold the TOP predictions at a given index.
forwards_predictions: List[Vind] = []
backwards_predictions: List[Vind] = []

for index, pred in enumerate(predictions):
vind = file_vector[index]
token = tokens[index]
Expand All @@ -78,54 +74,35 @@ def fix(self, source_file: bytes) -> Sequence[Edit]:
# truth.
result = IndexResult(index, file_vector, prefix_pred, suffix_pred, token, vind)
results.append(result)
print(result)

# Store the TOP prediction from each model.
# TODO: document corner cases!
top_next_prediction = prefix_pred.argmax()
top_prev_prediction = suffix_pred.argmax()

assert 0 <= top_next_prediction <= len(language.vocabulary)
assert 0 <= top_prev_prediction <= len(language.vocabulary)
forwards_predictions.append(cast(Vind, top_next_prediction))
backwards_predictions.append(cast(Vind, top_prev_prediction))
assert top_next_prediction == forwards_predictions[index]
assert top_prev_prediction == backwards_predictions[index]

# Rank the results by some metric of similarity defined by IndexResult
# (the top rank will be LEAST similar).
ranked_results = tuple(sorted(results, key=float))

# For the top-k disagreements, synthesize fixes.
# NOTE: k should be determined by the MRR of finding the syntax error!
# NOTE: k should be determined by the xentropy of the models!
fixes = Fixes(file_vector)
# import pdb
# pdb.set_trace()
print("Winners:")
for disagreement in ranked_results[:self.k]:
pos = disagreement.index
print(disagreement)

likely_next: Vind = forwards_predictions[pos]
likely_prev: Vind = backwards_predictions[pos]
likely_tokens = disagreement.best_suggestions()

# Note: the order of these operations SHOULDN'T matter,
# but typically we only report the first fix that works.
# Because missing tokens are the most common
# we'll try to insert tokens first, THEN delete.

# Assume a deletion. Let's try inserting some tokens.
fixes.try_insert(pos, likely_next)
fixes.try_insert(pos, likely_prev)
for likely_token in likely_tokens:
fixes.try_insert(pos, likely_token)

# Assume an insertion. Let's try removing the offensive token.
fixes.try_delete(pos)

# Assume a substitution. Let's try swapping the token.
fixes.try_substitute(pos, likely_next)
fixes.try_substitute(pos, likely_prev)
for likely_token in likely_tokens:
fixes.try_substitute(pos, likely_token)

# TODO: sort by how fixed the result is after applying the fix.
return tuple(fixes)


Expand Down Expand Up @@ -162,7 +139,7 @@ def __init__(self, index: int, program: SourceVector,

# Cross-entropy
p = one_hot(vind, len(a))
self.xentropy = cross_entropy(p, a) + cross_entropy(p, b)
self.xentropy = float(cross_entropy(p, a) + cross_entropy(p, b))

# Use averaged KL-divergence?
# http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.332.4480&rep=rep1&type=pdf
Expand Down Expand Up @@ -213,7 +190,7 @@ def __float__(self) -> float:
"""
# The original paper had this metric:
# Agreement: elementwise harmonic mean of two vectors
return self.indexed_prob
return -self.xentropy

def __str__(self) -> str:
"""
Expand All @@ -234,12 +211,26 @@ def __str__(self) -> str:
cosine_sim = {self.cosine_similarity:5}
"""

def best_suggestions(self) -> Set[Vind]:
return set(self.top_forwards) | set(self.top_backwards)

@property
def top_forwards(self):
return self._top(self.a)

@property
def top_backwards(self):
return self._top(self.b)

def _top(self, vector, k=3) -> np.ndarray:
return vector.argpartition(-k)[-k:][::-1]

def _maxes(self, vector, k=3):
"""
Yields percentage, and token text of top-k entries.
"""
from sensibility import current_language
for idx in vector.argpartition(-k)[-k:][::-1]: # for idx in vector.argsort()[-1:-4:-1]:
for idx in self._top(vector): # for idx in vector.argsort()[-1:-4:-1]:
yield 100. * vector[idx], current_language.to_text(idx)


Expand Down

0 comments on commit e5f5f1e

Please sign in to comment.