-
Notifications
You must be signed in to change notification settings - Fork 187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Example] Add preformer for precipitation nowcasting #976
Open
EricKing19
wants to merge
8
commits into
PaddlePaddle:develop
Choose a base branch
from
EricKing19:dev_model
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,054
−0
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
b7e0216
merge code of upstream
EricKing19 9092f86
Update preformer.md
EricKing19 038b3fc
Update preformer.md
EricKing19 52adb6a
Update train.yaml
EricKing19 c110ae4
Update and rename train.py to main.py
EricKing19 88eda6a
Update era5sq_dataset.py
EricKing19 b99da76
Update era5sq_dataset.py
EricKing19 4511659
Update preformer.md
EricKing19 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
# Preformer | ||
|
||
开始训练、评估前,请下载数据集文件 | ||
|
||
开始评估前,请下载或训练生成预训练模型 | ||
|
||
Comment on lines
+3
to
+6
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以稍微介绍一下数据集的准备过程吗?比如如何下载和解压后的文件组织形式? |
||
=== "模型训练命令" | ||
|
||
``` sh | ||
python main.py | ||
``` | ||
|
||
=== "模型评估命令" | ||
|
||
``` sh | ||
python main.py mode=eval | ||
``` | ||
|
||
## 1. 背景简介 | ||
|
||
降水是一种与人类生产生活密切相关的天气现象。准确预测短临降水不仅为农业管理、交通规划以及灾害预防等公共服务提供关键技术支持,也是一项具有挑战性的学术研究任务。近年来,深度学习在气象预测领域取得了重大突破。以多模态三维(高度、经度及纬度)气象数据为研究对象,研究基于深度学习的短临降水预测方法,具有重要的理论研究价值和广阔的应用前景。 | ||
|
||
Preformer,一种用于短临降水预测的时空Transformer网络,该模型由编码器、演变器和解码器组成。具体而言,编码器通过探索embedding之间的依赖来编码空间特征。通过演变器,从重新排列的embedding中学习全局时间动态特性。最后在解码器中,将时空表征解码为未来降水量。 | ||
|
||
|
||
## 2. 模型原理 | ||
|
||
本章节对 Preformer 的模型原理进行简单地介绍。 | ||
|
||
### 2.1 编码器 | ||
|
||
该模块使用两层Transformer,提取空间特征更新节点特征: | ||
|
||
``` py linenums="8" title="ppsci/arch/preformer.py" | ||
--8<-- | ||
ppsci/arch/preformer.py:194:217 | ||
--8<-- | ||
``` | ||
|
||
### 2.2 演变器 | ||
|
||
该模块使用两层Transformer,学习全局时间动态特性: | ||
|
||
``` py linenums="29" title="ppsci/arch/preformer.py" | ||
--8<-- | ||
ppsci/arch/preformer.py:220:254 | ||
--8<-- | ||
``` | ||
|
||
### 2.3 解码器 | ||
|
||
该模块使用两层卷积,将时空表征解码为未来降水量: | ||
|
||
``` py linenums="29" title="ppsci/arch/preformer.py" | ||
--8<-- | ||
ppsci/arch/preformer.py:257:273 | ||
--8<-- | ||
``` | ||
|
||
### 2.4 Preformer模型结构 | ||
|
||
Preformer模型首先使用特征嵌入层对输入信号(过去几小时的气象要素)进行空间特征编码: | ||
|
||
``` py linenums="73" title="ppsci/arch/preformer.py" | ||
--8<-- | ||
ppsci/arch/preformer.py:293:293 | ||
--8<-- | ||
``` | ||
|
||
``` py linenums="94" title="ppsci/arch/preformer.py" | ||
--8<-- | ||
ppsci/arch/preformer.py:194:217 | ||
--8<-- | ||
``` | ||
|
||
然后模型利用演变器将学习空间特征的动态特性,预测未来几小时的气象特征: | ||
|
||
``` py linenums="75" title="ppsci/arch/preformer.py" | ||
--8<-- | ||
ppsci/arch/preformer.py:310:313 | ||
--8<-- | ||
``` | ||
|
||
``` py linenums="96" title="ppsci/arch/preformer.py" | ||
--8<-- | ||
ppsci/arch/preformer.py:220:254 | ||
--8<-- | ||
``` | ||
|
||
最后模型将时空动态特性与初始气象底层特征结合,使用两层卷积预测未来短时降水强度: | ||
|
||
``` py linenums="112" title="ppsci/arch/preformer.py" | ||
--8<-- | ||
ppsci/arch/preformer.py:83:86 | ||
--8<-- | ||
``` | ||
|
||
``` py linenums="35" title="ppsci/arch/preformer.py" | ||
--8<-- | ||
ppsci/arch/preformer.py:112:116 | ||
--8<-- | ||
``` | ||
|
||
## 3. 模型训练 | ||
|
||
### 3.1 数据集介绍 | ||
|
||
案例中使用了预处理的ERA5SQ数据集,属于ERA5再分析数据的一个子集。ERA5SQ包含了全球大气、陆地和海洋的多种变量,分辨率为31公里。该数据集从1979年开始到2018年,每小时提供一次天气状况的估计,非常适合用于降水预测和水汽总量的分析等任务。 | ||
|
||
数据集被保存为 T x H x W 的矩阵,记录了相应地点和时间的降雨量,其中 T 为时间序列长度,H 和 W 代表按照经纬度划分后的矩阵的高度和宽度。根据年份,数据集按照 7:2:1 划分为训练集、验证集,和测试集。案例中预先计算了降雨数据的均值与标准差,用于后续的正则化操作。 | ||
|
||
### 3.2 模型训练 | ||
|
||
#### 3.2.1 模型构建 | ||
|
||
该案例基于 Preformer 模型实现,用 PaddleScience 代码表示如下: | ||
|
||
``` py linenums="79" title="examples/preformer/train.py" | ||
--8<-- | ||
examples/preformer/train.py:133:133 | ||
--8<-- | ||
``` | ||
|
||
#### 3.2.2 约束器构建 | ||
|
||
本案例基于数据驱动的方法求解问题,因此需要使用 PaddleScience 内置的 `SupervisedConstraint` 构建监督约束器。在定义约束器之前,需要首先指定约束器中用于数据加载的各个参数。 | ||
|
||
训练集数据加载的代码如下: | ||
|
||
``` py linenums="20" title="examples/preformer/train.py" | ||
--8<-- | ||
examples/preformer/train.py:44:79 | ||
--8<-- | ||
``` | ||
|
||
定义监督约束的代码如下: | ||
|
||
``` py linenums="40" title="examples/preformer/train.py" | ||
--8<-- | ||
examples/preformer/train.py:81:86 | ||
--8<-- | ||
``` | ||
|
||
#### 3.2.3 评估器构建 | ||
|
||
本案例训练过程中会按照一定的训练轮数间隔,使用验证集评估当前模型的训练情况,需要使用 `SupervisedValidator` 构建评估器。 | ||
|
||
验证集数据加载的代码如下: | ||
|
||
``` py linenums="44" title="examples/preformer/train.py" | ||
--8<-- | ||
examples/preformer/train.py:177:191 | ||
--8<-- | ||
``` | ||
|
||
定义监督评估器的代码如下: | ||
|
||
``` py linenums="65" title="examples/preformer/train.py" | ||
--8<-- | ||
examples/preformer/train.py:195:203 | ||
--8<-- | ||
``` | ||
|
||
#### 3.2.4 学习率与优化器构建 | ||
|
||
本案例中学习率大小设置为 `1e-3`,优化器使用 `Adam`,用 PaddleScience 代码表示如下: | ||
|
||
``` py linenums="83" title="examples/preformer/train.py" | ||
--8<-- | ||
examples/preformer/train.py:136:140 | ||
--8<-- | ||
``` | ||
|
||
#### 3.2.5 模型训练 | ||
|
||
完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练。 | ||
|
||
``` py linenums="88" title="examples/preformer/train.py" | ||
--8<-- | ||
examples/preformer/train.py:143:156 | ||
--8<-- | ||
``` | ||
|
||
## 4. 完整代码 | ||
|
||
``` py linenums="1" title="examples/preformer/main.py" | ||
--8<-- | ||
examples/preformer/main.py | ||
--8<-- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
defaults: | ||
- ppsci_default | ||
- TRAIN: train_default | ||
- TRAIN/ema: ema_default | ||
- TRAIN/swa: swa_default | ||
- EVAL: eval_default | ||
- INFER: infer_default | ||
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default | ||
- _self_ | ||
|
||
hydra: | ||
run: | ||
# dynamic output directory according to running time and override name | ||
dir: outputs_preformer | ||
job: | ||
name: ${mode} # name of logfile | ||
chdir: false # keep current working directory unchanged | ||
sweep: | ||
# output directory for multirun | ||
dir: ${hydra.run.dir} | ||
subdir: ./ | ||
|
||
# general settings | ||
mode: train # running mode: train/eval | ||
seed: 1024 | ||
output_dir: ${hydra:run.dir} | ||
log_freq: 20 | ||
|
||
# set training hyper-parameters | ||
SQ_LEN: 6 | ||
IMG_H: 192 | ||
IMG_W: 256 | ||
USE_SAMPLED_DATA: false | ||
|
||
# set train data path | ||
TRAIN_FILE_PATH: /patch/to/ERA5/ | ||
DATA_MEAN_PATH: examples/weather/datasets/era5/stat/mean.nc | ||
DATA_STD_PATH: examples/weather/datasets/era5/stat/std.nc | ||
|
||
# set evaluate data path | ||
VALID_FILE_PATH: /patch/to/ERA5/ | ||
|
||
# model settings | ||
MODEL: | ||
input_keys: ["input"] | ||
output_keys: ["output"] | ||
shape_in: | ||
- 6 | ||
- 12 | ||
- ${IMG_H} | ||
- ${IMG_W} | ||
|
||
# training settings | ||
TRAIN: | ||
epochs: 150 | ||
save_freq: 20 | ||
eval_during_train: true | ||
eval_freq: 20 | ||
lr_scheduler: | ||
epochs: ${TRAIN.epochs} | ||
learning_rate: 0.001 | ||
by_epoch: true | ||
batch_size: 16 | ||
pretrained_model_path: null | ||
checkpoint_path: null | ||
|
||
# evaluation settings | ||
EVAL: | ||
pretrained_model_path: null | ||
compute_metric_by_batch: true | ||
eval_with_no_grad: true | ||
batch_size: 16 |
Binary file not shown.
Binary file not shown.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件中的
examples/preformer/train.py
路径应该全部替换为examples/preformer/main.py