4/20/22
- We have updated the
setup.py
to make installation more flexible.
Please use pip install ws-benchmark==1.1.2rc0
to install the latest version. We strongly suggest create a new environment to install wrench. We will bring better compatibility in the next stable release.
If you have any problems with installation, please let us know.
Known incompatibilities:
tensorflow==2.8.0
, albumentations==0.1.12
3/18/22
- Wrench is available on ws-benchmark now, using
pip install ws-benchmark
to qucik install.
2/13/22
- Add script to generate LFs for any tabular dataset as well as 5 new tabular datasets, namely, mushroom, spambase, PhishingWebsites, Bioresponse, and bank-marketing.
11/04/21
- (beta) Add
parallel_fit
for torch model to support pytorch DistributedDataParallel-example
10/15/21
- A branch of new methods: WeaSEL, ImplyLoss, ASTRA, MeanTeacher, Meta-Weight-Net, Learning-to-Reweight
- Support image classification (dataset class / torchvision backbone) as well as DomainNet/Animals-with-Attributes2 datasets (check out the
datasets
folder)
Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development and evaluation of your own weak supervision models within the benchmark.
For more information, checkout our publications:
- WRENCH: A Comprehensive Benchmark for Weak Supervision (NeurIPS 2021)
- A Survey on Programmatic Weak Supervision
If you find this repository helpful, feel free to cite our publication:
@inproceedings{
zhang2021wrench,
title={{WRENCH}: A Comprehensive Benchmark for Weak Supervision},
author={Jieyu Zhang and Yue Yu and Yinghao Li and Yujing Wang and Yaming Yang and Mao Yang and Alexander Ratner},
booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
year={2021},
url={https://openreview.net/forum?id=Q9SKS5k8io}
}
Weak Supervision is a paradigm for automated training data creation without manual annotations.
For a brief overview, please check out this blog.
For more context, please check out this survey.
To track recent advances in weak supervision, please follow this repo.
[1] Install anaconda: Instructions here: https://www.anaconda.com/download/
[2] Clone the repository:
git clone https://github.com/JieyuZ2/wrench.git
cd wrench
[3] Create virtual environment:
conda env create -f environment.yml
source activate wrench
If this not working or you want to use only a subset of modules of Wrench, check out this wiki page
The datasets can be downloaded via this.
Note that some datasets may have more training examples than what is reported in README/paper because we include the dev set, whose indices can be found in labeled_id.json if exists.
A documentation of dataset format and usage can be found in this wiki-page
Name | Task | # class | # LF | # train | # validation | # test | data source | LF source |
---|---|---|---|---|---|---|---|---|
Census | income clasification | 2 | 83 | 10083 | 5561 | 16281 | link | link |
Youtube | spam clasification | 2 | 10 | 1586 | 120 | 250 | link | link |
SMS | spam clasification | 2 | 73 | 4571 | 500 | 500 | link | link |
IMDB | sentiment clasification | 2 | 8 | 20000 | 2500 | 2500 | link | link |
Yelp | sentiment clasification | 2 | 8 | 30400 | 3800 | 3800 | link | link |
AGNews | topic clasification | 4 | 9 | 96000 | 12000 | 12000 | link | link |
TREC | question classification | 6 | 68 | 4965 | 500 | 500 | link | link |
Spouse | relation classification | 2 | 9 | 22254 | 2801 | 2701 | link | link |
SemEval | relation classification | 9 | 164 | 1749 | 178 | 600 | link | link |
CDR | bio relation classification | 2 | 33 | 8430 | 920 | 4673 | link | link |
Chemprot | chemical relation classification | 10 | 26 | 12861 | 1607 | 1607 | link | link |
Commercial | video frame classification | 2 | 4 | 64130 | 9479 | 7496 | link | link |
Tennis Rally | video frame classification | 2 | 6 | 6959 | 746 | 1098 | link | link |
Basketball | video frame classification | 2 | 4 | 17970 | 1064 | 1222 | link | link |
DomainNet | image classification | - | - | - | - | - | link | link |
Name | # class | # LF | # train | # validation | # test | data source | LF source |
---|---|---|---|---|---|---|---|
CoNLL-03 | 4 | 16 | 14041 | 3250 | 3453 | link | link |
WikiGold | 4 | 16 | 1355 | 169 | 170 | link | link |
OntoNotes 5.0 | 18 | 17 | 115812 | 5000 | 22897 | link | link |
BC5CDR | 2 | 9 | 500 | 500 | 500 | link | link |
NCBI-Disease | 1 | 5 | 592 | 99 | 99 | link | link |
Laptop-Review | 1 | 3 | 2436 | 609 | 800 | link | link |
MIT-Restaurant | 8 | 16 | 7159 | 500 | 1521 | link | link |
MIT-Movies | 12 | 7 | 9241 | 500 | 2441 | link | link |
The detailed documentation is coming soon.
If you find any of the implementations is wrong/problematic, don't hesitate to raise issue/pull request, we really appreciate it!
TODO-list: check this out!
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
Majority Voting | Label Model | -- | link |
Weighted Majority Voting | Label Model | -- | link |
Dawid-Skene | Label Model | link | link |
Data Progamming | Label Model | link | link |
MeTaL | Label Model | link | link |
FlyingSquid | Label Model | link | link |
EBCC | Label Model | link | link |
IBCC | Label Model | link | link |
FABLE | Label Model | link | link |
Logistic Regression | End Model | -- | link |
MLP | End Model | -- | link |
BERT | End Model | link | link |
COSINE | End Model | link | link |
ARS2 | End Model | link | link |
Denoise | Joint Model | link | link |
WeaSEL | Joint Model | link | link |
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
Hidden Markov Model | Label Model | link | link |
Conditional Hidden Markov Model | Label Model | link | link |
LSTM-CNNs-CRF | End Model | link | link |
BERT-CRF | End Model | link | link |
LSTM-ConNet | Joint Model | link | link |
BERT-ConNet | Joint Model | link | link |
Wrench also provides a SeqLabelModelWrapper
that adaptes label model for classification task to sequence tagging task.
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
Meta-Weight-Net | End Model | link | link |
Learning2ReWeight | End Model | link | link |
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
MeanTeacher | End Model | link | link |
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
ImplyLoss | Joint Model | link | link |
ASTRA | Joint Model | link | link |
import logging
import numpy as np
import pprint
from wrench.dataset import load_dataset
from wrench._logging import LoggingHandler
from wrench.search import grid_search
from wrench import labelmodel
from wrench.evaluation import AverageMeter
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
#### Load dataset
dataset_home = '../datasets'
data = 'youtube'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)
#### Specify the hyper-parameter search space for grid search
search_space = {
'Snorkel': {
'lr': np.logspace(-5, -1, num=5, base=10),
'l2': np.logspace(-5, -1, num=5, base=10),
'n_epochs': [5, 10, 50, 100, 200],
}
}
#### Initialize label model
label_model_name = 'Snorkel'
label_model = getattr(labelmodel, label_model_name)
#### Search best hyper-parameters using validation set in parallel
n_trials = 100
n_repeats = 5
target = 'acc'
searched_paras = grid_search(label_model(), dataset_train=train_data, dataset_valid=valid_data,
metric=target, direction='auto', search_space=search_space[label_model_name],
n_repeats=n_repeats, n_trials=n_trials, parallel=True)
#### Evaluate the label model with searched hyper-parameters and average meter
meter = AverageMeter(names=[target])
for i in range(n_repeats):
model = label_model(**searched_paras)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data)
metric_value = model.test(test_data, target)
meter.update(target=metric_value)
metrics = meter.get_results()
pprint.pprint(metrics)
For detailed guidance of grid_search
, please check out this wiki page.
import logging
import torch
from wrench.dataset import load_dataset
from wrench._logging import LoggingHandler
from wrench.endmodel import MLPModel
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
#### Load dataset
dataset_home = '../datasets'
data = 'youtube'
#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
cache_name=extract_fn, model_name=model_name)
#### Train a MLP classifier
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000
patience = 200
evaluation_step = 50
target='acc'
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, device=device, metric=target,
patience=patience, evaluation_step=evaluation_step)
#### Evaluate the trained model
metric_value = model.test(test_data, target)
import logging
import torch
from wrench.dataset import load_dataset
from wrench._logging import LoggingHandler
from wrench.endmodel import MLPModel
from wrench.labelmodel import MajorityVoting
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
#### Load dataset
dataset_home = '../datasets'
data = 'youtube'
#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
cache_name=extract_fn, model_name=model_name)
#### Generate soft training label via a label model
#### The weak labels provided by supervision sources are alreadly encoded in dataset object
label_model = MajorityVoting()
label_model.fit(train_data, valid_data)
soft_label = label_model.predict_proba(train_data)
#### Train a MLP classifier with soft label
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000
patience = 200
evaluation_step = 50
target='acc'
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=soft_label,
device=device, metric=target, patience=patience, evaluation_step=evaluation_step)
#### Evaluate the trained model
metric_value = model.test(test_data, target)
#### We can also train a MLP classifier with hard label
from snorkel.utils import probs_to_preds
hard_label = probs_to_preds(soft_label)
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=hard_label,
device=device, metric=target, patience=patience, evaluation_step=evaluation_step)
import logging
import torch
from wrench.dataset import load_dataset
from wrench._logging import LoggingHandler
from wrench.synthetic import ConditionalIndependentGenerator, NGramLFGenerator
from wrench.labelmodel import FlyingSquid
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
#### Generate synthetic dataset
generator = ConditionalIndependentGenerator(
n_class=2,
n_lfs=10,
alpha=0.75, # mean accuracy
beta=0.1, # mean propensity
alpha_radius=0.2, # radius of accuracy
beta_radius=0.1 # radius of propensity
)
train_data = generator.generate_split('train', 10000)
valid_data = generator.generate_split('valid', 1000)
test_data = generator.generate_split('test', 1000)
#### Evaluate label model on synthetic dataset
label_model = FlyingSquid()
label_model.fit(dataset_train=train_data, dataset_valid=valid_data)
target_value = label_model.test(test_data, metric_fn='auc')
#### Load dataset
dataset_home = '../datasets'
data = 'youtube'
#### Load real-world dataset
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)
#### Generate procedural labeling functions
generator = NGramLFGenerator(dataset=train_data, min_acc_gain=0.1, min_support=0.01, ngram_range=(1, 2))
applier = generator.generate(mode='correlated', n_lfs=10)
L_test = applier.apply(test_data)
L_train = applier.apply(train_data)
#### Evaluate label model on real-world dataset with semi-synthetic labeling functions
label_model = FlyingSquid()
label_model.fit(dataset_train=L_train, dataset_valid=valid_data)
target_value = label_model.test(L_test, metric_fn='auc')
Contact person: Jieyu Zhang, jieyuzhang97@gmail.com
Don't hesitate to send us an e-mail if you have any question.
We're also open to any collaboration!
We sincerely welcome any contribution to the datasets or models!