Skip to content

Commit

Permalink
Use bf16-rounded sqrt for scaling embeddings to match Gemma
Browse files Browse the repository at this point in the history
Thanks Daniel & Michael Han for pointing this out.
https://unsloth.ai/blog/gemma-bugs

PiperOrigin-RevId: 615250003
  • Loading branch information
jan-wassenberg authored and copybara-github committed Mar 13, 2024
1 parent 0221956 commit 5fa2eb1
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions gemma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#ifndef GEMMA_ONCE
#define GEMMA_ONCE

#include <math.h> // sqrtf
#include <stddef.h>
#include <stdio.h>

Expand Down Expand Up @@ -426,6 +427,25 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
activations.ffw_out.data() + batch_idx * kModelDim, pool);
}

// __builtin_sqrt is not constexpr as of Clang 17.
#if HWY_COMPILER_GCC_ACTUAL && defined(HWY_HAVE_SCALAR_BF16_OPERATORS) && \
HWY_HAVE_SCALAR_BF16_OPERATORS
#define GEMMA_CONSTEXPR_SQRT constexpr
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) {
return __builtin_sqrt(x);
}
#else
#define GEMMA_CONSTEXPR_SQRT
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
#endif

template <typename TConfig>
GEMMA_CONSTEXPR_SQRT float EmbeddingScaling() {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(TConfig::kModelDim))));
}

template <typename TConfig, size_t kBatchSize>
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
const CompressedWeights<TConfig>& c_weights,
Expand All @@ -434,8 +454,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
hwy::ThreadPool& inner_pool) {
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
static constexpr size_t kModelDim = TConfig::kModelDim;
static const float kEmbScaling =
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling<TConfig>();

pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
Expand Down Expand Up @@ -490,12 +509,10 @@ void Transformer(int token, size_t pos,
static constexpr size_t kLayers = TConfig::kLayers;
static constexpr size_t kModelDim = TConfig::kModelDim;

static const float kEmbScaling =
static_cast<float>(sqrt(static_cast<double>(kModelDim)));

Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
activations.x.data(), kModelDim);

const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling<TConfig>();
MulByConst(kEmbScaling, activations.x.data(), kModelDim);

for (size_t layer = 0; layer < kLayers; ++layer) {
Expand Down

0 comments on commit 5fa2eb1

Please sign in to comment.