-
Notifications
You must be signed in to change notification settings - Fork 40
/
transformer.py
61 lines (49 loc) · 2.15 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
from objectmapper import ObjectMapper
from reader import Reader
class Transformer(object):
def __init__(self, xml_dir, out_dir, class_file):
self.xml_dir = xml_dir
self.out_dir = out_dir
self.class_file = class_file
def transform(self):
reader = Reader(xml_dir=self.xml_dir)
xml_files = reader.get_xml_files()
classes = reader.get_classes(self.class_file)
object_mapper = ObjectMapper()
annotations = object_mapper.bind_files(xml_files, xml_dir=self.xml_dir)
self.write_to_txt(annotations, classes)
def write_to_txt(self, annotations, classes):
for annotation in annotations:
output_path = os.path.join(self.out_dir, self.darknet_filename_format(annotation.filename))
if not os.path.exists(os.path.dirname(output_path)):
os.makedirs(os.path.dirname(output_path))
with open(output_path, "w+") as f:
f.write(self.to_darknet_format(annotation, classes))
def to_darknet_format(self, annotation, classes):
result = []
for obj in annotation.objects:
if obj.name not in classes:
print("Please, add '%s' to classes.txt file." % obj.name)
exit()
x, y, width, height = self.get_object_params(obj, annotation.size)
result.append("%d %.6f %.6f %.6f %.6f" % (classes[obj.name], x, y, width, height))
return "\n".join(result)
@staticmethod
def get_object_params(obj, size):
image_width = 1.0 * size.width
image_height = 1.0 * size.height
box = obj.box
absolute_x = box.xmin + 0.5 * (box.xmax - box.xmin)
absolute_y = box.ymin + 0.5 * (box.ymax - box.ymin)
absolute_width = box.xmax - box.xmin
absolute_height = box.ymax - box.ymin
x = absolute_x / image_width
y = absolute_y / image_height
width = absolute_width / image_width
height = absolute_height / image_height
return x, y, width, height
@staticmethod
def darknet_filename_format(filename):
pre, ext = os.path.splitext(filename)
return "%s.txt" % pre