Skip to content

Commit

Permalink
Merge branch 'main' of github.com:SapienzaNLP/relik into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Aug 2, 2024
2 parents 017c942 + e041878 commit b29df2a
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 28 deletions.
3 changes: 1 addition & 2 deletions relik/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,9 @@ def serve(
annotation_type=annotation_type,
host=host,
port=port,
frontend=frontend
)

if frontend:
serve_gradio()


if __name__ == "__main__":
Expand Down
34 changes: 34 additions & 0 deletions relik/inference/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,37 @@ def __call__(
self.window_manager = WindowManager(
self.tokenizer, self.sentence_splitter
)
else:
if isinstance(self.sentence_splitter, WindowSentenceSplitter):
if not isinstance(window_size, int):
logger.warning(
"With WindowSentenceSplitter the window_size must be an integer. "
f"Using the default window size {self.window_manager.window_size}."
f"If you want to change the window size to `sentence` or `none`, "
f"please create a new Relik instance."
)
window_size = self.window_manager.window_size
window_stride = self.window_manager.window_stride
if isinstance(self.sentence_splitter, SpacySentenceSplitter):
if window_size != "sentence":
logger.warning(
"With SpacySentenceSplitter the window_size must be `sentence`. "
f"Using the default window size {self.window_manager.window_size}."
f"If you want to change the window size to an integer or `none`, "
f"please create a new Relik instance."
)
window_size = "sentence"
window_stride = None
if isinstance(self.sentence_splitter, BlankSentenceSplitter):
if window_size != "none" or window_stride is not None:
logger.warning(
"With BlankSentenceSplitter the window_size must be `none`. "
f"Using the default window size {self.window_manager.window_size}."
f"If you want to change the window size to an integer or `sentence`, "
f"please create a new Relik instance."
)
window_size = "none"
window_stride = None

# sanity check for window size and stride
if (
Expand Down Expand Up @@ -520,6 +551,8 @@ def __call__(
windows = windows + blank_windows
windows.sort(key=lambda x: (x.doc_id, x.offset))

print(windows)

# if there is no reader, just return the windows
if self.reader is None:
# normalize window candidates to be a list of lists, like when the reader is used
Expand All @@ -530,6 +563,7 @@ def __call__(
merged_windows = self.window_manager.merge_windows(windows)

# transform predictions into RelikOutput objects
print(merged_windows)
output = []
for w in merged_windows:
span_labels = []
Expand Down
14 changes: 9 additions & 5 deletions relik/inference/data/window/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from relik.inference.data.splitters.blank_sentence_splitter import BlankSentenceSplitter
from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter
from relik.inference.data.splitters.window_based_splitter import WindowSentenceSplitter
from relik.inference.data.splitters.spacy_sentence_splitter import SpacySentenceSplitter
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
from relik.reader.data.relik_reader_sample import RelikReaderSample
from relik.inference.data.objects import AnnotationType, TaskType
Expand Down Expand Up @@ -227,10 +228,10 @@ def _merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSam
if len(windows) == 1:
return self._normalize_single_window(windows[0])

if not isinstance(self.splitter, WindowSentenceSplitter):
# here we don't really need to merge windows, just normalize them
# TODO: check if we need to merge windows in this case
return [self._normalize_single_window(w) for w in windows]
# if not isinstance(self.splitter, WindowSentenceSplitter):
# # here we don't really need to merge windows, just normalize them
# # TODO: check if we need to merge windows in this case
# return [self._normalize_single_window(w) for w in windows]

if len(windows) > 0 and getattr(windows[0], "offset", None) is not None:
windows = sorted(windows, key=(lambda x: x.offset))
Expand Down Expand Up @@ -507,7 +508,10 @@ def _merge_window_pair(
) = self._merge_predictions(window1, window2)

# merge text, take into account overlapping chars
m_text = window1.text[: window2.offset] + window2.text
if isinstance(self.splitter, SpacySentenceSplitter):
m_text = window1.text[: window2.offset] + " " + window2.text
else:
m_text = window1.text[: window2.offset] + window2.text

merging_output.update(
dict(
Expand Down
17 changes: 17 additions & 0 deletions relik/inference/serve/backend/fastapi_be.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ async def relik_endpoint(
relation_threshold: float = 0.5,
) -> List:
try:
if window_size:
# check if window size is a number as string
if window_size.isdigit():
window_size = int(window_size)

if window_stride:
# check if window stride is a number as string
if window_stride.isdigit():
window_stride = int(window_stride)

# get predictions for the retriever
return await self(
text=text,
Expand Down Expand Up @@ -182,6 +192,7 @@ def main(
workers: int = None,
host: str = "localhost",
port: int = 8000,
frontend: bool = False,
):
app = FastAPI(
title="ReLiK - A blazing fast and lightweight Information Extraction model for Entity Linking and Relation Extraction.",
Expand All @@ -201,6 +212,12 @@ def main(
annotation_type=annotation_type,
)
app.include_router(server.router)
if frontend:
from relik.inference.serve.frontend.gradio_fe import main as serve_frontend
import threading

threading.Thread(target=serve_frontend, daemon=True).start()

uvicorn.run(app, host=host, port=port, log_level="info", workers=workers)


Expand Down
25 changes: 4 additions & 21 deletions relik/inference/serve/frontend/gradio_fe.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,22 +305,8 @@ def generate_graph(
RELIK = os.getenv("RELIK", "localhost:8000/api/relik")


def text_analysis(Text, Model, Relation_Threshold, Window_Size, Window_Stride):
global loaded_model
if Model is None:
return "", ""
# if loaded_model is None or loaded_model["key"] != Model:
# relik = Relik.from_pretrained(Model, index_precision="bf16")
# loaded_model = {"key": Model, "model": relik}
# else:
# relik = loaded_model["model"]
# if Model not in relik_models:
# raise ValueError(f"Model {Model} not found.")
# relik = relik_models[Model]
# spacy for span visualization

def text_analysis(Text, Relation_Threshold, Window_Size, Window_Stride):
relik = RELIK

nlp = spacy.blank("xx")
# annotated_text = relik(
# Text,
Expand All @@ -331,8 +317,10 @@ def text_analysis(Text, Model, Relation_Threshold, Window_Size, Window_Stride):
# window_size=Window_Size,
# window_stride=Window_Stride,
# )
print(f"Using ReLiK at {relik}")
print(f"Querying ReLiK with ?text={Text}&relation_threshold={Relation_Threshold}&window_size={Window_Size}&window_stride={Window_Stride}&annotation_type=word&remove_nmes=False")
response = requests.get(
f"{relik}?text={Text}&relation_threshold={Relation_Threshold}&window_size={Window_Size}&window_stride={Window_Stride}&annotation_type=word&remove_nmes=False"
f"http://{relik}/?text={Text}&relation_threshold={Relation_Threshold}&window_size={Window_Size}&window_stride={Window_Stride}&annotation_type=word&remove_nmes=False",
)
if response.status_code != 200:
raise gr.Error(response.text)
Expand Down Expand Up @@ -390,11 +378,6 @@ def text_analysis(Text, Model, Relation_Threshold, Window_Size, Window_Stride):
text_analysis,
[
gr.Textbox(label="Input Text", placeholder="Enter sentence here..."),
# gr.Dropdown(
# relik_available_models,
# value=relik_available_models[0],
# label="Relik Model",
# ),
gr.Slider(
minimum=0,
maximum=1,
Expand Down

0 comments on commit b29df2a

Please sign in to comment.