Skip to content

Commit

Permalink
Simpler MatMul interface, vocab types, Tristate for use_spinning
Browse files Browse the repository at this point in the history
Add Extents2D, Range2D vocab types
Matmul uses ConstMat for inputs and RowPtr for output
Move RowVectorBatch to basics.h
Separate threading.cc
Fix topology string: report cores not LPs, and #HT
Move QStride/IsMHA into LayerConfig
ImageTokens does not require make_unique.
matmul_test: no longer require template args
PiperOrigin-RevId: 691460778
  • Loading branch information
jan-wassenberg authored and copybara-github committed Nov 4, 2024
1 parent baaa221 commit efd1d63
Show file tree
Hide file tree
Showing 26 changed files with 1,315 additions and 975 deletions.
11 changes: 9 additions & 2 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ cc_library(

cc_library(
name = "threading",
srcs = ["util/threading.cc"],
hdrs = ["util/threading.h"],
deps = [
":basics",
# Placeholder for container detection, do not remove
"@highway//:hwy",
"@highway//:thread_pool",
"@highway//:topology",
Expand Down Expand Up @@ -173,7 +176,9 @@ cc_test(
tags = ["hwy_ops_test"],
deps = [
":allocator",
":basics",
":ops",
":test_util",
":threading",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
Expand Down Expand Up @@ -280,6 +285,7 @@ cc_library(
":kv_cache",
":weights",
":threading",
"//compression:compress",
"//compression:io",
"//compression:sfp",
"//paligemma:image",
Expand Down Expand Up @@ -307,6 +313,7 @@ cc_library(
name = "args",
hdrs = ["util/args.h"],
deps = [
":basics",
"//compression:io",
"@highway//:hwy",
],
Expand All @@ -317,6 +324,7 @@ cc_library(
hdrs = ["util/app.h"],
deps = [
":args",
":basics",
":common",
":gemma_lib",
":threading",
Expand All @@ -342,8 +350,6 @@ cc_library(
"//compression:compress",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:thread_pool",
"@highway//:topology",
],
)

Expand Down Expand Up @@ -583,6 +589,7 @@ cc_test(
},
deps = [
":backprop",
":basics",
":common",
":gemma_lib",
":optimizer",
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ set(SOURCES
util/args.h
util/basics.h
util/test_util.h
util/threading.cc
util/threading.h
)

Expand Down
4 changes: 3 additions & 1 deletion backprop/optimize_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "util/basics.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h"

namespace gcpp {

TEST(OptimizeTest, GradientDescent) {
NestedPools pools(1, /*pin=*/0, BoundedSlice(0, 1), BoundedSlice(0, 1));
NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 1));
hwy::ThreadPool& pool = pools.Pool();
std::mt19937 gen(42);

Expand Down
25 changes: 23 additions & 2 deletions compression/compress.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "compression/blob_store.h"
#include "compression/io.h"
#include "compression/shared.h"
#include "util/basics.h"
// IWYU pragma: end_exports
#include "util/allocator.h"
#if COMPRESS_STATS
Expand Down Expand Up @@ -62,7 +63,9 @@ class MatPtr {
num_elements_(rows * cols),
rows_(rows),
cols_(cols),
ptr_(nullptr) {}
ptr_(nullptr) {
stride_ = cols;
}
// Default is to leave all fields default-initialized.
MatPtr() = default;
virtual ~MatPtr();
Expand All @@ -85,7 +88,9 @@ class MatPtr {
element_size_(key2.hi),
num_elements_(key2.lo),
rows_(key3.lo),
cols_(key3.hi) {}
cols_(key3.hi) {
stride_ = cols_;
}

// Adds the contents entry to the table of contents.
void AddToToc(std::vector<hwy::uint128_t>& toc) const {
Expand Down Expand Up @@ -137,6 +142,12 @@ class MatPtr {
// Returns the number of columns in the 2-d array (inner dimension).
size_t Cols() const { return cols_; }

Extents2D Extents() const { return Extents2D(rows_, cols_); }

// Currently same as cols, but may differ in the future. This is the offset by
// which to advance pointers to the next row.
size_t Stride() const { return stride_; }

// Decoded elements should be multiplied by this to restore their original
// range. This is required because SfpStream can only encode a limited range
// of magnitudes.
Expand Down Expand Up @@ -187,6 +198,8 @@ class MatPtr {
// freed. The underlying memory is owned by a subclass or some external class
// and must outlive this object.
void* ptr_ = nullptr;

size_t stride_;
};

// MatPtrT adds a single template argument to MatPtr for an explicit type.
Expand Down Expand Up @@ -288,7 +301,15 @@ decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) {
}
}

template <typename T>
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
ConstMat<T> mat = MakeConstMat(const_cast<T*>(m.data()), m.Extents(), ofs);
mat.scale = m.scale();
return mat;
}

// MatStorageT adds the actual data storage to MatPtrT.
// TODO: use Extents2D instead of rows and cols.
template <typename MatT>
class MatStorageT : public MatPtrT<MatT> {
public:
Expand Down
8 changes: 6 additions & 2 deletions compression/shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,12 @@ struct PackedSpan {
// check the compressed count and ensure we have that many.
const size_t required =
CompressedArrayElements<Packed>(packed_ofs + num_accessible);
HWY_DASSERT(num >= required);
(void)required;
if constexpr (HWY_IS_DEBUG_BUILD) {
if (num < required) {
HWY_ABORT("PackedSpan: ofs %zu, want %zu, req %zu > %zu packed",
packed_ofs, num_accessible, required, num);
}
}
}

Packed* HWY_RESTRICT ptr;
Expand Down
4 changes: 2 additions & 2 deletions evals/benchmark_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,12 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
fprintf(stderr,
"Date & Time : %s" // dt includes \n
"CPU : %s\n"
"CPU topology : %s\n"
"CPU topology : %s, %s\n"
"Instruction set : %s (%zu bits)\n"
"Compiled config : %s\n"
"Weight Type : %s\n"
"EmbedderInput Type : %s\n",
dt, cpu100, pools.TopologyString(),
dt, cpu100, pools.TopologyString(), pools.PinString(),
hwy::TargetName(hwy::DispatchedTarget()), hwy::VectorBytes() * 8,
CompiledConfig(), StringFromType(loader.Info().weight),
TypeName<EmbedderInputT>());
Expand Down
45 changes: 20 additions & 25 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,11 @@ struct Activations {
size_t seq_len;
size_t cache_pos_size = 0;

// Multi-Head Attention?
bool IsMHA() const { return layer_config.heads == layer_config.kv_heads; }

// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
size_t QStride() const { return layer_config.qkv_dim * (IsMHA() ? 3 : 1); }

static RowVectorBatch<float> CreateInvTimescale(size_t qkv_dim,
PostQKType post_qk) {
const size_t rope_dim =
post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim;
RowVectorBatch<float> inv_timescale(1, rope_dim / 2);
RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2));
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
const float freq_exponents =
static_cast<float>(2 * dim) / static_cast<float>(rope_dim);
Expand All @@ -100,29 +93,31 @@ struct Activations {
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
const size_t vocab_size = weights_config.vocab_size;

x = RowVectorBatch<float>(batch_size, model_dim);
q = RowVectorBatch<float>(batch_size, layer_config.heads * QStride());
x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
q = RowVectorBatch<float>(
Extents2D(batch_size, layer_config.heads * layer_config.QStride()));
if (vocab_size > 0) {
logits = RowVectorBatch<float>(batch_size, vocab_size);
logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
}

pre_att_rms_out = RowVectorBatch<float>(batch_size, model_dim);
att = RowVectorBatch<float>(batch_size,
layer_config.heads * weights_config.seq_len);
att_out = RowVectorBatch<float>(batch_size,
layer_config.heads * layer_config.qkv_dim);
att_sums = RowVectorBatch<float>(batch_size, model_dim);
pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
att = RowVectorBatch<float>(
Extents2D(batch_size, layer_config.heads * weights_config.seq_len));
att_out = RowVectorBatch<float>(
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim));
att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim));

bf_pre_ffw_rms_out = RowVectorBatch<BF16>(batch_size, model_dim);
C1 = RowVectorBatch<float>(batch_size, ff_hidden_dim);
C2 = RowVectorBatch<float>(batch_size, ff_hidden_dim);
ffw_out = RowVectorBatch<float>(batch_size, model_dim);
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
C1 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
C2 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
ffw_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));

if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
griffin_x = RowVectorBatch<float>(batch_size, model_dim);
griffin_y = RowVectorBatch<float>(batch_size, model_dim);
griffin_gate_x = RowVectorBatch<float>(batch_size, model_dim);
griffin_multiplier = RowVectorBatch<float>(batch_size, model_dim);
griffin_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
griffin_y = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
griffin_gate_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
griffin_multiplier =
RowVectorBatch<float>(Extents2D(batch_size, model_dim));
}

inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk);
Expand Down
7 changes: 7 additions & 0 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ enum class Model {
struct LayerConfig {
size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; }

// Multi-Head Attention?
bool IsMHA() const { return heads == kv_heads; }

// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); }

size_t model_dim = 0;
size_t griffin_dim = 0;
size_t ff_hidden_dim = 0;
Expand Down
Loading

0 comments on commit efd1d63

Please sign in to comment.