diff --git a/README.md b/README.md
index e1a9e40..a7ffeb8 100644
--- a/README.md
+++ b/README.md
@@ -191,7 +191,7 @@ You can Inference your **YOLO-NAS** model with **Single Command Line**
`-n`, `--num`: Number of classes the model trained on
`-m`, `--model`: Model type (choices: `yolo_nas_s`, `yolo_nas_m`, `yolo_nas_l`)
- `-w`, `--weight`: path to trained model weight
+ `-w`, `--weight`: path to trained model weight, for COCO model: `coco`
`-s`, `--source`: video path/cam-id/RTSP
`-c`, `--conf`: model prediction confidence (0
`--save`: to save video
@@ -200,7 +200,13 @@ You can Inference your **YOLO-NAS** model with **Single Command Line**
**Example:**
+```bash
+# For COCO YOLO-NAS Model
+python3 inference.py --model yolo_nas_s --weight coco --source 0 # Camera
+python3 inference.py --model yolo_nas_m --weight coco --source /test/video.mp4 --conf 0.66 # video
```
+
+```bash
python3 inference.py --num 3 --model yolo_nas_m --weight /runs/train4/ckpt_best.pth --source /test/video.mp4 --conf 0.66 # video
--source /test/sample.jpg --conf 0.5 --save # Image save
--source /test/video.mp4 --conf 0.75 --hide # to save and hide video window
diff --git a/inference.py b/inference.py
index 76da162..edf5ef5 100644
--- a/inference.py
+++ b/inference.py
@@ -9,7 +9,7 @@
ap = argparse.ArgumentParser()
-ap.add_argument("-n", "--num", type=int, required=True,
+ap.add_argument("-n", "--num", type=int, required=False,
help="number of classes the model trained on")
ap.add_argument("-m", "--model", type=str, default='yolo_nas_s',
choices=['yolo_nas_s', 'yolo_nas_m', 'yolo_nas_l'],
@@ -52,12 +52,17 @@ def get_bbox(img):
return labels
+# Load COCO YOLO-NAS Model
+if args["weight"] == "coco":
+ model = models.get(args['model'], pretrained_weights="coco")
# Load YOLO-NAS Model
-model = models.get(
- args['model'],
- num_classes=args['num'],
- checkpoint_path=args["weight"]
-)
+else:
+ model = models.get(
+ args['model'],
+ num_classes=args['num'],
+ checkpoint_path=args["weight"]
+ )
+
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
class_names = model.predict(np.zeros((1,1,3)), conf=args['conf'])._images_prediction_lst[0].class_names
print('Class Names: ', class_names)