Skip to content

Commit

Permalink
Added & improved on phylomodels (treeppl#75)
Browse files Browse the repository at this point in the history
discussed with Emma, ok to merge
  • Loading branch information
vresbok authored Sep 20, 2024
1 parent 1ef64cd commit ec72e47
Show file tree
Hide file tree
Showing 11 changed files with 455 additions and 69 deletions.
6 changes: 0 additions & 6 deletions models/data/treeinference/phylo_toydata1.json

This file was deleted.

56 changes: 56 additions & 0 deletions models/data/treeinference/phylo_toydata_fixtree.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
{
"data":
[
[1, 1, 0, 2, 3, 0, 0, 0, 0, 3, 1, 1, 3, 3, 2],
[1, 1, 0, 2, 0, 1, 0, 0, 0, 3, 1, 0, 1, 1, 0],
[0, 0, 1, 1, 0, 2, 1, 0, 0, 0, 2, 0, 3, 3, 0],
[0, 0, 1, 1, 0, 3, 0, 1, 0, 0, 2, 2, 3, 1, 0]
],
"tree":
{
"__data__": {
"age": 1.6497,
"left": {
"__data__": {
"age": 0.3347,
"left": {
"__data__": {
"age": 0.0,
"index": 3
},
"__constructor__": "Leaf"
},
"right": {
"__data__": {
"age": 0.0,
"index": 4
},
"__constructor__": "Leaf"
}
},
"__constructor__": "Node"
},
"right": {
"__data__": {
"age": 0.4414,
"left": {
"__data__": {
"age": 0.0,
"index": 1
},
"__constructor__": "Leaf"
},
"right": {
"__data__": {
"age": 0.0,
"index": 2
},
"__constructor__": "Leaf"
}
},
"__constructor__": "Node"
}
},
"__constructor__": "Node"
}
}
8 changes: 8 additions & 0 deletions models/data/treeinference/phylo_toydata_gaps.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{"data":
[
[1, 1, 0, 2, 3, 0, 0, 0, 0, 3, 1, 1, 3, 3, 4],
[1, 1, 0, 2, 0, 1, 0, 0, 0, 3, 1, 0, 1, 1, 0],
[0, 0, 1, 1, 0, 2, 1, 0, 0, 0, 2, 0, 3, 3, 0],
[4, 0, 1, 1, 0, 3, 0, 1, 0, 0, 2, 2, 3, 1, 0]
]
}
15 changes: 15 additions & 0 deletions models/data/treeinference/primates.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion models/metabarcoding/optmodel.tppl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
* Compilation:
* tpplc models/metabarcoding/optmodel.tppl -m 'smc-bpf' --resample align --output optmodel
* Execution with 10000 particles and 1 sweep (homogenate data H_real.json):
* ./optmodel models/data/H_real.json 10000 1 > output/H.json
* ./optmodel models/data/H_real.json 10000 1 > H.json
*/

/*
Expand Down
133 changes: 133 additions & 0 deletions models/phylo/phylogeny/substmodel_belief_propagation.tppl
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// TreePPL script for inference substitution
// model parameters of the GTR model for a fixed
// tree. Here we use explicit belief propagation, which should
// produce a state-of-the-art model with MCMC, using
// suitable drift kernels, slice sampling, or with HMC.

// Run this by compiling:
// tpplc models/phylo/phylogeny/substmodel_belief_propagation.tppl -m 'smc-apf' --resample align --subsample -n 1 --output substmodel
// And then execution x particles x sweeps, using for example a toy dataset found in phylo_toydata_fixtree.json:
// ./substmodel models/data/treeinference/phylo_toydata_fixtree.json 1000 1 > substmodel_output.json

// Data is a matrix of integers character, with each row
// consisting of the (translated) aligned DNA sequence for a single
// taxon. The nucleotide is specified as
// "A" for adenine - 0
// "C" for cytosine - 1
// "G" for guanine - 2
// "T" for thymine - 3
// "?" for any of the above - 4
// "-" a gap, usually treated the same as "?" - 4
// IUPAC has single-character codes for all possible
// combinations of the four nucleotides, but typically
// only those given above are used.

// TYPES

type TheReturn = TheReturn{stationary_probs: Real[], exchangeability_rates: Real[]}

type GTR_Tree =
| Leaf {age: Real, index: Int}
| Node {age: Real, left: GTR_Tree, right: GTR_Tree}


// MODEL FUNCTIONS

// A help function to compute the scaled rate matrix for GTR
function gtr(pi: Real[], r: Real[]) : Tensor[Real] {
// Construct a matrix of exchangeability rates
let unscaled_q = mtxCreate(4,4,
[ -(pi[2]*r[1]+pi[3]*r[2]+pi[4]*r[3]), pi[2]*r[1], pi[3]*r[2], pi[4]*r[3] ,
pi[1]*r[1], -(pi[1]*r[1]+pi[3]*r[4]+pi[4]*r[5]), pi[3]*r[4], pi[4]*r[5] ,
pi[1]*r[2], pi[2]*r[4], -(pi[1]*r[2]+pi[2]*r[4]+pi[4]*r[6]), pi[4]*r[6] ,
pi[1]*r[3], pi[2]*r[5], pi[3]*r[6], -(pi[1]*r[3]+pi[2]*r[5]+pi[3]*r[6])
]);

let scaler_1 = -(mtxGet(1, 1, unscaled_q) * pi[1]);
let scaler_2 = -(mtxGet(2, 2, unscaled_q) * pi[2]);
let scaler_3 = -(mtxGet(3, 3, unscaled_q) * pi[3]);
let scaler_4 = -(mtxGet(4, 4, unscaled_q) * pi[4]);
let scaler = scaler_1 + scaler_2 + scaler_3 + scaler_4;

return mtxSclrMul((1.0/scaler), unscaled_q);
}

// A simple help function to compute a message for a branch in the tree in a
// postorder traversal towards the root. We assume standard conventions here
// for the structure of q with respect to time (rows = from-states)
function message(start_msg: Tensor[Real][], q: Tensor[Real], time: Real) : Tensor[Real][] {
let trans_matrix = mtxTrans(mtxExp(mtxSclrMul(time, q)));
return sapply1(start_msg, mtxMul, trans_matrix);
}

// Compute postorder messages on the observed tree
function compute_postorder_message(tree: GTR_Tree, data: Int[][], q: Tensor[Real]) : Tensor[Real][] {
if (tree is Leaf) {
return sapply(data[tree.index], get_leaf_message);
}

let left_msg = message( compute_postorder_message(tree.left, data, q), q, (tree.age - tree.left.age) );
let right_msg = message( compute_postorder_message(tree.right, data, q), q, (tree.age - tree.right.age) );

return messageElemMul(left_msg, right_msg);
}


// Get message from leaves for each site
function get_leaf_message(seq: Int) : Tensor[Real] {
if (eqi(seq, 0)) { // "A"
let message = rvecCreate(4, [1.0, 0.0, 0.0, 0.0]);
return message;
}
if (eqi(seq, 1)) { // "C"
let message = rvecCreate(4, [0.0, 1.0, 0.0, 0.0]);
return message;
}
if (eqi(seq, 2)) { // "G"
let message = rvecCreate(4, [0.0, 0.0, 1.0, 0.0]);
return message;
}
if (eqi(seq, 3)) { // "T"
let message = rvecCreate(4, [0.0, 0.0, 0.0, 1.0]);
return message;
}
if (eqi(seq, 4)) { // "-" or "?"
let message = rvecCreate(4, [1.0, 1.0, 1.0, 1.0]);
return message;
}
else {
return error("Invalid state at leaf");
}
}

//Compute log likelihood for each site
function get_log_likes(msg: Tensor[Real], pi_col: Tensor[Real]): Real {
let like = mtxMul(msg, pi_col);
let log_like = log(mtxGet(1, 1, like));
return log_like;
}

// MODEL
// clock_rate not needed in this version
model function myModel(data: Int[][], tree: GTR_Tree) : TheReturn {

// Sample stationary probs and exchangeability rates and
// compute the scaled rate matrix for the GTR model.
// We assume nucleotides (A, C, G, T) represented by ints (0, 1, 2, 3)
assume pi ~ Dirichlet([1.0, 1.0, 1.0, 1.0]);

// We assume exchangeability rates being [r_AC, r_AG, r_AT, r_CG, r_CT, r_GT]
// according to convention
assume er ~ Dirichlet([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);

let q = gtr(pi, er);

let msgs = compute_postorder_message(tree, data, q);

let log_likes = sapply1(msgs, get_log_likes, cvecCreate(4, pi));

logWeight (seqSumReal(log_likes));

return TheReturn{stationary_probs = pi, exchangeability_rates = er};
}

28 changes: 14 additions & 14 deletions models/phylo/phylogeny/tree_inference.tppl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@

// Run this by compiling:
// tpplc models/phylo/phylogeny/tree_inference.tppl -m 'smc-apf' --resample align --output basic_phylo
// And then execution 10000 particles 1 sweep, using for example toy dataset phylo_toydata2_Ints.json.json:
// ./basic_phylo models/data/treeinference/phylo_toydata2_Ints.json 10000 1 > basic_phylo.json
// Or if subsampling is desired, where n is the number of samples in putput:
// tpplc models/phylo/phylogeny/tree_inference.tppl -m 'smc-apf' --resample align --subsample -n 1 --output basic_phylo
// And then execution 10000 particles 1 sweep, using for example toy dataset phylo_toydata.json.json:
// ./basic_phylo models/data/treeinference/phylo_toydata.json 10000 1 > basic_phylo.json

// Toy dataset as json file phylo_toydata2_Ints.json. Integers.
// Toy dataset as json file phylo_toydata.json. Integers.
// Data is a matrix, with each row
// consisting of the aligned DNA sequence for a single
// taxon. The nucleotide is specified as
// "A" for adenine 0
// "C" for cytosine 1
// "G" for guanine 2
// "T" for thymine 3
// "-" currently not handled in this model

// TYPES

Expand All @@ -38,15 +41,12 @@ function buildForest(data: Int[][], forest: SeqTree[], index: Int, dataLen: Int)
}
}

// Randomly sample two indices in the trees vector, to be combined
function pickpair(n: Int): Int[]
{
assume i ~ Categorical(rep(n, 1./Real(n)));
assume j ~ Categorical(rep(n, 1./Real(n)));
if (eqi(j, i)) {
return pickpair(n);
}
return [addi(i, 1), addi(j, 1)]; //adding 1 to avoid index 0
// Randomly sample two indices in the trees vector, to be combined. Avoiding mirror cases.
function pickpair(n: Int): Int[] {
assume i_1 ~ Categorical(rep(subi(n, 1), 1./Real(subi(n, 1))));
let i = addi(i_1, 2); //Adding two, avoiding index zero for i and avoiding trouble picking j
assume j ~ Categorical(rep(subi(i, 1), 1./Real(subi(i, 1)))); // j is always smaller than i, avoiding mirror cases
return [i, addi(j, 1)]; //avoid index of zero for j
}

//Note: this is not optimal, in terms of repeated calculation of matrix exp
Expand Down Expand Up @@ -97,8 +97,8 @@ function cluster(q: Tensor[Real], trees: SeqTree[], maxAge: Real, seqLen: Int):
let parent = Node{age=age, seq=seq, left=leftChild, right=rightChild};

// Compute new_trees list
let min = mini(pairs[1], pairs[2]);
let max = maxi(pairs[1], pairs[2]);
let min = pairs[2];
let max = pairs[1];
let new_trees = join([slice(trees, 1, min), slice(trees, addi(min, 1), max), slice(trees, addi(max, 1), addi(n, 1)), [parent]]);

// Recursive call to cluster for new_trees
Expand Down
43 changes: 22 additions & 21 deletions models/phylo/phylogeny/tree_inference_pruning.tppl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

// Run this by compiling:
// tpplc models/phylo/phylogeny/tree_inference_pruning.tppl -m 'smc-apf' --resample align --output phylo
// And then execution 10000 particles 1 sweep, using for example toy dataset phylo_toydata2.json:
// ./phylo models/data/treeinference/phylo_toydata2.json 10000 1 > phylo_pruning.json
// Or if subsampling is desired, where n is the number of samples in putput:
// tpplc models/phylo/phylogeny/tree_inference_pruning.tppl -m 'smc-apf' --resample align --subsample -n 1 --output phylo
// And then execution 10000 particles 1 sweep, using for example toy dataset phylo_toydata.json:
// ./phylo models/data/treeinference/phylo_toydata.json 10000 1 > phylo_pruning.json

// Toy dataset as json file phylo_toydata2.json.
// Data is a matrix with each row
Expand All @@ -14,28 +16,27 @@
// "C" for cytosine - 1
// "G" for guanine - 2
// "T" for thymine - 3
// "-" for gaps - 4 //To be extended

// TYPES

type Tree =
type MsgTree =
| Leaf {age: Real, index: Int, msg: Tensor[Real][]}
| Node {age: Real, msg: Tensor[Real][], left: Tree, right: Tree}
| Node {age: Real, msg: Tensor[Real][], left: MsgTree, right: MsgTree}


// FUNCTIONS

// Randomly sample two indices in the trees vector, to be combined
// Randomly sample two indices in the trees vector, to be combined. Avoiding mirror cases.
function pickpair(n: Int): Int[] {
assume i ~ Categorical(rep(n, 1./Real(n)));
assume j ~ Categorical(rep(n, 1./Real(n)));
if (eqi(j, i)) {
return pickpair(n);
}
return [addi(i, 1), addi(j, 1)]; //adding 1 to avoid index 0
assume i_1 ~ Categorical(rep(subi(n, 1), 1./Real(subi(n, 1))));
let i = addi(i_1, 2); //Adding two, avoiding index zero for i and avoiding trouble picking j
assume j ~ Categorical(rep(subi(i, 1), 1./Real(subi(i, 1)))); // j is always smaller than i, avoiding mirror cases
return [i, addi(j, 1)]; //avoid index of zero for j
}

// Build forest of trees from leaves, recursively
function build_forest(data: Int[][], forest: Tree[], index: Int, data_len: Int, seq_len: Int): Tree[] {
function build_forest(data: Int[][], forest: MsgTree[], index: Int, data_len: Int, seq_len: Int): MsgTree[] {
let new_message = sapply(data[index], get_leaf_message);
let new_leaf = Leaf{age = 0.0, index = index, msg = new_message};
let new_forest = join([forest, [new_leaf]]);
Expand Down Expand Up @@ -65,6 +66,10 @@ function get_leaf_message(seq: Int) : Tensor[Real] {
let message = rvecCreate(4, [0.0, 0.0, 0.0, 1.0]);
return message;
}
if (eqi(seq, 4)) { // "-"
let message = rvecCreate(4, [1.0, 1.0, 1.0, 1.0]);
return message;
}
else {
return error("Invalid state at leaf");
}
Expand All @@ -79,7 +84,7 @@ function get_log_likes(msg: Tensor[Real]): Real {
}

// KEY FUNCTION: CLUSTER
function cluster(q: Tensor[Real], trees: Tree[], maxAge: Real, seq_len: Int): Tree[] {
function cluster(q: Tensor[Real], trees: MsgTree[], maxAge: Real, seq_len: Int): MsgTree[] {

let n = length(trees);

Expand Down Expand Up @@ -108,12 +113,8 @@ function cluster(q: Tensor[Real], trees: Tree[], maxAge: Real, seq_len: Int): Tr
// Get the message of the new node.
let node_msg = messageElemMul(left_in_msg, right_in_msg);

// Now we want to weight/factor according to the message we just computed, with state probs weighted
// in some suitable fashion, e.g., by the stationary probs, which are [0.25,0.25,0.25,0.25] for
// the Jukes Cantor model. This has the additional advantage that this is the standard assumption
// for the root sequence in the tree, that is, we are assuming stationarity at the root.
// Weights
let log_likes = sapply(node_msg, get_log_likes);

logWeight (seqSumReal(log_likes));

// Remove weights of previous internal nodes
Expand All @@ -130,8 +131,8 @@ function cluster(q: Tensor[Real], trees: Tree[], maxAge: Real, seq_len: Int): Tr
let parent = Node{age=age, msg=node_msg, left=left_child, right=right_child};

// Compute new_trees list
let min = mini(pairs[1], pairs[2]);
let max = maxi(pairs[1], pairs[2]);
let min = pairs[2];
let max = pairs[1];
let new_trees = join([slice(trees, 1, min), slice(trees, addi(min, 1), max), slice(trees, addi(max, 1), addi(n, 1)), [parent]]);

// Recursive call to cluster for new_trees
Expand All @@ -140,7 +141,7 @@ function cluster(q: Tensor[Real], trees: Tree[], maxAge: Real, seq_len: Int): Tr

// MODEL FUNCTION

model function myModel(data: Int[][]) : Tree[] {
model function myModel(data: Int[][]) : MsgTree[] {
// Define the scaled rate matrix for Jukes-Cantor
let q = mtxCreate(4,4,
[ -1.0, (1.0/3.0), (1.0/3.0), (1.0/3.0),
Expand Down
Loading

0 comments on commit ec72e47

Please sign in to comment.