Skip to content

Commit

Permalink
Add compare_iters helper
Browse files Browse the repository at this point in the history
Signed-off-by: Szymon Karpiński <skarpinski@nvidia.com>
  • Loading branch information
szkarpinski committed Jan 16, 2024
1 parent 75122b9 commit d438b8f
Showing 1 changed file with 9 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def equal(self, a, b):

# Helpers

def compare_iters(self, iter, iter2):
for out1, out2 in zip(iter, iter2):
for d1, d2 in zip(out1, out2):
for key in d1.keys():
assert self.equal(d1[key], d2[key])

def check_pipeline_checkpointing(self, pipeline_factory, reader_name=None, size=-1):
pipe = pipeline_factory(**pipeline_args)
pipe.build()
Expand All @@ -48,10 +54,7 @@ def check_pipeline_checkpointing(self, pipeline_factory, reader_name=None, size=
restored, ["data"], auto_reset=True, reader_name=reader_name, size=size
)

for out1, out2 in zip(iter, iter2):
for d1, d2 in zip(out1, out2):
for key in d1.keys():
assert self.equal(d1[key], d2[key])
self.compare_iters(iter, iter2)

def check_single_input_operator(self, op, device, **kwargs):
pipeline_factory = check_single_input_operator_pipeline(op, device, **kwargs)
Expand Down Expand Up @@ -120,10 +123,7 @@ def pipeline():
restored.build()
iter2 = self.FwIterator(restored, ["data", "labels"], auto_reset=True, reader_name="Reader")

for out1, out2 in zip(iter, iter2):
for d1, d2 in zip(out1, out2):
for key in d1.keys():
assert self.equal(d1[key], d2[key])
self.compare_iters(iter, iter2)

# Random operators section

Expand Down Expand Up @@ -169,10 +169,7 @@ def run(iterator, iterations):
restored.build()
iter2 = self.FwIterator(restored, ["data"], auto_reset=True, size=size)

for out1, out2 in zip(iter, iter2):
for d1, d2 in zip(out1, out2):
for key in d1.keys():
assert self.equal(d1[key], d2[key])
self.compare_iters(iter, iter2)

@cartesian_params(
((1, 1), (4, 5)), # (epoch size, batch size)
Expand Down

0 comments on commit d438b8f

Please sign in to comment.