Skip to content

Commit

Permalink
Merge pull request #755 from TransformerLensOrg/main
Browse files Browse the repository at this point in the history
Upstream update
  • Loading branch information
bryce13950 authored Oct 16, 2024
2 parents 70029b9 + 336df99 commit 0dbc7a8
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,6 @@ def tokenize_and_concatenate(
Returns:
Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens"
Note: There is a bug when inputting very small datasets (eg, <1 batch per process) where it just outputs nothing. I'm not super sure why
"""
dataset = keep_single_column(dataset, column_name)
if tokenizer.pad_token is None:
Expand All @@ -329,6 +327,11 @@ def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]:
text = examples[column_name]
# Concatenate it all into an enormous string, separated by eos_tokens
full_text = tokenizer.eos_token.join(text)

# Handle the case when full_text is empty
if not full_text.strip():
return {"tokens": np.array([], dtype=np.int64)}

# Divide into 20 chunks of ~ equal length
num_chunks = 20
chunk_length = (len(full_text) - 1) // num_chunks + 1
Expand All @@ -338,9 +341,21 @@ def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]:
# Drop padding tokens
tokens = tokens[tokens != tokenizer.pad_token_id]
num_tokens = len(tokens)
num_batches = num_tokens // (seq_len)
# Drop the final tokens if not enough to make a full sequence
tokens = tokens[: seq_len * num_batches]

# Handle cases where num_tokens is less than seq_len
if num_tokens < seq_len:
num_batches = 1
# Pad tokens if necessary
tokens = tokens[:seq_len]
if len(tokens) < seq_len:
padding_length = seq_len - len(tokens)
padding = np.full(padding_length, tokenizer.pad_token_id)
tokens = np.concatenate([tokens, padding], axis=0)
else:
num_batches = num_tokens // seq_len
# Drop the final tokens if not enough to make a full sequence
tokens = tokens[: seq_len * num_batches]

tokens = einops.rearrange(
tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len
)
Expand Down

0 comments on commit 0dbc7a8

Please sign in to comment.