diff --git a/README.md b/README.md index e1a9e40..a7ffeb8 100644 --- a/README.md +++ b/README.md @@ -191,7 +191,7 @@ You can Inference your **YOLO-NAS** model with **Single Command Line** `-n`, `--num`: Number of classes the model trained on
`-m`, `--model`: Model type (choices: `yolo_nas_s`, `yolo_nas_m`, `yolo_nas_l`)
- `-w`, `--weight`: path to trained model weight
+ `-w`, `--weight`: path to trained model weight, for COCO model: `coco`
`-s`, `--source`: video path/cam-id/RTSP
`-c`, `--conf`: model prediction confidence (0 `--save`: to save video
@@ -200,7 +200,13 @@ You can Inference your **YOLO-NAS** model with **Single Command Line** **Example:** +```bash +# For COCO YOLO-NAS Model +python3 inference.py --model yolo_nas_s --weight coco --source 0 # Camera +python3 inference.py --model yolo_nas_m --weight coco --source /test/video.mp4 --conf 0.66 # video ``` + +```bash python3 inference.py --num 3 --model yolo_nas_m --weight /runs/train4/ckpt_best.pth --source /test/video.mp4 --conf 0.66 # video --source /test/sample.jpg --conf 0.5 --save # Image save --source /test/video.mp4 --conf 0.75 --hide # to save and hide video window diff --git a/inference.py b/inference.py index 76da162..edf5ef5 100644 --- a/inference.py +++ b/inference.py @@ -9,7 +9,7 @@ ap = argparse.ArgumentParser() -ap.add_argument("-n", "--num", type=int, required=True, +ap.add_argument("-n", "--num", type=int, required=False, help="number of classes the model trained on") ap.add_argument("-m", "--model", type=str, default='yolo_nas_s', choices=['yolo_nas_s', 'yolo_nas_m', 'yolo_nas_l'], @@ -52,12 +52,17 @@ def get_bbox(img): return labels +# Load COCO YOLO-NAS Model +if args["weight"] == "coco": + model = models.get(args['model'], pretrained_weights="coco") # Load YOLO-NAS Model -model = models.get( - args['model'], - num_classes=args['num'], - checkpoint_path=args["weight"] -) +else: + model = models.get( + args['model'], + num_classes=args['num'], + checkpoint_path=args["weight"] + ) + model = model.to("cuda" if torch.cuda.is_available() else "cpu") class_names = model.predict(np.zeros((1,1,3)), conf=args['conf'])._images_prediction_lst[0].class_names print('Class Names: ', class_names)