Skip to content

Commit

Permalink
Added support for alternative Fourier feature map (#54)
Browse files Browse the repository at this point in the history
* added support for alternative Fourier feature map

* added docs and additional tests

* fixed formatting issues

* fixed docs

* fixed formatting issues

* fixed doc formatting

* fixed doc formatting

* fixed doc formatting

* removed trailing whitespaces

* tiny doc fix

* tiny doc fix
  • Loading branch information
ltiao authored Oct 7, 2021
1 parent b5f1175 commit e05d7ba
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 81 deletions.
223 changes: 157 additions & 66 deletions gpflux/layers/basis_functions/random_fourier_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@

from gpflux.layers.basis_functions.utils import (
RFF_SUPPORTED_KERNELS,
_mapping,
_mapping_concat,
_mapping_cosine,
_matern_number,
_sample_students_t,
)
Expand All @@ -39,25 +40,7 @@
"""


class RandomFourierFeatures(tf.keras.layers.Layer):
r"""
Random Fourier Features (RFF) is a method for approximating kernels. The essential
element of the RFF approach :cite:p:`rahimi2007random` is the realization that Bochner's theorem
for stationary kernels can be approximated by a Monte Carlo sum.
We will approximate the kernel :math:`k(x, x')` by :math:`\Phi(x)^\top \Phi(x')`
where :math:`Phi: x \to \mathbb{R}` is a finite-dimensional feature map.
Each feature is defined as:
.. math:: \Phi(x) = \sqrt{2 \sigma^2 / \ell) \cos(\theta^\top x + \tau)
where :math:`\sigma^2` is the kernel variance.
The features are parameterised by random weights:
* :math:`\theta`, sampled proportional to the kernel's spectral density
* :math:`\tau \sim \mathcal{U}(0, 2\pi)`
"""

class RandomFourierFeaturesBase(tf.keras.layers.Layer):
def __init__(self, kernel: gpflow.kernels.Kernel, output_dim: int, **kwargs: Mapping):
"""
:param kernel: kernel to approximate using a set of random features.
Expand All @@ -66,73 +49,33 @@ def __init__(self, kernel: gpflow.kernels.Kernel, output_dim: int, **kwargs: Map
"""
super().__init__(**kwargs)

assert isinstance(kernel, RFF_SUPPORTED_KERNELS), "Unsupported Kernel"
self.kernel = kernel
assert isinstance(self.kernel, RFF_SUPPORTED_KERNELS), "Unsupported Kernel"
self.output_dim = output_dim # M
if kwargs.get("input_dim", None):
self._input_dim = kwargs["input_dim"]
self.build(tf.TensorShape([self._input_dim]))
else:
self._input_dim = None

def build(self, input_shape: ShapeType) -> None:
"""
Creates the variables of the layer.
See `tf.keras.layers.Layer.build()
<https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#build>`_.
"""
input_dim = input_shape[-1]

shape_bias = [1, self.output_dim]
self.b = self.add_weight(
name="bias",
trainable=False,
shape=shape_bias,
dtype=self.dtype,
initializer=self.bias_init,
)

shape_weights = [self.output_dim, input_dim]
def _weights_build(self, input_dim: int, n_components: int) -> None:
shape = (n_components, input_dim)
self.W = self.add_weight(
name="weights",
trainable=False,
shape=shape_weights,
shape=shape,
dtype=self.dtype,
initializer=self.weights_init,
initializer=self._weights_init,
)

super().build(input_shape)

def bias_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
return tf.random.uniform(shape=shape, maxval=2.0 * np.pi, dtype=dtype)

def weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
if isinstance(self.kernel, gpflow.kernels.SquaredExponential):
return tf.random.normal(shape, dtype=dtype)
else:
p = _matern_number(self.kernel)
nu = 2.0 * p + 1.0 # degrees of freedom
return _sample_students_t(nu, shape, dtype)

def call(self, inputs: TensorType) -> tf.Tensor:
"""
Evaluate the basis functions at ``inputs``.
:param inputs: The evaluation points, a tensor with the shape ``[N, D]``.
:return: A tensor with the shape ``[N, M]``.
"""
output = _mapping(
inputs,
self.W,
self.b,
variance=self.kernel.variance,
lengthscales=self.kernel.lengthscales,
n_components=self.output_dim,
)
tf.ensure_shape(output, self.compute_output_shape(inputs.shape))
return output

def compute_output_shape(self, input_shape: ShapeType) -> tf.TensorShape:
"""
Computes the output shape of the layer.
Expand All @@ -156,3 +99,151 @@ def get_config(self) -> Mapping:
)

return config


class RandomFourierFeatures(RandomFourierFeaturesBase):
r"""
Random Fourier features (RFF) is a method for approximating kernels. The essential
element of the RFF approach :cite:p:`rahimi2007random` is the realization that Bochner's theorem
for stationary kernels can be approximated by a Monte Carlo sum.
We will approximate the kernel :math:`k(\mathbf{x}, \mathbf{x}')`
by :math:`\Phi(\mathbf{x})^\top \Phi(\mathbf{x}')`
where :math:`\Phi: \mathbb{R}^{D} \to \mathbb{R}^{M}` is a finite-dimensional feature map.
The feature map is defined as:
.. math::
\Phi(\mathbf{x}) = \sqrt{\frac{2 \sigma^2}{\ell}}
\begin{bmatrix}
\cos(\boldsymbol{\theta}_1^\top \mathbf{x}) \\
\sin(\boldsymbol{\theta}_1^\top \mathbf{x}) \\
\vdots \\
\cos(\boldsymbol{\theta}_{\frac{M}{2}}^\top \mathbf{x}) \\
\sin(\boldsymbol{\theta}_{\frac{M}{2}}^\top \mathbf{x})
\end{bmatrix}
where :math:`\sigma^2` is the kernel variance.
The features are parameterised by random weights:
- :math:`\boldsymbol{\theta} \sim p(\boldsymbol{\theta})`
where :math:`p(\boldsymbol{\theta})` is the spectral density of the kernel.
At least for the squared exponential kernel, this variant of the feature
mapping has more desirable theoretical properties than its cosine-based
counterpart :class:`RandomFourierFeaturesCosine` :cite:p:`sutherland2015error`.
"""

def __init__(self, kernel: gpflow.kernels.Kernel, output_dim: int, **kwargs: Mapping):
"""
:param kernel: kernel to approximate using a set of random features.
:param output_dim: total number of basis functions used to approximate
the kernel.
"""
assert not output_dim % 2, "must specify an even number of random features"
super().__init__(kernel, output_dim, **kwargs)

def build(self, input_shape: ShapeType) -> None:
"""
Creates the variables of the layer.
See `tf.keras.layers.Layer.build()
<https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#build>`_.
"""
input_dim = input_shape[-1]
self._weights_build(input_dim, n_components=self.output_dim // 2)

super().build(input_shape)

def call(self, inputs: TensorType) -> tf.Tensor:
"""
Evaluate the basis functions at ``inputs``.
:param inputs: The evaluation points, a tensor with the shape ``[N, D]``.
:return: A tensor with the shape ``[N, M]``.
"""
output = _mapping_concat(
inputs,
self.W,
variance=self.kernel.variance,
lengthscales=self.kernel.lengthscales,
n_components=self.output_dim,
)
tf.ensure_shape(output, self.compute_output_shape(inputs.shape))
return output


class RandomFourierFeaturesCosine(RandomFourierFeaturesBase):
r"""
Random Fourier Features (RFF) is a method for approximating kernels. The essential
element of the RFF approach :cite:p:`rahimi2007random` is the realization that Bochner's theorem
for stationary kernels can be approximated by a Monte Carlo sum.
We will approximate the kernel :math:`k(\mathbf{x}, \mathbf{x}')`
by :math:`\Phi(\mathbf{x})^\top \Phi(\mathbf{x}')` where
:math:`\Phi: \mathbb{R}^{D} \to \mathbb{R}^{M}` is a finite-dimensional feature map.
The feature map is defined as:
.. math::
\Phi(\mathbf{x}) = \sqrt{\frac{2 \sigma^2}{\ell}}
\begin{bmatrix}
\cos(\boldsymbol{\theta}_1^\top \mathbf{x} + \tau) \\
\vdots \\
\cos(\boldsymbol{\theta}_M^\top \mathbf{x} + \tau)
\end{bmatrix}
where :math:`\sigma^2` is the kernel variance.
The features are parameterised by random weights:
- :math:`\boldsymbol{\theta} \sim p(\boldsymbol{\theta})`
where :math:`p(\boldsymbol{\theta})` is the spectral density of the kernel
- :math:`\tau \sim \mathcal{U}(0, 2\pi)`
Equivalent to :class:`RandomFourierFeatures` by elementary trignometric identities.
"""

def build(self, input_shape: ShapeType) -> None:
"""
Creates the variables of the layer.
See `tf.keras.layers.Layer.build()
<https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#build>`_.
"""
input_dim = input_shape[-1]
self._weights_build(input_dim, n_components=self.output_dim)
self._bias_build(n_components=self.output_dim)

super().build(input_shape)

def _bias_build(self, n_components: int) -> None:
shape = (1, n_components)
self.b = self.add_weight(
name="bias",
trainable=False,
shape=shape,
dtype=self.dtype,
initializer=self._bias_init,
)

def _bias_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
return tf.random.uniform(shape=shape, maxval=2.0 * np.pi, dtype=dtype)

def call(self, inputs: TensorType) -> tf.Tensor:
"""
Evaluate the basis functions at ``inputs``.
:param inputs: The evaluation points, a tensor with the shape ``[N, D]``.
:return: A tensor with the shape ``[N, M]``.
"""
output = _mapping_cosine(
inputs,
self.W,
self.b,
variance=self.kernel.variance,
lengthscales=self.kernel.lengthscales,
n_components=self.output_dim,
)
tf.ensure_shape(output, self.compute_output_shape(inputs.shape))
return output
24 changes: 22 additions & 2 deletions gpflux/layers/basis_functions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _sample_students_t(nu: float, shape: ShapeType, dtype: DType) -> TensorType:
return students_t_rvs


def _mapping(
def _mapping_cosine(
X: TensorType,
W: TensorType,
b: TensorType,
Expand All @@ -89,5 +89,25 @@ def _mapping(
"""
constant = tf.sqrt(2.0 * variance / n_components)
X_scaled = tf.divide(X, lengthscales) # [N, D]
bases = tf.cos(tf.matmul(X_scaled, W, transpose_b=True) + b) # [N, M]
proj = tf.matmul(X_scaled, W, transpose_b=True) # [N, M]
bases = tf.cos(proj + b) # [N, M]
return constant * bases # [N, M]


def _mapping_concat(
X: TensorType,
W: TensorType,
variance: TensorType,
lengthscales: TensorType,
n_components: int,
) -> TensorType:
"""
Feature map for random Fourier features (RFF) as originally prescribed
by Rahimi & Recht, 2007 :cite:p:`rahimi2007random`.
See also :cite:p:`sutherland2015error` for additional details.
"""
constant = tf.sqrt(2.0 * variance / n_components)
X_scaled = tf.divide(X, lengthscales) # [N, D]
proj = tf.matmul(X_scaled, W, transpose_b=True) # [N, M // 2]
bases = tf.concat([tf.sin(proj), tf.cos(proj)], axis=-1) # [N, M]
return constant * bases # [N, M]
Loading

0 comments on commit e05d7ba

Please sign in to comment.