diff --git a/torch_frame/data/mapper.py b/torch_frame/data/mapper.py index 3c388f79..c019b364 100644 --- a/torch_frame/data/mapper.py +++ b/torch_frame/data/mapper.py @@ -110,7 +110,7 @@ def forward( def backward(self, tensor: Tensor) -> pd.Series: index = tensor.cpu().numpy() - ser = pd.Series(self.categories[index].index) + ser = pd.Series(self.categories.iloc[index].index) ser[index < 0] = None return ser