Skip to content

Commit

Permalink
Merge pull request #1 from seahrh/aug22
Browse files Browse the repository at this point in the history
Aug22
  • Loading branch information
seahrh authored Aug 26, 2018
2 parents 290be32 + ea0d15a commit 26926b0
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 88 deletions.
39 changes: 27 additions & 12 deletions trainer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,18 @@
class BeamSearchDecoder(object):
"""Beam search decoder."""

def __init__(self, model, batcher, vocab, hps, single_pass, log_root, pointer_gen, data_path, beam_size):
def __init__(
self,
model,
batcher,
vocab,
hps,
single_pass,
pointer_gen,
data_path,
beam_size,
conf
):
"""Initialize decoder.
Args:
Expand All @@ -49,15 +60,15 @@ def __init__(self, model, batcher, vocab, hps, single_pass, log_root, pointer_ge
self._vocab = vocab
self._hps = hps
self._single_pass = single_pass
self._log_root = log_root
self._pointer_gen = pointer_gen
self._data_path = data_path
self._beam_size = beam_size
self._saver = tf.train.Saver() # we use this to load checkpoints for decoding
self._sess = tf.Session(config=util.get_config())
self._sess = tf.Session(config=conf.session_config)
self._conf = conf

# Load an initial checkpoint to use for decoding
ckpt_path = util.load_ckpt(self._saver, self._sess, log_root=self._log_root)
ckpt_path = util.load_ckpt(self._saver, self._sess, log_root=self._conf.model_dir)

if self._single_pass:
# Make a descriptive decode directory name
Expand All @@ -70,25 +81,28 @@ def __init__(self, model, batcher, vocab, hps, single_pass, log_root, pointer_ge
min_dec_steps=hps.min_dec_steps,
max_dec_steps=hps.max_dec_steps
)
self._decode_dir = os.path.join(self._log_root, dir_name)
self._decode_dir = os.path.join(self._conf.model_dir, dir_name)
if os.path.exists(self._decode_dir):
raise Exception("single_pass decode directory %s should not already exist" % self._decode_dir)

else: # Generic decode dir name
self._decode_dir = os.path.join(self._log_root, Modes.PREDICT)
self._decode_dir = os.path.join(self._conf.model_dir, Modes.PREDICT)

# Make the decode dir if necessary
if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir)

if self._single_pass:
# Make the dirs to contain output written in the correct format for pyrouge
self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir)
if not os.path.exists(self._rouge_ref_dir):
os.mkdir(self._rouge_ref_dir)
self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded")
if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir)
if not os.path.exists(self._rouge_dec_dir):
os.mkdir(self._rouge_dec_dir)

def decode(self):
"""Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
"""Decode examples until data is exhausted (if FLAGS.single_pass) and return,
or decode indefinitely, loading latest checkpoint at regular intervals"""
t0 = time.time()
counter = 0
while True:
Expand All @@ -97,7 +111,7 @@ def decode(self):
assert self._single_pass, "Dataset exhausted, but we are not in single_pass mode"
log.info("Decoder has finished reading dataset for single_pass.")
log.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir,
self._rouge_dec_dir)
self._rouge_dec_dir)
results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
rouge_log(results_dict, self._decode_dir)
return
Expand Down Expand Up @@ -147,7 +161,7 @@ def decode(self):
log.info(
'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
t1 - t0)
_ = util.load_ckpt(self._saver, self._sess, log_root=self._log_root)
_ = util.load_ckpt(self._saver, self._sess, log_root=self._conf.model_dir)
t0 = time.time()

def write_for_rouge(self, reference_sents, decoded_words, ex_index):
Expand Down Expand Up @@ -242,7 +256,8 @@ def rouge_log(results_dict, dir_to_write):


def get_decode_dir_name(ckpt_name, data_path, beam_size, max_enc_steps, min_dec_steps, max_dec_steps):
"""Make a descriptive name for the decode dir, including the name of the checkpoint we use to decode. This is called in single_pass mode."""
"""Make a descriptive name for the decode dir,
including the name of the checkpoint we use to decode. This is called in single_pass mode."""

if "train" in data_path:
dataset = "train"
Expand Down
46 changes: 27 additions & 19 deletions trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,20 @@ class SummarizationModel(object):
Supports both baseline mode, pointer-generator mode, and coverage
"""

def __init__(self, hps, vocab, mode, pointer_gen, coverage, log_root, cluster_spec):
def __init__(self, hps, vocab, mode, pointer_gen, coverage, conf):
self._hps = hps
self._vocab = vocab
self._mode = mode
self._pointer_gen = pointer_gen
self._coverage = coverage
self._log_root = log_root
# The model is configured with max_dec_steps=1
# because we only ever run one step of the decoder at a time (to do beam search).
# Note that the batcher is initialized with max_dec_steps equal to e.g. 100
# because the batches need to contain the full summaries
self._max_dec_steps = 1 if mode == Modes.PREDICT else self._hps.max_dec_steps
self._cluster_spec = cluster_spec
self._conf = conf
self.global_step = None
self._summaries = None

def _add_placeholders(self):
"""Add placeholders to the graph. These are entry points for any input data."""
Expand Down Expand Up @@ -166,9 +167,7 @@ def _add_decoder(self, inputs):
self._enc_states,
self._enc_padding_mask,
cell,
initial_state_attention=(
self._mode == Modes.PREDICT
),
initial_state_attention=(self._mode == Modes.PREDICT),
pointer_gen=self._pointer_gen,
use_coverage=self._coverage,
prev_coverage=prev_coverage
Expand Down Expand Up @@ -247,7 +246,7 @@ def _add_seq2seq(self):
embedding = tf.get_variable('embedding', [vsize, hps.emb_dim], dtype=tf.float32,
initializer=self.trunc_norm_init)
if self._mode == "train":
self._add_emb_vis(embedding, log_root=self._log_root) # add to tensorboard
self._add_emb_vis(embedding, log_root=self._conf.model_dir) # add to tensorboard
emb_enc_inputs = tf.nn.embedding_lookup(embedding,
self._enc_batch) # tensor with shape (batch_size, max_enc_steps, emb_size)
emb_dec_inputs = [tf.nn.embedding_lookup(embedding, x) for x in tf.unstack(self._dec_batch,
Expand Down Expand Up @@ -337,29 +336,38 @@ def _add_train_op(self):
gradients = tf.gradients(loss_to_minimize, tvars, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE)

# Clip the gradients
with tf.device(tf.train.replica_device_setter(cluster=self._cluster_spec)):
grads, global_norm = tf.clip_by_global_norm(gradients, self._hps.max_grad_norm)
grads, global_norm = tf.clip_by_global_norm(gradients, self._hps.max_grad_norm)

# Add a summary
tf.summary.scalar('global_norm', global_norm)

# Apply adagrad optimizer
optimizer = tf.train.AdagradOptimizer(self._hps.lr, initial_accumulator_value=self._hps.adagrad_init_acc)
with tf.device(tf.train.replica_device_setter(cluster=self._cluster_spec)):
self._train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step,
name='train_step')
optimizer = tf.train.AdagradOptimizer(
self._hps.lr,
initial_accumulator_value=self._hps.adagrad_init_acc
)
self._train_op = optimizer.apply_gradients(
zip(grads, tvars),
global_step=self.global_step,
name='train_step'
)

def build_graph(self):
"""Add the placeholders, model, global step, train_op and summaries to the graph"""
log.info('Building graph...')
t0 = time.time()
self._add_placeholders()
with tf.device(tf.train.replica_device_setter(cluster=self._cluster_spec)):
with tf.device(tf.train.replica_device_setter(cluster=self._conf.cluster_spec)):
self._add_placeholders()
self._add_seq2seq()
self.global_step = tf.Variable(0, name='global_step', trainable=False)
if self._mode == 'train':
self._add_train_op()
self._summaries = tf.summary.merge_all()
self.global_step = tf.get_variable(
'global_step',
dtype=tf.int32,
initializer=tf.constant(0),
trainable=False
)
if self._mode == Modes.TRAIN:
self._add_train_op()
self._summaries = tf.summary.merge_all()
t1 = time.time()
log.info('Time to build graph: %i seconds', t1 - t0)

Expand Down
67 changes: 28 additions & 39 deletions trainer/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,44 +59,44 @@ def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.
return running_avg_loss


def __restore_best_model(log_root):
def __restore_best_model(conf):
"""Load bestmodel file from eval directory, add variables for adagrad, and save to train directory"""
log.info("Restoring bestmodel for training...")

# Initialize all vars in the model
sess = tf.Session(config=util.get_config())
sess = tf.Session(config=conf.session_config)
log.info("Initializing all variables...")
sess.run(tf.global_variables_initializer())

# Restore the best model from eval dir
saver = tf.train.Saver([v for v in tf.global_variables() if "Adagrad" not in v.name])
log.info("Restoring all non-adagrad variables from best model in eval dir...")
curr_ckpt = util.load_ckpt(saver, sess, log_root=log_root, ckpt_dir="eval")
curr_ckpt = util.load_ckpt(saver, sess, log_root=conf.model_dir, ckpt_dir="eval")
log.info("Restored %s." % curr_ckpt)

# Save this model to train dir and quit
new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model")
new_fname = os.path.join(log_root, "train", new_model_name)
new_fname = os.path.join(conf.model_dir, "train", new_model_name)
log.info("Saving model to %s..." % new_fname)
new_saver = tf.train.Saver() # this saver saves all variables that now exist, including Adagrad variables
new_saver.save(sess, new_fname)
log.info("Saved.")
exit()


def __convert_to_coverage_model(log_root):
def __convert_to_coverage_model(conf):
"""Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint"""
log.info("converting non-coverage model to coverage model..")

# initialize an entire coverage model from scratch
sess = tf.Session(config=util.get_config())
sess = tf.Session(config=conf.session_config)
log.info("initializing everything...")
sess.run(tf.global_variables_initializer())

# load all non-coverage weights from checkpoint
saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name])
log.info("restoring non-coverage variables...")
curr_ckpt = util.load_ckpt(saver, sess, log_root=log_root)
curr_ckpt = util.load_ckpt(saver, sess, log_root=conf.model_dir)
log.info("restored.")

# save this model and quit
Expand All @@ -111,16 +111,15 @@ def __convert_to_coverage_model(log_root):
def setup_training(
model,
batcher,
log_root,
convert_to_coverage_model,
coverage,
restore_best_model,
debug,
max_step,
is_chief
conf
):
"""Does setup before starting training (run_training)"""
train_dir = os.path.join(log_root, "train")
train_dir = os.path.join(conf.model_dir, "train")
if not os.path.exists(train_dir):
os.makedirs(train_dir)
model.build_graph() # build the graph
Expand All @@ -129,27 +128,27 @@ def setup_training(
To convert your non-coverage model to a coverage model,
run with convert_to_coverage_model=True and coverage=True\
"""
__convert_to_coverage_model(log_root)
__convert_to_coverage_model(conf=conf)
if restore_best_model:
__restore_best_model(log_root)
__restore_best_model(conf=conf)
try:
# this is an infinite loop until interrupted
run_training(model, batcher, train_dir,
coverage=coverage, debug=debug, max_step=max_step, is_chief=is_chief)
coverage=coverage, debug=debug, max_step=max_step, conf=conf)
except KeyboardInterrupt:
log.info("Caught keyboard interrupt on worker. Stopping...")


def __train_session(train_dir, is_chief, debug):
def __train_session(train_dir, debug, conf):
sess = tf.train.MonitoredTrainingSession(
checkpoint_dir=train_dir, # required to restore variables!
summary_dir=train_dir,
is_chief=True,
is_chief=conf.is_chief,
save_summaries_secs=60,
save_checkpoint_secs=60,
max_wait_secs=60,
stop_grace_period_secs=60,
config=util.get_config(),
config=conf.session_config,
scaffold=tf.train.Scaffold(
saver=tf.train.Saver(max_to_keep=3)
)
Expand All @@ -160,11 +159,11 @@ def __train_session(train_dir, is_chief, debug):
return sess


def run_training(model, batcher, train_dir, coverage, debug, max_step, is_chief):
def run_training(model, batcher, train_dir, coverage, debug, max_step, conf):
"""Repeatedly runs training iterations, logging loss to screen and writing summaries"""
log.debug("starting run_training")
summary_writer = tf.summary.FileWriterCache.get(train_dir)
with __train_session(train_dir=train_dir, is_chief=is_chief, debug=debug) as sess:
with __train_session(train_dir=train_dir, debug=debug, conf=conf) as sess:
train_step = 0
# repeats until max_step is reached
while not sess.should_stop() and train_step <= max_step:
Expand All @@ -186,23 +185,23 @@ def run_training(model, batcher, train_dir, coverage, debug, max_step, is_chief)
summary_writer.add_summary(summaries, train_step) # write the summaries


def run_eval(model, batcher, log_root, coverage):
def run_eval(model, batcher, coverage, conf):
"""
Repeatedly runs eval iterations, logging to screen and writing summaries.
Saves the model with the best loss seen so far.
"""
model.build_graph() # build the graph
saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
sess = tf.Session(config=util.get_config())
eval_dir = os.path.join(log_root, "eval") # make a subdir of the root dir for eval data
sess = tf.Session(config=conf.session_config)
eval_dir = os.path.join(conf.model_dir, "eval") # make a subdir of the root dir for eval data
bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved
summary_writer = tf.summary.FileWriter(eval_dir)
# the eval job keeps a smoother, running average loss to tell it when to implement early stopping
running_avg_loss = 0
best_loss = None # will hold the best loss achieved so far

while True:
_ = util.load_ckpt(saver, sess, log_root=log_root) # load a new checkpoint
_ = util.load_ckpt(saver, sess, log_root=conf.model_dir) # load a new checkpoint
batch = batcher.next_batch() # get the next batch

# run eval on the batch
Expand Down Expand Up @@ -294,17 +293,9 @@ def __main(
log.info('Starting seq2seq_attention in %s mode...', mode)
log_root = __log_root(log_root, exp_name, mode)
vocab = Vocab(vocab_path, vocab_size) # create a vocabulary
tf_config = util.tf_config()
hps = __hparams(**hparams)
log.info(f'hps={repr(hps)}\ntf_config={repr(tf_config)}')
conf = tf.estimator.RunConfig(
model_dir=log_root,
session_config=util.get_config(),
save_checkpoints_secs=60,
save_summary_steps=100,
keep_checkpoint_max=3,
tf_random_seed=random_seed
)
conf = util.run_config(model_dir=log_root, random_seed=random_seed)
log.info(f'hps={repr(hps)}\nconf={util.repr_run_config(conf)}')

# Create a batcher object that will create minibatches of data
batcher = Batcher(
Expand All @@ -321,31 +312,29 @@ def __main(
mode=mode,
pointer_gen=pointer_gen,
coverage=coverage,
log_root=log_root,
cluster_spec=tf_config['cluster_spec']
conf=conf
)
if mode == Modes.TRAIN:
setup_training(
model,
batcher,
log_root=log_root,
convert_to_coverage_model=convert_to_coverage_model,
coverage=coverage,
restore_best_model=restore_best_model,
debug=debug,
max_step=hps.max_step,
is_chief=tf_config['is_chief']
conf=conf
)
elif mode == Modes.EVAL:
run_eval(model, batcher, log_root=log_root, coverage=coverage)
run_eval(model, batcher, coverage=coverage, conf=conf)
elif mode == Modes.PREDICT:
decoder = BeamSearchDecoder(model, batcher, vocab,
hps=hps,
single_pass=single_pass,
log_root=log_root,
pointer_gen=pointer_gen,
data_path=data_path,
beam_size=beam_size
beam_size=beam_size,
conf=conf
)
# decode indefinitely
# (unless single_pass=True, in which case deocde the dataset exactly once)
Expand Down
Loading

0 comments on commit 26926b0

Please sign in to comment.