Skip to content

Commit

Permalink
Fixing Package Compiling issue
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Jan 19, 2024
1 parent 1b04902 commit 8a72c9e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from jax import tree_util
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask_info as mask_info_lib
from . import splash_attention_mask as mask_lib
from . import splash_attention_mask_info as mask_info_lib
import jax.numpy as jnp
import numpy as np

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import functools
from typing import Callable, Dict, List, NamedTuple, Set, Tuple
from jax import util as jax_util
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
from . import splash_attention_mask as mask_lib
import numpy as np


Expand Down

0 comments on commit 8a72c9e

Please sign in to comment.