Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature]: incremental saving checkpoint #174

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include easy_rec/python/ops/1.12/*.so*
include easy_rec/python/ops/1.15/*.so*
16 changes: 10 additions & 6 deletions easy_rec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@
logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')

ops_dir = os.path.join(curr_dir, 'python/ops')
if 'PAI' in tf.__version__:
ops_dir = os.path.join(ops_dir, '1.12_pai')
elif tf.__version__.startswith('1.12'):
ops_dir = os.path.join(ops_dir, '1.12')
elif tf.__version__.startswith('1.15'):
ops_dir = os.path.join(ops_dir, '1.15')
else:
ops_dir = None

from easy_rec.python.inference.predictor import Predictor # isort:skip # noqa: E402
from easy_rec.python.main import evaluate # isort:skip # noqa: E402
from easy_rec.python.main import distribute_evaluate # isort:skip # noqa: E402
Expand All @@ -32,12 +42,6 @@

_global_config = {}

ops_dir = os.path.join(curr_dir, 'python/ops')
if tf.__version__.startswith('1.12'):
ops_dir = os.path.join(ops_dir, '1.12')
elif tf.__version__.startswith('1.15'):
ops_dir = os.path.join(ops_dir, '1.15')


def help():
print("""
Expand Down
21 changes: 19 additions & 2 deletions easy_rec/python/compat/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from tensorflow.python.training import moving_averages
from tensorflow.python.training import optimizer as optimizer_
from tensorflow.python.training import training as train
from easy_rec.python.ops.incr_record import set_sparse_indices
import tensorflow as tf
import logging

OPTIMIZER_CLS_NAMES = {
'Adagrad':
Expand Down Expand Up @@ -75,7 +78,8 @@ def optimize_loss(loss,
summaries=None,
colocate_gradients_with_ops=False,
not_apply_grad_after_first_step=False,
increment_global_step=True):
increment_global_step=True,
incr_save=False):
"""Given loss and parameters for optimizer, returns a training op.

Various ways of passing optimizers include:
Expand Down Expand Up @@ -146,6 +150,7 @@ def optimize_loss(loss,
calls `optimize_loss` multiple times per training step (e.g. to optimize
different parts of the model), use this arg to avoid incrementing
`global_step` more times than necessary.
incr_save: increment dump checkpoints.

Returns:
Training op.
Expand Down Expand Up @@ -300,11 +305,23 @@ def optimize_loss(loss,

# Create gradient updates.
def _apply_grad():
incr_save_ops = []
if incr_save:
for grad, var in gradients:
if isinstance(grad, ops.IndexedSlices):
with ops.colocate_with(var):
incr_save_op = set_sparse_indices(grad.indices, var_name=var.op.name)
incr_save_ops.append(incr_save_op)
ops.add_to_collection('SPARSE_UPDATE_VARIABLES', (var, grad.indices.dtype))
else:
ops.add_to_collection('DENSE_UPDATE_VARIABLES', var)

grad_updates = opt.apply_gradients(
gradients,
global_step=global_step if increment_global_step else None,
name='train')
return control_flow_ops.with_dependencies([grad_updates], loss)

return control_flow_ops.with_dependencies([grad_updates] + incr_save_ops, loss)

if not_apply_grad_after_first_step:
train_tensor = control_flow_ops.cond(global_step > 0, lambda: loss,
Expand Down
171 changes: 117 additions & 54 deletions easy_rec/python/input/datahub_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import time

import numpy as np
import json
import tensorflow as tf
import traceback

from easy_rec.python.input.input import Input
from easy_rec.python.utils import odps_util
Expand All @@ -18,14 +20,18 @@
from datahub.exceptions import DatahubException
from datahub.models import RecordType
from datahub.models import CursorType
except Exception:
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
logging.getLogger('datahub.account').setLevel(logging.INFO)
except Exception as ex:
# logging.warning(traceback.format_exc(ex))
logging.warning(
'DataHub is not installed. You can install it by: pip install pydatahub')
DataHub = None


class DataHubInput(Input):
"""Common IO based interface, could run at local or on data science."""
"""DataHubInput is used for online train."""


def __init__(self,
data_config,
Expand All @@ -35,27 +41,59 @@ def __init__(self,
task_num=1):
super(DataHubInput, self).__init__(data_config, feature_config, '',
task_index, task_num)
if DataHub is None:
logging.error('please install datahub: ',
'pip install pydatahub ;Python 3.6 recommended')

try:
self._datahub_config = datahub_config
if self._datahub_config is None:
pass
self._datahub = DataHub(self._datahub_config.akId,
self._datahub_config.akSecret,
self._datahub_config.region)
self._num_epoch = 0
self._datahub_config = datahub_config
if self._datahub_config is not None:
akId = self._datahub_config.akId
akSecret = self._datahub_config.akSecret
region = self._datahub_config.region
if not isinstance(akId, str):
akId = akId.encode('utf-8')
akSecret = akSecret.encode('utf-8')
region = region.encode('utf-8')
self._datahub = DataHub(akId, akSecret, region)
else:
self._datahub = None
except Exception as ex:
logging.info('exception in init datahub:', str(ex))
logging.info('exception in init datahub: %s' % str(ex))
pass
self._offset_dict = {}
if datahub_config:
if self._datahub_config.offset_info:
self._offset_dict = json.loads(self._datahub_config.offset_info)
shard_result = self._datahub.list_shard(self._datahub_config.project,
self._datahub_config.topic)
shards = shard_result.shards
self._shards = [shards[i] for i in range(len(shards)) if (i % task_num) == task_index]
logging.info('all shards: %s' % str(self._shards))
offset_dict = {}
for x in self._shards:
if x.shard_id in self._offset_dict:
offset_dict[x.shard_id] = self._offset_dict[x.shard_id]
self._offset_dict = offset_dict

def _parse_record(self, *fields):
fields = list(fields)
inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
field_dict = {self._input_fields[x]: fields[x] for x in self._effective_fids}
for x in self._label_fids:
inputs[self._input_fields[x]] = fields[x]
return inputs
field_dict[self._input_fields[x]] = fields[x]
field_dict[Input.DATA_OFFSET] = fields[-1]
return field_dict

def _preprocess(self, field_dict):
output_dict = super(DataHubInput, self)._preprocess(field_dict)

# append offset fields
if Input.DATA_OFFSET in field_dict:
output_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]

# for _get_features to include DATA_OFFSET
if Input.DATA_OFFSET not in self._appended_fields:
self._appended_fields.append(Input.DATA_OFFSET)

return output_dict

def _datahub_generator(self):
logging.info('start epoch[%d]' % self._num_epoch)
Expand All @@ -65,62 +103,87 @@ def _datahub_generator(self):
self.get_type_defaults(x, v)
for x, v in zip(self._input_field_types, self._input_field_defaults)
]
batch_defaults = [
np.array([x] * self._data_config.batch_size) for x in record_defaults
batch_data = [
np.asarray([x] * self._data_config.batch_size, order='C', dtype=object)
if isinstance(x, str) else
np.array([x] * self._data_config.batch_size)
for x in record_defaults
]
batch_data.append(json.dumps(self._offset_dict))

try:
self._datahub.wait_shards_ready(self._datahub_config.project,
self._datahub_config.topic)
topic_result = self._datahub.get_topic(self._datahub_config.project,
self._datahub_config.topic)
if topic_result.record_type != RecordType.TUPLE:
logging.error('topic type illegal !')
logging.error('datahub topic type(%s) illegal' % str(topic_result.record_type))
record_schema = topic_result.record_schema
shard_result = self._datahub.list_shard(self._datahub_config.project,
self._datahub_config.topic)
shards = shard_result.shards
for shard in shards:
shard_id = shard._shard_id
cursor_result = self._datahub.get_cursor(self._datahub_config.project,

batch_size = self._data_config.batch_size

tid = 0
while True:
shard_id = self._shards[tid].shard_id
tid += 1
if tid >= len(self._shards):
tid = 0
if shard_id not in self._offset_dict:
cursor_result = self._datahub.get_cursor(self._datahub_config.project,
self._datahub_config.topic,
shard_id, CursorType.OLDEST)
cursor = cursor_result.cursor
limit = self._data_config.batch_size
while True:
get_result = self._datahub.get_tuple_records(
self._datahub_config.project, self._datahub_config.topic,
shard_id, record_schema, cursor, limit)
batch_data_np = [x.copy() for x in batch_defaults]
for row_id, record in enumerate(get_result.records):
for col_id in range(len(record_defaults)):
if record.values[col_id] not in ['', 'Null', None]:
batch_data_np[col_id][row_id] = record.values[col_id]
yield tuple(batch_data_np)
if 0 == get_result.record_count:
time.sleep(1)
cursor = get_result.next_cursor
except DatahubException as e:
logging.error(e)
cursor = cursor_result.cursor
else:
cursor = self._offset_dict[shard_id]['cursor']

get_result = self._datahub.get_tuple_records(
self._datahub_config.project, self._datahub_config.topic,
shard_id, record_schema, cursor, batch_size)
count = get_result.record_count
if count == 0:
continue
time_offset = 0
sequence_offset = 0
for row_id, record in enumerate(get_result.records):
if record.system_time > time_offset:
time_offset = record.system_time
if record.sequence > sequence_offset:
sequence_offset = record.sequence
for col_id in range(len(record_defaults)):
if record.values[col_id] not in ['', 'Null', 'null', 'NULL', None]:
batch_data[col_id][row_id] = record.values[col_id]
else:
batch_data[col_id][row_id] = record_defaults[col_id]
cursor = get_result.next_cursor
self._offset_dict[shard_id] = {'sequence_offset': sequence_offset,
'time_offset': time_offset,
'cursor': cursor
}
batch_data[-1] = json.dumps(self._offset_dict)
yield tuple(batch_data)
except DatahubException as ex:
logging.error('DatahubException: %s' % str(ex))

def _build(self, mode, params):
# get input type
list_type = [self.get_tf_type(x) for x in self._input_field_types]
list_type = tuple(list_type)
list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))]
# get input types
list_types = [self.get_tf_type(x) for x in self._input_field_types]
list_types.append(tf.string)
list_types = tuple(list_types)
list_shapes = [tf.TensorShape([None]) for x in range(0, len(self._input_field_types))]
list_shapes.append(tf.TensorShape([]))
list_shapes = tuple(list_shapes)
# read datahub
dataset = tf.data.Dataset.from_generator(
self._datahub_generator,
output_types=list_type,
output_types=list_types,
output_shapes=list_shapes)
if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.shuffle(
self._data_config.shuffle_buffer_size,
seed=2020,
reshuffle_each_iteration=True)
dataset = dataset.repeat(self.num_epochs)
else:
dataset = dataset.repeat(1)
if self._data_config.shuffle:
dataset = dataset.shuffle(
self._data_config.shuffle_buffer_size,
seed=2020,
reshuffle_each_iteration=True)

dataset = dataset.map(
self._parse_record,
num_parallel_calls=self._data_config.num_parallel_calls)
Expand Down
2 changes: 2 additions & 0 deletions easy_rec/python/input/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

class Input(six.with_metaclass(_meta_type, object)):

DATA_OFFSET = 'DATA_OFFSET'

def __init__(self,
data_config,
feature_configs,
Expand Down
Loading