Skip to content

Commit

Permalink
optimize sampling utils
Browse files Browse the repository at this point in the history
  • Loading branch information
kjappelbaum committed Aug 14, 2024
1 parent 0604d7c commit 233e275
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 7 deletions.
1 change: 0 additions & 1 deletion data/text_sampling/text_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,6 @@ def get_sample_dict(self, sample: pd.Series, template: str):
multiple_choice_enum_idx = multiple_choice_enum_idx[0] # unpack list
multiple_choice_enum = input_variables[multiple_choice_enum_idx]


# get multiple_choice_var
multiple_choice_var_idx = [
i for i, x in enumerate(input_variables) if x.endswith("%")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dataset_creation = [
chemnlp-generate-meta = "chemnlp.data.meta_yaml_generator:cli"
chemnlp-augment-meta = "chemnlp.data.meta_yaml_augmenter:cli"
chemnlp-sample = "chemnlp.data.sampler_cli:cli"
chemlp-add-random-split-column = "chemnlp.data.utils:add_random_split_column_cli"
chemnlp-add-random-split-column = "chemnlp.data.utils:add_random_split_column_cli"

[tool.setuptools_scm]
version_scheme = "post-release"
2 changes: 2 additions & 0 deletions src/chemnlp/data/meta_yaml_augmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,10 @@ def _cli(data_dir: str, model: str = "gpt-4o", override: bool = False):

return augmented_meta_yaml


def cli():
fire.Fire(_cli)


if __name__ == "__main__":
fire.Fire(_cli)
1 change: 1 addition & 0 deletions src/chemnlp/data/meta_yaml_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def _cli(
def cli():
fire.Fire(_cli)


# Example usage
if __name__ == "__main__":
fire.Fire(_cli)
4 changes: 3 additions & 1 deletion src/chemnlp/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def _wrap_identifier(self, identifier: str, value: str) -> str:
except ValueError:
identifier_type = None

if identifier_type and identifier_type not in self.config.get('excluded_from_wrapping', []):
if identifier_type and identifier_type not in self.config.get(
"excluded_from_wrapping", []
):
return f"[BEGIN_{identifier_type}]{value}[END_{identifier_type}]"
return value

Expand Down
4 changes: 3 additions & 1 deletion src/chemnlp/data/sampler_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def process_dataset(
"multiple_choice_benchmarking_format": None,
"wrap_identifiers": wrap_identifiers,
"benchmarking_templates": benchmarking,
"excluded_from_wrapping": ['Other']
"excluded_from_wrapping": ["Other"],
}

templates = meta["templates"]
Expand Down Expand Up @@ -172,8 +172,10 @@ def main(
wrap_identifiers,
)


def cli():
fire.Fire(main)


if __name__ == "__main__":
fire.Fire(main)
4 changes: 3 additions & 1 deletion src/chemnlp/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ def add_random_split_column(df):

return df


def _add_random_split_column(file):
df = pd.read_csv(file)
df = add_random_split_column(df)
df.to_csv(file, index=False)

def add_random_split_column_cli(file: str):

def add_random_split_column_cli():
fire.Fire(_add_random_split_column)


Expand Down
4 changes: 2 additions & 2 deletions tests/data/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def sample_config():
"multiple_choice_rnd_symbols": ["", ".)", ")"],
"multiple_choice_benchmarking_templates": False,
"multiple_choice_benchmarking_format": None,
"excluded_from_wrapping": ['Other']
"excluded_from_wrapping": ["Other"],
}


Expand All @@ -123,7 +123,7 @@ def sample_config_with_wrapping():
"multiple_choice_benchmarking_templates": False,
"multiple_choice_benchmarking_format": None,
"wrap_identifiers": True,
"excluded_from_wrapping": ['Other']
"excluded_from_wrapping": ["Other"],
}


Expand Down

0 comments on commit 233e275

Please sign in to comment.