Skip to content

Commit

Permalink
Implement basic scaled operations for MLP model.
Browse files Browse the repository at this point in the history
TODO
  • Loading branch information
samho authored and balancap committed Nov 14, 2023
1 parent bba779d commit 6469c37
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,23 @@ def scaled_transpose(A: ScaledArray, permutation: Sequence[int]) -> ScaledArray:
@core.register_scaled_lax_op
def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(A.data * B.data, A.scale * B.scale)


def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray:
output_scale = lax.sqrt(A.scale**2 + B.scale**2)
# check correct type output if mismatch between data and scale precision
output_data = (A.scale / output_scale) * A.data + (B.scale / output_scale) * B.data
return ScaledArray(output_data, output_scale)


def scaled_sub(A: ScaledArray, B: ScaledArray) -> ScaledArray:
output_scale = lax.sqrt(A.scale**2 + B.scale**2)
# check correct type output if mismatch between data and scale precision
output_data = (A.scale / output_scale) * A.data - (B.scale / output_scale) * B.data
return ScaledArray(output_data, output_scale)


def scaled_dot(A: ScaledArray, B: ScaledArray) -> ScaledArray:
output_scale = A.scale * B.scale * lax.sqrt(A.data.shape[-1])
output_data = lax.dot(A.data, B.data) / lax.sqrt(A.data.shape[-1])
return ScaledArray(output_data, output_scale)

0 comments on commit 6469c37

Please sign in to comment.