Skip to content

Commit

Permalink
Revert "Add 'culling sampled below n' code (note!)"
Browse files Browse the repository at this point in the history
This reverts commit 39026e1.
  • Loading branch information
vinhowe committed May 20, 2023
1 parent 7e7adef commit 7435097
Showing 1 changed file with 2 additions and 30 deletions.
32 changes: 2 additions & 30 deletions lm_survey/survey/survey.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,9 +462,7 @@ def to_dict(self) -> typing.List[typing.Dict]:
return [variable.to_dict() for variable in self.variables]

def iterate(
self,
n_samples_per_dependent_variable: typing.Optional[int] = None,
n_cull_sampled_below: typing.Optional[int] = None,
self, n_samples_per_dependent_variable: typing.Optional[int] = None
) -> typing.Iterator[DependentVariableSample]:
if n_samples_per_dependent_variable is None:
n_samples_per_dependent_variable = len(self.df)
Expand All @@ -473,8 +471,6 @@ def iterate(
key: 0 for key in self._dependent_variables.keys()
}

dv_samples = []

# The index from iterrows gives type errors when using it as a key in iloc.
for i, (_, row) in enumerate(self.df.iterrows()):
try:
Expand Down Expand Up @@ -511,41 +507,17 @@ def iterate(
correct_completion=correct_completion,
)

dv_sample = DependentVariableSample(
yield DependentVariableSample(
variable_name=name,
question=dependent_variable.to_question(row),
independent_variables=independent_variables,
index=i,
prompt=prompt,
completion=completion,
)
if not n_cull_sampled_below:
yield dv_sample
else:
dv_samples.append(dv_sample)

n_sampled_per_dependent_variable[name] += 1

if n_cull_sampled_below:
# print all n_sampled_per_dependent_variable < n_cull_sampled_below
print(
"\n".join(
[
f"{name}: {n_sampled_per_dependent_variable[name]}"
for name in n_sampled_per_dependent_variable.keys()
if n_sampled_per_dependent_variable[name] < n_cull_sampled_below
]
)
)
# Remove any dependent variables that were sampled below the threshold.
dv_samples = [
dv_sample
for dv_sample in dv_samples
if n_sampled_per_dependent_variable[dv_sample.variable_name]
>= n_cull_sampled_below
]
yield from dv_samples

def mutual_info_stats(self, include_demographics=False) -> pd.DataFrame:
mutual_info_dvs = {}
independent_variable_names = [iv.name for iv in self._independent_variables]
Expand Down

0 comments on commit 7435097

Please sign in to comment.