-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
25 changed files
with
737 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#%% | ||
import irene.lang |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from .server import IreneService | ||
from .lang import * | ||
import json | ||
import attr | ||
|
||
service = IreneService() | ||
INDEX = 'robust' | ||
service.open(INDEX, 'robust04.irene') | ||
terms = service.tokenize(INDEX, "hello world!") | ||
ql = RM3Expr( | ||
CombineExpr( | ||
children=[DirQLExpr(TextExpr(t)) for t in terms], | ||
weights=[1.0 for t in terms])) | ||
|
||
print(service.query(INDEX, ql, 20)) | ||
print(service.doc(INDEX, 'LA081890-0076')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
#%% | ||
from abc import abstractmethod, ABC | ||
import attr | ||
from typing import List, Optional | ||
|
||
@attr.s | ||
class CountStats(object): | ||
source = attr.ib(type=str) | ||
cf = attr.ib(type=int) | ||
df = attr.ib(type=int) | ||
cl = attr.ib(type=int) | ||
dc = attr.ib(type=int) | ||
def average_doc_len(self): | ||
if self.dc == 0: | ||
return 0.0 | ||
return self.cl / self.dc | ||
def count_probability(self): | ||
if self.cl == 0: | ||
return 0.0 | ||
return self.cf / self.cl | ||
def nonzero_count_probability(self): | ||
if self.cl == 0: | ||
return 0.0 | ||
return max(0.5, self.cf) / self.cl | ||
def binary_probability(self): | ||
if self.dc == 0: | ||
return 0.0 | ||
return self.df / self.dc | ||
|
||
@attr.s | ||
class QExpr(object): | ||
def children(self): | ||
return list(attr.asdict(self, recurse=False, filter=lambda attr,x: isinstance(x, QExpr)).values()) | ||
def weighted(self, weight): | ||
return WeightedExpr(child=self, weight=weight) | ||
|
||
### | ||
# Boolean opeartions | ||
### | ||
|
||
# Sync this class to Galago semantics. | ||
# - Consider every doc that has a match IFF cond has a match, using value, regardless of whether value also has a match. | ||
# - Implemented by [RequireEval]. | ||
# - If instead you want to score value only if cond has a match, use [MustExpr] -> [MustEval]. | ||
@attr.s | ||
class RequireExpr(object): | ||
cond = attr.ib(type=QExpr) | ||
value = attr.ib(type=QExpr) | ||
kind = attr.ib(type=str, default='Require') | ||
|
||
# Score the [value] query when it matches IFF [must] also has a match. This is a logical AND. | ||
@attr.s | ||
class MustExpr(object): | ||
cond = attr.ib(type=QExpr) | ||
value = attr.ib(type=QExpr) | ||
|
||
@attr.s | ||
class AndExpr(QExpr): | ||
children: List[QExpr] = attr.ib() | ||
|
||
@attr.s | ||
class OrExpr(QExpr): | ||
children: List[QExpr] = attr.ib() | ||
|
||
# AKA: True | ||
@attr.s | ||
class AlwaysMatchLeaf(object): | ||
pass | ||
|
||
# AKA: False | ||
@attr.s | ||
class NeverMatchLeaf(object): | ||
pass | ||
|
||
### | ||
# Scoring Transformations | ||
### | ||
|
||
def SumExpr(children: List[QExpr]): | ||
N = len(children) | ||
return CombineExpr(children, [1.0 for _ in children]) | ||
|
||
def MeanExpr(children: List[QExpr]): | ||
N = len(children) | ||
return CombineExpr(children, [1.0 / N for _ in children]) | ||
|
||
@attr.s | ||
class CombineExpr(QExpr): | ||
children: List[QExpr] = attr.ib() | ||
weights: List[float] = attr.ib() | ||
kind = attr.ib(type=str, default='Combine') | ||
|
||
@attr.s | ||
class MultExpr(QExpr): | ||
children: List[QExpr] = attr.ib() | ||
kind = attr.ib(type=str, default='Mult') | ||
|
||
@attr.s | ||
class MaxExpr(QExpr): | ||
children: List[QExpr] = attr.ib() | ||
kind = attr.ib(type=str, default='Max') | ||
|
||
@attr.s | ||
class WeightedExpr(QExpr): | ||
child = attr.ib(type=QExpr) | ||
weight = attr.ib(type=float) | ||
kind = attr.ib(type=str, default='Weighted') | ||
|
||
### | ||
# Leaf Nodes | ||
### | ||
|
||
@attr.s | ||
class TextExpr(QExpr): | ||
text = attr.ib(type=str) | ||
field = attr.ib(type=Optional[str], default=None) | ||
stats_field = attr.ib(type=Optional[str], default=None) | ||
kind = attr.ib(type=str, default='Text') | ||
|
||
@attr.s | ||
class BoolExpr(QExpr): | ||
field = attr.ib(type=str) | ||
desired = attr.ib(type=bool, default=True) | ||
|
||
@attr.s | ||
class LengthsExpr(QExpr): | ||
field = attr.ib(type=str) | ||
|
||
@attr.s | ||
class WhitelistMatchExpr(QExpr): | ||
doc_names: List[str] = attr.ib() | ||
|
||
@attr.s | ||
class DenseLongField(QExpr): | ||
name = attr.ib(type=str) | ||
missing = attr.ib(type=int, default=0) | ||
|
||
class DenseFloatField(QExpr): | ||
name = attr.ib(type=str) | ||
# TODO: how do I float32::min in python? | ||
missing = attr.ib(type=float, default=None) | ||
|
||
### | ||
# Phrase Nodes | ||
### | ||
|
||
@attr.s | ||
class OrderedWindowExpr(QExpr): | ||
children: List[QExpr] = attr.ib() | ||
step = attr.ib(type=int, default=1) | ||
kind = attr.ib(type=str, default="OrderedWindow") | ||
|
||
@attr.s | ||
class UnorderedWindowNode(QExpr): | ||
children: List[QExpr] = attr.ib() | ||
width = attr.ib(type=int, default=8) | ||
kind = attr.ib(type=str, default="UnorderedWindow") | ||
|
||
@attr.s | ||
class SmallestCountExpr(QExpr): | ||
children: List[QExpr] = attr.ib() | ||
|
||
@attr.s | ||
class SynonymExpr(QExpr): | ||
children: List[QExpr] = attr.ib() | ||
kind = attr.ib(type=str, default="Synonym") | ||
|
||
### | ||
# Scorers | ||
### | ||
|
||
@attr.s | ||
class BM25Expr(QExpr): | ||
child = attr.ib(type=QExpr) | ||
b = attr.ib(type=Optional[float], default=None) | ||
k = attr.ib(type=Optional[float], default=None) | ||
stats = attr.ib(type=CountStats, default=None) | ||
kind = attr.ib(type=str, default='BM25Expr') | ||
|
||
@attr.s | ||
class LinearQLExpr(QExpr): | ||
child = attr.ib(type=QExpr) | ||
smoothing_lambda = attr.ib(type=Optional[float], default=None) | ||
stats = attr.ib(type=CountStats, default=None) | ||
kind = attr.ib(type=str, default='LinearQL') | ||
|
||
|
||
@attr.s | ||
class DirQLExpr(QExpr): | ||
child = attr.ib(type=QExpr) | ||
mu = attr.ib(type=Optional[float], default=None) | ||
stats = attr.ib(type=CountStats, default=None) | ||
kind = attr.ib(type=str, default='DirQL') | ||
|
||
@attr.s | ||
class RM3Expr(QExpr): | ||
child = attr.ib(type=QExpr) | ||
orig_weight = attr.ib(type=float, default=0.3) | ||
fb_docs = attr.ib(type=int, default=20) | ||
fb_terms = attr.ib(type=int, default=100) | ||
stopwords = attr.ib(type=bool, default=True) | ||
field = attr.ib(type=Optional[str], default=None) | ||
kind = attr.ib(type=str, default='RM3') | ||
|
||
|
||
if __name__ == '__main__': | ||
expr = WeightedExpr(child=TextExpr("hello"), weight=0.5) | ||
print(expr) | ||
print(expr.children()) | ||
#%% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from irene.lang import QExpr, MeanExpr, SumExpr, DirQLExpr, TextExpr, BM25Expr | ||
from typing import List | ||
|
||
def QueryLikelihood(words: List[str], scorer: lambda x: DirQLExpr(x)) -> QExpr: | ||
return MeanExpr([scorer(TextExpr(w)) for w in words]) | ||
|
||
def BM25(words: List[str]) -> QExpr: | ||
return SumExpr([BM25Expr(TextExpr(w)) for w in words]) | ||
|
||
def SequentialDependenceModel(words: List[str]): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import requests | ||
from typing import List, Optional, Dict, Any | ||
from .lang import QExpr | ||
import attr | ||
import json | ||
|
||
@attr.s | ||
class DocResponse(object): | ||
name = attr.ib(type=str) | ||
score = attr.ib(type=float) | ||
|
||
@attr.s | ||
class QueryResponse(object): | ||
topdocs = attr.ib(type=List[DocResponse]) | ||
totalHits = attr.ib(type=int) | ||
|
||
#%% | ||
class IreneService(object): | ||
def __init__(self, host="localhost", port=1234): | ||
self.host = host | ||
self.port = port | ||
self.url = 'http://{0}:{1}'.format(host, port) | ||
self.known_open_indexes = {} | ||
|
||
def _url(self, path): | ||
return self.url + path | ||
|
||
def open(self, name: str, path: str) -> bool: | ||
response = requests.post(self._url("/open"), data={'name': name, 'path': path}) | ||
if response.ok: | ||
self.known_open_indexes[name] = path | ||
return True | ||
return False | ||
|
||
def tokenize(self, index: str, text: str, field: Optional[str] = None) -> List[str]: | ||
params = {'index': index, 'text': text} | ||
if field is not None: | ||
params['field'] = field | ||
return requests.get(self._url("/tokenize"), params).json()['terms'] | ||
|
||
def doc(self, index: str, name: str) -> Dict[str, Any]: | ||
params = {'index': index, 'id': name } | ||
return requests.get(self._url('/doc'), params).json() | ||
|
||
def query(self, index: str, query: QExpr, depth: int=50) -> QueryResponse: | ||
# data class QueryRequest(val index: String, val depth: Int, val query: QExpr) | ||
params = {'index': index, 'depth': depth, 'query': attr.asdict(query) } | ||
print(json.dumps(params['query'], indent=2)) | ||
r_json = requests.post(self._url('/query'), json=params).json() | ||
topdocs = r_json['topdocs'] | ||
return QueryResponse([DocResponse(**td) for td in topdocs], r_json['totalHits']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.