Skip to content

Commit

Permalink
Add Era of 1.58 bit LLMs BitLinear implementation
Browse files Browse the repository at this point in the history
Adding MIT Licensed ternary implementation of BitLinear:
https://huggingface.co/1bitLLM/bitnet_b1_58-large/blob/main/utils_quant.py

Ternary BitLinear Arxiv Paper Link:
https://arxiv.org/abs/2402.17764
  • Loading branch information
gkielian committed Apr 10, 2024
1 parent 5290cd4 commit c337068
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 2 deletions.
2 changes: 1 addition & 1 deletion explorations/linear_sweep.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"device": ["cuda"],
"dtype": ["float16"],
"dataset": ["shakespeare_char"],
"linear_variant": ["bitlinear", "bitlinear_optimized", "linear"],
"linear_variant": ["bitlinear_1p58", "bitlinear", "bitlinear_optimized", "linear"],
"compile": [true],
"softmax_variant_attn": ["softmax", "polymax"],
"tensorboard_run_name": ["linear_variation_sweep"]
Expand Down
2 changes: 1 addition & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from variations.normalization_variations import LayerNorm, RMSNorm
from variations.position_encoding_variations import RotaryEmbedding, ShortRope, SymmetricalOverlapAngularPositions
from variations.activation_variations import SquaredReLU, activation_dictionary
from variations.linear_variations import BitLinear, BitLinearOptimized, linear_dictionary
from variations.linear_variations import BitLinear1p58, BitLinear, BitLinearOptimized, linear_dictionary

def create_shared_param_group(layer_type, config):
shared_size = None
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def parse_args():
choices=[
"linear",
"bitlinear",
"bitlinear_1p58",
"bitlinear_optimized",
],
)
Expand Down
45 changes: 45 additions & 0 deletions variations/linear_variations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,50 @@
import torch.nn as nn
import math

class BitLinear1p58(nn.Linear):
""" BitLinear from Era of 1.58 LLMs Paper
Source: https://huggingface.co/1bitLLM/bitnet_b1_58-large/blob/main/utils_quant.py
Source License: MIT
Paper Link: https://arxiv.org/abs/2402.17764
"""

def __init__(self, in_features, out_features, bias=True, num_groups=1):
super().__init__(in_features, out_features, bias)

"""
RMSNorm is placed outside BitLinear
"""
weight_bits=1
input_bits=8
self.weight_bits = weight_bits
self.input_bits = input_bits

def forward(self, x):

quant_input = x + (self.activation_quant(x, self.input_bits) - x).detach()
quant_weight = self.weight + (self.weight_quant(self.weight, self.weight_bits) - self.weight).detach()

out = nn.functional.linear(quant_input, quant_weight)
if not self.bias is None:
out += self.bias.view(1, -1).expand_as(out)

return out

def weight_quant(self, weight, num_bits=1):
dtype = weight.dtype
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(-1, 1) / s
return result.type(dtype)

def activation_quant(self, x, num_bits=8):
dtype = x.dtype
x = x.float()
Qn = -2 ** (num_bits - 1)
Qp = 2 ** (num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp) / s
return result.type(dtype)

class BitLinear(nn.Linear):
"""PyTorch BitLinear Layer
Expand Down Expand Up @@ -175,4 +219,5 @@ def forward(self, input):
"linear": nn.Linear,
"bitlinear": BitLinear,
"bitlinear_optimized": BitLinearOptimized,
"bitlinear_1p58": BitLinear1p58,
}

0 comments on commit c337068

Please sign in to comment.