diff --git a/README.md b/README.md index 1f5cb0b..041788c 100755 --- a/README.md +++ b/README.md @@ -20,9 +20,6 @@ Transform texts in a hundred different [languages](https://github.com/artitw/tex * [Index](https://github.com/artitw/text2text#index) * [Distance](https://github.com/artitw/text2text#levenshtein-sub-word-edit-distance) * [Translation](https://github.com/artitw/text2text#translation) - * [Question Answering](https://github.com/artitw/text2text#question-answering) - * [Question Generation](https://github.com/artitw/text2text#question-generation) - * [Summarization](https://github.com/artitw/text2text#summarization) * [Data Augmentation](https://github.com/artitw/text2text#data-augmentation--back-translation) * [Finetuning](https://github.com/artitw/text2text#training--finetuning) * [Identification](https://github.com/artitw/text2text#identification) @@ -60,10 +57,7 @@ Module Importing | `import text2text as t2t` | Libraries imported [BM25](https://github.com/artitw/text2text#bm25) | `t2t.Bm25er().transform(["Hello, World!"])` | `[{'!': 0.3068528194400547, ',': 0.3068528194400547, '▁Hello': 0.3068528194400547, '▁World': 0.3068528194400547}]` [Indexer](https://github.com/artitw/text2text#index) | `index = t2t.Indexer().transform(["Hello, World!"])` | Index object for information retrieval [Translation](https://github.com/artitw/text2text#translation) | `t2t.Translater().transform(["Hello, World!"], src_lang="en, tgt_lang="zh")` | `['你好,世界!']` -[Question Generation](https://github.com/artitw/text2text#question-generation) | `t2t.Questioner().transform(["Hello, World!"], src_lang="en)` | `[('What is the name of the world you are in?', 'The world')]` -[Summarization](https://github.com/artitw/text2text#summarization) | `t2t.Summarizer().transform(["Hello, World!"], src_lang="en)` | `["World ' s largest world"]` [Data Augmentation](https://github.com/artitw/text2text#data-augmentation--back-translation) | `t2t.Variator().transform(["Hello, World!"], src_lang="en)` | `['Hello the world!', 'Welcome to the world.', 'Hello to the world!',...` -[Question Answering](https://github.com/artitw/text2text#question-answering) | `t2t.Answerer().transform(["Hello, World! [SEP] Hello, what?"], src_lang="en")` | `['World']` [Distance](https://github.com/artitw/text2text#levenshtein-sub-word-edit-distance) | `t2t.Measurer().transform(["Hello, World! [SEP] Hello, what?"])` | `[2]` [Training/Finetuning](https://github.com/artitw/text2text#training--finetuning) | `t2t.Fitter().transform(["Hello, World! [TGT] Hello, what?"])` | Finetuned model saved [Identification](https://github.com/artitw/text2text#identification) | `t2t.Identifier().transform(["Aj keď sa Buzz Aldrin stal až „druhým človekom“..."])` | `['sk', 'Slovak']` @@ -225,7 +219,7 @@ class Song(BaseModel): result = asst.chat_completion([ {"role": "user", "content": "What is Britney Spears's best song?"} -], schema=Song) +], schema=Song) # Song(name='Toxic', artist='Britney Spears') # Embeddings @@ -456,76 +450,6 @@ t2t.Translator().transform( -### Question Answering -Question must follow context with ` [SEP] ` in between. -``` -t2t.Answerer().transform([ - "Hello, this is Text2Text! [SEP] What is this?", - "It works very well. It's awesome! [SEP] How is it?" -]) - -t2t.Answerer().transform([ - "很喜欢陈慧琳唱歌。[SEP] 喜欢做什么?" -], src_lang="zh") - -# Answers -['Text2Text', 'awesome'] -['唱歌'] -``` - -### Question Generation -``` -t2t.Questioner().transform(["很喜欢陈慧琳唱歌。"], src_lang='zh') -t2t.Questioner().transform([ - bio_str, - bio_str, - bio_str, - bio_str, - bio_str, - "I will go to school today to take my math exam.", - "I will go to school today to take my math exam.", - "Tomorrow is my cousin's birthday. He will turn 24 years old.", - notre_dame_str, - bacteria_str, - bacteria_str, - bacteria_str, - "I will go to school today to take my math exam. [SEP] school", - "I will go to school today to take my math exam. [SEP] exam", - "I will go to school today to take my math exam. [SEP] math", -], src_lang='en') - -``` -Note that the last three answers were controlled by specifying the `[SEP]` token in the input above. -``` -# Questions -[('我喜欢做什么?', '唱歌')] -[('What is biology the science that studies?', 'life'), - ('What is the study of life?', 'studies'), - ('What would you find the question " life "?', 'sound'), - ('What can viruses do to living organisms?', 'attack'), - ('What is the study of life?', 'studies'), - ('Where will I go to to take my math exam?', 'school'), - ('Where will I go to to take my math exam?', 'school'), - ("What will my cousin's birthday?", 'turn'), - ('What type of oversight does The Observer not have?', 'editorial'), - ('What shape can bacteria be found in?', 'rods'), - ('What is the typical length of bacteria?', 'micrometres'), - ('What is the typical length of bacteria?', 'micrometres'), - ('Where will I go to to take my math exam?', 'school'), - ('What will I take after school?', 'exam'), - ('What exam will I take?', 'math')] -``` - -### Summarization -``` -t2t.Summarizer().transform([notre_dame_str, bacteria_str, bio_str], src_lang='en') - -# Summaries -["Notre Dame's students run nine student - run outlets . [X_SEP] Scholastic magazine claims to be the oldest continuous collegiate publication in the United States . [X_SEP] The Observer is an independent publication .", - 'Bacteria were among the first life forms to appear on Earth .', - 'biology is the science that studies life .'] -``` - ### Data Augmentation / Back-Translation Back-translations useful for augmenting training data ``` diff --git a/demos/Text2Text_Demos.ipynb b/demos/Text2Text_Demos.ipynb index e988776..6bcd3a4 100644 --- a/demos/Text2Text_Demos.ipynb +++ b/demos/Text2Text_Demos.ipynb @@ -531,99 +531,6 @@ "execution_count": null, "outputs": [] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jFWWM9jFTAat", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "fc7c97d6-2b04-4e38-abaa-6a349d66b68a" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "100%|██████████| 213450/213450 [00:00<00:00, 321640.64B/s]\n", - "100%|██████████| 1242874899/1242874899 [01:49<00:00, 11337284.32B/s]\n", - "100%|██████████| 1/1 [00:03<00:00, 3.80s/it]\n" - ] - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "['Hello , World! Hello , World! Hello , World! Hello , World! Hello , World! Hello , World! Hello , World! Hello , World! Hello , World! Hello , World! Hello , World! Hello , World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World! World!']" - ] - }, - "metadata": {}, - "execution_count": 9 - } - ], - "source": [ - "summarizer = t2t.Summarizer()\n", - "summarizer.transform([\"Hello, World!\"], src_lang=\"en\") #[\"World ' s largest world\"]" - ] - }, - { - "cell_type": "code", - "source": [ - "del summarizer" - ], - "metadata": { - "id": "wqnjVpig5mps" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "T6bEcoNCTBt7", - "outputId": "e8b58a21-ff9f-4bae-c381-0340b8107ad8" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "100%|██████████| 1/1 [00:01<00:00, 1.41s/it]\n" - ] - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "[('What is the name of the Hello , World!', 'world')]" - ] - }, - "metadata": {}, - "execution_count": 10 - } - ], - "source": [ - "questioner = t2t.Questioner()\n", - "questioner.transform([\"Hello, World!\"], src_lang=\"en\")\n", - "#[('What is the name of the world you are in?', 'The world')]" - ] - }, - { - "cell_type": "code", - "source": [ - "del questioner" - ], - "metadata": { - "id": "HYQToFAP5nuu" - }, - "execution_count": null, - "outputs": [] - }, { "cell_type": "code", "execution_count": null, @@ -763,203 +670,6 @@ "execution_count": null, "outputs": [] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "referenced_widgets": [ - "a6e4771ecffb42ec81d24699e1a7d42d", - "580822060dea4a0180e47f8ddf8f1cbe", - "667b64ddc26c48629f1edc2a67171ced", - "d382ff8acf544f9aa2e50ba95942c6d7", - "fb9751542f8448a787537ed4c79fd6f0", - "806f718f6d6d461e8f1165b5afe94c52", - "87bffb9eb0c6449e8117c600e9091c16", - "3a6fcef9dfd04d0e8c34d1d4e2d03826", - "168fe0de904b4396921f0b7235b0f5e3", - "685b31392dae473296125e42323f5fda", - "dd42ad5f3fea4467868b6d94a1bbc8c6", - "e3a63d656ce5487d89eea37326033140", - "d01e01c16d3945fda05d2551f82762a0", - "1a3f8943ac2e436f8d31bc27fc952d6d", - "07fb63fe75b648dba86200779dca502e", - "c20f4c32680c497aafb12dc81125711c", - "fc6ccd7d0ab94bfca8dd9af6ced8b0e2", - "4a7710e1baf14a47a8c3e81130802e4c", - "7457df1e20fa4c24ac50c3b3049158e5", - "79887031de6b4853b5c96f9bc20498d3", - "b8e4a237e1a54301b8d7c34716e88e04", - "5d125743b1e7423fa912be21473a25da", - "7efa8ba623534414a229c99b4fbbb42e", - "c41dbdbb1fc7499d86c0e00b7aad7e6e", - "5f4a9330bea84f9f802a151d6ab15bfd", - "a181e9f5de0a4421807dabc4d26b21a0", - "14fed89053784858b9882de493521fab", - "588371472a864470be59e60263939024", - "fa9e2b2d05704055b605dddfff1c1059", - "fcbf062448084d8abe420f8d78ec1dbf", - "d438e96471914d7d9c43f080e15374dd", - "a4d6be53b6ea4058838d5fab89ecc89b", - "81877e5f311f46bb956cb79287e16116", - "c3aa57d233424fcdbd22c68c82df36f4", - "529bd1b6a2f442b9a596827eac574545", - "a8df66f3f7e4454fa719f0b5d65f30cb", - "58c7474f716d4d54aedadd67d6036a60", - "2f869d86cb71420089f1ce6fced96998", - "0173736e3f04409b9fb315a6930de4f3", - "b76b605c05b9437394a828a824db2463", - "6a4a0dff8a434cacb026bc432136ca3f", - "308824685f57438faf627b30fc6c1812", - "71735ef41cba4f769ea5141c2baf0e34", - "63b6d654ac084e3c932d184f866d4a2f", - "13277ee1a96741d1900328f11ed29f12", - "a9fd9d7e47624c4e993e325ebae01a69", - "952faee482c247f5a61d22e7792d2086", - "119379a3e3e449408e3103c1c104511a", - "c41453eb477f4ce590f75a8f0340056b", - "d5e5ff8a033d4828b1855417ca824133", - "6867a31a4877455082fa0b9a386106f9", - "d2bf9877f1ee4a67be2162b5b0d2ad3b", - "b8e671a962094f6d830c437b0c2637e3", - "7d5e2b2596704cac96de72a2ba3e3909", - "a4aa711ac7ad4285b32bfe38762c8e17", - "050ffd14831c4299abb33fc632036218", - "fb1764c86a1641c58db34168793a6f11", - "bd20e40c38a24902a584418a39b48bd8", - "1f079a7ccb2a4b48a02c1bbf56a9d157", - "f1aa7f2ba1d5425d84b683a69a058f97", - "e3ff967da05043ca9fe038252ebb64b1", - "f9a0bbd58f874a6fba2d2f82850c4dcd", - "fe4bd2a7d19a41a4967994a5414aa133", - "b0d389b36b984b9d8b75e6f56722e1b4", - "b891150ee6204bf4a4cc0d5bb1a1b54b", - "ac9282ec8a7f4f75b369edf1f613a708" - ], - "base_uri": "https://localhost:8080/", - "height": 226 - }, - "id": "1IbuXQ0HTEbc", - "outputId": "acb4caa4-dc69-4406-a90e-f35149e5021b" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "Downloading: 0%| | 0.00/26.0 [00:00 0: - r_list[-1] = r_list[-1] + tk[2:] - else: - r_list.append(tk) - return r_list - - def _download_pretrained_model(self): - pretrained_parameters = self.__class__.pretrained_parameters - if os.path.isfile(pretrained_parameters["model_recover_path"]): - return - s = requests.session() - file_id = pretrained_parameters["file_id"] - r = s.get(f'https://docs.google.com/uc?export=download&id={file_id}&confirm=t') - z = zipfile.ZipFile(io.BytesIO(r.content)) - z.extractall() - - def _get_token_id_set(self, s): - r = None - if s: - w_list = [] - for w in s.split('|'): - if w.startswith('[') and w.endswith(']'): - w_list.append(w.upper()) - else: - w_list.append(w) - r = set(self.__class__.tokenizer.convert_tokens_to_ids(w_list)) - return r - - def __init__(self, **kwargs): - self.__class__.pretrained_translator = kwargs.get('pretrained_translator') - pretrained_parameters = self.__class__.pretrained_parameters - if pretrained_parameters["max_tgt_length"] >= pretrained_parameters["max_seq_length"] - 2: - raise ValueError("Maximum tgt length exceeds max seq length - 2.") - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - n_gpu = torch.cuda.device_count() - - seed = self.__class__.SEED - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if n_gpu > 0: - torch.cuda.manual_seed_all(seed) - tokenizer = BertTokenizer.from_pretrained(pretrained_parameters["bert_model"], **pretrained_parameters) - tokenizer.max_len = pretrained_parameters["max_seq_length"] - - pair_num_relation = 0 - bi_uni_pipeline = [] - bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder(list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, pretrained_parameters["max_seq_length"], **pretrained_parameters)) - - # Prepare model - cls_num_labels = 2 - type_vocab_size = 6 + \ - (1 if pretrained_parameters["s2s_add_segment"] else 0) if pretrained_parameters["new_segment_ids"] else 2 - mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(["[MASK]", "[SEP]", "[S2S_SOS]"]) - - self.__class__.tokenizer = tokenizer - - forbid_ignore_set = self._get_token_id_set(pretrained_parameters["forbid_ignore_word"]) - not_predict_set = self._get_token_id_set(pretrained_parameters["not_predict_token"]) - - self._download_pretrained_model() - - model_recover_path = glob.glob(pretrained_parameters["model_recover_path"].strip())[0] - map_device = None - if not torch.cuda.is_available(): - map_device='cpu' - model_recover = torch.load(model_recover_path,map_location=map_device) - pretrained_parameters["max_position_embeddings"] = pretrained_parameters["max_seq_length"] - params = {k: v for k, v in pretrained_parameters.items() if k in BertForSeq2SeqDecoder.__dict__} - model = BertForSeq2SeqDecoder.from_pretrained(pretrained_parameters["bert_model"], state_dict=model_recover, num_labels=cls_num_labels, num_rel=pair_num_relation, type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, - eos_id=eos_word_ids, sos_id=sos_word_id, forbid_ignore_set=forbid_ignore_set, not_predict_set=not_predict_set, **params) - del model_recover - - if pretrained_parameters["fp16"]: - model.half() - model.to(device) - if n_gpu > 1: - model = torch.nn.DataParallel(model) - torch.cuda.empty_cache() - model.eval() - - self.__class__.device = device - self.__class__.model = model - self.__class__.bi_uni_pipeline = bi_uni_pipeline - - def _translate_lines(self, input_lines, src_lang, tgt_lang): - translator = getattr(self.__class__, "translator", t2t.Translator(pretrained_translator=self.__class__.pretrained_translator)) - self.__class__.translator = translator - return translator.transform(input_lines, src_lang=src_lang, tgt_lang=tgt_lang) - - def transform(self, input_lines, src_lang='en', **kwargs): - input_lines = t2t.Transformer.transform(self, input_lines, src_lang, **kwargs) - if src_lang != 'en': - input_lines = self._translate_lines(input_lines, src_lang, 'en') - - pretrained_parameters = self.__class__.pretrained_parameters - tokenizer = self.__class__.tokenizer - model = self.__class__.model - bi_uni_pipeline = self.__class__.bi_uni_pipeline - device = self.__class__.device - - max_src_length = pretrained_parameters["max_seq_length"] - 2 - pretrained_parameters["max_tgt_length"] - input_lines = [tokenizer.tokenize(x)[:max_src_length] for x in input_lines] - input_lines = sorted(list(enumerate(input_lines)), key=lambda x: -len(x[1])) - output_lines = [""] * len(input_lines) - score_trace_list = [None] * len(input_lines) - total_batch = math.ceil(len(input_lines) / pretrained_parameters["batch_size"]) - next_i = 0 - with tqdm(total=total_batch) as pbar: - while next_i < len(input_lines): - _chunk = input_lines[next_i:next_i + pretrained_parameters["batch_size"]] - buf_id = [x[0] for x in _chunk] - buf = [x[1] for x in _chunk] - next_i += pretrained_parameters["batch_size"] - max_a_len = max([len(x) for x in buf]) - instances = [] - for instance in [(x, max_a_len) for x in buf]: - for proc in bi_uni_pipeline: - instances.append(proc(instance)) - with torch.no_grad(): - batch = seq2seq_loader.batch_list_to_batch_tensors(instances) - batch = [t.to(device) if t is not None else None for t in batch] - input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch - traces = model(input_ids, token_type_ids, - position_ids, input_mask, task_idx=task_idx, mask_qkv=mask_qkv) - output_ids = traces.tolist() - for i in range(len(buf)): - w_ids = output_ids[i] - output_buf = tokenizer.convert_ids_to_tokens(w_ids) - output_tokens = [] - for t in output_buf: - if t in ("[SEP]", "[PAD]"): - break - output_tokens.append(t) - output_sequence = ' '.join(self._detokenize(output_tokens)) - output_sequence = re.sub(r'\s([?.!"](?:\s|$))', r'\1', output_sequence) - output_sequence = re.sub(r"\b\s+'\b", r"'", output_sequence) - output_sequence = output_sequence.replace("[X_SEP]", "") - output_lines[buf_id[i]] = output_sequence - pbar.update(1) - - if src_lang != 'en': - output_lines = self._translate_lines(output_lines, src_lang='en', tgt_lang=src_lang) - - return output_lines \ No newline at end of file diff --git a/text2text/answerer.py b/text2text/answerer.py deleted file mode 100644 index 940a088..0000000 --- a/text2text/answerer.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch -import text2text as t2t - -from transformers import AutoModelForQuestionAnswering, AutoTokenizer - -class Answerer(t2t.Transformer): - pretrained_answerer = "valhalla/longformer-base-4096-finetuned-squadv1" - - def __init__(self, **kwargs): - pretrained_answerer = kwargs.get('pretrained_answerer') - if not pretrained_answerer: - pretrained_answerer = self.__class__.pretrained_answerer - self.__class__.tokenizer = AutoTokenizer.from_pretrained(pretrained_answerer) - self.__class__.model = AutoModelForQuestionAnswering.from_pretrained(pretrained_answerer, device_map="auto", load_in_8bit=True) - - def _translate_lines(self, input_lines, src_lang, tgt_lang): - translator = getattr(self.__class__, "translator", t2t.Translator()) - self.__class__.translator = translator - return translator.transform(input_lines, src_lang=src_lang, tgt_lang=tgt_lang) - - def _get_answers(self, input_lines): - tokenizer = self.__class__.tokenizer - model = self.__class__.model - num_examples = len(input_lines) - encoded_inputs = tokenizer.batch_encode_plus(input_lines, padding=True, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") - input_ids = encoded_inputs["input_ids"] - attention_mask = encoded_inputs["attention_mask"] - results = model(input_ids, attention_mask=attention_mask) - ans_ids = [None] * num_examples - for i in range(num_examples): - max_startscore = torch.argmax(results["start_logits"][i]) - max_endscore = torch.argmax(results["end_logits"][i]) - ans_ids[i] = input_ids[i][max_startscore:max_endscore+1] - answers = tokenizer.batch_decode(ans_ids, skip_special_tokens=True) - answers = [a.strip() for a in answers] - return answers - - def transform(self, input_lines, src_lang='en', **kwargs): - input_lines = t2t.Transformer.transform(self, input_lines, src_lang, **kwargs) - if src_lang != 'en': - input_lines = self._translate_lines(input_lines, src_lang, 'en') - - input_lines = [line.split(" [SEP] ")[::-1] for line in input_lines] - output_lines = self._get_answers(input_lines) - - if src_lang != 'en': - output_lines = self._translate_lines(output_lines, src_lang='en', tgt_lang=src_lang) - - return output_lines \ No newline at end of file diff --git a/text2text/assistant.py b/text2text/assistant.py index 652b2e2..bd3c513 100644 --- a/text2text/assistant.py +++ b/text2text/assistant.py @@ -11,7 +11,7 @@ def __init__(self, **kwargs): self.host = kwargs.get("host", "http://localhost") self.port = kwargs.get("port", 11434) self.model_url = f"{self.host}:{self.port}" - self.model_name = kwargs.get("model_name", "llama3.1") + self.model_name = kwargs.get("model_name", "llama3.2") self.load_model() self.client = ollama.Client(host=self.model_url) self.structured_client = Ollama(model=self.model_name, request_timeout=120.0) diff --git a/text2text/biunilm/__init__.py b/text2text/biunilm/__init__.py deleted file mode 100755 index 33c63fe..0000000 --- a/text2text/biunilm/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import seq2seq_loader -from .loader_utils import get_random_word, batch_list_to_batch_tensors, Pipeline diff --git a/text2text/biunilm/loader_utils.py b/text2text/biunilm/loader_utils.py deleted file mode 100755 index ba67101..0000000 --- a/text2text/biunilm/loader_utils.py +++ /dev/null @@ -1,299 +0,0 @@ -from random import randint, shuffle -from random import random as rand -import numpy as np - -import torch -import torch.utils.data - - -def get_random_word(vocab_words): - i = randint(0, len(vocab_words)-1) - return vocab_words[i] - - -def batch_list_to_batch_tensors(batch): - batch_tensors = [] - for x in zip(*batch): - if x[0] is None: - batch_tensors.append(None) - elif isinstance(x[0], torch.Tensor): - batch_tensors.append(torch.stack(x)) - else: - batch_tensors.append(torch.tensor(x, dtype=torch.long)) - return batch_tensors - - -class TrieNode(object): - def __init__(self): - self.children = {} - self.is_leaf = False - - def try_get_children(self, key): - if key not in self.children: - self.children[key] = TrieNode() - return self.children[key] - - -class TrieTree(object): - def __init__(self): - self.root = TrieNode() - - def add(self, tokens): - r = self.root - for token in tokens: - r = r.try_get_children(token) - r.is_leaf = True - - def get_pieces(self, tokens, offset): - pieces = [] - r = self.root - token_id = 0 - last_valid = 0 - match_count = 0 - while last_valid < len(tokens): - if token_id < len(tokens) and tokens[token_id] in r.children: - r = r.children[tokens[token_id]] - match_count += 1 - if r.is_leaf: - last_valid = token_id - token_id += 1 - else: - pieces.append( - list(range(token_id - match_count + offset, last_valid + 1 + offset))) - last_valid += 1 - token_id = last_valid - r = self.root - match_count = 0 - - return pieces - - -def _get_word_split_index(tokens, st, end): - split_idx = [] - i = st - while i < end: - if (not tokens[i].startswith('##')) or (i == st): - split_idx.append(i) - i += 1 - split_idx.append(end) - return split_idx - - -def _expand_whole_word(tokens, st, end): - new_st, new_end = st, end - while (new_st >= 0) and tokens[new_st].startswith('##'): - new_st -= 1 - while (new_end < len(tokens)) and tokens[new_end].startswith('##'): - new_end += 1 - return new_st, new_end - - -class Pipeline(): - """ Pre-process Pipeline Class : callable """ - - def __init__(self): - super().__init__() - self.skipgram_prb = None - self.skipgram_size = None - self.pre_whole_word = None - self.mask_whole_word = None - self.word_subsample_prb = None - self.sp_prob = None - self.pieces_dir = None - self.vocab_words = None - self.pieces_threshold = 10 - self.trie = None - self.call_count = 0 - self.offline_mode = False - self.skipgram_size_geo_list = None - self.span_same_mask = False - - def init_skipgram_size_geo_list(self, p): - if p > 0: - g_list = [] - t = p - for _ in range(self.skipgram_size): - g_list.append(t) - t *= (1-p) - s = sum(g_list) - self.skipgram_size_geo_list = [x/s for x in g_list] - - def create_trie_tree(self, pieces_dir): - print("sp_prob = {}".format(self.sp_prob)) - print("pieces_threshold = {}".format(self.pieces_threshold)) - if pieces_dir is not None: - self.trie = TrieTree() - pieces_files = [pieces_dir] - for token in self.vocab_words: - self.trie.add([token]) - for piece_file in pieces_files: - print("Load piece file: {}".format(piece_file)) - with open(piece_file, mode='r', encoding='utf-8') as reader: - for line in reader: - parts = line.split('\t') - if int(parts[-1]) < self.pieces_threshold: - pass - tokens = [] - for part in parts[:-1]: - tokens.extend(part.split(' ')) - self.trie.add(tokens) - - def __call__(self, instance): - raise NotImplementedError - - # pre_whole_word: tokenize to words before masking - # post whole word (--mask_whole_word): expand to words after masking - def get_masked_pos(self, tokens, n_pred, add_skipgram=False, mask_segment=None, protect_range=None): - if self.pieces_dir is not None and self.trie is None: - self.create_trie_tree(self.pieces_dir) - if self.pre_whole_word: - if self.trie is not None: - pieces = self.trie.get_pieces(tokens, 0) - - new_pieces = [] - for piece in pieces: - if len(new_pieces) > 0 and tokens[piece[0]].startswith("##"): - new_pieces[-1].extend(piece) - else: - new_pieces.append(piece) - del pieces - pieces = new_pieces - - pre_word_split = list(_[-1] for _ in pieces) - pre_word_split.append(len(tokens)) - else: - pre_word_split = _get_word_split_index(tokens, 0, len(tokens)) - index2piece = None - else: - pre_word_split = list(range(0, len(tokens)+1)) - - if self.trie is not None: - pieces = self.trie.get_pieces(tokens, 0) - - index2piece = {} - for piece in pieces: - for index in piece: - index2piece[index] = (piece[0], piece[-1]) - else: - index2piece = None - - span_list = list(zip(pre_word_split[:-1], pre_word_split[1:])) - - # candidate positions of masked tokens - cand_pos = [] - special_pos = set() - if mask_segment: - for i, sp in enumerate(span_list): - sp_st, sp_end = sp - if (sp_end-sp_st == 1) and tokens[sp_st].endswith('SEP]'): - segment_index = i - break - for i, sp in enumerate(span_list): - sp_st, sp_end = sp - if (sp_end-sp_st == 1) and (tokens[sp_st].endswith('CLS]') or tokens[sp_st].endswith('SEP]')): - special_pos.add(i) - else: - if mask_segment: - if ((i < segment_index) and ('a' in mask_segment)) or ((i > segment_index) and ('b' in mask_segment)): - cand_pos.append(i) - else: - cand_pos.append(i) - shuffle(cand_pos) - - masked_pos = set() - for i_span in cand_pos: - if len(masked_pos) >= n_pred: - break - cand_st, cand_end = span_list[i_span] - if len(masked_pos)+cand_end-cand_st > n_pred: - continue - if any(p in masked_pos for p in range(cand_st, cand_end)): - continue - - n_span = 1 - if index2piece is not None: - p_start, p_end = index2piece[i_span] - if p_start < p_end and (rand() < self.sp_prob): - # n_span = p_end - p_start + 1 - st_span, end_span = p_start, p_end + 1 - else: - st_span, end_span = i_span, i_span + 1 - else: - rand_skipgram_size = 0 - # ngram - if self.skipgram_size_geo_list: - # sampling ngram size from geometric distribution - rand_skipgram_size = np.random.choice( - len(self.skipgram_size_geo_list), 1, p=self.skipgram_size_geo_list)[0] + 1 - else: - if add_skipgram and (self.skipgram_prb > 0) and (self.skipgram_size >= 2) and (rand() < self.skipgram_prb): - rand_skipgram_size = min( - randint(2, self.skipgram_size), len(span_list)-i_span) - for n in range(2, rand_skipgram_size+1): - tail_st, tail_end = span_list[i_span+n-1] - if (tail_end-tail_st == 1) and (tail_st in special_pos): - break - if len(masked_pos)+tail_end-cand_st > n_pred: - break - n_span = n - st_span, end_span = i_span, i_span + n_span - - if self.mask_whole_word: - # pre_whole_word==False: position index of span_list is the same as tokens - st_span, end_span = _expand_whole_word( - tokens, st_span, end_span) - - # subsampling according to frequency - if self.word_subsample_prb: - skip_pos = set() - if self.pre_whole_word: - w_span_list = span_list[st_span:end_span] - else: - split_idx = _get_word_split_index( - tokens, st_span, end_span) - w_span_list = list( - zip(split_idx[:-1], split_idx[1:])) - for i, sp in enumerate(w_span_list): - sp_st, sp_end = sp - if sp_end-sp_st == 1: - w_cat = tokens[sp_st] - else: - w_cat = ''.join(tokens[sp_st:sp_end]) - if (w_cat in self.word_subsample_prb) and (rand() < self.word_subsample_prb[w_cat]): - for k in range(sp_st, sp_end): - skip_pos.add(k) - else: - skip_pos = None - - for sp in range(st_span, end_span): - for mp in range(span_list[sp][0], span_list[sp][1]): - if not(skip_pos and (mp in skip_pos)) and (mp not in special_pos) and not(protect_range and (protect_range[0] <= mp < protect_range[1])): - masked_pos.add(mp) - - if len(masked_pos) < n_pred: - shuffle(cand_pos) - for pos in cand_pos: - if len(masked_pos) >= n_pred: - break - if pos not in masked_pos: - masked_pos.add(pos) - masked_pos = list(masked_pos) - if len(masked_pos) > n_pred: - # shuffle(masked_pos) - masked_pos = masked_pos[:n_pred] - return masked_pos - - def replace_masked_tokens(self, tokens, masked_pos): - if self.span_same_mask: - masked_pos = sorted(list(masked_pos)) - prev_pos, prev_rand = None, None - for pos in masked_pos: - if self.span_same_mask and (pos-1 == prev_pos): - t_rand = prev_rand - else: - t_rand = rand() - if t_rand < 0.8: # 80% - tokens[pos] = '[MASK]' - elif t_rand < 0.9: # 10% - tokens[pos] = get_random_word(self.vocab_words) - prev_pos, prev_rand = pos, t_rand diff --git a/text2text/biunilm/seq2seq_loader.py b/text2text/biunilm/seq2seq_loader.py deleted file mode 100755 index 08d00e5..0000000 --- a/text2text/biunilm/seq2seq_loader.py +++ /dev/null @@ -1,407 +0,0 @@ -from random import randint, shuffle, choice -from random import random as rand -import math -import torch - -from .loader_utils import get_random_word, batch_list_to_batch_tensors, Pipeline - -# Input file format : -# 1. One sentence per line. These should ideally be actual sentences, -# not entire paragraphs or arbitrary spans of text. (Because we use -# the sentence boundaries for the "next sentence prediction" task). -# 2. Blank lines between documents. Document boundaries are needed -# so that the "next sentence prediction" task doesn't span between documents. - - -def truncate_tokens_pair(tokens_a, tokens_b, max_len, max_len_a=0, max_len_b=0, trunc_seg=None, always_truncate_tail=False): - num_truncated_a = [0, 0] - num_truncated_b = [0, 0] - while True: - if len(tokens_a) + len(tokens_b) <= max_len: - break - if (max_len_a > 0) and len(tokens_a) > max_len_a: - trunc_tokens = tokens_a - num_truncated = num_truncated_a - elif (max_len_b > 0) and len(tokens_b) > max_len_b: - trunc_tokens = tokens_b - num_truncated = num_truncated_b - elif trunc_seg: - # truncate the specified segment - if trunc_seg == 'a': - trunc_tokens = tokens_a - num_truncated = num_truncated_a - else: - trunc_tokens = tokens_b - num_truncated = num_truncated_b - else: - # truncate the longer segment - if len(tokens_a) > len(tokens_b): - trunc_tokens = tokens_a - num_truncated = num_truncated_a - else: - trunc_tokens = tokens_b - num_truncated = num_truncated_b - # whether always truncate source sequences - if (not always_truncate_tail) and (rand() < 0.5): - del trunc_tokens[0] - num_truncated[0] += 1 - else: - trunc_tokens.pop() - num_truncated[1] += 1 - return num_truncated_a, num_truncated_b - - -class Seq2SeqDataset(torch.utils.data.Dataset): - """ Load sentence pair (sequential or random order) from corpus """ - - def __init__(self, file_src, file_tgt, batch_size, tokenizer, max_len, file_oracle=None, short_sampling_prob=0.1, sent_reverse_order=False, bi_uni_pipeline=[]): - super().__init__() - self.tokenizer = tokenizer # tokenize function - self.max_len = max_len # maximum length of tokens - self.short_sampling_prob = short_sampling_prob - self.bi_uni_pipeline = bi_uni_pipeline - self.batch_size = batch_size - self.sent_reverse_order = sent_reverse_order - - # read the file into memory - self.ex_list = [] - if file_oracle is None: - with open(file_src, "r", encoding='utf-8') as f_src, open(file_tgt, "r", encoding='utf-8') as f_tgt: - for src, tgt in zip(f_src, f_tgt): - src_tk = tokenizer.tokenize(src.strip()) - tgt_tk = tokenizer.tokenize(tgt.strip()) - assert len(src_tk) > 0 - assert len(tgt_tk) > 0 - self.ex_list.append((src_tk, tgt_tk)) - else: - with open(file_src, "r", encoding='utf-8') as f_src, \ - open(file_tgt, "r", encoding='utf-8') as f_tgt, \ - open(file_oracle, "r", encoding='utf-8') as f_orc: - for src, tgt, orc in zip(f_src, f_tgt, f_orc): - src_tk = tokenizer.tokenize(src.strip()) - tgt_tk = tokenizer.tokenize(tgt.strip()) - s_st, labl = orc.split('\t') - s_st = [int(x) for x in s_st.split()] - labl = [int(x) for x in labl.split()] - self.ex_list.append((src_tk, tgt_tk, s_st, labl)) - print('Load {0} documents'.format(len(self.ex_list))) - - def __len__(self): - return len(self.ex_list) - - def __getitem__(self, idx): - instance = self.ex_list[idx] - proc = choice(self.bi_uni_pipeline) - instance = proc(instance) - return instance - - def __iter__(self): # iterator to load data - for __ in range(math.ceil(len(self.ex_list) / float(self.batch_size))): - batch = [] - for __ in range(self.batch_size): - idx = randint(0, len(self.ex_list)-1) - batch.append(self.__getitem__(idx)) - # To Tensor - yield batch_list_to_batch_tensors(batch) - - -class Preprocess4Seq2seq(Pipeline): - """ Pre-processing steps for pretraining transformer """ - - def __init__(self, max_pred, mask_prob, vocab_words, indexer, max_len=512, skipgram_prb=0, skipgram_size=0, block_mask=False, mask_whole_word=False, new_segment_ids=False, truncate_config={}, mask_source_words=False, mode="s2s", has_oracle=False, num_qkv=0, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False, **kwargs): - super().__init__() - self.max_len = max_len - self.max_pred = max_pred # max tokens of prediction - self.mask_prob = mask_prob # masking probability - self.vocab_words = vocab_words # vocabulary (sub)words - self.indexer = indexer # function from token to token index - self.max_len = max_len - self._tril_matrix = torch.tril(torch.ones( - (max_len, max_len), dtype=torch.long)) - self.skipgram_prb = skipgram_prb - self.skipgram_size = skipgram_size - self.mask_whole_word = mask_whole_word - self.new_segment_ids = new_segment_ids - self.always_truncate_tail = truncate_config.get( - 'always_truncate_tail', False) - self.max_len_a = truncate_config.get('max_len_a', None) - self.max_len_b = truncate_config.get('max_len_b', None) - self.trunc_seg = truncate_config.get('trunc_seg', None) - self.task_idx = 3 # relax projection layer for different tasks - self.mask_source_words = mask_source_words - assert mode in ("s2s", "l2r") - self.mode = mode - self.has_oracle = has_oracle - self.num_qkv = num_qkv - self.s2s_special_token = s2s_special_token - self.s2s_add_segment = s2s_add_segment - self.s2s_share_segment = s2s_share_segment - self.pos_shift = pos_shift - - def __call__(self, instance): - tokens_a, tokens_b = instance[:2] - - if self.pos_shift: - tokens_b = ['[S2S_SOS]'] + tokens_b - - # -3 for special tokens [CLS], [SEP], [SEP] - num_truncated_a, _ = truncate_tokens_pair(tokens_a, tokens_b, self.max_len - 3, max_len_a=self.max_len_a, - max_len_b=self.max_len_b, trunc_seg=self.trunc_seg, always_truncate_tail=self.always_truncate_tail) - - # Add Special Tokens - if self.s2s_special_token: - tokens = ['[S2S_CLS]'] + tokens_a + \ - ['[S2S_SEP]'] + tokens_b + ['[SEP]'] - else: - tokens = ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]'] - - if self.new_segment_ids: - if self.mode == "s2s": - if self.s2s_add_segment: - if self.s2s_share_segment: - segment_ids = [0] + [1] * \ - (len(tokens_a)+1) + [5]*(len(tokens_b)+1) - else: - segment_ids = [4] + [6] * \ - (len(tokens_a)+1) + [5]*(len(tokens_b)+1) - else: - segment_ids = [4] * (len(tokens_a)+2) + \ - [5]*(len(tokens_b)+1) - else: - segment_ids = [2] * (len(tokens)) - else: - segment_ids = [0]*(len(tokens_a)+2) + [1]*(len(tokens_b)+1) - - if self.pos_shift: - n_pred = min(self.max_pred, len(tokens_b)) - masked_pos = [len(tokens_a)+2+i for i in range(len(tokens_b))] - masked_weights = [1]*n_pred - masked_ids = self.indexer(tokens_b[1:]+['[SEP]']) - else: - # For masked Language Models - # the number of prediction is sometimes less than max_pred when sequence is short - effective_length = len(tokens_b) - if self.mask_source_words: - effective_length += len(tokens_a) - n_pred = min(self.max_pred, max( - 1, int(round(effective_length*self.mask_prob)))) - # candidate positions of masked tokens - cand_pos = [] - special_pos = set() - for i, tk in enumerate(tokens): - # only mask tokens_b (target sequence) - # we will mask [SEP] as an ending symbol - if (i >= len(tokens_a)+2) and (tk != '[CLS]'): - cand_pos.append(i) - elif self.mask_source_words and (i < len(tokens_a)+2) and (tk != '[CLS]') and (not tk.startswith('[SEP')): - cand_pos.append(i) - else: - special_pos.add(i) - shuffle(cand_pos) - - masked_pos = set() - max_cand_pos = max(cand_pos) - for pos in cand_pos: - if len(masked_pos) >= n_pred: - break - if pos in masked_pos: - continue - - def _expand_whole_word(st, end): - new_st, new_end = st, end - while (new_st >= 0) and tokens[new_st].startswith('##'): - new_st -= 1 - while (new_end < len(tokens)) and tokens[new_end].startswith('##'): - new_end += 1 - return new_st, new_end - - if (self.skipgram_prb > 0) and (self.skipgram_size >= 2) and (rand() < self.skipgram_prb): - # ngram - cur_skipgram_size = randint(2, self.skipgram_size) - if self.mask_whole_word: - st_pos, end_pos = _expand_whole_word( - pos, pos + cur_skipgram_size) - else: - st_pos, end_pos = pos, pos + cur_skipgram_size - else: - # directly mask - if self.mask_whole_word: - st_pos, end_pos = _expand_whole_word(pos, pos + 1) - else: - st_pos, end_pos = pos, pos + 1 - - for mp in range(st_pos, end_pos): - if (0 < mp <= max_cand_pos) and (mp not in special_pos): - masked_pos.add(mp) - else: - break - - masked_pos = list(masked_pos) - if len(masked_pos) > n_pred: - shuffle(masked_pos) - masked_pos = masked_pos[:n_pred] - - masked_tokens = [tokens[pos] for pos in masked_pos] - for pos in masked_pos: - if rand() < 0.8: # 80% - tokens[pos] = '[MASK]' - elif rand() < 0.5: # 10% - tokens[pos] = get_random_word(self.vocab_words) - # when n_pred < max_pred, we only calculate loss within n_pred - masked_weights = [1]*len(masked_tokens) - - # Token Indexing - masked_ids = self.indexer(masked_tokens) - # Token Indexing - input_ids = self.indexer(tokens) - - # Zero Padding - n_pad = self.max_len - len(input_ids) - input_ids.extend([0]*n_pad) - segment_ids.extend([0]*n_pad) - - if self.num_qkv > 1: - mask_qkv = [0]*(len(tokens_a)+2) + [1] * (len(tokens_b)+1) - mask_qkv.extend([0]*n_pad) - else: - mask_qkv = None - - input_mask = torch.zeros(self.max_len, self.max_len, dtype=torch.long) - if self.mode == "s2s": - input_mask[:, :len(tokens_a)+2].fill_(1) - second_st, second_end = len( - tokens_a)+2, len(tokens_a)+len(tokens_b)+3 - input_mask[second_st:second_end, second_st:second_end].copy_( - self._tril_matrix[:second_end-second_st, :second_end-second_st]) - else: - st, end = 0, len(tokens_a) + len(tokens_b) + 3 - input_mask[st:end, st:end].copy_(self._tril_matrix[:end, :end]) - - # Zero Padding for masked target - if self.max_pred > n_pred: - n_pad = self.max_pred - n_pred - if masked_ids is not None: - masked_ids.extend([0]*n_pad) - if masked_pos is not None: - masked_pos.extend([0]*n_pad) - if masked_weights is not None: - masked_weights.extend([0]*n_pad) - - oracle_pos = None - oracle_weights = None - oracle_labels = None - if self.has_oracle: - s_st, labls = instance[2:] - oracle_pos = [] - oracle_labels = [] - for st, lb in zip(s_st, labls): - st = st - num_truncated_a[0] - if st > 0 and st < len(tokens_a): - oracle_pos.append(st) - oracle_labels.append(lb) - oracle_pos = oracle_pos[:20] - oracle_labels = oracle_labels[:20] - oracle_weights = [1] * len(oracle_pos) - if len(oracle_pos) < 20: - x_pad = 20 - len(oracle_pos) - oracle_pos.extend([0] * x_pad) - oracle_labels.extend([0] * x_pad) - oracle_weights.extend([0] * x_pad) - - return (input_ids, segment_ids, input_mask, mask_qkv, masked_ids, - masked_pos, masked_weights, -1, self.task_idx, - oracle_pos, oracle_weights, oracle_labels) - - return (input_ids, segment_ids, input_mask, mask_qkv, masked_ids, masked_pos, masked_weights, -1, self.task_idx) - - -class Preprocess4Seq2seqDecoder(Pipeline): - """ Pre-processing steps for pretraining transformer """ - - def __init__(self, vocab_words, indexer, max_len=512, max_tgt_length=128, new_segment_ids=False, mode="s2s", num_qkv=0, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False, **kwargs): - super().__init__() - self.max_len = max_len - self.vocab_words = vocab_words # vocabulary (sub)words - self.indexer = indexer # function from token to token index - self.max_len = max_len - self._tril_matrix = torch.tril(torch.ones( - (max_len, max_len), dtype=torch.long)) - self.new_segment_ids = new_segment_ids - self.task_idx = 3 # relax projection layer for different tasks - assert mode in ("s2s", "l2r") - self.mode = mode - self.max_tgt_length = max_tgt_length - self.num_qkv = num_qkv - self.s2s_special_token = s2s_special_token - self.s2s_add_segment = s2s_add_segment - self.s2s_share_segment = s2s_share_segment - self.pos_shift = pos_shift - - def __call__(self, instance): - tokens_a, max_a_len = instance - - # Add Special Tokens - if self.s2s_special_token: - padded_tokens_a = ['[S2S_CLS]'] + tokens_a + ['[S2S_SEP]'] - else: - padded_tokens_a = ['[CLS]'] + tokens_a + ['[SEP]'] - assert len(padded_tokens_a) <= max_a_len + 2 - if max_a_len + 2 > len(padded_tokens_a): - padded_tokens_a += ['[PAD]'] * \ - (max_a_len + 2 - len(padded_tokens_a)) - assert len(padded_tokens_a) == max_a_len + 2 - max_len_in_batch = min(self.max_tgt_length + - max_a_len + 2, self.max_len) - tokens = padded_tokens_a - if self.new_segment_ids: - if self.mode == "s2s": - _enc_seg1 = 0 if self.s2s_share_segment else 4 - if self.s2s_add_segment: - if self.s2s_share_segment: - segment_ids = [ - 0] + [1]*(len(padded_tokens_a)-1) + [5]*(max_len_in_batch - len(padded_tokens_a)) - else: - segment_ids = [ - 4] + [6]*(len(padded_tokens_a)-1) + [5]*(max_len_in_batch - len(padded_tokens_a)) - else: - segment_ids = [4]*(len(padded_tokens_a)) + \ - [5]*(max_len_in_batch - len(padded_tokens_a)) - else: - segment_ids = [2]*max_len_in_batch - else: - segment_ids = [0]*(len(padded_tokens_a)) \ - + [1]*(max_len_in_batch - len(padded_tokens_a)) - - if self.num_qkv > 1: - mask_qkv = [0]*(len(padded_tokens_a)) + [1] * \ - (max_len_in_batch - len(padded_tokens_a)) - else: - mask_qkv = None - - position_ids = [] - for i in range(len(tokens_a) + 2): - position_ids.append(i) - for i in range(len(tokens_a) + 2, max_a_len + 2): - position_ids.append(0) - for i in range(max_a_len + 2, max_len_in_batch): - position_ids.append(i - (max_a_len + 2) + len(tokens_a) + 2) - - # Token Indexing - input_ids = self.indexer(tokens) - - # Zero Padding - input_mask = torch.zeros( - max_len_in_batch, max_len_in_batch, dtype=torch.long) - if self.mode == "s2s": - input_mask[:, :len(tokens_a)+2].fill_(1) - else: - st, end = 0, len(tokens_a) + 2 - input_mask[st:end, st:end].copy_( - self._tril_matrix[:end, :end]) - input_mask[end:, :len(tokens_a)+2].fill_(1) - second_st, second_end = len(padded_tokens_a), max_len_in_batch - - input_mask[second_st:second_end, second_st:second_end].copy_( - self._tril_matrix[:second_end-second_st, :second_end-second_st]) - - return (input_ids, segment_ids, position_ids, input_mask, mask_qkv, self.task_idx) \ No newline at end of file diff --git a/text2text/handler.py b/text2text/handler.py index 2e87aa8..aa568cd 100644 --- a/text2text/handler.py +++ b/text2text/handler.py @@ -6,7 +6,6 @@ class Handler(object): """ EXPOSED_TRANSFORMERS = { - "answer": t2t.Answerer, "assist": t2t.Assistant, "bm25": t2t.Bm25er, "count": t2t.Counter, @@ -14,8 +13,6 @@ class Handler(object): "index": t2t.Indexer, "fit": t2t.Fitter, "measure": t2t.Measurer, - "question": t2t.Questioner, - "summarize": t2t.Summarizer, "tfidf": t2t.Tfidfer, "tokenize": t2t.Tokenizer, "translate": t2t.Translator, diff --git a/text2text/langchain/__init__.py b/text2text/langchain/__init__.py deleted file mode 100644 index 1a43829..0000000 --- a/text2text/langchain/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .stfidf import STFIDFRetriever -from .text2text_assistant import Text2TextAssistant \ No newline at end of file diff --git a/text2text/langchain/stfidf.py b/text2text/langchain/stfidf.py deleted file mode 100644 index fe31b76..0000000 --- a/text2text/langchain/stfidf.py +++ /dev/null @@ -1,75 +0,0 @@ -"""STF-IDF Retriever. - -Based on https://github.com/artitw/text2text""" - -from __future__ import annotations - -from typing import Any, Dict, Iterable, List, Optional - -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) -from langchain.schema import BaseRetriever, Document - - -class STFIDFRetriever(BaseRetriever): - index: Any - docs: List[Document] - k: int = 4 - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - @classmethod - def from_texts( - cls, - texts: Iterable[str], - metadatas: Optional[Iterable[dict]] = None, - **kwargs: Any, - ) -> STFIDFRetriever: - try: - import text2text as t2t - except ImportError: - raise ImportError( - "Could not import text2text, please install with `pip install " - "text2text`." - ) - - index = t2t.Indexer().transform(texts) - metadatas = metadatas or ({} for _ in texts) - docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)] - return cls(index=index, docs=docs, **kwargs) - - @classmethod - def from_documents( - cls, - documents: Iterable[Document], - *, - tfidf_params: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> STFIDFRetriever: - texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) - return cls.from_texts( - texts=texts, metadatas=metadatas, **kwargs - ) - - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: - distances, pred_ids = self.index.search([query], k=self.k) - return [self.docs[i] for i in pred_ids[0] if i >= 0] - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError - - async def aadd_documents( - self, documents: List[Document], **kwargs: Any - ) -> List[str]: - texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) - self.docs += documents - self.index.add(texts) \ No newline at end of file diff --git a/text2text/langchain/test_stfidf.py b/text2text/langchain/test_stfidf.py deleted file mode 100644 index 67f8e28..0000000 --- a/text2text/langchain/test_stfidf.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest - -from text2text.langchain.stfidf import STFIDFRetriever -from langchain.schema import Document - - -@pytest.mark.requires("langchain") -def test_from_texts() -> None: - input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."] - stfidf_retriever = STFIDFRetriever.from_texts(texts=input_texts) - assert len(stfidf_retriever.docs) == 3 - - -@pytest.mark.requires("langchain") -def test_retrieval() -> None: - input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."] - stfidf_retriever = STFIDFRetriever.from_texts( - texts=input_texts - ) - assert len(stfidf_retriever.index.retrieve(["pen"], k=2)[0]) == 2 - - -@pytest.mark.requires("langchain") -def test_from_documents() -> None: - input_docs = [ - Document(page_content="I have a pen."), - Document(page_content="Do you have a pen?"), - Document(page_content="I have a bag."), - ] - tfidf_retriever = STFIDFRetriever.from_documents(documents=input_docs) - assert len(tfidf_retriever.docs) == 3 diff --git a/text2text/langchain/test_text2text_assistant.py b/text2text/langchain/test_text2text_assistant.py deleted file mode 100644 index 14d6065..0000000 --- a/text2text/langchain/test_text2text_assistant.py +++ /dev/null @@ -1,10 +0,0 @@ -import pytest - -from text2text.langchain.text2text_assistant import Text2TextAssistant - -@pytest.mark.requires("langchain") -def test_llm_inference() -> None: - input_text = 'Say "hello, world" back to me' - llm = Text2TextAssistant() - result = llm(input_text) - assert "hello" in result.lower() \ No newline at end of file diff --git a/text2text/langchain/text2text_assistant.py b/text2text/langchain/text2text_assistant.py deleted file mode 100644 index fa3f93a..0000000 --- a/text2text/langchain/text2text_assistant.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Any, List, Mapping, Optional - -import text2text as t2t -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.llms.base import LLM - -class Text2TextAssistant(LLM): - model: t2t.Assistant = t2t.Assistant() - - @property - def _llm_type(self) -> str: - return "Text2Text" - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs - ) -> str: - if stop is not None: - raise ValueError("stop kwargs are not permitted.") - return self.model.transform(messages=[{"role": "user", "content": prompt}], **kwargs)["content"] - - @property - def _identifying_params(self) -> Mapping[str, Any]: - """Get the identifying parameters.""" - return {"type": self._llm_type} \ No newline at end of file diff --git a/text2text/mixtral/__init__.py b/text2text/mixtral/__init__.py deleted file mode 100644 index 95cef1d..0000000 --- a/text2text/mixtral/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .build_model import OffloadConfig, QuantConfig, build_model diff --git a/text2text/mixtral/build_model.py b/text2text/mixtral/build_model.py deleted file mode 100644 index 09d3b16..0000000 --- a/text2text/mixtral/build_model.py +++ /dev/null @@ -1,263 +0,0 @@ -# MIT License -# -# Copyright (c) 2023 Artyom Eliseev, Denis Mazur -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import os -import json -from functools import cache -from dataclasses import dataclass -import typing as tp - -import torch -from torch import nn - -from transformers import AutoConfig -from transformers.models.mixtral import MixtralForCausalLM, MixtralConfig - -from safetensors.torch import load_file - -from torch import nn -from tqdm.auto import trange - -from hqq.core.quantize import BaseQuantizeConfig - -from .expert_cache import ExpertCache -from .expert_wrapper import MixtralExpertWrapper -from .custom_layers import ( - HQQLinearTritonSavable, - MixtralBLockSparseTop2MLP_HQQ, - SparseMoeWrapper, -) -from .utils import with_default_dtype - - -@dataclass(frozen=True) -class OffloadConfig: - main_size: int - offload_size: int - buffer_size: int - offload_per_layer: int - - -class QuantConfig: - def __init__( - self, - ffn_config: BaseQuantizeConfig, - attn_config: BaseQuantizeConfig, - ): - self.ffn_config = ffn_config - self.attn_config = attn_config - - @cache - def get_ffn_metas(self, hidden_dim: int, ffn_dim: int) -> tuple[tp.Any, tp.Any]: - return ( - HQQLinearTritonSavable.get_hqq_meta((hidden_dim, ffn_dim), self.ffn_config), - HQQLinearTritonSavable.get_hqq_meta((ffn_dim, hidden_dim), self.ffn_config), - ) - - -def replace_attn_layers( - model: MixtralForCausalLM, - config: MixtralConfig, - quant_config: QuantConfig, - device: torch.device, -) -> None: - attn_quant_config = quant_config.attn_config - - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - head_dim = hidden_size // num_heads - num_key_value_heads = config.num_key_value_heads - - shapes = [ - (hidden_size, num_heads * head_dim), - (hidden_size, num_key_value_heads * head_dim), - (hidden_size, num_key_value_heads * head_dim), - (num_heads * head_dim, hidden_size), - ] - - shape_to_meta = { - shape: HQQLinearTritonSavable.get_hqq_meta(shape, attn_quant_config) - for shape in shapes - } - - def patch_fct_hqq(shape, quant_config): - meta = shape_to_meta[shape] - layer = HQQLinearTritonSavable(None, quant_config, meta=meta) - return layer - - for layer in model.model.layers: - layer.block_sparse_moe.gate = nn.Linear( - config.hidden_size, - config.num_local_experts, - dtype=torch.float16, - device=device, - bias=False, - ) - - layer.self_attn.q_proj = patch_fct_hqq( - (hidden_size, num_heads * head_dim), attn_quant_config - ) - layer.self_attn.k_proj = patch_fct_hqq( - (hidden_size, num_key_value_heads * head_dim), attn_quant_config - ) - layer.self_attn.v_proj = patch_fct_hqq( - (hidden_size, num_key_value_heads * head_dim), attn_quant_config - ) - layer.self_attn.o_proj = patch_fct_hqq( - (hidden_size, num_heads * head_dim), attn_quant_config - ) - - -@cache -def get_default_ffn_quant_config(ffn_dim: int = 14336, hidden_dim: int = 4096): - quant_config = BaseQuantizeConfig( - nbits=2, - group_size=16, - quant_zero=True, - quant_scale=True, - ) - - meta1 = HQQLinearTritonSavable.get_hqq_meta((hidden_dim, ffn_dim), quant_config) - meta2 = HQQLinearTritonSavable.get_hqq_meta((ffn_dim, hidden_dim), quant_config) - - return quant_config, meta1, meta2 - - -def make_empty_expert( - model_config: MixtralConfig, quant_config: QuantConfig -) -> MixtralBLockSparseTop2MLP_HQQ: - meta1, meta2 = quant_config.get_ffn_metas( - model_config.hidden_size, model_config.intermediate_size - ) - return MixtralBLockSparseTop2MLP_HQQ( - model_config, - quant_config.ffn_config, - meta1, - meta2, - ) - - -def make_and_load_expert_wrapper( - config: MixtralConfig, - quant_config: QuantConfig, - states_dir: str, - expert_uid: tuple[int, int], - device: torch.device, -) -> MixtralExpertWrapper: - layer_idx, expert_idx = expert_uid - - index_path = os.path.join(states_dir, "model.safetensors.index.json") - with open(index_path) as f: - module_idx = f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}" - state_fpath = json.load(f)["weight_map"][f"{module_idx}.w1.W_q"] - - state_dict = load_file(os.path.join(states_dir, state_fpath), device=str(device)) - expert = make_empty_expert(config, quant_config) - expert.load_state_dict(state_dict, strict=True) - - return MixtralExpertWrapper(expert, device) - - -def load_00_expert_state_dict(states_dir: str, device: torch.device): - index_path = os.path.join(states_dir, "model.safetensors.index.json") - with open(index_path) as f: - module_idx = f"model.layers.0.block_sparse_moe.experts.0" - state_fpath = json.load(f)["weight_map"][f"{module_idx}.w1.W_q"] - return load_file(os.path.join(states_dir, state_fpath), device=str(device)) - - -def build_model( - device: torch.device, - quant_config: QuantConfig, - offload_config: OffloadConfig, - state_path: str, -): - model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1" - - state_dict_00 = load_00_expert_state_dict(state_path, device) - - def _make_module(): - config = AutoConfig.from_pretrained(model_name) - expert = make_empty_expert(config, quant_config) - expert.load_state_dict(state_dict_00) - return MixtralExpertWrapper(expert, device=device) - - with device, with_default_dtype(torch.float16): - model = MixtralForCausalLM( - AutoConfig.from_pretrained( - model_name, - num_local_experts=0, - torch_dtype=torch.float16, - device_map=device, - ), - ) - - model_config = AutoConfig.from_pretrained(model_name) - replace_attn_layers(model, model_config, quant_config, device) - state_index_path = os.path.join(state_path, "model.safetensors.index.json") - with open(state_index_path) as f: - weight_map = json.load(f)["weight_map"] - - trunk_state_path = os.path.join( - state_path, - weight_map["model.embed_tokens.weight"], - ) - model.load_state_dict(load_file(trunk_state_path, device=str(device)), strict=True) - - expert_cache = ExpertCache( - make_module=_make_module, - main_size=offload_config.main_size, - offload_size=offload_config.offload_size, - buffer_size=offload_config.buffer_size, - ) - for layer_idx in trange(model_config.num_hidden_layers, desc="Loading experts"): - curr_layer = model.model.layers[layer_idx] - curr_layer.block_sparse_moe = SparseMoeWrapper( - model_config, - layer_idx, - curr_layer.block_sparse_moe.gate, - expert_cache, - ) - - for expert_idx in range(model_config.num_local_experts): - do_offload = expert_idx < offload_config.offload_per_layer - - expert_wrapper = make_and_load_expert_wrapper( - config=model_config, - quant_config=quant_config, - states_dir=state_path, - expert_uid=(layer_idx, expert_idx), - device=device, - ) - - expert_cache.add_expert( - uid=(layer_idx, expert_idx), - module=expert_wrapper, - eviction_group=layer_idx, - offload=do_offload, - ) - - del expert_wrapper - torch.cuda.synchronize(device) - torch.cuda.empty_cache() - - return model \ No newline at end of file diff --git a/text2text/mixtral/custom_layers.py b/text2text/mixtral/custom_layers.py deleted file mode 100644 index 8909d17..0000000 --- a/text2text/mixtral/custom_layers.py +++ /dev/null @@ -1,336 +0,0 @@ -# MIT License -# -# Copyright (c) 2023 Artyom Eliseev, Denis Mazur -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import copy -import functools -from transformers.models.mixtral.configuration_mixtral import MixtralConfig -from transformers.activations import ACT2FN -from typing import Dict, Any -from hqq.core.quantize import HQQLinear, Quantizer - -import torch -from torch import nn -from torch.nn import functional as F - -from .packing import pack_4bit_u8_common, pack_2bit_u8_common, unpack_4bit_u8_common, unpack_2bit_u8_common -from .triton_kernels import triton_matmul4_transpose, triton_matmul3_transpose, triton_matmul2_transpose - - -class HQQLinearTritonSavable(HQQLinear): - def __init__(self, layer, quant_config, meta=None, **kwargs): - """ - Example how to get meta: - >>>> meta1 = HQQLinearSavable.get_hqq_meta((hidden_dim, ffn_dim), quant_config) - >>>> meta2 = HQQLinearSavable.get_hqq_meta((ffn_dim, hidden_dim), quant_config) - """ - - assert quant_config['weight_quant_params']['nbits'] in [2, 3, 4] - - super().__init__(layer, quant_config, **kwargs) - - if not hasattr(self, 'meta'): - assert meta is not None - self.meta = copy.deepcopy(meta) - - self._register_state_dict_hook(self._add_to_state_dict_hook) - self._register_load_state_dict_pre_hook(self._load_from_state_dict_hook) - - def quantize(self, *args, **kwargs): - super().quantize(*args, **kwargs) - - # repacking - self.repack() - - def repack(self): - if self.W_q.shape != self.meta['shape']: - W_q = Quantizer.unpack[self.meta['packing']](self.W_q) - sh = self.meta['shape'] - W_q = W_q.reshape((-1,) + sh[1:]) - W_q = W_q[:sh[0], ...] - self.W_q = Quantizer.pack[self.meta['packing']](W_q) - - def forward(self, x): - return self.forward_triton(x) - - def set_backend(self, backend): - pass - - @torch.inference_mode() - def forward_triton(self, x): - assert self.ready, "model was not quantized" - assert self.meta['axis'] == 0 - - W_q, meta = self.W_q, self.meta - - del_keys = [] - if 'quant_scale' in meta and meta['quant_scale']: - meta['scale'] = Quantizer.dequantize(meta['scale_q'], meta['meta_scale']); del_keys.append('scale') - if 'quant_zero' in meta and meta['quant_zero']: - meta['zero'] = Quantizer.dequantize(meta['zero_q'], meta['meta_zero']); del_keys.append('zero') - - K = meta['shape'][1] - N = meta['shape'][0] - - if self.meta['nbits'] == 4: - fn = triton_matmul4_transpose - elif self.meta['nbits'] == 3: - fn = functools.partial(triton_matmul3_transpose, N=N) - elif self.meta['nbits'] == 2: - fn = triton_matmul2_transpose - else: - raise RuntimeError(f"nbits == {self.meta['nbits']} isn't yet supported") - - output = fn( - meta['group_size'], x, - W_q.view(-1, K), - meta['scale'].view(-1, K), - meta['zero'].view(-1, K), - bias=self.bias if hasattr(self, 'bias') else None, - ) - - #Cleanup - for key in del_keys: - del meta[key] - - return output - - # to support .forward_pytorch(...) - backward compatibility - @torch.inference_mode() - def dequantize(self): - assert self.ready, "model was not quantized" - W_q, meta = self.W_q, self.meta - del_keys = [] - if(meta['quant_scale']): - meta['scale'] = Quantizer.dequantize(meta['scale_q'], meta['meta_scale']); del_keys.append('scale') - if(meta['quant_zero']): - meta['zero'] = Quantizer.dequantize(meta['zero_q'], meta['meta_zero']); del_keys.append('zero') - - W_q_p = Quantizer.unpack[meta['packing']](W_q).half() - W_q_p = W_q_p[:meta['shape'][0], ...] - W_q_p = W_q_p.reshape((meta['group_size'], -1)) - - if((meta['group_size'] is not None) and (meta['nbits']==3)): - W_q_p = W_q_p[:meta['group_size']] if (meta['axis']==0) else W_q_p[:,:meta['group_size']] - W_est = ((W_q_p - meta['zero'])*meta['scale']).reshape(meta['shape']) - - #Cleanup - del W_q_p - for key in del_keys: del meta[key] - return W_est - - @classmethod - def get_hqq_meta(cls, linear_shape, quant_config): - layer = HQQLinear(nn.Linear(*linear_shape, bias=False), quant_config) - meta = layer.meta - - def _remove_tensors_recursive(d): - keys = list(d.keys()) - - for k in keys: - if isinstance(d[k], torch.Tensor): - del d[k] - elif isinstance(d[k], dict): - _remove_tensors_recursive(d[k]) - - _remove_tensors_recursive(meta) - - return meta - - @staticmethod - def _add_to_state_dict_hook(self, state_dict, prefix, local_metadata): - tensor_paths = self._get_tensor_paths(self.meta) - assert set(tensor_paths).issubset( - {'scale_q', 'meta_scale.scale', 'meta_scale.zero', 'zero_q', 'meta_zero.scale', 'meta_zero.zero', - 'scale', 'zero'} - ) - - def _add(name, value): - state_dict[prefix + name] = value - - _add('W_q', self.W_q) - - if self.bias is not None: - _add('bias', self.bias) - - if 'meta_scale' in self.meta: - _add('meta.scale_q', self.meta['scale_q']) - _add('meta.meta_scale.scale', self.meta['meta_scale']['scale']) - _add('meta.meta_scale.zero', self.meta['meta_scale']['zero']) - else: - _add('meta.scale', self.meta['scale']) - - if 'meta_zero' in self.meta: - _add('meta.zero_q', self.meta['zero_q']) - _add('meta.meta_zero.scale', self.meta['meta_zero']['scale']) - _add('meta.meta_zero.zero', self.meta['meta_zero']['zero']) - else: - _add('meta.zero', self.meta['zero']) - - return state_dict - - def _load_from_state_dict_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - tensor_paths = [k[len(prefix + 'meta.'):] for k in state_dict.keys() if k.startswith(prefix + 'meta.')] - assert set(tensor_paths).issubset( - {'scale_q', 'meta_scale.scale', 'meta_scale.zero', 'zero_q', 'meta_zero.scale', 'meta_zero.zero', - 'scale', 'zero'} - ) - - def _del(name): - del state_dict[prefix + name] - def _set(name): - setattr(self, name, state_dict[prefix + name]) - _del(name) - def _get(name): - v = state_dict[prefix + name] - _del(name) - return v - - _set('W_q') - if 'bias' in state_dict: - _set('bias') - else: - self.bias = None - - if not hasattr(self, 'meta'): - self.meta = {} - - if (prefix + 'meta.meta_scale.scale') in state_dict: - self.meta['scale_q'] = _get('meta.scale_q') - self.meta['quant_scale'] = True - if not 'meta_scale' in self.meta: - self.meta['meta_scale'] = {} - self.meta['meta_scale'] |= { - 'scale': _get('meta.meta_scale.scale'), - 'zero': _get('meta.meta_scale.zero') - } - else: - self.meta['scale'] = _get('meta.scale') - if (prefix + 'meta.meta_zero.scale') in state_dict: - self.meta['zero_q'] = _get('meta.zero_q') - self.meta['quant_zero'] = True - if not 'meta_zero' in self.meta: - self.meta['meta_zero'] = {} - self.meta['meta_zero'] |= { - 'scale': _get('meta.meta_zero.scale'), - 'zero': _get('meta.meta_zero.zero') - } - else: - self.meta['zero'] = _get('meta.zero') - self.ready = True - - # self.cuda() - # self.in_gpu = self.W_q.device.type == 'cuda' - # assert self.in_gpu - - self.repack() - - @classmethod - def _get_tensor_paths(cls, state: Dict[str, Any], prefix=''): - paths = [] - - for k, v in state.items(): - if isinstance(v, dict): - paths += cls._get_tensor_paths(v, prefix=k + '.') - elif isinstance(v, torch.Tensor): - paths.append(prefix + k) - - return paths - - def state_dict(self, *args, **kwargs): - return nn.Module.state_dict(self, *args, **kwargs) - - def load_state_dict(self, *args, **kwargs): - nn.Module.load_state_dict(self, *args, **kwargs) - - -class MixtralBLockSparseTop2MLP_HQQ(nn.Module): - def __init__(self, config: MixtralConfig, quant_config: Dict[str, Any], meta1, meta2): - super().__init__() - - self.w1 = HQQLinearTritonSavable(None, quant_config, meta1) - self.w2 = HQQLinearTritonSavable(None, quant_config, meta2) - self.w3 = HQQLinearTritonSavable(None, quant_config, meta1) - - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - - -class SparseMoeWrapper(nn.Module): - def __init__(self, config, layer_id, gate, expert_cache): - super().__init__() - - self.hidden_dim = config.hidden_size - self.ffn_dim = config.intermediate_size - self.num_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - self.layer_id = layer_id - - self.gate = gate - self.experts = expert_cache - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - active_experts = selected_experts.flatten().unique().tolist() - - # Loop over all available experts in the model and perform the computation on each expert - for (_layer_index, expert_idx), expert_layer in self.experts.load_experts( - *((self.layer_id, expert_idx) for expert_idx in active_experts), unordered=True): - idx, top_x = torch.where(expert_mask[expert_idx]) - assert top_x.shape[0] > 0 - - # in torch it is faster to index using lists than torch tensors - top_x_list = top_x.tolist() - idx_list = idx.tolist() - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states, router_logits \ No newline at end of file diff --git a/text2text/mixtral/expert_cache.py b/text2text/mixtral/expert_cache.py deleted file mode 100644 index 9d47625..0000000 --- a/text2text/mixtral/expert_cache.py +++ /dev/null @@ -1,223 +0,0 @@ -# MIT License -# -# Copyright (c) 2023 Artyom Eliseev, Denis Mazur -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Iterator, Tuple, List -from collections import deque, defaultdict, OrderedDict -from .expert_wrapper import MixtralExpertWrapper - -import torch -from torch import nn - -ExpertUID = Any - - -@dataclass(frozen=False) -class ExpertInfo: - uid: ExpertUID - eviction_group: int - offloaded: bool - index: int - - -@dataclass -class EvictionGroupInfo: - # infos in main and offload devices; ordered from least recently used to most - main_infos: OrderedDict[ExpertUID, ExpertInfo] = field(default_factory=OrderedDict) - offloaded_infos: OrderedDict[ExpertUID, ExpertInfo] = field(default_factory=OrderedDict) - hits: int = field(default=0) - misses: int = field(default=0) - - def add(self, info: ExpertInfo): - infos_odict = self.offloaded_infos if info.offloaded else self.main_infos - assert info.uid not in infos_odict, f"expert {info.uid} already exists" - infos_odict[info.uid] = info - - def choose_expert_to_evict(self) -> ExpertInfo: - for uid, info in self.main_infos.items(): - return info # least recently used - raise ValueError("No evictable experts") - - def swap(self, info_to_load: ExpertInfo, info_to_evict: ExpertInfo): - assert info_to_load.uid in self.offloaded_infos and info_to_evict.uid in self.main_infos - self.main_infos[info_to_load.uid] = self.offloaded_infos.pop(info_to_load.uid) - self.main_infos.move_to_end(info_to_load.uid, last=True) - self.offloaded_infos[info_to_evict.uid] = self.main_infos.pop(info_to_evict.uid) - - def mark_used(self, info: ExpertInfo): - if info.uid in self.main_infos: - self.main_infos.move_to_end(info.uid, last=True) - self.hits += 1 - elif info.uid in self.offloaded_infos: - self.offloaded_infos.move_to_end(info.uid, last=True) - self.misses += 1 - else: - raise ValueError(f"Expert {info} not in group") - - -class ExpertCache: - def __init__(self, make_module: callable, main_size: int, offload_size: int, buffer_size: int): - """Dynamically loads an array of modules with identical hyperparameters""" - self.module_type = self.module_size = self.device = None - self.active = False - - self.registered_experts: Dict[ExpertUID, ExpertInfo] = dict() - - self.main_modules = [self._check_module(make_module()) for i in range(main_size)] - self.main_infos: List[Optional[ExpertInfo]] = [None for _ in range(main_size)] - - assert self.module_size is not None - self.offloaded_storages = [ - torch.UntypedStorage(self.module_size).pin_memory(self.device) for _ in range(offload_size)] - self.offloaded_infos: List[Optional[ExpertInfo]] = [None for _ in range(offload_size)] - - # temporary storage to shave off latency - self.device_expert_buffers = deque([self._check_module(make_module()) for _ in range(buffer_size)]) - self.offloaded_storage_buffers = deque([ - torch.UntypedStorage(self.module_size).pin_memory(self.device) for _ in range(buffer_size)]) - self.group_infos: Dict[int, EvictionGroupInfo] = defaultdict(EvictionGroupInfo) - - def _check_module(self, module: MixtralExpertWrapper): - assert isinstance(module.storage, torch.UntypedStorage) - if self.module_type is None: - self.module_type = type(module) - self.module_size = len(module.storage) - self.device = module.storage.device - else: - assert isinstance(module, self.module_type) - assert len(module.storage) == self.module_size - assert module.storage.device == self.device - return module - - def add_expert(self, uid: ExpertUID, module: MixtralExpertWrapper, eviction_group: int = 0, - offload: Optional[bool] = None): - """Register an expert to the cache and associate it with uid""" - assert self.module_type is not None - assert isinstance(module, self.module_type) - return self.add_expert_storage(uid, module.storage, eviction_group=eviction_group, offload=offload) - - def add_expert_storage(self, uid: ExpertUID, storage: torch.UntypedStorage, - eviction_group: int = 0, offload: Optional[bool] = None): - assert uid not in self.registered_experts, f"expert {uid} already registered" - assert isinstance(storage, torch.UntypedStorage) - assert len(storage) == self.module_size - - if offload is None or not offload: # False or None - for i in range(len(self.main_modules)): - if self.main_infos[i] is None: - self.main_modules[i].storage.copy_(storage) - info = ExpertInfo(uid, eviction_group=eviction_group, offloaded=False, index=i) - self.registered_experts[uid] = self.main_infos[i] = info - self.group_infos[eviction_group].add(info) - return # done allocating; found spot on device - if offload is None or offload: # True or None - for i in range(len(self.offloaded_storages)): - if self.offloaded_infos[i] is None: - self.offloaded_storages[i].copy_(storage) - info = ExpertInfo(uid, eviction_group=eviction_group, offloaded=True, index=i) - self.registered_experts[uid] = self.offloaded_infos[i] = info - self.group_infos[eviction_group].add(info) - return # done allocating; found an offloaded spot - raise ValueError("Cache is full") - - def load_experts( - self, *uids: ExpertUID, unordered: bool = False) -> Iterator[Tuple[ExpertUID, MixtralExpertWrapper]]: - """ - :example: - >>> for uid, expert in expert_cache.load_experts(*list_of_uids, unordered=True): - >>> for uid, expert in expert_iter: - >>> result += expert(x) * get_moe_weight(uid) - - :param uids: iterate over the specified expert uids. Same uids as in add_expert - :param unordered: if True, allows cache to iterate experts in arbitrary order - The order is chosen to minimize the total wait time. - :returns: an iterator that yields (uid, expert) pairs, only usable inside the for loop - - """ - assert len(set(uids)) == len(uids) - assert not self.active, "already loading experts; buffers are busy" - if unordered: # yield non-offloaded experts first - uids = sorted(uids, key=lambda uid: self.registered_experts[uid].offloaded) - infos = [self.registered_experts[uid] for uid in uids] - - assert len(set(info.eviction_group for info in infos)) == 1, "experts must be in the same evicton group" - eviction_group = self.group_infos[infos[0].eviction_group] - for info in infos: - eviction_group.mark_used(info) - - try: - self.active = True - # save pre-loaded experts before they can be swapped - pre_loaded_infos = deque([info for info in infos if not info.offloaded]) - pre_loaded_experts = deque([self.main_modules[info.index] for info in pre_loaded_infos]) - - # begin loading experts into free buffers in background (via non-blocking copy) - infos_to_load = deque([info for info in infos if info.offloaded]) - infos_in_loading = deque([]) - experts_in_loading = deque([]) - window_size = min(len(self.device_expert_buffers) - 1, - len(eviction_group.main_infos), - len(infos_to_load)) - for _ in range(window_size): - info_to_load = infos_to_load.popleft() - infos_in_loading.append(info_to_load) - experts_in_loading.append( - self._swap(info_to_load, eviction_group.choose_expert_to_evict())) - - for info in infos: - if len(pre_loaded_infos) > 0 and info is pre_loaded_infos[0]: - pre_loaded_infos.popleft() - yield (info.uid, pre_loaded_experts.popleft()) - elif len(infos_in_loading) > 0 and info is infos_in_loading[0]: - infos_in_loading.popleft() - yield (info.uid, experts_in_loading.popleft()) - if len(infos_to_load) > 0: - info_to_load = infos_to_load.popleft() - infos_in_loading.append(info_to_load) - experts_in_loading.append( - self._swap(info_to_load, eviction_group.choose_expert_to_evict())) - else: - raise RuntimeError("internal error: caching algorithm failed") - finally: - self.active = False - - def _swap(self, info_to_load: ExpertInfo, info_to_evict: ExpertInfo) -> nn.Module: - """Swap an offloaded expert (info_to_load) with an on-device expert (info_to_evict) return the loaded expert""" - assert info_to_load.offloaded and not info_to_evict.offloaded - assert info_to_load.eviction_group == info_to_evict.eviction_group - # swap a single on-device expert with a single offloaded expert using buffers for parallelism - offloaded_storage_buffer = self.offloaded_storage_buffers.popleft() - device_expert_buffer = self.device_expert_buffers.popleft() - device_expert_buffer.storage.copy_(self.offloaded_storages[info_to_load.index], non_blocking=True) - offloaded_storage_buffer.copy_(self.main_modules[info_to_evict.index].storage, non_blocking=True) - - self.device_expert_buffers.append(self.main_modules[info_to_evict.index]) - self.main_modules[info_to_evict.index] = device_expert_buffer - self.offloaded_storage_buffers.append(self.offloaded_storages[info_to_load.index]) - self.offloaded_storages[info_to_load.index] = offloaded_storage_buffer - - self.main_infos[info_to_evict.index] = info_to_load - self.offloaded_infos[info_to_load.index] = info_to_evict - info_to_evict.offloaded, info_to_load.offloaded = info_to_load.offloaded, info_to_evict.offloaded - info_to_evict.index, info_to_load.index = info_to_load.index, info_to_evict.index - self.group_infos[info_to_load.eviction_group].swap(info_to_load, info_to_evict) - return device_expert_buffer \ No newline at end of file diff --git a/text2text/mixtral/expert_wrapper.py b/text2text/mixtral/expert_wrapper.py deleted file mode 100644 index a37fdd3..0000000 --- a/text2text/mixtral/expert_wrapper.py +++ /dev/null @@ -1,107 +0,0 @@ -# MIT License -# -# Copyright (c) 2023 Artyom Eliseev, Denis Mazur -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import typing as tp - -import torch -from torch import nn - -from .utils import nested_flatten, nested_pack - - -class MixtralExpertWrapper(nn.Module): - def __init__( - self, - expert_module: tp.Any, - device: torch.device, - ): - super().__init__() - - expert_module, self.storage = self.replace_layer_storage(expert_module, device) - self.expert_module = lambda *args, **kwargs: expert_module(*args, **kwargs) - - self._register_state_dict_hook(self._add_storage_to_state_dict_hook) - self._register_load_state_dict_pre_hook(self._load_storage_from_state_dict_hook) - - @staticmethod - def _add_storage_to_state_dict_hook(self, state_dict, prefix, local_metadata): - state_dict[prefix + 'storage'] = torch.as_tensor(self.storage, dtype=torch.uint8) - return state_dict - - def _load_storage_from_state_dict_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - self.storage.copy_(state_dict[prefix + 'storage'].storage().untyped()) - del state_dict[prefix + 'storage'] - - def forward(self, *args, **kwargs): - return self.expert_module(*args, **kwargs) - - - @staticmethod - def replace_layer_storage( - layer: tp.Any, - device: torch.device, - ): - state_dict = { - f"w{i}": { - "W_q": getattr(layer, f"w{i}").W_q, - "meta": getattr(layer, f"w{i}").meta, - "bias": getattr(layer, f"w{i}").bias, - } - for i in range(1, 4) - } - - storage_size = 0 - offsets = [0] - - for x in nested_flatten(state_dict): - if not isinstance(x, torch.Tensor): - continue - storage_size += x.nbytes - offsets.append(storage_size) - - storage = torch.UntypedStorage(storage_size, device=device) - - i = 0 - new_flattened_states = list() - for x in nested_flatten(state_dict): - if not isinstance(x, torch.Tensor): - new_flattened_states.append(x) - continue - - start = offsets[i] - end = offsets[i + 1] - a_view = torch.as_tensor(storage[start:end], dtype=x.dtype, device=device).view(x.shape) - a_view[...] = x - assert a_view.data_ptr() == storage.data_ptr() + start - i += 1 - new_flattened_states.append(a_view) - - state_dict = nested_pack(new_flattened_states, state_dict) - - for layer_id, states in state_dict.items(): - patched = getattr(layer, layer_id) - patched.W_q = states["W_q"] - patched.meta = states["meta"] - patched.bias = states["bias"] - setattr(layer, layer_id, patched) - - return layer, storage \ No newline at end of file diff --git a/text2text/mixtral/packing.py b/text2text/mixtral/packing.py deleted file mode 100644 index 3e1d6ef..0000000 --- a/text2text/mixtral/packing.py +++ /dev/null @@ -1,135 +0,0 @@ -# MIT License -# -# Copyright (c) 2023 Artyom Eliseev, Denis Mazur -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -from hqq.core.quantize import Quantizer -from hqq.core.bitpack import BitPack - -class PackedTensor(torch.Tensor): - def __init__(self, t: torch.Tensor): - self = t - -# 4 bit to uint8 -def pack_4bit_u8_common(W_q: torch.Tensor): - height = W_q.size(0) - assert height % 2 == 0 - - W_q = W_q.to(torch.uint8) - p = (W_q[::2, ...] << 4) | (W_q[1::2, ...]) - - return PackedTensor(p.to(torch.uint8)) - -def unpack_4bit_u8_common(W_q: torch.Tensor): - height = W_q.size(0) - W_q = W_q.to(torch.uint8) - result = torch.empty([2 * height] + list(W_q.shape[1:]), - dtype=torch.uint8, device=W_q.device) - result[::2, ...] = (W_q >> 4) - result[1::2, ...] = (W_q & 0b1111) - - return result - -def unpack_4bit_u8_universal(W_q: torch.Tensor): - if isinstance(W_q, PackedTensor): - return unpack_4bit_u8_common(W_q) - else: - return BitPack.unpack_4bit_u8(W_q) - -# 2 bit to uin8 -def pack_2bit_u8_common(W_q: torch.Tensor): - W_q = W_q.to(torch.uint8) - height = W_q.size(0) - p = (W_q[::4, ...] << 6) | (W_q[1::4, ...] << 4) | (W_q[2::4, ...] << 2) | (W_q[3::4, ...]) - - return PackedTensor(p) - -def unpack_2bit_u8_common(W_q: torch.Tensor): - W_q = W_q.to(torch.uint8) - height = W_q.size(0) - result = torch.empty([4 * height] + list(W_q.shape[1:]), - dtype=torch.uint8, device=W_q.device) - result[::4, ...] = (W_q >> 6) & 0b11 - result[1::4, ...] = (W_q >> 4) & 0b11 - result[2::4, ...] = (W_q >> 2) & 0b11 - result[3::4, ...] = W_q & 0b11 - - return result - -def unpack_2bit_u8_universal(W_q: torch.Tensor): - if isinstance(W_q, PackedTensor): - return unpack_2bit_u8_common(W_q) - else: - return BitPack.unpack_2bit_u8(W_q) - -# 3 bit to int32 -def pack_3bit_i32_common(W_q: torch.Tensor): - height = W_q.size(0) - - # rounding height to nearest 10, because i32 can fit 10 3-bit integers - rem = height % 10 - if rem == 0: - rem = 10 - - new_height = (height + 10 - 1) // 10 - p = torch.zeros((new_height,) + W_q.shape[1:], device=W_q.device, dtype=torch.int32) - - for i in range(10): - if i < rem: - p |= W_q[i::10, ...].to(torch.int32) << (3 * (9 - i)) - else: - p[:new_height - 1, ...] |= W_q[i::10, ...].to(torch.int32) << (3 * (9 - i)) - - assert p.dtype == torch.int32 - - return PackedTensor(p) - -def unpack_3bit_i32_common(W_q: torch.Tensor): - """ - There may be spare rows after unpacking (height is rounded to nearest multiple of 10) - """ - - assert W_q.dtype == torch.int32 - height = W_q.size(0) - - result = torch.empty([10 * height] + list(W_q.shape[1:]), - dtype=torch.uint8, device=W_q.device) - - for i in range(10): - result[i::10, ...] = (W_q >> (3 * (9 - i))) & 0b111 - - return result - -def unpack_3bit_i32_universal(W_q: torch.Tensor): - if isinstance(W_q, PackedTensor): - return unpack_3bit_i32_common(W_q) - else: - return BitPack.unpack_3bit_32(W_q) - -def patch_packing(): - Quantizer.pack['4bit_u8'] = pack_4bit_u8_common - Quantizer.unpack['4bit_u8'] = unpack_4bit_u8_universal - Quantizer.pack['2bit_u8'] = pack_2bit_u8_common - Quantizer.unpack['2bit_u8'] = unpack_2bit_u8_universal - Quantizer.pack['3bit_32'] = pack_3bit_i32_common - Quantizer.unpack['3bit_32'] = unpack_3bit_i32_universal - -patch_packing() \ No newline at end of file diff --git a/text2text/mixtral/triton_kernels.py b/text2text/mixtral/triton_kernels.py deleted file mode 100644 index 2fc6626..0000000 --- a/text2text/mixtral/triton_kernels.py +++ /dev/null @@ -1,586 +0,0 @@ -# MIT License -# -# Copyright (c) 2023 Artyom Eliseev, Denis Mazur -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import triton -import triton.language as tl -import torch -from typing import Optional - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': N, - 'BLOCK_SIZE_K': K, 'GROUP_SIZE_M': 1}, - num_stages=S, num_warps=W) for N, K, S, W in - [ -# (32, 16, 1, 2), - (32, 32, 4, 4), -# (32, 32, 5, 2), -# (32, 32, 5, 8), -# (32, 128, 2, 4), -# (64, 32, 2, 4), -# (64, 32, 3, 4), -# (64, 32, 4, 4), -# (64, 32, 4, 8), -# (64, 32, 5, 2), -# (64, 32, 5, 8), -# (64, 64, 3, 8), -# (128, 32, 2, 8), -# (128, 32, 3, 4), -# (128, 32, 3, 8), -# (128, 32, 4, 4), -# (128, 32, 4, 8), -# (256, 32, 3, 8), -# (256, 32, 4, 4), -# (256, 64, 3, 8), - ] - - ], - key=['M', 'N', 'K'], -) -@triton.jit -def matmul4_kernel_transpose( - a_ptr, b_ptr, c_ptr, - scales_ptr, zeros_ptr, - M, N, K, - stride_am, stride_ak, - stride_bn, stride_bk, - stride_cm, stride_cn, - stride_scales_g, stride_scales_n, - stride_zeros_g, stride_zeros_n, - groupsize, NO_GROUPS: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (N//2, K) int32 - C is of shape (M, N) float16 - scales is of shape (G, K) float16 - zeros is of shape (G, K) int32 - groupsize is an int specifying the size of groups for scales and zeros. - G is N // groupsize. - Set NO_GROUPS to groupsize == N, in which case G = 1 and the kernel is more efficient. - - WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K. - WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N. - WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K. - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group # - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = (offs_am[:, None] < M) - # b_ptrs is set up such that it repeats elements along the N axis 2 times - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + (offs_bn[None, :] // 2) * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - - G = N // groupsize - scales_ptrs = scales_ptr + (offs_bn[None, :] % G) * stride_scales_g # (1, BLOCK_SIZE_N) - zeros_ptrs = zeros_ptr + (offs_bn[None, :] % G) * stride_zeros_g # (1, BLOCK_SIZE_N) - - # shifter is used to extract the 4 bits of each element in the 8-bit word from B - shifter = ((offs_bn + 1) % 2) * 4 - - # If G == 1, scales and zeros are the same for all N, so we can load them once - if NO_GROUPS: - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load(scales_ptrs) # (BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 - - # Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N) - # M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension - # So this loop is along the infeatures dimension (K) - # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, num_pid_k): - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - if not NO_GROUPS: - offs_k_scale = BLOCK_SIZE_K * k + offs_k - ptr = scales_ptrs + offs_k_scale[:, None] * stride_scales_n # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - scales = tl.load(ptr) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - ptr = zeros_ptrs + offs_k_scale[:, None] * stride_zeros_n # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(ptr) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - # Now we need to unpack b (which is 4-bit values) into 8-bit values - b = (b >> shifter[None, :]) & 0xF # Extract the 4-bit values - b = b.to(tl.float16) - b = (b - zeros) * scales # Scale and shift - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - c = accumulator.to(tl.float16) - - # Store the result - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - -def triton_matmul4_transpose(groupsize: int, a: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, zeros: torch.FloatTensor, bias: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: - """ - Compute the matrix multiplication C = A x B + bias. - Where B is quantized using GPTQ and groupsize = -1 into 4-bit values. - - A is of shape (M, K) float16 - qweight is of shape (N//2, K) int32 - scales is of shape (G, K) float16 - zeros is of shape (G, K) float16 - bias is of shape (1, N) float16 - - groupsize is the number of infeatures in each group. - G = N // groupsize - - C = A @ qweight.T - Returns C of shape (..., N) float16 - """ - assert a.shape[-1] == (qweight.shape[1]) - assert a.is_contiguous(), "A must be contiguous" - assert scales.shape[1] == zeros.shape[1] - assert scales.shape[1] == qweight.shape[1] - - # Flatten a into (-1, K) - x = a.view(-1, a.shape[-1]) - - M, K = x.shape - N = qweight.shape[0] * 2 - # This is based on the possible BLOCK_SIZE_Ks -# assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128" - # This is based on the possible BLOCK_SIZE_Ns -# assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0 and N % 128 == 0 and N % 256 == 0, "N must be a multiple of 16, 32, 64, 128, and 256" - # This is based on the possible BLOCK_SIZE_Ks -# assert groupsize % 32 == 0 and groupsize % 64 == 0 and groupsize % 128 == 0, "groupsize must be a multiple of 32, 64, and 128" - - c = torch.empty((M, N), device='cuda', dtype=torch.float16) - - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - matmul4_kernel_transpose[grid]( - x, qweight, c, - scales, zeros, - M, N, K, - x.stride(0), x.stride(1), - qweight.stride(0), qweight.stride(1), - c.stride(0), c.stride(1), - scales.stride(0), scales.stride(1), - zeros.stride(0), zeros.stride(1), - groupsize, groupsize == N, - ) - - # Reshape c - c = c.view(a.shape[:-1] + (N,)) # (..., N) - - # Add bias - if bias is not None: - c = c + bias - - return c - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': N, - 'BLOCK_SIZE_K': K, 'GROUP_SIZE_M': 1}, - num_stages=S, num_warps=W) for N, K, S, W in - [ -# (32, 16, 1, 2), - (32, 32, 4, 4), # best -# (32, 32, 5, 2), -# (32, 32, 5, 8), -# (32, 128, 2, 4), -# (64, 32, 2, 4), -# (64, 32, 3, 4), -# (64, 32, 4, 4), -# (64, 32, 4, 8), -# (64, 32, 5, 2), -# (64, 32, 5, 8), -# (64, 64, 3, 8), -# (128, 32, 2, 8), -# (128, 32, 3, 4), -# (128, 32, 3, 8), -# (128, 32, 4, 4), -# (128, 32, 4, 8), -# (256, 32, 3, 8), -# (256, 32, 4, 4), -# (256, 64, 3, 8), - ] - - ], - key=['M', 'N', 'K'], -) -@triton.jit -def matmul2_kernel_transpose( - a_ptr, b_ptr, c_ptr, - scales_ptr, zeros_ptr, - M, N, K, - stride_am, stride_ak, - stride_bn, stride_bk, - stride_cm, stride_cn, - stride_scales_g, stride_scales_n, - stride_zeros_g, stride_zeros_n, - groupsize, NO_GROUPS: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (N // 4, K) int8 - C is of shape (M, N) float16 - scales is of shape (G, K) float16 - zeros is of shape (G, K) int32 - groupsize is an int specifying the size of groups for scales and zeros. - G is N // groupsize. - Set NO_GROUPS to groupsize == N, in which case G = 1 and the kernel is more efficient. - - WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K. - WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N. - WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K. - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group # - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = (offs_am[:, None] < M) - # b_ptrs is set up such that it repeats elements along the N axis 4 times - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + (offs_bn[None, :] // 4) * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - - G = N // groupsize - scales_ptrs = scales_ptr + (offs_bn[None, :] % G) * stride_scales_g # (1, BLOCK_SIZE_N) - zeros_ptrs = zeros_ptr + (offs_bn[None, :] % G) * stride_zeros_g # (1, BLOCK_SIZE_N) - - # shifter is used to extract the 2 bits of each element in the 8-bit word from B - shifter = (3 - (offs_bn % 4)) * 2 - - # If G == 1, scales and zeros are the same for all N, so we can load them once - if NO_GROUPS: - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load(scales_ptrs) # (BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_N,) - - # Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N) - # M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension - # So this loop is along the infeatures dimension (K) - # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, num_pid_k): - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - if not NO_GROUPS: - offs_k_scale = BLOCK_SIZE_K * k + offs_k - ptr = scales_ptrs + offs_k_scale[:, None] * stride_scales_n # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - scales = tl.load(ptr) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - ptr = zeros_ptrs + offs_k_scale[:, None] * stride_zeros_n # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(ptr) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - # Now we need to unpack b (which is 4-bit values) into 8-bit values - b = (b >> shifter[None, :]) & 0b11 # Extract the 2-bit values - b = b.to(tl.float16) - b = (b - zeros) * scales # Scale and shift - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - c = accumulator.to(tl.float16) - - # Store the result - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - -def triton_matmul2_transpose(groupsize: int, a: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, zeros: torch.FloatTensor, bias: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: - """ - Compute the matrix multiplication C = A x B + bias. - Where B is quantized using GPTQ and groupsize = -1 into 4-bit values. - - A is of shape (M, K) float16 - qweight is of shape (N // 4, K) int32 - scales is of shape (G, K) float16 - zeros is of shape (G, K) float16 - bias is of shape (1, N) float16 - - groupsize is the number of infeatures in each group. - G = N // groupsize - - C = A @ qweight.T - Returns C of shape (..., N) float16 - """ - - assert a.shape[-1] == (qweight.shape[1]) - assert a.is_contiguous(), "A must be contiguous" - assert scales.shape[1] == zeros.shape[1] - assert scales.shape[1] == qweight.shape[1] - - # Flatten a into (-1, K) - x = a.view(-1, a.shape[-1]) - - M, K = x.shape - N = qweight.shape[0] * 4 - # This is based on the possible BLOCK_SIZE_Ks -# assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128" - # This is based on the possible BLOCK_SIZE_Ns -# assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0 and N % 128 == 0 and N % 256 == 0, "N must be a multiple of 16, 32, 64, 128, and 256" - # This is based on the possible BLOCK_SIZE_Ks -# assert groupsize % 32 == 0 and groupsize % 64 == 0 and groupsize % 128 == 0, "groupsize must be a multiple of 32, 64, and 128" - - c = torch.empty((M, N), device='cuda', dtype=torch.float16) - - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - matmul2_kernel_transpose[grid]( - x, qweight, c, - scales, zeros, - M, N, K, - x.stride(0), x.stride(1), - qweight.stride(0), qweight.stride(1), - c.stride(0), c.stride(1), - scales.stride(0), scales.stride(1), - zeros.stride(0), zeros.stride(1), - groupsize, groupsize == N, - ) - - # Reshape c - c = c.view(a.shape[:-1] + (N,)) # (..., N) - - # Add bias - if bias is not None: - c = c + bias - - return c - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': N, - 'BLOCK_SIZE_K': K, 'GROUP_SIZE_M': 1}, - num_stages=S, num_warps=W) for N, K, S, W in - [ -# (32, 16, 1, 2), -# (32, 32, 4, 4), -# (32, 32, 5, 2), - (32, 32, 5, 8), # best -# (32, 128, 2, 4), -# (64, 32, 2, 4), -# (64, 32, 3, 4), -# (64, 32, 4, 4), -# (64, 32, 4, 8), -# (64, 32, 5, 2), -# (64, 32, 5, 8), -# (64, 64, 3, 8), -# (128, 32, 2, 8), -# (128, 32, 3, 4), -# (128, 32, 3, 8), -# (128, 32, 4, 4), -# (128, 32, 4, 8), -# (256, 32, 3, 8), -# (256, 32, 4, 4), -# (256, 64, 3, 8), - ] - - ], - key=['M', 'N', 'K'], -) -@triton.jit -def matmul3_kernel_transpose( - a_ptr, b_ptr, c_ptr, - scales_ptr, zeros_ptr, - M, N, K, - stride_am, stride_ak, - stride_bn, stride_bk, - stride_cm, stride_cn, - stride_scales_g, stride_scales_n, - stride_zeros_g, stride_zeros_n, - groupsize, NO_GROUPS: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (ceil(N / 10), K) int32 - C is of shape (M, N) float16 - scales is of shape (G, K) float16 - zeros is of shape (G, K) int32 - groupsize is an int specifying the size of groups for scales and zeros. - G is N // groupsize. - Set NO_GROUPS to groupsize == N, in which case G = 1 and the kernel is more efficient. - - WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K. - WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N. - WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K. - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group # - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = (offs_am[:, None] < M) - - # b_ptrs is set up such that it repeats elements along the N axis 10 times - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + (offs_bn[None, :] // 10) * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - - G = N // groupsize - scales_ptrs = scales_ptr + (offs_bn[None, :] % G) * stride_scales_g # (1, BLOCK_SIZE_N) - zeros_ptrs = zeros_ptr + (offs_bn[None, :] % G) * stride_zeros_g # (1, BLOCK_SIZE_N) - - # shifter is used to extract the 3 bits of each element in the 32-bit word from B - shifter = (9 - (offs_bn % 10)) * 3 - - # If G == 1, scales and zeros are the same for all N, so we can load them once - if NO_GROUPS: - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load(scales_ptrs) # (BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_N,) - - # Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N) - # M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension - # So this loop is along the infeatures dimension (K) - # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, num_pid_k): - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - if not NO_GROUPS: - offs_k_scale = BLOCK_SIZE_K * k + offs_k - ptr = scales_ptrs + offs_k_scale[:, None] * stride_scales_n # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - scales = tl.load(ptr) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - ptr = zeros_ptrs + offs_k_scale[:, None] * stride_zeros_n # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(ptr) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - # Now we need to unpack b (which is 3-bit values into 32-bit values) - b = (b >> shifter[None, :]) & 0b111 # Extract the 3-bit values - b = b.to(tl.float16) - b = (b - zeros) * scales # Scale and shift - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - c = accumulator.to(tl.float16) - - # Store the result - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - -def triton_matmul3_transpose(groupsize: int, a: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, zeros: torch.FloatTensor, N: int, bias: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: - """ - Compute the matrix multiplication C = A x B + bias. - Where B is quantized using GPTQ and groupsize = -1 into 4-bit values. - - A is of shape (M, K) float16 - qweight is of shape (ceil(N / 10), K) int32 - scales is of shape (G, K) float16 - zeros is of shape (G, K) float16 - bias is of shape (1, N) float16 - - groupsize is the number of infeatures in each group. - G = N // groupsize - - C = A @ qweight.T - Returns C of shape (..., N) float16 - """ - - assert a.shape[-1] == (qweight.shape[1]) - assert a.is_contiguous(), "A must be contiguous" - assert scales.shape[1] == zeros.shape[1] - assert scales.shape[1] == qweight.shape[1] - - # Flatten a into (-1, K) - x = a.view(-1, a.shape[-1]) - - M, K = x.shape - assert 0 <= (qweight.shape[0] * 10 - N) < 10 - - c = torch.empty((M, N), device='cuda', dtype=torch.float16) - - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - matmul3_kernel_transpose[grid]( - x, qweight, c, - scales, zeros, - M, N, K, - x.stride(0), x.stride(1), - qweight.stride(0), qweight.stride(1), - c.stride(0), c.stride(1), - scales.stride(0), scales.stride(1), - zeros.stride(0), zeros.stride(1), - groupsize, groupsize == N, - ) - - # Reshape c - c = c.view(a.shape[:-1] + (N,)) # (..., N) - - # Add bias - if bias is not None: - c = c + bias - - return c \ No newline at end of file diff --git a/text2text/mixtral/utils.py b/text2text/mixtral/utils.py deleted file mode 100644 index 32a2464..0000000 --- a/text2text/mixtral/utils.py +++ /dev/null @@ -1,123 +0,0 @@ -# MIT License -# -# Copyright (c) 2023 Artyom Eliseev, Denis Mazur -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from contextlib import contextmanager -import torch -""" utility functions that help you process nested dicts, tuples, lists and namedtuples """ - - -def nested_compare(t, u): - """ - Return whether nested structure of t1 and t2 matches. - """ - if isinstance(t, (list, tuple)): - if not isinstance(u, type(t)): - return False - if len(t) != len(u): - return False - for a, b in zip(t, u): - if not nested_compare(a, b): - return False - return True - - if isinstance(t, dict): - if not isinstance(u, dict): - return False - if set(t.keys()) != set(u.keys()): - return False - for k in t: - if not nested_compare(t[k], u[k]): - return False - return True - - else: - return True - - -def nested_flatten(t): - """ - Turn nested list/tuple/dict into a flat iterator. - """ - if isinstance(t, (list, tuple)): - for x in t: - yield from nested_flatten(x) - elif isinstance(t, dict): - for k, v in sorted(t.items()): - yield from nested_flatten(v) - else: - yield t - - -def nested_pack(flat, structure): - """ - Restore nested structure from flattened state - :param flat: result of nested_flatten - :param structure: used as example when recovering structure - :returns: nested structure like :structure: filled with elements of :flat: - """ - return _nested_pack(iter(flat), structure) - - -def _nested_pack(flat_iter, structure): - if is_namedtuple(structure): - return type(structure)(*[_nested_pack(flat_iter, x) for x in structure]) - elif isinstance(structure, (list, tuple)): - return type(structure)(_nested_pack(flat_iter, x) for x in structure) - elif isinstance(structure, dict): - return {k: _nested_pack(flat_iter, v) for k, v in sorted(structure.items())} - else: - return next(flat_iter) - - -def is_namedtuple(x): - """Checks if x is a namedtuple instance. Taken from https://stackoverflow.com/a/2166841 .""" - t = type(x) - b = t.__bases__ - if len(b) != 1 or b[0] != tuple: - return False - f = getattr(t, "_fields", None) - if not isinstance(f, tuple): - return False - return all(type(n) == str for n in f) - - -def nested_map(fn, *t): - # Check arguments. - if not t: - raise ValueError("Expected 2+ arguments, got 1") - for i in range(1, len(t)): - if not nested_compare(t[0], t[i]): - msg = "Nested structure of %r and %r differs" - raise ValueError(msg % (t[0], t[i])) - - flat = map(nested_flatten, t) - return nested_pack(map(fn, *flat), t[0]) - -@contextmanager -def with_default_dtype(dtype): - _dtype_original = torch.get_default_dtype() - - try: - torch.set_default_dtype(dtype) - yield - finally: - torch.set_default_dtype(_dtype_original) \ No newline at end of file diff --git a/text2text/pytorch_pretrained_bert/__init__.py b/text2text/pytorch_pretrained_bert/__init__.py deleted file mode 100755 index 7ea34c4..0000000 --- a/text2text/pytorch_pretrained_bert/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .tokenization import BertTokenizer -from .modeling import BertForSeq2SeqDecoder -from .file_utils import cached_path diff --git a/text2text/pytorch_pretrained_bert/__main__.py b/text2text/pytorch_pretrained_bert/__main__.py deleted file mode 100755 index 79ad842..0000000 --- a/text2text/pytorch_pretrained_bert/__main__.py +++ /dev/null @@ -1,22 +0,0 @@ -# coding: utf8 -def main(): - import sys - try: - from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch - except ModuleNotFoundError: - print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " - "In that case, it requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions.") - raise - - if len(sys.argv) != 5: - # pylint: disable=line-too-long - print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") - else: - PYTORCH_DUMP_OUTPUT = sys.argv.pop() - TF_CONFIG = sys.argv.pop() - TF_CHECKPOINT = sys.argv.pop() - convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) - -if __name__ == '__main__': - main() diff --git a/text2text/pytorch_pretrained_bert/file_utils.py b/text2text/pytorch_pretrained_bert/file_utils.py deleted file mode 100755 index 7718eb7..0000000 --- a/text2text/pytorch_pretrained_bert/file_utils.py +++ /dev/null @@ -1,237 +0,0 @@ -""" -Utilities for working with the local dataset cache. -This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp -Copyright by the AllenNLP authors. -""" - -import os -import logging -import shutil -import tempfile -import json -from urllib.parse import urlparse -from pathlib import Path -from typing import Optional, Tuple, Union, IO, Callable, Set -from hashlib import sha256 -from functools import wraps - -from tqdm import tqdm - -import requests - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - -PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', - Path.home() / '.pytorch_pretrained_bert')) - - -def url_to_filename(url: str, etag: str = None) -> str: - """ - Convert `url` into a hashed filename in a repeatable way. - If `etag` is specified, append its hash to the url's, delimited - by a period. - """ - url_bytes = url.encode('utf-8') - url_hash = sha256(url_bytes) - filename = url_hash.hexdigest() - - if etag: - etag_bytes = etag.encode('utf-8') - etag_hash = sha256(etag_bytes) - filename += '.' + etag_hash.hexdigest() - - return filename - - -def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: - """ - Return the url and etag (which may be ``None``) stored for `filename`. - Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. - """ - if cache_dir is None: - cache_dir = PYTORCH_PRETRAINED_BERT_CACHE - if isinstance(cache_dir, Path): - cache_dir = str(cache_dir) - - cache_path = os.path.join(cache_dir, filename) - if not os.path.exists(cache_path): - raise FileNotFoundError("file {} not found".format(cache_path)) - - meta_path = cache_path + '.json' - if not os.path.exists(meta_path): - raise FileNotFoundError("file {} not found".format(meta_path)) - - with open(meta_path) as meta_file: - metadata = json.load(meta_file) - url = metadata['url'] - etag = metadata['etag'] - - return url, etag - - -def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: - """ - Given something that might be a URL (or might be a local path), - determine which. If it's a URL, download the file and cache it, and - return the path to the cached file. If it's already a local path, - make sure the file exists and then return the path. - """ - if cache_dir is None: - cache_dir = PYTORCH_PRETRAINED_BERT_CACHE - if isinstance(url_or_filename, Path): - url_or_filename = str(url_or_filename) - if isinstance(cache_dir, Path): - cache_dir = str(cache_dir) - - parsed = urlparse(url_or_filename) - - if parsed.scheme in ('http', 'https', 's3'): - # URL, so get it from the cache (downloading if necessary) - return get_from_cache(url_or_filename, cache_dir) - elif os.path.exists(url_or_filename): - # File, and it exists. - return url_or_filename - elif parsed.scheme == '': - # File, but it doesn't exist. - raise FileNotFoundError("file {} not found".format(url_or_filename)) - else: - # Something unknown - raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) - - -def split_s3_path(url: str) -> Tuple[str, str]: - """Split a full s3 path into the bucket name and path.""" - parsed = urlparse(url) - if not parsed.netloc or not parsed.path: - raise ValueError("bad s3 path {}".format(url)) - bucket_name = parsed.netloc - s3_path = parsed.path - # Remove '/' at beginning of path. - if s3_path.startswith("/"): - s3_path = s3_path[1:] - return bucket_name, s3_path - - -def s3_request(func: Callable): - """ - Wrapper function for s3 requests in order to create more helpful error - messages. - """ - - @wraps(func) - def wrapper(url: str, *args, **kwargs): - try: - return func(url, *args, **kwargs) - except ClientError as exc: - if int(exc.response["Error"]["Code"]) == 404: - raise FileNotFoundError("file {} not found".format(url)) - else: - raise - - return wrapper - - -@s3_request -def s3_etag(url: str) -> Optional[str]: - """Check ETag on S3 object.""" - s3_resource = boto3.resource("s3") - bucket_name, s3_path = split_s3_path(url) - s3_object = s3_resource.Object(bucket_name, s3_path) - return s3_object.e_tag - - -@s3_request -def s3_get(url: str, temp_file: IO) -> None: - """Pull a file directly from S3.""" - s3_resource = boto3.resource("s3") - bucket_name, s3_path = split_s3_path(url) - s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) - - -def http_get(url: str, temp_file: IO) -> None: - req = requests.get(url, stream=True) - content_length = req.headers.get('Content-Length') - total = int(content_length) if content_length is not None else None - progress = tqdm(unit="B", total=total) - for chunk in req.iter_content(chunk_size=1024): - if chunk: # filter out keep-alive new chunks - progress.update(len(chunk)) - temp_file.write(chunk) - progress.close() - - -def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: - """ - Given a URL, look for the corresponding dataset in the local cache. - If it's not there, download it. Then return the path to the cached file. - """ - if cache_dir is None: - cache_dir = PYTORCH_PRETRAINED_BERT_CACHE - if isinstance(cache_dir, Path): - cache_dir = str(cache_dir) - - os.makedirs(cache_dir, exist_ok=True) - - # Get eTag to add to filename, if it exists. - if url.startswith("s3://"): - etag = s3_etag(url) - else: - response = requests.head(url, allow_redirects=True) - if response.status_code != 200: - raise IOError("HEAD request failed for url {} with status code {}" - .format(url, response.status_code)) - etag = response.headers.get("ETag") - - filename = url_to_filename(url, etag) - - # get cache path to put the file - cache_path = os.path.join(cache_dir, filename) - - if not os.path.exists(cache_path): - # Download to temporary file, then copy to cache dir once finished. - # Otherwise you get corrupt cache entries if the download gets interrupted. - with tempfile.NamedTemporaryFile() as temp_file: - logger.info("%s not found in cache, downloading to %s", url, temp_file.name) - - # GET file object - if url.startswith("s3://"): - s3_get(url, temp_file) - else: - http_get(url, temp_file) - - # we are copying the file before closing it, so flush to avoid truncation - temp_file.flush() - # shutil.copyfileobj() starts at the current position, so go to the start - temp_file.seek(0) - - logger.info("copying %s to cache at %s", temp_file.name, cache_path) - with open(cache_path, 'wb') as cache_file: - shutil.copyfileobj(temp_file, cache_file) - - logger.info("creating metadata file for %s", cache_path) - meta = {'url': url, 'etag': etag} - meta_path = cache_path + '.json' - with open(meta_path, 'w') as meta_file: - json.dump(meta, meta_file) - - logger.info("removing temp file %s", temp_file.name) - - return cache_path - - -def read_set_from_file(filename: str) -> Set[str]: - ''' - Extract a de-duped collection (set) of text from a file. - Expected file format is one item per line. - ''' - collection = set() - with open(filename, 'r', encoding='utf-8') as file_: - for line in file_: - collection.add(line.rstrip()) - return collection - - -def get_file_extension(path: str, dot=True, lower: bool = True): - ext = os.path.splitext(path)[1] - ext = ext if dot else ext[1:] - return ext.lower() if lower else ext diff --git a/text2text/pytorch_pretrained_bert/loss.py b/text2text/pytorch_pretrained_bert/loss.py deleted file mode 100755 index 590f9e4..0000000 --- a/text2text/pytorch_pretrained_bert/loss.py +++ /dev/null @@ -1,48 +0,0 @@ -# coding=utf-8 - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import torch -import torch.nn.functional as F -from torch.nn.modules.loss import _Loss - - -class LabelSmoothingLoss(_Loss): - """ - With label smoothing, - KL-divergence between q_{smoothed ground truth prob.}(w) - and p_{prob. computed by model}(w) is minimized. - """ - - def __init__(self, label_smoothing=0, tgt_vocab_size=0, ignore_index=0, size_average=None, reduce=None, reduction='mean'): - assert 0.0 < label_smoothing <= 1.0 - self.ignore_index = ignore_index - super(LabelSmoothingLoss, self).__init__( - size_average=size_average, reduce=reduce, reduction=reduction) - - assert label_smoothing > 0 - assert tgt_vocab_size > 0 - - smoothing_value = label_smoothing / (tgt_vocab_size - 2) - one_hot = torch.full((tgt_vocab_size,), smoothing_value) - one_hot[self.ignore_index] = 0 - self.register_buffer('one_hot', one_hot.unsqueeze(0)) - self.confidence = 1.0 - label_smoothing - self.tgt_vocab_size = tgt_vocab_size - - def forward(self, output, target): - """ - output (FloatTensor): batch_size * num_pos * n_classes - target (LongTensor): batch_size * num_pos - """ - assert self.tgt_vocab_size == output.size(2) - batch_size, num_pos = target.size(0), target.size(1) - output = output.view(-1, self.tgt_vocab_size) - target = target.view(-1) - model_prob = self.one_hot.repeat(target.size(0), 1) - model_prob.scatter_(1, target.unsqueeze(1), self.confidence) - model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) - - return F.kl_div(output, model_prob, reduction='none').view(batch_size, num_pos, -1).sum(2) diff --git a/text2text/pytorch_pretrained_bert/modeling.py b/text2text/pytorch_pretrained_bert/modeling.py deleted file mode 100755 index b8ea0fe..0000000 --- a/text2text/pytorch_pretrained_bert/modeling.py +++ /dev/null @@ -1,2210 +0,0 @@ -# coding=utf-8 -"""PyTorch BERT model.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import copy -import json -import math -import logging -import tarfile -import tempfile -import shutil -import numpy as np -from scipy.stats import truncnorm - -import torch -from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss -import torch.nn.functional as F - -from .file_utils import cached_path -from .loss import LabelSmoothingLoss - -logger = logging.getLogger(__name__) - -PRETRAINED_MODEL_ARCHIVE_MAP = { - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", - 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", - 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", -} -CONFIG_NAME = 'bert_config.json' -WEIGHTS_NAME = 'pytorch_model.bin' - - -def gelu(x): - """Implementation of the gelu activation function. - For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): - 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) - """ - return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) - - -def swish(x): - return x * torch.sigmoid(x) - - -ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} - - -class BertConfig(object): - """Configuration class to store the configuration of a `BertModel`. - """ - - def __init__(self, - vocab_size_or_config_json_file, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - relax_projection=0, - new_pos_ids=False, - initializer_range=0.02, - task_idx=None, - fp32_embedding=False, - ffn_type=0, - label_smoothing=None, - num_qkv=0, - seg_emb=False): - """Constructs BertConfig. - - Args: - vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. - hidden_size: Size of the encoder layers and the pooler layer. - num_hidden_layers: Number of hidden layers in the Transformer encoder. - num_attention_heads: Number of attention heads for each attention layer in - the Transformer encoder. - intermediate_size: The size of the "intermediate" (i.e., feed-forward) - layer in the Transformer encoder. - hidden_act: The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu" and "swish" are supported. - hidden_dropout_prob: The dropout probabilitiy for all fully connected - layers in the embeddings, encoder, and pooler. - attention_probs_dropout_prob: The dropout ratio for the attention - probabilities. - max_position_embeddings: The maximum sequence length that this model might - ever be used with. Typically set this to something large just in case - (e.g., 512 or 1024 or 2048). - type_vocab_size: The vocabulary size of the `token_type_ids` passed into - `BertModel`. - initializer_range: The sttdev of the truncated_normal_initializer for - initializing all weight matrices. - """ - if isinstance(vocab_size_or_config_json_file, str): - with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: - json_config = json.loads(reader.read()) - self.__dict__.update(json_config) - elif isinstance(vocab_size_or_config_json_file, int): - self.vocab_size = vocab_size_or_config_json_file - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.relax_projection = relax_projection - self.new_pos_ids = new_pos_ids - self.initializer_range = initializer_range - self.task_idx = task_idx - self.fp32_embedding = fp32_embedding - self.ffn_type = ffn_type - self.label_smoothing = label_smoothing - self.num_qkv = num_qkv - self.seg_emb = seg_emb - else: - raise ValueError("First argument must be either a vocabulary size (int)" - "or the path to a pretrained model config file (str)") - - @classmethod - def from_dict(cls, json_object): - """Constructs a `BertConfig` from a Python dictionary of parameters.""" - config = BertConfig(vocab_size_or_config_json_file=-1) - config.__dict__.update(json_object) - return config - - @classmethod - def from_json_file(cls, json_file): - """Constructs a `BertConfig` from a json file of parameters.""" - with open(json_file, "r", encoding='utf-8') as reader: - text = reader.read() - return cls.from_dict(json.loads(text)) - - def __repr__(self): - return str(self.to_json_string()) - - def to_dict(self): - """Serializes this instance to a Python dictionary.""" - output = copy.deepcopy(self.__dict__) - return output - - def to_json_string(self): - """Serializes this instance to a JSON string.""" - return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" - - -try: - from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm -except ImportError: - class BertLayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-5): - """Construct a layernorm module in the TF style (epsilon inside the square root). - """ - super(BertLayerNorm, self).__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.bias = nn.Parameter(torch.zeros(hidden_size)) - self.variance_epsilon = eps - - def forward(self, x): - u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.variance_epsilon) - return self.weight * x + self.bias - - -class PositionalEmbedding(nn.Module): - def __init__(self, demb): - super(PositionalEmbedding, self).__init__() - - self.demb = demb - - inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) - self.register_buffer('inv_freq', inv_freq) - - def forward(self, pos_seq, bsz=None): - sinusoid_inp = torch.ger(pos_seq, self.inv_freq) - pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) - - if bsz is not None: - return pos_emb[:, None, :].expand(-1, bsz, -1) - else: - return pos_emb[:, None, :] - - -class BertEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings. - """ - - def __init__(self, config): - super(BertEmbeddings, self).__init__() - self.word_embeddings = nn.Embedding( - config.vocab_size, config.hidden_size) - self.token_type_embeddings = nn.Embedding( - config.type_vocab_size, config.hidden_size) - if hasattr(config, 'fp32_embedding'): - self.fp32_embedding = config.fp32_embedding - else: - self.fp32_embedding = False - - if hasattr(config, 'new_pos_ids') and config.new_pos_ids: - self.num_pos_emb = 4 - else: - self.num_pos_emb = 1 - self.position_embeddings = nn.Embedding( - config.max_position_embeddings, config.hidden_size*self.num_pos_emb) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, input_ids, token_type_ids=None, position_ids=None, task_idx=None): - seq_length = input_ids.size(1) - if position_ids is None: - position_ids = torch.arange( - seq_length, dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - if self.num_pos_emb > 1: - num_batch = position_embeddings.size(0) - num_pos = position_embeddings.size(1) - position_embeddings = position_embeddings.view( - num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] - - embeddings = words_embeddings + position_embeddings + token_type_embeddings - if self.fp32_embedding: - embeddings = embeddings.half() - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - -class BertSelfAttention(nn.Module): - def __init__(self, config): - super(BertSelfAttention, self).__init__() - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads)) - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int( - config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - if hasattr(config, 'num_qkv') and (config.num_qkv > 1): - self.num_qkv = config.num_qkv - else: - self.num_qkv = 1 - - self.query = nn.Linear( - config.hidden_size, self.all_head_size*self.num_qkv) - self.key = nn.Linear(config.hidden_size, - self.all_head_size*self.num_qkv) - self.value = nn.Linear( - config.hidden_size, self.all_head_size*self.num_qkv) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - - self.uni_debug_flag = True if os.getenv( - 'UNI_DEBUG_FLAG', '') else False - if self.uni_debug_flag: - self.register_buffer('debug_attention_probs', - torch.zeros((512, 512))) - if hasattr(config, 'seg_emb') and config.seg_emb: - self.b_q_s = nn.Parameter(torch.zeros( - 1, self.num_attention_heads, 1, self.attention_head_size)) - self.seg_emb = nn.Embedding( - config.type_vocab_size, self.all_head_size) - else: - self.b_q_s = None - self.seg_emb = None - - def transpose_for_scores(self, x, mask_qkv=None): - if self.num_qkv > 1: - sz = x.size()[:-1] + (self.num_qkv, - self.num_attention_heads, self.all_head_size) - # (batch, pos, num_qkv, head, head_hid) - x = x.view(*sz) - if mask_qkv is None: - x = x[:, :, 0, :, :] - elif isinstance(mask_qkv, int): - x = x[:, :, mask_qkv, :, :] - else: - # mask_qkv: (batch, pos) - if mask_qkv.size(1) > sz[1]: - mask_qkv = mask_qkv[:, :sz[1]] - # -> x: (batch, pos, head, head_hid) - x = x.gather(2, mask_qkv.view(sz[0], sz[1], 1, 1, 1).expand( - sz[0], sz[1], 1, sz[3], sz[4])).squeeze(2) - else: - sz = x.size()[:-1] + (self.num_attention_heads, - self.attention_head_size) - # (batch, pos, head, head_hid) - x = x.view(*sz) - # (batch, head, pos, head_hid) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states, attention_mask, history_states=None, mask_qkv=None, seg_ids=None): - if history_states is None: - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - else: - x_states = torch.cat((history_states, hidden_states), dim=1) - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(x_states) - mixed_value_layer = self.value(x_states) - - query_layer = self.transpose_for_scores(mixed_query_layer, mask_qkv) - key_layer = self.transpose_for_scores(mixed_key_layer, mask_qkv) - value_layer = self.transpose_for_scores(mixed_value_layer, mask_qkv) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch, head, pos, pos) - attention_scores = torch.matmul( - query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2)) - - if self.seg_emb is not None: - seg_rep = self.seg_emb(seg_ids) - # (batch, pos, head, head_hid) - seg_rep = seg_rep.view(seg_rep.size(0), seg_rep.size( - 1), self.num_attention_heads, self.attention_head_size) - qs = torch.einsum('bnih,bjnh->bnij', - query_layer+self.b_q_s, seg_rep) - attention_scores = attention_scores + qs - - # attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.Softmax(dim=-1)(attention_scores) - - if self.uni_debug_flag: - _pos = attention_probs.size(-1) - self.debug_attention_probs[:_pos, :_pos].copy_( - attention_probs[0].mean(0).view(_pos, _pos)) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[ - :-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - return context_layer - - -class BertSelfOutput(nn.Module): - def __init__(self, config): - super(BertSelfOutput, self).__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertAttention(nn.Module): - def __init__(self, config): - super(BertAttention, self).__init__() - self.self = BertSelfAttention(config) - self.output = BertSelfOutput(config) - - def forward(self, input_tensor, attention_mask, history_states=None, mask_qkv=None, seg_ids=None): - self_output = self.self( - input_tensor, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids) - attention_output = self.output(self_output, input_tensor) - return attention_output - - -class BertIntermediate(nn.Module): - def __init__(self, config): - super(BertIntermediate, self).__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - self.intermediate_act_fn = ACT2FN[config.hidden_act] \ - if isinstance(config.hidden_act, str) else config.hidden_act - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - -class BertOutput(nn.Module): - def __init__(self, config): - super(BertOutput, self).__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class TransformerFFN(nn.Module): - def __init__(self, config): - super(TransformerFFN, self).__init__() - self.ffn_type = config.ffn_type - assert self.ffn_type in (1, 2) - if self.ffn_type in (1, 2): - self.wx0 = nn.Linear(config.hidden_size, config.hidden_size) - if self.ffn_type in (2,): - self.wx1 = nn.Linear(config.hidden_size, config.hidden_size) - if self.ffn_type in (1, 2): - self.output = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, x): - if self.ffn_type in (1, 2): - x0 = self.wx0(x) - if self.ffn_type == 1: - x1 = x - elif self.ffn_type == 2: - x1 = self.wx1(x) - out = self.output(x0 * x1) - out = self.dropout(out) - out = self.LayerNorm(out + x) - return out - - -class BertLayer(nn.Module): - def __init__(self, config): - super(BertLayer, self).__init__() - self.attention = BertAttention(config) - self.ffn_type = config.ffn_type - if self.ffn_type: - self.ffn = TransformerFFN(config) - else: - self.intermediate = BertIntermediate(config) - self.output = BertOutput(config) - - def forward(self, hidden_states, attention_mask, history_states=None, mask_qkv=None, seg_ids=None): - attention_output = self.attention( - hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids) - if self.ffn_type: - layer_output = self.ffn(attention_output) - else: - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - return layer_output - - -class BertEncoder(nn.Module): - def __init__(self, config): - super(BertEncoder, self).__init__() - layer = BertLayer(config) - self.layer = nn.ModuleList([copy.deepcopy(layer) - for _ in range(config.num_hidden_layers)]) - - def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, seg_ids=None): - # history embedding and encoded layer must be simultanously given - assert (prev_embedding is None) == (prev_encoded_layers is None) - - all_encoder_layers = [] - if (prev_embedding is not None) and (prev_encoded_layers is not None): - history_states = prev_embedding - for i, layer_module in enumerate(self.layer): - hidden_states = layer_module( - hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids) - if output_all_encoded_layers: - all_encoder_layers.append(hidden_states) - if prev_encoded_layers is not None: - history_states = prev_encoded_layers[i] - else: - for layer_module in self.layer: - hidden_states = layer_module( - hidden_states, attention_mask, mask_qkv=mask_qkv, seg_ids=seg_ids) - if output_all_encoded_layers: - all_encoder_layers.append(hidden_states) - if not output_all_encoded_layers: - all_encoder_layers.append(hidden_states) - return all_encoder_layers - - -class BertPooler(nn.Module): - def __init__(self, config): - super(BertPooler, self).__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class BertPredictionHeadTransform(nn.Module): - def __init__(self, config): - super(BertPredictionHeadTransform, self).__init__() - self.transform_act_fn = ACT2FN[config.hidden_act] \ - if isinstance(config.hidden_act, str) else config.hidden_act - hid_size = config.hidden_size - if hasattr(config, 'relax_projection') and (config.relax_projection > 1): - hid_size *= config.relax_projection - self.dense = nn.Linear(config.hidden_size, hid_size) - self.LayerNorm = BertLayerNorm(hid_size, eps=1e-5) - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class BertLMPredictionHead(nn.Module): - def __init__(self, config, bert_model_embedding_weights): - super(BertLMPredictionHead, self).__init__() - self.transform = BertPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(bert_model_embedding_weights.size(1), - bert_model_embedding_weights.size(0), - bias=False) - self.decoder.weight = bert_model_embedding_weights - self.bias = nn.Parameter(torch.zeros( - bert_model_embedding_weights.size(0))) - if hasattr(config, 'relax_projection') and (config.relax_projection > 1): - self.relax_projection = config.relax_projection - else: - self.relax_projection = 0 - self.fp32_embedding = config.fp32_embedding - - def convert_to_type(tensor): - if self.fp32_embedding: - return tensor.half() - else: - return tensor - self.type_converter = convert_to_type - self.converted = False - - def forward(self, hidden_states, task_idx=None): - if not self.converted: - self.converted = True - if self.fp32_embedding: - self.transform.half() - hidden_states = self.transform(self.type_converter(hidden_states)) - if self.relax_projection > 1: - num_batch = hidden_states.size(0) - num_pos = hidden_states.size(1) - # (batch, num_pos, relax_projection*hid) -> (batch, num_pos, relax_projection, hid) -> (batch, num_pos, hid) - hidden_states = hidden_states.view( - num_batch, num_pos, self.relax_projection, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] - if self.fp32_embedding: - hidden_states = F.linear(self.type_converter(hidden_states), self.type_converter( - self.decoder.weight), self.type_converter(self.bias)) - else: - hidden_states = self.decoder(hidden_states) + self.bias - return hidden_states - - -class BertOnlyMLMHead(nn.Module): - def __init__(self, config, bert_model_embedding_weights): - super(BertOnlyMLMHead, self).__init__() - self.predictions = BertLMPredictionHead( - config, bert_model_embedding_weights) - - def forward(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -class BertOnlyNSPHead(nn.Module): - def __init__(self, config): - super(BertOnlyNSPHead, self).__init__() - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, pooled_output): - seq_relationship_score = self.seq_relationship(pooled_output) - return seq_relationship_score - - -class BertPreTrainingHeads(nn.Module): - def __init__(self, config, bert_model_embedding_weights, num_labels=2): - super(BertPreTrainingHeads, self).__init__() - self.predictions = BertLMPredictionHead( - config, bert_model_embedding_weights) - self.seq_relationship = nn.Linear(config.hidden_size, num_labels) - - def forward(self, sequence_output, pooled_output, task_idx=None): - prediction_scores = self.predictions(sequence_output, task_idx) - if pooled_output is None: - seq_relationship_score = None - else: - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, seq_relationship_score - - -class PreTrainedBertModel(nn.Module): - """ An abstract class to handle weights initialization and - a simple interface for dowloading and loading pretrained models. - """ - - def __init__(self, config, *inputs, **kwargs): - super(PreTrainedBertModel, self).__init__() - if not isinstance(config, BertConfig): - raise ValueError( - "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " - "To create a model from a Google pretrained model use " - "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( - self.__class__.__name__, self.__class__.__name__ - )) - self.config = config - - def init_bert_weights(self, module): - """ Initialize the weights. - """ - if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, BertLayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - @classmethod - def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs): - """ - Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. - Download and cache the pre-trained model file if needed. - - Params: - pretrained_model_name: either: - - a str with the name of a pre-trained model to load selected in the list of: - . `bert-base-uncased` - . `bert-large-uncased` - . `bert-base-cased` - . `bert-base-multilingual` - . `bert-base-chinese` - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance - cache_dir: an optional path to a folder in which the pre-trained models will be cached. - state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models - *inputs, **kwargs: additional input for the specific Bert class - (ex: num_labels for BertForSequenceClassification) - """ - if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: - archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] - else: - archive_file = pretrained_model_name - # redirect to the cache, if necessary - try: - resolved_archive_file = cached_path( - archive_file, cache_dir=cache_dir) - except FileNotFoundError: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find any file " - "associated to this path or url.".format( - pretrained_model_name, - ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), - archive_file)) - return None - if resolved_archive_file == archive_file: - logger.info("loading archive file {}".format(archive_file)) - else: - logger.info("loading archive file {} from cache at {}".format( - archive_file, resolved_archive_file)) - tempdir = None - if os.path.isdir(resolved_archive_file): - serialization_dir = resolved_archive_file - else: - # Extract archive to temp dir - tempdir = tempfile.mkdtemp() - logger.info("extracting archive file {} to temp dir {}".format( - resolved_archive_file, tempdir)) - with tarfile.open(resolved_archive_file, 'r:gz') as archive: - def is_within_directory(directory, target): - - abs_directory = os.path.abspath(directory) - abs_target = os.path.abspath(target) - - prefix = os.path.commonprefix([abs_directory, abs_target]) - - return prefix == abs_directory - - def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - - for member in tar.getmembers(): - member_path = os.path.join(path, member.name) - if not is_within_directory(path, member_path): - raise Exception("Attempted Path Traversal in Tar File") - - tar.extractall(path, members, numeric_owner=numeric_owner) - - - safe_extract(archive, tempdir) - serialization_dir = tempdir - # Load config - if ('config_path' in kwargs) and kwargs['config_path']: - config_file = kwargs['config_path'] - else: - config_file = os.path.join(serialization_dir, CONFIG_NAME) - config = BertConfig.from_json_file(config_file) - - # define new type_vocab_size (there might be different numbers of segment ids) - if 'type_vocab_size' in kwargs: - config.type_vocab_size = kwargs['type_vocab_size'] - # define new relax_projection - if ('relax_projection' in kwargs) and kwargs['relax_projection']: - config.relax_projection = kwargs['relax_projection'] - # new position embedding - if ('new_pos_ids' in kwargs) and kwargs['new_pos_ids']: - config.new_pos_ids = kwargs['new_pos_ids'] - # define new relax_projection - if ('task_idx' in kwargs) and kwargs['task_idx']: - config.task_idx = kwargs['task_idx'] - # define new max position embedding for length expansion - if ('max_position_embeddings' in kwargs) and kwargs['max_position_embeddings']: - config.max_position_embeddings = kwargs['max_position_embeddings'] - # use fp32 for embeddings - if ('fp32_embedding' in kwargs) and kwargs['fp32_embedding']: - config.fp32_embedding = kwargs['fp32_embedding'] - # type of FFN in transformer blocks - if ('ffn_type' in kwargs) and kwargs['ffn_type']: - config.ffn_type = kwargs['ffn_type'] - # label smoothing - if ('label_smoothing' in kwargs) and kwargs['label_smoothing']: - config.label_smoothing = kwargs['label_smoothing'] - # dropout - if ('hidden_dropout_prob' in kwargs) and kwargs['hidden_dropout_prob']: - config.hidden_dropout_prob = kwargs['hidden_dropout_prob'] - if ('attention_probs_dropout_prob' in kwargs) and kwargs['attention_probs_dropout_prob']: - config.attention_probs_dropout_prob = kwargs['attention_probs_dropout_prob'] - # different QKV - if ('num_qkv' in kwargs) and kwargs['num_qkv']: - config.num_qkv = kwargs['num_qkv'] - # segment embedding for self-attention - if ('seg_emb' in kwargs) and kwargs['seg_emb']: - config.seg_emb = kwargs['seg_emb'] - # initialize word embeddings - _word_emb_map = None - if ('word_emb_map' in kwargs) and kwargs['word_emb_map']: - _word_emb_map = kwargs['word_emb_map'] - - logger.info("Model config {}".format(config)) - - # clean the arguments in kwargs - for arg_clean in ('config_path', 'type_vocab_size', 'relax_projection', 'new_pos_ids', 'task_idx', 'max_position_embeddings', 'fp32_embedding', 'ffn_type', 'label_smoothing', 'hidden_dropout_prob', 'attention_probs_dropout_prob', 'num_qkv', 'seg_emb', 'word_emb_map'): - if arg_clean in kwargs: - del kwargs[arg_clean] - - # Instantiate model. - model = cls(config, *inputs, **kwargs) - if state_dict is None: - weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) - state_dict = torch.load(weights_path) - - old_keys = [] - new_keys = [] - for key in state_dict.keys(): - new_key = None - if 'gamma' in key: - new_key = key.replace('gamma', 'weight') - if 'beta' in key: - new_key = key.replace('beta', 'bias') - if new_key: - old_keys.append(key) - new_keys.append(new_key) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - - # initialize new segment embeddings - _k = 'bert.embeddings.token_type_embeddings.weight' - if (_k in state_dict) and (config.type_vocab_size != state_dict[_k].shape[0]): - logger.info("config.type_vocab_size != state_dict[bert.embeddings.token_type_embeddings.weight] ({0} != {1})".format( - config.type_vocab_size, state_dict[_k].shape[0])) - if config.type_vocab_size > state_dict[_k].shape[0]: - # state_dict[_k].data = state_dict[_k].data.resize_(config.type_vocab_size, state_dict[_k].shape[1]) - state_dict[_k].resize_( - config.type_vocab_size, state_dict[_k].shape[1]) - # L2R - if config.type_vocab_size >= 3: - state_dict[_k].data[2, :].copy_(state_dict[_k].data[0, :]) - # R2L - if config.type_vocab_size >= 4: - state_dict[_k].data[3, :].copy_(state_dict[_k].data[0, :]) - # S2S - if config.type_vocab_size >= 6: - state_dict[_k].data[4, :].copy_(state_dict[_k].data[0, :]) - state_dict[_k].data[5, :].copy_(state_dict[_k].data[1, :]) - if config.type_vocab_size >= 7: - state_dict[_k].data[6, :].copy_(state_dict[_k].data[1, :]) - elif config.type_vocab_size < state_dict[_k].shape[0]: - state_dict[_k].data = state_dict[_k].data[:config.type_vocab_size, :] - - _k = 'bert.embeddings.position_embeddings.weight' - n_config_pos_emb = 4 if config.new_pos_ids else 1 - if (_k in state_dict) and (n_config_pos_emb*config.hidden_size != state_dict[_k].shape[1]): - logger.info("n_config_pos_emb*config.hidden_size != state_dict[bert.embeddings.position_embeddings.weight] ({0}*{1} != {2})".format( - n_config_pos_emb, config.hidden_size, state_dict[_k].shape[1])) - assert state_dict[_k].shape[1] % config.hidden_size == 0 - n_state_pos_emb = int(state_dict[_k].shape[1]/config.hidden_size) - assert (n_state_pos_emb == 1) != (n_config_pos_emb == - 1), "!!!!n_state_pos_emb == 1 xor n_config_pos_emb == 1!!!!" - if n_state_pos_emb == 1: - state_dict[_k].data = state_dict[_k].data.unsqueeze(1).repeat( - 1, n_config_pos_emb, 1).reshape((config.max_position_embeddings, n_config_pos_emb*config.hidden_size)) - elif n_config_pos_emb == 1: - if hasattr(config, 'task_idx') and (config.task_idx is not None) and (0 <= config.task_idx <= 3): - _task_idx = config.task_idx - else: - _task_idx = 0 - state_dict[_k].data = state_dict[_k].data.view( - config.max_position_embeddings, n_state_pos_emb, config.hidden_size).select(1, _task_idx) - - # initialize new position embeddings - _k = 'bert.embeddings.position_embeddings.weight' - if _k in state_dict and config.max_position_embeddings != state_dict[_k].shape[0]: - logger.info("config.max_position_embeddings != state_dict[bert.embeddings.position_embeddings.weight] ({0} - {1})".format( - config.max_position_embeddings, state_dict[_k].shape[0])) - if config.max_position_embeddings > state_dict[_k].shape[0]: - old_size = state_dict[_k].shape[0] - # state_dict[_k].data = state_dict[_k].data.resize_(config.max_position_embeddings, state_dict[_k].shape[1]) - state_dict[_k].resize_( - config.max_position_embeddings, state_dict[_k].shape[1]) - start = old_size - while start < config.max_position_embeddings: - chunk_size = min( - old_size, config.max_position_embeddings - start) - state_dict[_k].data[start:start+chunk_size, - :].copy_(state_dict[_k].data[:chunk_size, :]) - start += chunk_size - elif config.max_position_embeddings < state_dict[_k].shape[0]: - state_dict[_k].data = state_dict[_k].data[:config.max_position_embeddings, :] - - # initialize relax projection - _k = 'cls.predictions.transform.dense.weight' - n_config_relax = 1 if (config.relax_projection < - 1) else config.relax_projection - if (_k in state_dict) and (n_config_relax*config.hidden_size != state_dict[_k].shape[0]): - logger.info("n_config_relax*config.hidden_size != state_dict[cls.predictions.transform.dense.weight] ({0}*{1} != {2})".format( - n_config_relax, config.hidden_size, state_dict[_k].shape[0])) - assert state_dict[_k].shape[0] % config.hidden_size == 0 - n_state_relax = int(state_dict[_k].shape[0]/config.hidden_size) - assert (n_state_relax == 1) != (n_config_relax == - 1), "!!!!n_state_relax == 1 xor n_config_relax == 1!!!!" - if n_state_relax == 1: - _k = 'cls.predictions.transform.dense.weight' - state_dict[_k].data = state_dict[_k].data.unsqueeze(0).repeat( - n_config_relax, 1, 1).reshape((n_config_relax*config.hidden_size, config.hidden_size)) - for _k in ('cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias'): - state_dict[_k].data = state_dict[_k].data.unsqueeze( - 0).repeat(n_config_relax, 1).view(-1) - elif n_config_relax == 1: - if hasattr(config, 'task_idx') and (config.task_idx is not None) and (0 <= config.task_idx <= 3): - _task_idx = config.task_idx - else: - _task_idx = 0 - _k = 'cls.predictions.transform.dense.weight' - state_dict[_k].data = state_dict[_k].data.view( - n_state_relax, config.hidden_size, config.hidden_size).select(0, _task_idx) - for _k in ('cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias'): - state_dict[_k].data = state_dict[_k].data.view( - n_state_relax, config.hidden_size).select(0, _task_idx) - - # initialize QKV - _all_head_size = config.num_attention_heads * \ - int(config.hidden_size / config.num_attention_heads) - n_config_num_qkv = 1 if (config.num_qkv < 1) else config.num_qkv - for qkv_name in ('query', 'key', 'value'): - _k = 'bert.encoder.layer.0.attention.self.{0}.weight'.format( - qkv_name) - if (_k in state_dict) and (n_config_num_qkv*_all_head_size != state_dict[_k].shape[0]): - logger.info("n_config_num_qkv*_all_head_size != state_dict[_k] ({0}*{1} != {2})".format( - n_config_num_qkv, _all_head_size, state_dict[_k].shape[0])) - for layer_idx in range(config.num_hidden_layers): - _k = 'bert.encoder.layer.{0}.attention.self.{1}.weight'.format( - layer_idx, qkv_name) - assert state_dict[_k].shape[0] % _all_head_size == 0 - n_state_qkv = int(state_dict[_k].shape[0]/_all_head_size) - assert (n_state_qkv == 1) != (n_config_num_qkv == - 1), "!!!!n_state_qkv == 1 xor n_config_num_qkv == 1!!!!" - if n_state_qkv == 1: - _k = 'bert.encoder.layer.{0}.attention.self.{1}.weight'.format( - layer_idx, qkv_name) - state_dict[_k].data = state_dict[_k].data.unsqueeze(0).repeat( - n_config_num_qkv, 1, 1).reshape((n_config_num_qkv*_all_head_size, _all_head_size)) - _k = 'bert.encoder.layer.{0}.attention.self.{1}.bias'.format( - layer_idx, qkv_name) - state_dict[_k].data = state_dict[_k].data.unsqueeze( - 0).repeat(n_config_num_qkv, 1).view(-1) - elif n_config_num_qkv == 1: - if hasattr(config, 'task_idx') and (config.task_idx is not None) and (0 <= config.task_idx <= 3): - _task_idx = config.task_idx - else: - _task_idx = 0 - assert _task_idx != 3, "[INVALID] _task_idx=3: n_config_num_qkv=1 (should be 2)" - if _task_idx == 0: - _qkv_idx = 0 - else: - _qkv_idx = 1 - _k = 'bert.encoder.layer.{0}.attention.self.{1}.weight'.format( - layer_idx, qkv_name) - state_dict[_k].data = state_dict[_k].data.view( - n_state_qkv, _all_head_size, _all_head_size).select(0, _qkv_idx) - _k = 'bert.encoder.layer.{0}.attention.self.{1}.bias'.format( - layer_idx, qkv_name) - state_dict[_k].data = state_dict[_k].data.view( - n_state_qkv, _all_head_size).select(0, _qkv_idx) - - if _word_emb_map: - _k = 'bert.embeddings.word_embeddings.weight' - for _tgt, _src in _word_emb_map: - state_dict[_k].data[_tgt, :].copy_( - state_dict[_k].data[_src, :]) - - missing_keys = [] - unexpected_keys = [] - error_msgs = [] - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - def load(module, prefix=''): - local_metadata = {} if metadata is None else metadata.get( - prefix[:-1], {}) - module._load_from_state_dict( - state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + '.') - load(model, prefix='' if hasattr(model, 'bert') else 'bert.') - model.missing_keys = missing_keys - if len(missing_keys) > 0: - logger.info("Weights of {} not initialized from pretrained model: {}".format( - model.__class__.__name__, missing_keys)) - if len(unexpected_keys) > 0: - logger.info("Weights from pretrained model not used in {}: {}".format( - model.__class__.__name__, unexpected_keys)) - if len(error_msgs) > 0: - logger.info('\n'.join(error_msgs)) - if tempdir: - # Clean up temp dir - shutil.rmtree(tempdir) - return model - - -class BertModel(PreTrainedBertModel): - """BERT model ("Bidirectional Embedding Representations from a Transformer"). - - Params: - config: a BertConfig class instance with the configuration to build a new model - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. - - Outputs: Tuple of (encoded_layers, pooled_output) - `encoded_layers`: controled by `output_all_encoded_layers` argument: - - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end - of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each - encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], - - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding - to the last attention block of shape [batch_size, sequence_length, hidden_size], - `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a - classifier pretrained on top of the hidden state associated to the first character of the - input (`CLF`) to train on the Next-Sentence task (see BERT's paper). - ``` - """ - - def __init__(self, config): - super(BertModel, self).__init__(config) - self.embeddings = BertEmbeddings(config) - self.encoder = BertEncoder(config) - self.pooler = BertPooler(config) - self.apply(self.init_bert_weights) - - def rescale_some_parameters(self): - for layer_id, layer in enumerate(self.encoder.layer): - layer.attention.output.dense.weight.data.div_( - math.sqrt(2.0*(layer_id + 1))) - layer.output.dense.weight.data.div_(math.sqrt(2.0*(layer_id + 1))) - - def get_extended_attention_mask(self, input_ids, token_type_ids, attention_mask): - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - if attention_mask.dim() == 2: - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - elif attention_mask.dim() == 3: - extended_attention_mask = attention_mask.unsqueeze(1) - else: - raise NotImplementedError - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to( - dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - return extended_attention_mask - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, mask_qkv=None, task_idx=None): - extended_attention_mask = self.get_extended_attention_mask( - input_ids, token_type_ids, attention_mask) - - embedding_output = self.embeddings( - input_ids, token_type_ids, task_idx=task_idx) - encoded_layers = self.encoder(embedding_output, extended_attention_mask, - output_all_encoded_layers=output_all_encoded_layers, mask_qkv=mask_qkv, seg_ids=token_type_ids) - sequence_output = encoded_layers[-1] - pooled_output = self.pooler(sequence_output) - if not output_all_encoded_layers: - encoded_layers = encoded_layers[-1] - return encoded_layers, pooled_output - - -class BertModelIncr(BertModel): - def __init__(self, config): - super(BertModelIncr, self).__init__(config) - - def forward(self, input_ids, token_type_ids, position_ids, attention_mask, output_all_encoded_layers=True, prev_embedding=None, - prev_encoded_layers=None, mask_qkv=None, task_idx=None): - extended_attention_mask = self.get_extended_attention_mask( - input_ids, token_type_ids, attention_mask) - - embedding_output = self.embeddings( - input_ids, token_type_ids, position_ids, task_idx=task_idx) - encoded_layers = self.encoder(embedding_output, - extended_attention_mask, - output_all_encoded_layers=output_all_encoded_layers, - prev_embedding=prev_embedding, - prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv, seg_ids=token_type_ids) - sequence_output = encoded_layers[-1] - pooled_output = self.pooler(sequence_output) - if not output_all_encoded_layers: - encoded_layers = encoded_layers[-1] - return embedding_output, encoded_layers, pooled_output - - -class BertForPreTraining(PreTrainedBertModel): - """BERT model with pre-training heads. - This module comprises the BERT model followed by the two pre-training heads: - - the masked language modeling head, and - - the next sentence classification head. - Params: - config: a BertConfig class instance with the configuration to build a new model. - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss - is only computed for the labels set in [0, ..., vocab_size] - `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] - with indices selected in [0, 1]. - 0 => next sentence is the continuation, 1 => next sentence is a random sentence. - Outputs: - if `masked_lm_labels` and `next_sentence_label` are not `None`: - Outputs the total_loss which is the sum of the masked language modeling loss and the next - sentence classification loss. - if `masked_lm_labels` or `next_sentence_label` is `None`: - Outputs a tuple comprising - - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and - - the next sentence classification logits of shape [batch_size, 2]. - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - model = BertForPreTraining(config) - masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config): - super(BertForPreTraining, self).__init__(config) - self.bert = BertModel(config) - self.cls = BertPreTrainingHeads( - config, self.bert.embeddings.word_embeddings.weight) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, mask_qkv=None, task_idx=None): - sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, - output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) - prediction_scores, seq_relationship_score = self.cls( - sequence_output, pooled_output) - - if masked_lm_labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct( - prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) - next_sentence_loss = loss_fct( - seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - total_loss = masked_lm_loss + next_sentence_loss - return total_loss - else: - return prediction_scores, seq_relationship_score - - -class BertPreTrainingPairTransform(nn.Module): - def __init__(self, config): - super(BertPreTrainingPairTransform, self).__init__() - self.dense = nn.Linear(config.hidden_size*2, config.hidden_size) - self.transform_act_fn = ACT2FN[config.hidden_act] \ - if isinstance(config.hidden_act, str) else config.hidden_act - # self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) - - def forward(self, pair_x, pair_y): - hidden_states = torch.cat([pair_x, pair_y], dim=-1) - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - # hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class BertPreTrainingPairRel(nn.Module): - def __init__(self, config, num_rel=0): - super(BertPreTrainingPairRel, self).__init__() - self.R_xy = BertPreTrainingPairTransform(config) - self.rel_emb = nn.Embedding(num_rel, config.hidden_size) - - def forward(self, pair_x, pair_y, pair_r, pair_pos_neg_mask): - # (batch, num_pair, hidden) - xy = self.R_xy(pair_x, pair_y) - r = self.rel_emb(pair_r) - _batch, _num_pair, _hidden = xy.size() - pair_score = (xy * r).sum(-1) - # torch.bmm(xy.view(-1, 1, _hidden),r.view(-1, _hidden, 1)).view(_batch, _num_pair) - # .mul_(-1.0): objective to loss - return F.logsigmoid(pair_score * pair_pos_neg_mask.type_as(pair_score)).mul_(-1.0) - - -class BertForPreTrainingLossMask(PreTrainedBertModel): - """refer to BertForPreTraining""" - - def __init__(self, config, num_labels=2, num_rel=0, num_sentlvl_labels=0, no_nsp=False): - super(BertForPreTrainingLossMask, self).__init__(config) - self.bert = BertModel(config) - self.cls = BertPreTrainingHeads( - config, self.bert.embeddings.word_embeddings.weight, num_labels=num_labels) - self.num_sentlvl_labels = num_sentlvl_labels - self.cls2 = None - if self.num_sentlvl_labels > 0: - self.secondary_pred_proj = nn.Embedding( - num_sentlvl_labels, config.hidden_size) - self.cls2 = BertPreTrainingHeads( - config, self.secondary_pred_proj.weight, num_labels=num_sentlvl_labels) - self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none') - if no_nsp: - self.crit_next_sent = None - else: - self.crit_next_sent = nn.CrossEntropyLoss(ignore_index=-1) - self.num_labels = num_labels - self.num_rel = num_rel - if self.num_rel > 0: - self.crit_pair_rel = BertPreTrainingPairRel( - config, num_rel=num_rel) - if hasattr(config, 'label_smoothing') and config.label_smoothing: - self.crit_mask_lm_smoothed = LabelSmoothingLoss( - config.label_smoothing, config.vocab_size, ignore_index=0, reduction='none') - else: - self.crit_mask_lm_smoothed = None - self.apply(self.init_bert_weights) - self.bert.rescale_some_parameters() - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, - next_sentence_label=None, masked_pos=None, masked_weights=None, task_idx=None, pair_x=None, - pair_x_mask=None, pair_y=None, pair_y_mask=None, pair_r=None, pair_pos_neg_mask=None, - pair_loss_mask=None, masked_pos_2=None, masked_weights_2=None, masked_labels_2=None, - num_tokens_a=None, num_tokens_b=None, mask_qkv=None): - if token_type_ids is None and attention_mask is None: - task_0 = (task_idx == 0) - task_1 = (task_idx == 1) - task_2 = (task_idx == 2) - task_3 = (task_idx == 3) - - sequence_length = input_ids.shape[-1] - index_matrix = torch.arange(sequence_length).view( - 1, sequence_length).to(input_ids.device) - - num_tokens = num_tokens_a + num_tokens_b - - base_mask = (index_matrix < num_tokens.view(-1, 1) - ).type_as(input_ids) - segment_a_mask = ( - index_matrix < num_tokens_a.view(-1, 1)).type_as(input_ids) - - token_type_ids = ( - task_idx + 1 + task_3.type_as(task_idx)).view(-1, 1) * base_mask - token_type_ids = token_type_ids - segment_a_mask * \ - (task_0 | task_3).type_as(segment_a_mask).view(-1, 1) - - index_matrix = index_matrix.view(1, 1, sequence_length) - index_matrix_t = index_matrix.view(1, sequence_length, 1) - - tril = index_matrix <= index_matrix_t - - attention_mask_task_0 = ( - index_matrix < num_tokens.view(-1, 1, 1)) & (index_matrix_t < num_tokens.view(-1, 1, 1)) - attention_mask_task_1 = tril & attention_mask_task_0 - attention_mask_task_2 = torch.transpose( - tril, dim0=-2, dim1=-1) & attention_mask_task_0 - attention_mask_task_3 = ( - (index_matrix < num_tokens_a.view(-1, 1, 1)) | tril) & attention_mask_task_0 - - attention_mask = (attention_mask_task_0 & task_0.view(-1, 1, 1)) | \ - (attention_mask_task_1 & task_1.view(-1, 1, 1)) | \ - (attention_mask_task_2 & task_2.view(-1, 1, 1)) | \ - (attention_mask_task_3 & task_3.view(-1, 1, 1)) - attention_mask = attention_mask.type_as(input_ids) - sequence_output, pooled_output = self.bert( - input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) - - def gather_seq_out_by_pos(seq, pos): - return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1))) - - def gather_seq_out_by_pos_average(seq, pos, mask): - # pos/mask: (batch, num_pair, max_token_num) - batch_size, max_token_num = pos.size(0), pos.size(-1) - # (batch, num_pair, max_token_num, seq.size(-1)) - pos_vec = torch.gather(seq, 1, pos.view(batch_size, -1).unsqueeze( - 2).expand(-1, -1, seq.size(-1))).view(batch_size, -1, max_token_num, seq.size(-1)) - # (batch, num_pair, seq.size(-1)) - mask = mask.type_as(pos_vec) - pos_vec_masked_sum = ( - pos_vec * mask.unsqueeze(3).expand_as(pos_vec)).sum(2) - return pos_vec_masked_sum / mask.sum(2, keepdim=True).expand_as(pos_vec_masked_sum) - - def loss_mask_and_normalize(loss, mask): - mask = mask.type_as(loss) - loss = loss * mask - denominator = torch.sum(mask) + 1e-5 - return (loss / denominator).sum() - - if masked_lm_labels is None: - if masked_pos is None: - prediction_scores, seq_relationship_score = self.cls( - sequence_output, pooled_output, task_idx=task_idx) - else: - sequence_output_masked = gather_seq_out_by_pos( - sequence_output, masked_pos) - prediction_scores, seq_relationship_score = self.cls( - sequence_output_masked, pooled_output, task_idx=task_idx) - return prediction_scores, seq_relationship_score - - # masked lm - sequence_output_masked = gather_seq_out_by_pos( - sequence_output, masked_pos) - prediction_scores_masked, seq_relationship_score = self.cls( - sequence_output_masked, pooled_output, task_idx=task_idx) - if self.crit_mask_lm_smoothed: - masked_lm_loss = self.crit_mask_lm_smoothed( - F.log_softmax(prediction_scores_masked.float(), dim=-1), masked_lm_labels) - else: - masked_lm_loss = self.crit_mask_lm( - prediction_scores_masked.transpose(1, 2).float(), masked_lm_labels) - masked_lm_loss = loss_mask_and_normalize( - masked_lm_loss.float(), masked_weights) - - # next sentence - if self.crit_next_sent is None or next_sentence_label is None: - next_sentence_loss = 0.0 - else: - next_sentence_loss = self.crit_next_sent( - seq_relationship_score.view(-1, self.num_labels).float(), next_sentence_label.view(-1)) - - if self.cls2 is not None and masked_pos_2 is not None: - sequence_output_masked_2 = gather_seq_out_by_pos( - sequence_output, masked_pos_2) - prediction_scores_masked_2, _ = self.cls2( - sequence_output_masked_2, None) - masked_lm_loss_2 = self.crit_mask_lm( - prediction_scores_masked_2.transpose(1, 2).float(), masked_labels_2) - masked_lm_loss_2 = loss_mask_and_normalize( - masked_lm_loss_2.float(), masked_weights_2) - masked_lm_loss = masked_lm_loss + masked_lm_loss_2 - - if pair_x is None or pair_y is None or pair_r is None or pair_pos_neg_mask is None or pair_loss_mask is None: - return masked_lm_loss, next_sentence_loss - - # pair and relation - if pair_x_mask is None or pair_y_mask is None: - pair_x_output_masked = gather_seq_out_by_pos( - sequence_output, pair_x) - pair_y_output_masked = gather_seq_out_by_pos( - sequence_output, pair_y) - else: - pair_x_output_masked = gather_seq_out_by_pos_average( - sequence_output, pair_x, pair_x_mask) - pair_y_output_masked = gather_seq_out_by_pos_average( - sequence_output, pair_y, pair_y_mask) - pair_loss = self.crit_pair_rel( - pair_x_output_masked, pair_y_output_masked, pair_r, pair_pos_neg_mask) - pair_loss = loss_mask_and_normalize( - pair_loss.float(), pair_loss_mask) - return masked_lm_loss, next_sentence_loss, pair_loss - - -class BertForExtractiveSummarization(PreTrainedBertModel): - """refer to BertForPreTraining""" - - def __init__(self, config): - super(BertForExtractiveSummarization, self).__init__(config) - self.bert = BertModel(config) - self.secondary_pred_proj = nn.Embedding(2, config.hidden_size) - self.cls2 = BertPreTrainingHeads( - config, self.secondary_pred_proj.weight, num_labels=2) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_pos_2=None, masked_weights_2=None, task_idx=None, mask_qkv=None): - sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, - output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) - - def gather_seq_out_by_pos(seq, pos): - return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1))) - - sequence_output_masked_2 = gather_seq_out_by_pos( - sequence_output, masked_pos_2) - prediction_scores_masked_2, _ = self.cls2( - sequence_output_masked_2, None, task_idx=task_idx) - - predicted_probs = torch.nn.functional.softmax( - prediction_scores_masked_2, dim=-1) - - return predicted_probs, masked_pos_2, masked_weights_2 - - -class BertForSeq2SeqDecoder(PreTrainedBertModel): - """refer to BertForPreTraining""" - - def __init__(self, config, mask_word_id=0, num_labels=2, num_rel=0, - search_beam_size=1, length_penalty=1.0, eos_id=0, sos_id=0, - forbid_duplicate_ngrams=False, forbid_ignore_set=None, not_predict_set=None, ngram_size=3, min_len=0, mode="s2s", pos_shift=False): - super(BertForSeq2SeqDecoder, self).__init__(config) - self.bert = BertModelIncr(config) - self.cls = BertPreTrainingHeads( - config, self.bert.embeddings.word_embeddings.weight, num_labels=num_labels) - self.apply(self.init_bert_weights) - self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none') - self.crit_next_sent = nn.CrossEntropyLoss(ignore_index=-1) - self.mask_word_id = mask_word_id - self.num_labels = num_labels - self.num_rel = num_rel - if self.num_rel > 0: - self.crit_pair_rel = BertPreTrainingPairRel( - config, num_rel=num_rel) - self.search_beam_size = search_beam_size - self.length_penalty = length_penalty - self.eos_id = eos_id - self.sos_id = sos_id - self.forbid_duplicate_ngrams = forbid_duplicate_ngrams - self.forbid_ignore_set = forbid_ignore_set - self.not_predict_set = not_predict_set - self.ngram_size = ngram_size - self.min_len = min_len - assert mode in ("s2s", "l2r") - self.mode = mode - self.pos_shift = pos_shift - - def forward(self, input_ids, token_type_ids, position_ids, attention_mask, task_idx=None, mask_qkv=None): - if self.search_beam_size > 1: - return self.beam_search(input_ids, token_type_ids, position_ids, attention_mask, task_idx=task_idx, mask_qkv=mask_qkv) - - input_shape = list(input_ids.size()) - batch_size = input_shape[0] - input_length = input_shape[1] - output_shape = list(token_type_ids.size()) - output_length = output_shape[1] - - output_ids = [] - prev_embedding = None - prev_encoded_layers = None - curr_ids = input_ids - mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id) - next_pos = input_length - if self.pos_shift: - sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id) - - while next_pos < output_length: - curr_length = list(curr_ids.size())[1] - - if self.pos_shift: - if next_pos == input_length: - x_input_ids = torch.cat((curr_ids, sos_ids), dim=1) - start_pos = 0 - else: - x_input_ids = curr_ids - start_pos = next_pos - else: - start_pos = next_pos - curr_length - x_input_ids = torch.cat((curr_ids, mask_ids), dim=1) - - curr_token_type_ids = token_type_ids[:, start_pos:next_pos+1] - curr_attention_mask = attention_mask[:, - start_pos:next_pos+1, :next_pos+1] - curr_position_ids = position_ids[:, start_pos:next_pos+1] - new_embedding, new_encoded_layers, _ = \ - self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask, - output_all_encoded_layers=True, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv) - - last_hidden = new_encoded_layers[-1][:, -1:, :] - prediction_scores, _ = self.cls( - last_hidden, None, task_idx=task_idx) - if self.not_predict_set: - for token_id in self.not_predict_set: - prediction_scores[:, :, token_id].fill_(-10000.0) - _, max_ids = torch.max(prediction_scores, dim=-1) - output_ids.append(max_ids) - - if self.pos_shift: - if prev_embedding is None: - prev_embedding = new_embedding - else: - prev_embedding = torch.cat( - (prev_embedding, new_embedding), dim=1) - if prev_encoded_layers is None: - prev_encoded_layers = [x for x in new_encoded_layers] - else: - prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip( - prev_encoded_layers, new_encoded_layers)] - else: - if prev_embedding is None: - prev_embedding = new_embedding[:, :-1, :] - else: - prev_embedding = torch.cat( - (prev_embedding, new_embedding[:, :-1, :]), dim=1) - if prev_encoded_layers is None: - prev_encoded_layers = [x[:, :-1, :] - for x in new_encoded_layers] - else: - prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1) - for x in zip(prev_encoded_layers, new_encoded_layers)] - curr_ids = max_ids - next_pos += 1 - - return torch.cat(output_ids, dim=1) - - def beam_search(self, input_ids, token_type_ids, position_ids, attention_mask, task_idx=None, mask_qkv=None): - input_shape = list(input_ids.size()) - batch_size = input_shape[0] - input_length = input_shape[1] - output_shape = list(token_type_ids.size()) - output_length = output_shape[1] - - output_ids = [] - prev_embedding = None - prev_encoded_layers = None - curr_ids = input_ids - mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id) - next_pos = input_length - if self.pos_shift: - sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id) - - K = self.search_beam_size - - total_scores = [] - beam_masks = [] - step_ids = [] - step_back_ptrs = [] - partial_seqs = [] - forbid_word_mask = None - buf_matrix = None - - while next_pos < output_length: - curr_length = list(curr_ids.size())[1] - - if self.pos_shift: - if next_pos == input_length: - x_input_ids = torch.cat((curr_ids, sos_ids), dim=1) - start_pos = 0 - else: - x_input_ids = curr_ids - start_pos = next_pos - else: - start_pos = next_pos - curr_length - x_input_ids = torch.cat((curr_ids, mask_ids), dim=1) - - curr_token_type_ids = token_type_ids[:, start_pos:next_pos + 1] - curr_attention_mask = attention_mask[:, - start_pos:next_pos + 1, :next_pos + 1] - curr_position_ids = position_ids[:, start_pos:next_pos + 1] - new_embedding, new_encoded_layers, _ = \ - self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask, - output_all_encoded_layers=True, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv) - - last_hidden = new_encoded_layers[-1][:, -1:, :] - prediction_scores, _ = self.cls( - last_hidden, None, task_idx=task_idx) - log_scores = torch.nn.functional.log_softmax( - prediction_scores, dim=-1) - if forbid_word_mask is not None: - log_scores += (forbid_word_mask * -10000.0) - if self.min_len and (next_pos-input_length+1 <= self.min_len): - log_scores[:, :, self.eos_id].fill_(-10000.0) - if self.not_predict_set: - for token_id in self.not_predict_set: - log_scores[:, :, token_id].fill_(-10000.0) - kk_scores, kk_ids = torch.topk(log_scores, k=K) - if len(total_scores) == 0: - k_ids = torch.reshape(kk_ids, [batch_size, K]) - back_ptrs = torch.zeros(batch_size, K, dtype=torch.long) - k_scores = torch.reshape(kk_scores, [batch_size, K]) - else: - last_eos = torch.reshape( - beam_masks[-1], [batch_size * K, 1, 1]) - last_seq_scores = torch.reshape( - total_scores[-1], [batch_size * K, 1, 1]) - kk_scores += last_eos * (-10000.0) + last_seq_scores - kk_scores = torch.reshape(kk_scores, [batch_size, K * K]) - k_scores, k_ids = torch.topk(kk_scores, k=K) - back_ptrs = torch.div(k_ids, K) - kk_ids = torch.reshape(kk_ids, [batch_size, K * K]) - k_ids = torch.gather(kk_ids, 1, k_ids) - step_back_ptrs.append(back_ptrs) - step_ids.append(k_ids) - beam_masks.append(torch.eq(k_ids, self.eos_id).float()) - total_scores.append(k_scores) - - def first_expand(x): - input_shape = list(x.size()) - expanded_shape = input_shape[:1] + [1] + input_shape[1:] - x = torch.reshape(x, expanded_shape) - repeat_count = [1, K] + [1] * (len(input_shape) - 1) - x = x.repeat(*repeat_count) - x = torch.reshape(x, [input_shape[0] * K] + input_shape[1:]) - return x - - def select_beam_items(x, ids): - id_shape = list(ids.size()) - id_rank = len(id_shape) - assert len(id_shape) == 2 - x_shape = list(x.size()) - x = torch.reshape(x, [batch_size, K] + x_shape[1:]) - x_rank = len(x_shape) + 1 - assert x_rank >= 2 - if id_rank < x_rank: - ids = torch.reshape( - ids, id_shape + [1] * (x_rank - id_rank)) - ids = ids.expand(id_shape + x_shape[1:]) - y = torch.gather(x, 1, ids) - y = torch.reshape(y, x_shape) - return y - - is_first = (prev_embedding is None) - - if self.pos_shift: - if prev_embedding is None: - prev_embedding = first_expand(new_embedding) - else: - prev_embedding = torch.cat( - (prev_embedding, new_embedding), dim=1) - prev_embedding = select_beam_items( - prev_embedding, back_ptrs) - if prev_encoded_layers is None: - prev_encoded_layers = [first_expand( - x) for x in new_encoded_layers] - else: - prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip( - prev_encoded_layers, new_encoded_layers)] - prev_encoded_layers = [select_beam_items( - x, back_ptrs) for x in prev_encoded_layers] - else: - if prev_embedding is None: - prev_embedding = first_expand(new_embedding[:, :-1, :]) - else: - prev_embedding = torch.cat( - (prev_embedding, new_embedding[:, :-1, :]), dim=1) - prev_embedding = select_beam_items( - prev_embedding, back_ptrs) - if prev_encoded_layers is None: - prev_encoded_layers = [first_expand( - x[:, :-1, :]) for x in new_encoded_layers] - else: - prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1) - for x in zip(prev_encoded_layers, new_encoded_layers)] - prev_encoded_layers = [select_beam_items( - x, back_ptrs) for x in prev_encoded_layers] - - curr_ids = torch.reshape(k_ids, [batch_size * K, 1]) - - if is_first: - token_type_ids = first_expand(token_type_ids) - position_ids = first_expand(position_ids) - attention_mask = first_expand(attention_mask) - mask_ids = first_expand(mask_ids) - if mask_qkv is not None: - mask_qkv = first_expand(mask_qkv) - - if self.forbid_duplicate_ngrams: - wids = step_ids[-1].tolist() - ptrs = step_back_ptrs[-1].tolist() - if is_first: - partial_seqs = [] - for b in range(batch_size): - for k in range(K): - partial_seqs.append([wids[b][k]]) - else: - new_partial_seqs = [] - for b in range(batch_size): - for k in range(K): - new_partial_seqs.append( - partial_seqs[ptrs[b][k] + b * K] + [wids[b][k]]) - partial_seqs = new_partial_seqs - - def get_dup_ngram_candidates(seq, n): - cands = set() - if len(seq) < n: - return [] - tail = seq[-(n-1):] - if self.forbid_ignore_set and any(tk in self.forbid_ignore_set for tk in tail): - return [] - for i in range(len(seq) - (n - 1)): - mismatch = False - for j in range(n - 1): - if tail[j] != seq[i + j]: - mismatch = True - break - if (not mismatch) and not(self.forbid_ignore_set and (seq[i + n - 1] in self.forbid_ignore_set)): - cands.add(seq[i + n - 1]) - return list(sorted(cands)) - - if len(partial_seqs[0]) >= self.ngram_size: - dup_cands = [] - for seq in partial_seqs: - dup_cands.append( - get_dup_ngram_candidates(seq, self.ngram_size)) - if max(len(x) for x in dup_cands) > 0: - if buf_matrix is None: - vocab_size = list(log_scores.size())[-1] - buf_matrix = np.zeros( - (batch_size * K, vocab_size), dtype=float) - else: - buf_matrix.fill(0) - for bk, cands in enumerate(dup_cands): - for i, wid in enumerate(cands): - buf_matrix[bk, wid] = 1.0 - forbid_word_mask = torch.tensor( - buf_matrix, dtype=log_scores.dtype) - forbid_word_mask = torch.reshape( - forbid_word_mask, [batch_size * K, 1, vocab_size]).cuda() - else: - forbid_word_mask = None - next_pos += 1 - - # [(batch, beam)] - total_scores = [x.tolist() for x in total_scores] - step_ids = [x.tolist() for x in step_ids] - step_back_ptrs = [x.tolist() for x in step_back_ptrs] - # back tracking - traces = {'pred_seq': [], 'scores': [], 'wids': [], 'ptrs': []} - for b in range(batch_size): - # [(beam,)] - scores = [x[b] for x in total_scores] - wids_list = [x[b] for x in step_ids] - ptrs = [x[b] for x in step_back_ptrs] - traces['scores'].append(scores) - traces['wids'].append(wids_list) - traces['ptrs'].append(ptrs) - # first we need to find the eos frame where all symbols are eos - # any frames after the eos frame are invalid - last_frame_id = len(scores) - 1 - for i, wids in enumerate(wids_list): - if all(wid == self.eos_id for wid in wids): - last_frame_id = i - break - max_score = -math.inf - frame_id = -1 - pos_in_frame = -1 - - for fid in range(last_frame_id + 1): - for i, wid in enumerate(wids_list[fid]): - if wid == self.eos_id or fid == last_frame_id: - s = scores[fid][i] - if self.length_penalty > 0: - s /= math.pow((5 + fid + 1) / 6.0, - self.length_penalty) - if s > max_score: - max_score = s - frame_id = fid - pos_in_frame = i - if frame_id == -1: - traces['pred_seq'].append([0]) - else: - seq = [wids_list[frame_id][pos_in_frame]] - for fid in range(frame_id, 0, -1): - pos_in_frame = ptrs[fid][pos_in_frame] - seq.append(wids_list[fid - 1][pos_in_frame]) - seq.reverse() - traces['pred_seq'].append(seq) - - def _pad_sequence(sequences, max_len, padding_value=0): - trailing_dims = sequences[0].size()[1:] - out_dims = (len(sequences), max_len) + trailing_dims - - out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) - for i, tensor in enumerate(sequences): - length = tensor.size(0) - # use index notation to prevent duplicate references to the tensor - out_tensor[i, :length, ...] = tensor - return out_tensor - - # convert to tensors for DataParallel - for k in ('pred_seq', 'scores', 'wids', 'ptrs'): - ts_list = traces[k] - if not isinstance(ts_list[0], torch.Tensor): - dt = torch.float if k == 'scores' else torch.long - ts_list = [torch.tensor(it, dtype=dt) for it in ts_list] - traces[k] = _pad_sequence( - ts_list, output_length, padding_value=0).to(input_ids.device) - - return traces - - -class BertForMaskedLM(PreTrainedBertModel): - """BERT model with the masked language modeling head. - This module comprises the BERT model followed by the masked language modeling head. - - Params: - config: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss - is only computed for the labels set in [0, ..., vocab_size] - - Outputs: - if `masked_lm_labels` is `None`: - Outputs the masked language modeling loss. - if `masked_lm_labels` is `None`: - Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForMaskedLM(config) - masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config): - super(BertForMaskedLM, self).__init__(config) - self.bert = BertModel(config) - self.cls = BertOnlyMLMHead( - config, self.bert.embeddings.word_embeddings.weight) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, mask_qkv=None, task_idx=None): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, - output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) - prediction_scores = self.cls(sequence_output) - - if masked_lm_labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct( - prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) - return masked_lm_loss - else: - return prediction_scores - - -class BertForNextSentencePrediction(PreTrainedBertModel): - """BERT model with next sentence prediction head. - This module comprises the BERT model followed by the next sentence classification head. - - Params: - config: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] - with indices selected in [0, 1]. - 0 => next sentence is the continuation, 1 => next sentence is a random sentence. - - Outputs: - if `next_sentence_label` is not `None`: - Outputs the total_loss which is the sum of the masked language modeling loss and the next - sentence classification loss. - if `next_sentence_label` is `None`: - Outputs the next sentence classification logits of shape [batch_size, 2]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForNextSentencePrediction(config) - seq_relationship_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config): - super(BertForNextSentencePrediction, self).__init__(config) - self.bert = BertModel(config) - self.cls = BertOnlyNSPHead(config) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, mask_qkv=None, task_idx=None): - _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, - output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) - seq_relationship_score = self.cls(pooled_output) - - if next_sentence_label is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - next_sentence_loss = loss_fct( - seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - return next_sentence_loss - else: - return seq_relationship_score - - -class BertForSequenceClassification(PreTrainedBertModel): - """BERT model for classification. - This module is composed of the BERT model with a linear layer on top of - the pooled output. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_labels`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] - with indices selected in [0, ..., num_labels]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_labels = 2 - - model = BertForSequenceClassification(config, num_labels) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config, num_labels=2): - super(BertForSequenceClassification, self).__init__(config) - self.num_labels = num_labels - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, num_labels) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, mask_qkv=None, task_idx=None): - _, pooled_output = self.bert( - input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - - if labels is not None: - if labels.dtype == torch.long: - loss_fct = CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.num_labels), labels.view(-1)) - elif labels.dtype == torch.half or labels.dtype == torch.float: - loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: - print('unkown labels.dtype') - loss = None - return loss - else: - return logits - - -class BertForMultipleChoice(PreTrainedBertModel): - """BERT model for multiple choice tasks. - This module is composed of the BERT model with a linear layer on top of - the pooled output. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_choices`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] - with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` - and type 1 corresponds to a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] - with indices selected in [0, ..., num_choices]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) - input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) - token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_choices = 2 - - model = BertForMultipleChoice(config, num_choices) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config, num_choices=2): - super(BertForMultipleChoice, self).__init__(config) - self.num_choices = num_choices - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, 1) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, mask_qkv=None, task_idx=None): - flat_input_ids = input_ids.view(-1, input_ids.size(-1)) - flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) - flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) - _, pooled_output = self.bert( - flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - reshaped_logits = logits.view(-1, self.num_choices) - - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) - return loss - else: - return reshaped_logits - - -class BertForTokenClassification(PreTrainedBertModel): - """BERT model for token-level classification. - This module is composed of the BERT model with a linear layer on top of - the full hidden state of the last layer. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_labels`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] - with indices selected in [0, ..., num_labels]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_labels = 2 - - model = BertForTokenClassification(config, num_labels) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config, num_labels=2): - super(BertForTokenClassification, self).__init__(config) - self.num_labels = num_labels - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, num_labels) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, mask_qkv=None, task_idx=None): - sequence_output, _ = self.bert( - input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - if labels is not None: - loss_fct = CrossEntropyLoss() - # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels)[active_loss] - active_labels = labels.view(-1)[active_loss] - loss = loss_fct(active_logits, active_labels) - else: - loss = loss_fct( - logits.view(-1, self.num_labels), labels.view(-1)) - return loss - else: - return logits - - -class BertForQuestionAnswering(PreTrainedBertModel): - """BERT model for Question Answering (span extraction). - This module is composed of the BERT model with a linear layer on top of - the sequence output that computes start_logits and end_logits - - Params: - `config`: either - - a BertConfig class instance with the configuration to build a new model, or - - a str with the name of a pre-trained model to load selected in the list of: - . `bert-base-uncased` - . `bert-large-uncased` - . `bert-base-cased` - . `bert-base-multilingual` - . `bert-base-chinese` - The pre-trained model will be downloaded and cached if needed. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. - Positions are clamped to the length of the sequence and position outside of the sequence are not taken - into account for computing the loss. - `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. - Positions are clamped to the length of the sequence and position outside of the sequence are not taken - into account for computing the loss. - - Outputs: - if `start_positions` and `end_positions` are not `None`: - Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. - if `start_positions` or `end_positions` is `None`: - Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end - position tokens of shape [batch_size, sequence_length]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForQuestionAnswering(config) - start_logits, end_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config): - super(BertForQuestionAnswering, self).__init__(config) - self.bert = BertModel(config) - # self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, task_idx=None): - sequence_output, _ = self.bert( - input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, task_idx=task_idx) - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions.clamp_(0, ignored_index) - end_positions.clamp_(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - return total_loss - else: - return start_logits, end_logits diff --git a/text2text/pytorch_pretrained_bert/tokenization.py b/text2text/pytorch_pretrained_bert/tokenization.py deleted file mode 100755 index a2a4d6d..0000000 --- a/text2text/pytorch_pretrained_bert/tokenization.py +++ /dev/null @@ -1,390 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tokenization classes.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import unicodedata -import os -import logging - -from .file_utils import cached_path - -logger = logging.getLogger(__name__) - -PRETRAINED_VOCAB_ARCHIVE_MAP = { - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", - 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", - 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", -} -PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { - 'bert-base-uncased': 512, - 'bert-large-uncased': 512, - 'bert-base-cased': 512, - 'bert-large-cased': 512, - 'bert-base-multilingual-uncased': 512, - 'bert-base-multilingual-cased': 512, - 'bert-base-chinese': 512, -} -VOCAB_NAME = 'vocab.txt' - - -def load_vocab(vocab_file): - """Loads a vocabulary file into a dictionary.""" - # mapping unused tokens to special tokens - extra_map = {} - extra_map['[unused1]'] = '[X_SEP]' - for i in range(10): - extra_map['[unused{}]'.format(i+2)] = '[SEP_{}]'.format(i) - extra_map['[unused12]'] = '[S2S_SEP]' - extra_map['[unused13]'] = '[S2S_CLS]' - extra_map['[unused14]'] = '[L2R_SEP]' - extra_map['[unused15]'] = '[L2R_CLS]' - extra_map['[unused16]'] = '[R2L_SEP]' - extra_map['[unused17]'] = '[R2L_CLS]' - extra_map['[unused18]'] = '[S2S_SOS]' - - vocab = collections.OrderedDict() - index = 0 - with open(vocab_file, "r", encoding="utf-8") as reader: - while True: - token = reader.readline() - if not token: - break - token = token.strip() - if token in extra_map: - token = extra_map[token] - vocab[token] = index - index += 1 - return vocab - - -def whitespace_tokenize(text): - """Runs basic whitespace cleaning and splitting on a peice of text.""" - text = text.strip() - if not text: - return [] - tokens = text.split() - return tokens - - -class BertTokenizer(object): - """Runs end-to-end tokenization: punctuation splitting + wordpiece""" - - def __init__(self, vocab_file, do_lower_case=True, max_len=None, never_split=("[UNK]", "[SEP]", "[X_SEP]", "[PAD]", "[CLS]", "[MASK]"), **kwargs): - if not os.path.isfile(vocab_file): - raise ValueError( - "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " - "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) - self.vocab = load_vocab(vocab_file) - self.ids_to_tokens = collections.OrderedDict( - [(ids, tok) for tok, ids in self.vocab.items()]) - self.basic_tokenizer = BasicTokenizer( - do_lower_case=do_lower_case, never_split=never_split) - self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) - self.max_len = max_len if max_len is not None else int(1e12) - - def tokenize(self, text): - split_tokens = [] - for token in self.basic_tokenizer.tokenize(text): - for sub_token in self.wordpiece_tokenizer.tokenize(token): - split_tokens.append(sub_token) - return split_tokens - - def convert_tokens_to_ids(self, tokens): - """Converts a sequence of tokens into ids using the vocab.""" - ids = [] - for token in tokens: - ids.append(self.vocab[token]) - if len(ids) > self.max_len: - raise ValueError( - "Token indices sequence length is longer than the specified maximum " - " sequence length for this BERT model ({} > {}). Running this" - " sequence through BERT will result in indexing errors".format( - len(ids), self.max_len) - ) - return ids - - def convert_ids_to_tokens(self, ids): - """Converts a sequence of ids in wordpiece tokens using the vocab.""" - tokens = [] - for i in ids: - tokens.append(self.ids_to_tokens[i]) - return tokens - - @classmethod - def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): - """ - Instantiate a PreTrainedBertModel from a pre-trained model file. - Download and cache the pre-trained model file if needed. - """ - if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: - vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] - else: - vocab_file = pretrained_model_name - if os.path.isdir(vocab_file): - vocab_file = os.path.join(vocab_file, VOCAB_NAME) - # redirect to the cache, if necessary - try: - resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) - except FileNotFoundError: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find any file " - "associated to this path or url.".format( - pretrained_model_name, - ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), - vocab_file)) - return None - if resolved_vocab_file == vocab_file: - logger.info("loading vocabulary file {}".format(vocab_file)) - else: - logger.info("loading vocabulary file {} from cache at {}".format( - vocab_file, resolved_vocab_file)) - if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: - # if we're using a pretrained model, ensure the tokenizer wont index sequences longer - # than the number of positional embeddings - max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] - kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) - # Instantiate tokenizer. - tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) - return tokenizer - - -class WhitespaceTokenizer(object): - def tokenize(self, text): - return whitespace_tokenize(text) - - -class BasicTokenizer(object): - """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" - - def __init__(self, do_lower_case=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): - """Constructs a BasicTokenizer. - - Args: - do_lower_case: Whether to lower case the input. - """ - self.do_lower_case = do_lower_case - self.never_split = never_split - - def tokenize(self, text): - """Tokenizes a piece of text.""" - text = self._clean_text(text) - # This was added on November 1st, 2018 for the multilingual and Chinese - # models. This is also applied to the English models now, but it doesn't - # matter since the English models were not trained on any Chinese data - # and generally don't have any Chinese data in them (there are Chinese - # characters in the vocabulary because Wikipedia does have some Chinese - # words in the English Wikipedia.). - text = self._tokenize_chinese_chars(text) - orig_tokens = whitespace_tokenize(text) - split_tokens = [] - for token in orig_tokens: - if self.do_lower_case and token not in self.never_split: - token = token.lower() - token = self._run_strip_accents(token) - split_tokens.extend(self._run_split_on_punc(token)) - - output_tokens = whitespace_tokenize(" ".join(split_tokens)) - return output_tokens - - def _run_strip_accents(self, text): - """Strips accents from a piece of text.""" - text = unicodedata.normalize("NFD", text) - output = [] - for char in text: - cat = unicodedata.category(char) - if cat == "Mn": - continue - output.append(char) - return "".join(output) - - def _run_split_on_punc(self, text): - """Splits punctuation on a piece of text.""" - if text in self.never_split: - return [text] - chars = list(text) - i = 0 - start_new_word = True - output = [] - while i < len(chars): - char = chars[i] - if _is_punctuation(char): - output.append([char]) - start_new_word = True - else: - if start_new_word: - output.append([]) - start_new_word = False - output[-1].append(char) - i += 1 - - return ["".join(x) for x in output] - - def _tokenize_chinese_chars(self, text): - """Adds whitespace around any CJK character.""" - output = [] - for char in text: - cp = ord(char) - if self._is_chinese_char(cp): - output.append(" ") - output.append(char) - output.append(" ") - else: - output.append(char) - return "".join(output) - - def _is_chinese_char(self, cp): - """Checks whether CP is the codepoint of a CJK character.""" - # This defines a "chinese character" as anything in the CJK Unicode block: - # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) - # - # Note that the CJK Unicode block is NOT all Japanese and Korean characters, - # despite its name. The modern Korean Hangul alphabet is a different block, - # as is Japanese Hiragana and Katakana. Those alphabets are used to write - # space-separated words, so they are not treated specially and handled - # like the all of the other languages. - if ((cp >= 0x4E00 and cp <= 0x9FFF) or # - (cp >= 0x3400 and cp <= 0x4DBF) or # - (cp >= 0x20000 and cp <= 0x2A6DF) or # - (cp >= 0x2A700 and cp <= 0x2B73F) or # - (cp >= 0x2B740 and cp <= 0x2B81F) or # - (cp >= 0x2B820 and cp <= 0x2CEAF) or - (cp >= 0xF900 and cp <= 0xFAFF) or # - (cp >= 0x2F800 and cp <= 0x2FA1F)): # - return True - - return False - - def _clean_text(self, text): - """Performs invalid character removal and whitespace cleanup on text.""" - output = [] - for char in text: - cp = ord(char) - if cp == 0 or cp == 0xfffd or _is_control(char): - continue - if _is_whitespace(char): - output.append(" ") - else: - output.append(char) - return "".join(output) - - -class WordpieceTokenizer(object): - """Runs WordPiece tokenization.""" - - def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): - self.vocab = vocab - self.unk_token = unk_token - self.max_input_chars_per_word = max_input_chars_per_word - - def tokenize(self, text): - """Tokenizes a piece of text into its word pieces. - - This uses a greedy longest-match-first algorithm to perform tokenization - using the given vocabulary. - - For example: - input = "unaffable" - output = ["un", "##aff", "##able"] - - Args: - text: A single token or whitespace separated tokens. This should have - already been passed through `BasicTokenizer`. - - Returns: - A list of wordpiece tokens. - """ - - output_tokens = [] - for token in whitespace_tokenize(text): - chars = list(token) - if len(chars) > self.max_input_chars_per_word: - output_tokens.append(self.unk_token) - continue - - is_bad = False - start = 0 - sub_tokens = [] - while start < len(chars): - end = len(chars) - cur_substr = None - while start < end: - substr = "".join(chars[start:end]) - if start > 0: - substr = "##" + substr - if substr in self.vocab: - cur_substr = substr - break - end -= 1 - if cur_substr is None: - is_bad = True - break - sub_tokens.append(cur_substr) - start = end - - if is_bad: - output_tokens.append(self.unk_token) - else: - output_tokens.extend(sub_tokens) - return output_tokens - - -def _is_whitespace(char): - """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically contorl characters but we treat them - # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False - - -def _is_control(char): - """Checks whether `chars` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char == "\t" or char == "\n" or char == "\r": - return False - cat = unicodedata.category(char) - if cat.startswith("C"): - return True - return False - - -def _is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or - (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False \ No newline at end of file diff --git a/text2text/questioner.py b/text2text/questioner.py deleted file mode 100644 index df4beb1..0000000 --- a/text2text/questioner.py +++ /dev/null @@ -1,49 +0,0 @@ -import random -import string -import text2text as t2t - -class Questioner(t2t.Abstractor): - pretrained_parameters = { - "file_id": "1JN2wnkSRotwUnJ_Z-AbWwoPdP53Gcfsn", - "fp16": False, - "amp": False, - "model_recover_path": "qg_model.bin", - "max_seq_length": 512, - "max_tgt_length": 48, - "batch_size": 16, - "search_beam_size": 1, - "length_penalty": 0, - "forbid_duplicate_ngrams": False, - "forbid_ignore_word": None, - "bert_model": "bert-large-cased", - "ffn_type": 0, - "num_qkv": 0, - "seg_emb": False, - "do_lower_case": False, - "new_segment_ids": True, - "min_len": None, - "ngram_size": 3, - "mode": "s2s", - "s2s_special_token": False, - "s2s_add_segment": False, - "s2s_share_segment": False, - "pos_shift": False, - "not_predict_token": None, - } - - def _get_random_answer(self, doc): - unique_words = set(doc.lower().translate(str.maketrans('', '', string.punctuation)).split()) - answers = list(unique_words-self.__class__.STOP_WORDS) - return random.choice(answers) if answers else random.choice(list(unique_words)) - - def transform(self, input_lines, src_lang='en', **kwargs): - input_lines = t2t.Transformer.transform(self, input_lines, src_lang, **kwargs) - if src_lang != 'en': - input_lines = self._translate_lines(input_lines, src_lang, 'en') - input_lines = [x + " [SEP] " + self._get_random_answer(x) if " [SEP] " not in x else x for x in input_lines] - questions = t2t.Abstractor.transform(self, input_lines, src_lang='en', **kwargs) - answers = [x.split(" [SEP] ")[1] for x in input_lines] - if src_lang != 'en': - questions = self._translate_lines(questions, 'en', src_lang) - answers = self._translate_lines(answers, 'en', src_lang) - return list(zip(questions, answers)) \ No newline at end of file diff --git a/text2text/summarizer.py b/text2text/summarizer.py deleted file mode 100644 index 8770b48..0000000 --- a/text2text/summarizer.py +++ /dev/null @@ -1,45 +0,0 @@ -import re -import glob -import math -from tqdm import tqdm -import numpy as np -import torch -import random -import requests, zipfile, io -import os - -from .pytorch_pretrained_bert.tokenization import BertTokenizer -from .pytorch_pretrained_bert.modeling import BertForSeq2SeqDecoder - -from .biunilm import seq2seq_loader - -import text2text as t2t - -class Summarizer(t2t.Abstractor): - pretrained_parameters = { - "file_id": "1RyJxShxC9tDYVAyZwUwqkSoQ3l5DfjuE", - "fp16": True, - "amp": True, - "model_recover_path": "cnndm_model.bin", - "max_seq_length": 768, - "max_tgt_length": 128, - "batch_size": 64, - "search_beam_size": 5, - "length_penalty": 0, - "forbid_duplicate_ngrams": True, - "forbid_ignore_word": ".|[X_SEP]", - "bert_model": "bert-large-cased", - "ffn_type": 0, - "num_qkv": 0, - "seg_emb": False, - "do_lower_case": False, - "new_segment_ids": True, - "min_len": None, - "ngram_size": 3, - "mode": "s2s", - "s2s_special_token": False, - "s2s_add_segment": False, - "s2s_share_segment": False, - "pos_shift": False, - "not_predict_token": None, - }