Skip to content

Commit

Permalink
Add tests for pytorch ragged iterator
Browse files Browse the repository at this point in the history
Signed-off-by: Szymon Karpiński <skarpinski@nvidia.com>
  • Loading branch information
szkarpinski committed Jan 16, 2024
1 parent 60c54fb commit e329611
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,12 @@ def get_fw_iterator_class(self):

def equal(self, a, b):
return (a == b).all()

class TestPytorchRagged(FwTestBase):
def get_fw_iterator_class(self):
from nvidia.dali.plugin.pytorch import DALIRaggedIterator

return DALIRaggedIterator

def equal(self, a, b):
return (a == b).all()
1 change: 1 addition & 0 deletions qa/TL0_python-self-test-core/test_body.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ test_type_annotations() {
test_pytorch() {
${python_invoke_test} --attr '!slow,pytorch' test_dali_variable_batch_size.py
${python_new_invoke_test} -A '!slow' checkpointing.test_dali_checkpointing_fw_iterators.TestPytorch
${python_new_invoke_test} -A '!slow' checkpointing.test_dali_checkpointing_fw_iterators.TestPytorchRagged
${python_new_invoke_test} -A 'pytorch' -s type_annotations
}

Expand Down

0 comments on commit e329611

Please sign in to comment.