Skip to content

Commit

Permalink
bug fix and ready for release
Browse files Browse the repository at this point in the history
  • Loading branch information
WangHexie committed Apr 10, 2020
1 parent 306fb17 commit d09ae54
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 18 deletions.
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# 情感分类(sentiment classification)
> 文本情感分类
## Requirements

```sh
pip install -r requirements.txt
git clone https://github.com/WangHexie/ktrain.git
cd ktrain && pip install .
```

## Usage example

```shell script
python ./main.py --traning-dataset ./data --prediction-file "./data/test.csv" --test-dataset ./data/test --max_len 80 --model hfl/chinese-roberta-wwm-ext --batch_size 196
```

```shell script
usage: sentiment classification [-h] [--traning-dataset TRANING_DATASET]
[--prediction-file PREDICTION_FILE]
[--test-dataset TEST_DATASET] [--model MODEL]
[--batch_size BATCH_SIZE] [--max_len MAX_LEN]
[--temp_path TEMP_PATH]
[--early_stopping EARLY_STOPPING]
[--reduce_on_plateau REDUCE_ON_PLATEAU]

```






7 changes: 7 additions & 0 deletions data/labels.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
开心,兴奋,
,,
,,
开心,,
,,
焦虑,开心,
,,
7 changes: 7 additions & 0 deletions data/texts.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
大法开机无广告
"想多了吧如若不强制 等于无效"
"说实话网格员 有点浪费了"
每个女孩节日快乐!!!
我就也点赞了一个
条件别满足了
高三初三可以考虑早一点开学
19 changes: 2 additions & 17 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import asdict

from src.config.configs import ClassifierParam
from src.submit import prediction_all, classifier_ensemble
from src.submit import prediction_all

parser = argparse.ArgumentParser("sentiment classification", fromfile_prefix_chars='@')
parser.add_argument('--traning-dataset', type=str, help='train_dataset input path')
Expand All @@ -19,26 +19,11 @@

args = parser.parse_args()

prediction_paths = ["./test1.csv", "./test2.csv"]

prediction_all(test_dir=args.test_dataset,
train_file_dir=args.traning_dataset,
prediction_file_path=prediction_paths[0],
prediction_file_path=args.prediction_file,
config=asdict(
ClassifierParam(epochs=30, batch_size=args.batch_size, max_len=args.max_len, model_name=args.model,
learning_rate=1e-5, early_stopping=args.early_stopping,
reduce_on_plateau=args.reduce_on_plateau)),
temp_path=args.temp_path)

prediction_all(test_dir=args.test_dataset,
train_file_dir=args.traning_dataset,
prediction_file_path=prediction_paths[1],
config=asdict(
ClassifierParam(epochs=30, batch_size=156, max_len=85, model_name=args.model,
learning_rate=1e-5, early_stopping=args.early_stopping,
reduce_on_plateau=args.reduce_on_plateau)),
temp_path=args.temp_path)


classifier_ensemble(prediction_paths, output_path=args.prediction_file)

2 changes: 1 addition & 1 deletion src/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def reverse_transform_one_hot_to_label(df_prediction):

@staticmethod
def read_one_hot_label(label_path=FilePath.label_path, mode="train"):
return Dataset._transform_original_label_to_one_hot(Dataset.read_original_label(label_path, mode=mode).values)
return Dataset._transform_original_label_to_one_hot(Dataset.read_original_label(label_path, mode=mode).values).drop(columns=[''])

@staticmethod
def read_splitted_train_file(path):
Expand Down

0 comments on commit d09ae54

Please sign in to comment.