Skip to content

Commit

Permalink
fix coco dataset training support (#174)
Browse files Browse the repository at this point in the history
* fix training from coco

* add tests for training from coco

* add test data.yaml for cooc training

* fix ci

* update version
  • Loading branch information
fcakyon authored Dec 14, 2022
1 parent 7c010a1 commit 73a6de3
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 3 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,15 @@ jobs:
- name: Test scripts
run: |
pip install -e .
# donwload coco formatted testing data
python tests/download_test_coco_dataset.py
# train (dl base model from hf hub)
yolov5 train --data tests/data/coco_data.yaml --img 128 --batch 16 --weights fcakyon/yolov5n-v7.0 --epochs 1 --device cpu --freeze 10
python yolov5/train.py --img 128 --batch 16 --weights fcakyon/yolov5n-v7.0 --epochs 1 --device cpu
yolov5 train --img 128 --batch 16 --weights fcakyon/yolov5n-v7.0 --epochs 1 --device cpu --freeze 10
yolov5 train --img 128 --batch 16 --weights fcakyon/yolov5n-v7.0 --epochs 1 --device cpu --evolve 2
# train
yolov5 train --data tests/data/coco_data.yaml --img 128 --batch 16 --weights yolov5/weights/yolov5n.pt --epochs 1 --device cpu --freeze 10
python yolov5/train.py --img 128 --batch 16 --weights yolov5/weights/yolov5n.pt --epochs 1 --device cpu
yolov5 train --img 128 --batch 16 --weights yolov5/weights/yolov5n.pt --epochs 1 --device cpu --freeze 10
yolov5 train --img 128 --batch 16 --weights yolov5/weights/yolov5n.pt --epochs 1 --device cpu --evolve 2
Expand Down
4 changes: 4 additions & 0 deletions tests/data/coco_data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
train_json_path: "tests/data/dataset.json"
train_image_dir: "tests/data/images/"
val_json_path: "tests/data/dataset.json"
val_image_dir: "tests/data/images/"
34 changes: 34 additions & 0 deletions tests/download_test_coco_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
from pathlib import Path

import urllib.request

TEST_COCO_DATASET_URL = "https://github.com/fcakyon/yolov5-pip/releases/download/0.6/test-coco-dataset.zip"
TEST_COCO_DATASET_ZIP_PATH = "tests/data/test-coco-dataset.zip"


def download_from_url(from_url: str, to_path: str):

Path(to_path).parent.mkdir(parents=True, exist_ok=True)

if not os.path.exists(to_path):
urllib.request.urlretrieve(
from_url,
to_path,
)

def download_test_coco_dataset_and_unzip():
if not os.path.exists(TEST_COCO_DATASET_ZIP_PATH):
print("Downloading test coco dataset...")
download_from_url(TEST_COCO_DATASET_URL, TEST_COCO_DATASET_ZIP_PATH)
print("Download complete.")

if not os.path.exists("tests/data/dataset.json"):
print("Unzipping test coco dataset...")
import zipfile
with zipfile.ZipFile(TEST_COCO_DATASET_ZIP_PATH, "r") as zip_ref:
zip_ref.extractall("tests/data")
print("Unzip complete.")

if __name__ == "__main__":
download_test_coco_dataset_and_unzip()
2 changes: 1 addition & 1 deletion yolov5/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from yolov5.helpers import YOLOv5
from yolov5.helpers import load_model as load

__version__ = "7.0.1"
__version__ = "7.0.2"
4 changes: 2 additions & 2 deletions yolov5/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def convert_coco_dataset_to_yolo(opt, save_dir):

if is_coco_data:
from sahi.utils.coco import export_coco_as_yolov5_via_yml
data = export_coco_as_yolov5_via_yml(yml_path=data, output_dir=save_dir / 'data')
data = export_coco_as_yolov5_via_yml(yml_path=opt.data, output_dir=save_dir / 'data')
opt.data = data

# add coco fields to data.yaml
Expand Down Expand Up @@ -158,7 +158,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
cuda = device.type != 'cpu'
init_seeds(opt.seed + 1 + RANK, deterministic=True)
with torch_distributed_zero_first(LOCAL_RANK):
data_dict = check_dataset(data)
data_dict = check_dataset(opt.data)
train_path, val_path = data_dict['train'], data_dict['val']
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
Expand Down

0 comments on commit 73a6de3

Please sign in to comment.