-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
161 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import os.path as osp | ||
|
||
from torch_frame.datasets import Jobs | ||
|
||
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', "jobs") | ||
dataset = Jobs(root=path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import pandas as pd | ||
|
||
import torch_frame | ||
|
||
|
||
class Jobs(torch_frame.data.Dataset): | ||
r"""The `Jobs | ||
<https://www.fredjo.com/files/jobs_DW_bin.new.10.train.npz>`_ | ||
dataset from Lalonde. | ||
treatment indicator (1 if treated, 0 if not treated), age, | ||
education, Black (1 if black, 0 otherwise), Hispanic | ||
(1 if Hispanic, 0 otherwise), married (1 if married, 0 otherwise), | ||
nodegree (1 if no degree, 0 otherwise), RE74 (earnings in 1974), | ||
RE75 (earnings in 1975), and RE78 (earnings in 1978). | ||
""" | ||
lalonde_treated = 'https://users.nber.org/~rdehejia/data/nsw_treated.txt' | ||
lalonde_control = 'https://users.nber.org/~rdehejia/data/nsw_control.txt' | ||
psid = 'https://users.nber.org/~rdehejia/data/psid_controls.txt' # noqa | ||
|
||
def __init__(self, root: str): | ||
# National Supported Work Demonstration | ||
nsw_treated = self.download_url(Jobs.lalonde_treated, root) | ||
nsw_control = self.download_url(Jobs.lalonde_control, root) | ||
# Population Survey of Income Dynamics | ||
psid = self.download_url(Jobs.psid, root) | ||
names = [ | ||
'treated', 'age', 'education', 'Black', 'Hispanic', 'married', | ||
'nodegree', 'RE75', 'RE78' | ||
] | ||
|
||
nsw_treated_df = pd.read_csv( | ||
nsw_treated, | ||
sep='\s+', # noqa | ||
names=names) | ||
assert (nsw_treated_df['treated'] == 1).all() | ||
nsw_treated_df['source'] = 'nsw' | ||
|
||
nsw_control_df = pd.read_csv( | ||
nsw_control, | ||
sep='\s+', # noqa | ||
names=names) | ||
assert (nsw_control_df['treated'] == 0).all() | ||
nsw_control_df['source'] = 'nsw' | ||
|
||
names.insert(7, 'RE74') | ||
|
||
psid_df = pd.read_csv(psid, sep='\s+', names=names) # noqa | ||
assert (psid_df['treated'] == 0).all() | ||
psid_df['source'] = 'psid' | ||
psid_df = psid_df.drop('RE74', axis=1) | ||
|
||
df = pd.concat([nsw_treated_df, nsw_control_df, psid_df], axis=0) | ||
df['target'] = df['RE78'] != 0 | ||
|
||
col_to_stype = { | ||
'treated': torch_frame.categorical, | ||
'age': torch_frame.numerical, | ||
'education': torch_frame.categorical, | ||
'Black': torch_frame.categorical, | ||
'Hispanic': torch_frame.categorical, | ||
'married': torch_frame.categorical, | ||
'nodegree': torch_frame.categorical, | ||
'RE75': torch_frame.numerical, | ||
'target': torch_frame.categorical, | ||
} | ||
|
||
super().__init__(df, col_to_stype, target_col='target') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from torch import Tensor | ||
from torch.nn import Linear, ReLU | ||
|
||
from torch_frame.nn.decoder import Decoder | ||
|
||
|
||
class MLPDecoder(Decoder): | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
hidden_channels: int, | ||
out_channels: int, | ||
num_cols: int, | ||
) -> None: | ||
super().__init__() | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.activation = ReLU() | ||
self.lin_1 = Linear(num_cols, hidden_channels) | ||
self.lin_2 = Linear(hidden_channels, hidden_channels) | ||
self.lin_3 = Linear(hidden_channels, out_channels) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self) -> None: | ||
self.lin_1.reset_parameters() | ||
self.lin_2.reset_parameters() | ||
self.lin_3.reset_parameters() | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
r"""Transforming :obj:`x` into output predictions. | ||
Args: | ||
x (Tensor): Input column-wise tensor of shape | ||
[batch_size, num_cols, in_channels] | ||
Returns: | ||
Tensor: [batch_size, out_channels]. | ||
""" | ||
x = self.lin_1(x) | ||
x = self.activation(x) | ||
x = self.lin_2(x) | ||
x = self.activation(x) | ||
x = self.lin_3(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torch.nn import Module, Parameter | ||
|
||
from torch_frame.nn.decoder import MLPDecoder | ||
from torch_frame.nn.models import MLP | ||
|
||
|
||
class EpsilonLayer(Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.epsilon = Parameter(torch.randn(1, 1)) | ||
|
||
def reset_parameters(self): | ||
self.epsilon.reset_parameters() | ||
|
||
def forward(self, t): | ||
return F.sigmoid(self.epsilon * torch.ones_like(t)[:, 0:1]) | ||
|
||
|
||
class BCAUSS(Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.mlp = MLP() | ||
self.epsilon = EpsilonLayer() | ||
# decoder for treatment group | ||
self.treatment_decoder = MLPDecoder() | ||
# decoder for control group | ||
self.control_decoder = MLPDecoder() | ||
|
||
def forward(self, x, t): | ||
r"""T stands for treatment and y stands for output.""" | ||
out = self.mlp(x) | ||
if t == 0: | ||
out = self.control_decoder(out) | ||
else: | ||
out = self.treatment_decoder(out) | ||
penalty = self.epsilon(out) | ||
return out + penalty |