diff --git a/src/gentropy/dataset/l2g_features/other.py b/src/gentropy/dataset/l2g_features/other.py index 2fc32592b..fe8424ecf 100644 --- a/src/gentropy/dataset/l2g_features/other.py +++ b/src/gentropy/dataset/l2g_features/other.py @@ -85,42 +85,56 @@ def common_genecount_feature_logic( def is_protein_coding_feature_logic( study_loci_to_annotate: StudyLocus | L2GGoldStandard, *, - gene_index: GeneIndex, + variant_index: VariantIndex, feature_name: str, - genomic_window: int, + genomic_window: int = 500_000, ) -> DataFrame: """Computes the feature to indicate if a gene is protein-coding or not. Args: study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - gene_index (GeneIndex): Dataset containing information related to all genes in release. + variant_index (VariantIndex): Dataset containing information related to all overlapping genes within a genomic window. feature_name (str): The name of the feature - genomic_window (int): The maximum window size to consider + genomic_window (int): The window size around the locus to consider. Defaults to its maximum value: 500kb up and downstream the locus Returns: DataFrame: Feature dataset, with 1 if the gene is protein-coding, 0 if not. """ - study_loci_window = ( - study_loci_to_annotate.df.withColumn( - "window_start", f.col("position") - (genomic_window / 2) + assert genomic_window <= 500_000, "Genomic window must be less than 500kb." + genes_in_window = ( + variant_index.df.withColumn( + "transcriptConsequence", f.explode("transcriptConsequences") ) - .withColumn("window_end", f.col("position") + (genomic_window / 2)) - .withColumnRenamed("chromosome", "SL_chromosome") + .select( + "variantId", + f.col("transcriptConsequence.targetId").alias("geneId"), + f.col("transcriptConsequence.biotype").alias("biotype"), + f.col("transcriptConsequence.distanceFromFootprint").alias( + "distanceFromFootprint" + ), + ) + .filter(f.col("distanceFromFootprint") <= genomic_window) ) + if isinstance(study_loci_to_annotate, StudyLocus): + variants_df = study_loci_to_annotate.df.select( + f.explode_outer("locus.variantId").alias("variantId"), + "studyLocusId", + ).filter(f.col("variantId").isNotNull()) + elif isinstance(study_loci_to_annotate, L2GGoldStandard): + variants_df = study_loci_to_annotate.df.select("studyLocusId", "variantId") return ( - study_loci_window.join( - gene_index.df.alias("genes"), - on=( - (f.col("SL_chromosome") == f.col("genes.chromosome")) - & (f.col("genes.tss") >= f.col("window_start")) - & (f.col("genes.tss") <= f.col("window_end")) - ), - how="inner", + # Annotate all genes in the window of a locus + variants_df.join( + genes_in_window, + on="variantId", ) + # Apply flag across all variants in the locus .withColumn( feature_name, - f.when(f.col("biotype") == "protein_coding", f.lit(1)).otherwise(f.lit(0)), + f.when(f.col("biotype") == "protein_coding", f.lit(1.0)).otherwise( + f.lit(0.0) + ), ) .select("studyLocusId", "geneId", feature_name) .distinct() @@ -211,7 +225,7 @@ def compute( class ProteinCodingFeature(L2GFeature): """Indicates whether a gene is protein-coding within a specified window size from the study locus.""" - feature_dependency_type = GeneIndex + feature_dependency_type = VariantIndex feature_name = "isProteinCoding" @classmethod @@ -224,12 +238,12 @@ def compute( Args: study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dictionary containing dependencies, including gene index + feature_dependency (dict[str, Any]): Dictionary containing dependencies, including variant index Returns: ProteinCodingFeature: Feature dataset with 1 if the gene is protein-coding, 0 otherwise """ - genomic_window = 1000000 + genomic_window = 500_000 protein_coding_df = is_protein_coding_feature_logic( study_loci_to_annotate=study_loci_to_annotate, feature_name=cls.feature_name, diff --git a/tests/gentropy/dataset/test_l2g_feature.py b/tests/gentropy/dataset/test_l2g_feature.py index feb8e449a..9320b3aa4 100644 --- a/tests/gentropy/dataset/test_l2g_feature.py +++ b/tests/gentropy/dataset/test_l2g_feature.py @@ -238,9 +238,11 @@ def sample_variant_index_schema() -> StructType: ArrayType( StructType( [ + StructField("distanceFromFootprint", LongType(), True), StructField("distanceFromTss", LongType(), True), StructField("targetId", StringType(), True), StructField("isEnsemblCanonical", BooleanType(), True), + StructField("biotype", StringType(), True), ] ) ), @@ -624,13 +626,17 @@ def _setup( [ { "distanceFromTss": 10, + "distanceFromFootprint": 0, "targetId": "gene1", "isEnsemblCanonical": True, + "biotype": "protein_coding", }, { "distanceFromTss": 2, + "distanceFromFootprint": 0, "targetId": "gene2", "isEnsemblCanonical": True, + "biotype": "protein_coding", }, ], ), @@ -643,8 +649,10 @@ def _setup( [ { "distanceFromTss": 5, + "distanceFromFootprint": 0, "targetId": "gene1", "isEnsemblCanonical": True, + "biotype": "protein_coding", }, ], ), @@ -928,9 +936,8 @@ class TestCommonProteinCodingFeatureLogic: [ ( [ - {"studyLocusId": "1", "geneId": "gene1", "isProteinCoding500kb": 1}, - {"studyLocusId": "1", "geneId": "gene2", "isProteinCoding500kb": 1}, - {"studyLocusId": "1", "geneId": "gene3", "isProteinCoding500kb": 0}, + {"studyLocusId": "1", "geneId": "gene1", "isProteinCoding": 1.0}, + {"studyLocusId": "1", "geneId": "gene2", "isProteinCoding": 0.0}, ] ), ], @@ -944,17 +951,16 @@ def test_is_protein_coding_feature_logic( observed_df = ( is_protein_coding_feature_logic( study_loci_to_annotate=self.sample_study_locus, - gene_index=self.sample_gene_index, - feature_name="isProteinCoding500kb", - genomic_window=500000, + variant_index=self.sample_variant_index, + feature_name="isProteinCoding", ) - .select("studyLocusId", "geneId", "isProteinCoding500kb") + .select("studyLocusId", "geneId", "isProteinCoding") .orderBy("studyLocusId", "geneId") ) expected_df = ( spark.createDataFrame(expected_data) - .select("studyLocusId", "geneId", "isProteinCoding500kb") + .select("studyLocusId", "geneId", "isProteinCoding") .orderBy("studyLocusId", "geneId") ) assert ( @@ -962,7 +968,11 @@ def test_is_protein_coding_feature_logic( ), "Expected and observed DataFrames do not match." @pytest.fixture(autouse=True) - def _setup(self: TestCommonProteinCodingFeatureLogic, spark: SparkSession) -> None: + def _setup( + self: TestCommonProteinCodingFeatureLogic, + spark: SparkSession, + sample_variant_index_schema: StructType, + ) -> None: """Set up sample data for the test.""" # Sample study locus data self.sample_study_locus = StudyLocus( @@ -974,39 +984,47 @@ def _setup(self: TestCommonProteinCodingFeatureLogic, spark: SparkSession) -> No "studyId": "study1", "chromosome": "1", "position": 1000000, + "locus": [ + { + "variantId": "var1", + }, + ], }, ], StudyLocus.get_schema(), ), _schema=StudyLocus.get_schema(), ) - - # Sample gene index data with biotype - self.sample_gene_index = GeneIndex( + self.sample_variant_index = VariantIndex( _df=spark.createDataFrame( [ - { - "geneId": "gene1", - "chromosome": "1", - "tss": 950000, - "biotype": "protein_coding", - }, - { - "geneId": "gene2", - "chromosome": "1", - "tss": 1050000, - "biotype": "protein_coding", - }, - { - "geneId": "gene3", - "chromosome": "1", - "tss": 1010000, - "biotype": "non_coding", - }, + ( + "var1", + "chrom", + 1, + "A", + "T", + [ + { + "distanceFromFootprint": 0, + "distanceFromTss": 10, + "targetId": "gene1", + "biotype": "protein_coding", + "isEnsemblCanonical": True, + }, + { + "distanceFromFootprint": 0, + "distanceFromTss": 20, + "targetId": "gene2", + "biotype": "non_coding", + "isEnsemblCanonical": True, + }, + ], + ), ], - GeneIndex.get_schema(), + sample_variant_index_schema, ), - _schema=GeneIndex.get_schema(), + _schema=VariantIndex.get_schema(), ) @@ -1067,8 +1085,10 @@ def _setup( [ { "distanceFromTss": 10, + "distanceFromFootprint": 0, "targetId": "gene1", "isEnsemblCanonical": True, + "biotype": "protein_coding", }, ], ) diff --git a/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py b/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py index 79f9d925a..54bcbf8d0 100644 --- a/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py +++ b/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py @@ -29,7 +29,6 @@ from pyspark.sql.session import SparkSession from gentropy.dataset.colocalisation import Colocalisation - from gentropy.dataset.gene_index import GeneIndex from gentropy.dataset.study_locus import StudyLocus @@ -162,15 +161,15 @@ def test_build_feature_matrix( mock_study_locus: StudyLocus, mock_colocalisation: Colocalisation, mock_study_index: StudyIndex, - mock_gene_index: GeneIndex, + mock_variant_index: VariantIndex, ) -> None: - """Test building feature matrix with the eQtlColocH4Maximum feature.""" + """Test building feature matrix with the eQtlColocH4Maximum and isProteinCoding features.""" features_list = ["eQtlColocH4Maximum", "isProteinCoding"] loader = L2GFeatureInputLoader( colocalisation=mock_colocalisation, study_index=mock_study_index, study_locus=mock_study_locus, - gene_index=mock_gene_index, + variant_index=mock_variant_index, ) fm = mock_study_locus.build_feature_matrix(features_list, loader) assert isinstance(