-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 6d6bfcb
Showing
12 changed files
with
1,127 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
name: pre-commit | ||
|
||
on: | ||
pull_request: | ||
push: | ||
branches: [main] | ||
|
||
jobs: | ||
check_and_test: | ||
runs-on: [self-hosted, linux, x64, cpu-only] | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: actions/setup-python@v4 | ||
id: ko-sentence-transformers | ||
with: | ||
python-version: '3.10' | ||
- name: pre-commit # don't use in self-hosted `- uses: pre-commit/action@v2.0.3` | ||
run: | | ||
pip install -U pre-commit | ||
pre-commit install --install-hooks | ||
pre-commit run -a |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
exclude: ^(legacy|bin) | ||
repos: | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v4.0.1 | ||
hooks: | ||
- id: end-of-file-fixer | ||
types: [python] | ||
- id: trailing-whitespace | ||
types: [python] | ||
- id: mixed-line-ending | ||
types: [python] | ||
- id: check-added-large-files | ||
args: [--maxkb=4096] | ||
- repo: https://github.com/psf/black | ||
rev: 22.3.0 | ||
hooks: | ||
- id: black | ||
args: ["--line-length", "120"] | ||
- repo: https://github.com/pycqa/isort | ||
rev: 5.12.0 | ||
hooks: | ||
- id: isort | ||
name: isort (python) | ||
args: ["--profile", "black", "-l", "120"] | ||
- repo: https://github.com/pycqa/flake8.git | ||
rev: 6.0.0 | ||
hooks: | ||
- id: flake8 | ||
types: [python] | ||
args: ["--max-line-length", "120", "--ignore", "F811,F841,E203,E402,E712,W503"] |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
# kf-deberta-multitask | ||
|
||
kakaobank의 [kf-deberta-base](https://huggingface.co/kakaobank/kf-deberta-base) 모델을 KorNLI, KorSTS 데이터셋으로 파인튜닝한 모델입니다. | ||
[jhgan00/ko-sentence-transformers](https://github.com/jhgan00/ko-sentence-transformers) 코드를 기반으로 일부 수정하여 진행하였습니다. | ||
|
||
<br> | ||
|
||
## KorSTS Benchmark | ||
|
||
- [jhgan00/ko-sentence-transformers](https://github.com/jhgan00/ko-sentence-transformers#korsts-benchmarks)의 결과를 참고하여 재작성하였습니다. | ||
- 학습 및 성능 평가 과정은 `training_*.py`, `benchmark.py` 에서 확인할 수 있습니다. | ||
- 학습된 모델은 허깅페이스 모델 허브에 공개되어 있습니다. | ||
|
||
<br> | ||
|
||
|model|cosine_pearson|cosine_spearman|euclidean_pearson|euclidean_spearman|manhattan_pearson|manhattan_spearman|dot_pearson|dot_spearman| | ||
|:-------------------------|-----------------:|------------------:|--------------------:|---------------------:|--------------------:|---------------------:|--------------:|---------------:| | ||
|[kf-deberta-multitask](https://huggingface.co/upskyy/kf-deberta-multitask)|**85.75**|**86.25**|**84.79**|**85.25**|**84.80**|**85.27**|**82.93**|**82.86**| | ||
|[ko-sroberta-multitask](https://huggingface.co/jhgan/ko-sroberta-multitask)|84.77|85.6|83.71|84.40|83.70|84.38|82.42|82.33| | ||
|[ko-sbert-multitask](https://huggingface.co/jhgan/ko-sbert-multitask)|84.13|84.71|82.42|82.66|82.41|82.69|80.05|79.69| | ||
|[ko-sroberta-base-nli](https://huggingface.co/jhgan/ko-sroberta-nli)|82.83|83.85|82.87|83.29|82.88|83.28|80.34|79.69| | ||
|[ko-sbert-nli](https://huggingface.co/jhgan/ko-sbert-multitask)|82.24|83.16|82.19|82.31|82.18|82.3|79.3|78.78| | ||
|[ko-sroberta-sts](https://huggingface.co/jhgan/ko-sroberta-sts)|81.84|81.82|81.15|81.25|81.14|81.25|79.09|78.54| | ||
|[ko-sbert-sts](https://huggingface.co/jhgan/ko-sbert-sts)|81.55|81.23|79.94|79.79|79.9|79.75|76.02|75.31| | ||
|
||
<br> | ||
|
||
## Examples | ||
|
||
- 예시 출처: <https://github.com/BM-K/KoSentenceBERT-SKT> | ||
|
||
<br> | ||
|
||
아래는 임베딩 벡터를 통해 가장 유사한 문장을 찾는 예시입니다. | ||
더 많은 예시는 [sentence-transformers 문서](https://www.sbert.net/index.html)를 참고해주세요. | ||
|
||
```python | ||
from sentence_transformers import SentenceTransformer, util | ||
import numpy as np | ||
|
||
|
||
embedder = SentenceTransformer("upskyy/kf-deberta-multitask") | ||
|
||
# Corpus with example sentences | ||
corpus = [ | ||
"한 남자가 음식을 먹는다.", | ||
"한 남자가 빵 한 조각을 먹는다.", | ||
"그 여자가 아이를 돌본다.", | ||
"한 남자가 말을 탄다.", | ||
"한 여자가 바이올린을 연주한다.", | ||
"두 남자가 수레를 숲 속으로 밀었다.", | ||
"한 남자가 담으로 싸인 땅에서 백마를 타고 있다.", | ||
"원숭이 한 마리가 드럼을 연주한다.", | ||
"치타 한 마리가 먹이 뒤에서 달리고 있다.", | ||
] | ||
|
||
corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True) | ||
|
||
# Query sentences: | ||
queries = [ | ||
"한 남자가 파스타를 먹는다.", | ||
"고릴라 의상을 입은 누군가가 드럼을 연주하고 있다.", | ||
"치타가 들판을 가로 질러 먹이를 쫓는다." | ||
] | ||
|
||
# Find the closest 5 sentences of the corpus for each query sentence based on cosine similarity | ||
top_k = 5 | ||
for query in queries: | ||
query_embedding = embedder.encode(query, convert_to_tensor=True) | ||
cos_scores = util.pytorch_cos_sim(query_embedding, corpus_embeddings)[0] | ||
cos_scores = cos_scores.cpu() | ||
|
||
# We use np.argpartition, to only partially sort the top_k results | ||
top_results = np.argpartition(-cos_scores, range(top_k))[0:top_k] | ||
|
||
print("\n\n======================\n\n") | ||
print("Query:", query) | ||
print("\nTop 5 most similar sentences in corpus:") | ||
|
||
for idx in top_results[0:top_k]: | ||
print(corpus[idx].strip(), "(Score: %.4f)" % (cos_scores[idx])) | ||
``` | ||
|
||
<br> | ||
|
||
``` | ||
====================== | ||
Query: 한 남자가 파스타를 먹는다. | ||
Top 5 most similar sentences in corpus: | ||
한 남자가 음식을 먹는다. (Score: 0.5826) | ||
한 남자가 빵 한 조각을 먹는다. (Score: 0.5507) | ||
한 남자가 말을 탄다. (Score: 0.1767) | ||
한 남자가 담으로 싸인 땅에서 백마를 타고 있다. (Score: 0.0965) | ||
치타 한 마리가 먹이 뒤에서 달리고 있다. (Score: 0.0429) | ||
====================== | ||
Query: 고릴라 의상을 입은 누군가가 드럼을 연주하고 있다. | ||
Top 5 most similar sentences in corpus: | ||
원숭이 한 마리가 드럼을 연주한다. (Score: 0.7093) | ||
한 여자가 바이올린을 연주한다. (Score: 0.2374) | ||
치타 한 마리가 먹이 뒤에서 달리고 있다. (Score: 0.1872) | ||
그 여자가 아이를 돌본다. (Score: 0.1574) | ||
한 남자가 말을 탄다. (Score: 0.0883) | ||
====================== | ||
Query: 치타가 들판을 가로 질러 먹이를 쫓는다. | ||
Top 5 most similar sentences in corpus: | ||
치타 한 마리가 먹이 뒤에서 달리고 있다. (Score: 0.7740) | ||
한 남자가 담으로 싸인 땅에서 백마를 타고 있다. (Score: 0.2161) | ||
두 남자가 수레를 숲 속으로 밀었다. (Score: 0.1806) | ||
한 남자가 음식을 먹는다. (Score: 0.1651) | ||
한 남자가 말을 탄다. (Score: 0.1352) | ||
``` | ||
|
||
<br> | ||
|
||
## Training | ||
|
||
직접 모델을 파인튜닝하려면 [`kor-nlu-datasets`](https://github.com/kakaobrain/kor-nlu-datasets) 저장소를 clone 하고 `training_*.py` 스크립트를 실행시키면 됩니다. | ||
|
||
`train.sh` 파일에서 학습 예시를 확인할 수 있습니다. | ||
|
||
```bash | ||
git clone https://github.com/upskyy/kf-deberta-multitask.git | ||
cd kf-deberta-multitask | ||
|
||
pip install -r requirements.txt | ||
|
||
git clone https://github.com/kakaobrain/kor-nlu-datasets.git | ||
|
||
python training_multi_task.py --model_name_or_path kakaobank/kf-deberta-base | ||
./bin/train.sh | ||
``` | ||
|
||
<br> | ||
|
||
## Evaluation | ||
|
||
KorSTS Benchmark를 평가하는 방법입니다. | ||
|
||
```bash | ||
git clone https://github.com/upskyy/kf-deberta-multitask.git | ||
cd kf-deberta-multitask | ||
|
||
pip install -r requirements.txt | ||
|
||
git clone https://github.com/kakaobrain/kor-nlu-datasets.git | ||
python bin/benchmark.py | ||
``` | ||
|
||
<br> | ||
|
||
## Export ONNX | ||
|
||
`requirements.txt` 설치 후 `bin` 디렉토리에서 `export_onnx.py` 스크립트를 실행시키면 됩니다. | ||
|
||
```bash | ||
git clone https://github.com/upskyy/kf-deberta-multitask.git | ||
cd kf-deberta-multitask | ||
|
||
pip install -r requirements.txt | ||
|
||
python bin/export_onnx.py | ||
``` | ||
|
||
<br> | ||
|
||
## Acknowledgements | ||
|
||
- [kakaobank/kf-deberta-base](https://huggingface.co/kakaobank/kf-deberta-base) for pretrained model | ||
- [jhgan00/ko-sentence-transformers](https://github.com/jhgan00/ko-sentence-transformers) for original codebase | ||
- [kor-nlu-datasets](https://github.com/kakaobrain/kor-nlu-datasets) for training data | ||
|
||
<br> | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@proceedings{jeon-etal-2023-kfdeberta, | ||
title = {KF-DeBERTa: Financial Domain-specific Pre-trained Language Model}, | ||
author = {Eunkwang Jeon, Jungdae Kim, Minsang Song, and Joohyun Ryu}, | ||
booktitle = {Proceedings of the 35th Annual Conference on Human and Cognitive Language Technology}, | ||
moth = {oct}, | ||
year = {2023}, | ||
publisher = {Korean Institute of Information Scientists and Engineers}, | ||
url = {http://www.hclt.kr/symp/?lnb=conference}, | ||
pages = {143--148}, | ||
} | ||
``` | ||
|
||
```bibtex | ||
@article{ham2020kornli, | ||
title={KorNLI and KorSTS: New Benchmark Datasets for Korean Natural Language Understanding}, | ||
author={Ham, Jiyeon and Choe, Yo Joong and Park, Kyubyong and Choi, Ilji and Soh, Hyungjoon}, | ||
journal={arXiv preprint arXiv:2004.03289}, | ||
year={2020} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import argparse | ||
import csv | ||
import logging | ||
import os | ||
|
||
from sentence_transformers import InputExample, LoggingHandler, SentenceTransformer | ||
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator | ||
|
||
logging.basicConfig( | ||
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--sts_dataset_path", type=str, default="kor-nlu-datasets/KorSTS") | ||
parser.add_argument("--model_name_or_path", type=str, required=True) | ||
args = parser.parse_args() | ||
|
||
# Read STSbenchmark dataset and use it as development set | ||
test_samples = [] | ||
test_file = os.path.join(args.sts_dataset_path, "sts-test.tsv") | ||
|
||
with open(test_file, "rt", encoding="utf8") as fIn: | ||
reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) | ||
for row in reader: | ||
score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 | ||
test_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) | ||
|
||
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test") | ||
|
||
model = SentenceTransformer(args.model_name_or_path) | ||
test_evaluator(model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import os | ||
from pathlib import Path | ||
from transformers.convert_graph_to_onnx import convert | ||
|
||
|
||
if __name__ == "__main__": | ||
output_dir = "models" | ||
|
||
if not os.path.exists(output_dir): | ||
os.makedirs(output_dir, exist_ok=False) | ||
|
||
output_fpath = os.path.join(output_dir, "kf-deberta-multitask.onnx") | ||
convert(framework="pt", model="upskyy/kf-deberta-multitask", output=Path(output_fpath), opset=15) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# To start training, you need to download the KorNLUDatasets first. | ||
# git clone https://github.com/kakaobrain/kor-nlu-datasets.git | ||
|
||
# train on STS dataset only | ||
# python training_sts.py --model_name_or_path klue/bert-base | ||
# python training_sts.py --model_name_or_path klue/roberta-base | ||
# python training_sts.py --model_name_or_path klue/roberta-small | ||
# python training_sts.py --model_name_or_path klue/roberta-large | ||
python training_sts.py --model_name_or_path kakaobank/kf-deberta-base | ||
|
||
# train on both NLI and STS dataset (multi-task) | ||
# python training_multi_task.py --model_name_or_path klue/bert-base | ||
# python training_multi_task.py --model_name_or_path klue/roberta-base | ||
# python training_multi_task.py --model_name_or_path klue/roberta-small | ||
# python training_multi_task.py --model_name_or_path klue/roberta-large | ||
python training_multi_task.py --model_name_or_path kakaobank/kf-deberta-base | ||
|
||
# train on NLI dataset only | ||
# python training_nli.py --model_name_or_path klue/bert-base | ||
# python training_nli.py --model_name_or_path klue/roberta-base | ||
# python training_nli.py --model_name_or_path klue/roberta-small | ||
# python training_nli.py --model_name_or_path klue/roberta-large | ||
python training_nli.py --model_name_or_path kakaobank/kf-deberta-base |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import csv | ||
import random | ||
|
||
from sentence_transformers.readers import InputExample | ||
|
||
|
||
def load_kor_sts_samples(filename): | ||
samples = [] | ||
with open(filename, "rt", encoding="utf8") as fIn: | ||
reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) | ||
for row in reader: | ||
score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 | ||
samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) | ||
return samples | ||
|
||
|
||
def load_kor_nli_samples(filename): | ||
data = {} | ||
|
||
def add_to_samples(sent1, sent2, label): | ||
if sent1 not in data: | ||
data[sent1] = {"contradiction": set(), "entailment": set(), "neutral": set()} | ||
data[sent1][label].add(sent2) | ||
|
||
with open(filename, "r", encoding="utf-8") as fIn: | ||
reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) | ||
for row in reader: | ||
sent1 = row["sentence1"].strip() | ||
sent2 = row["sentence2"].strip() | ||
add_to_samples(sent1, sent2, row["gold_label"]) | ||
add_to_samples(sent2, sent1, row["gold_label"]) # Also add the opposite | ||
|
||
samples = [] | ||
for sent, others in data.items(): | ||
if len(others["entailment"]) > 0 and len(others["contradiction"]) > 0: | ||
samples.append( | ||
InputExample( | ||
texts=[ | ||
sent, | ||
random.choice(list(others["entailment"])), | ||
random.choice(list(others["contradiction"])), | ||
] | ||
) | ||
) | ||
samples.append( | ||
InputExample( | ||
texts=[ | ||
random.choice(list(others["entailment"])), | ||
sent, | ||
random.choice(list(others["contradiction"])), | ||
] | ||
) | ||
) | ||
return samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
sentence-transformers | ||
onnxruntime | ||
onnx |
Oops, something went wrong.