-
Notifications
You must be signed in to change notification settings - Fork 187
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
【Hackathon 7th PPSCI No.4】NO.4 DrivAerNet 论文复现 (#1047)
* Create ReadME.md * Add files via upload * Delete examples/fsi/DrivAerNet directory * Create ReadME.md * Add files via upload * drivaernet * Delete examples/fsi/DrivAerNet-paddle-convert-main directory * modify some style * modify yaml * update example drivaernet code * update requirments and delete fsi/DrivAerNet.py etc. * Update drivaernet_dataset.py * update drivaernet_dataset.py * update drivaernet_dataset.py * Update __init__.py * Update __init__.py * Update DrivAerNet.py * Update DrivAerNet.py * Update __init__.py * update * support fraction of the training data * modify some error about version * Update regdgcnn.py * Delete examples/DrivAerNet/requirments.txt * Delete docs/zh/examples/drivaernet directory * Update DriveAerNet.yaml * Rename DrivAerNet.md to drivaernet.md * Update mkdocs.yml * Update drivaernet.md * Update DriveAerNet.yaml * Update DrivAerNet.py * Update DriveAerNet.yaml * Update DrivAerNet.py * Update DrivAerNet.py * Update drivaernet.md * Update DrivAerNet.py * Update DriveAerNet.yaml * Update drivaernet_dataset.py * Update solver.py * update metric.md * Update metric.md * Update optimizer.md * Update arch.md * Update DrivAerNetDataset dataset.md * Update drivaernet_dataset.py * Update drivaernet.md * Update dataset.md * Update arch.md * Update lr_scheduler.py * Delete docs/zh/api/arch.md * Delete docs/zh/api/data/dataset.md * Delete ppsci/optimizer/lr_scheduler.py * Update optimizer.md * Delete docs/zh/api/optimizer.md * Create arch.md * Create dataset.md * Create optimizer.md * Update lr_scheduler.md * Create lr_scheduler.py * Update solver.py * Update drivaernet.md * Update arch.md * Update dataset.md * Update lr_scheduler.py * Update lr_scheduler.md * Rename DriveAerNet.yaml to driveaernet.yaml * Update and rename DrivAerNet.py to drivaernet.py * Rename driveaernet.yaml to drivaernet.yaml * Update drivaernet.md * Update r2_score.py * Update max_ae.py * Update drivaernet_dataset.py * Update drivaernet.md * Update solver.py * Update solver.py * Update drivaernet.py * Create drivaernet * Delete examples/drivaernet * Create drivaernet * Delete examples/drivaernet * Delete examples/DrivAerNet directory * Create drivaernet.py * Create drivaernet.yaml * Update drivaernet.yaml * Update drivaernet.md * Update drivaernet.md * Update regdgcnn.py * Update lr_scheduler.py * Update lr_scheduler.py * Update solver.py * Update lr_scheduler.py * Update examples/drivaernet/drivaernet.py Co-authored-by: HydrogenSulfate <490868991@qq.com> * Update examples/drivaernet/conf/drivaernet.yaml Co-authored-by: HydrogenSulfate <490868991@qq.com> * Update examples/drivaernet/conf/drivaernet.yaml Co-authored-by: HydrogenSulfate <490868991@qq.com> * Update examples/drivaernet/conf/drivaernet.yaml Co-authored-by: HydrogenSulfate <490868991@qq.com> * Update examples/drivaernet/drivaernet.py Co-authored-by: HydrogenSulfate <490868991@qq.com> * Update ppsci/arch/regdgcnn.py Co-authored-by: HydrogenSulfate <490868991@qq.com> * Update ppsci/arch/regdgcnn.py Co-authored-by: HydrogenSulfate <490868991@qq.com> * Update regdgcnn.py * Update lr_scheduler.py * Update drivaernet_dataset.py * Update drivaernet_dataset.py * Update solver.py * Update ppsci/solver/solver.py --------- Co-authored-by: HydrogenSulfate <490868991@qq.com>
- Loading branch information
1 parent
6652bb8
commit 2706d22
Showing
17 changed files
with
1,627 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,5 +37,6 @@ | |
- USCNN | ||
- LNO | ||
- TGCN | ||
- RegDGCNN | ||
show_root_heading: true | ||
heading_level: 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,4 +32,5 @@ | |
- MOlFLOWDataset | ||
- CGCNNDataset | ||
- PEMSDataset | ||
- DrivAerNetDataset | ||
show_root_heading: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,5 +13,6 @@ | |
- OneCycleLR | ||
- Piecewise | ||
- Step | ||
- ReduceOnPlateau | ||
show_root_heading: true | ||
heading_level: 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,5 +13,7 @@ | |
- MeanL2Rel | ||
- MSE | ||
- RMSE | ||
- MaxAE | ||
- R2Score | ||
show_root_heading: true | ||
heading_level: 3 |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
defaults: | ||
- ppsci_default | ||
- TRAIN: train_default | ||
- TRAIN/ema: ema_default | ||
- TRAIN/swa: swa_default | ||
- EVAL: eval_default | ||
- INFER: infer_default | ||
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default | ||
- _self_ | ||
|
||
hydra: | ||
run: | ||
dir: outputs_drivaernet/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} | ||
job: | ||
name: ${mode} | ||
chdir: false | ||
callbacks: | ||
init_callback: | ||
_target_: ppsci.utils.callbacks.InitCallback | ||
sweep: | ||
dir: ${hydra.run.dir} | ||
subdir: ./ | ||
|
||
# general settings | ||
mode: train | ||
seed: 1 | ||
output_dir: ${hydra:run.dir} | ||
log_freq: 100 | ||
|
||
# model settings | ||
MODEL: | ||
input_keys: ["vertices"] | ||
output_keys: ["cd_value"] | ||
weight_keys: ["weight_keys"] | ||
dropout: 0.4 | ||
emb_dims: 512 | ||
k: 40 | ||
output_channels: 1 | ||
|
||
# training settings | ||
TRAIN: | ||
iters_per_epoch: 2776 | ||
num_points: 5000 | ||
epochs: 100 | ||
num_workers: 32 | ||
eval_during_train: True | ||
train_ids_file: "train_design_ids.txt" | ||
eval_ids_file: "val_design_ids.txt" | ||
batch_size: 1 | ||
train_fractions: 1 | ||
scheduler: | ||
mode: "min" | ||
patience: 20 | ||
factor: 0.1 | ||
verbose: True | ||
|
||
# evaluation settings | ||
EVAL: | ||
num_points: 5000 | ||
batch_size: 2 | ||
pretrained_model_path: "https://dataset.bj.bcebos.com/PaddleScience/DNNFluid-Car/DrivAer/CdPrediction_DrivAerNet_r2_100epochs_5k_best_model.pdparams" | ||
eval_with_no_grad: True | ||
ids_file: "test_design_ids.txt" | ||
num_workers: 8 | ||
|
||
# optimizer settings | ||
optimizer: | ||
weight_decay: 0.0001 | ||
lr: 0.001 | ||
optimizer: 'adam' | ||
|
||
ARGS: | ||
# dataset settings | ||
dataset_path: 'data/DrivAerNet_Processed_Point_Clouds_5k_paddle' | ||
aero_coeff: 'data/AeroCoefficients_DrivAerNet_FilteredCorrected.csv' | ||
subset_dir: 'data/subset_dir' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import warnings | ||
from functools import partial | ||
|
||
import hydra | ||
import paddle | ||
from omegaconf import DictConfig | ||
|
||
import ppsci | ||
|
||
|
||
def train(cfg: DictConfig): | ||
# set model | ||
model = ppsci.arch.RegDGCNN( | ||
input_keys=cfg.MODEL.input_keys, | ||
label_keys=cfg.MODEL.output_keys, | ||
weight_keys=cfg.MODEL.weight_keys, | ||
args=cfg.MODEL, | ||
) | ||
|
||
train_dataloader_cfg = { | ||
"dataset": { | ||
"name": "DrivAerNetDataset", | ||
"root_dir": cfg.ARGS.dataset_path, | ||
"input_keys": ("vertices",), | ||
"label_keys": ("cd_value",), | ||
"weight_keys": ("weight_keys",), | ||
"subset_dir": cfg.ARGS.subset_dir, | ||
"ids_file": cfg.TRAIN.train_ids_file, | ||
"csv_file": cfg.ARGS.aero_coeff, | ||
"num_points": cfg.TRAIN.num_points, | ||
"train_fractions": cfg.TRAIN.train_fractions, | ||
"mode": cfg.mode, | ||
}, | ||
"batch_size": cfg.TRAIN.batch_size, | ||
"num_workers": cfg.TRAIN.num_workers, | ||
} | ||
|
||
drivaernet_constraint = ppsci.constraint.SupervisedConstraint( | ||
train_dataloader_cfg, | ||
ppsci.loss.MSELoss("mean"), | ||
name="DrivAerNet_constraint", | ||
) | ||
|
||
constraint = {drivaernet_constraint.name: drivaernet_constraint} | ||
|
||
valid_dataloader_cfg = { | ||
"dataset": { | ||
"name": "DrivAerNetDataset", | ||
"root_dir": cfg.ARGS.dataset_path, | ||
"input_keys": ("vertices",), | ||
"label_keys": ("cd_value",), | ||
"weight_keys": ("weight_keys",), | ||
"subset_dir": cfg.ARGS.subset_dir, | ||
"ids_file": cfg.TRAIN.eval_ids_file, | ||
"csv_file": cfg.ARGS.aero_coeff, | ||
"num_points": cfg.TRAIN.num_points, | ||
}, | ||
"batch_size": cfg.TRAIN.batch_size, | ||
"num_workers": cfg.TRAIN.num_workers, | ||
} | ||
|
||
drivaernet_valid = ppsci.validate.SupervisedValidator( | ||
valid_dataloader_cfg, | ||
loss=ppsci.loss.MSELoss("mean"), | ||
metric={"MSE": ppsci.metric.MSE()}, | ||
name="DrivAerNet_valid", | ||
) | ||
|
||
validator = {drivaernet_valid.name: drivaernet_valid} | ||
|
||
# set optimizer | ||
lr_scheduler = ppsci.optimizer.lr_scheduler.ReduceOnPlateau( | ||
epochs=cfg.TRAIN.epochs, | ||
iters_per_epoch=( | ||
cfg.TRAIN.iters_per_epoch | ||
* cfg.TRAIN.train_fractions | ||
// (paddle.distributed.get_world_size() * cfg.TRAIN.batch_size) | ||
+ 1 | ||
), | ||
learning_rate=cfg.optimizer.lr, | ||
mode=cfg.TRAIN.scheduler.mode, | ||
patience=cfg.TRAIN.scheduler.patience, | ||
factor=cfg.TRAIN.scheduler.factor, | ||
verbose=cfg.TRAIN.scheduler.verbose, | ||
)() | ||
|
||
optimizer = ( | ||
ppsci.optimizer.Adam(lr_scheduler, weight_decay=cfg.optimizer.weight_decay)( | ||
model | ||
) | ||
if cfg.optimizer.optimizer == "adam" | ||
else ppsci.optimizer.SGD(lr_scheduler, weight_decay=cfg.optimizer.weight_decay)( | ||
model | ||
) | ||
) | ||
|
||
# initialize solver | ||
solver = ppsci.solver.Solver( | ||
model=model, | ||
iters_per_epoch=( | ||
cfg.TRAIN.iters_per_epoch | ||
* cfg.TRAIN.train_fractions | ||
// (paddle.distributed.get_world_size() * cfg.TRAIN.batch_size) | ||
+ 1 | ||
), | ||
constraint=constraint, | ||
output_dir=cfg.output_dir, | ||
optimizer=optimizer, | ||
lr_scheduler=lr_scheduler, | ||
epochs=cfg.TRAIN.epochs, | ||
validator=validator, | ||
eval_during_train=cfg.TRAIN.eval_during_train, | ||
eval_with_no_grad=cfg.EVAL.eval_with_no_grad, | ||
) | ||
|
||
lr_scheduler.step = partial(lr_scheduler.step, metrics=solver.cur_metric) | ||
solver.lr_scheduler = lr_scheduler | ||
|
||
# train model | ||
solver.train() | ||
|
||
solver.eval() | ||
|
||
|
||
def evaluate(cfg: DictConfig): | ||
# set model | ||
model = ppsci.arch.RegDGCNN( | ||
input_keys=cfg.MODEL.input_keys, | ||
label_keys=cfg.MODEL.output_keys, | ||
weight_keys=cfg.MODEL.weight_keys, | ||
args=cfg.MODEL, | ||
) | ||
|
||
valid_dataloader_cfg = { | ||
"dataset": { | ||
"name": "DrivAerNetDataset", | ||
"root_dir": cfg.ARGS.dataset_path, | ||
"input_keys": ("vertices",), | ||
"label_keys": ("cd_value",), | ||
"weight_keys": ("weight_keys",), | ||
"subset_dir": cfg.ARGS.subset_dir, | ||
"ids_file": cfg.EVAL.ids_file, | ||
"csv_file": cfg.ARGS.aero_coeff, | ||
"num_points": cfg.EVAL.num_points, | ||
"mode": cfg.mode, | ||
}, | ||
"batch_size": cfg.EVAL.batch_size, | ||
"num_workers": cfg.EVAL.num_workers, | ||
} | ||
|
||
drivaernet_valid = ppsci.validate.SupervisedValidator( | ||
valid_dataloader_cfg, | ||
loss=ppsci.loss.MSELoss("mean"), | ||
metric={ | ||
"MSE": ppsci.metric.MSE(), | ||
"MAE": ppsci.metric.MAE(), | ||
"Max AE": ppsci.metric.MaxAE(), | ||
"R²": ppsci.metric.R2Score(), | ||
}, | ||
name="DrivAerNet_valid", | ||
) | ||
|
||
validator = {drivaernet_valid.name: drivaernet_valid} | ||
|
||
solver = ppsci.solver.Solver( | ||
model=model, | ||
validator=validator, | ||
pretrained_model_path=cfg.EVAL.pretrained_model_path, | ||
eval_with_no_grad=cfg.EVAL.eval_with_no_grad, | ||
) | ||
|
||
# evaluate model | ||
solver.eval() | ||
|
||
|
||
@hydra.main(version_base=None, config_path="./conf", config_name="drivaernet.yaml") | ||
def main(cfg: DictConfig): | ||
warnings.filterwarnings("ignore") | ||
if cfg.mode == "train": | ||
train(cfg) | ||
elif cfg.mode == "eval": | ||
evaluate(cfg) | ||
else: | ||
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.