Skip to content

Commit

Permalink
added yolonas coco model inference
Browse files Browse the repository at this point in the history
  • Loading branch information
naseemap47 committed Sep 30, 2024
1 parent 6f838b0 commit 7167a11
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ You can Inference your **YOLO-NAS** model with **Single Command Line**

`-n`, `--num`: Number of classes the model trained on <br>
`-m`, `--model`: Model type (choices: `yolo_nas_s`, `yolo_nas_m`, `yolo_nas_l`) <br>
`-w`, `--weight`: path to trained model weight <br>
`-w`, `--weight`: path to trained model weight, for COCO model: `coco` <br>
`-s`, `--source`: video path/cam-id/RTSP <br>
`-c`, `--conf`: model prediction confidence (0<conf<1) <br>
`--save`: to save video <br>
Expand All @@ -200,7 +200,13 @@ You can Inference your **YOLO-NAS** model with **Single Command Line**
</details>

**Example:**
```bash
# For COCO YOLO-NAS Model
python3 inference.py --model yolo_nas_s --weight coco --source 0 # Camera
python3 inference.py --model yolo_nas_m --weight coco --source /test/video.mp4 --conf 0.66 # video
```

```bash
python3 inference.py --num 3 --model yolo_nas_m --weight /runs/train4/ckpt_best.pth --source /test/video.mp4 --conf 0.66 # video
--source /test/sample.jpg --conf 0.5 --save # Image save
--source /test/video.mp4 --conf 0.75 --hide # to save and hide video window
Expand Down
17 changes: 11 additions & 6 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


ap = argparse.ArgumentParser()
ap.add_argument("-n", "--num", type=int, required=True,
ap.add_argument("-n", "--num", type=int, required=False,
help="number of classes the model trained on")
ap.add_argument("-m", "--model", type=str, default='yolo_nas_s',
choices=['yolo_nas_s', 'yolo_nas_m', 'yolo_nas_l'],
Expand Down Expand Up @@ -52,12 +52,17 @@ def get_bbox(img):
return labels


# Load COCO YOLO-NAS Model
if args["weight"] == "coco":
model = models.get(args['model'], pretrained_weights="coco")
# Load YOLO-NAS Model
model = models.get(
args['model'],
num_classes=args['num'],
checkpoint_path=args["weight"]
)
else:
model = models.get(
args['model'],
num_classes=args['num'],
checkpoint_path=args["weight"]
)

model = model.to("cuda" if torch.cuda.is_available() else "cpu")
class_names = model.predict(np.zeros((1,1,3)), conf=args['conf'])._images_prediction_lst[0].class_names
print('Class Names: ', class_names)
Expand Down

0 comments on commit 7167a11

Please sign in to comment.