diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py index 7ec097d806..c150191695 100644 --- a/src/adapters/methods/bottleneck.py +++ b/src/adapters/methods/bottleneck.py @@ -215,7 +215,7 @@ def repeat(self, state: BottleneckState, channels: int) -> BottleneckState: state.bottleneck_up.repeat(channels, 1, 1) if state.bottleneck_up is not None else None, ) - def mean(self, states: List[NamedTuple], weights: torch.Tensor) -> NamedTuple: + def mean(self, states: List[BottleneckState], weights: torch.Tensor) -> BottleneckState: return BottleneckState( torch.mean(torch.stack([s.hidden_states for s in states], 0) * weights, dim=0), states[0].input_tensor,