diff --git a/PyITA/ITA.py b/PyITA/ITA.py index 039771a..6d6f390 100644 --- a/PyITA/ITA.py +++ b/PyITA/ITA.py @@ -556,7 +556,7 @@ def soft(self, no_partial_softmax = False): def step5_AV(self): self.O_soft = np.array( - [np.matmul(self.A_partial_softmax[i], self.Vp_requant[i], dtype = np.int32) for i in range(self.H)]) + [np.matmul(self.A_partial_softmax[i].astype(np.uint8), self.Vp_requant[i], dtype = np.int32) for i in range(self.H)]) self.O_soft = np.clip(self.O_soft, -2**(self.WO - 1), 2**(self.WO - 1) - 1) self.O_soft_requant = requantize(self.O_soft, self.requant_eps_mult[4], self.requant_right_shift[4], self.requant_add[4])