Skip to content

sunyilgdx/RoBERTa4Keras

Repository files navigation

RoBERTa4Keras

让英文原版RoBERTa在bert4keras框架下运行

使用了苏神的bert4keras框架

分别借鉴了并改动了

优点

得益于bert4keras的可扩展性,完全不需要修改bert4keras的框架,仅需要调用BERT类和更改tokenizer即可实现在bert4keras框架使用英文原版RoBERTa

核心思路

  • 使用英文RoBERTa、GPT2、BART等模型的bpe分词
  • 传入custom position id, 将position id设置为 [2, max_length]
  • 抛弃segment embeddings, 将segment_vocab_size=0
  • fairseq的pytorch RoBERTa转为google原版的BERT

使用例子

简单写了一个文本分类的代码,可以直接运行cls_classification_roberta.py,也可以使用run_cls_roberta.sh(可能需要修改相对路径)

  1. 将fairseq原版的pytorch RoBERTa转为tensorflow格式
  1. 基于[CLS]位置的fine-tuning关键步骤
  • 加载bpe分词器

    merges_file = r'./models/roberta_large_fairseq_tf/merges.txt'
    vocab_file = r'./models/roberta_large_fairseq_tf/vocab.json'
    tokenizer = RobertaTokenizer(vocab_file=vocab_file, merges_file=merges_file)
    
  • 使用bert4keras的框架中的BERT模型加载RoBERTa模型 这里需要设置两个关键参数,custom_position_ids=True传入自定义position_ids, segment_vocab_size=0将segment embeddings维度设置为0

    bert = build_transformer_model(
          config_path,
          checkpoint_path,
          model=model,
          with_pool=False,
          return_keras_model=False,
          custom_position_ids=True,
          segment_vocab_size=0  # RoBERTa don't have the segment embeddings (token type embeddings).
      )
    
  • 分词和token编码 这里会调用bpe_tokenization.py中的分词模型,将会把输入的文本编码为<s> X </s>这种形式

    token_ids, _ = tokenizer.encode(text, maxlen=maxlen)
    
  • 位置编码 这里需要把位置编码设置为[2,max_len],具体原因需要到fairseq的仓库下查issues,padding使用的是1这个position id(可能也没有影响)

    custom_position_ids = [2 + i for i in range(len(token_ids))]
    batch_custom_position_ids = sequence_padding(batch_custom_position_ids, value=1)
    

其他地方与中文BERT基本一致,只需要改一些inputs就可以了,不再赘述

Releases

No releases published

Packages

No packages published