Skip to content

Commit

Permalink
Deprecate aqt's make_dot_general.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623636468
  • Loading branch information
cdh4696 authored and copybara-github committed Apr 10, 2024
1 parent e27e953 commit 1fad619
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions chirp/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from typing import Callable, NamedTuple

from aqt.jax.v2 import aqt_conv_general
from aqt.jax.v2 import aqt_dot_general
from aqt.jax.v2 import config as aqt_cfg
from chirp.models import layers
from flax import linen as nn
import flax.typing as flax_typing
import jax
from jax import numpy as jnp


Expand Down Expand Up @@ -131,7 +131,7 @@ class OpSet:
sigmoid=nn.hard_sigmoid,
stem_activation=nn.hard_swish,
head_activation=nn.hard_swish,
dot_general=aqt_dot_general.make_dot_general(None),
dot_general=jax.lax.dot_general,
conv_general_dilated=aqt_conv_general.make_conv_general_dilated(
aqt_cfg.DotGeneralRaw.make_conv_general_dilated()
),
Expand Down

0 comments on commit 1fad619

Please sign in to comment.