-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e7db83a
commit b22941b
Showing
9 changed files
with
140 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{ | ||
"files.associations": { | ||
"complex": "cpp" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
// Optimized implementation for deep recurrent neural networks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
// Implementation for Patch Embedding |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
// Optimized implementation of the fourier layer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
// Optimized implementation of Gram Schmidt Networks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
// Optimized partial implementation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,13 @@ | ||
// Implementation for Patch Embedding | ||
// Implementation for Patch Embedding | ||
#include <torch/extensions.h> | ||
|
||
namespace F = torch::nn::functional; | ||
|
||
at::Tensor patch_embed_2d () { | ||
|
||
} | ||
|
||
at::Tensor patch_embed_3d () { | ||
|
||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,90 @@ | ||
// Optimized implementation of the fourier layer | ||
// Optimized implementation of the fourier layer | ||
#include <torch/extension.h> | ||
#include <complex> | ||
#include <iostream> | ||
#include <vector> | ||
|
||
namespace F = torch::nn::functional; | ||
|
||
// Defining the derivatives for the backward pass | ||
|
||
|
||
// Defining the Backward pass | ||
std::vector<torch::Tensor> fourier_conv_backward( | ||
const torch::Tensor& grad_output, | ||
const torch::Tensor& input, | ||
const torch::Tensor& weight, | ||
const torch::Tensor& bias, | ||
const bool& pre_fft, | ||
const bool& post_ifft | ||
) { | ||
torch::Tensor grad_input, grad_weight, grad_bias, grad_output_fft; | ||
|
||
struct out = {}; | ||
if (pre_fft) { | ||
// Compute the derivative of the fourier transform | ||
} | ||
// Compute the default derivative | ||
|
||
if (post_ifft) { | ||
// Compute the derivative of the inverse fourier transform | ||
} | ||
} | ||
|
||
// Defining the Forward pass | ||
torch::Tensor fourier_conv_forward ( | ||
at::Tensor& input, | ||
at::Tensor& weight, | ||
at::Tensor& bias, | ||
const bool& pre_fft, | ||
const bool& post_ifft | ||
) { | ||
at::Tensor out; | ||
// Assertions | ||
TORCH_CHECK(input.dim() >= 3 && input.dim() <= 5, 'Input must be over 3 dimensions (min 1d_conv: (batch, channel, size)), and under 5 dimensions (max 3d_conv: (batch, channel, frame, height, width))'); | ||
TORCH_CHECK(input.dim() - 2 == weight.dim(), 'Input dimension must be equal to the weight dimension'); | ||
TORCH_CHECK(bias.dim() == 1, 'Bias should be of size 1, don\'t reshape.'); | ||
TORCH_INTERNAL_ASSERT(input.device().type() == at::DeviceType::CPU); | ||
TORCH_INTERNAL_ASSERT(weight.device().type() == at::DeviceType::CPU); | ||
TORCH_INTERNAL_ASSERT(bias.device().type() == at::DeviceType::CPU); | ||
|
||
// define the dimensionality | ||
if (pre_fft || post_ifft) { | ||
// Defining the dimensions for the fft and ifft | ||
std::vector<int64_t> dim; | ||
// Compute the fft of the input | ||
switch (input.dim() - 2) { | ||
case 1: const struct dim = {-1}; break; | ||
case 2: const struct dim = {-1, -2}; break; | ||
case 3: const struct dim = {-1, -2, -3}; break; | ||
default: TORCH_CHECK(false, 'Unsupported in put dimension'); break; | ||
} | ||
|
||
if (pre_ifft) { | ||
// To Fourier space | ||
input = torch::fft::fftn(input, {}, dim); | ||
weight = torch::fft::fftn(weight, {}, dim); | ||
bias = torch::fft::fftn(bias, {}, dim = {-1}); | ||
} | ||
} | ||
|
||
// Convolution in Fourier Space | ||
out = input * weight; | ||
|
||
// Add the bias term | ||
if (bias.defined()) { | ||
out += bias.view(1, -1, 1, 1); | ||
} | ||
|
||
if (post_ifft) { | ||
// Compute the ifft | ||
out = torch::fft::ifftn(out, {}, dim); | ||
} | ||
return out; | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def('fourier_conv_forward', &fourier_conv_forward, 'Fourier Convolutions Forward'), | ||
m.def('fourier_conv_backward', &fourier_conv_backward, 'Fourier Convolutions Backward') | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,27 @@ | ||
// Optimized implementation of Gram Schmidt Networks | ||
// Optimized implementation of Gram Schmidt Networks | ||
#include <torch/extensions.h> | ||
#include <math.h> | ||
at::Tensor gsn() { | ||
|
||
} | ||
|
||
at::Tensor gram_schmidt (Layers) { | ||
|
||
} | ||
|
||
|
||
// Defining the Frontibier inner product for linear transformation space | ||
float inner (const at::Tensor& L_1, const at::Tensor& L_2) { | ||
return torch::trace(torch::addmm(L_1.transpose({-1 , -2}).conj(), L_2)); | ||
} | ||
|
||
// Defining the generalized norm for any linear space | ||
float norm (const at::Tensor& L_1, const at::Tensor& L_2) { | ||
return sqrt(inner(L_1, L_2)); | ||
} | ||
|
||
// Defining the projection operation proj_u v | ||
at::Tensor proj(const at::Tensor& u, const at::Tensor& v) { | ||
return torch::mul((inner(v, u)/inner(u, u)), u); | ||
} | ||
|