Skip to content

Commit

Permalink
Added random seed for reproducibility.
Browse files Browse the repository at this point in the history
  • Loading branch information
BrianPulfer committed Feb 11, 2023
1 parent 5d88f42 commit d5004ef
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Watermarking for language models
## Description
Re-implementation of the watermarking technique proposed in [*A Watermark for Large Language Models*](https://arxiv.org/abs/2301.10226v2)
by **Kirchenbauer** & **Geiping** et. al. ([Original repo](https://github.com/jwkirchenbauer/lm-watermarking))
by **Kirchenbauer** & **Geiping** et. al. ([Original repo](https://github.com/jwkirchenbauer/lm-watermarking)).

## Usage
Generating a (soft) watermarked text with your language model is as easy as:
Expand Down
6 changes: 4 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from evaluate import load
from accelerate import Accelerator
from transformers import AutoTokenizer, GPT2LMHeadModel
from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed

from watermarking import generate, detect_watermark, get_perplexities

Expand All @@ -20,6 +19,9 @@ def forward(self, input_ids):

def main():
"""Plots the perplexity of the GPT2 model and the z-static for sentences generated with and without watermarking."""
# Setting seed
set_seed(0)

# Device
device = Accelerator().device

Expand Down
18 changes: 12 additions & 6 deletions src/plot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from watermarking import detect_watermark, generate, get_perplexities
from argparse import ArgumentParser
import torch
from transformers import AutoTokenizer, GPT2LMHeadModel

from tqdm import tqdm
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 22})
from argparse import ArgumentParser

import torch
from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed

from watermarking import detect_watermark, generate, get_perplexities


def parse_args():
Expand All @@ -23,6 +24,7 @@ def parse_args():
help="Amount to add to the logits of the model when watermarking")
parser.add_argument("--device", type=int, default=0,
help="Device to use for generation")
parser.add_argument("--seed", type=int, default=0, help="Seed for the generation")

return vars(parser.parse_args())

Expand All @@ -48,6 +50,10 @@ def main():
batch_size = args["batch_size"]
gamma = args["gamma"]
delta = args["delta"]
seed = args["seed"]

# Setting seed
set_seed(seed)

# Device
device = torch.device(
Expand Down Expand Up @@ -88,7 +94,7 @@ def main():
plt.xlabel("Perplexity")
plt.ylabel("Z-score")
plt.savefig(
f"perplexity_vs_zscore_(n={n_sentences}, seq_len={seq_len}, gamma={gamma}, delta={delta}).png")
f"perplexity_vs_zscore_(n={n_sentences}, seq_len={seq_len}, gamma={gamma}, delta={delta}, seed={seed}).png")
plt.show()
print("Program completed successfully!")

Expand Down

0 comments on commit d5004ef

Please sign in to comment.