Skip to content

Commit

Permalink
lint fixed (#240)
Browse files Browse the repository at this point in the history
* lint fixed

* lint fix

* removed binary file

* files for example data are restored

* removed some file

* renamed filter_component

* some fixes

* some fixes

* some lint fixes

* some lint fixes

* fixed some comments

* fixed some formating
  • Loading branch information
vulkomilev authored Jul 4, 2023
1 parent 4f728ab commit 8a276a7
Show file tree
Hide file tree
Showing 7 changed files with 788 additions and 0 deletions.
51 changes: 51 additions & 0 deletions examples/example_filter/data/test_data.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
label,col1
,2
,2
,2
,2
,2
,2
,2
,2
,2
,2
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
481 changes: 481 additions & 0 deletions examples/example_filter/filter_example_colab.ipynb

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions examples/example_filter/filter_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Filters the data from input data by using the filter function."""


def filter_function(x_list):
"""Filters the data from input data by using the filter function.
Args:
x_list: Input list of data to be filtered.
Returns:
filtered list
"""
new_list = []
for element in x_list:
if element['label'] == [0]:
new_list.append(element)
return new_list
96 changes: 96 additions & 0 deletions tfx_addons/example_filter/component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
the component for filter addon
"""

import importlib
import os

import tensorflow as tf
from tfx.dsl.component.experimental.annotations import OutputDict
from tfx.dsl.io.fileio import listdir
from tfx.types import standard_artifacts
from tfx.v1.dsl.components import InputArtifact, Parameter
from tfx_bsl.coders import example_coder


def _get_data_from_tfrecords(train_uri: str):
'''
Reads and returns data from TFRecords at URI as a list
of dictionaries with values as numpy arrays
Example:
_get_data_from_tfrecords('path_to_TFRecords')
'''
train_uri = [
os.path.join(train_uri, file_path) for file_path in listdir(train_uri)
]
raw_dataset = tf.data.TFRecordDataset(train_uri, compression_type='GZIP')

np_dataset = []
for tfrecord in raw_dataset:
serialized_example = tfrecord.numpy()
example = example_coder.ExampleToNumpyDict(serialized_example)
np_dataset.append(example)

return np_dataset


def filter_component(input_data: InputArtifact[standard_artifacts.Examples],
filter_function_str: Parameter[str],
output_file: Parameter[str]) -> OutputDict(list_len=int):
"""Filters the data from input data by using the filter function.
Args:
input_data: Input list of data to be filtered.
output_file: the name of the file to be saved to.
filter_function_str: Module name of the function that will be used to
filter the data.
Example for the function
my_example/my_filter.py:
# filter module must have filter_function implemented
def filter_function(input_list: Array):
output_list = []
for element in input_list:
if element.something:
output_list.append(element)
return output_list
pipeline.py:
filter_component(input_data ,'my_example.my_filter',output_data)
Returns:
len of the list after the filter
{
'list_len': len(output_list)
}
"""
records = _get_data_from_tfrecords(input_data.uri + "/Split-train")
filter_function = importlib.import_module(
filter_function_str).filter_function
filtered_data = filter_function(records)
result_len = len(filtered_data)
new_data = []
for key in list(filtered_data[0].keys()):
local_list = []
for i in range(result_len):
local_list.append(str(filtered_data[i][key][0]))
new_data.append(str(local_list))
writer = tf.io.TFRecordWriter(output_file)
writer.write(tf.data.Dataset.from_tensor_slices(new_data).map(lambda x: x))

return {'list_len': result_len}
43 changes: 43 additions & 0 deletions tfx_addons/example_filter/component_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Component test for the filter component."""

import os

import tensorflow as tf
from absl.testing import absltest
from tfx.types import artifact_utils, standard_artifacts

from tfx_addons.example_filter.component import filter_component


class ComponentTest(absltest.TestCase):
def testConstructWithOptions(self):
source_data_dir = os.path.join(os.path.dirname(__file__), 'data')

examples = standard_artifacts.Examples()
examples.uri = os.path.join(source_data_dir, "example_gen")
examples.split_names = artifact_utils.encode_split_names(['train', 'eval'])

params = {
"input_data": examples,
"filter_function_str": 'filter_function',
"output_file": 'output',
}
filter_component(**params)


if __name__ == '__main__':
tf.test.main()
51 changes: 51 additions & 0 deletions tfx_addons/example_filter/data/test_data.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
label,col1
,2
,2
,2
,2
,2
,2
,2
,2
,2
,2
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
1,1
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
0,0
33 changes: 33 additions & 0 deletions tfx_addons/example_filter/filter_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Example function to demonstrate the filter functionality of the module."""


def filter_function(x_list):
"""Filters the data from input data by using the filter function.
Args:
x_list: Input list of data to be filtered.
Returns:
filtered list
"""
new_list = []
for element in x_list:
if element['label'] == [0]:
new_list.append(element)
return new_list

0 comments on commit 8a276a7

Please sign in to comment.