Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
Signed-off-by: Szymon Karpiński <skarpinski@nvidia.com>
  • Loading branch information
szkarpinski committed Jan 16, 2024
1 parent 50447fe commit 75122b9
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nvidia.dali.pipeline import pipeline_def
from nose2.tools import params, cartesian_params


class FwTestBase:
FwIterator = None

Expand All @@ -43,27 +44,26 @@ def check_pipeline_checkpointing(self, pipeline_factory, reader_name=None, size=

restored = pipeline_factory(**pipeline_args, checkpoint=iter.checkpoints()[0])
restored.build()
iter2 = self.FwIterator(restored, ["data"], auto_reset=True, reader_name=reader_name, size=size)
iter2 = self.FwIterator(
restored, ["data"], auto_reset=True, reader_name=reader_name, size=size
)

for out1, out2 in zip(iter, iter2):
for d1, d2 in zip(out1, out2):
for key in d1.keys():
assert self.equal(d1[key], d2[key])


def check_single_input_operator(self, op, device, **kwargs):
pipeline_factory = check_single_input_operator_pipeline(op, device, **kwargs)
self.check_pipeline_checkpointing(pipeline_factory, reader_name="Reader")


def check_no_input_operator(self, op, device, **kwargs):
@pipeline_def
def pipeline_factory():
return op(device=device, **kwargs)

self.check_pipeline_checkpointing(pipeline_factory, size=8)


# Reader tests section

@params(
Expand Down

0 comments on commit 75122b9

Please sign in to comment.