Skip to content

Commit

Permalink
Merge pull request #145 from jeromekelleher/better-dencode
Browse files Browse the repository at this point in the history
Better dencode
  • Loading branch information
jeromekelleher authored Apr 30, 2024
2 parents 1c47953 + 3faa3c8 commit 5a602c8
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 54 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# 0.0.7 2024-04-30
- Change on-disk format of distributed encode and simplify
- Check for all partitions nominally completed encoding before doing
anything destructive in dencode-finalise

# 0.0.6 2024-04-24

- Only use NOSHUFFLE by default on ``call_genotype`` and bool arrays.
Expand Down
117 changes: 64 additions & 53 deletions bio2zarr/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,17 +294,18 @@ def scan_vcf(path, target_num_partitions):

def check_overlap(partitions):
for i in range(1, len(partitions)):
prev_partition = partitions[i - 1]
current_partition = partitions[i]
if (
prev_partition.region.contig == current_partition.region.contig
and prev_partition.region.end > current_partition.region.start
):
raise ValueError(
f"Multiple VCFs have the region "
f"{prev_partition.region.contig}:{prev_partition.region.start}-"
f"{current_partition.region.end}"
)
prev_region = partitions[i - 1].region
current_region = partitions[i].region
if prev_region.contig == current_region.contig:
if prev_region.end is None:
logger.warning("Cannot check overlaps; issue #146")
continue
if prev_region.end > current_region.start:
raise ValueError(
f"Multiple VCFs have the region "
f"{prev_region.contig}:{prev_region.start}-"
f"{current_region.end}"
)


def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
Expand Down Expand Up @@ -453,7 +454,7 @@ def sanitise_value_float_2d(buff, j, value):

def sanitise_int_array(value, ndmin, dtype):
if isinstance(value, tuple):
value = [VCF_INT_MISSING if x is None else x for x in value] # NEEDS TEST
value = [VCF_INT_MISSING if x is None else x for x in value] # NEEDS TEST
value = np.array(value, ndmin=ndmin, copy=False)
value[value == VCF_INT_MISSING] = -1
value[value == VCF_INT_FILL] = -2
Expand Down Expand Up @@ -1548,10 +1549,8 @@ def parse_max_memory(max_memory):

@dataclasses.dataclass
class VcfZarrPartition:
start_index: int
stop_index: int
start_chunk: int
stop_chunk: int
start: int
stop: int

@staticmethod
def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None):
Expand All @@ -1565,9 +1564,7 @@ def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None
stop_chunk = int(chunk_slice[-1]) + 1
start_index = start_chunk * chunk_size
stop_index = min(stop_chunk * chunk_size, num_records)
partitions.append(
VcfZarrPartition(start_index, stop_index, start_chunk, stop_chunk)
)
partitions.append(VcfZarrPartition(start_index, stop_index))
return partitions


Expand All @@ -1590,7 +1587,7 @@ def asdict(self):
def fromdict(d):
if d["format_version"] != VZW_METADATA_FORMAT_VERSION:
raise ValueError(
"VcfZarrWriter format version mismatch: "
"VcfZarrWriter format version mismatch: "
f"{d['format_version']} != {VZW_METADATA_FORMAT_VERSION}"
)
ret = VcfZarrWriterMetadata(**d)
Expand Down Expand Up @@ -1675,7 +1672,7 @@ def init(
root = zarr.group(store=store)

for column in self.schema.columns.values():
self.init_array(root, column, partitions[-1].stop_index)
self.init_array(root, column, partitions[-1].stop)

logger.info("Writing WIP metadata")
with open(self.wip_path / "metadata.json", "w") as f:
Expand Down Expand Up @@ -1762,28 +1759,42 @@ def load_metadata(self):
def partition_path(self, partition_index):
return self.partitions_path / f"p{partition_index}"

def wip_partition_path(self, partition_index):
return self.partitions_path / f"wip_p{partition_index}"

def wip_partition_array_path(self, partition_index, name):
return self.partition_path(partition_index) / f"wip_{name}"
return self.wip_partition_path(partition_index) / name

def partition_array_path(self, partition_index, name):
return self.partition_path(partition_index) / name

def encode_partition(self, partition_index):
self.load_metadata()
partition_path = self.partition_path(partition_index)
if partition_index < 0 or partition_index >= self.num_partitions:
raise ValueError(
"Partition index must be in the range 0 <= index < num_partitions"
)
partition_path = self.wip_partition_path(partition_index)
partition_path.mkdir(exist_ok=True)
logger.info(f"Encoding partition {partition_index} to {partition_path}")

self.encode_alleles_partition(partition_index)
self.encode_id_partition(partition_index)
self.encode_filters_partition(partition_index)
self.encode_contig_partition(partition_index)
self.encode_alleles_partition(partition_index)
for col in self.schema.columns.values():
if col.vcf_field is not None:
self.encode_array_partition(col, partition_index)
if "call_genotype" in self.schema.columns:
self.encode_genotypes_partition(partition_index)

final_path = self.partition_path(partition_index)
logger.info(f"Finalising {partition_index} at {final_path}")
if final_path.exists():
logger.warning("Removing existing partition at {final_path}")
shutil.rmtree(final_path)
os.rename(partition_path, final_path)

def init_partition_array(self, partition_index, name):
wip_path = self.wip_partition_array_path(partition_index, name)
# Create an empty array like the definition
Expand All @@ -1795,27 +1806,17 @@ def init_partition_array(self, partition_index, name):
return array

def finalise_partition_array(self, partition_index, name):
wip_path = self.wip_partition_array_path(partition_index, name)
final_path = self.partition_array_path(partition_index, name)
if final_path.exists():
# NEEDS TEST
logger.warning(f"Removing existing {final_path}")
shutil.rmtree(final_path)
# Atomic swap
os.rename(wip_path, final_path)
logger.debug(f"Encoded {name} partition {partition_index}")

def encode_array_partition(self, column, partition_index):
array = self.init_partition_array(partition_index, column.name)

partition = self.metadata.partitions[partition_index]
ba = core.BufferedArray(array, partition.start_index)
ba = core.BufferedArray(array, partition.start)
source_col = self.icf.columns[column.vcf_field]
sanitiser = source_col.sanitiser_factory(ba.buff.shape)

for value in source_col.iter_values(
partition.start_index, partition.stop_index
):
for value in source_col.iter_values(partition.start, partition.stop):
# We write directly into the buffer in the sanitiser function
# to make it easier to reason about dimension padding
j = ba.next_buffer_row()
Expand All @@ -1831,14 +1832,12 @@ def encode_genotypes_partition(self, partition_index):
)

partition = self.metadata.partitions[partition_index]
gt = core.BufferedArray(gt_array, partition.start_index)
gt_mask = core.BufferedArray(gt_mask_array, partition.start_index)
gt_phased = core.BufferedArray(gt_phased_array, partition.start_index)
gt = core.BufferedArray(gt_array, partition.start)
gt_mask = core.BufferedArray(gt_mask_array, partition.start)
gt_phased = core.BufferedArray(gt_phased_array, partition.start)

source_col = self.icf.columns["FORMAT/GT"]
for value in source_col.iter_values(
partition.start_index, partition.stop_index
):
for value in source_col.iter_values(partition.start, partition.stop):
j = gt.next_buffer_row()
sanitise_value_int_2d(gt.buff, j, value[:, :-1])
j = gt_phased.next_buffer_row()
Expand All @@ -1859,13 +1858,13 @@ def encode_alleles_partition(self, partition_index):
array_name = "variant_allele"
alleles_array = self.init_partition_array(partition_index, array_name)
partition = self.metadata.partitions[partition_index]
alleles = core.BufferedArray(alleles_array, partition.start_index)
alleles = core.BufferedArray(alleles_array, partition.start)
ref_col = self.icf.columns["REF"]
alt_col = self.icf.columns["ALT"]

for ref, alt in zip(
ref_col.iter_values(partition.start_index, partition.stop_index),
alt_col.iter_values(partition.start_index, partition.stop_index),
ref_col.iter_values(partition.start, partition.stop),
alt_col.iter_values(partition.start, partition.stop),
):
j = alleles.next_buffer_row()
alleles.buff[j, :] = STR_FILL
Expand All @@ -1879,11 +1878,11 @@ def encode_id_partition(self, partition_index):
vid_array = self.init_partition_array(partition_index, "variant_id")
vid_mask_array = self.init_partition_array(partition_index, "variant_id_mask")
partition = self.metadata.partitions[partition_index]
vid = core.BufferedArray(vid_array, partition.start_index)
vid_mask = core.BufferedArray(vid_mask_array, partition.start_index)
vid = core.BufferedArray(vid_array, partition.start)
vid_mask = core.BufferedArray(vid_mask_array, partition.start)
col = self.icf.columns["ID"]

for value in col.iter_values(partition.start_index, partition.stop_index):
for value in col.iter_values(partition.start, partition.stop):
j = vid.next_buffer_row()
k = vid_mask.next_buffer_row()
assert j == k
Expand All @@ -1904,10 +1903,10 @@ def encode_filters_partition(self, partition_index):
array_name = "variant_filter"
array = self.init_partition_array(partition_index, array_name)
partition = self.metadata.partitions[partition_index]
var_filter = core.BufferedArray(array, partition.start_index)
var_filter = core.BufferedArray(array, partition.start)

col = self.icf.columns["FILTERS"]
for value in col.iter_values(partition.start_index, partition.stop_index):
for value in col.iter_values(partition.start, partition.stop):
j = var_filter.next_buffer_row()
var_filter.buff[j] = False
for f in value:
Expand All @@ -1926,10 +1925,10 @@ def encode_contig_partition(self, partition_index):
array_name = "variant_contig"
array = self.init_partition_array(partition_index, array_name)
partition = self.metadata.partitions[partition_index]
contig = core.BufferedArray(array, partition.start_index)
contig = core.BufferedArray(array, partition.start)
col = self.icf.columns["CHROM"]

for value in col.iter_values(partition.start_index, partition.stop_index):
for value in col.iter_values(partition.start, partition.stop):
j = contig.next_buffer_row()
# Note: because we are using the indexes to define the lookups
# and we always have an index, it seems that we the contig lookup
Expand All @@ -1950,7 +1949,7 @@ def finalise_array(self, name):
if final_path.exists():
# NEEDS TEST
raise ValueError(f"Array {name} already exists")
for partition in range(len(self.metadata.partitions)):
for partition in range(self.num_partitions):
# Move all the files in partition dir to dest dir
src = self.partition_array_path(partition, name)
if not src.exists():
Expand All @@ -1977,6 +1976,15 @@ def finalise_array(self, name):
def finalise(self, show_progress=False):
self.load_metadata()

logger.info("Scanning {self.num_partitions} partitions")
missing = []
# TODO may need a progress bar here
for partition_id in range(self.num_partitions):
if not self.partition_path(partition_id).exists():
missing.append(partition_id)
if len(missing) > 0:
raise FileNotFoundError(f"Partitions not encoded: {missing}")

progress_config = core.ProgressConfig(
total=len(self.schema.columns),
title="Finalise",
Expand All @@ -1994,6 +2002,9 @@ def finalise(self, show_progress=False):
with core.ParallelWorkManager(0, progress_config) as pwm:
for name in self.schema.columns:
pwm.submit(self.finalise_array, name)
logger.debug(f"Removing {self.wip_path}")
shutil.rmtree(self.wip_path)
logger.info("Consolidating Zarr metadata")
zarr.consolidate_metadata(self.path)

######################
Expand Down
Loading

0 comments on commit 5a602c8

Please sign in to comment.