Optimization of Deep Learning models for Genomic Variant Calling on Distributed Environment with GPUs
Deoxyribonucleic acid (DNA) is the chemical compound that contains the instructions needed to develop and direct the activities of nearly all living organisms. Genome is an organism’s complete set of DNA. A human genome is approximately made up of 3 billion DNA base pairs. Genomics is the study of the genome including the interactions of genes with each other and the person’s environments[1]. However, there exists major computational bottlenecks and inefficiencies throughout the genome analysis pipeline. Variant identification and classification is an important task of genome analysis that gives doctors and scientists information about an organism's response to certain infections, drugs, and conditions they are gentically predisposed to. There have been variety of algorithms and tools developed to call generic and specific variants like GATK, Haplotype etc. However, these tools are extremely time consuming and ineficient when run in CPUs. This led to the rise of deep neural network based variant callers like DeepVariant from Google. These still have scope for improvement as training the models for good accuracy across different data requires training using multiple read aligned and variant called genomes with very large number of samples. Our objective in this project is to optimize the variant classification process on current deep learning models to reduce the inefficiencies and computational bottlenecks.
Deliverables from this project:
- Analysed Clairvoyante to identify the pitfalls and challenges in training the model.
- Analysed Clair3 and then used Data Parallelism to improve the performance of the model.
This repository is divided into three main folders:
- Clairvoyante: A CNN model using python2 and tensorflow 1.
- Clair3GPU: An RNN model using python3, tensorflow 2 and 4 GPUs. (Implements Data Parallelism)
- Clair3CPU: An RNN model using python3 and tensorflow 2. (CPU version of Clair3GPU)
- Data Preparation
wget 'http://www.bio8.cs.hku.hk/testingData.tar'
tar -xf testingData.tar
cd dataPrepScripts
sh PrepDataBeforeDemo.sh
- Training the model
Followjupyter_nb/demo.ipynb
- Data Preparation
Data Sources
Clair3 Data: HG001 BAM
Clair3 Data: HG002 BAM
cd dataPrepScripts
sh PrepDataBeforeDemo.sh
- Training the model
Executesbatch train_batch_hg001.sh
Please change the name of the file according to the read sample you want to use for training -hg001
orhg002
.
- The model converged to a 92% validation accuracy after ~50 epochs. The convergence was not stable even with a very small learning rate, indicating that the model structure is succiptible to overfitting.
- With an increase in learning rate and a scheduler to reduce the learning rate periodically, the model performace degraded massively and could not achieve an acceptable validation accuracy.
- The authors do mention a higher learning rate being a pitfall, but their claim suggests a breakdown in performace would happen a lot earlier than the results we got.
- After 30 epochs, model achieves a validation F1 score of 98% on CPU and 95% on 4 GPUs with DataParallel.
- Given the very large size of the data set, parallelizing the training across GPUs with by distributing data decreases the time per epoch by 50% and hence the execution time for 30 epochs is also halved.
- Accuracy scales more slowly on GPUs with DataParallel than it does on CPUs. The time to reach 94% is approximately 1682 seconds on 4 GPUs while the same is 928.9 seconds on CPU despite the time per epoch being significantly lesser with DataParallel.