Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make groups, irreps, and gspaces into stateless, picklable singletons #78

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

kalekundert
Copy link
Contributor

My reason for making this PR is that I wanted to be able to pickle escnn modules for the purpose of making checkpoints during training. Before I started, I noticed #37, which has the same goal. However, it hasn't been merged because @Gabri95 thought it would be better to take advantage of the fact that group, irrep, and gspace instances don't really have state. The problems come from trying to pickle complicated internal caches maintained by those objects, but in principle there should be no need to even try do that.

I really agree with this approach, so that's what I tried to implement here. A big part of my implementation is a general-purpose singleton class. Groups are already singletons, but in a way that relies on (i) each subclass properly implementing the _keys() and _generator() methods and (ii) everyone using factory functions instead of constructors. I simplified this by moving all the singleton logic into a metaclass. Metaclasses can control every aspect of the object instantiation process (here's a good intro if you're not familiar), and this one intercepts the arguments to the constructor and uses them to decide whether to instantiate a new object or return an existing one. It also uses this information to pickle, compare, and hash the objects in question.

Irreps are also already singletons, although they're simpler because they have a single natural factory—the Group.irrep() method—and they don't have any subclasses. In this case, I don't think my new singleton metaclass would be much of an improvement. It wouldn't really simplify anything, as groups would still need to keep track of their own irreps, so I just added the necessary pickling code and kept everything else mostly as is. I did move the singleton logic out of the various group subclasses and into the base class, to reduce code duplication.

GSpaces are not currently singletons, but it would make sense for them to be, so I applied my singleton metaclass. That was the only change I really had to make.

Even after doing all of the above, it was still not possible to pickle R3Conv. The issue was that FieldType references a Representation, and representations need to have a function that they can use to turn group elements into matrices. The default function is a direct sum of irreps, possibly with a change-of-basis. I made this function pickleable by using functools.partial(), although another option would've been to turn it into a class (as in #37).


Important unresolved issue!

This PR is not ready to merge yet, because I ran into one conceptual issue that I'm not sure how to resolve. The issue relates to (i) the maximum_frequency argument to all the orthogonal groups (O2, SO2, O3, SO3), (ii) the Group.irreps() method, and (iii) the fact that singleton objects should be immutable.

When you instantiate (for example) the O2 group, you can specify a maximum frequency. This determines how many irreps are pre-calculated (although you can always get more by calling irrep() with the necessary arguments). If you later try to instantiate the O2 group again, because groups are singletons, you'll always get the same original object. But if you asked for more frequencies than before, that object will be updated with more irreps.

This is a problem because of the irreps() method, which simply returns all of the irreps that have been requested up to that point. Since this can change, either as more objects are instantiated with higher maximum frequencies, or as the irrep() method is called, these groups are not immutable. This leads to fragile, long-range dependencies in the code. For example:

g1 = o2_group(maximum_frequency=3)

# Hundreds of lines later, or in a different file...

g2 = o2_group(maximum_frequency=2)
function_that_requires_max_freq_3(g2)

This code works, because the first instantiation of the O2 group puts the second in the right state to succeed. But if the first line is changed, the second will mysteriously stop working. It turns out that this exact situation happens a bunch of times in the unit tests, and probably happens in real-life code as well. The potential for this kind of "spooky action at a distance" is why singleton objects (or any form of global state) should be immutable.

As I see it, the problem is the irreps() method. It shouldn't reveal the number of cached irreps. That means it either needs to accept an argument specifying how many irreps to return (i.e. a dimension or something like that), or it needs to be a generator that can literally yield an infinite number of irreps. I'm not sure which approach is best, and both seem like they could be very disruptive. So I want to get your thoughts on the matter before doing anything.

For what it's worth, this branch currently treats the maximum frequency parameter as part of the group's "state". I made that decision back when I didn't understand the code as well, and I've come to realize that it's basically the worst of both worlds. Conceptually, it's wrong because it's the same group no matter how many frequencies you look at. Practically, it doesn't change the fact that irreps() is mutable, e.g. if irrep() is called.


Minor changes that aren't directly related to the main goal of the PR:

  • I added the class name to the __repr__() for a number of classes, because without that information I was having a hard time debugging things. I tried to stick to the following convention: ClassName(*args, **kwargs) for reprs that would actually reconstruct the object if copy-and-pasted into a python interpreter, and ClassName[extra info] for reprs that have more free-form formatting. I think this is a substantial improvement, but I can revert it if you prefer things the way they were.

  • I gave the GSpaces names based on the factory function that was used to create them.

  • I reimplemented some hash functions. Now that groups, irreps, and gspaces are all hashable, they can easily be incorporated into hashes as necessary.

  • I refactored some test code, mostly to remove code duplication, and I fixed some failing tests.

  • I copied the Python .gitignore file from github/gitignore into the project. It recognizes a bunch of files that tend to pop up in python projects.

@Gabri95
Copy link
Collaborator

Gabri95 commented Aug 28, 2023

Hi @kalekundert

This is amazing, thanks a lot for all this precious work!! 😄

If you don't mind, this will take me a bit of time to really process :/

In principle, I fully support the singleton strategy!
My current implementation reflects my attempt to achieve singletons essentially (I am not very familiar with metaclasses, so maybe that's a cleaner way to achieve what I wanted actually).

Let me reply to a couple of points in the meantime:

As I see it, the problem is the irreps() method. It shouldn't reveal the number of cached irreps.

This method should indeed not be expected to have a "deterministic" behaviour.
In some occasions, it is useful to be able to loop over all irreps instantiated so far (with precisely this meaning), e.g. when trying to decompose a representation. In this occasion, indeed, the decomposition is always theoretically possible, but not enough irreps might have been instantiated yet, hence causing a failure.
It might be worth deprecating this method though (or at least hiding it from the user interface, although I didn't think about this too much yet).

function_that_requires_max_freq_3(g2)

A function of this kind should internally ensure all irreps up to frequency 3 have been instantiated so far.
Some groups have method to do that (e.g. the maximum_frequency keyword or the bl_irreps() method); see also my comments later.

I agree that this situation is a bit tricky, though. Indeed, as you pointed out, some unittests might fail sometimes depending on their execution order (this is because .irreps() is used multiple times). I tried to catch most of these cases by removing the use of .irreps() but I might have missed some.

That means it either needs to accept an argument specifying how many irreps to return (i.e. a dimension or something like that), or it needs to be a generator that can literally yield an infinite number of irreps.

Unfortunately, neither of these options are possible.
Irreps are "indexed" by an integer (e.g. a frequency) only in some special cases (like 2-d or 3-d rotations), but this is not true for any group (simply think of the group SO(2) x SO(2), whose irreps are indexed by pairs of frequencies).
That means having an infinite iterator is not sufficient (do we iterate on the first or second frequency first? this boils down to the problem of building a complete ordering of Z^2).
This is the reason why so far I took this approach:

  • irreps() returns all irreps instantiated so far
  • each group has a number of custom methods to return finite subsets of irreps (e.g. most rotation groups have a bl_irreps(L) method)

The user should mostly rely on the second methods for deterministic behaviour, but this requires a lot of ad-hoc code for each group. This is also why I preferred preserving the irreps() method.
However, I can find reasonable to deprecate .irreps() now to avoid the problems you mention.

For what it's worth, this branch currently treats the maximum frequency parameter as part of the group's "state".

If we deprecate .irreps(),is it still possible to keep the maximum_frequency argument? This is only needed for initialization of the hidden group._irreps list.
This is essentially useful for caching purpose, but shouldn't change the "state" of a singleton group.
If the user only interfaces with irreps via ad-hoc methods (i.e. not with irreps()), this shoudln't be an issue, right?

I added the class name to the repr() for a number of classes

I didn't do that for groups and representations since it would make the already verbose representation worse.
I also assumed that when the user prints an instance of group or representation, its main interest is understanding what group or what representation it is, not whether it is an instance of Group or Representation.
Do I understand correctly this is what you meant?

I gave the GSpaces names based on the factory function that was used to create them.

Does this break the singleton strategy? Two identical gspaces might be generated 1) manually by setting a particular subgroup id and 2) via a factory function. In these two cases, they will have a different representation (so I suppose also a different private attribute).
The idea seems nice though, I'd support this if it doesn't break the new singleton framework.

I reimplemented some hash functions. Now that groups, irreps, and gspaces are all hashable, they can easily be incorporated into hashes as necessary.

I refactored some test code, mostly to remove code duplication, and I fixed some failing tests.

copied the Python .gitignore file from github/gitignore into the project. It recognizes a bunch of files that tend to pop up in python projects.

Amazing, thaaaanks!

Again, this is quite some changes to process and I need to get more familiar with the metaclasses before properly commenting on this PR, sorry.
Unfortunately, I'm also a little busy at the moment so I cannot fully focus on this PR 😢
This means I might take quite some time for me to merge this PR.

However, I will try to keep this discussion going (as I get more familiar with your code) since I think this is a really important issue in the library.

I really appreciate your PR, it will be very useful! 🎉
Thanks a lot for your effort!

Best,
Gabriele

@kalekundert
Copy link
Contributor Author

If you don't mind, this will take me a bit of time to really process :/

No worries, I definitely understand. Take your time, and let me know if you have any questions about how anything works.

That means it either needs to accept an argument specifying how many irreps to return (i.e. a dimension or something like that), or it needs to be a generator that can literally yield an infinite number of irreps.

Unfortunately, neither of these options are possible.

Would it be possible to limit irreps based on their size? That is, generate all irreps that are 4x4 or smaller (in arbitrary order)? And if this is possible, would it be useful? Would the functions calling irreps() know the size of the irreps they need?

If we deprecate .irreps(), is it still possible to keep the maximum_frequency argument?

If we deprecate irreps(), I don't think there will be any need for the maximum_frequency argument any more. I would advocate for getting rid of it. All the other methods of generating irreps should know how many are needed, so the group can either use it's cache or generate the needed irreps behind the scenes.

That said, I assume that the maximum_frequency argument, or really the size of the _irreps cache, are implicitly used by various nn modules to encode things like how many basis functions to use when constructing filters. If this is the case, then I think this information should be removed from the group class and stored directly in the modules, since it is clearly state that pertains to the modules themselves and not the group as a whole. But this also sounds like it could be a disruptive change.

I agree that this situation is a bit tricky, though. Indeed, as you pointed out, some unittests might fail sometimes depending on their execution order...

In case it's useful, here are all the tests I found that seem to depend on the state of the irreps() cache. Note that these tests all pass in the master branch and fail in this branch:

nn.test_basisexpansion.TestBasisExpansion.test_conv2d
nn.test_linear.TestLinear.test_o2
kernelspaces.test_differentiable_restrictedwignereckart.TestWEbasis.test_sphere_o2_dihedral
kernelspaces.test_differentiable_restrictedwignereckart.TestWEbasis.test_sphere_o2_conical
nn.test_point_convolution.TestConvolution.test_o2
kernelspaces.test_differentiable_restrictedwignereckart.TestWEbasis.test_sphere_c2xso2_cyl
nn.test_amp.TestMixedPrecision.test_r2conv_non_uniform
kernelspaces.test_differentiable_wignereckart.TestWEbasis.test_circular_shell
kernelspaces.test_differentiable_restrictedwignereckart.TestWEbasis.test_sphere_c2xo2_fullcyl
nn.test_basisexpansion.TestBasisExpansion.test_linear
kernelspaces.test_wignereckart.TestWEbasis.test_circular_shell
kernelspaces.test_restrictedwignereckart.TestWEbasis.test_sphere_c2xo2_fullcyl
kernelspaces.test_restrictedwignereckart.TestWEbasis.test_sphere_o2_dihedral
kernelspaces.test_differentiable_kernels.TestBasesDifferentiable.test_o2_irreps_onR3
kernelspaces.test_restrictedwignereckart.TestWEbasis.test_sphere_c2xso2_cyl
nn.test_basisexpansion.TestBasisExpansion.test_conv3d
kernelspaces.test_equivariance.TestSolutionsEquivariance.test_o2_irreps_onR3
group.test_tensor_prod_representation.TestTensorProductRepresentation.test_tensor_irreps_o2

I only looked at the last one in detail. The error ultimately comes from escnn.group._clebsh_gordon._find_tensor_decomposition(), which I think is exactly the "decompose a representation" scenario that you brought up above. This function seems to know when it doesn't have enough irreps. Would it be possible for it to ask the group to generate more?

I gave the GSpaces names based on the factory function that was used to create them.

Does this break the singleton strategy?

The singleton strategy treats names as a special case. They aren't really part of the object state, since they don't affect how the object behaves, but they're useful for debugging and therefore worth keeping around. So what happens is that the names are ignored when (i) deciding whether to create a new instance or reuse an existing one, (ii) comparing two objects, and (iii) hashing an object. But, they're pickled and unpickled. Going back to your example with two identical gspaces created in different ways, the name would comes from whichever was created first (because that's the only object that will ever be created). I can see that being confusing in some situations, but I think it's made up for by the extra clarity in the most common situations.

It's also worth noting that groups already have names (and in fact have this same issue where you can get different names depending on whether or not you use a factory, e.g. the Klein 4 group), so I had to come up with some way of handling names anyways. Once I did that, it was natural to apply the same logic to gspaces.

Strictly speaking, it would be more principled to remove the name arguments completely (since they're not state), and to let each group/gspace subclass figure out what its name should be from the information it has. For some of the groups with factory functions, this would look something like:

class DirectProductGroup(Group):

    @property
    def name(self):
        if self.G1 is cyclic_group(2) and self.G2 is cyclic_group(2):
            return 'Klein 4'

        if self.G1 is ico_group(2) and self.G2 is cyclic_group(2):
            return 'Full Icosahedral'

        ...

The biggest downside to this approach is that it would make it impossible for users to give names to novel groups/gspaces that they construct. But that's just another way of saying that these classes are immutable, which is what they should be. I initially dismissed this approach for being too limiting, but now that I think about it more, I actually think it's better than what I did. Let me know if you have any objections, otherwise I'll probably switch to something more like this.

I added the class name to the repr() for a number of classes

I also assumed that when the user prints an instance of group or representation, its main interest is understanding what group or what representation it is, not whether it is an instance of Group or Representation.

The issue for me comes from composite objects like R3Conv. This repr includes input and output field types, which include gspaces and representations, which include groups and maybe irreps. So you end up getting a lot of output and not really knowing what any of it is.

For what it's worth, here's what the python docs have to say about the repr() function:

Return a string containing a printable representation of an object. For many types, this function makes an attempt to return a string that would yield an object with the same value when passed to eval(); otherwise, the representation is a string enclosed in angle brackets that contains the name of the type of the object together with additional information often including the name and address of the object.

Reading this now, I realize that I probably should've used the angle-bracket convention instead of my ad-hoc square-bracket convention. I do think it's best to include type names in repr strings, like the docs recommend, but I also don't care too strongly one way or the other, so I'm happy to go back to the old reprs if you prefer them. After all, you'll probably have to read these reprs more often than me. 😉

@Gabri95
Copy link
Collaborator

Gabri95 commented Aug 29, 2023

Would it be possible to limit irreps based on their size? That is, generate all irreps that are 4x4 or smaller (in arbitrary order)? And if this is possible, would it be useful? Would the functions calling irreps() know the size of the irreps they need?

This also wouldn't work :( Think about SO(2): all the irreps are 2 dimensional. I am not aware of any simple and general way to generate a finite set of irreps procedurally for any group unfortunately...

If we deprecate irreps(), I don't think there will be any need for the maximum_frequency argument any more. I would advocate for getting rid of it. All the other methods of generating irreps should know how many are needed, so the group can either use it's cache or generate the needed irreps behind the scenes.

only looked at the last one in detail. The error ultimately comes from escnn.group._clebsh_gordon._find_tensor_decomposition(), which I think is exactly the "decompose a representation" scenario that you brought up above. This function seems to know when it doesn't have enough irreps. Would it be possible for it to ask the group to generate more?

This is the main reason why I want to keep the maximum_frequency keyword to initialize the cache.
For rotation groups, the decomposition of the tensor product of frequency M and N contains irreps up to frequency M + N.
This decomposition happens whenever you build any linear or convolution layer.
Hence, whatever is the maximum frequency of the irreps you instantiated to build your network, you should ensure more irreps have been cached (e.g. if M is the max frequency you are using, you should cache up to frequency 2M).
Right now I'm manually initialising these caches with sufficiently high frequencies such that the user doesn't incur in these issues.
Unfortunately, this still happens sometimes in the unittests (this might be due to their execution order: some code caches more irreps and then these are tensored again by another test, reaching frequencies higher than those cached so far).

Because the tensor product decomposition is a general code which is agnostic of the group, it can only access the list of irreps cached so far. I thought about this for some time in the past, but I couldn't come up with a better strategy than the current one: the user must pre-cache sufficiently many irreps and then the backend of the library access the irreps() method. An error is thrown if some irreps are missing, and the user should manually instantiate them using the ad-hoc methods of the group currently used.

I hope this clarifies why I care about this maximum_frequency argument.

That said, I assume that the maximum_frequency argument, or really the size of the _irreps cache, are implicitly used by various nn modules to encode things like how many basis functions to use when constructing filters. If this is the case, then I think this information should be removed from the group class and stored directly in the modules, since it is clearly state that pertains to the modules themselves and not the group as a whole. But this also sounds like it could be a disruptive change.

I think this decomposition of tensor products is essentially the only place in the nn module which cares about this information. However, considering my previous comment, I don't think this information pertains to the modules themselves.

Maybe a good solution is considering groups as singleton as you recommend and the irreps cache as a separate entity.
The maximum frequency argument is still used to initialize the cache when the group is create (or one tries to re-instantiate it) but it has no effect on the group itself.
If a group is pickled, we also independently pickle the cached irreps (we only need to pickle the list of their ids to be able to re-instantiate them) and, similarly, when a group is unpickled we also unpickle its irreps cache.
This should achieve all we want:

  • groups are stateless
  • the irreps cache is properly maintained and can be conveniently initialized in a group-specific way
  • the irreps cache can be recovered whenever the group is unpickled, guaranteeing some deterministic behaviour
    Do you think something like this could work?

The singleton strategy treats names as a special case. They aren't really part of the object state, since they don't affect how the object behaves, but they're useful for debugging and therefore worth keeping around. So what happens is that the names are ignored when (i) deciding whether to create a new instance or reuse an existing one, (ii) comparing two objects, and (iii) hashing an object. But, they're pickled and unpickled. Going back to your example with two identical gspaces created in different ways, the name would

This solution sounds good for me!

Strictly speaking, it would be more principled to remove the name arguments completely (since they're not state), and to let each group/gspace subclass figure out what its name should be from the information it has. For some of the groups with factory functions, this would look something like: [...]

I'd avoid this solution, since the same abstract group could have different names.
For example, the group C_2 could be interpreted as the group of mirroring in 3D or the group of 180 degree planar rotations.
Similarly, the group C_2 x C_2 can be thought as the Klein group (typically we think of this as a subgroup of SO(3)) or the group of mirroring and 180deg rotations along a certain axis in 3D.
There exist a few cases where the same abstract group has different names in different contexts.

The issue for me comes from composite objects like R3Conv. This repr includes input and output field types, which include gspaces and representations, which include groups and maybe irreps. So you end up getting a lot of output and not really knowing what any of it is.

Do you think this can be more easily fixed within the repr of the R3Conv class?

For what it's worth, here's what the python docs have to say [...]

This is also a really good point actually. I will think a bit more about this

Thanks for the helpful discussion!

@kalekundert
Copy link
Contributor Author

Ok, I know you've told me twice now that there's no way to generate the necessary finite sets of irreps for arbitrary groups, so I apologize if I'm beating a dead horse, but I want to revisit the infinite generator idea. I think the reason I'm so stuck on trying to find a way to generate irreps on the fly is that (i) it seems like the most conceptually "right" thing to do and (ii) everything works if the group is initialized with enough irreps, so it seems like it should be possible (in the worst case) to find some way of continually re-initializing the group until there are enough irreps.

It's clear how an infinite irrep generator would be implemented for finite groups (just yield all the irreps) and rotational groups (yield the irreps in frequency order). The only other groups in the codebase, I believe, are the direct product and double groups. The double group is basically just a direct product between a group and itself, so the direct product case is the only one we need to consider.

Right now, the direct product group considers all pairs of cached irreps from the two groups being combined. For rotational groups, this means it considers all the irreps below the frequency threshold that was used to populate the cache. We can write an infinite generator that similarly considers all pairs below any given frequency before any pairs above that frequency. The trick is to alternate between the two groups every time a new irrep is needed:

from itertools import cycle

def direct_product_irreps(g1, g2):
    irrep_iterators = [
        g1.yield_all_irreps(),
        g2.yield_all_irreps(),
    ]

    irreps = [[], []]

    for i in cycle([0, 1]):
        irrep_i = next(irrep_iterators[i])
        irreps[i].append(irrep_i)

        for irrep_j in irreps[(i + 1) % 2]:
            yield (irrep_i, irrep_j)[::1 if i == 0 else -1]
      
        # Some extra logic would be needed to terminate in the case of finite 
        # groups, but that wouldn't change anything conceptually.

This would start by yielding all the same irreps as if the groups were both initialized with maximum_frequency == 1, then would yield the irreps as if maximum_frequency == 2, and so on forever. Here are the first 10 pairs that would be generated by this function, if it were given two infinite iterators that both independently count up from 0:

(0, 0)
(1, 0)
(0, 1)
(1, 1)
(2, 0)
(2, 1)
(0, 2)
(1, 2)
(2, 2)
(3, 0)

Is there something I'm overlooking? I think this would effectively behave in the same way as the code already does, but without requiring any externally observable state. In fact, it might even result in fewer Clebsh-Gordon coefficients needing to be calculated, since it goes in order from low to high frequencies. (Although I don't know if those calculations have any meaningful effect on runtime. Probably not, because they only need to be calculated once.)

Maybe a good solution is considering groups as singleton as you recommend and the irreps cache as a separate entity.

In order for this to work, the group wouldn't be able to provide access to the cache, because doing so would effectively make the group mutable. The cache itself also couldn't be a global object, because the whole point is that global objects should never be mutable. So the cache would either have to be a wrapper around the group, or provided alongside the group. The former would probably be easier to implement, especially if the cache wrapper were to implement the whole group interface.

I think the end result would actually be pretty similar (most of the time) to the way this branch behaves right now, where the maximum frequency is just considered part of the group's state. This might not be as bad of a solution as I thought earlier. What it really means is that if you instantiate (for example) O2 twice with different maximum frequencies, you'll get two different objects that each have their own separate caches. There's probably some potential for confusion if you try to have two different O2 instances interact with each other (because the fact that they're different might not be obvious), but I think the most common thing anyways is just to create a single gspace with a single group and to use it for everything.

I do think that if we take this approach, we'd still have to modify the irrep() method of the rotational groups to either not update the cache, or to raise an exception if an uncached irrep is requested. In other words, since this approach means making the state of the cache part of the public interface for these groups, any means of modifying the cache would have to be removed.

I'd avoid this solution, since the same abstract group could have different names.

Ok, I'll leave it as it is.

Do you think this can be more easily fixed within the repr of the R3Conv class?

Well, it was pretty easy to fix by changing the gspace/representation/irrep reprs (I didn't actually change the group repr, since I never has any trouble recognizing group names), so I wouldn't say that it would be easier to fix by changing R3Conv. R3Conv probably isn't the only class that has a composite repr like this, either. But I don't think ease really matters here, it's just a (subjective) question of whether the extra verbosity makes thing more or less clear.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants