-
Notifications
You must be signed in to change notification settings - Fork 93
/
inputs.py
65 lines (56 loc) · 2.16 KB
/
inputs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""Input feature columns and input_fn for models.
Handles both training, evaluation and inference.
"""
import tensorflow as tf
def BuildTextExample(text, ngrams=None, label=None):
record = tf.train.Example()
text = [tf.compat.as_bytes(x) for x in text]
record.features.feature["text"].bytes_list.value.extend(text)
if label is not None:
label = tf.compat.as_bytes(label)
record.features.feature["label"].bytes_list.value.append(label)
if ngrams is not None:
ngrams = [tf.compat.as_bytes(x) for x in ngrams]
record.features.feature["ngrams"].bytes_list.value.extend(ngrams)
return record
def ParseSpec(use_ngrams, include_target):
parse_spec = {"text": tf.VarLenFeature(dtype=tf.string)}
if use_ngrams:
parse_spec["ngrams"] = tf.VarLenFeature(dtype=tf.string)
if include_target:
parse_spec["label"] = tf.FixedLenFeature(shape=(), dtype=tf.string,
default_value=None)
return parse_spec
def InputFn(mode,
use_ngrams,
input_file,
vocab_file,
vocab_size,
embedding_dimension,
num_oov_vocab_buckets,
label_file,
label_size,
ngram_embedding_dimension,
num_ngram_hash_buckets,
batch_size,
num_epochs=None,
num_threads=1):
if num_epochs <= 0:
num_epochs=None
def input_fn():
include_target = mode != tf.estimator.ModeKeys.PREDICT
parse_spec = ParseSpec(use_ngrams, include_target)
print("ParseSpec", parse_spec)
print("Input file:", input_file)
features = tf.contrib.learn.read_batch_features(
input_file, batch_size, parse_spec, tf.TFRecordReader,
num_epochs=num_epochs, reader_num_threads=num_threads)
label = None
if include_target:
label = features.pop("label")
return features, label
return input_fn
def ServingInputFn(use_ngrams):
parse_spec = ParseSpec(use_ngrams, include_target=False)
return tf.estimator.export.build_parsing_serving_input_receiver_fn(
parse_spec)