Skip to content

Commit

Permalink
Merge pull request #181 from GilesStrong/refactor_remove_weighting
Browse files Browse the repository at this point in the history
Refactor remove weighting
  • Loading branch information
GilesStrong authored Feb 10, 2024
2 parents 56dd930 + a40c87c commit 6ef1dcc
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 190 deletions.
20 changes: 10 additions & 10 deletions tests/test_forwards_backwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,9 @@ def fixed_budget_sigmoid_panel_inferrer() -> PanelX0Inferrer:
],
)
def test_forwards_panel(mode, inferrer):
pred, weight = inferrer.get_prediction()
pred = inferrer.get_prediction()
loss_func = VoxelX0Loss(target_budget=None, cost_coef=None)
loss_val = loss_func(pred, weight, inferrer.volume)
loss_val = loss_func(pred, inferrer.volume)

if "fixed-budget" in mode:
assert torch.autograd.grad(loss_val, inferrer.volume.budget_weights, retain_graph=True, allow_unused=False)[0].abs().sum() > 0
Expand All @@ -263,9 +263,9 @@ def test_forwards_panel(mode, inferrer):


# def test_forwards_deep_panel(deep_inferrer):
# pred, weight = deep_inferrer.get_prediction()
# pred = deep_inferrer.get_prediction()
# loss_func = VolumeClassLoss(target_budget=1, cost_coef=1e-5, x02id={1: 1})
# loss_val = loss_func(pred, weight, deep_inferrer.volume)
# loss_val = loss_func(pred, deep_inferrer.volume)

# for l in deep_inferrer.volume.get_detectors():
# for p in l.panels:
Expand All @@ -284,9 +284,9 @@ def test_forwards_panel(mode, inferrer):
],
)
def test_backwards_panel(mode, inferrer):
pred, weight = inferrer.get_prediction()
pred = inferrer.get_prediction()
loss_func = VoxelX0Loss(target_budget=1, cost_coef=0.15)
loss_val = loss_func(pred, weight, inferrer.volume)
loss_val = loss_func(pred, inferrer.volume)
opt = torch.optim.SGD(inferrer.volume.parameters(), lr=1)
opt.zero_grad()
loss_val.backward()
Expand All @@ -310,15 +310,15 @@ def test_backwards_panel(mode, inferrer):

@pytest.mark.flaky(max_runs=3, min_passes=1)
def test_backwards_heatmap(heatmap_inferrer):
pred, weight = heatmap_inferrer.get_prediction()
pred = heatmap_inferrer.get_prediction()
init_params = defaultdict(lambda: defaultdict(dict))
for i, l in enumerate(heatmap_inferrer.volume.get_detectors()):
for j, p in enumerate(l.panels):
init_params[i][j]["mu"] = p.mu.detach().clone()
init_params[i][j]["sig"] = p.mu.detach().clone()

loss_func = VoxelX0Loss(target_budget=1, cost_coef=0.15)
loss_val = loss_func(pred, weight, heatmap_inferrer.volume)
loss_val = loss_func(pred, heatmap_inferrer.volume)
opt = torch.optim.SGD(heatmap_inferrer.volume.parameters(), lr=1)
opt.zero_grad()
loss_val.backward()
Expand All @@ -339,9 +339,9 @@ def test_backwards_heatmap(heatmap_inferrer):


# def test_backwards_deep_panel(deep_inferrer):
# pred, weight = deep_inferrer.get_prediction()
# pred = deep_inferrer.get_prediction()
# loss_func = VolumeClassLoss(target_budget=1, cost_coef=0.15, x02id={1: 1})
# loss_val = loss_func(pred, weight, deep_inferrer.volume)
# loss_val = loss_func(pred, deep_inferrer.volume)
# opt = torch.optim.SGD(deep_inferrer.volume.parameters(), lr=1)
# opt.zero_grad()
# loss_val.backward()
Expand Down
34 changes: 9 additions & 25 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,14 +431,13 @@ def test_panel_x0_inferrer_methods(mocker): # noqa F811
assert inferrer._combine_scatters.call_count == 1
assert inferrer._get_voxel_zxy_x0_preds.call_count == 1

p1, w1 = inferrer.get_prediction()
p1 = inferrer.get_prediction()
assert inferrer._combine_scatters.call_count == 1
assert inferrer._get_voxel_zxy_x0_preds.call_count == 1
assert isinstance(inferrer._muon_probs_per_voxel_zxy, Tensor)

true = volume.get_rad_cube()
assert p1.shape == true.shape
assert w1.shape == torch.Size([])
assert inferrer._muon_probs_per_voxel_zxy.shape == torch.Size([len(sb)]) + true.shape
assert (p1 != p1).sum() == 0 # No NaNs
assert (((p1 - true)).abs() / true).mean() < 100
Expand All @@ -448,8 +447,6 @@ def test_panel_x0_inferrer_methods(mocker): # noqa F811
assert torch.autograd.grad(p1.abs().sum(), panel.xy_span, retain_graph=True, allow_unused=True)[0].abs().sum() > 0
assert torch.autograd.grad(p1.abs().sum(), panel.xy, retain_graph=True, allow_unused=True)[0].abs().sum() > 0
assert torch.autograd.grad(p1.abs().sum(), panel.z, retain_graph=True, allow_unused=True)[0].abs().sum() > 0
assert torch.autograd.grad(w1.abs().sum(), panel.xy_span, retain_graph=True, allow_unused=True)[0].abs().sum() > 0
assert torch.autograd.grad(w1.abs().sum(), panel.xy, retain_graph=True, allow_unused=True)[0].abs().sum() > 0

# Multiple batches
mus = MuonResampler.resample(gen(N), volume=volume, gen=gen)
Expand All @@ -472,9 +469,8 @@ def test_panel_x0_inferrer_methods(mocker): # noqa F811
assert inferrer.n_mu == len(sb) + len(sb2)
assert inferrer._combine_scatters.call_count == 2

p2, w2 = inferrer.get_prediction() # Averaged prediction slightly changes with new batch
p2 = inferrer.get_prediction() # Averaged prediction slightly changes with new batch
assert (p2 - p1).abs().sum() > 1e-2
assert (w2 - w1).abs().sum() > 1e-2
assert inferrer._get_voxel_zxy_x0_preds.call_count == 2


Expand All @@ -488,7 +484,7 @@ def test_panel_inferrer_multi_batch():
# one batch
inf = PanelX0Inferrer(volume=volume)
inf.add_scatters(ScatterBatch(mu=mu, volume=volume))
pred1, weight1 = inf.get_prediction()
pred1 = inf.get_prediction()

# multi-batch
inf = PanelX0Inferrer(volume=volume)
Expand All @@ -503,10 +499,9 @@ def test_panel_inferrer_multi_batch():
for xy_pos in mu._hits[pos][var]:
mu_batch._hits[pos][var].append(xy_pos[mask])
inf.add_scatters(ScatterBatch(mu=mu_batch, volume=volume))
pred4, weight4 = inf.get_prediction()
pred4 = inf.get_prediction()

assert (((pred1 - pred4) / pred1).abs() < 1e-4).all()
assert (((weight1 - weight4) / weight1).abs() < 1e-4).all()


def test_panel_x0_inferrer_efficiency(mocker, panel_scatter_batch): # noqa F811
Expand Down Expand Up @@ -676,18 +671,14 @@ def get_passives(self):
# inputs = inferrer._build_inputs(inferrer.in_vars[0])
# assert inputs.shape == torch.Size((600, nvalid, n_infeats + 4)) # +4 since voxels and dpoca_r

# pred, weight = inferrer.get_prediction()
# pred = inferrer.get_prediction()
# assert pred.shape == torch.Size((1, 1))
# assert weight.shape == torch.Size(())

# for l in volume.get_detectors():
# for panel in l.panels:
# assert torch.autograd.grad(pred.abs().sum(), panel.xy_span, retain_graph=True, allow_unused=True)[0].abs().sum() > 0
# assert torch.autograd.grad(pred.abs().sum(), panel.xy, retain_graph=True, allow_unused=True)[0].abs().sum() > 0
# assert torch.autograd.grad(pred.abs().sum(), panel.z, retain_graph=True, allow_unused=True)[0].abs().sum() > 0
# assert torch.autograd.grad(weight.abs().sum(), panel.xy_span, retain_graph=True, allow_unused=True)[0].abs().sum() > 0
# assert torch.autograd.grad(weight.abs().sum(), panel.xy, retain_graph=True, allow_unused=True)[0].abs().sum() > 0
# assert torch.autograd.grad(weight.abs().sum(), panel.z, retain_graph=True, allow_unused=True)[0].abs().sum() == 0


@pytest.mark.flaky(max_runs=2, min_passes=1)
Expand All @@ -709,14 +700,12 @@ def u_rad_length(*, z: float, lw: Tensor, size: float) -> Tensor:
inferrer = DenseBlockClassifierFromX0s(12, PanelX0Inferrer, volume=volume)
inferrer.add_scatters(sb)

p, w = inferrer.get_prediction()
p = inferrer.get_prediction()
for l in volume.get_detectors():
for panel in l.panels:
assert torch.autograd.grad(p.abs().sum(), panel.xy_span, retain_graph=True, allow_unused=True)[0].abs().sum() > 0
assert torch.autograd.grad(p.abs().sum(), panel.xy, retain_graph=True, allow_unused=True)[0].abs().sum() > 0
assert torch.autograd.grad(p.abs().sum(), panel.z, retain_graph=True, allow_unused=True)[0].abs().sum() > 0
assert torch.autograd.grad(w.abs().sum(), panel.xy_span, retain_graph=True, allow_unused=True)[0].abs().sum() > 0
assert torch.autograd.grad(w.abs().sum(), panel.xy, retain_graph=True, allow_unused=True)[0].abs().sum() > 0


def test_abs_int_classifier_from_x0():
Expand All @@ -735,30 +724,25 @@ def x02probs(self, vox_preds: Tensor) -> Tensor:
inferrer = Inf(partial_x0_inferrer=PanelX0Inferrer, volume=volume, output_probs=True)
inferrer.add_scatters(sb)

p, w = inferrer.get_prediction()
p = inferrer.get_prediction()
assert p.shape == torch.Size([6])
assert w.shape == torch.Size([])

for l in volume.get_detectors():
for panel in l.panels:
assert jacobian(p, panel.xy_span).abs().sum() > 0
assert jacobian(p, panel.xy).abs().sum() > 0
assert jacobian(p, panel.z).abs().sum() > 0
assert torch.autograd.grad(w.abs().sum(), panel.xy_span, retain_graph=True, allow_unused=True)[0].abs().sum() > 0
assert torch.autograd.grad(w.abs().sum(), panel.xy, retain_graph=True, allow_unused=True)[0].abs().sum() > 0

# Single prediction
inferrer = Inf(partial_x0_inferrer=PanelX0Inferrer, volume=volume, output_probs=False)
inferrer.add_scatters(sb)
p, w = inferrer.get_prediction()
p = inferrer.get_prediction()
assert p.type() == "torch.LongTensor"
assert p.shape == torch.Size([])
assert w.shape == torch.Size([])

# Single float prediction
inferrer = Inf(partial_x0_inferrer=PanelX0Inferrer, volume=volume, output_probs=False, class2float=lambda x, v: 3.5 * x)
inferrer.add_scatters(sb)
p, w = inferrer.get_prediction()
p = inferrer.get_prediction()
assert p.type() == "torch.FloatTensor"
assert p.shape == torch.Size([])
assert w.shape == torch.Size([])
Loading

0 comments on commit 6ef1dcc

Please sign in to comment.