diff --git a/.gcloudignore b/.gcloudignore index eaeced0..cc6d495 100644 --- a/.gcloudignore +++ b/.gcloudignore @@ -13,3 +13,4 @@ package*.json /README* /LICENSE /account.json +.idea diff --git a/.gitignore b/.gitignore index 561d105..cfdcd8c 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ package-lock.json /~* /logs /account.json +.idea diff --git a/README.md b/README.md index 73bcd3d..39852f8 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,8 @@ An engaging virtual assistant service for answering (almost) any question about ## Setup ### Requirements -- (optional) *\/account.json* with valid *GCP service account* data +- *Dialogflow* agent restored from `knowledge-agent.zip` (see *releases*): + - the *Fandom* knowledge base ('https://{0}.fandom.com/wiki/{1}') - *enabled*, identified by `Fandom KB ID` (the part after '.../editKnowledgeBase/'); - *\/config.yaml* with the following: ```yaml google_api: @@ -36,14 +37,18 @@ google_api: custom_search: cx: - + +dialogflow: + fandom: + redis: - host: port: - pass: + pass: # optional - ... ``` +- (optional) *\/account.json* with valid *GCP service account* data. -The Redis credentials are tried sequentially until a successful database connection is established. +The Redis credentials are tried sequentially until the first successful database connection. ## MIT License diff --git a/_bareasgi.py b/_bareasgi.py new file mode 100644 index 0000000..83d3a16 --- /dev/null +++ b/_bareasgi.py @@ -0,0 +1,6 @@ +import bareasgi as _bareasgi +from bareasgi import * + +def json_response(data, status=200, headers={}): + headers = [] # FIXME + return _bareasgi.json_response(status, headers, data) diff --git a/app.js b/app.js index b75822b..7540e5c 100644 --- a/app.js +++ b/app.js @@ -46,10 +46,10 @@ function get_knowledge() { } function send_message(text='', cb) { - request.post('/respond').send(text).then(({ body }) => { - state.conversation.push({ text: body[0] }); + request.post('/message').send(text).then(({ body }) => { + state.conversation.push(body); }, (e) => { - _error('POST', '/respond', e); + _error('POST', '/message', e); if (cb) { cb(); } diff --git a/config.py b/config.py new file mode 100644 index 0000000..576a04e --- /dev/null +++ b/config.py @@ -0,0 +1,17 @@ +import os, yaml, json + +from util import realpath + + +with open(realpath('config.yaml')) as f: + locals().update(yaml.load(f, Loader=yaml.SafeLoader)) + +project_id = os.getenv('GOOGLE_CLOUD_PROJECT', None) +if not project_id: + with open(realpath('account.json')) as f: + _account = json.load(f) + project_id = _account['project_id'] + +__all__ = ( + 'google_api', 'custom_search', 'dialogflow', 'redis', 'project_id' +) diff --git a/dialogflow.py b/dialogflow.py index fceb680..374f292 100644 --- a/dialogflow.py +++ b/dialogflow.py @@ -1,11 +1,12 @@ +import os from uuid import uuid4 from dialogflow_v2beta1 import SessionsClient, KnowledgeBasesClient, DocumentsClient -from dialogflow_v2beta1 import types as dialogflow -from dialogflow_v2beta1 import enums +from dialogflow_v2beta1 import types, enums from google.api_core.exceptions import InvalidArgument, GoogleAPICallError -from util import * +from util import realpath +from config import project_id EXTRACTIVE_QA = [enums.Document.KnowledgeType.EXTRACTIVE_QA] _account = realpath('account.json') @@ -17,6 +18,30 @@ session = SessionsClient() kb = KnowledgeBasesClient() docs = DocumentsClient() + +class KnowledgeBase: + def __init__(self, id): + if isinstance(id, types.KnowledgeBase): + self._path = id.name + self.caption = id.display_name + else: + self._path = kb.knowledge_base_path(project_id, str(id)) + self.caption = kb.get_knowledge_base(self._path).display_name + + def __iter__(self): + yield from docs.list_documents(self._path) + + def create(self, caption, text=None): + if text is None: + caption, text = caption + doc = types.Document( + display_name=caption, mime_type='text/plain', + knowledge_types=EXTRACTIVE_QA, content=text) + try: + return docs.create_document(self._path, doc).result() + except (InvalidArgument, GoogleAPICallError): + res = [d for d in self if d.display_name == caption] + return res[0] if res else None class Dialogflow: def __init__(self, session_id=uuid4(), language_code='en'): @@ -26,27 +51,14 @@ def __init__(self, session_id=uuid4(), language_code='en'): self.language_code = language_code self.min_confidence = 0.8 - def init(self, name): - return kb.create_knowledge_base(self._kb, dialogflow.KnowledgeBase(display_name=name)) - - def store(self, container, title, text): - doc = dialogflow.Document( - display_name=title, mime_type='text/plain', - knowledge_types=EXTRACTIVE_QA, content=text) - try: - return docs.create_document(container, doc).result() - except (InvalidArgument, GoogleAPICallError): - res = [d for d in self.documents(container) if d.display_name == title] - return res[0] if res else None - def __call__(self, text=None, event=None): language_code = self.language_code if text is not None: - text_input = dialogflow.TextInput(text=text, language_code=language_code) - query_input = dialogflow.QueryInput(text=text_input) + text_input = types.TextInput(text=text, language_code=language_code) + query_input = types.QueryInput(text=text_input) elif event is not None: - event_input = dialogflow.EventInput(name=event, language_code=language_code) - query_input = dialogflow.QueryInput(event=event_input) + event_input = types.EventInput(name=event, language_code=language_code) + query_input = types.QueryInput(event=event_input) else: return None return session.detect_intent(session=self._session, query_input=query_input) @@ -72,11 +84,6 @@ def event(self, name, raw=False): res = self(event=name) return res if raw else res.query_result.fulfillment_text - def knowledge_bases(self): - return kb.list_knowledge_bases(self._kb) - - def documents(self, container): - name = container - if not isinstance(container, str): - name = container.name - return docs.list_documents(name) + def __iter__(self): + for item in kb.list_knowledge_bases(self._kb): + yield KnowledgeBase(item) diff --git a/document.py b/document.py new file mode 100644 index 0000000..ebb142d --- /dev/null +++ b/document.py @@ -0,0 +1,106 @@ +from urllib.error import HTTPError +from collections import namedtuple + +from util import pq, List + + +_excludes = ( + 'Recommended_Readings', + 'See_Also', + 'Residents', + 'Paraphernalia', + 'Alternate_Reality_Versions' +) + +scrape_excludes = List( + [ + *_excludes, + 'Links_and_References', + 'References', + 'Points_of_Interest', + 'Links' + ], + format=':not(#{item})', + str='' +) + +Fragment = namedtuple('Fragment', ('caption', 'text')) + +def _text(el, to_strip=None): + if el is None: + return None + return el.text().strip().strip(to_strip).strip() + +class Document: + def __init__(self, url=None, name=None, quotes=False): + if name is not None: + url = url.format(*name.split('|')) + self.name = name + else: + self.name = '|'.join([url.subdomain, url.basename]) + + self.url = str(url) + try: + doc = pq(url=self.url) + except HTTPError: + doc = pq([]) + self.caption = doc.children('head > title').text().split('|', 1)[0].strip() + self.site = doc.find('link[rel="search').attr('title').rstrip('(en)').strip() + self._doc = doc + self.__content = None + self._data = None + self._quotes = quotes + sel = List(['h3, p, ul, ol']) + if self._quotes: + sel.append('.quote') + self._sel = str(sel) + + def __bool__(self): + return bool(self._doc) + + def _content(self): + if self.__content is None: + content = self._doc.find('.mw-content-text') + content.find('.noprint, noscript, script, style, link, iframe, embed, video, img, .editsection').remove() + content.find('*').remove_attr('style') + self.__content = content + return self.__content + + def __iter__(self): + if not self: + return + + if self._data is not None: + yield from self._data + + self._data = [] + content = self._content() + content.find('.reference').remove() + if self._quotes: + for quote in content.find('.quote').items(): + author = quote.find('.selflink').closest('b') + author.closest('dl').remove() + _quote = quote.find('i') + _quote.text('"' + _text(_quote, '"\'') + '"') + author.append('said').prependTo(_quote.closest('dd')) + + h2_list = content.children(f'h2{scrape_excludes} > {scrape_excludes}').closest('h2') + for h2 in h2_list.items(): + self._append(h2.nextUntil('h2, h3', self._sel), h2) + for h3 in h2.nextUntil('h2', 'h3'): + self._append(h3.nextUntil('h2, h3', self._sel), h2, h3) + + def _append(self, body, *heads): + _data = self._data + if _data is None or not body: + return False + + caption = List((_text(h) for h in heads), str='/') + text = List((_text(e) for e in body.items()), False, str='\n') + _data.append(Fragment(f"{self.name}#{caption}", str(text))) + return True + + @staticmethod + def parse_name(name): + name, heads = name.split('#') + return name, heads.split('/') diff --git a/engine.py b/engine.py index 49024e8..fc575e1 100644 --- a/engine.py +++ b/engine.py @@ -1,23 +1,25 @@ -import os, json, re -from math import inf +import re -from walrus import Database -from bareasgi import text_reader -from redis.exceptions import ResponseError +from _bareasgi import text_reader, json_response +#from redis import StrictRedis +#from redis.exceptions import ResponseError +import rom +from rom import util as _rom, session -from util import * -from scraper import scrape, parse -from search import Search -from dialogflow import Dialogflow +from config import dialogflow as _dialogflow, redis +from document import Document +from search import search +from dialogflow import Dialogflow, KnowledgeBase -google_api = config['google_api'] -search = Search(google_api['key'], config['custom_search']['cx']) from uuid import uuid1 ping = bytes(str(uuid1()), 'utf-8') -for _redis in config['redis']: +db_error = '' +for server in redis: try: - db = Database(host=_redis['host'], port=_redis['port'], password=_redis.get('pass'), db=0) + _rom.set_connection_settings(host=server['host'], port=server['port'], password=server.get('pass'), decode_responses=True) + db = _rom.get_connection() + #db = StrictRedis(host=server['host'], port=server['port'], password=server.get('pass'), db=0, decode_responses=True) if db is None: db_error = '' continue @@ -29,28 +31,56 @@ pass if db_error: - print('[WARN] Redis connection failed: ', db_error) + print('[WARN] Redis connection failed:', db_error) db = None else: - print('[INFO] Redis connection: ', db.connection_pool.connection_kwargs['host']) - url_db = db.Hash('urls') + print('[INFO] Redis connection:', db.connection_pool.connection_kwargs['host']) + +class _Fragment(rom.Model): + path = rom.String(required=True, unique=True) + name = rom.String(required=True, unique=True) + document = rom.ManyToOne('_Document', required=True, on_delete='no action') + +class _Document(rom.Model): + name = rom.String(required=True, unique=True) + url = rom.String(unique=True) + caption = rom.String(required=True) + site = rom.String(default='') + fragments = rom.OneToMany('_Fragment') dialogflow = Dialogflow() +fandom = KnowledgeBase(_dialogflow['fandom']) +_url = fandom.caption + +# TODO: Delete database entries not present in Fandom KB +docs = {} +for fragment in fandom: + if not _Fragment.get_by(path=fragment.name): + name, heads = Document.parse_name(fragment.display_name) + docs.setdefault(name, set()).add(fragment) -kb_dict = {} -for kb in dialogflow.knowledge_bases(): - kb_dict[kb.display_name] = kb.name +for name, fragments in docs.items(): + _doc = _Document.get_by(name=name) + if not _doc: + doc = Document(_url, name) + _doc = _Document(name=name, url=doc.url, caption=doc.caption, site=doc.site) + for fragment in fragments: + _fragment = _Fragment(path=fragment.name, name=fragment.display_name, document=_doc) +session.flush() +sites = {} async def knowledge(scope, info, matches, content): + for _doc in _Document.query.all(): + sites.setdefault(_doc.site, {}).setdefault(_doc.url, _doc.caption) + res = [] - for name, path in kb_dict.items(): - docs = [] - for doc in dialogflow.documents(path): - docs.append({ 'caption': doc.display_name, 'url': url_db.get(doc.name, b'').decode() or None }) - res.append({ 'caption': name, 'documents': sorted(docs, key=lambda e: e['caption']) }) + for site, docs in sites.items(): + _docs = ({'caption': caption, 'url': url} for url, caption in docs.items()) + res.append({ 'caption': site, 'documents': sorted(_docs, key=lambda e: e['caption']) }) + return json_response(sorted(res, key=lambda e: e['caption'])) -async def respond(scope, info, matches, content): +async def message(scope, info, matches, content): text = re.sub(r'\s+', ' ', (await text_reader(content)).strip().lstrip('.').strip()) if text == '': return json_response([dialogflow.event('WELCOME')]) @@ -63,37 +93,28 @@ async def respond(scope, info, matches, content): if not query: return 400 - docs = str_set() - for e in search(query)[:10]: - if e['url'] in db: - path = db[e['url']] - docs.add(path) - url_db[path] = e['url'] - else: - print('[INFO] Generating document:', e['url']) - doc = parse(url=e['url']) - if doc is None: - print('[WARN] URL request failed:', e['url']) + urls = set() + for url in search(query)[:10]: + if not _Document.get_by(url=str(url)): + doc = Document(url) + print('[INFO] Generating document:', doc.name) + if not doc: + print('[WARN] URL request failed:', doc.url) continue - if doc.site not in kb_dict: - print('[WARN] Missing knowledge base:', doc.site) - continue - #db[e['domain']] = dialogflow.init(doc.site).name - res = dialogflow.store(kb_dict[doc.site], doc.title, scrape(doc)) - if res is None: - print('[WARN] Document creation failed:', e['url']) - continue - print('[INFO] Document created:', res.name) - db[e['url']] = res.name - url_db[res.name] = e['url'] - docs.add(res.name) + _doc = _Document(name=doc.name, url=doc.url, caption=doc.caption, site=doc.site) + for fragment in doc: + res = fandom.create(fragment) + if res is None: + print('[WARN] Fragment creation failed:', fragment.caption) + continue + _fragment = _Fragment(path=res.name, name=fragment.caption, document=_doc) + print('[INFO] Document created:', doc.name) - try: - db.bgsave() - except ResponseError: - print("[WARN] Redis: command 'BGSAVE' failed") + urls.add(str(url)) + + session.flush() - answers = dialogflow.get_answers(text, filter=lambda a: a.source in docs) + answers = dialogflow.get_answers(text, filter=lambda a: _Fragment.get_by(path=a.source).document.url in urls) if answers: return json_response(answers) diff --git a/main.py b/main.py index 272502f..7ded84a 100644 --- a/main.py +++ b/main.py @@ -4,16 +4,16 @@ import os import uvicorn -from bareasgi import Application +from _bareasgi import Application from bareasgi_static import add_static_file_provider -from engine import respond, knowledge +from engine import knowledge, message here = os.path.abspath(os.path.dirname(__file__)) app = Application() app.http_router.add({'GET'}, '/knowledge', knowledge) -app.http_router.add({'POST'}, '/respond', respond) +app.http_router.add({'POST'}, '/message', message) add_static_file_provider(app, os.path.join(here, 'static'), index_filename='index.html') diff --git a/requirements.txt b/requirements.txt index f8162d8..2a6f974 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,12 @@ uvicorn bareasgi bareasgi_static -walrus pyquery google-api-python-client dialogflow PyYAML aiofiles +redis +rom +pywsd +numpy diff --git a/scraper.py b/scraper.py deleted file mode 100644 index f8bec5d..0000000 --- a/scraper.py +++ /dev/null @@ -1,94 +0,0 @@ -from urllib.error import HTTPError - -from pyquery import PyQuery as pq - -from util import * - -class Selector(tuple): - def __new__(cls, *data): - return tuple.__new__(cls, data) - - def __str__(self): - return self._str - - def replace(self, *args): - return str(self).replace(*args) - -class Inclusion(Selector): - def __init__(self, *data): - self._str = ', '.join(f'#{e}' for e in data) - -class Exclusion(Selector): - def __init__(self, *data): - self._str = ''.join(f':not(#{e})' for e in data) - -_excludes = Exclusion( - 'Recommended_Readings', - 'See_Also', - 'Residents', -) - -scrape_excludes = Exclusion( - *_excludes, - 'Links_and_References', - 'References', - 'Points_of_Interest', - 'Links' -) - -with_h3 = Inclusion( - 'Paraphernalia', - 'Alternate_Reality_Versions' -) - -@attach(pq.fn) -def content(): - if not hasattr(this, '_content'): - this._content = this.find('.mw-content-text') - this._content.find('.noprint, noscript, script, style, link, iframe, embed, video, img, .editsection').remove() - this._content.find('*').remove_attr('style') - return this._content - -def parse(**kw): - key, val = next(iter(kw.items())) - kw = { key: make_url(val) if key is 'url' else val } - try: - doc = pq(**kw) - _title = doc.children('head > title').text() - title = _title.split('|', 2) - doc.title, doc.site = (e.strip() for e in title[:2]) - return doc - except HTTPError: - pass - except ValueError: - print('[WARN] Could not determine site:', _title) - doc.title, doc.site = title.strip(), '' - return doc - -def _scrape(doc): - content = doc.content() - for ref in content.find('.reference').items(): - ref.attr('data-ref', ref.text().strip('[]')) - ref.text('') - for quote in content.find('.quote').items(): - author = quote.find('.selflink').closest('b') - author.closest('dl').remove() - text = quote.find('i') - text.text('"' + text.text().strip('"\'').strip() + '"') - author.append('said').prependTo(text.closest('dd')); - sections = [] - headlines = content.children(f'h2{scrape_excludes} > {scrape_excludes}').closest('h2') - for h in headlines.items(): - sel = 'p, ul, ol, .quote' - sel = f'h3, {sel}' if h.children(with_h3) else sel - h.body = h.nextUntil('h2', sel) - sections.append(h) - doc.sections = sections - return doc - -def scrape(doc): - _scrape(doc) - res = [] - for section in doc.sections: - res.extend(e.text().strip() for e in section.body.items()) - return '\n'.join(s for s in res if s) diff --git a/search.py b/search.py index 12345b3..d9fa29e 100644 --- a/search.py +++ b/search.py @@ -3,37 +3,29 @@ from googleapiclient.discovery import build -from util import * +from util import URL +from config import custom_search, google_api -site_priority = dict([e, i] for i, e in enumerate([ - 'marvel', 'ironman', 'dc', 'batman', 'arkhamcity', 'agentsofshield', 'marvelcinematicuniverse', 'marvel-movies', 'dcau', 'gotham' +site_priority = dict((e, i) for i, e in enumerate([ + 'marvel', 'ironman', 'xmen', 'dc', 'batman', 'arkham', 'dcau', 'heroes', 'arkhamcity', 'agentsofshield', 'marvelcinematicuniverse', 'marvel-movies' ])) -class Search: - def __init__(self, key, cx): - self.service = build('customsearch', 'v1', developerKey=key) - self.cx = cx - self.cse = self.service.cse() +_service = build('customsearch', 'v1', developerKey=google_api['key']) +_cx = custom_search['cx'] +_cse = _service.cse() - def __call__(self, query, limit=30): - num_pages = ceil(limit / 10) - res = [] - for page in range(num_pages): - out = self.cse.list(q=query, cx=self.cx, num=10, start=1+page*10).execute() - for e in out['items']: - url = make_url(e['formattedUrl']) - basename = os.path.basename(urlsplit(url).path) - m = re.match(r'^.+?\(Earth-(\d+)\)$', basename) - if m and m[1] != '616': - continue - domain = e['displayLink'] - site = domain.rstrip('.fandom.com') - m = re.match(r'^.+?\(.+?\)$', basename) - if m and site == 'batman': - continue - res.append({ - 'url': url, - 'domain': domain, - 'priority': site_priority.get(domain.rstrip('.fandom.com'), inf) - }) - return sorted(res, key=lambda e: e['priority']) +def search(query, limit=30): + num_pages = ceil(limit / 10) + res = [] + for page in range(num_pages): + out = _cse.list(q=query, cx=_cx, num=10, start=1+page*10).execute() + for e in out['items']: + url = URL(e['formattedUrl'], 'query') + m = re.match(r'^.+?\(Earth-(\d+)\)$', url.basename) + if m and m[1] != '616': + continue + m = re.match(r'^(.+?)_\(.+?\)$', url.basename) + if m and url.subdomain == 'batman': + continue + res.append(url) + return sorted(res, key=lambda e: site_priority.get(e.subdomain, inf)) diff --git a/semantic_similarity.py b/semantic_similarity.py new file mode 100644 index 0000000..04b3d74 --- /dev/null +++ b/semantic_similarity.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +from pywsd import disambiguate +#from pywsd.similarity import max_similarity as maxsim +import numpy as np +from collections import defaultdict +alpha = 0.2 +beta = 0.45 +benchmark_similarity = 0.8025 +gamma = 1.8 +""" +Semantic similarity based on the paper: + Calculating the similarity between words and sentences using a lexical database and corpus statistics +TKDE, 2018 +""" + +def _synset_similarity(s1,s2): + L1 =dict() + L2 =defaultdict(list) + + for syn1 in s1: + L1[syn1[0]] =list() + for syn2 in s2: + + subsumer = syn1[1].lowest_common_hypernyms(syn2[1], simulate_root=True)[0] + h =subsumer.max_depth() + 1 # as done on NLTK wordnet + syn1_dist_subsumer = syn1[1].shortest_path_distance(subsumer,simulate_root =True) + syn2_dist_subsumer = syn2[1].shortest_path_distance(subsumer,simulate_root =True) + l =syn1_dist_subsumer + syn2_dist_subsumer + f1 = np.exp(-alpha*l) + a = np.exp(beta*h) + b = np.exp(-beta*h) + f2 = (a-b) /(a+b) + sim = f1*f2 + L1[syn1[0]].append(sim) + L2[syn2[0]].append(sim) + return L1, L2 + +def similarity(s1,s2): + wsd = ( + [syn for syn in disambiguate(s) if syn[1]] + for s in (s1, s2) + ) + + #vector_length = max(len(s1_wsd), len(s2_wsd)) + + L = _synset_similarity(*wsd) + V1, V2 = ( + np.array([max(e[key]) for key in e.keys()]) + for e in L + ) + S = np.linalg.norm(V1)*np.linalg.norm(V2) + C1, C2 = ( + sum(V >= benchmark_similarity) + for V in (V1, V2) + ) + + Xi = (C1+C2) / gamma + + if C1+C2 == 0: + Xi = max(V1.size, V2.size) / 2 + return S/Xi diff --git a/util.py b/util.py deleted file mode 100644 index eed1af3..0000000 --- a/util.py +++ /dev/null @@ -1,129 +0,0 @@ -from urllib.parse import urlsplit, urlunsplit, urljoin - -import yaml, os, json -from pyquery import PyQuery as pq -import bareasgi - -class str_dict(dict): - def __setitem__(self, key, value): - if isinstance(key, bytes): - key = key.decode() - return super().__setitem__(key, value) - -class str_set(set): - def add(self, val): - if isinstance(val, bytes): - val = val.decode() - return super().add(val) - -def json_response(data, status=200, headers={}): - headers = [] # FIXME - return bareasgi.json_response(status, headers, data) - -from platform import python_version_tuple as get_pyversion -_pyversion = tuple(int(e) for e in get_pyversion()) - -if _pyversion >= (3, 6, 0): - OrderedDict = dict -else: - from collections import OrderedDict - -__dir__ = os.path.dirname(os.path.realpath(__file__)) -def realpath(path): - return os.path.join(__dir__, path) - -with open(realpath('config.yaml')) as f: - config = yaml.load(f, Loader=yaml.SafeLoader) - -project_id = os.getenv('GOOGLE_CLOUD_PROJECT', None) -if not project_id: - with open(realpath('account.json')) as f: - _account = json.load(f) - project_id = _account['project_id'] - -def make_url(url, base=None): - if base: - url = urljoin(base, url) - return urlunsplit((*urlsplit(url)[:3], '', '')) - -class OrderedSet: - def __init__(self, src=None): - super().__init__() - self._data = OrderedDict() - - if src: - self.update(src) - - def update(self, *sources): - for src in sources: - for item in src: - self.add(item) - return self - - def difference_update(self, src): - for src in sources: - for item in src: - self.discard(item) - return self - - def add(self, item): - if item in self: - return False - self._data[item] = None - return True - - def __contains__(self, item): - return item in self._data - - def discard(self, item): - if item in self: - return False - del self._data[item] - return True - - def remove(self, item): - if not self.discard(item): - raise KeyError - - def __iter__(self): - yield from self._data - - def clear(self): - if not self: - return False - self._data.clear() - return True - - def __bool__(self): - return bool(self._data) - - def __repr__(self): - return f"{{{', '.join(self)}}}" - - def __len__(self): - return len(self._data) - - def __getitem__(self, index): - res = list(self) - if index == slice(None): - return res - return res[index] - -def attach(target): - def deco(func): - setattr(target, func.__name__, func) - return func - return deco - -@attach(pq.fn) -def nextUntil(sel, filter=None): - res = OrderedSet() - for node in this.items(): - while True: - node = node.next() - if node.is_(sel) or not node: - break - if node.is_(filter): - res.add(node[0]) - - return pq(res[:]) diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000..f6609d4 --- /dev/null +++ b/util/__init__.py @@ -0,0 +1,43 @@ +import os +from platform import python_version_tuple as get_pyversion + +from pyquery import PyQuery as pq + +pyversion = tuple(int(e) for e in get_pyversion()) +if pyversion >= (3, 6, 0): + OrderedDict = dict +else: + from collections import OrderedDict + +__dir__ = os.path.normpath(os.path.dirname(os.path.realpath(__file__)) + '/..') + +def realpath(path): + return os.path.join(__dir__, path) + +def new(cls, *args, **kw): + if not isinstance(cls, type): + cls = type(cls) + return cls(*args, **kw) + +from .set import Set, OrderedSet +from .list import Tuple, List +from .url import URL + +def attach(target): + def deco(func): + setattr(target, func.__name__, func) + return func + return deco + +@attach(pq.fn) +def nextUntil(sel, filter=None): + res = OrderedSet() + for node in this.items(): + while True: + node = node.next() + if node.is_(sel) or not node: + break + if node.is_(filter): + res.add(node[0]) + + return pq(res[:]) diff --git a/util/list.py b/util/list.py new file mode 100644 index 0000000..41fbb59 --- /dev/null +++ b/util/list.py @@ -0,0 +1,66 @@ +from util import new, Set + + +class _List: + def __add__(self, item): + return new(self, [*self, item]) + + def __radd__(self, item): + return new(self, [item, *self]) + + def __and__(self, other): + return new(self, [*self, *other]) + + def __rand__(self, other): + return new(self, [*other, *self]) + + def __mul__(self, other): + return new(self, zip(self, other)) + + def __rmul__(self, other): + return new(self, zip(other, self)) + + def join(self, sep=''): + return sep.join(self) + + def range(self): + return range(len(self)) + + def get(self, index, default=None): + return self[index] if index in self.range() else default + + def index(self, item, default=None): + try: + return super().index(item) + except ValueError: + return default + +class Tuple(_List, tuple): + pass + +class List(_List, list): + def __init__(self, src=None, banned=None, **kw): + super().__init__() + self.__banned = Set(banned) + self.__format = kw.get('format', '{item}') + self._str = kw.get('str', ', ') + self.extend(src) + + def _format(self, item): + return self.__format.format(item=str(item)) + + def __str__(self): + return self._str.join(self._format(e) for e in self) + + def _banned(self, kw): + return self.__banned | kw.get('banned') + + def append(self, item, **kw): + if item not in self._banned(kw): + super().append(item) + + def extend(self, other, **kw): + if other is None: + return False + super().extend(e for e in other if e not in self._banned(kw)) + return True diff --git a/util/set.py b/util/set.py new file mode 100644 index 0000000..d35ce36 --- /dev/null +++ b/util/set.py @@ -0,0 +1,103 @@ +from util import OrderedDict + + +def _check(value): + if isinstance(value, bool): + return True + try: + iter(value) + return True + except TypeError: + return False + +class Set(set): + def __init__(self, src): + super().__init__() + self._bool = None + self.update(src) + + def clone(self): + res = Set(self) + res._bool = self._bool + return res + + def union(self, other): + if _check(other): + return self.clone().update(other) + return self + + def __contains__(self, item): + if super().__contains__(item): + return True + return self._bool is bool(item) + + def update(self, src): + if isinstance(src, bool): + self._bool = src + return True + if _check(src): + super().update(src) + return True + return False + +class OrderedSet: + def __init__(self, src=None): + super().__init__() + self._data = OrderedDict() + + if src: + self.update(src) + + def update(self, *sources): + for src in sources: + for item in src: + self.add(item) + return self + + def difference_update(self, src): + for item in src: + self.discard(item) + return self + + def add(self, item): + if item in self: + return False + self._data[item] = None + return True + + def __contains__(self, item): + return item in self._data + + def discard(self, item): + if item in self: + return False + del self._data[item] + return True + + def remove(self, item): + if not self.discard(item): + raise KeyError + + def __iter__(self): + yield from self._data + + def clear(self): + if not self: + return False + self._data.clear() + return True + + def __bool__(self): + return bool(self._data) + + def __repr__(self): + return f"{{{', '.join(self)}}}" + + def __len__(self): + return len(self._data) + + def __getitem__(self, index): + res = list(self) + if index == slice(None): + return res + return res[index] diff --git a/util/url.py b/util/url.py new file mode 100644 index 0000000..8e60c9e --- /dev/null +++ b/util/url.py @@ -0,0 +1,84 @@ +import os +from urllib.parse import urlsplit, urlunsplit, SplitResult + +from util import Tuple + + +class URL(object): + class _Dict(dict): + _extra = { 'basename', 'subdomain' } + + def regular(self, key): + return key in self and key not in self._extra + + def __setitem__(self, key, value): + if key is 'path': + super().__setitem__('basename', os.path.basename(value) if value else '') + elif key is 'netloc': + domain = value.rsplit('.', 2) + super().__setitem__('subdomain', domain[0] if len(domain) is 3 else '') + + return super().__setitem__(key, value) + + _keys = Tuple(SplitResult._fields) + basename = None + subdomain = None + + def __init__(self, url=None, cut=None): + self.__dict__ = self._Dict() + _keys = self._keys + if isinstance(url, str): + data = _keys * urlsplit(url) + elif isinstance(url, dict): + data = url + elif url is not None: + data = _keys * url + else: + data = [] + + index = len(_keys) + if isinstance(cut, str): + index = _keys.index(cut, index) + elif isinstance(cut, int): + index = cut + + _dict = dict(data) + for k in _keys[:index]: + self._set(k, _dict.get(k, '')) + for k in _keys[index:]: + self._set(k) + + def _set(self, key, value=''): + self.__dict__[key] = value + + def _regular(self, key): + return self.__dict__.regular(key) + + def _key(self, key): + return key if isinstance(key, str) else self._keys.get(key) + + def __getitem__(self, key): + return self.__dict__[self._key(key)] + + def __setitem__(self, key, value): + key = self._key(key) + if self._regular(key): + self._set(key, value) + + def __len__(self): + return len(self._keys) + + def __iter__(self): + for k in self._keys: + yield self[k] + + def __str__(self): + return urlunsplit(self) + + def format(self, *args, **kw): + return str(self).format(*args, **kw) + + def __repr__(self): + _name = self.__class__.__name__ + data = (f"{k}='{self[k]}'" for k in self._keys) + return f"{_name}({', '.join(data)})"