From 1fad61979d1614e99b5ea159cc6b0ea38a8fe378 Mon Sep 17 00:00:00 2001 From: DongHyun Choi Date: Wed, 10 Apr 2024 16:18:58 -0700 Subject: [PATCH] Deprecate aqt's make_dot_general. PiperOrigin-RevId: 623636468 --- chirp/models/efficientnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chirp/models/efficientnet.py b/chirp/models/efficientnet.py index 122a40d3..4e034493 100644 --- a/chirp/models/efficientnet.py +++ b/chirp/models/efficientnet.py @@ -23,11 +23,11 @@ from typing import Callable, NamedTuple from aqt.jax.v2 import aqt_conv_general -from aqt.jax.v2 import aqt_dot_general from aqt.jax.v2 import config as aqt_cfg from chirp.models import layers from flax import linen as nn import flax.typing as flax_typing +import jax from jax import numpy as jnp @@ -131,7 +131,7 @@ class OpSet: sigmoid=nn.hard_sigmoid, stem_activation=nn.hard_swish, head_activation=nn.hard_swish, - dot_general=aqt_dot_general.make_dot_general(None), + dot_general=jax.lax.dot_general, conv_general_dilated=aqt_conv_general.make_conv_general_dilated( aqt_cfg.DotGeneralRaw.make_conv_general_dilated() ),