Skip to content

Commit

Permalink
Up
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Mar 1, 2024
1 parent 9e2441d commit 8412007
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 29 deletions.
1 change: 1 addition & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ defmodule Bumblebee do
"PhiModel" => {Bumblebee.Text.Phi, :base},
"PhiForCausalLM" => {Bumblebee.Text.Phi, :for_causal_language_modeling},
"PhiForSequenceClassification" => {Bumblebee.Text.Phi, :for_sequence_classification},
"PhiForTokenClassification" => {Bumblebee.Text.Phi, :for_token_classification},
"ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification},
"ResNetModel" => {Bumblebee.Vision.ResNet, :base},
"RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
Expand Down
33 changes: 28 additions & 5 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,21 @@ defmodule Bumblebee.Layers do
* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `false`
"""
def dense_transposed(%Axon{} = x, units, opts \\ []) do
opts = Keyword.validate!(opts, [:name, kernel_initializer: :glorot_uniform])
opts =
Keyword.validate!(opts, [
:name,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros,
use_bias: false
])

kernel_shape = fn input_shape ->
kernel_shape = Axon.Shape.dense_kernel(input_shape, units)
Expand All @@ -396,13 +408,24 @@ defmodule Bumblebee.Layers do
|> List.to_tuple()
end

bias_shape = &Axon.Shape.dense_bias(&1, units)

kernel = Axon.param("kernel", kernel_shape, initializer: opts[:kernel_initializer])

op = fn x, kernel, _opts ->
Nx.dot(x, [-1], kernel, [1])
end
{inputs, op} =
if opts[:use_bias] do
bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer])
{[x, kernel, bias], &dense_transposed_impl/4}
else
{[x, kernel], &dense_transposed_impl/3}
end

Axon.layer(op, [x, kernel], name: opts[:name], op_name: :dense_transposed)
Axon.layer(op, inputs, name: opts[:name], op_name: :dense_transposed)
end

deftransformp dense_transposed_impl(x, kernel, bias \\ 0, _opts) do
Nx.dot(x, [-1], kernel, [1])
|> Nx.add(bias)
end

@doc """
Expand Down
84 changes: 64 additions & 20 deletions lib/bumblebee/text/phi.ex
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ defmodule Bumblebee.Text.Phi do
default: :gelu_approx_tanh,
doc: "the activation function"
],
rotary_embedding_percentage: [
default: 0.5,
doc: "percentage of the query and keys that will have rotary embedding"
],
rotary_embedding_base: [
default: 10_000,
doc: "base for computing rotary embedding frequency"
],
layer_norm_epsilon: [
default: 1.0e-12,
doc: "the epsilon used by RMS normalization layers"
Expand All @@ -55,14 +63,6 @@ defmodule Bumblebee.Text.Phi do
default: 0.02,
doc:
"the standard deviation of the normal initializer used for initializing kernel parameters"
],
rotary_embedding_base: [
default: 10_000,
doc: "base for computing rotary embedding frequency"
],
rotary_embedding_percentage: [
default: 0.5,
doc: "percentage of the query and keys that will have rotary embedding"
]
] ++
Shared.common_options([
Expand All @@ -87,6 +87,10 @@ defmodule Bumblebee.Text.Phi do
classification head. The head returns logits corresponding to
possible classes
* `:for_token_classification` - Phi with a token classification
head. The head returns logits for each token in the original
sequence
## Inputs
* `"input_ids"` - `{batch_size, sequence_length}`
Expand Down Expand Up @@ -144,7 +148,8 @@ defmodule Bumblebee.Text.Phi do
do: [
:base,
:for_causal_language_modeling,
:for_sequence_classification
:for_sequence_classification,
:for_token_classification
]

@impl true
Expand Down Expand Up @@ -237,6 +242,28 @@ defmodule Bumblebee.Text.Phi do
})
end

def model(%__MODULE__{architecture: :for_token_classification} = spec) do
inputs = inputs(spec)
outputs = core(inputs, spec)

logits =
outputs.hidden_state
|> Axon.dropout(
rate: 0.1,
name: "token_classification_head.dropout"
)
|> Axon.dense(spec.num_labels,
kernel_initializer: kernel_initializer(spec),
name: "token_classification_head.output"
)

Layers.output(%{
logits: logits,
hidden_states: outputs.hidden_states,
attentions: outputs.attentions
})
end

defp inputs(spec) do
shape = {nil, nil}
hidden_shape = {nil, nil, spec.hidden_size}
Expand Down Expand Up @@ -332,7 +359,7 @@ defmodule Bumblebee.Text.Phi do
intermediate_size: spec.intermediate_size,
activation: spec.activation
],
block_type: :parallel,
block_type: &block_impl/3,
causal: true,
rotary_embedding: [
position_ids: position_ids,
Expand All @@ -350,13 +377,33 @@ defmodule Bumblebee.Text.Phi do
)
end

# :parallel block with attention norm applied earlier and without ffn norm
defp block_impl(hidden_state, steps, _name) do
shortcut = hidden_state

hidden_state = steps.self_attention_norm.(hidden_state)

{attention_hidden_state, attention_info} = steps.self_attention.(hidden_state)

{_hidden_state, cross_attention_info} =
steps.cross_attention_maybe.(hidden_state, fn _hidden_state ->
raise "cross attention not supported"
end)

ffn_hidden_state = steps.ffn.(hidden_state)

hidden_state = Axon.add([shortcut, attention_hidden_state, ffn_hidden_state])

{hidden_state, attention_info, cross_attention_info}
end

defp language_modeling_head(hidden_state, spec, opts) do
name = opts[:name]

# TODO: Tie to word embedding as spec option
Axon.dense(outputs.hidden_state, spec.vocab_size,
# TODO: Tie lm-head to word embedding as a spec option
Layers.dense_transposed(hidden_state, spec.vocab_size,
kernel_initializer: kernel_initializer(spec),
name: "language_modeling_head.output",
name: join(name, "output"),
use_bias: true
)
end
Expand All @@ -382,7 +429,7 @@ defmodule Bumblebee.Text.Phi do
rotary_embedding_base: {"rope_theta", number()},
rotary_embedding_percentage: {"partial_rotary_factor", number()},
initializer_scale: {"initializer_range", number()},
layer_norm_epsilon: {"rms_norm_eps", number()}
layer_norm_epsilon: {"layer_norm_eps", number()}
) ++ Shared.common_options_from_transformers(data, spec)

@for.config(spec, opts)
Expand All @@ -402,13 +449,10 @@ defmodule Bumblebee.Text.Phi do
"model.layers.{n}.self_attn.rotary_emb",
"decoder.blocks.{n}.ffn.intermediate" => "model.layers.{n}.mlp.fc1",
"decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.fc2",
# TODO: This is a hack because their parallel layer norm implementation
# actually does no output norm, but instead uses the normalized hidden
# state. Though I thought it was overkill to add another block type?
"decoder.blocks.{n}.output_norm" => "model.layers.{n}.input_layernorm",
"output_norm" => "final_layernorm",
"output_norm" => "model.final_layernorm",
"language_modeling_head.output" => "lm_head",
"sequence_classification_head.output" => "score"
"sequence_classification_head.output" => "score",
"token_classification_head.output" => "classifier"
}
end
end
Expand Down
73 changes: 69 additions & 4 deletions test/bumblebee/text/phi_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ defmodule Bumblebee.Text.PhiTest do

test ":base" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "seanmor5/tiny-random-phi"})
Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-PhiModel"})

assert %Bumblebee.Text.Phi{architecture: :base} = spec

Expand All @@ -22,9 +22,74 @@ defmodule Bumblebee.Text.PhiTest do

assert_all_close(
outputs.hidden_state[[.., 1..3, 1..3]],
Nx.tensor([
[[-0.4979, 0.0582, -0.3027], [0.0697, 0.5218, -0.5603], [-0.0904, 0.6243, -0.5626]]
])
Nx.tensor([[[-0.3275, 0.5231, 0.5690], [0.2239, 0.5028, 0.4599], [-0.0979, 1.0183, 0.3350]]])
)
end

test ":for_sequence_classification" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf, "bumblebee-testing/tiny-random-PhiForSequenceClassification"}
)

assert %Bumblebee.Text.Phi{architecture: :for_sequence_classification} = spec

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
}

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.logits) == {1, 2}

assert_all_close(
outputs.logits,
Nx.tensor([[0.1403, -0.1382]])
)
end

test ":for_token_classification" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf, "bumblebee-testing/tiny-random-PhiForTokenClassification"}
)

assert %Bumblebee.Text.Phi{architecture: :for_token_classification} = spec

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
}

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.logits) == {1, 10, 2}

assert_all_close(
outputs.logits[[.., 1..3//1, ..]],
Nx.tensor([[[-0.0364, -0.1207], [0.2520, 0.0755], [0.0243, 0.0269]]])
)
end

test ":for_causal_language_modeling" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-PhiForCausalLM"})

assert %Bumblebee.Text.Phi{architecture: :for_causal_language_modeling} = spec

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
}

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.logits) == {1, 10, 1024}

assert_all_close(
outputs.logits[[.., 1..3, 1..3]],
Nx.tensor([[[0.2541, 0.0827, 0.0526], [0.1901, 0.1289, 0.0758], [0.1051, 0.0658, -0.1167]]])
)
end
end

0 comments on commit 8412007

Please sign in to comment.