-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathexampleAdultLLAMA2.py
159 lines (113 loc) · 7.62 KB
/
exampleAdultLLAMA2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import networkx as nx
import pandas as pd
import numpy as np
import os
from core.discovery import get_graphs
from langchain.llms import Replicate
def graph_adult_income():
'''
returns a graph for the adult income dataset. causal graph obtained from https://arxiv.org/pdf/1611.07438.pdf
'''
G = nx.DiGraph(directed=True)
G.add_node("ethnicity")
G.add_edges_from([("ethnicity", "income"), ("ethnicity", "occupation"), ("ethnicity", "marital-status"), ("ethnicity", "hours-per-week"), ("ethnicity", "education")])
G.add_node("age")
G.add_edges_from([("age", "income"), ("age", "occupation"), ("age", "marital-status"), ("age", "workclass"), ("age", "education"),
("age", "hours-per-week"), ("age", "relationship")])
G.add_node("native-country")
G.add_edges_from([("native-country", "education"), ("native-country", "workclass"), ("native-country", "hours-per-week"),
("native-country", "marital-status"), ("native-country", "relationship"), ("native-country", "income") ])
G.add_node("gender")
G.add_edges_from([("gender", "education"), ("gender", "hours-per-week"), ("gender", "marital-status"), ("gender", "occupation"),
("gender", "relationship"), ("gender", "income") ])
G.add_node("education")
G.add_edges_from([("education", "occupation"), ("education", "workclass"), ("education", "hours-per-week" ), ("education", "relationship"),
("education", "income") ])
G.add_node("hours-per-week")
G.add_edges_from([("hours-per-week", "workclass"), ("hours-per-week", "marital-status" ), ("hours-per-week", "income")])
G.add_node("workclass")
G.add_edges_from([("workclass", "occupation"), ("workclass", "marital-status" ), ("workclass", "income")])
G.add_node("marital-status")
G.add_edges_from([("marital-status", "occupation"), ("marital-status", "relationship"), ("marital-status", "income")])
G.add_node("occupation")
G.add_edges_from([("occupation", "income")])
G.add_node("relationship")
G.add_edges_from([("relationship", "income")])
G.add_node("income")
return G
def load_adultdataset(nodes):
'''
read dataset and preprocessing for adult income dataset
return data only for the nodes
'''
df = pd.read_csv("data/adult_income_dataset.csv")
print(df.shape)
# Binarize the target 0 = <= credit; 1 = >50K
df['income'] = df['income'].map({'<=50K': 0, '>50K': 1}).astype(int)
# Finding the special characters in the data frame
df.isin(['?']).sum(axis=0)
# code will replace the special character to nan and then drop the columns
df['native-country'] = df['native-country'].replace('?', np.nan)
df['workclass'] = df['workclass'].replace('?', np.nan)
df['occupation'] = df['occupation'].replace('?', np.nan)
# dropping the NaN rows now
df.dropna(how='any', inplace=True)
print(df.shape)
# categorical variables
catvars = ['workclass', 'marital-status', 'occupation', 'relationship',
'ethnicity', 'gender', 'native-country']
# education order > https: // www.rdocumentation.org / packages / arules / versions / 1.6 - 6 / topics / Adult
df['education'] = df['education'].map(
{'Preschool': 0, '1st-4th': 1, '5th-6th': 2, '7th-8th': 3, '9th': 4, '10th': 5, '11th': 6, '12th': 7, 'HS-grad': 8,
'Prof-school': 9, 'Assoc-acdm': 10, 'Assoc-voc': 11, 'Some-college':12, 'Bachelors': 13, 'Masters': 14,'Doctorate': 15}).astype(int)
#create quickaccess list with numerical variables labels
numvars = ['age', 'hours-per-week']
# dic categories Index(['A11', 'A12', 'A13', 'A14'], dtype='object')
dict_categorical = {}
for c in catvars:
dict_categorical[c] = list(df[c].astype("category").cat.categories)
df[c] = df[c].astype("category").cat.codes
# all features as float
df = df.astype("float64")
df["income"] = df["income"].astype("int32")
# save codes
with open('dict_adult.txt', 'w') as f:
f.write(str(dict_categorical))
return df[nodes]
### START EXPERIMENT
'''
# Downoload and save the dataset
s = requests.get("https://raw.githubusercontent.com/jbrownlee/Datasets/master/adult-all.csv").text
df = pd.read_csv(io.StringIO(s), names=["age", "workclass", "fnlwgt", "education", "education-num",
"marital-status", "occupation", "relationship", "ethnicity", "gender",
"capital-gain", "capital-loss", "hours-per-week", "native-country", "income"])
# save original dataset
df.to_csv("data/adult_income_dataset.csv") # save as csv file
'''
# generate graph
G = graph_adult_income()
nodes = list(G.nodes)
# info about features
constraints_features = {"immutable": ["ethnicity", "native-country", "gender"], "higher": ["age", "education"]}
categ_features = ["gender", "ethnicity", "occupation", "marital-status", "education", "workclass", "relationship", "native-country"]
# load dataset
df = load_adultdataset(nodes)
immutable_features = ["ethnicity", "native-country", "gender", "age"]
descriptions = {
"ethnicity": "Refers to an individual's ethnic lineage. Helps in understanding socio-economic patterns and identifying potential discrimination. Value example: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black",
"income": "The derived feature representing an individual's estimated annual income, categorized as <=50K or >50K.",
"occupation": "Denotes a person's job role. Key for understanding income variation based on professional domains. Value example: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.",
"marital-status": "Represents marital standing. Useful for insights on combined incomes, financial commitments, and fiscal stability. Value example: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.",
"hours-per-week": "The number of work hours weekly. Directly relates to earnings potential and employment type.",
"education": "Denotes academic level. Highlights the relationship between education, job prospects, and earnings. Value example: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool",
"age": "Indicates an individual's age. Offers insights into career stage and potential earnings.",
"workclass": "The employment status of the individual, categorized into Private, Self-emp, Govt, Without-pay, or Never-worked. This feature is valuable for predicting annual income by highlighting income disparities, occupation types, job stability, and interactions with other features like education, aiding in more accurate income predictions.",
"relationship": "Outlines family dynamics, like 'Wife' or 'Unmarried'. Helps understand household financial responsibilities. Value example: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried",
"native-country": "Individual's birth country. Offers insights into economic backgrounds and potential income based on origin.",
"gender": "Specifies as Female or Male. Useful for highlighting potential income disparities and gender-based biases."
}
llm = Replicate(
model="replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1",
input={"temperature": 0.01, "system_prompt": ""},
)
result = get_graphs(df, descriptions, immutable_features, "individual's annual income results from various factors", "income", "results", llm)