Skip to content

Commit

Permalink
[r2] fix seeds in se_a and se_atten (#3880) (#3947)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Bug Fixes**
- Resolved inconsistencies in seed values by incrementing `self.seed`
conditionally in descriptor modules.
  
- **Tests**
- Updated test arrays `refe`, `reff`, and `refv` with new reference
values.
- Adjusted expected values in `test_model_ener` method for better
accuracy.

These changes ensure more reliable descriptor computations and improved
test accuracy.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------


(cherry picked from commit 0c472d1)

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Jul 3, 2024
1 parent a85d58f commit 84ca63c
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 33 deletions.
7 changes: 6 additions & 1 deletion deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,8 @@ def _filter_lower(
mixed_prec=self.mixed_prec,
)
net_output = tf.nn.embedding_lookup(net_output, idx)
if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift
net_output = tf.reshape(net_output, [-1, self.filter_neuron[-1]])
else:
xyz_scatter = self._concat_type_embedding(
Expand All @@ -1042,7 +1044,7 @@ def _filter_lower(
)
# natom x 4 x outputs_size
if nvnmd_cfg.enable:
return filter_lower_R42GR(
oo = filter_lower_R42GR(
type_i,
type_input,
inputs_i,
Expand All @@ -1060,6 +1062,9 @@ def _filter_lower(
self.filter_resnet_dt,
self.embedding_net_variables,
)
if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift
return oo
if self.compress and (not is_exclude):
if self.stripped_type_embedding:
net_output = tf.nn.embedding_lookup(
Expand Down
16 changes: 14 additions & 2 deletions deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,8 @@ def _attention_layers(
uniform_seed=self.uniform_seed,
initial_variables=self.attention_layer_variables,
)
if not self.uniform_seed and self.seed is not None:
self.seed += 1
K_c = one_layer(
input_xyz,
self.att_n,
Expand All @@ -972,6 +974,8 @@ def _attention_layers(
uniform_seed=self.uniform_seed,
initial_variables=self.attention_layer_variables,
)
if not self.uniform_seed and self.seed is not None:
self.seed += 1
V_c = one_layer(
input_xyz,
self.att_n,
Expand All @@ -985,6 +989,8 @@ def _attention_layers(
uniform_seed=self.uniform_seed,
initial_variables=self.attention_layer_variables,
)
if not self.uniform_seed and self.seed is not None:
self.seed += 1
# # natom x nei_type_i x out_size
# xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1] // 4, outputs_size[-1]))
# natom x nei_type_i x att_n
Expand Down Expand Up @@ -1017,6 +1023,8 @@ def _attention_layers(
uniform_seed=self.uniform_seed,
initial_variables=self.attention_layer_variables,
)
if not self.uniform_seed and self.seed is not None:
self.seed += 1
input_xyz = tf.keras.layers.LayerNormalization(
beta_initializer=tf.constant_initializer(self.beta[i]),
gamma_initializer=tf.constant_initializer(self.gamma[i]),
Expand Down Expand Up @@ -1080,6 +1088,8 @@ def _filter_lower(
initial_variables=self.embedding_net_variables,
mixed_prec=self.mixed_prec,
)
if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift
else:
if self.attn_layer == 0:
log.info(
Expand Down Expand Up @@ -1119,6 +1129,8 @@ def _filter_lower(
initial_variables=self.embedding_net_variables,
mixed_prec=self.mixed_prec,
)
if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift
else:
net = "filter_net"
info = [
Expand Down Expand Up @@ -1176,6 +1188,8 @@ def _filter_lower(
initial_variables=self.two_side_embeeding_net_variables,
mixed_prec=self.mixed_prec,
)
if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift
two_embd = tf.nn.embedding_lookup(
embedding_of_two_side_type_embedding, index_of_two_side
)
Expand All @@ -1194,8 +1208,6 @@ def _filter_lower(
is_sorted=len(self.exclude_types) == 0,
)

if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift
input_r = tf.slice(
tf.reshape(inputs_i, (-1, shape_i[1] // 4, 4)), [0, 0, 1], [-1, -1, 3]
)
Expand Down
56 changes: 28 additions & 28 deletions source/tests/test_model_se_a_ebd_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,37 +139,37 @@ def test_model(self):
f = f.reshape([-1])
v = v.reshape([-1])

refe = [5.435394596262052014e-01]
refe = [6.100037044296185e-01]
reff = [
6.583728125594628944e-02,
7.228993116083935744e-02,
1.971543579114074483e-03,
6.567474563776359853e-02,
7.809421727465599983e-02,
-4.866958849094786890e-03,
-8.670511901715304004e-02,
3.525374157021862048e-02,
1.415748959800727487e-03,
6.375813001810648473e-02,
-1.139053242798149790e-01,
-4.178593754384440744e-03,
-1.471737787218250215e-01,
4.189712704724830872e-02,
7.011731363309440038e-03,
3.860874082716164030e-02,
-1.136296927731473005e-01,
-1.353471298745012206e-03,
8.448651008616304e-02,
8.613568658155157e-02,
4.377711655236228e-03,
9.264613309788312e-02,
9.351200240060925e-02,
-6.743918515275118e-03,
-1.268078358219972e-01,
4.855965861982662e-02,
1.361334787979757e-04,
4.193213089916692e-02,
-1.324120032345251e-01,
-4.507320444374342e-03,
-1.314595297986654e-01,
4.120567370248839e-02,
7.896917575801866e-03,
3.920259153744955e-02,
-1.370010180699507e-01,
-1.159523750186610e-03,
]
refv = [
-4.243979601186427253e-01,
1.097173849143971286e-01,
1.227299373463585502e-02,
1.097173849143970314e-01,
-2.462891443164323124e-01,
-5.711664180530139426e-03,
1.227299373463585502e-02,
-5.711664180530143763e-03,
-6.217348853341628408e-04,
-0.277134219204478,
0.088897922530779,
0.008633318264458,
0.088897922530779,
-0.292191560546969,
-0.005709595520904,
0.008633318264458,
-0.005709595520904,
-0.000682136341924,
]
refe = np.reshape(refe, [-1])
reff = np.reshape(reff, [-1])
Expand Down
4 changes: 2 additions & 2 deletions source/tests/test_pairwise_dprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,8 @@ def test_model_ener(self):
# the model is pairwise!
self.assertAllClose(e[1] + e[2] + e[3] - 3 * e[0], e[4] - e[0])
self.assertAllClose(f[1] + f[2] + f[3] - 3 * f[0], f[4] - f[0])
self.assertAllClose(e[0], 0.189075, 1e-6)
self.assertAllClose(f[0, 0], 0.060047, 1e-6)
self.assertAllClose(e[0], 4.82969, 1e-6)
self.assertAllClose(f[0, 0], -0.104339, 1e-6)

def test_nloc(self):
jfile = tests_path / "pairwise_dprc.json"
Expand Down

0 comments on commit 84ca63c

Please sign in to comment.