-
Notifications
You must be signed in to change notification settings - Fork 15
/
onnx_util.go
99 lines (86 loc) · 2.87 KB
/
onnx_util.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
package main
import (
ort "github.com/yalue/onnxruntime_go"
"runtime"
)
func InitYolo8Session(input []float32) (ModelSession, error) {
ort.SetSharedLibraryPath(getSharedLibPath())
err := ort.InitializeEnvironment()
if err != nil {
return ModelSession{}, err
}
inputShape := ort.NewShape(1, 3, 640, 640)
inputTensor, err := ort.NewTensor(inputShape, input)
if err != nil {
return ModelSession{}, err
}
outputShape := ort.NewShape(1, 84, 8400)
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
if err != nil {
return ModelSession{}, err
}
options, e := ort.NewSessionOptions()
if e != nil {
return ModelSession{}, err
}
if UseCoreML { // If CoreML is enabled, append the CoreML execution provider
e = options.AppendExecutionProviderCoreML(0)
if e != nil {
options.Destroy()
return ModelSession{}, err
}
defer options.Destroy()
}
session, err := ort.NewAdvancedSession(ModelPath,
[]string{"images"}, []string{"output0"},
[]ort.ArbitraryTensor{inputTensor}, []ort.ArbitraryTensor{outputTensor}, options)
if err != nil {
return ModelSession{}, err
}
modelSes := ModelSession{
Session: session,
Input: inputTensor,
Output: outputTensor,
}
return modelSes, err
}
func getSharedLibPath() string {
if runtime.GOOS == "windows" {
if runtime.GOARCH == "amd64" {
return "./third_party/onnxruntime.dll"
}
}
if runtime.GOOS == "darwin" {
if runtime.GOARCH == "arm64" {
return "./third_party/onnxruntime_arm64.dylib"
}
}
if runtime.GOOS == "linux" {
if runtime.GOARCH == "arm64" {
return "../third_party/onnxruntime_arm64.so"
}
return "./third_party/onnxruntime.so"
}
panic("Unable to find a version of the onnxruntime library supporting this system.")
}
// Array of YOLOv8 class labels
var yolo_classes = []string{
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
"sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon",
"bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
"cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
"clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
}
func runInference(modelSes ModelSession, input []float32) ([]float32, error) {
inTensor := modelSes.Input.GetData()
copy(inTensor, input)
err := modelSes.Session.Run()
if err != nil {
return nil, err
}
return modelSes.Output.GetData(), nil
}