-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Differentiable local field projection #1822
base: develop
Are you sure you want to change the base?
Conversation
9e2bc35
to
cecf530
Compare
👀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a few comments. I'm mostly a little confused about the base.py stuff?
260293e
to
b387268
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments.. Mostly things I thought could be refactored, renamed, simplified.
Otherwise I think it's consistent / correct. Only thing I'm not 100% sure on is in the DataArray.static
property.
vals = self.values if self.values.size == 0 else np.vectorize(getval)(self.values)
will getval
work if self.values
is a np.ndarray of dtype=object?
otherwise feeling pretty good about it in its current state
65ff9c3
to
5542fff
Compare
5ec140c
to
2fe0315
Compare
2fe0315
to
5bb4a0a
Compare
de9a4ba
to
3cd6ae9
Compare
52b3b7f
to
77c5973
Compare
c770e70
to
c5a5caa
Compare
@tylerflex There are some changes to |
c5a5caa
to
350d1db
Compare
@momchil-flex what do you think? basically there were some changes in my (recently merged) #1923 that Yannick used in this PR. So ideally we would do another develop -> pre/2.8 merge before Yannick rebases against 2.8? |
@yaugenst-flex what do you think will be simplest? I'm already running notebook tests for 2.7.3 and since it's hard to tell if this PR may not have various small effects, I'd rather not include it for now. However we could merge it to develop after 2.7.3 and then eventually into pre/2.8, or we can merge develop into pre/2.8 again and then you rebase the PR. Either way it will likely only come out in 2.8.0rc1 since we may not do any other 2.7. patches, but if we did the first approach would put it in there too. |
@momchil-flex I think it makes sense either way to merge develop into pre/2.8 after the 2.7.3 release so we keep them somewhat in sync. I'll rebase to pre/2.8 after that, shouldn't be a problem. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really good! thanks @yaugenst-flex . I dont know if you want @QimingFlex to look at the field projection bits? I couldn't understand most of the physics of the operations, but did follow through the code and thought it looked good.
@@ -38,6 +38,7 @@ dask = "*" | |||
toml = "*" | |||
autograd = "1.6.2" | |||
scipy = "*" | |||
opt-einsum = "^3.3.0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to be careful about this due to the danger of bloating our core dependencies. It became an issue recently and we agreed to let MC know if we add any dependencies. Is this absolutely necessary you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
opt-einsum
is a pretty tiny dependency (~200kb) and itself only depends on numpy, and it's already a dependency of jax (although that's not core). Otherwise I can also rewrite that part using regular numpy einsum, not a big deal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lei-flex what do you think about adding this dependency?
@@ -10,4 +10,5 @@ extend-ignore = [ | |||
"E731", # lambda assignment | |||
"F841", # unused local variable | |||
"S101", # asserts allowed in tests | |||
"NPY201", # numpy 2.* compatibility check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
snuck in from another PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually no, this got flagged (np.trapz
usage) and I disabled it. We should enable it again once we support numpy 2.
return super().conj(*args, **kwargs) | ||
|
||
@property | ||
def real(self): | ||
"""Return the real part of this DataArray.""" | ||
if self.tracers is not None: | ||
return self.__array_wrap__(anp.real(self.tracers)) | ||
if isbox(self.values.flat[0]) or AUTOGRAD_KEY in self.attrs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps we can refactor this logic out (eg. the one in real, conj, etc).
def apply_traced(self, fn):
if isbox(self.values.flat[0]) or AUTOGRAD_KEY in self.attrs:
tracers = anp.array(self.values.tolist()))
new_values = fn(tracers)
return self.insert_tracers(new_values)
return fn(self)
@property
def real(self):
return self.apply_traced(anp.real)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also it seems the anp.array(self.values.tolist())
and if
statement both could use the convenience methods / functions used to check these things?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really sure how to handle the super()
call this way? But yeah I can at least partially refactor this..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need super or can np.real(self)
suffice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't that be infinite recursion? pretty sure np.real(self)
will call self.real
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am blanking on a lot of the discussions we had regarding the data array handing, but will mostly assume that this logic follows what we already talked about before?
Is there anything you think I should take a more detailed look at?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think it's fine for the scope of this PR, so no need to dive deep for now, I vote that we put together a proper DataArray refactor soonish though.
power_phi = 0.5 * np.real(-self.Ephi.values * np.conj(self.Htheta.values)) | ||
conj = np.vectorize(np.conj) | ||
real = np.vectorize(np.real) | ||
power_theta = 0.5 * real(self.Etheta.values * conj(self.Hphi.values)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to do this treatment in other parts of this class? I think this kind of operation does pop up from time to time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to differentiate w.r.t. those methods then yeah this needs to be done. Although I'm hoping that we can get rid of this "hack" with a DataArray refactor...
@@ -144,7 +143,12 @@ The following components are traceable as outputs of the `td.SimulationData` | |||
- `SimulationData.get_intensity` | |||
- `SimulationData.get_poynting` | |||
|
|||
- Local field projections using `td.FieldProjector.project_fields()` for: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was this mentioned as a "to do" item? if so would be good to remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not seeing that anywhere in this PR, but I know I've seen it somewhere...
Hey @QimingFlex, would be great if you could have a quick look at the field projection bits in this PR, I reworked some of it to more easily accomodate autograd. I did try to make sure that all of the field projection values that come out of this are still exactly the same as before, but it wouldn't hurt to have another pair of eyes on it. |
0e92e95
to
c064c42
Compare
c064c42
to
862d5b8
Compare
No description provided.