Skip to content

Commit

Permalink
[CINN]Fix topk indices no grad problem (#2935)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurelius84 authored Sep 5, 2024
1 parent 71828c0 commit dfa3642
Show file tree
Hide file tree
Showing 16 changed files with 48 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def forward(
var_11 = paddle.tensor.search.where(var_9, var_10, var_8)
var_12 = var_2.flatten()
var_13 = var_11.flatten()
# bool/int tensors has no grad
var_12.stop_gradient = True
var_13.stop_gradient = True
return var_12, var_13


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def forward(
var_20 = paddle.tensor.search.where(var_18, var_19, var_11)
var_21 = var_2.flatten()
var_22 = var_20.flatten()
# bool/int tensors has no grad
var_21.stop_gradient = True
var_22.stop_gradient = True
return var_21, var_22


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def forward(
var_20 = paddle.tensor.search.where(var_18, var_19, var_11)
var_21 = var_2.flatten()
var_22 = var_20.flatten()
# bool/int tensors has no grad
var_21.stop_gradient = True
var_22.stop_gradient = True
return var_21, var_22


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def forward(
var_20 = paddle.tensor.search.where(var_18, var_19, var_11)
var_21 = var_2.flatten()
var_22 = var_20.flatten()
# bool/int tensors has no grad
var_21.stop_gradient = True
var_22.stop_gradient = True
return var_21, var_22


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def forward(
var_20 = paddle.tensor.search.where(var_18, var_19, var_11)
var_21 = var_2.flatten()
var_22 = var_20.flatten()
# bool/int tensors has no grad
var_21.stop_gradient = True
var_22.stop_gradient = True
return var_21, var_22


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def forward(
var_20 = paddle.tensor.search.where(var_18, var_19, var_11)
var_21 = var_2.flatten()
var_22 = var_20.flatten()
# bool/int tensors has no grad
var_21.stop_gradient = True
var_22.stop_gradient = True
return var_21, var_22


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def forward(
var_11 = paddle.tensor.search.where(var_9, var_10, var_8)
var_12 = var_2.flatten()
var_13 = var_11.flatten()
# bool/int tensors has no grad
var_12.stop_gradient = True
var_13.stop_gradient = True
return var_12, var_13


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def forward(
var_20 = paddle.tensor.search.where(var_18, var_19, var_11)
var_21 = var_2.flatten()
var_22 = var_20.flatten()
# bool/int tensors has no grad
var_21.stop_gradient = True
var_22.stop_gradient = True
return var_21, var_22


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def forward(
var_20 = paddle.tensor.search.where(var_18, var_19, var_11)
var_21 = var_2.flatten()
var_22 = var_20.flatten()
# bool/int tensors has no grad
var_21.stop_gradient = True
var_22.stop_gradient = True
return var_21, var_22


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def forward(
var_11 = paddle.tensor.search.where(var_9, var_10, var_8)
var_12 = var_2.flatten()
var_13 = var_11.flatten()
# bool/int tensors has no grad
var_12.stop_gradient = True
var_13.stop_gradient = True
return var_12, var_13


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def forward(
var_20 = paddle.tensor.search.where(var_18, var_19, var_11)
var_21 = var_2.flatten()
var_22 = var_20.flatten()
# bool/int tensors has no grad
var_21.stop_gradient = True
var_22.stop_gradient = True
return var_21, var_22


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def forward(
var_20 = paddle.tensor.search.where(var_18, var_19, var_11)
var_21 = var_2.flatten()
var_22 = var_20.flatten()
# bool/int tensors has no grad
var_21.stop_gradient = True
var_22.stop_gradient = True
return var_21, var_22


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def forward(
var_20 = paddle.tensor.search.where(var_18, var_19, var_11)
var_21 = var_2.flatten()
var_22 = var_20.flatten()
# bool/int tensors has no grad
var_21.stop_gradient = True
var_22.stop_gradient = True
return var_21, var_22


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def forward(
var_11 = paddle.tensor.search.where(var_9, var_10, var_8)
var_12 = var_2.flatten()
var_13 = var_11.flatten()
# bool/int tensors has no grad
var_12.stop_gradient = True
var_13.stop_gradient = True
return var_12, var_13


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def forward(
var_20 = paddle.tensor.search.where(var_18, var_19, var_11)
var_21 = var_2.flatten()
var_22 = var_20.flatten()
# bool/int tensors has no grad
var_21.stop_gradient = True
var_22.stop_gradient = True
return var_21, var_22


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def forward(
var_11 = paddle.tensor.search.where(var_9, var_10, var_8)
var_12 = var_2.flatten()
var_13 = var_11.flatten()
# bool/int tensors has no grad
var_12.stop_gradient = True
var_13.stop_gradient = True
return var_12, var_13


Expand Down

0 comments on commit dfa3642

Please sign in to comment.