-
Notifications
You must be signed in to change notification settings - Fork 0
/
tf_read_write_schema.py
124 lines (99 loc) · 4.14 KB
/
tf_read_write_schema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 5 19:20:05 2021
@author: edwardcui
"""
import tensorflow as tf
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
# BytesList won't unpack a string from an EagerTensor.
value = value.numpy()
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def tensors2tfrecord(filename, **kwargs):
"""
Write a {feature_name: feature_val} dict to a TFRecord file.
Parameters
----------
filename : str
Filename of the tfrecord to save to
kwargs: dict
Input the feature data as {feature_name: feature_val},
or feature_name=feature_val
The saved serialized tfrecord can be read with
tf.data.experimental.make_batched_features_dataset
"""
feature_name, features = tuple(zip(*kwargs.items()))
def serialize_example(*args):
"""
Creates a tf.train.Example message ready to be written to a file.
"""
# Create a dictionary mapping the feature name to the tf.train.Example-compatible
# data type.
feature = {}
for i, val in enumerate(args):
if val.dtype in [tf.int32, tf.int64]:
casted_val = _int64_feature(val)
elif val.dtype in [tf.float16, tf.float32, tf.float64]:
casted_val = _float_feature(val)
else:
casted_val = _bytes_feature(val)
key = feature_name[i]
feature[key] = casted_val
# Create a Features message using tf.train.Example
example_proto = tf.train.Example(
features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
def tf_serialize_example(*args):
tf_string = tf.py_function(
serialize_example,
args, # pass these args to the above function.
tf.string) # the return type is `tf.string`.
return tf.reshape(tf_string, ()) # The result is a scalar
feature_dataset = tf.data.Dataset.from_tensor_slices(features)
serialized_features_dataset = feature_dataset.map(tf_serialize_example)
# write to tf record
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)
def tfrecord2dataset(file_pattern, feature_spec, label_key, batch_size=5,
num_epochs=2):
"""Returns:
A dataset that contains (features, indices) tuple where features is a
dictionary of Tensors, and indices is a single Tensor of label indices.
"""
dataset = tf.data.experimental.make_batched_features_dataset(
file_pattern=file_pattern,
batch_size=batch_size,
num_epochs=num_epochs,
features=feature_spec,
label_key=label_key)
#dataset = tf.data.TFRecord()
return dataset
def test_tensors2tfrecord():
S = tf.constant([[1,2, 3], [4, 5, 6], [1, 2, 3], [6, 7, 8], [2, 3, 5], [3, 5, 7]])
tensors2tfrecord("temp.tfrecord", name=S[:, 0],
content=S[:, 1],
weight=tf.cast(S[:, 2], "float32"))
def test_tfrecord2dataset():
# Read
feature_spec = {
"name": tf.io.FixedLenFeature([], dtype=tf.int64),
"content": tf.io.FixedLenFeature([], dtype=tf.int64),
"weight": tf.io.FixedLenFeature([], dtype=tf.float32)}
def map_fn(x, y):
return (x["content"], x["name"]), y
loaded_dataset = tfrecord2dataset(["temp.tfrecord"],
feature_spec,
label_key="weight",
batch_size=5
).map(map_fn)
for i, d in enumerate(loaded_dataset):
print(i)
print(d)