Skip to content

Commit

Permalink
Add Causal ML benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
yiweny committed Apr 14, 2024
1 parent a9b288e commit 6141684
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 4 deletions.
6 changes: 6 additions & 0 deletions examples/bcauss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import os.path as osp

from torch_frame.datasets import Jobs

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', "jobs")
dataset = Jobs(root=path)
2 changes: 1 addition & 1 deletion torch_frame/data/multi_embedding_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _empty(self, dim: int) -> MultiEmbeddingTensor:
Returns:
MultiEmbeddingTensor: An empty :class:`MultiEmbeddingTensor`.
Note that if :obj:`dim=0`, it will return with the original
offset tensor.
offset tensor.git
"""
return MultiEmbeddingTensor(
num_rows=0 if dim == 0 else self.num_rows,
Expand Down
2 changes: 2 additions & 0 deletions torch_frame/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .amazon_fine_food_reviews import AmazonFineFoodReviews
from .diamond_images import DiamondImages
from .huggingface_dataset import HuggingFaceDatasetDict
from .jobs import Jobs

real_world_datasets = [
'Titanic',
Expand All @@ -36,6 +37,7 @@
'Mercari',
'AmazonFineFoodReviews',
'DiamondImages',
'jobs',
]

synthetic_datasets = [
Expand Down
67 changes: 67 additions & 0 deletions torch_frame/datasets/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pandas as pd

import torch_frame


class Jobs(torch_frame.data.Dataset):
r"""The `Jobs
<https://www.fredjo.com/files/jobs_DW_bin.new.10.train.npz>`_
dataset from Lalonde.
treatment indicator (1 if treated, 0 if not treated), age,
education, Black (1 if black, 0 otherwise), Hispanic
(1 if Hispanic, 0 otherwise), married (1 if married, 0 otherwise),
nodegree (1 if no degree, 0 otherwise), RE74 (earnings in 1974),
RE75 (earnings in 1975), and RE78 (earnings in 1978).
"""
lalonde_treated = 'https://users.nber.org/~rdehejia/data/nsw_treated.txt'
lalonde_control = 'https://users.nber.org/~rdehejia/data/nsw_control.txt'
psid = 'https://users.nber.org/~rdehejia/data/psid_controls.txt' # noqa

def __init__(self, root: str):
# National Supported Work Demonstration
nsw_treated = self.download_url(Jobs.lalonde_treated, root)
nsw_control = self.download_url(Jobs.lalonde_control, root)
# Population Survey of Income Dynamics
psid = self.download_url(Jobs.psid, root)
names = [
'treated', 'age', 'education', 'Black', 'Hispanic', 'married',
'nodegree', 'RE75', 'RE78'
]

nsw_treated_df = pd.read_csv(
nsw_treated,
sep='\s+', # noqa
names=names)
assert (nsw_treated_df['treated'] == 1).all()
nsw_treated_df['source'] = 'nsw'

nsw_control_df = pd.read_csv(
nsw_control,
sep='\s+', # noqa
names=names)
assert (nsw_control_df['treated'] == 0).all()
nsw_control_df['source'] = 'nsw'

names.insert(7, 'RE74')

psid_df = pd.read_csv(psid, sep='\s+', names=names) # noqa
assert (psid_df['treated'] == 0).all()
psid_df['source'] = 'psid'
psid_df = psid_df.drop('RE74', axis=1)

df = pd.concat([nsw_treated_df, nsw_control_df, psid_df], axis=0)
df['target'] = df['RE78'] != 0

col_to_stype = {
'treated': torch_frame.categorical,
'age': torch_frame.numerical,
'education': torch_frame.categorical,
'Black': torch_frame.categorical,
'Hispanic': torch_frame.categorical,
'married': torch_frame.categorical,
'nodegree': torch_frame.categorical,
'RE75': torch_frame.numerical,
'target': torch_frame.categorical,
}

super().__init__(df, col_to_stype, target_col='target')
5 changes: 2 additions & 3 deletions torch_frame/nn/decoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from .decoder import Decoder
from .trompt_decoder import TromptDecoder
from .excelformer_decoder import ExcelFormerDecoder
from .mlpdecoder import MLPDecoder

__all__ = classes = [
'Decoder',
'TromptDecoder',
'ExcelFormerDecoder',
'Decoder', 'TromptDecoder', 'ExcelFormerDecoder', 'MLPDecoder'
]
44 changes: 44 additions & 0 deletions torch_frame/nn/decoder/mlpdecoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from torch import Tensor
from torch.nn import Linear, ReLU

from torch_frame.nn.decoder import Decoder


class MLPDecoder(Decoder):
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_cols: int,
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.activation = ReLU()
self.lin_1 = Linear(num_cols, hidden_channels)
self.lin_2 = Linear(hidden_channels, hidden_channels)
self.lin_3 = Linear(hidden_channels, out_channels)
self.reset_parameters()

def reset_parameters(self) -> None:
self.lin_1.reset_parameters()
self.lin_2.reset_parameters()
self.lin_3.reset_parameters()

def forward(self, x: Tensor) -> Tensor:
r"""Transforming :obj:`x` into output predictions.
Args:
x (Tensor): Input column-wise tensor of shape
[batch_size, num_cols, in_channels]
Returns:
Tensor: [batch_size, out_channels].
"""
x = self.lin_1(x)
x = self.activation(x)
x = self.lin_2(x)
x = self.activation(x)
x = self.lin_3(x)
return x
39 changes: 39 additions & 0 deletions torch_frame/nn/models/bcauss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
import torch.nn.functional as F
from torch.nn import Module, Parameter

from torch_frame.nn.decoder import MLPDecoder
from torch_frame.nn.models import MLP


class EpsilonLayer(Module):
def __init__(self):
super().__init__()
self.epsilon = Parameter(torch.randn(1, 1))

def reset_parameters(self):
self.epsilon.reset_parameters()

def forward(self, t):
return F.sigmoid(self.epsilon * torch.ones_like(t)[:, 0:1])


class BCAUSS(Module):
def __init__(self):
super().__init__()
self.mlp = MLP()
self.epsilon = EpsilonLayer()
# decoder for treatment group
self.treatment_decoder = MLPDecoder()
# decoder for control group
self.control_decoder = MLPDecoder()

def forward(self, x, t):
r"""T stands for treatment and y stands for output."""
out = self.mlp(x)
if t == 0:
out = self.control_decoder(out)
else:
out = self.treatment_decoder(out)
penalty = self.epsilon(out)
return out + penalty

0 comments on commit 6141684

Please sign in to comment.