forked from onnx/onnx-tensorrt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathonnxProtoUtils.cpp
80 lines (73 loc) · 2.13 KB
/
onnxProtoUtils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnxProtoUtils.hpp"
#include <fstream>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <iostream>
#include <onnx/onnx_pb.h>
#include <sstream>
namespace onnx2trt
{
void removeRawDataStrings(std::string& s)
{
std::string::size_type beg = 0;
const std::string key = "raw_data: \"";
const std::string sub = "...";
while ((beg = s.find(key, beg)) != std::string::npos)
{
beg += key.length();
std::string::size_type end = beg - 1;
// Note: Must skip over escaped end-quotes
while (s[(end = s.find("\"", ++end)) - 1] == '\\')
{
}
if (end - beg > 128)
{ // Only remove large data strings
s.replace(beg, end - beg, "...");
}
beg += sub.length();
}
}
std::string removeRepeatedDataStrings(std::string const& s)
{
std::istringstream iss(s);
std::ostringstream oss;
bool is_repeat = false;
for (std::string line; std::getline(iss, line);)
{
if (line.find("float_data:") != std::string::npos || line.find("int32_data:") != std::string::npos
|| line.find("int64_data:") != std::string::npos)
{
if (!is_repeat)
{
is_repeat = true;
oss << line.substr(0, line.find(":") + 1) << " ...\n";
}
}
else
{
is_repeat = false;
oss << line << "\n";
}
}
return oss.str();
}
std::string convertProtoToString(::google::protobuf::Message const& message)
{
std::string s;
::google::protobuf::TextFormat::PrintToString(message, &s);
removeRawDataStrings(s);
s = removeRepeatedDataStrings(s);
return s;
}
std::string onnxIRVersionAsString(int64_t irVersion)
{
int64_t verMajor = irVersion / 1000000;
int64_t verMinor = irVersion % 1000000 / 10000;
int64_t verPatch = irVersion % 10000;
return (std::to_string(verMajor) + "." + std::to_string(verMinor) + "." + std::to_string(verPatch));
}
} // namespace onnx2trt