Skip to content

Commit

Permalink
Pull request #6: Add Variational Autoencoder
Browse files Browse the repository at this point in the history
Merge in CSID/glaucus from feature/vae to main

Squashed commit of the following:

commit 607d45bc9d0f51e69da68a535bffc2998fe98062
Author: Kyle A Logue <kyle.a.logue@aero.org>
Date:   Tue Apr 16 08:55:15 2024 -0700

    Variational Autoencoder & RFLoss Improvements

    * RFLoss now has optional `overlap` parameter for spectrogram calculation
    * Glaucus can now be instantated as an RF Unet by using `blockgen` with new modes `unet-encoder` and `unet-decoder`.
    * Add GitHub workflow
    * Increment to v1.2.0
    * Warnings eliminated
    * Eliminate noise layer from basic AE; could be re-enabled for denoising autoencoder
  • Loading branch information
Philip K Giang authored and Kyle A Logue committed Apr 16, 2024
1 parent 48910a0 commit 60b9ed5
Show file tree
Hide file tree
Showing 14 changed files with 796 additions and 412 deletions.
26 changes: 26 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Python package

on:
push:
pull_request:
types: [opened, synchronize]

jobs:
build:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: ["3.8", "3.10", "3.12"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .
- name: Test with pytest
run: |
coverage run
63 changes: 46 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,40 @@ processing (DSP) in PyTorch.
* `coverage run`
* `pylint glaucus tests`

### Load Variational Autoencoder Model
*New in v1.2.0*

```python
import torch

from glaucus import blockgen, GlaucusVAE

# define model
encoder_blocks = blockgen(steps=8, spatial_in=4096, spatial_out=16, filters_in=2, filters_out=64, mode="encoder")
decoder_blocks = blockgen(steps=8, spatial_in=16, spatial_out=4096, filters_in=64, filters_out=2, mode="decoder")
model = GlaucusVAE(encoder_blocks, decoder_blocks, bottleneck_in=1024, bottleneck_out=1024, data_format='nl')
# get weights
state_dict = torch.hub.load_state_dict_from_url(
'https://github.com/the-aerospace-corporation/glaucus/releases/download/v1.2.0/gvae-1920-2b2478a0.pth',
map_location='cpu')
model.load_state_dict(state_dict)
model.freeze()
model.eval()
# example usage
x_tensor = torch.randn(7, 4096, dtype=torch.complex64)
y_tensor, y_encoded, _, _ = model(x_tensor)
```

### Use pre-trained model with SigMF data

Load quantized model and return compressed signal vector & reconstruction.
Our weights were trained & evaluated on a corpus of 200 GB of RF waveforms with
various added RF impairments for a 1 PB training set.

```python
import torch
import sigmf
import torch

from glaucus import GlaucusAE

# create model
Expand All @@ -41,7 +66,7 @@ model = torch.quantization.prepare(model)
state_dict = torch.hub.load_state_dict_from_url(
'https://github.com/the-aerospace-corporation/glaucus/releases/download/v1.1.0/glaucus-512-3275-5517642b.pth',
map_location='cpu')
model.load_state_dict(state_dict)
model.load_state_dict(state_dict, strict=False)
# prepare for prediction
model.freeze()
model.eval()
Expand All @@ -60,7 +85,8 @@ y_encoded_uint8 = torch.int_repr(y_encoded)
```python
# define architecture
import torch
from glaucus import blockgen, GlaucusAE

from glaucus import GlaucusAE, blockgen

encoder_blocks = blockgen(steps=6, spatial_in=4096, spatial_out=16, filters_in=2, filters_out=64, mode='encoder')
decoder_blocks = blockgen(steps=6, spatial_in=16, spatial_out=4096, filters_in=64, filters_out=2, mode='decoder')
Expand All @@ -71,7 +97,7 @@ model = torch.quantization.prepare(model)
state_dict = torch.hub.load_state_dict_from_url(
'https://github.com/the-aerospace-corporation/glaucus/releases/download/v1.1.0/glaucus-1024-761-c49063fd.pth',
map_location='cpu')
model.load_state_dict(state_dict)
model.load_state_dict(state_dict, strict=False)
# see above for rest
```

Expand All @@ -80,6 +106,7 @@ model.load_state_dict(state_dict)
```python
# create model, but skip quantization
from glaucus.utils import adapt_glaucus_quantized_weights

model = GlaucusAE(bottleneck_quantize=False, data_format='nl')
state_dict = torch.hub.load_state_dict_from_url(
'https://github.com/the-aerospace-corporation/glaucus/releases/download/v1.1.0/glaucus-512-3275-5517642b.pth',
Expand All @@ -98,6 +125,7 @@ model.eval()
```python
import np
import torch

import glaucus

# create criterion
Expand All @@ -117,7 +145,8 @@ loss(xxx, yyy)
*partially implemented pending update or replace with notebook example*

```python
import lightning as pl
import lightning as L

from glaucus import GlaucusAE

model = GlaucusAE(data_format='nl')
Expand All @@ -130,7 +159,7 @@ train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
(len(signal_data)*np.array([0.8, 0.1, 0.1])).astype(int),
generator=torch.Generator().manual_seed(0xcab005e)
)
class RFDataModule(pl.LightningDataModule):
class RFDataModule(L.LightningDataModule):
'''
defines the dataloaders for train, val, test and uses datasets
'''
Expand All @@ -150,26 +179,26 @@ class RFDataModule(pl.LightningDataModule):
def test_dataloader(self):
return DataLoader(self.test_dataset, num_workers=self.num_workers, batch_size=self.batch_size, shuffle=False, pin_memory=True)

loader = RFDataModule(
datamodule = RFDataModule(
train_dataset=train_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
batch_size=batch_size, num_workers=num_workers)

trainer = pl.Trainer()
trainer.fit(model, loader)

# rewind to best checkpoint
model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path, strict=False)
trainer = L.Trainer()
trainer.fit(model, datamodule=datamodule)
# test with best checkpoint
trainer.test(model, datamodule=datamodule, ckpt_path="best")
```

## Pre-trained Model List

| desc | link | size (MB) | params (M) | multiadds (M) | provenance |
|----------|--------------------------------------------------------------------------------------------------------------------------------------------------------|-----------|------------|---------------|---------------------------------------------------------------|
| small | [glaucus-512-3275-5517642b](https://github.com/the-aerospace-corporation/glaucus/releases/download/v1.1.0/glaucus-512-3275-5517642b.pth) | 8.5 | 2.030 | 259 | .009 pfs-days on modulation-only Aerospace DSet |
| accurate | [glaucus-1024-761-c49063fd](https://github.com/the-aerospace-corporation/glaucus/releases/download/v1.1.0/glaucus-1024-761-c49063fd.pth) | 11 | 2.873 | 380 | .035 pfs-days modulation & general waveform Aerospace Dset |
| sig53 | [glaucus-1024-sig53TLe37-2956bcb6](https://github.com/the-aerospace-corporation/glaucus/releases/download/v1.1.3/glaucus-1024-sig53TLe37-2956bcb6.pth) | 11 | 2.873 | 380 | transfer learning from glaucus-1024-761-c49063fd w/Sig53 Dset |
| model weights | desc | published | mem (MB) | params (M) | multiadds (M) | provenance |
|--------------------------------------------------------------------------------------------------------------------------------------------------------|--------------|------------|----------|------------|---------------|--------------------------------------------------------------------------------------------------------------------------------|
| [glaucus-512-3275-5517642b](https://github.com/the-aerospace-corporation/glaucus/releases/download/v1.1.0/glaucus-512-3275-5517642b.pth) | AE small | 2023-03-02 | 17.9 | 2.030 | 259 | .009 pfs-days on modulation-only Aerospace Dset |
| [glaucus-1024-761-c49063fd](https://github.com/the-aerospace-corporation/glaucus/releases/download/v1.1.0/glaucus-1024-761-c49063fd.pth) | AE accurate | 2023-03-02 | 19.9 | 2.873 | 380 | .035 pfs-days modulation & general waveform Aerospace Dset |
| [glaucus-1024-sig53TLe37-2956bcb6](https://github.com/the-aerospace-corporation/glaucus/releases/download/v1.1.3/glaucus-1024-sig53TLe37-2956bcb6.pth) | AE for Sig53 | 2023-05-16 | 19.9 | 2.873 | 380 | transfer learning from glaucus-1024-761-c49063fd w/Sig53 Dset |
| [gvae-1920-2b2478a0.pth](https://github.com/the-aerospace-corporation/glaucus/releases/download/v1.2.0/gvae-1920-2b2478a0.pth) | VAE | 2024-03-25 | 21.6 | 3.440 | 263 | Variational Autoencoder with progressive resampling and a better defined latent space. .006 pfs-days on general waveform Dset. |

### Note on pfs-days

Expand Down
10 changes: 5 additions & 5 deletions glaucus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
# This file is a part of Glaucus
# SPDX-License-Identifier: LGPL-3.0-or-later

__version__ = '1.1.4'
__version__ = '1.2.0'

from .rfloss import *
from .layers import *
from .gblocks import *
from .fcblocks import *
from .autoencoders import *
from .fcblocks import *
from .gblocks import *
from .layers import *
from .rfloss import *
from .utils import *
Loading

0 comments on commit 60b9ed5

Please sign in to comment.