Skip to content

Commit

Permalink
update superresolution example
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Nov 8, 2023
1 parent 4b82d4f commit a7787f6
Show file tree
Hide file tree
Showing 81 changed files with 507 additions and 335 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
55 changes: 39 additions & 16 deletions examples/vision/ipynb/super_resolution_sub_pixel.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
"## Introduction\n",
"\n",
"ESPCN (Efficient Sub-Pixel CNN), proposed by [Shi, 2016](https://arxiv.org/abs/1609.05158)\n",
"is a model that reconstructs a high-resolution version of an image given a low-resolution version.\n",
"is a model that reconstructs a high-resolution version of an image given a low-resolution\n",
"version.\n",
"It leverages efficient \"sub-pixel convolution\" layers, which learns an array of\n",
"image upscaling filters.\n",
"\n",
"In this code example, we will implement the model from the paper and train it on a small dataset,\n",
"In this code example, we will implement the model from the paper and train it on a small\n",
"dataset,\n",
"[BSDS500](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html).\n",
"[BSDS500](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html)."
]
},
Expand All @@ -48,19 +51,19 @@
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import keras\n",
"from keras import layers\n",
"from keras import ops\n",
"from keras.utils import load_img\n",
"from keras.utils import array_to_img\n",
"from keras.utils import img_to_array\n",
"from keras.preprocessing import image_dataset_from_directory\n",
"import tensorflow as tf # only for data preprocessing\n",
"\n",
"import os\n",
"import math\n",
"import numpy as np\n",
"\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"from tensorflow.keras.utils import load_img\n",
"from tensorflow.keras.utils import array_to_img\n",
"from tensorflow.keras.utils import img_to_array\n",
"from tensorflow.keras.preprocessing import image_dataset_from_directory\n",
"\n",
"from IPython.display import display"
]
},
Expand Down Expand Up @@ -316,19 +319,38 @@
},
"outputs": [],
"source": [
"\n",
"class DepthToSpace(layers.Layer):\n",
" def __init__(self, block_size):\n",
" super().__init__()\n",
" self.block_size = block_size\n",
"\n",
" def call(self, input):\n",
" batch, height, width, depth = ops.shape(input)\n",
" depth = depth // (self.block_size**2)\n",
"\n",
" x = ops.reshape(\n",
" input, [batch, height, width, self.block_size, self.block_size, depth]\n",
" )\n",
" x = ops.transpose(x, [0, 1, 3, 2, 4, 5])\n",
" x = ops.reshape(\n",
" x, [batch, height * self.block_size, width * self.block_size, depth]\n",
" )\n",
" return x\n",
"\n",
"\n",
"def get_model(upscale_factor=3, channels=1):\n",
" conv_args = {\n",
" \"activation\": \"relu\",\n",
" \"kernel_initializer\": \"Orthogonal\",\n",
" \"kernel_initializer\": \"orthogonal\",\n",
" \"padding\": \"same\",\n",
" }\n",
" inputs = keras.Input(shape=(None, None, channels))\n",
" x = layers.Conv2D(64, 5, **conv_args)(inputs)\n",
" x = layers.Conv2D(64, 3, **conv_args)(x)\n",
" x = layers.Conv2D(32, 3, **conv_args)(x)\n",
" x = layers.Conv2D(channels * (upscale_factor ** 2), 3, **conv_args)(x)\n",
" outputs = tf.nn.depth_to_space(x, upscale_factor)\n",
" x = layers.Conv2D(channels * (upscale_factor**2), 3, **conv_args)(x)\n",
" outputs = DepthToSpace(upscale_factor)(x)\n",
"\n",
" return keras.Model(inputs, outputs)\n",
""
Expand Down Expand Up @@ -492,11 +514,11 @@
"source": [
"early_stopping_callback = keras.callbacks.EarlyStopping(monitor=\"loss\", patience=10)\n",
"\n",
"checkpoint_filepath = \"/tmp/checkpoint\"\n",
"checkpoint_filepath = \"/tmp/checkpoint.keras\"\n",
"\n",
"model_checkpoint_callback = keras.callbacks.ModelCheckpoint(\n",
" filepath=checkpoint_filepath,\n",
" save_weights_only=True,\n",
" save_weights_only=False,\n",
" monitor=\"loss\",\n",
" mode=\"min\",\n",
" save_best_only=True,\n",
Expand Down Expand Up @@ -530,7 +552,8 @@
"epochs = 100\n",
"\n",
"model.compile(\n",
" optimizer=optimizer, loss=loss_fn,\n",
" optimizer=optimizer,\n",
" loss=loss_fn,\n",
")\n",
"\n",
"model.fit(\n",
Expand Down
Loading

0 comments on commit a7787f6

Please sign in to comment.