-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeepONet tutorial for Advection Problem #217
base: 0.2
Are you sure you want to change the base?
Conversation
# This test is wrong! The user could define a custom network and do a | ||
# reshape at the end! A more general way to check input and output is | ||
# needed | ||
# def test_constructor_fails_when_invalid_inner_layer_size(): | ||
# branch_net = FeedForward(input_dimensions=1, output_dimensions=10) | ||
# trunk_net = FeedForward(input_dimensions=2, output_dimensions=8) | ||
# with pytest.raises(ValueError): | ||
# DeepONet(branch_net=branch_net, | ||
# trunk_net=trunk_net, | ||
# input_indeces_branch_net=['a'], | ||
# input_indeces_trunk_net=['b', 'c'], | ||
# reduction='+', | ||
# aggregator='*') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not understand why the check is removed. What does it mean The user could define a custom network and do a reshape at the end
?
# check trunk branch nets consistency | ||
shapes = [] | ||
for key, value in networks.items(): | ||
check_consistency(value, (str, int)) | ||
check_consistency(key, torch.nn.Module) | ||
input_ = torch.rand(10, len(value)) | ||
shapes.append(key(input_).shape[-1]) | ||
# TODO: Fix it does not work | ||
# # check trunk branch nets consistency | ||
# shapes = [] | ||
# for key, value in networks.items(): | ||
# check_consistency(value, (str, int)) | ||
# check_consistency(key, torch.nn.Module) | ||
# input_ = torch.rand(10, len(value)) | ||
# shapes.append(key(input_).shape[-1]) | ||
|
||
if not all(map(lambda x: x == shapes[0], shapes)): | ||
raise ValueError('The passed networks have not the same ' | ||
'output dimension.') | ||
# if not all(map(lambda x: x == shapes[0], shapes)): | ||
# raise ValueError('The passed networks have not the same ' | ||
# 'output dimension.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why disabling checks about the shapes?
The comments are made because the checks are just correct for 2D inputs. I commented because it is not still clear what is the data format for all models (2d/3d?). |
@ndem0 can we merge or do we wait for fixing the input format? |
In this PR I uploaded a new tutorial for the
DeepONet
class since it was missing.When writing the tutorial I noticed some bugs which I corrected.
Here I report a list of the bugs:
MIONet._aggregator
) had the problem that the input was already stacked, but sometimes it is not possible to stack (e.g.net1
output shape[T, N, d]
,net2
output shape[T, d]
and aggregation performseinsum('TNd,Td->TNd', ...)
). So to avoid this problem now theMIONet._aggregator
function takes as input a tuple of tensors (also more simple for the user to understand)..reshape
with.unsqueeze
inoutput_ = self._reduction(aggregated).reshape(-1, 1)
since the final aim of the operation is to add an extra dimension, and not reshape the tensor.Finally, I remove the check consistency of the network outputs before aggregation and put a warning message that the consistency is up to the user to check. This is not because I believe it should be up to the user the check, but it is a permanent fix since the way we were checking the consistency is wrong. Indeed consider what we were doing:
This is wrong if the user wants to specify a custom network whose inputs is not in the form
[N, len(value)]
. For example, consider the following net:Due to the input reshape the line
shapes.append(key(input_).shape[-1])
will raise an error since we are trying to slice a two dimensional tensor (input_ = torch.rand(10, len(value))
) like a three dimensional one (linex[:, 0, :]
).The problem to check network consistency is more general than just MIONet or DeepONet, as also in FNO we have the same problem. Maybe we can think to make a PR just for this problem.