Skip to content

Commit

Permalink
Merge pull request #3 from derilinx/tweaks
Browse files Browse the repository at this point in the history
Compatibility and Performance
  • Loading branch information
amercader authored Mar 6, 2024
2 parents 7f6c07a + aa47612 commit 4f04c1b
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 24 deletions.
21 changes: 10 additions & 11 deletions ckanext/embeddings/actions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ckan.plugins import toolkit

from ckanext.embeddings.backends import get_embeddings_backend

import logging
log = logging.getLogger(__name__)

@toolkit.side_effect_free
def package_similar_show(context, data_dict):
Expand All @@ -12,18 +12,17 @@ def package_similar_show(context, data_dict):
except ValueError:
raise toolkit.ValidationError(f"Wrong value for limit paramater: {limit}")

field_name = toolkit.config.get("ckanext.embeddings.solr_vector_field_name", "vector")

try:
dataset_dict = toolkit.get_action("package_show")(
{"ignore_auth": True}, {"id": dataset_id}
)
except toolkit.ObjectNotFound:
vectors = toolkit.get_action("package_search")(
{"ignore_auth": True}, {"fq": f"(id:{dataset_id} OR name:{dataset_id})", 'fl':f"{field_name},id"}
)['results']
dataset_dict = vectors.pop()
dataset_embedding = dataset_dict[field_name]
except IndexError:
raise toolkit.ObjectNotFound(f"Dataset not found: {dataset_id}")

backend = get_embeddings_backend()
dataset_embedding = backend.get_embedding_for_dataset(dataset_dict)

field_name = toolkit.config.get("ckanext.embeddings.solr_vector_field_name", "vector")

search_params = {}
search_params["defType"] = "lucene"
search_params["q"] = f"{{!knn f={field_name} topK={limit}}}{list(dataset_embedding)}"
Expand Down
24 changes: 17 additions & 7 deletions ckanext/embeddings/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

class BaseEmbeddingsBackend:
def get_dataset_values(self, dataset_dict):

if dataset_dict.get("notes"):
return dataset_dict["title"] + " " + dataset_dict["notes"]
else:
Expand Down Expand Up @@ -102,19 +101,30 @@ def create_embedding(self, values):

def _load_embeddings_backends():
from importlib.metadata import entry_points
for ep in entry_points(group="ckanext.embeddings.backends"):
try:
eps = entry_points(group="ckanext.embeddings.backends")
except:
# python 3.9/3.8
eps = (ep for ep in entry_points()['ckanext.embeddings.backends'])
for ep in eps:
embeddings_backends[ep.name] = ep.load()
log.debug(f"Registering Embeddings Backend: {ep.name}")

_embeddings_backend = None

def get_embeddings_backend():

# TODO: config declaration

global _embeddings_backend
backend = toolkit.config.get("ckanext.embeddings.backend", "sentence_transformers")

log.debug(f"Using Embeddings Backend: {backend}")
return embeddings_backends[backend]()

import time
start = time.time()
try:
_load_embeddings_backends()
if _embeddings_backend is None:
_embeddings_backend = embeddings_backends[backend]()
return _embeddings_backend
finally:
log.debug("loading embeddings took: %.3f sec", time.time()-start)

_load_embeddings_backends()
7 changes: 7 additions & 0 deletions ckanext/embeddings/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,10 @@ def search(query: str, limit: int):
for r in result["results"]:

print(f"{r['id']} - {r['title']}")


@embeddings.command()
def load():
""" Loads the backend embeddings, filling whatever cache is required, downloading models, etc """
backend = get_embeddings_backend()

19 changes: 13 additions & 6 deletions ckanext/embeddings/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ class EmbeddingPlugin(plugins.SingletonPlugin):
plugins.implements(plugins.ITemplateHelpers)
plugins.implements(plugins.IPackageController, inherit=True)

backend = None
_backend = None

@property
def backend(self):
if self._backend is None:
self._backend = get_embeddings_backend()
return self._backend

# IConfigurer

def update_config(self, config):

self.backend = get_embeddings_backend()

toolkit.add_template_directory(config, "templates")
toolkit.add_resource("assets", "ckanext-embeddings")

Expand Down Expand Up @@ -57,8 +60,6 @@ def before_dataset_index(self, dataset_dict):

dataset_id = dataset_dict["id"]

if not self.backend:
self.backend = get_embeddings_backend()
dataset_embedding = self.backend.get_embedding_for_dataset(dataset_dict)

if dataset_embedding is not None:
Expand All @@ -69,6 +70,9 @@ def before_dataset_index(self, dataset_dict):

return dataset_dict

def before_index(self, dataset_dict):
return self.before_dataset_index(dataset_dict)

def before_dataset_search(self, search_params):
extras = search_params.get("extras", {})
if isinstance(extras, str):
Expand Down Expand Up @@ -101,3 +105,6 @@ def before_dataset_search(self, search_params):
search_params["q"] = f"{{!knn f={field_name} topK={rows}}}{list(embedding)}"

return search_params

def before_search(self, search_params):
return self.before_dataset_search(search_params)

0 comments on commit 4f04c1b

Please sign in to comment.