diff --git a/requirements.txt b/requirements.txt index 4982de1..dd6071f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -tqdm==4.26.0 -pytorch_pretrained_bert==0.6.1 -torch==1.0.0 -numpy==1.15.1 +tqdm>=4.27 +transformers==4.8.1 +torch==1.9.0 +numpy>=1.17 spacy==2.1.3 scipy==1.1.0 scikit_learn==0.21.1 diff --git a/wsi/WSISettings.py b/wsi/WSISettings.py index c9c14f7..9b7d951 100644 --- a/wsi/WSISettings.py +++ b/wsi/WSISettings.py @@ -2,7 +2,7 @@ WSISettings = namedtuple('WSISettings', ['n_represents', 'n_samples_per_rep', 'cuda_device', 'debug_dir', 'disable_tfidf', 'disable_lemmatization', 'run_name', 'patterns', - 'min_sense_instances', 'bert_model', + 'min_sense_instances', 'bert_model', 'spacy_lang', 'max_batch_size', 'prediction_cutoff', 'max_number_senses', ]) @@ -26,6 +26,7 @@ # sense clusters that dominate less than this number of samples # would be remapped to their closest big sense + spacy_lang="en", max_batch_size=10, prediction_cutoff=200, bert_model='bert-large-uncased' diff --git a/wsi/lm_bert.py b/wsi/lm_bert.py index ba6e2e4..5186ba5 100644 --- a/wsi/lm_bert.py +++ b/wsi/lm_bert.py @@ -1,6 +1,6 @@ from .slm_interface import SLM import multiprocessing -from pytorch_pretrained_bert import BertForMaskedLM, tokenization +from transformers import BertForMaskedLM, BertTokenizer import torch import numpy as np from tqdm import tqdm @@ -26,7 +26,7 @@ def get_batches(from_iter, group_size): class LMBert(SLM): - def __init__(self, cuda_device, bert_model, max_batch_size=20): + def __init__(self, cuda_device, bert_model, spacy_lang="en", max_batch_size=20): super().__init__() logging.info( 'creating bert in device %d. bert ath %s' @@ -43,7 +43,7 @@ def __init__(self, cuda_device, bert_model, max_batch_size=20): model.eval() self.bert = model - self.tokenizer = tokenization.BertTokenizer.from_pretrained(bert_model) + self.tokenizer = BertTokenizer.from_pretrained(bert_model) self.max_sent_len = model.config.max_position_embeddings # self.max_sent_len = config.max_position_embeddings @@ -54,7 +54,7 @@ def __init__(self, cuda_device, bert_model, max_batch_size=20): self.original_vocab = [] import spacy - nlp = spacy.load("en", disable=['ner', 'parser']) + nlp = spacy.load(spacy_lang, disable=['ner', 'parser']) self._lemmas_cache = {} self._spacy = nlp for spacyed in tqdm( @@ -141,7 +141,7 @@ def predict_sent_substitute_representatives(self, inst_id_to_sentence: Dict[str, torch_mask = torch_input_ids != 0 - logits_all_tokens = self.bert(torch_input_ids, attention_mask=torch_mask) + logits_all_tokens = self.bert(torch_input_ids, attention_mask=torch_mask).logits logits_target_tokens = torch.zeros((len(batch_sents), logits_all_tokens.shape[2])).to(self.device) for i in range(0, len(batch_sents)): diff --git a/wsi/wsi.py b/wsi/wsi.py index d79c672..2352c5b 100644 --- a/wsi/wsi.py +++ b/wsi/wsi.py @@ -124,3 +124,58 @@ def run(self, wsisettings: WSISettings, print(msg) return scores2010['all'], scores2013['all'] + + +from typing import Dict, List, Tuple + + +def perform_wsi_on_ds_gen( + lm: SLM, + ds_name: str, + gen: List[Tuple[str, str, str, str]], + wsisettings: WSISettings, + print_progress=False, +) -> Dict[str, Dict[str, int]]: + + ds_by_target = defaultdict(dict) + for pre, target, post, inst_id in gen: + lemma_pos = inst_id.rsplit('.', 1)[0] + ds_by_target[lemma_pos][inst_id] = (pre, target, post) + + inst_id_to_sense = {} + gen = ds_by_target.items() + if print_progress: + gen = tqdm(gen, desc=f'predicting substitutes {ds_name}') + for lemma_pos, inst_id_to_sentence in gen: + inst_ids_to_representatives = \ + lm.predict_sent_substitute_representatives( + inst_id_to_sentence=inst_id_to_sentence, + wsisettings=wsisettings, + ) + + clusters, statistics = cluster_inst_ids_representatives( + inst_ids_to_representatives=inst_ids_to_representatives, + max_number_senses=wsisettings.max_number_senses, + min_sense_instances=wsisettings.min_sense_instances, + disable_tfidf=wsisettings.disable_tfidf, + explain_features=True, + ) + inst_id_to_sense.update(clusters) + if statistics: + logging.info('Sense cluster statistics:') + for idx, (rep_count, best_features, best_features_pmi, best_instance_id) in enumerate(statistics): + best_instance = ds_by_target[lemma_pos][best_instance_id] + nice_print_instance = f'{best_instance[0]} -{best_instance[1]}- {best_instance[2]}' + logging.info( + f'Sense {idx}, # reps: {rep_count}, best feature words: {", ".join(best_features)}.' + f', best feature words(PMI): {", ".join(best_features_pmi)}.' + f' closest instance({best_instance_id}):\n---\n{nice_print_instance}\n---\n') + + out_key_path = None + if wsisettings.debug_dir: + out_key_path = os.path.join(wsisettings.debug_dir, f'{wsisettings.run_name}-{ds_name}.key') + + if print_progress: + print(f'writing {ds_name} key file to %s' % out_key_path) + + return inst_id_to_sense diff --git a/wsi_bert.py b/wsi_bert.py index b1f364a..a799145 100644 --- a/wsi_bert.py +++ b/wsi_bert.py @@ -34,6 +34,7 @@ startmsg = startmsg.strip() lm = LMBert(settings.cuda_device, settings.bert_model, + spacy_lang=settings.spacy_lang, max_batch_size=settings.max_batch_size) if settings.debug_dir: