Skip to content

Commit

Permalink
update to TensorLayer 1.8.1
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Mar 15, 2018
1 parent 8ebb42f commit 14b6fd6
Show file tree
Hide file tree
Showing 113 changed files with 14,628 additions and 8,597 deletions.
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
*.gz
*.npz
*.pyc
*~
.DS_Store
.idea
.spyproject/
build/
dist
docs/_build
tensorlayer.egg-info
tensorlayer/__pacache__
5 changes: 2 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os, sys, pprint, time
import scipy.misc
import os, pprint, time
import numpy as np
import tensorflow as tf
import tensorlayer as tl
Expand Down Expand Up @@ -111,7 +110,7 @@ def main(_):
## load image data
batch_idxs = min(len(data_files), FLAGS.train_size) // FLAGS.batch_size

for idx in xrange(0, batch_idxs):
for idx in range(0, batch_idxs):
batch_files = data_files[idx*FLAGS.batch_size:(idx+1)*FLAGS.batch_size]
## get real images
# more image augmentation functions in http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html
Expand Down
18 changes: 8 additions & 10 deletions tensorlayer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""
Deep learning and Reinforcement learning library for Researchers and Engineers
"""
# from __future__ import absolute_import

"""Deep learning and Reinforcement learning library for Researchers and Engineers"""
from __future__ import absolute_import

try:
install_instr = "Please make sure you install a recent enough version of TensorFlow."
Expand All @@ -11,21 +8,22 @@
raise ImportError("__init__.py : Could not import TensorFlow." + install_instr)

from . import activation
act = activation
from . import cost
from . import files
# from . import init
from . import iterate
from . import layers
from . import ops
from . import utils
from . import visualize
from . import prepro # was preprocesse
from . import prepro
from . import nlp
from . import rein
from . import distributed

# alias
act = activation
vis = visualize

__version__ = "1.4.5"
__version__ = "1.8.1"

global_flag = {}
global_dict = {}
Binary file not shown.
Binary file added tensorlayer/__pycache__/__init__.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/_logging.cpython-34.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/_logging.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/activation.cpython-34.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/activation.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/cost.cpython-34.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/cost.cpython-35.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tensorlayer/__pycache__/files.cpython-34.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/files.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/iterate.cpython-34.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/iterate.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/nlp.cpython-34.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/nlp.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/prepro.cpython-34.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/prepro.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/rein.cpython-34.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/rein.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/utils.cpython-34.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/utils.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/visualize.cpython-34.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/visualize.cpython-35.pyc
Binary file not shown.
16 changes: 16 additions & 0 deletions tensorlayer/_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import logging as _logger

logging = _logger.getLogger('tensorlayer')
logging.setLevel(_logger.INFO)
_hander = _logger.StreamHandler()
formatter = _logger.Formatter('[TL] %(message)s')
_hander.setFormatter(formatter)
logging.addHandler(_hander)


def info(fmt, *args):
logging.info(fmt, *args)


def warning(fmt, *args):
logging.warning(fmt, *args)
163 changes: 104 additions & 59 deletions tensorlayer/activation.py
Original file line number Diff line number Diff line change
@@ -1,109 +1,154 @@
#! /usr/bin/python
# -*- coding: utf8 -*-
# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.python.util.deprecation import deprecated

__all__ = [
'identity',
'ramp',
'leaky_relu',
'swish',
'pixel_wise_softmax',
'linear',
'lrelu',
]

import tensorflow as tf

def identity(x, name=None):
"""The identity activation function, Shortcut is ``linear``.
@deprecated("2018-06-30", "This API will be deprecated soon as tf.identity can do the same thing.")
def identity(x):
"""The identity activation function.
Shortcut is ``linear``.
Parameters
----------
x : a tensor input
input(s)
x : Tensor
input.
Returns
--------
A `Tensor` with the same type as `x`.
-------
Tensor
A ``Tensor`` in the same type as ``x``.
"""
return x

# Shortcut
linear = identity

def ramp(x=None, v_min=0, v_max=1, name=None):
def ramp(x, v_min=0, v_max=1, name=None):
"""The ramp activation function.
Parameters
----------
x : a tensor input
input(s)
x : Tensor
input.
v_min : float
if input(s) smaller than v_min, change inputs to v_min
cap input to v_min as a lower bound.
v_max : float
if input(s) greater than v_max, change inputs to v_max
name : a string or None
An optional name to attach to this activation function.
cap input to v_max as a upper bound.
name : str
The function name (optional).
Returns
--------
A `Tensor` with the same type as `x`.
-------
Tensor
A ``Tensor`` in the same type as ``x``.
"""
return tf.clip_by_value(x, clip_value_min=v_min, clip_value_max=v_max, name=name)

def leaky_relu(x=None, alpha=0.1, name="LeakyReLU"):

def leaky_relu(x, alpha=0.1, name="lrelu"):
"""The LeakyReLU, Shortcut is ``lrelu``.
Modified version of ReLU, introducing a nonzero gradient for negative
input.
Modified version of ReLU, introducing a nonzero gradient for negative input.
Parameters
----------
x : A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
`int16`, or `int8`.
alpha : `float`. slope.
name : a string or None
An optional name to attach to this activation function.
x : Tensor
Support input type ``float``, ``double``, ``int32``, ``int64``, ``uint8``,
``int16``, or ``int8``.
alpha : float
Slope.
name : str
The function name (optional).
Examples
---------
>>> network = tl.layers.DenseLayer(network, n_units=100, name = 'dense_lrelu',
... act= lambda x : tl.act.lrelu(x, 0.2))
--------
>>> net = tl.layers.DenseLayer(net, 100, act=lambda x : tl.act.lrelu(x, 0.2), name='dense')
Returns
-------
Tensor
A ``Tensor`` in the same type as ``x``.
References
------------
- `Rectifier Nonlinearities Improve Neural Network Acoustic Models, Maas et al. (2013) <http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf>`_
- `Rectifier Nonlinearities Improve Neural Network Acoustic Models, Maas et al. (2013) <http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf>`__
"""
with tf.name_scope(name) as scope:
# x = tf.nn.relu(x)
# m_x = tf.nn.relu(-x)
# x -= alpha * m_x
x = tf.maximum(x, alpha * x)
# with tf.name_scope(name) as scope:
# x = tf.nn.relu(x)
# m_x = tf.nn.relu(-x)
# x -= alpha * m_x
x = tf.maximum(x, alpha * x, name=name)
return x

#Shortcut
lrelu = leaky_relu

def pixel_wise_softmax(output, name='pixel_wise_softmax'):
def swish(x, name='swish'):
"""The Swish function.
See `Swish: a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941>`__.
Parameters
----------
x : Tensor
input.
name: str
function name (optional).
Returns
-------
Tensor
A ``Tensor`` in the same type as ``x``.
"""
with tf.name_scope(name):
x = tf.nn.sigmoid(x) * x
return x


@deprecated("2018-06-30", "This API will be deprecated soon as tf.nn.softmax can do the same thing.")
def pixel_wise_softmax(x, name='pixel_wise_softmax'):
"""Return the softmax outputs of images, every pixels have multiple label, the sum of a pixel is 1.
Usually be used for image segmentation.
Parameters
------------
output : tensor
- For 2d image, 4D tensor [batch_size, height, weight, channel], channel >= 2.
- For 3d image, 5D tensor [batch_size, depth, height, weight, channel], channel >= 2.
----------
x : Tensor
input.
- For 2d image, 4D tensor (batch_size, height, weight, channel), where channel >= 2.
- For 3d image, 5D tensor (batch_size, depth, height, weight, channel), where channel >= 2.
name : str
function name (optional)
Returns
-------
Tensor
A ``Tensor`` in the same type as ``x``.
Examples
---------
--------
>>> outputs = pixel_wise_softmax(network.outputs)
>>> dice_loss = 1 - dice_coe(outputs, y_, epsilon=1e-5)
References
-----------
- `tf.reverse <https://www.tensorflow.org/versions/master/api_docs/python/array_ops.html#reverse>`_
----------
- `tf.reverse <https://www.tensorflow.org/versions/master/api_docs/python/array_ops.html#reverse>`__
"""
with tf.name_scope(name) as scope:
return tf.nn.softmax(output)
## old implementation
# exp_map = tf.exp(output)
# if output.get_shape().ndims == 4: # 2d image
# evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, True]))
# elif output.get_shape().ndims == 5: # 3d image
# evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, False, True]))
# else:
# raise Exception("output parameters should be 2d or 3d image, not %s" % str(output._shape))
# return tf.div(exp_map, evidence)
with tf.name_scope(name):
return tf.nn.softmax(x)


# Alias
linear = identity
lrelu = leaky_relu
1 change: 1 addition & 0 deletions tensorlayer/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""The tensorlayer.cli module provides a command-line tool for some common tasks."""
14 changes: 14 additions & 0 deletions tensorlayer/cli/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import argparse

from tensorlayer.cli import train

if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='tl')
subparsers = parser.add_subparsers(dest='cmd')
train_parser = subparsers.add_parser('train', help='train a model using multiple local GPUs or CPUs.')
train.build_arg_parser(train_parser)
args = parser.parse_args()
if args.cmd == 'train':
train.main(args)
else:
parser.print_help()
Loading

0 comments on commit 14b6fd6

Please sign in to comment.