-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Excessive memory usage when the number of trees is large (e.g. 500 trees) #31
Comments
This problem can be reproduces even on very small training sets (e.g. 1e4 samples) but with |
This should hopefully be fixed by #21 right? At least partially |
Memory use is still much higher with pygbm. Maybe we could try to reuse the grower instance over and over again with some kind of a |
that's exactly what LGBM is doing BTW. I think it would make sense yes, 500 trees = 500 array allocations of size n_samples so it takes a toll. But that means somehow the old grower instances aren't properly garbage collected? |
This is what I suspect. Maybe numba is keeping some internal references on jitclass instances? |
I tried with toy examples outside of pygbm and everything looks fine (no leak). |
So that would suggest that the problem doesn't come from some garbage collector issue? I tried a quick hack # reset grower
def reset(self, all_gradients, all_hessians, shrinkage=1.):
self.splitting_context.reset(all_gradients, all_hessians)
self.shrinkage = shrinkage
del self.splittable_nodes # just to make sure
del self.finalized_leaves
self.splittable_nodes = []
self.finalized_leaves = []
self.total_find_split_time = 0. # time spent finding the best splits
self.total_apply_split_time = 0. # time spent splitting nodes
self._intilialize_root()
self.n_nodes = 1 # reset split info
def reset(self, all_gradients, all_hessians):
self.all_gradients[:] = all_gradients
self.all_hessians[:] = all_hessians
self.sum_gradients = self.all_gradients.sum()
self.sum_hessians = self.all_hessians.sum()
self.ordered_gradients[:] = all_gradients.copy()
self.ordered_hessians[:] = all_hessians.copy() but I'm still observing the huge VIRT usage (about 12GB while I only have 8) on the benchmark from #30 after ~100 trees. Note that this doesn't reuse the space allocated for the histograms so maybe it comes from there, but I doubt it. For a given node the histograms take up |
https://mg.pov.lt/objgraph/ might be helpful at finding the cause of the memory growth. |
I tried a bit with objgraph and the only new Python-level object between 2 iteration of gradient boosting is a |
Yeah but that makes sense right? I am 90% sure that if everything were correctly garbage collected, we wouldn't have any leak. And the TreePredictor objects are super small anyway. So if we assume that CPython's GC is working, that means there's something going on with the numba interaction. putting this inside the main loop of print(objgraph.at(id(grower)))
print(objgraph.at(id(grower.splitting_context))) will give me
So this suggests that the splittingcontext isn't "tracked by the GC" at least according to
|
Interesting. This is really weird. We should try to come up with a minimal reproduction case by pruning the code of pygbm while reproducing the memory "leak" until it's no longer specific to pygbm concepts and reduced to basic numpy + numba primitives. |
Fun stuff incoming. I patched l = []
p = psutil.Process()
while True: # <-- main training loop
use = p.memory_info().rss
print(f"{use / 1e6} MB")
l.append(use)
should_stop = self._stopping_criterion(
....
return self, l and then I plot this list for various from sklearn.datasets import make_classification
from pygbm import GradientBoostingMachine
import matplotlib.pyplot as plt
sizes = [1e1, 1e2, 1e3, 1e4, 1e5]
n_trees = 1000
n_leaf_nodes = 20
fig, ax = plt.subplots(1)
for n_samples in sizes:
n_samples = int(n_samples)
X, y = make_classification(n_samples=n_samples, random_state=0)
est = GradientBoostingMachine(learning_rate=1, max_iter=n_trees,
max_leaf_nodes=n_leaf_nodes,
random_state=0, scoring=None,
verbose=1, validation_split=None)
_, l = est.fit(X, y)
ax.plot([x / 1e6 for x in l], '-', label=f'n_samples={n_samples}')
ax.set_ylabel('MB used')
ax.set_xlabel('iteration')
plt.legend()
plt.show() (changes are on my leak branch) A few observations:
|
This could be cause by a cyclic reference. You can break cyclic references by forcing a Note that I had already tried to add this inside the main loop of |
You should try with a smaller learning rate, e.g. 0.01. I noticed that with a large learning rate and many trees, the boosting algorithms can degenerate into producing small decision stumps (only 2 leaves). Maybe the additive gradient descent process has converged: gain would be too small to produce any gain improvement with larger trees. I have not checked. |
The only datastructure that could cause this growth are the histograms stored on each grower nodes: Assuming 500 approximately balanced trees (with 255 nodes with stored histograms), that would result in 22 GB if they are not collected: >>> 500 * 255 * 28 * 256 * 8 * 3 / 1e9
21.93408 |
I think I found the culprit: it seems to be the |
Shouldn't this be But yeah that would make sense, the leak in my plots is approximately 1.7MB per iteration which is close the histogram size Can't wait for the fix!! :) |
Yes you are right: 4 bytes instead of 8. I had a preliminary fix but forgot to commit it... |
@NicolasHug you can have a look at #36. |
As investigated in #30, when the number of trees is large, htop reports VIRT memory that exceeds the physical memory of the host and causes computer hangs that slows down computation.
What is weird is that RES memory seems to stay constant.
The text was updated successfully, but these errors were encountered: