Skip to content

Commit

Permalink
updating to stata17
Browse files Browse the repository at this point in the history
  • Loading branch information
mahaalbashir committed Mar 19, 2024
1 parent bfa6696 commit f0c6895
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 92 deletions.
276 changes: 187 additions & 89 deletions acro/acro_stata_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
MIT licenses apply.
"""

import re

import pandas as pd
import statsmodels.iolib.summary as sm_iolib_summary

Expand Down Expand Up @@ -83,61 +85,91 @@ def apply_stata_expstmt(raw: str, all_data: pd.DataFrame) -> pd.DataFrame:
return all_data.iloc[start : end + 1]


def find_brace_contents(word: str, raw: str):
def find_brace_word(word: str, raw: str):
"""
Given a word followed by a (
finds and returns as a list of strings
the rest of the contents up to the closing ).
first returned value is True/False depending on parsing ok.
"""
result = []
idx = raw.find(word)
if idx == -1:
return False, f"{word} not found"
idx += len(word) + 1
substr = ""
while idx < len(raw) and raw[idx] != ")":
substr += raw[idx]
idx += 1
while idx != -1:
substr = ""
idx += len(word) + 1
while idx < len(raw) and raw[idx] != ")":
substr += raw[idx]
idx += 1

if idx == len(raw):
return False, "phrase not completed"

if idx == len(raw):
return False, "phrase not completed"
return True, substr
result.append(substr)
idx = raw.find(word, idx)

return True, result

def parse_table_details(varlist: list, varnames: list, options: str) -> dict:

def parse_table_details(
varlist: list, varnames: list, options: str, stata_version: str
) -> dict:
"""Function to parse stata-16 style table calls
Note this is not for latest version of stata, syntax here:
https://www.stata.com/manuals16/rtable.pdf
>> table rowvar [colvar [supercolvar] [if] [in] [weight] [, options].
"""
details: dict = {"errmsg": "", "rowvars": list([]), "colvars": list([])}
details["rowvars"] = [varlist.pop(0)]
details["colvars"] = list(reversed(varlist))
# by() contents are super-rows
found, superrows = find_brace_contents("by", options)
if found and len(superrows) > 0:
extras = superrows.split()
for word in extras:
if word not in varnames:
details["errmsg"] = (
f"Error: word {word} in by-list is not a variables name"
)
return details
if word not in details["rowvars"]:
details["rowvars"].insert(0, word)

if stata_version == "16":
details["rowvars"] = [varlist.pop(0)]
details["colvars"] = list(reversed(varlist))

contents_found, content = find_brace_word("contents", options)

# by() contents are super-rows
by_found, superrows = find_brace_word("by", options)
if by_found and len(superrows) > 0:
for row in superrows:
extras = row.split()
for word in extras:
if word not in varnames:
details["errmsg"] = (
f"Error: word {word} in by-list is not a variables name"
)
return details
if word not in details["rowvars"]:
details["rowvars"].insert(0, word)

elif stata_version == "17":
details["rowvars"] = varlist.pop(0).split()
details["colvars"] = varlist.pop(0).split()
if len(details["rowvars"]) == 0 or len(details["colvars"]) == 0:
details["errmsg"] = (
"acro does not currently support one dimensioanl tables. "
"To calculate cross tabulation, you need to provide at least "
"one row and one column."
)
return details
# print(details["rowvars"])
# print(details["colvars"])
if varlist:
details["tables"] = varlist.pop(0).split()
# print(f"table is {details['tables']}")

contents_found, content = find_brace_word("statistic", options)

# contents can be variable names or aggregation functions
details["aggfuncs"], details["values"] = list([]), list([])
found, content = find_brace_contents("contents", options)
if found and len(content) > 0:
contents = content.split()
for word in contents:
if word in varnames:
if word not in details["values"]:
details["values"].append(word)
else:
if word not in details["aggfuncs"]:
details["aggfuncs"].append(word)
if contents_found and len(content) > 0:
for element in content:
contents = element.split()
for word in contents:
if word in varnames:
if word not in details["values"]:
details["values"].append(word)
else:
if word not in details["aggfuncs"]:
details["aggfuncs"].append(word)

# default values
details["totals"] = False
Expand All @@ -156,6 +188,7 @@ def parse_and_run( # pylint: disable=too-many-arguments,too-many-locals
exp: str,
weights: str,
options: str,
stata_version: str,
) -> pd.DataFrame:
"""
Takes a dataframe and the parsed stata command line.
Expand All @@ -172,8 +205,29 @@ def parse_and_run( # pylint: disable=too-many-arguments,too-many-locals
# Sometime_TODO de-abbreviate according to
# https://www.stata.com/manuals13/u11.pdf#u11.1.3ifexp

varlist: list = varlist_as_str.split()
# print(f' split varlist is {varlist}')
if stata_version == "16":
varlist: list = varlist_as_str.split()
elif stata_version == "17":
# Regular expression pattern to match strings within parentheses
pattern = re.compile(r"\((.*?)\)")

# Extract strings within parentheses
strings_within_parentheses = re.findall(pattern, varlist_as_str)
# print(f"string_within_parentheses are {strings_within_parentheses}")

# Remove strings within parentheses from the input string
remaining_string = re.sub(pattern, " string_with_parentheses ", varlist_as_str)
# print(f"remaining string is {remaining_string}")

# Combine the strings within parentheses and strings outside parentheses
varlist = []
for string in remaining_string.split():
# print(f"string is {string}")
if string == "string_with_parentheses":
varlist.append(strings_within_parentheses.pop(0))
else:
varlist.append(string)
# print(f"split varlist is {varlist}")

# data reduction
# print(f'before in {mydata.shape}')
Expand All @@ -191,7 +245,7 @@ def parse_and_run( # pylint: disable=too-many-arguments,too-many-locals
elif command in ["remove_output", "rename_output", "add_comments", "add_exception"]:
outcome = run_output_command(command, varlist)
elif command == "table":
outcome = run_table_command(mydata, varlist, weights, options)
outcome = run_table_command(mydata, varlist, weights, options, stata_version)

elif command in ["regress", "probit", "logit"]:
outcome = run_regression(command, mydata, varlist)
Expand Down Expand Up @@ -256,6 +310,7 @@ def run_output_command(command: str, varlist: list) -> str:
stata_config.stata_acro.rename_output(the_output, the_str)
outcome = f"output {the_output} renamed to {the_str}.\n"
elif command == "add_comments":
print("entering the add_comments")
stata_config.stata_acro.add_comments(the_output, the_str)
outcome = f"Comments added to output {the_output}.\n"
elif command == "add_exception":
Expand All @@ -272,6 +327,7 @@ def run_table_command( # pylint: disable=too-many-arguments,too-many-locals
varlist: list,
weights: str,
options: str,
stata_version: str,
) -> str:
"""
Converts a stata table command into an acro.crosstab
Expand All @@ -282,67 +338,109 @@ def run_table_command( # pylint: disable=too-many-arguments,too-many-locals
return f"weights not currently implemented for _{weights}_\n"

varnames = data.columns
details = parse_table_details(varlist, varnames, options)
details = parse_table_details(varlist, varnames, options, stata_version)
if len(details["errmsg"]) > 0:
return details["errmsg"]

aggfuncs = list(map(lambda x: x.replace("sd", "std"), details["aggfuncs"]))
rows, cols = [], []
# don't pass single aggfunc as a list
if len(aggfuncs) == 1:
aggfuncs = aggfuncs[0]

for row in details["rowvars"]:
rows.append(data[row])
for col in details["colvars"]:
cols.append(data[col])
if len(aggfuncs) > 0 and len(details["values"]) > 0:
# sanity checking
# if len(rows) > 1 or len(cols) > 1:
# msg = (
# "acro crosstab with an aggregation function "
# " does not currently support hierarchies within rows or columns"
# )
# return msg

if len(details["values"]) > 1:
msg = (
"pandas crosstab can aggregate over multiple functions "
"but only over one feature/attribute: provided as 'value'"
)
return msg
val = details["values"][0]
values = data[val]

safe_output = stata_config.stata_acro.crosstab(
index=rows,
columns=cols,
aggfunc=aggfuncs,
values=values,
margins=details["totals"],
margins_name="Total",
set_of_data = {"Total": data}
msg = ""
# if tables var parameter was assigned, each table will
# be treated as an exlcion which will be applied to the data.
# The number of datasets will be equal to the number of unique values in the tables var
# Crosstabulation will be calculate for each dataset
if "tables" in details:
# print(f"table is {details['tables']}")
msg = (
"You need to manually check all the outputs for the risk of differncing.\n"
)
for table in details["tables"]:
unique_values = data[table].unique()
# print(f"unique_values are {unique_values}")
for value in unique_values:
if isinstance(value, str):
exclusion = f"{table}=='{value}'"
else:
exclusion = f"{table}=={value}"
# print(f"exclusion is {exclusion}")
my_data = apply_stata_ifstmt(exclusion, data)
set_of_data[exclusion] = my_data
# print(f"set of data is {set_of_data}")
results = ""
output_count = 0
for exclusion, my_data in set_of_data.items():
rows, cols = [], []
# print(f"my data is {my_data}")
for row in details["rowvars"]:
rows.append(my_data[row])
for col in details["colvars"]:
cols.append(my_data[col])
# print(f"rows are {rows}")
# print(f"cols are {cols}")
if len(aggfuncs) > 0 and len(details["values"]) > 0:
# sanity checking
# if len(rows) > 1 or len(cols) > 1:
# msg = (
# "acro crosstab with an aggregation function "
# " does not currently support hierarchies within rows or columns"
# )
# return msg

if len(details["values"]) > 1:
msg = (
"pandas crosstab can aggregate over multiple functions "
"but only over one feature/attribute: provided as 'value'"
)
return msg
val = details["values"][0]
values = data[val]
print(exclusion)
safe_output = stata_config.stata_acro.crosstab(
index=rows,
columns=cols,
aggfunc=aggfuncs,
values=values,
margins=details["totals"],
margins_name="Total",
)

else:
safe_output = stata_config.stata_acro.crosstab(
index=rows,
columns=cols,
# suppress=details['suppress'],
margins=details["totals"],
margins_name="Total",
else:
print(exclusion)
safe_output = stata_config.stata_acro.crosstab(
index=rows,
columns=cols,
# suppress=details['suppress'],
margins=details["totals"],
margins_name="Total",
)
run_output_command(
"add_comments",
[
f"output_{output_count}",
"You need to manually check all the outputs for the risk of differncing.\n",
],
)
options_str = ""
formatting = [
"cellwidth",
"csepwidth",
"stubwidth",
"scsepwidth",
"center",
"left",
]
if any(word in options for word in formatting):
options_str = "acro does not currently support table formatting commands.\n "
return options_str + prettify_table_string(safe_output) + "\n"
results += f"{exclusion}\n{prettify_table_string(safe_output)}\n"
output_count += 1

options_str = ""
formatting = [
"cellwidth",
"csepwidth",
"stubwidth",
"scsepwidth",
"center",
"left",
]
if any(word in options for word in formatting):
options_str = (
"acro does not currently support table formatting commands.\n "
)
return msg + options_str + results


def run_regression(command: str, data: pd.DataFrame, varlist: list) -> str:
Expand Down
Loading

0 comments on commit f0c6895

Please sign in to comment.