-
Notifications
You must be signed in to change notification settings - Fork 23
/
demo.cpp
129 lines (102 loc) · 3.74 KB
/
demo.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#include <dirent.h>
#include <sys/stat.h>
#include <opencv2/opencv.hpp>
#include <argparse.hpp>
#include <SiamMask/siammask.h>
bool dirExists(const std::string& path)
{
struct stat info{};
if (stat(path.c_str(), &info) != 0)
return false;
return info.st_mode & S_IFDIR;
}
std::vector<std::string> listDir(const std::string& path, const std::vector<std::string>& match_ending)
{
static const auto ends_with = [](std::string const & value, std::string const & ending) -> bool
{
if (ending.size() > value.size()) return false;
return std::equal(ending.rbegin(), ending.rend(), value.rbegin());
};
if(!dirExists(path)) {
throw std::runtime_error(std::string("Directory not found: ") + path);
}
std::vector<std::string> files;
DIR *dir = opendir(path.c_str());
if(dir == nullptr)
return files;
struct dirent *pdirent;
while ((pdirent = readdir(dir)) != nullptr) {
std::string name(pdirent->d_name);
for(const auto& ending : match_ending){
if(ends_with(name, ending)) {
files.push_back(path + "/" + name);
break;
}
}
}
closedir(dir);
return files;
}
void overlayMask(const cv::Mat& src, const cv::Mat& mask, cv::Mat& dst) {
std::vector<cv::Mat> chans;
cv::split(src, chans);
cv::max(chans[2], mask, chans[2]);
cv::merge(chans, dst);
}
void drawBox(
cv::Mat& img, const cv::RotatedRect& box, const cv::Scalar& color,
int thickness = 1, int lineType = cv::LINE_8, int shift = 0
) {
cv::Point2f corners[4];
box.points(corners);
for(int i = 0; i < 4; ++i) {
cv::line(img, corners[i], corners[(i + 1) % 4], color, thickness, lineType, shift);
}
}
int main(int argc, const char* argv[]) try {
argparse::ArgumentParser parser;
parser.addArgument("-m", "--modeldir", 1, false);
parser.addArgument("-c", "--config", 1, false);
parser.addFinalArgument("target");
parser.parse(argc, argv);
torch::Device device(torch::kCUDA);
SiamMask siammask(parser.retrieve<std::string>("modeldir"), device);
State state;
state.load_config(parser.retrieve<std::string>("config"));
const std::string target_dir = parser.retrieve<std::string>("target");
std::vector<std::string> image_files = listDir(target_dir, {"jpg", "png", "bmp"});
std::sort(image_files.begin(), image_files.end());
std::cout << image_files.size() << " images found in " << target_dir << std::endl;
std::vector<cv::Mat> images;
for(const auto& image_file : image_files) {
images.push_back(cv::imread(image_file));
}
cv::namedWindow("SiamMask");
int64 toc = 0;
cv::Rect roi = cv::selectROI("SiamMask", images.front(), false);
if(roi.empty())
return EXIT_SUCCESS;
for(unsigned long i = 0; i < images.size(); ++i) {
int64 tic = cv::getTickCount();
cv::Mat& src = images[i];
if (i == 0) {
std::cout << "Initializing..." << std::endl;
siameseInit(state, siammask, src, roi, device);
cv::rectangle(src, roi, cv::Scalar(0, 255, 0));
} else {
siameseTrack(state, siammask, src, device);
overlayMask(src, state.mask, src);
drawBox(src, state.rotated_rect, cv::Scalar(0, 255, 0));
}
cv::imshow("SiamMask", src);
toc += cv::getTickCount() - tic;
cv::waitKey(1);
}
double total_time = toc / cv::getTickFrequency();
double fps = image_files.size() / total_time;
printf("SiamMask Time: %.1fs Speed: %.1ffps (with visulization!)\n", total_time, fps);
return EXIT_SUCCESS;
} catch (std::exception& e) {
std::cout << "Exception thrown!\n" << e.what() << std::endl;
return EXIT_FAILURE;
}