From ce97f960f1ae921c5a7bbf9ddd63bbe0bd55b25a Mon Sep 17 00:00:00 2001 From: Mehmet Suzen Date: Sat, 17 Sep 2022 15:47:47 +0200 Subject: [PATCH] vaniall cpse resolved #34 --- README.md | 18 ++++++++++++++++++ bristol/cPSE.py | 15 +++++++++++++++ bristol/version.py | 2 +- 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index edac3d6..0d3a8b8 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,24 @@ pip install -upgrade git+https://github.com/msuzen/bristol.git ## Documentation ### Complexity of a deep learning model: cPSE +### Vanilla case + +In the vanilla case a list of matrices that are representative of +ordered set of weight matrices can be used to compute cPSE over +layers. As an examples: + +```python +from bristol import cPSE +import numpy as np +np.random.seed(42) +matrices = [np.random.normal(size=(64,64)) for _ in range(10)] +(d_layers, cpse) = cPSE.cpse_measure_vanilla(matrices) +``` +Even for set of Gaussian matrices, d_layers decrease. +Note that different layer types should be converted to a matrix +format, i.e., CNNs to 2D matrices. See the main paper. + +### For torch models You need to put your model as pretrained model format of PyTorch. An example for vgg, and use `cPSE.cpse_measure` function simply: diff --git a/bristol/cPSE.py b/bristol/cPSE.py index 2e8a9c1..0760236 100644 --- a/bristol/cPSE.py +++ b/bristol/cPSE.py @@ -161,4 +161,19 @@ def cpse_measure(pmodel): d_layers = d_layers_pse(eset_per) return d_layers, np.mean(np.log10(d_layers)) +def cpse_measure_vanilla(matrices): + """ + + Given list of weight matrices. + pse on layers and mean log pse Cascading PSE + (d_layers, cpse) : d_layers vector and real number cpse + + np.random.seed(42) + matrices = [np.random.normal(size=(64,64)) for _ in range(10)] + (d_layers, cpse) = cPSE.cpse_measure_vanilla(matrices) + """ + eset = get_eigenvals_layer_matrix_set(matrices) + eset_per = eigenvals_set_to_periodic(eset) + d_layers = d_layers_pse(eset_per) + return d_layers, np.mean(np.log10(d_layers)) diff --git a/bristol/version.py b/bristol/version.py index 0caa809..0652ff5 100644 --- a/bristol/version.py +++ b/bristol/version.py @@ -1 +1 @@ -__version__='0.2.11' +__version__='0.2.12'