diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 3c3fb000..a78eedea 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -167,6 +167,7 @@ defmodule Bumblebee do "PhiModel" => {Bumblebee.Text.Phi, :base}, "PhiForCausalLM" => {Bumblebee.Text.Phi, :for_causal_language_modeling}, "PhiForSequenceClassification" => {Bumblebee.Text.Phi, :for_sequence_classification}, + "PhiForTokenClassification" => {Bumblebee.Text.Phi, :for_token_classification}, "ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification}, "ResNetModel" => {Bumblebee.Vision.ResNet, :base}, "RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling}, diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index bb2a52f7..4bc4a988 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -382,9 +382,21 @@ defmodule Bumblebee.Layers do * `:kernel_initializer` - initializer for `kernel` weights. Defaults to `:glorot_uniform` + * `:bias_initializer` - initializer for `bias` weights. Defaults + to `:zeros`. + + * `:use_bias` - whether the layer should add bias to the output. + Defaults to `false` + """ def dense_transposed(%Axon{} = x, units, opts \\ []) do - opts = Keyword.validate!(opts, [:name, kernel_initializer: :glorot_uniform]) + opts = + Keyword.validate!(opts, [ + :name, + kernel_initializer: :glorot_uniform, + bias_initializer: :zeros, + use_bias: false + ]) kernel_shape = fn input_shape -> kernel_shape = Axon.Shape.dense_kernel(input_shape, units) @@ -396,13 +408,24 @@ defmodule Bumblebee.Layers do |> List.to_tuple() end + bias_shape = &Axon.Shape.dense_bias(&1, units) + kernel = Axon.param("kernel", kernel_shape, initializer: opts[:kernel_initializer]) - op = fn x, kernel, _opts -> - Nx.dot(x, [-1], kernel, [1]) - end + {inputs, op} = + if opts[:use_bias] do + bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer]) + {[x, kernel, bias], &dense_transposed_impl/4} + else + {[x, kernel], &dense_transposed_impl/3} + end - Axon.layer(op, [x, kernel], name: opts[:name], op_name: :dense_transposed) + Axon.layer(op, inputs, name: opts[:name], op_name: :dense_transposed) + end + + deftransformp dense_transposed_impl(x, kernel, bias \\ 0, _opts) do + Nx.dot(x, [-1], kernel, [1]) + |> Nx.add(bias) end @doc """ diff --git a/lib/bumblebee/text/phi.ex b/lib/bumblebee/text/phi.ex index de65145c..6f4f5828 100644 --- a/lib/bumblebee/text/phi.ex +++ b/lib/bumblebee/text/phi.ex @@ -47,6 +47,14 @@ defmodule Bumblebee.Text.Phi do default: :gelu_approx_tanh, doc: "the activation function" ], + rotary_embedding_percentage: [ + default: 0.5, + doc: "percentage of the query and keys that will have rotary embedding" + ], + rotary_embedding_base: [ + default: 10_000, + doc: "base for computing rotary embedding frequency" + ], layer_norm_epsilon: [ default: 1.0e-12, doc: "the epsilon used by RMS normalization layers" @@ -55,14 +63,6 @@ defmodule Bumblebee.Text.Phi do default: 0.02, doc: "the standard deviation of the normal initializer used for initializing kernel parameters" - ], - rotary_embedding_base: [ - default: 10_000, - doc: "base for computing rotary embedding frequency" - ], - rotary_embedding_percentage: [ - default: 0.5, - doc: "percentage of the query and keys that will have rotary embedding" ] ] ++ Shared.common_options([ @@ -87,6 +87,10 @@ defmodule Bumblebee.Text.Phi do classification head. The head returns logits corresponding to possible classes + * `:for_token_classification` - Phi with a token classification + head. The head returns logits for each token in the original + sequence + ## Inputs * `"input_ids"` - `{batch_size, sequence_length}` @@ -144,7 +148,8 @@ defmodule Bumblebee.Text.Phi do do: [ :base, :for_causal_language_modeling, - :for_sequence_classification + :for_sequence_classification, + :for_token_classification ] @impl true @@ -237,6 +242,28 @@ defmodule Bumblebee.Text.Phi do }) end + def model(%__MODULE__{architecture: :for_token_classification} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + logits = + outputs.hidden_state + |> Axon.dropout( + rate: 0.1, + name: "token_classification_head.dropout" + ) + |> Axon.dense(spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "token_classification_head.output" + ) + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + defp inputs(spec) do shape = {nil, nil} hidden_shape = {nil, nil, spec.hidden_size} @@ -332,7 +359,7 @@ defmodule Bumblebee.Text.Phi do intermediate_size: spec.intermediate_size, activation: spec.activation ], - block_type: :parallel, + block_type: &block_impl/3, causal: true, rotary_embedding: [ position_ids: position_ids, @@ -350,13 +377,33 @@ defmodule Bumblebee.Text.Phi do ) end + # :parallel block with attention norm applied earlier and without ffn norm + defp block_impl(hidden_state, steps, _name) do + shortcut = hidden_state + + hidden_state = steps.self_attention_norm.(hidden_state) + + {attention_hidden_state, attention_info} = steps.self_attention.(hidden_state) + + {_hidden_state, cross_attention_info} = + steps.cross_attention_maybe.(hidden_state, fn _hidden_state -> + raise "cross attention not supported" + end) + + ffn_hidden_state = steps.ffn.(hidden_state) + + hidden_state = Axon.add([shortcut, attention_hidden_state, ffn_hidden_state]) + + {hidden_state, attention_info, cross_attention_info} + end + defp language_modeling_head(hidden_state, spec, opts) do name = opts[:name] - # TODO: Tie to word embedding as spec option - Axon.dense(outputs.hidden_state, spec.vocab_size, + # TODO: Tie lm-head to word embedding as a spec option + Layers.dense_transposed(hidden_state, spec.vocab_size, kernel_initializer: kernel_initializer(spec), - name: "language_modeling_head.output", + name: join(name, "output"), use_bias: true ) end @@ -382,7 +429,7 @@ defmodule Bumblebee.Text.Phi do rotary_embedding_base: {"rope_theta", number()}, rotary_embedding_percentage: {"partial_rotary_factor", number()}, initializer_scale: {"initializer_range", number()}, - layer_norm_epsilon: {"rms_norm_eps", number()} + layer_norm_epsilon: {"layer_norm_eps", number()} ) ++ Shared.common_options_from_transformers(data, spec) @for.config(spec, opts) @@ -402,13 +449,10 @@ defmodule Bumblebee.Text.Phi do "model.layers.{n}.self_attn.rotary_emb", "decoder.blocks.{n}.ffn.intermediate" => "model.layers.{n}.mlp.fc1", "decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.fc2", - # TODO: This is a hack because their parallel layer norm implementation - # actually does no output norm, but instead uses the normalized hidden - # state. Though I thought it was overkill to add another block type? - "decoder.blocks.{n}.output_norm" => "model.layers.{n}.input_layernorm", - "output_norm" => "final_layernorm", + "output_norm" => "model.final_layernorm", "language_modeling_head.output" => "lm_head", - "sequence_classification_head.output" => "score" + "sequence_classification_head.output" => "score", + "token_classification_head.output" => "classifier" } end end diff --git a/test/bumblebee/text/phi_test.exs b/test/bumblebee/text/phi_test.exs index 4249b37f..83a2f3d4 100644 --- a/test/bumblebee/text/phi_test.exs +++ b/test/bumblebee/text/phi_test.exs @@ -7,7 +7,7 @@ defmodule Bumblebee.Text.PhiTest do test ":base" do assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "seanmor5/tiny-random-phi"}) + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-PhiModel"}) assert %Bumblebee.Text.Phi{architecture: :base} = spec @@ -22,9 +22,74 @@ defmodule Bumblebee.Text.PhiTest do assert_all_close( outputs.hidden_state[[.., 1..3, 1..3]], - Nx.tensor([ - [[-0.4979, 0.0582, -0.3027], [0.0697, 0.5218, -0.5603], [-0.0904, 0.6243, -0.5626]] - ]) + Nx.tensor([[[-0.3275, 0.5231, 0.5690], [0.2239, 0.5028, 0.4599], [-0.0979, 1.0183, 0.3350]]]) + ) + end + + test ":for_sequence_classification" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "bumblebee-testing/tiny-random-PhiForSequenceClassification"} + ) + + assert %Bumblebee.Text.Phi{architecture: :for_sequence_classification} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 2} + + assert_all_close( + outputs.logits, + Nx.tensor([[0.1403, -0.1382]]) + ) + end + + test ":for_token_classification" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "bumblebee-testing/tiny-random-PhiForTokenClassification"} + ) + + assert %Bumblebee.Text.Phi{architecture: :for_token_classification} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 10, 2} + + assert_all_close( + outputs.logits[[.., 1..3//1, ..]], + Nx.tensor([[[-0.0364, -0.1207], [0.2520, 0.0755], [0.0243, 0.0269]]]) + ) + end + + test ":for_causal_language_modeling" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-PhiForCausalLM"}) + + assert %Bumblebee.Text.Phi{architecture: :for_causal_language_modeling} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 10, 1024} + + assert_all_close( + outputs.logits[[.., 1..3, 1..3]], + Nx.tensor([[[0.2541, 0.0827, 0.0526], [0.1901, 0.1289, 0.0758], [0.1051, 0.0658, -0.1167]]]) ) end end