Skip to content

Commit

Permalink
Add new tests
Browse files Browse the repository at this point in the history
  • Loading branch information
benedekrozemberczki committed Jan 11, 2022
1 parent d455890 commit 72e7036
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
56 changes: 53 additions & 3 deletions chemicalx/data/datasetloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,65 @@


class DatasetLoader:
"""
General dataset loader for the integrated drug pair scoring datasets.
"""

def __init__(self, dataset_name: str):
"""
Args:
dataset_name (str): The name of the dataset.
"""
self.base_url = "https://raw.githubusercontent.com/AstraZeneca/chemicalx/main/dataset"
self.dataset_name = dataset_name
assert dataset_name in ["drugcombdb", "drugcomb"]

def generate_path(self, file_name: str) -> str:
"""
Generating a complete url for a dataset file.
Args:
file_name (str): Name of the data file.
Returns:
data_path (str): The complete url to the dataset.
"""
data_path = "/".join([self.base_url, self.dataset_name, file_name])
return data_path

def load_raw_json_data(self, path: str) -> Dict:
"""
Given a path reading the raw JSON dataset.
Args:
path (str): The path to the JSON file.
Returns:
raw_data (dict): A dictionary with the data.
"""
with urllib.request.urlopen(path) as url:
raw_data = json.loads(url.read().decode())
return raw_data

def load_raw_csv_data(self, path: str) -> pd.DataFrame:
"""
Reading the labeled triples CSV in memory.
Args:
path (str): The path to the triples CSV file.
Returns:
raw_data (pd.DataFrame): A pandas DataFrame with the data.
"""
data_bytes = urllib.request.urlopen(path).read()
types = {"drug_1": str, "drug_2": str, "context": str, "label": float}
raw_data = pd.read_csv(io.BytesIO(data_bytes), encoding="utf8", sep=",", dtype=types)
return raw_data

def get_context_features(self):
"""
Reading the context feature set.
Returns:
context_feature_set (ContextFeatureSet): The ContextFeatureSet of the dataset of interest.
"""
path = self.generate_path("context_set.json")
raw_data = self.load_raw_json_data(path)
raw_data = {k: np.array(v) for k, v in raw_data.items()}
Expand All @@ -37,6 +75,12 @@ def get_context_features(self):
return context_feature_set

def get_drug_features(self):
"""
Reading the drug feature set.
Returns:
drug_feature_set (DrugFeatureSet): The DrugFeatureSet of the dataset of interest.
"""
path = self.generate_path("drug_set.json")
raw_data = self.load_raw_json_data(path)
raw_data = {k: {"smiles": v["smiles"], "features": np.array(v["features"])} for k, v in raw_data.items()}
Expand All @@ -45,8 +89,14 @@ def get_drug_features(self):
return drug_feature_set

def get_labeled_triples(self):
"""
Getting the labeled triples file from the storage.
Returns:
labeled_triples (LabeledTriples): The labeled triples in the dataset.
"""
path = self.generate_path("labeled_triples.csv")
raw_data = self.load_raw_csv_data(path)
labeled_triple_set = LabeledTriples()
labeled_triple_set.update_from_pandas(raw_data)
return labeled_triple_set
labeled_triples = LabeledTriples()
labeled_triples.update_from_pandas(raw_data)
return labeled_triples
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
setup(
name="chemicalx",
packages=find_packages(),
version="0.0.4",
version="0.0.5",
license="Apache License, Version 2.0",
description="A Deep Learning Library for Drug Pair Scoring.",
author="Benedek Rozemberczki and Charles Hoyt",
author_email="",
url="https://github.com/AstraZeneca/chemicalx",
download_url="https://github.com/AstraZeneca/chemicalx/archive/v0.0.4.tar.gz",
download_url="https://github.com/AstraZeneca/chemicalx/archive/v0.0.5.tar.gz",
keywords=keywords,
install_requires=install_requires,
setup_requires=setup_requires,
Expand Down

0 comments on commit 72e7036

Please sign in to comment.