Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
RyotaroOKabe committed Sep 25, 2024
1 parent cd85599 commit cf1a896
Show file tree
Hide file tree
Showing 15 changed files with 69 additions and 1,176 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ To set up the environment locally and run the code:
```bash
jupyter notebook
```
Open either `VVN.ipynb` or `kMVN.ipynb`.

---

Expand Down
2 changes: 1 addition & 1 deletion cif_MVN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### load pre-trained model\n",
"### Load pre-trained model\n",
"Load the pre-trained model from the archive. Our code automatically load the hyperparameters and the learned weights from the .torch file. "
]
},
Expand Down
2 changes: 1 addition & 1 deletion cif_VVN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### load pre-trained model\n",
"### Load pre-trained model\n",
"Load the pre-trained model from the archive. Our code automatically load the hyperparameters and the learned weights from the .torch file. "
]
},
Expand Down
10 changes: 5 additions & 5 deletions cif_kMVN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
" warnings.warn(\n",
"/home/rokabe/anaconda3/envs/phonon/lib/python3.9/site-packages/seekpath/hpkot/__init__.py:362: EdgeCaseWarning: aP lattice, but the k_gamma3 angle is almost equal to 90 degrees\n",
" warnings.warn(\n",
"100%|██████████| 5/5 [00:00<00:00, 56.20it/s]\n"
"100%|██████████| 5/5 [00:00<00:00, 55.82it/s]\n"
]
}
],
Expand Down Expand Up @@ -264,7 +264,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### load pre-trained model\n",
"### Load pre-trained model\n",
"Load the pre-trained model from the archive. Our code automatically load the hyperparameters and the learned weights from the .torch file. "
]
},
Expand Down Expand Up @@ -328,7 +328,7 @@
"text": [
" 0%| | 0/5 [00:00<?, ?it/s]/data1/rokabe/phonon/phonon_prediction/utils/utils_plot.py:209: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n",
" df = pd.concat([df, df0], ignore_index=True)\n",
"100%|██████████| 5/5 [00:04<00:00, 1.07it/s]\n"
"100%|██████████| 5/5 [00:05<00:00, 1.00s/it]\n"
]
}
],
Expand All @@ -339,7 +339,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -368,7 +368,7 @@
"Text(0.5, 1.0, '[mp-149] Si₈: phonon band structure')"
]
},
"execution_count": 10,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
Expand Down
1,061 changes: 0 additions & 1,061 deletions previous_codes/kMVN_previous.ipynb

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@
},
"source": [
"## Data Preparation\n",
"### Data provenance\n",
"### Download dataset\n",
"We train our model using the database of Density Functional Perturbation Theory (DFPT)-calculated phonon bands, containing approximately 1,500 crystalline solids [[Petretto et al. 2018]](https://doi.org/10.1038/sdata.2018.65)."
]
},
Expand Down Expand Up @@ -478,8 +478,8 @@
"id": "LZtSOwB9Cols"
},
"source": [
"### Network architecture\n",
"We build a model based on the `Network` described in the `e3nn` [Documentation](https://docs.e3nn.org/en/latest/api/nn/models/gate_points_2101.html), modified to incorporate the periodic boundaries we imposed on the crystal graphs. The network applies equivariant convolutions to each atomic node and finally takes an average over all nodes, normalizing the output."
"## Network architecture\n",
"We build a model based on the `Network` described in the `e3nn` [Documentation](https://docs.e3nn.org/en/latest/api/nn/models/gate_points_2101.html), modified to incorporate our Virtual node Graph Neural Network (VGNN) to deal with variable output dimension. "
]
},
{
Expand Down Expand Up @@ -580,7 +580,7 @@
"id": "-5cuaOFOC_D_"
},
"source": [
"### Training\n",
"## Train the model\n",
"The model is trained using a mean-squared error loss function with an Adam optimizer.\n",
"Plot the prediction results (train/test data) after 'max_iter' epochs. \n",
"Here we demonstrate training only with zero (or one) epoch. If you just want to use pre-trained model for phonon prediction you can skip this tab. "
Expand Down Expand Up @@ -608,8 +608,8 @@
" batch_size,\n",
" k_fold,\n",
" option=option,\n",
" factor=factor, #!\n",
" conf_dict=None) #!"
" factor=factor, \n",
" conf_dict=None) "
]
},
{
Expand Down Expand Up @@ -651,7 +651,7 @@
"id": "5kpYlnpwEFW3"
},
"source": [
"### Results\n",
"## Results\n",
"We evaluate our model by visualizing the predicted and true phonon band in each error tertile. "
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
"print('training ratio: ', tr_ratio)\n",
"print('batch size: ', batch_size)\n",
"print('irreduceble output representation: ', irreps_out)\n",
"print('factor to multiply/divide label data: ', factor) #!"
"print('factor to multiply/divide label data: ', factor) "
]
},
{
Expand Down Expand Up @@ -221,7 +221,7 @@
},
"source": [
"## Data Preparation\n",
"### Data provenance\n",
"### Download dataset\n",
"We train our model using the database of Density Functional Perturbation Theory (DFPT)-calculated phonon bands, containing approximately 1,500 crystalline solids [[Petretto et al. 2018]](https://doi.org/10.1038/sdata.2018.65)."
]
},
Expand Down Expand Up @@ -489,8 +489,8 @@
"id": "LZtSOwB9Cols"
},
"source": [
"### Network architecture\n",
"We build a model based on the `Network` described in the `e3nn` [Documentation](https://docs.e3nn.org/en/latest/api/nn/models/gate_points_2101.html), modified to incorporate the periodic boundaries we imposed on the crystal graphs. The network applies equivariant convolutions to each atomic node and finally takes an average over all nodes, normalizing the output."
"## Network architecture\n",
"We build a model based on the `Network` described in the `e3nn` [Documentation](https://docs.e3nn.org/en/latest/api/nn/models/gate_points_2101.html), modified to incorporate our Virtual node Graph Neural Network (VGNN) to deal with variable output dimension. "
]
},
{
Expand Down Expand Up @@ -591,7 +591,7 @@
"id": "-5cuaOFOC_D_"
},
"source": [
"### Training\n",
"## Train the model\n",
"The model is trained using a mean-squared error loss function with an Adam optimizer.\n",
"Plot the prediction results (train/test data) after 'max_iter' epochs. \n",
"Here we demonstrate training only with zero (or one) epoch. If you just want to use pre-trained model for phonon prediction you can skip this tab. "
Expand Down Expand Up @@ -619,8 +619,8 @@
" batch_size,\n",
" k_fold,\n",
" option=option,\n",
" factor=factor, #!\n",
" conf_dict=None) #!"
" factor=factor, \n",
" conf_dict=None) "
]
},
{
Expand Down Expand Up @@ -662,7 +662,7 @@
"id": "5kpYlnpwEFW3"
},
"source": [
"### Results\n",
"## Results\n",
"We evaluate our model by visualizing the predicted and true phonon band in each error tertile. "
]
},
Expand Down Expand Up @@ -704,8 +704,8 @@
],
"source": [
"# Generate Data Frame\n",
"df_tr = generate_dataframe(model, tr_loader, loss_fn, device, option, factor) #!\n",
"df_te = generate_dataframe(model, te_loader, loss_fn, device, option, factor) #!"
"df_tr = generate_dataframe(model, tr_loader, loss_fn, device, option, factor) \n",
"df_te = generate_dataframe(model, te_loader, loss_fn, device, option, factor) "
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@
},
"source": [
"## Data Preparation\n",
"### Data provenance\n",
"### Download dataset\n",
"We train our model using the database of Density Functional Perturbation Theory (DFPT)-calculated phonon bands, containing approximately 1,500 crystalline solids [[Petretto et al. 2018]](https://doi.org/10.1038/sdata.2018.65)."
]
},
Expand Down Expand Up @@ -485,8 +485,8 @@
"id": "LZtSOwB9Cols"
},
"source": [
"### Network architecture\n",
"We build a model based on the `Network` described in the `e3nn` [Documentation](https://docs.e3nn.org/en/latest/api/nn/models/gate_points_2101.html), modified to incorporate the periodic boundaries we imposed on the crystal graphs. The network applies equivariant convolutions to each atomic node and finally takes an average over all nodes, normalizing the output."
"## Network architecture\n",
"We build a model based on the `Network` described in the `e3nn` [Documentation](https://docs.e3nn.org/en/latest/api/nn/models/gate_points_2101.html), modified to incorporate our Virtual node Graph Neural Network (VGNN) to deal with variable output dimension. "
]
},
{
Expand Down Expand Up @@ -587,7 +587,7 @@
"id": "-5cuaOFOC_D_"
},
"source": [
"### Training\n",
"## Train the model\n",
"The model is trained using a mean-squared error loss function with an Adam optimizer.\n",
"Plot the prediction results (train/test data) after 'max_iter' epochs. \n",
"Here we demonstrate training only with zero (or one) epoch. If you just want to use pre-trained model for phonon prediction you can skip this tab. "
Expand Down Expand Up @@ -641,7 +641,7 @@
],
"source": [
"model_name = pretrained_name if use_pretrained else run_name # model name (load from pretrained model if exists)\n",
"save_name = f'./models/{model_name}' #!\n",
"save_name = f'./models/{model_name}' \n",
"model_file = save_name + '.torch'\n",
"model.load_state_dict(torch.load(model_file)['state'])\n",
"model = model.to(device)\n",
Expand All @@ -656,7 +656,7 @@
"id": "5kpYlnpwEFW3"
},
"source": [
"### Results\n",
"## Results\n",
"We evaluate our model by visualizing the predicted and true phonon band in each error tertile. "
]
},
Expand Down
16 changes: 8 additions & 8 deletions tutorial_MVN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@
},
"source": [
"## Data Preparation\n",
"### Data provenance\n",
"### Download dataset\n",
"We train our model using the database of Density Functional Perturbation Theory (DFPT)-calculated phonon bands, containing approximately 1,500 crystalline solids [[Petretto et al. 2018]](https://doi.org/10.1038/sdata.2018.65)."
]
},
Expand Down Expand Up @@ -478,8 +478,8 @@
"id": "LZtSOwB9Cols"
},
"source": [
"### Network architecture\n",
"We build a model based on the `Network` described in the `e3nn` [Documentation](https://docs.e3nn.org/en/latest/api/nn/models/gate_points_2101.html), modified to incorporate the periodic boundaries we imposed on the crystal graphs. The network applies equivariant convolutions to each atomic node and finally takes an average over all nodes, normalizing the output."
"## Network architecture\n",
"We build a model based on the `Network` described in the `e3nn` [Documentation](https://docs.e3nn.org/en/latest/api/nn/models/gate_points_2101.html), modified to incorporate our Virtual node Graph Neural Network (VGNN) to deal with variable output dimension. "
]
},
{
Expand Down Expand Up @@ -580,7 +580,7 @@
"id": "-5cuaOFOC_D_"
},
"source": [
"### Training\n",
"## Train the model\n",
"The model is trained using a mean-squared error loss function with an Adam optimizer.\n",
"Plot the prediction results (train/test data) after 'max_iter' epochs. \n",
"Here we demonstrate training only with zero (or one) epoch. If you just want to use pre-trained model for phonon prediction you can skip this tab. "
Expand Down Expand Up @@ -608,8 +608,8 @@
" batch_size,\n",
" k_fold,\n",
" option=option,\n",
" factor=factor, #!\n",
" conf_dict=None) #!"
" factor=factor, \n",
" conf_dict=None) "
]
},
{
Expand Down Expand Up @@ -678,8 +678,8 @@
"id": "5kpYlnpwEFW3"
},
"source": [
"### Results\n",
"We evaluate our model by visualizing the predicted and true phonon band in each error tertile. "
"## Results\n",
"We evaluate our model by visualizing the predicted and true phonon in each error tertile. "
]
},
{
Expand Down
29 changes: 15 additions & 14 deletions tutorial_VVN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"from utils.utils_load import load_band_structure_data \n",
"from utils.utils_data import generate_data_dict, get_lattice_parameters\n",
"from utils.utils_model import BandLoss, GraphNetwork_VVN, train, load_model\n",
"from utils.utils_plot import plot_element_count_stack, generate_dataframe, plot_gphonons, plot_loss, compare_models #!\n",
"from utils.utils_plot import plot_element_count_stack, generate_dataframe, plot_gphonons, plot_loss, compare_models \n",
"# from utils.helpers import make_dict\n",
"from config_file import pretrained_model_dict\n",
"torch.set_default_dtype(torch.float64)\n",
Expand Down Expand Up @@ -88,14 +88,14 @@
"batch_size = 1\n",
"k_fold = 5\n",
"irreps_out = '1x0e'\n",
"factor = 1000 #!\n",
"factor = 1000 \n",
"\n",
"print('\\ndata parameters')\n",
"print('method: ', k_fold, '-fold cross validation')\n",
"print('training ratio: ', tr_ratio)\n",
"print('batch size: ', batch_size)\n",
"print('irreduceble output representation: ', irreps_out)\n",
"print('factor to multiply/divide label data: ', factor) #!"
"print('factor to multiply/divide label data: ', factor) "
]
},
{
Expand Down Expand Up @@ -221,7 +221,7 @@
},
"source": [
"## Data Preparation\n",
"### Data provenance\n",
"### Download dataset\n",
"We train our model using the database of Density Functional Perturbation Theory (DFPT)-calculated phonon bands, containing approximately 1,500 crystalline solids [[Petretto et al. 2018]](https://doi.org/10.1038/sdata.2018.65)."
]
},
Expand Down Expand Up @@ -482,8 +482,8 @@
"id": "LZtSOwB9Cols"
},
"source": [
"### Network architecture\n",
"We build a model based on the `Network` described in the `e3nn` [Documentation](https://docs.e3nn.org/en/latest/api/nn/models/gate_points_2101.html), modified to incorporate the periodic boundaries we imposed on the crystal graphs. The network applies equivariant convolutions to each atomic node and finally takes an average over all nodes, normalizing the output."
"## Network architecture\n",
"We build a model based on the `Network` described in the `e3nn` [Documentation](https://docs.e3nn.org/en/latest/api/nn/models/gate_points_2101.html), modified to incorporate our Virtual node Graph Neural Network (VGNN) to deal with variable output dimension. "
]
},
{
Expand Down Expand Up @@ -584,7 +584,7 @@
"id": "-5cuaOFOC_D_"
},
"source": [
"### Training\n",
"## Train the model\n",
"The model is trained using a mean-squared error loss function with an Adam optimizer.\n",
"Plot the prediction results (train/test data) after 'max_iter' epochs. \n",
"Here we demonstrate training only with zero (or one) epoch. If you just want to use pre-trained model for phonon prediction you can skip this tab. "
Expand Down Expand Up @@ -612,8 +612,8 @@
" batch_size,\n",
" k_fold,\n",
" option=option,\n",
" factor=factor, #!\n",
" conf_dict=None) #!"
" factor=factor, \n",
" conf_dict=None) "
]
},
{
Expand Down Expand Up @@ -642,7 +642,7 @@
"source": [
"\n",
"model_name = pretrained_name if use_pretrained else run_name # model name (load from pretrained model if exists)\n",
"save_name = f'./models/{model_name}' #!\n",
"save_name = f'./models/{model_name}' \n",
"model_file = save_name + '.torch'\n",
"model, conf_dict, history, s0 = load_model(GraphNetwork_VVN, model_file, device)\n",
"num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
Expand All @@ -653,7 +653,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Plot the loss function"
"### Plot the loss function\n",
"Plot the loss history (train, valid)."
]
},
{
Expand Down Expand Up @@ -684,7 +685,7 @@
},
"source": [
"### Results\n",
"We evaluate our model by visualizing the predicted and true phonon band in each error tertile. "
"We evaluate our model by visualizing the predicted and true phonon in each error tertile. "
]
},
{
Expand Down Expand Up @@ -718,8 +719,8 @@
],
"source": [
"# Generate Data Frame\n",
"df_tr = generate_dataframe(model, tr_loader, loss_fn, device, option, factor) #!\n",
"df_te = generate_dataframe(model, te_loader, loss_fn, device, option, factor) #!"
"df_tr = generate_dataframe(model, tr_loader, loss_fn, device, option, factor) \n",
"df_te = generate_dataframe(model, te_loader, loss_fn, device, option, factor) "
]
},
{
Expand Down
Loading

0 comments on commit cf1a896

Please sign in to comment.