diff --git a/smashbenchmarking/bench.py b/smashbenchmarking/bench.py index 8df944f..4c57cb5 100755 --- a/smashbenchmarking/bench.py +++ b/smashbenchmarking/bench.py @@ -51,19 +51,36 @@ date_run = datetime.datetime.now().strftime("%Y-%m-%d %H:%M") # this needs to move to another class -tsv_header = ['VariantType','#True','#Pred','Precision','Recall','TP','FP','FN','NonReferenceDiscrepancy'] - -def tsv_row(variant_name,stats,err): - return [variant_name, - stats['num_true'], - stats['num_pred'], - interval(*bound_precision(stats['good_predictions'],stats['false_positives'],err)), - interval(*bound_recall(stats['good_predictions'],stats['false_negatives'],err)), - stats['good_predictions'], - stats['false_negatives'], - stats['false_positives'], - get_nrd(stats) - ] +def get_tsv_header(knownFP=False): + if knownFP: + return ['VariantType','#True','#Pred','Precision', 'FP Precision','Recall','TP','FP','FN','NonReferenceDiscrepancy'] + else: + return ['VariantType','#True','#Pred','Precision','Recall','TP','FP','FN','NonReferenceDiscrepancy'] + +def tsv_row(variant_name,stats,err,knownFP=False): + if knownFP: + return [variant_name, + stats['num_true'], + stats['num_pred'], + interval(*bound_precision(stats['good_predictions'],stats['false_positives'],err)), + interval(*bound_precision(stats['good_predictions'],stats['calls_at_known_fp'],err)), + interval(*bound_recall(stats['good_predictions'],stats['false_negatives'],err)), + stats['good_predictions'], + stats['calls_at_known_fp'], + stats['false_negatives'], + get_nrd(stats) + ] + else: + return [variant_name, + stats['num_true'], + stats['num_pred'], + interval(*bound_precision(stats['good_predictions'],stats['false_positives'],err)), + interval(*bound_recall(stats['good_predictions'],stats['false_negatives'],err)), + stats['good_predictions'], + stats['false_positives'], + stats['false_negatives'], + get_nrd(stats) + ] def get_nrd(stats): if stats['num_true'] > 0: @@ -103,7 +120,7 @@ def bound_precision(tp, fp, e): return 0, 0 return (tp - e) / p, (tp + e) / p -def print_snp_results(num_true, num_pred, num_fp, num_fn, num_ib, num_ig, nrd, known_fp_prec, err, known_fp_vars=False): +def print_snp_results(num_true, num_pred, num_fp, num_fn, num_ib, num_ig, nrd, known_fp_calls, err, known_fp_vars=False): print("\n-----------") print("SNP Results") print("-----------") @@ -113,7 +130,7 @@ def print_snp_results(num_true, num_pred, num_fp, num_fn, num_ib, num_ig, nrd, k assert tp + num_fp <= num_pred print("\t# precision =", interval(*bound_precision(tp, num_fp, err))) if known_fp_vars: - print("\t# precision (known FP) = %.1f" % (100*known_fp_prec) ) + print("\t# precision (known FP) =", interval(*bound_precision(tp,known_fp_calls,err)) ) print("\t# recall =", interval(*bound_recall(tp, num_fn, err))) print("\t# allele mismatch = %d" % num_ib) print("\t# correct = %d" % num_ig) @@ -126,7 +143,7 @@ def print_snp_stats(stats, err, known_fp_vars=False): print_snp_results(stats['num_true'], stats['num_pred'], stats['false_positives'], stats['false_negatives'], stats['intersect_bad'], stats['good_predictions'],ratio(stats['nrd_wrong'],stats['nrd_total']), - 1-ratio(stats['known_fp_calls'],stats['known_fp']),err, known_fp_vars) + stats['known_fp_calls'],err, known_fp_vars) def ratio(a,b,sig=5): if b == 0: @@ -135,14 +152,14 @@ def ratio(a,b,sig=5): return 0.0 return float(int(10**sig*(float(a)/b)))/10**sig -def print_sv_results(var_type_str, num_true, num_pred, num_fp, num_fn, num_mm, num_gp, nrd, known_fp_prec, err, known_fp_vars=False): +def print_sv_results(var_type_str, num_true, num_pred, num_fp, num_fn, num_mm, num_gp, nrd, known_fp_calls, err, known_fp_vars=False): print("\n\n------------------------") print("%s Results" % var_type_str) print("------------------------") print("# True = %d; # Predicted = %d" % (num_true, num_pred)) print("\t# precision =", interval(*bound_precision(num_gp, num_fp, err))) if known_fp_vars: - print("\t# precision (known FP) = %.1f" % (100*known_fp_prec) ) + print("\t# precision (known FP) = ", interval(*bound_precision(num_gp,known_fp_calls,err)) ) print("\t# recall =", interval(*bound_recall(num_true - num_fn, num_fn, err))) #print "\t# multiple matches = %d" % num_mm print("\t# correct = %d" % num_gp) @@ -161,7 +178,7 @@ def print_sv_stats(description, stats, err): stats['false_positives'], stats['false_negatives'], #stats['mult_matches'], 0, stats['good_predictions'], ratio(stats['nrd_wrong'],stats['nrd_total']), - 1-ratio(stats['known_fp_calls'],stats['known_fp']), err) + stats['known_fp_calls'], err) def print_sv_other_results(var_type_str, num_true, num_pred): print("\n\n------------------------") @@ -265,7 +282,7 @@ def main(params): with open(args.knownFP) as f: known_fp_vcf = vcf.Reader(f) known_fp_vars = Variants(known_fp_vcf, - MAX_INDEL_LEN, knownFp=True) + MAX_INDEL_LEN, knownFP=True) else: known_fp_vars = None @@ -291,7 +308,7 @@ def main(params): if args.output == "tsv": print(get_text_header(params),file=sys.stdout) tsvwriter = csv.writer(sys.stdout, delimiter='\t') - tsvwriter.writerow(tsv_header) + tsvwriter.writerow(get_tsv_header(args.knownFP)) tsvwriter.writerow(tsv_row("SNP",stat_reporter(VARIANT_TYPE.SNP),snp_err)) tsvwriter.writerow(tsv_row("Indel Deletions",stat_reporter(VARIANT_TYPE.INDEL_DEL),indel_err)) tsvwriter.writerow(tsv_row("Indel Insertions",stat_reporter(VARIANT_TYPE.INDEL_INS),indel_err)) diff --git a/smashbenchmarking/vcf_eval/chrom_variants.py b/smashbenchmarking/vcf_eval/chrom_variants.py index e061fe8..26cc9ca 100644 --- a/smashbenchmarking/vcf_eval/chrom_variants.py +++ b/smashbenchmarking/vcf_eval/chrom_variants.py @@ -182,7 +182,7 @@ def add_record(self, record): raise AssertionError("VCF contains lower-case bases in the reference: " + record.REF+ " at "+str(record.POS)) alt = map(str, record.ALT) - if _lacks_alt_alleles(record): + if _lacks_alt_alleles(record) and not self._args.get('knownFP',None): raise Exception("Monomorphic records (no alt allele) are not supported.") len_ref = len(ref) @@ -228,6 +228,8 @@ def add_appropriate_variant(indel_type,sv_type): if record.is_snp: add_variant(VARIANT_TYPE.SNP) + elif self._args.get('knownFP',None) and len(record.REF) == 1 and not record.ALT[0]: # is_snp False if no alt allele + add_variant(VARIANT_TYPE.SNP) else: allele_lengths = map(len, alt) + [len_ref] is_indel = (not is_sv(record,self._max_indel_len) ) and max(allele_lengths) <= self._max_indel_len diff --git a/smashbenchmarking/vcf_eval/eval_helper.py b/smashbenchmarking/vcf_eval/eval_helper.py index b9eef39..43e1e6d 100644 --- a/smashbenchmarking/vcf_eval/eval_helper.py +++ b/smashbenchmarking/vcf_eval/eval_helper.py @@ -127,6 +127,9 @@ def structural_match(true_variant,pred_vars_all,sv_eps,sv_eps_bp): matches = get_closest(true_variant,matches) return matches +def vartype_match_at_location(fp_vars,pred_vars,loc): + return fp_vars.all_variants[loc].ref == pred_vars.all_variants[loc].ref + class ChromVariantStats: """Stats for a certain contig's worth of variants.""" @@ -256,12 +259,12 @@ def chrom_evaluate_variants(true_var,pred_var,sv_eps,sv_eps_bp,ref,window,known_ all_known_fp = _type_dict() if known_fp: for loc in pred_loc.intersection(known_fp.all_locations): - vartype = known_fp.all_variants[loc].var_type - match = any_var_match_at_loc(known_fp,pred_var,loc) + match = vartype_match_at_location(known_fp,pred_var,loc) if match: calls_at_known_fp[vartype] += 1 known_fp_calls_positions.append(loc) - all_known_fp[vartype] += 1 + vartype = known_fp.all_variants[loc].var_type + all_known_fp[vartype] += 1 # note this only holds known fp sharing a location with pred var, NOT all # structural variants are a special case if not matching exactly for loc in (true_loc - pred_loc): @@ -288,7 +291,7 @@ def chrom_evaluate_variants(true_var,pred_var,sv_eps,sv_eps_bp,ref,window,known_ if ( known_fp ): variant_stats.known_fp = all_known_fp variant_stats.calls_at_known_fp = calls_at_known_fp - variant_stats.known_fp_variants = variant_stats._extract(variant_stats.pred_var,known_fp_calls_positions,None) + variant_stats.known_fp_variants = variant_stats._extract(variant_stats.pred_var,known_fp_calls_positions,_type_dict()) variant_stats.intersect_bad = intersect_bad_dict #stats = variant_stats.to_dict() #stats['intersect_bad'] = len(intersect_bad) diff --git a/test/chrom_stats.py b/test/chrom_stats.py index cc4248a..22ab88c 100755 --- a/test/chrom_stats.py +++ b/test/chrom_stats.py @@ -29,7 +29,7 @@ import unittest import StringIO -from test_helper import MAX_INDEL_LEN +from test_helper import MAX_INDEL_LEN,vcf_to_Variants,get_reference sys.path.insert(0,'..') from smashbenchmarking import Variants,evaluate_variants @@ -357,5 +357,45 @@ def test_sv_out_of_range(self): self.trueNegative(stat_reporter,VARIANT_TYPE.SV_DEL) + def test_known_false_positives(self): + true_vcf = """##fileformat=VCFv4.0\n +##FORMAT=\n +##source=TVsim\n +#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n +chr1 1 . T A 20 PASS . GT 0/1\n +chr1 8 . A C 20 PASS . GT 1/1\n +""" + pred_vcf = """##fileformat=VCFv4.0\n +##FORMAT=\n +##source=TVsim\n +#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n +chr1 3 . G C 20 PASS . GT 1/1\n +chr1 5 . C G 20 PASS . GT 0/1\n +chr1 8 . A C 20 PASS . GT 1/1\n +""" + known_fp_vcf = """##fileformat=VCFv4.0\n +##FORMAT=\n +##source=TVsim\n +#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n +chr1 3 . G . 20 PASS . GT 0/0\n +chr1 5 . C G 20 PASS . GT 0/0\n +chr1 9 . T . 20 PASS . GT 0/0\n +""" + known_fp_io = StringIO.StringIO(known_fp_vcf) + known_fp_vars = Variants(vcf.Reader(known_fp_io),MAX_INDEL_LEN,knownFP=True) + + stat_reporter, vcf_output = evaluate_variants(vcf_to_Variants(true_vcf),vcf_to_Variants(pred_vcf),sv_eps,sv_eps, \ + get_reference(),50,known_fp_vars) + + snp_stats = stat_reporter(VARIANT_TYPE.SNP) + + self.assertEqual(snp_stats['num_true'],2) + self.assertEqual(snp_stats['num_pred'],3) + self.assertEqual(snp_stats['good_predictions'],1) + self.assertEqual(snp_stats['false_positives'],2) # predicted vars not in ground truth + self.assertEqual(snp_stats['false_negatives'],1) + self.assertEqual(snp_stats['known_fp_calls'],2) + self.assertEqual(snp_stats['known_fp'],2) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/chrom_variants.py b/test/chrom_variants.py index 8f5974e..4482aeb 100644 --- a/test/chrom_variants.py +++ b/test/chrom_variants.py @@ -148,6 +148,22 @@ def testRemoveRecord(self): self.assertEqual(len(newChromVar._var_locations[VARIANT_TYPE.SV_DEL]),1) self.assertEqual(len(newChromVar._var_dict(VARIANT_TYPE.SV_DEL)),1) + def testKnownFalsePositives(self): + vcf_str = """##fileformat=VCFv4.0\n +##FORMAT=\n +#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n +chr1 7 . G . 20 PASS . GT 0/0\n +""" + newChromVar = ChromVariants('chr1',MAX_INDEL_LEN,knownFP=True) + vcf_io = StringIO.StringIO(vcf_str) + vcfr = vcf.Reader(vcf_io) + for r in vcfr: + newChromVar.add_record(r) + self.assertEqual(newChromVar.all_locations,[7]) + var = newChromVar.all_variants[7] + self.assertEqual(var.ref,'G') + self.assertEqual(var.alt[0], 'None') + #test rando helper methods class ChromVariantHelperMethodsTestCase(unittest.TestCase): def testExtractRange(self): diff --git a/test/eval_helper.py b/test/eval_helper.py index 4f7545e..2d597b2 100644 --- a/test/eval_helper.py +++ b/test/eval_helper.py @@ -35,6 +35,7 @@ from smashbenchmarking.vcf_eval.eval_helper import * from smashbenchmarking.vcf_eval.eval_helper import _genotype_concordance_dict from smashbenchmarking.vcf_eval.chrom_variants import Variant,VARIANT_TYPE,GENOTYPE_TYPE +from smashbenchmarking.vcf_eval.variants import Variants EPS_BP = 10 EPS_LEN = 10 @@ -169,35 +170,36 @@ def testChromEvaluateGenotypeConcordance(self): self.assertEqual(cvs.genotype_concordance[VARIANT_TYPE.SNP][GENOTYPE_TYPE.HOM_VAR][GENOTYPE_TYPE.HOM_VAR],1) self.assertEqual(cvs._nrd_counts(VARIANT_TYPE.SNP),(0,2)) + def testChromEvaluateVariantsKnownFP(self): + # one known true variant + true_str = """##fileformat=VCFv4.0\n +##FORMAT=\n +#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n +chr1 2 . A T 20 PASS . GT 0/1\n + """ + # call var where known fp is, where true var is, where nothing is known + pred_str = """##fileformat=VCFv4.0\n +##FORMAT=\n +#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n +chr1 2 . A T 20 PASS . GT 0/1\n +chr1 4 . G C 20 PASS . GT 1/1\n +chr1 7 . G A 20 PASS . GT 0/1\n + """ + # known locations with NO variant + known_fp_str = """##fileformat=VCFv4.0\n +##FORMAT=\n +#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n +chr1 1 . A T 20 PASS . GT ./.\n +chr1 7 . G . 20 PASS . GT 0/0\n + """ + true_vars = vcf_to_ChromVariants(true_str,'chr1') + pred_vars = vcf_to_ChromVariants(pred_str,'chr1') + known_fp_io = StringIO.StringIO(known_fp_str) + known_fp = Variants(vcf.Reader(known_fp_io),MAX_INDEL_LEN,knownFP=True) + cvs = chrom_evaluate_variants(true_vars,pred_vars,100,100,get_reference(),50,known_fp.on_chrom('chr1')) + self.assertEqual(cvs.num_fp[VARIANT_TYPE.SNP],2) # usual definition, in pred vars but not in true + self.assertEqual(cvs.calls_at_known_fp[VARIANT_TYPE.SNP],1) # call at location known to NOT have SNP -# def testChromEvaluateVariantsKnownFP(self): -# # one known true variant -# true_str = """##fileformat=VCFv4.0\n -# ##FORMAT=\n -# #CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n -# chr1 2 . A T 20 PASS . GT 0/1\n -# """ -# # call var where known fp is, where true var is, where nothing is known -# pred_str = """##fileformat=VCFv4.0\n -# ##FORMAT=\n -# #CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n -# chr1 2 . A T 20 PASS . GT 0/1\n -# chr1 4 . GCC G 20 PASS . GT 1/1\n -# chr1 7 . G A 20 PASS . GT 0/1\n -# """ -# # known location with NO variant -# known_fp_str = """##fileformat=VCFv4.0\n -# ##FORMAT=\n -# #CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n -# chr1 7 . G . 20 PASS . GT 0/0\n -# """ -# true_vars = vcf_to_ChromVariants(true_str,'chr1') -# pred_vars = vcf_to_ChromVariants(pred_str,'chr1') -# known_fp = vcf_to_ChromVariants(known_fp_str,'chr1',True) -# cvs = chrom_evaluate_variants(true_vars,pred_vars,100,100,get_reference(),50,known_fp) -# print(cvs.known_fp) -# print(cvs.calls_at_known_fp) -# print(cvs.known_fp_vars) def testChromEvaluateVariantsSV(self): #NB: SVs aren't rescued, just checked for within breakpoint tolerance true_str = """##fileformat=VCFv4.0\n @@ -456,5 +458,7 @@ def testTruePosRectify(self): self.assertEqual(cvs.num_fp[VARIANT_TYPE.INDEL_DEL],0) + + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/test/test_helper.py b/test/test_helper.py index 25a6623..08a52d1 100644 --- a/test/test_helper.py +++ b/test/test_helper.py @@ -36,6 +36,11 @@ MAX_INDEL_LEN = 50 +def vcf_to_Variants(vcf_str): + str_io = StringIO.StringIO(vcf_str) + str_vcf = vcf.Reader(str_io) + return Variants(str_vcf,MAX_INDEL_LEN) + def vcf_to_ChromVariants(vcf_str,chrom): str_io = StringIO.StringIO(vcf_str) str_vcf = vcf.Reader(str_io) diff --git a/test/variants.py b/test/variants.py index d1ef1d3..6449c71 100755 --- a/test/variants.py +++ b/test/variants.py @@ -6,6 +6,8 @@ import unittest import StringIO +from test_helper import MAX_INDEL_LEN,vcf_to_ChromVariants,get_reference + sys.path.insert(0,'..') from smashbenchmarking.vcf_eval.variants import * from smashbenchmarking.vcf_eval.variants import _aggregate @@ -13,17 +15,6 @@ from smashbenchmarking.vcf_eval.eval_helper import chrom_evaluate_variants,ChromVariantStats,_genotype_concordance_dict from smashbenchmarking.parsers.genome import Genome -MAX_INDEL_LEN = 50 - -def vcf_to_ChromVariants(vcf_str,chrom): - str_io = StringIO.StringIO(vcf_str) - str_vcf = vcf.Reader(str_io) - str_vars = Variants(str_vcf,MAX_INDEL_LEN) - return str_vars.on_chrom(chrom) - -def get_reference(): - return Genome('ref.fasta',lambda t: t.split()[0]) - class VariantsTestCase(unittest.TestCase): def testInit(self): vcf_str = """##fileformat=VCFv4.0\n @@ -36,12 +27,25 @@ def testInit(self): """ vcf_io = StringIO.StringIO(vcf_str) newvcf = vcf.Reader(vcf_io) - newvars = Variants(newvcf, 50) + newvars = Variants(newvcf, MAX_INDEL_LEN) self.assertEqual(len(newvars.chroms),2) self.assertEqual(newvars.var_num(VARIANT_TYPE.SNP),1) self.assertEqual(newvars.var_num(VARIANT_TYPE.INDEL_DEL),2) self.assertEqual(newvars.var_num(VARIANT_TYPE.INDEL_INS),1) + def testKnownFalsePositives(self): + vcf_str = """##fileformat=VCFv4.0\n +##FORMAT=\n +#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n +chr1 7 . G . 20 PASS . GT 0/0\n +""" + vcf_io = StringIO.StringIO(vcf_str) + newvcf = vcf.Reader(vcf_io) + newvars = Variants(newvcf,MAX_INDEL_LEN,knownFP=True) + chromvars = newvars.on_chrom('chr1') + self.assertEqual(chromvars.all_locations,[7]) + self.assertEqual(chromvars.all_variants[7].alt[0],'None') + class VariantsHelpersTestCase(unittest.TestCase): def testAggregate(self):