Skip to content

Commit

Permalink
Allow predict --cls to receive a file path
Browse files Browse the repository at this point in the history
  • Loading branch information
johnbradley committed Aug 22, 2024
1 parent 1abf49e commit 3d4d647
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,8 @@ Options:
--format=FORMAT format of the output (table or csv) for predict mode [default: csv]
--rank=RANK rank of the classification (kingdom, phylum, class, order, family, genus, species) [default: species]
--k=K number of top predictions to show [default: 5]
--cls=CLS comma separated list of classes to predict, when specified the --rank and --k
arguments are not allowed
--cls-file CLS_FILE path to file with list of classes to predict, one per line, when specified the --rank and --k arguments are not allowed
--cls=CLS classes to predict either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed.
--rank and --k arguments are not allowed
--device=DEVICE device to use matrix math (cpu or cuda or mps) [default: cpu]
--output=OUTFILE print output to file OUTFILE [default: stdout]
```
Expand Down
11 changes: 6 additions & 5 deletions src/bioclip/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier
from .predict import BIOCLIP_MODEL_STR
import open_clip as oc
import os
import json
import sys
import prettytable as pt
Expand Down Expand Up @@ -83,9 +84,9 @@ def create_parser():
predict_parser.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
help='rank of the classification, default: species (when)')
predict_parser.add_argument('--k', type=int, help='number of top predictions to show, default: 5')
cls_group = predict_parser.add_mutually_exclusive_group(required=False)
cls_group.add_argument('--cls', help='comma separated list of classes to predict, when specified the --rank argument is not allowed')
cls_group.add_argument('--cls-file', help='path to file with list of classes to predict, one per line, when specified the --rank and --k arguments are not allowed')
cls_help = "classes to predict either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed."
predict_parser.add_argument('--cls', help=cls_help)

predict_parser.add_argument('--device', **device_arg)
predict_parser.add_argument('--model', **model_arg)
predict_parser.add_argument('--pretrained', **pretrained_arg)
Expand Down Expand Up @@ -147,8 +148,8 @@ def main():
pretrained_str=args.pretrained)
elif args.command == 'predict':
cls_str = args.cls
if args.cls_file:
cls_str = create_classes_str(args.cls_file)
if os.path.exists(args.cls):
cls_str = create_classes_str(args.cls)
predict(args.image_file,
format=args.format,
output=args.output,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def test_parse_args(self):
args = parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10'])
self.assertEqual(args.k, 10)

args = parse_args(['predict', '--cls-file', 'somefile.txt', 'image.jpg'])
self.assertEqual(args.cls_file, 'somefile.txt')
self.assertEqual(args.cls, None)
# example showing filename
args = parse_args(['predict', 'image.jpg', '--cls', 'classes.txt', '--k', '10'])
self.assertEqual(args.cls, 'classes.txt')

args = parse_args(['embed', 'image.jpg'])
self.assertEqual(args.command, 'embed')
Expand Down

0 comments on commit 3d4d647

Please sign in to comment.