diff --git a/lm_survey/survey/survey.py b/lm_survey/survey/survey.py index 15f54f0..61eab71 100644 --- a/lm_survey/survey/survey.py +++ b/lm_survey/survey/survey.py @@ -462,9 +462,7 @@ def to_dict(self) -> typing.List[typing.Dict]: return [variable.to_dict() for variable in self.variables] def iterate( - self, - n_samples_per_dependent_variable: typing.Optional[int] = None, - n_cull_sampled_below: typing.Optional[int] = None, + self, n_samples_per_dependent_variable: typing.Optional[int] = None ) -> typing.Iterator[DependentVariableSample]: if n_samples_per_dependent_variable is None: n_samples_per_dependent_variable = len(self.df) @@ -473,8 +471,6 @@ def iterate( key: 0 for key in self._dependent_variables.keys() } - dv_samples = [] - # The index from iterrows gives type errors when using it as a key in iloc. for i, (_, row) in enumerate(self.df.iterrows()): try: @@ -511,7 +507,7 @@ def iterate( correct_completion=correct_completion, ) - dv_sample = DependentVariableSample( + yield DependentVariableSample( variable_name=name, question=dependent_variable.to_question(row), independent_variables=independent_variables, @@ -519,33 +515,9 @@ def iterate( prompt=prompt, completion=completion, ) - if not n_cull_sampled_below: - yield dv_sample - else: - dv_samples.append(dv_sample) n_sampled_per_dependent_variable[name] += 1 - if n_cull_sampled_below: - # print all n_sampled_per_dependent_variable < n_cull_sampled_below - print( - "\n".join( - [ - f"{name}: {n_sampled_per_dependent_variable[name]}" - for name in n_sampled_per_dependent_variable.keys() - if n_sampled_per_dependent_variable[name] < n_cull_sampled_below - ] - ) - ) - # Remove any dependent variables that were sampled below the threshold. - dv_samples = [ - dv_sample - for dv_sample in dv_samples - if n_sampled_per_dependent_variable[dv_sample.variable_name] - >= n_cull_sampled_below - ] - yield from dv_samples - def mutual_info_stats(self, include_demographics=False) -> pd.DataFrame: mutual_info_dvs = {} independent_variable_names = [iv.name for iv in self._independent_variables]