-
Notifications
You must be signed in to change notification settings - Fork 0
/
task-predict
executable file
·106 lines (93 loc) · 2.11 KB
/
task-predict
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/bin/bash
usage() {
echo "usage: task-predict [--device DEV] [--batch N] [--epochs N] [--verbose] [--checkpoint id] modeldir source [source ...]"
echo
echo "arguments:"
echo
echo " modeldir : model storage directory"
echo " source : prediction dataset(s) and/or image folder path(s)"
echo
echo "prediction options:"
echo
echo " --device DEV : tensorflow device"
echo " --batch N : prediction batch size"
echo " --epochs N : prediction epochs (only for raw images)"
echo " --verbose : output detailed summaries"
echo " --checkpoint ID : use specific checkpoint"
}
MODEL=
TMPMODEL=tmp-$(date +%Y%m%d%H%M%S)
SOURCES=
if [ -f "/proc/driver/nvidia/version" ]
then
DEVICE="/gpu:0"
else
DEVICE="/cpu:0"
fi
BATCH_SIZE=1
EPOCHS=1
VERBOSE=
CHECKPOINT=
while [ $# -ge 2 ]
do
case "$1" in
--device)
DEVICE=$2
shift 2
;;
--batch)
BATCH_SIZE=$2
shift 2
;;
--epochs)
EPOCHS=$2
shift 2
;;
--verbose)
VERBOSE=$1
shift
;;
--checkpoint)
CHECKPOINT=$2
shift 2
;;
*)
MODEL=$1
shift
SOURCES=$@
shift $#
;;
esac
done
if [ $# -gt 0 -o "x$MODEL" = "x" -o "x$SOURCES" = "x" ]
then
usage
exit 1
fi
if [ ! -f "$MODEL/checkpoint" ]
then
echo "ERROR: no model checkpoint available"
exit 2
fi
if [ "x$CHECKPOINT" = "x" ]
then
CHECKPOINT=$(cat "$MODEL/checkpoint" | grep -E '^model_checkpoint_path' | grep -o -E '[^ ]+$' | grep -o -E '[^"]+')
fi
if [ ! -f "$MODEL/$CHECKPOINT.meta" ]
then
echo "ERROR: model checkpoint not found ($CHECKPOINT)"
exit 3
fi
echo "MODEL: importing checkpoint $CHECKPOINT from $MODEL to $TMPMODEL..."
mkdir -p $TMPMODEL
cp -p $MODEL/{graph.pbtxt,parameters.json,topology.yaml,$CHECKPOINT.meta,$CHECKPOINT.index,$CHECKPOINT.data*} $TMPMODEL/
echo "model_checkpoint_path: \"$CHECKPOINT\"" > $TMPMODEL/checkpoint
python3 src/run.py predict \
--device $DEVICE \
--batch $BATCH_SIZE \
--epochs $EPOCHS \
$VERBOSE \
$TMPMODEL \
$SOURCES
echo "MODEL: cleanup $TMPMODEL..."
rm -rf $TMPMODEL