Skip to content

Commit

Permalink
Merge pull request #8 from amplab/master
Browse files Browse the repository at this point in the history
get known fp mode to work
  • Loading branch information
kwestbrooks committed Jun 10, 2014
2 parents 29882a0 + e77eec7 commit ca66a3a
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 67 deletions.
59 changes: 38 additions & 21 deletions smashbenchmarking/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,36 @@
date_run = datetime.datetime.now().strftime("%Y-%m-%d %H:%M")

# this needs to move to another class
tsv_header = ['VariantType','#True','#Pred','Precision','Recall','TP','FP','FN','NonReferenceDiscrepancy']

def tsv_row(variant_name,stats,err):
return [variant_name,
stats['num_true'],
stats['num_pred'],
interval(*bound_precision(stats['good_predictions'],stats['false_positives'],err)),
interval(*bound_recall(stats['good_predictions'],stats['false_negatives'],err)),
stats['good_predictions'],
stats['false_negatives'],
stats['false_positives'],
get_nrd(stats)
]
def get_tsv_header(knownFP=False):
if knownFP:
return ['VariantType','#True','#Pred','Precision', 'FP Precision','Recall','TP','FP','FN','NonReferenceDiscrepancy']
else:
return ['VariantType','#True','#Pred','Precision','Recall','TP','FP','FN','NonReferenceDiscrepancy']

def tsv_row(variant_name,stats,err,knownFP=False):
if knownFP:
return [variant_name,
stats['num_true'],
stats['num_pred'],
interval(*bound_precision(stats['good_predictions'],stats['false_positives'],err)),
interval(*bound_precision(stats['good_predictions'],stats['calls_at_known_fp'],err)),
interval(*bound_recall(stats['good_predictions'],stats['false_negatives'],err)),
stats['good_predictions'],
stats['calls_at_known_fp'],
stats['false_negatives'],
get_nrd(stats)
]
else:
return [variant_name,
stats['num_true'],
stats['num_pred'],
interval(*bound_precision(stats['good_predictions'],stats['false_positives'],err)),
interval(*bound_recall(stats['good_predictions'],stats['false_negatives'],err)),
stats['good_predictions'],
stats['false_positives'],
stats['false_negatives'],
get_nrd(stats)
]

def get_nrd(stats):
if stats['num_true'] > 0:
Expand Down Expand Up @@ -103,7 +120,7 @@ def bound_precision(tp, fp, e):
return 0, 0
return (tp - e) / p, (tp + e) / p

def print_snp_results(num_true, num_pred, num_fp, num_fn, num_ib, num_ig, nrd, known_fp_prec, err, known_fp_vars=False):
def print_snp_results(num_true, num_pred, num_fp, num_fn, num_ib, num_ig, nrd, known_fp_calls, err, known_fp_vars=False):
print("\n-----------")
print("SNP Results")
print("-----------")
Expand All @@ -113,7 +130,7 @@ def print_snp_results(num_true, num_pred, num_fp, num_fn, num_ib, num_ig, nrd, k
assert tp + num_fp <= num_pred
print("\t# precision =", interval(*bound_precision(tp, num_fp, err)))
if known_fp_vars:
print("\t# precision (known FP) = %.1f" % (100*known_fp_prec) )
print("\t# precision (known FP) =", interval(*bound_precision(tp,known_fp_calls,err)) )
print("\t# recall =", interval(*bound_recall(tp, num_fn, err)))
print("\t# allele mismatch = %d" % num_ib)
print("\t# correct = %d" % num_ig)
Expand All @@ -126,7 +143,7 @@ def print_snp_stats(stats, err, known_fp_vars=False):
print_snp_results(stats['num_true'], stats['num_pred'],
stats['false_positives'], stats['false_negatives'],
stats['intersect_bad'], stats['good_predictions'],ratio(stats['nrd_wrong'],stats['nrd_total']),
1-ratio(stats['known_fp_calls'],stats['known_fp']),err, known_fp_vars)
stats['known_fp_calls'],err, known_fp_vars)

def ratio(a,b,sig=5):
if b == 0:
Expand All @@ -135,14 +152,14 @@ def ratio(a,b,sig=5):
return 0.0
return float(int(10**sig*(float(a)/b)))/10**sig

def print_sv_results(var_type_str, num_true, num_pred, num_fp, num_fn, num_mm, num_gp, nrd, known_fp_prec, err, known_fp_vars=False):
def print_sv_results(var_type_str, num_true, num_pred, num_fp, num_fn, num_mm, num_gp, nrd, known_fp_calls, err, known_fp_vars=False):
print("\n\n------------------------")
print("%s Results" % var_type_str)
print("------------------------")
print("# True = %d; # Predicted = %d" % (num_true, num_pred))
print("\t# precision =", interval(*bound_precision(num_gp, num_fp, err)))
if known_fp_vars:
print("\t# precision (known FP) = %.1f" % (100*known_fp_prec) )
print("\t# precision (known FP) = ", interval(*bound_precision(num_gp,known_fp_calls,err)) )
print("\t# recall =", interval(*bound_recall(num_true - num_fn, num_fn, err)))
#print "\t# multiple matches = %d" % num_mm
print("\t# correct = %d" % num_gp)
Expand All @@ -161,7 +178,7 @@ def print_sv_stats(description, stats, err):
stats['false_positives'], stats['false_negatives'],
#stats['mult_matches'],
0, stats['good_predictions'], ratio(stats['nrd_wrong'],stats['nrd_total']),
1-ratio(stats['known_fp_calls'],stats['known_fp']), err)
stats['known_fp_calls'], err)

def print_sv_other_results(var_type_str, num_true, num_pred):
print("\n\n------------------------")
Expand Down Expand Up @@ -265,7 +282,7 @@ def main(params):
with open(args.knownFP) as f:
known_fp_vcf = vcf.Reader(f)
known_fp_vars = Variants(known_fp_vcf,
MAX_INDEL_LEN, knownFp=True)
MAX_INDEL_LEN, knownFP=True)
else:
known_fp_vars = None

Expand All @@ -291,7 +308,7 @@ def main(params):
if args.output == "tsv":
print(get_text_header(params),file=sys.stdout)
tsvwriter = csv.writer(sys.stdout, delimiter='\t')
tsvwriter.writerow(tsv_header)
tsvwriter.writerow(get_tsv_header(args.knownFP))
tsvwriter.writerow(tsv_row("SNP",stat_reporter(VARIANT_TYPE.SNP),snp_err))
tsvwriter.writerow(tsv_row("Indel Deletions",stat_reporter(VARIANT_TYPE.INDEL_DEL),indel_err))
tsvwriter.writerow(tsv_row("Indel Insertions",stat_reporter(VARIANT_TYPE.INDEL_INS),indel_err))
Expand Down
4 changes: 3 additions & 1 deletion smashbenchmarking/vcf_eval/chrom_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def add_record(self, record):
raise AssertionError("VCF contains lower-case bases in the reference: " + record.REF+ " at "+str(record.POS))

alt = map(str, record.ALT)
if _lacks_alt_alleles(record):
if _lacks_alt_alleles(record) and not self._args.get('knownFP',None):
raise Exception("Monomorphic records (no alt allele) are not supported.")

len_ref = len(ref)
Expand Down Expand Up @@ -228,6 +228,8 @@ def add_appropriate_variant(indel_type,sv_type):

if record.is_snp:
add_variant(VARIANT_TYPE.SNP)
elif self._args.get('knownFP',None) and len(record.REF) == 1 and not record.ALT[0]: # is_snp False if no alt allele
add_variant(VARIANT_TYPE.SNP)
else:
allele_lengths = map(len, alt) + [len_ref]
is_indel = (not is_sv(record,self._max_indel_len) ) and max(allele_lengths) <= self._max_indel_len
Expand Down
11 changes: 7 additions & 4 deletions smashbenchmarking/vcf_eval/eval_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def structural_match(true_variant,pred_vars_all,sv_eps,sv_eps_bp):
matches = get_closest(true_variant,matches)
return matches

def vartype_match_at_location(fp_vars,pred_vars,loc):
return fp_vars.all_variants[loc].ref == pred_vars.all_variants[loc].ref

class ChromVariantStats:

"""Stats for a certain contig's worth of variants."""
Expand Down Expand Up @@ -256,12 +259,12 @@ def chrom_evaluate_variants(true_var,pred_var,sv_eps,sv_eps_bp,ref,window,known_
all_known_fp = _type_dict()
if known_fp:
for loc in pred_loc.intersection(known_fp.all_locations):
vartype = known_fp.all_variants[loc].var_type
match = any_var_match_at_loc(known_fp,pred_var,loc)
match = vartype_match_at_location(known_fp,pred_var,loc)
if match:
calls_at_known_fp[vartype] += 1
known_fp_calls_positions.append(loc)
all_known_fp[vartype] += 1
vartype = known_fp.all_variants[loc].var_type
all_known_fp[vartype] += 1 # note this only holds known fp sharing a location with pred var, NOT all

# structural variants are a special case if not matching exactly
for loc in (true_loc - pred_loc):
Expand All @@ -288,7 +291,7 @@ def chrom_evaluate_variants(true_var,pred_var,sv_eps,sv_eps_bp,ref,window,known_
if ( known_fp ):
variant_stats.known_fp = all_known_fp
variant_stats.calls_at_known_fp = calls_at_known_fp
variant_stats.known_fp_variants = variant_stats._extract(variant_stats.pred_var,known_fp_calls_positions,None)
variant_stats.known_fp_variants = variant_stats._extract(variant_stats.pred_var,known_fp_calls_positions,_type_dict())
variant_stats.intersect_bad = intersect_bad_dict
#stats = variant_stats.to_dict()
#stats['intersect_bad'] = len(intersect_bad)
Expand Down
42 changes: 41 additions & 1 deletion test/chrom_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import unittest
import StringIO

from test_helper import MAX_INDEL_LEN
from test_helper import MAX_INDEL_LEN,vcf_to_Variants,get_reference

sys.path.insert(0,'..')
from smashbenchmarking import Variants,evaluate_variants
Expand Down Expand Up @@ -357,5 +357,45 @@ def test_sv_out_of_range(self):

self.trueNegative(stat_reporter,VARIANT_TYPE.SV_DEL)

def test_known_false_positives(self):
true_vcf = """##fileformat=VCFv4.0\n
##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n
##source=TVsim\n
#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n
chr1 1 . T A 20 PASS . GT 0/1\n
chr1 8 . A C 20 PASS . GT 1/1\n
"""
pred_vcf = """##fileformat=VCFv4.0\n
##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n
##source=TVsim\n
#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n
chr1 3 . G C 20 PASS . GT 1/1\n
chr1 5 . C G 20 PASS . GT 0/1\n
chr1 8 . A C 20 PASS . GT 1/1\n
"""
known_fp_vcf = """##fileformat=VCFv4.0\n
##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n
##source=TVsim\n
#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n
chr1 3 . G . 20 PASS . GT 0/0\n
chr1 5 . C G 20 PASS . GT 0/0\n
chr1 9 . T . 20 PASS . GT 0/0\n
"""
known_fp_io = StringIO.StringIO(known_fp_vcf)
known_fp_vars = Variants(vcf.Reader(known_fp_io),MAX_INDEL_LEN,knownFP=True)

stat_reporter, vcf_output = evaluate_variants(vcf_to_Variants(true_vcf),vcf_to_Variants(pred_vcf),sv_eps,sv_eps, \
get_reference(),50,known_fp_vars)

snp_stats = stat_reporter(VARIANT_TYPE.SNP)

self.assertEqual(snp_stats['num_true'],2)
self.assertEqual(snp_stats['num_pred'],3)
self.assertEqual(snp_stats['good_predictions'],1)
self.assertEqual(snp_stats['false_positives'],2) # predicted vars not in ground truth
self.assertEqual(snp_stats['false_negatives'],1)
self.assertEqual(snp_stats['known_fp_calls'],2)
self.assertEqual(snp_stats['known_fp'],2)

if __name__ == '__main__':
unittest.main()
16 changes: 16 additions & 0 deletions test/chrom_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,22 @@ def testRemoveRecord(self):
self.assertEqual(len(newChromVar._var_locations[VARIANT_TYPE.SV_DEL]),1)
self.assertEqual(len(newChromVar._var_dict(VARIANT_TYPE.SV_DEL)),1)

def testKnownFalsePositives(self):
vcf_str = """##fileformat=VCFv4.0\n
##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n
#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n
chr1 7 . G . 20 PASS . GT 0/0\n
"""
newChromVar = ChromVariants('chr1',MAX_INDEL_LEN,knownFP=True)
vcf_io = StringIO.StringIO(vcf_str)
vcfr = vcf.Reader(vcf_io)
for r in vcfr:
newChromVar.add_record(r)
self.assertEqual(newChromVar.all_locations,[7])
var = newChromVar.all_variants[7]
self.assertEqual(var.ref,'G')
self.assertEqual(var.alt[0], 'None')

#test rando helper methods
class ChromVariantHelperMethodsTestCase(unittest.TestCase):
def testExtractRange(self):
Expand Down
60 changes: 32 additions & 28 deletions test/eval_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from smashbenchmarking.vcf_eval.eval_helper import *
from smashbenchmarking.vcf_eval.eval_helper import _genotype_concordance_dict
from smashbenchmarking.vcf_eval.chrom_variants import Variant,VARIANT_TYPE,GENOTYPE_TYPE
from smashbenchmarking.vcf_eval.variants import Variants

EPS_BP = 10
EPS_LEN = 10
Expand Down Expand Up @@ -169,35 +170,36 @@ def testChromEvaluateGenotypeConcordance(self):
self.assertEqual(cvs.genotype_concordance[VARIANT_TYPE.SNP][GENOTYPE_TYPE.HOM_VAR][GENOTYPE_TYPE.HOM_VAR],1)
self.assertEqual(cvs._nrd_counts(VARIANT_TYPE.SNP),(0,2))

def testChromEvaluateVariantsKnownFP(self):
# one known true variant
true_str = """##fileformat=VCFv4.0\n
##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n
#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n
chr1 2 . A T 20 PASS . GT 0/1\n
"""
# call var where known fp is, where true var is, where nothing is known
pred_str = """##fileformat=VCFv4.0\n
##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n
#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n
chr1 2 . A T 20 PASS . GT 0/1\n
chr1 4 . G C 20 PASS . GT 1/1\n
chr1 7 . G A 20 PASS . GT 0/1\n
"""
# known locations with NO variant
known_fp_str = """##fileformat=VCFv4.0\n
##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n
#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n
chr1 1 . A T 20 PASS . GT ./.\n
chr1 7 . G . 20 PASS . GT 0/0\n
"""
true_vars = vcf_to_ChromVariants(true_str,'chr1')
pred_vars = vcf_to_ChromVariants(pred_str,'chr1')
known_fp_io = StringIO.StringIO(known_fp_str)
known_fp = Variants(vcf.Reader(known_fp_io),MAX_INDEL_LEN,knownFP=True)
cvs = chrom_evaluate_variants(true_vars,pred_vars,100,100,get_reference(),50,known_fp.on_chrom('chr1'))
self.assertEqual(cvs.num_fp[VARIANT_TYPE.SNP],2) # usual definition, in pred vars but not in true
self.assertEqual(cvs.calls_at_known_fp[VARIANT_TYPE.SNP],1) # call at location known to NOT have SNP

# def testChromEvaluateVariantsKnownFP(self):
# # one known true variant
# true_str = """##fileformat=VCFv4.0\n
# ##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n
# #CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n
# chr1 2 . A T 20 PASS . GT 0/1\n
# """
# # call var where known fp is, where true var is, where nothing is known
# pred_str = """##fileformat=VCFv4.0\n
# ##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n
# #CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n
# chr1 2 . A T 20 PASS . GT 0/1\n
# chr1 4 . GCC G 20 PASS . GT 1/1\n
# chr1 7 . G A 20 PASS . GT 0/1\n
# """
# # known location with NO variant
# known_fp_str = """##fileformat=VCFv4.0\n
# ##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n
# #CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001\n
# chr1 7 . G . 20 PASS . GT 0/0\n
# """
# true_vars = vcf_to_ChromVariants(true_str,'chr1')
# pred_vars = vcf_to_ChromVariants(pred_str,'chr1')
# known_fp = vcf_to_ChromVariants(known_fp_str,'chr1',True)
# cvs = chrom_evaluate_variants(true_vars,pred_vars,100,100,get_reference(),50,known_fp)
# print(cvs.known_fp)
# print(cvs.calls_at_known_fp)
# print(cvs.known_fp_vars)
def testChromEvaluateVariantsSV(self):
#NB: SVs aren't rescued, just checked for within breakpoint tolerance
true_str = """##fileformat=VCFv4.0\n
Expand Down Expand Up @@ -456,5 +458,7 @@ def testTruePosRectify(self):
self.assertEqual(cvs.num_fp[VARIANT_TYPE.INDEL_DEL],0)




if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions test/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@

MAX_INDEL_LEN = 50

def vcf_to_Variants(vcf_str):
str_io = StringIO.StringIO(vcf_str)
str_vcf = vcf.Reader(str_io)
return Variants(str_vcf,MAX_INDEL_LEN)

def vcf_to_ChromVariants(vcf_str,chrom):
str_io = StringIO.StringIO(vcf_str)
str_vcf = vcf.Reader(str_io)
Expand Down
Loading

0 comments on commit ca66a3a

Please sign in to comment.