diff --git a/CHANGELOG.md b/CHANGELOG.md index 669d81a..42f79f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ -# 0.0.6 2024-04-xx +# 0.0.6 2024-04-24 - Only use NOSHUFFLE by default on ``call_genotype`` and bool arrays. +- Add initial implementation of distributed encode # 0.0.5 2024-04-17 diff --git a/bio2zarr/cli.py b/bio2zarr/cli.py index 8b8d134..b088bdf 100644 --- a/bio2zarr/cli.py +++ b/bio2zarr/cli.py @@ -5,6 +5,7 @@ import click import coloredlogs +import humanfriendly import numcodecs import tabulate @@ -39,6 +40,14 @@ def list_commands(self, ctx): "zarr_path", type=click.Path(file_okay=False, dir_okay=True) ) +zarr_path = click.argument( + "zarr_path", type=click.Path(exists=True, file_okay=False, dir_okay=True) +) + +num_partitions = click.argument("num_partitions", type=click.IntRange(min=1)) + +partition = click.argument("partition", type=click.IntRange(min=0)) + verbose = click.option("-v", "--verbose", count=True, help="Increase verbosity") force = click.option( @@ -92,6 +101,27 @@ def list_commands(self, ctx): help="Chunk size in the samples dimension", ) +schema = click.option("-s", "--schema", default=None, type=click.Path(exists=True)) + +max_variant_chunks = click.option( + "-V", + "--max-variant-chunks", + type=int, + default=None, + help=( + "Truncate the output in the variants dimension to have " + "this number of chunks. Mainly intended to help with " + "schema tuning." + ), +) + +max_memory = click.option( + "-M", + "--max-memory", + default=None, + help="An approximate bound on overall memory usage (e.g. 10G),", +) + def setup_logging(verbosity): level = "WARNING" @@ -158,7 +188,7 @@ def explode( @click.command @vcfs @new_icf_path -@click.argument("num_partitions", type=click.IntRange(min=1)) +@num_partitions @force @column_chunk_size @compressor @@ -194,7 +224,7 @@ def dexplode_init( @click.command @icf_path -@click.argument("partition", type=click.IntRange(min=0)) +@partition @verbose def dexplode_partition(icf_path, partition, verbose): """ @@ -207,14 +237,14 @@ def dexplode_partition(icf_path, partition, verbose): @click.command -@click.argument("path", type=click.Path(), required=True) +@icf_path @verbose -def dexplode_finalise(path, verbose): +def dexplode_finalise(icf_path, verbose): """ Final step for distributed conversion of VCF(s) to intermediate columnar format. """ setup_logging(verbose) - vcf.explode_finalise(path) + vcf.explode_finalise(icf_path) @click.command @@ -244,26 +274,11 @@ def mkschema(icf_path): @new_zarr_path @force @verbose -@click.option("-s", "--schema", default=None, type=click.Path(exists=True)) +@schema @variants_chunk_size @samples_chunk_size -@click.option( - "-V", - "--max-variant-chunks", - type=int, - default=None, - help=( - "Truncate the output in the variants dimension to have " - "this number of chunks. Mainly intended to help with " - "schema tuning." - ), -) -@click.option( - "-M", - "--max-memory", - default=None, - help="An approximate bound on overall memory usage (e.g. 10G),", -) +@max_variant_chunks +@max_memory @worker_processes def encode( icf_path, @@ -288,13 +303,96 @@ def encode( schema_path=schema, variants_chunk_size=variants_chunk_size, samples_chunk_size=samples_chunk_size, - max_v_chunks=max_variant_chunks, + max_variant_chunks=max_variant_chunks, worker_processes=worker_processes, max_memory=max_memory, show_progress=True, ) +@click.command +@icf_path +@new_zarr_path +@num_partitions +@force +@schema +@variants_chunk_size +@samples_chunk_size +@max_variant_chunks +@verbose +def dencode_init( + icf_path, + zarr_path, + num_partitions, + force, + schema, + variants_chunk_size, + samples_chunk_size, + max_variant_chunks, + verbose, +): + """ + Initialise conversion of intermediate format to VCF Zarr. This will + set up the specified ZARR_PATH to perform this conversion over + NUM_PARTITIONS. + + The output of this commmand is the actual number of partitions generated + (which may be less then the requested number, if there is not sufficient + chunks in the variants dimension) and a rough lower-bound on the amount + of memory required to encode a partition. + + NOTE: the format of this output will likely change in subsequent releases; + it should not be considered machine-readable for now. + """ + setup_logging(verbose) + check_overwrite_dir(zarr_path, force) + num_partitions, max_memory = vcf.encode_init( + icf_path, + zarr_path, + target_num_partitions=num_partitions, + schema_path=schema, + variants_chunk_size=variants_chunk_size, + samples_chunk_size=samples_chunk_size, + max_variant_chunks=max_variant_chunks, + show_progress=True, + ) + formatted_size = humanfriendly.format_size(max_memory, binary=True) + # NOTE adding the size to the stdout here so that users can parse it + # and use in their submission scripts. This is a first pass, and + # will most likely change as we see what works and doesn't. + # NOTE we probably want to format this as a table, which lists + # some other properties, line by line + # NOTE This size number is also not quite enough, you need a bit of + # headroom with it (probably 10% or so). We should include this. + click.echo(f"{num_partitions}\t{formatted_size}") + + +@click.command +@zarr_path +@partition +@verbose +def dencode_partition(zarr_path, partition, verbose): + """ + Convert a partition from intermediate columnar format to VCF Zarr. + Must be called *after* the Zarr path has been initialised with dencode_init. + Partition indexes must be from 0 (inclusive) to the number of paritions + returned by dencode_init (exclusive). + """ + setup_logging(verbose) + vcf.encode_partition(zarr_path, partition) + + +@click.command +@zarr_path +@verbose +def dencode_finalise(zarr_path, verbose): + """ + Final step for distributed conversion of ICF to VCF Zarr. + """ + setup_logging(verbose) + vcf.encode_finalise(zarr_path, show_progress=True) + + @click.command(name="convert") @vcfs @new_zarr_path @@ -382,6 +480,9 @@ def vcf2zarr(): vcf2zarr.add_command(dexplode_init) vcf2zarr.add_command(dexplode_partition) vcf2zarr.add_command(dexplode_finalise) +vcf2zarr.add_command(dencode_init) +vcf2zarr.add_command(dencode_partition) +vcf2zarr.add_command(dencode_finalise) @click.command(name="convert") diff --git a/bio2zarr/core.py b/bio2zarr/core.py index cafc43e..8b07ac5 100644 --- a/bio2zarr/core.py +++ b/bio2zarr/core.py @@ -110,6 +110,7 @@ def flush(self): sync_flush_2d_array( self.buff[: self.buffer_row], self.array, self.array_offset ) + # FIXME the array.name doesn't seem to be working here for some reason logger.debug( f"Flushed <{self.array.name} {self.array.shape} " f"{self.array.dtype}> " @@ -131,8 +132,7 @@ def sync_flush_2d_array(np_buffer, zarr_array, offset): # encoder implementations. s = slice(offset, offset + np_buffer.shape[0]) samples_chunk_size = zarr_array.chunks[1] - # TODO use zarr chunks here to support non-uniform chunking later - # and for simplicity + # TODO use zarr chunks here for simplicity zarr_array_width = zarr_array.shape[1] start = 0 while start < zarr_array_width: @@ -192,7 +192,7 @@ def __init__(self, worker_processes=1, progress_config=None): self.progress_config = progress_config self.progress_bar = tqdm.tqdm( total=progress_config.total, - desc=f"{progress_config.title:>7}", + desc=f"{progress_config.title:>8}", unit_scale=True, unit=progress_config.units, smoothing=0.1, diff --git a/bio2zarr/vcf.py b/bio2zarr/vcf.py index 1794707..e7288f8 100644 --- a/bio2zarr/vcf.py +++ b/bio2zarr/vcf.py @@ -1318,6 +1318,17 @@ def _choose_compressor_settings(self): self.compressor["shuffle"] = shuffle + @property + def variant_chunk_nbytes(self): + """ + Returns the nbytes for a single variant chunk of this array. + """ + chunk_items = self.chunks[0] + for size in self.shape[1:]: + chunk_items *= size + dt = np.dtype(self.dtype) + return chunk_items * dt.itemsize + ZARR_SCHEMA_FORMAT_VERSION = "0.2" @@ -1526,15 +1537,6 @@ def summary_table(self): return data -@dataclasses.dataclass -class EncodingWork: - func: callable = dataclasses.field(repr=False) - start: int - stop: int - columns: list[str] - memory: int = 0 - - def parse_max_memory(max_memory): if max_memory is None: # Effectively unbounded @@ -1545,32 +1547,199 @@ def parse_max_memory(max_memory): return max_memory +@dataclasses.dataclass +class VcfZarrPartition: + start_index: int + stop_index: int + start_chunk: int + stop_chunk: int + + @staticmethod + def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None): + num_chunks = int(np.ceil(num_records / chunk_size)) + if max_chunks is not None: + num_chunks = min(num_chunks, max_chunks) + partitions = [] + splits = np.array_split(np.arange(num_chunks), min(num_partitions, num_chunks)) + for chunk_slice in splits: + start_chunk = int(chunk_slice[0]) + stop_chunk = int(chunk_slice[-1]) + 1 + start_index = start_chunk * chunk_size + stop_index = min(stop_chunk * chunk_size, num_records) + partitions.append( + VcfZarrPartition(start_index, stop_index, start_chunk, stop_chunk) + ) + return partitions + + +VZW_METADATA_FORMAT_VERSION = "0.1" + + +@dataclasses.dataclass +class VcfZarrWriterMetadata: + format_version: str + icf_path: str + schema: VcfZarrSchema + dimension_separator: str + partitions: list + provenance: dict + + def asdict(self): + return dataclasses.asdict(self) + + @staticmethod + def fromdict(d): + if d["format_version"] != VZW_METADATA_FORMAT_VERSION: + raise ValueError( + "VcfZarrWriter format version mismatch: " + f"{d['format_version']} != {VZW_METADATA_FORMAT_VERSION}" + ) + ret = VcfZarrWriterMetadata(**d) + ret.schema = VcfZarrSchema.fromdict(ret.schema) + ret.partitions = [VcfZarrPartition(**p) for p in ret.partitions] + return ret + + class VcfZarrWriter: - def __init__(self, path, icf, schema, dimension_separator=None): + def __init__(self, path): self.path = pathlib.Path(path) + self.wip_path = self.path / "wip" + self.arrays_path = self.wip_path / "arrays" + self.partitions_path = self.wip_path / "partitions" + self.metadata = None + self.icf = None + + @property + def schema(self): + return self.metadata.schema + + @property + def num_partitions(self): + return len(self.metadata.partitions) + + ####################### + # init + ####################### + + def init( + self, + icf, + *, + target_num_partitions, + schema, + dimension_separator=None, + max_variant_chunks=None, + ): self.icf = icf - self.schema = schema + if self.path.exists(): + raise ValueError("Zarr path already exists") # NEEDS TEST + partitions = VcfZarrPartition.generate_partitions( + self.icf.num_records, + schema.variants_chunk_size, + target_num_partitions, + max_chunks=max_variant_chunks, + ) # Default to using nested directories following the Zarr v3 default. # This seems to require version 2.17+ to work properly - self.dimension_separator = ( + dimension_separator = ( "/" if dimension_separator is None else dimension_separator ) + self.metadata = VcfZarrWriterMetadata( + format_version=VZW_METADATA_FORMAT_VERSION, + icf_path=str(self.icf.path), + schema=schema, + dimension_separator=dimension_separator, + partitions=partitions, + # Bare minimum here for provenance - see comments above + provenance={"source": f"bio2zarr-{provenance.__version__}"}, + ) + + self.path.mkdir() store = zarr.DirectoryStore(self.path) - self.root = zarr.group(store=store) + root = zarr.group(store=store) + root.attrs.update( + { + "vcf_zarr_version": "0.2", + "vcf_header": self.icf.vcf_header, + "source": f"bio2zarr-{provenance.__version__}", + } + ) + # Doing this syncronously - this is fine surely + self.encode_samples(root) + self.encode_filter_id(root) + self.encode_contig_id(root) + + self.wip_path.mkdir() + self.arrays_path.mkdir() + self.partitions_path.mkdir() + store = zarr.DirectoryStore(self.arrays_path) + root = zarr.group(store=store) + + for column in self.schema.columns.values(): + self.init_array(root, column, partitions[-1].stop_index) - def init_array(self, variable): + logger.info("Writing WIP metadata") + with open(self.wip_path / "metadata.json", "w") as f: + json.dump(self.metadata.asdict(), f, indent=4) + return len(partitions) + + def encode_samples(self, root): + if not np.array_equal(self.schema.sample_id, self.icf.metadata.samples): + raise ValueError( + "Subsetting or reordering samples not supported currently" + ) # NEEDS TEST + array = root.array( + "sample_id", + self.schema.sample_id, + dtype="str", + compressor=DEFAULT_ZARR_COMPRESSOR, + chunks=(self.schema.samples_chunk_size,), + ) + array.attrs["_ARRAY_DIMENSIONS"] = ["samples"] + logger.debug("Samples done") + + def encode_contig_id(self, root): + array = root.array( + "contig_id", + self.schema.contig_id, + dtype="str", + compressor=DEFAULT_ZARR_COMPRESSOR, + ) + array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"] + if self.schema.contig_length is not None: + array = root.array( + "contig_length", + self.schema.contig_length, + dtype=np.int64, + compressor=DEFAULT_ZARR_COMPRESSOR, + ) + array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"] + + def encode_filter_id(self, root): + array = root.array( + "filter_id", + self.schema.filter_id, + dtype="str", + compressor=DEFAULT_ZARR_COMPRESSOR, + ) + array.attrs["_ARRAY_DIMENSIONS"] = ["filters"] + + def init_array(self, root, variable, variants_dim_size): object_codec = None if variable.dtype == "O": object_codec = numcodecs.VLenUTF8() - a = self.root.empty( - "wip_" + variable.name, - shape=variable.shape, + shape = list(variable.shape) + # Truncate the variants dimension is max_variant_chunks was specified + shape[0] = variants_dim_size + a = root.empty( + variable.name, + shape=shape, chunks=variable.chunks, dtype=variable.dtype, compressor=numcodecs.get_codec(variable.compressor), filters=[numcodecs.get_codec(filt) for filt in variable.filters], object_codec=object_codec, - dimension_separator=self.dimension_separator, + dimension_separator=self.metadata.dimension_separator, ) a.attrs.update( { @@ -1579,38 +1748,98 @@ def init_array(self, variable): "_ARRAY_DIMENSIONS": variable.dimensions, } ) + logger.debug(f"Initialised {a}") - def get_array(self, name): - return self.root["wip_" + name] + ####################### + # encode_partition + ####################### - def finalise_array(self, variable_name): - source = self.path / ("wip_" + variable_name) - dest = self.path / variable_name + def load_metadata(self): + if self.metadata is None: + with open(self.wip_path / "metadata.json") as f: + self.metadata = VcfZarrWriterMetadata.fromdict(json.load(f)) + self.icf = IntermediateColumnarFormat(self.metadata.icf_path) + + def partition_path(self, partition_index): + return self.partitions_path / f"p{partition_index}" + + def wip_partition_array_path(self, partition_index, name): + return self.partition_path(partition_index) / f"wip_{name}" + + def partition_array_path(self, partition_index, name): + return self.partition_path(partition_index) / name + + def encode_partition(self, partition_index): + self.load_metadata() + partition_path = self.partition_path(partition_index) + partition_path.mkdir(exist_ok=True) + logger.info(f"Encoding partition {partition_index} to {partition_path}") + + self.encode_alleles_partition(partition_index) + self.encode_id_partition(partition_index) + self.encode_filters_partition(partition_index) + self.encode_contig_partition(partition_index) + for col in self.schema.columns.values(): + if col.vcf_field is not None: + self.encode_array_partition(col, partition_index) + if "call_genotype" in self.schema.columns: + self.encode_genotypes_partition(partition_index) + + def init_partition_array(self, partition_index, name): + wip_path = self.wip_partition_array_path(partition_index, name) + # Create an empty array like the definition + src = self.arrays_path / name + # Overwrite any existing WIP files + shutil.copytree(src, wip_path, dirs_exist_ok=True) + array = zarr.open(wip_path) + logger.debug(f"Opened empty array {array} @ {wip_path}") + return array + + def finalise_partition_array(self, partition_index, name): + wip_path = self.wip_partition_array_path(partition_index, name) + final_path = self.partition_array_path(partition_index, name) + if final_path.exists(): + # NEEDS TEST + logger.warning(f"Removing existing {final_path}") + shutil.rmtree(final_path) # Atomic swap - os.rename(source, dest) - logger.info(f"Finalised {variable_name}") + os.rename(wip_path, final_path) + logger.debug(f"Encoded {name} partition {partition_index}") + + def encode_array_partition(self, column, partition_index): + array = self.init_partition_array(partition_index, column.name) - def encode_array_slice(self, column, start, stop): + partition = self.metadata.partitions[partition_index] + ba = core.BufferedArray(array, partition.start_index) source_col = self.icf.columns[column.vcf_field] - array = self.get_array(column.name) - ba = core.BufferedArray(array, start) sanitiser = source_col.sanitiser_factory(ba.buff.shape) - for value in source_col.iter_values(start, stop): + for value in source_col.iter_values( + partition.start_index, partition.stop_index + ): # We write directly into the buffer in the sanitiser function # to make it easier to reason about dimension padding j = ba.next_buffer_row() sanitiser(ba.buff, j, value) ba.flush() - logger.debug(f"Encoded {column.name} slice {start}:{stop}") + self.finalise_partition_array(partition_index, column.name) - def encode_genotypes_slice(self, start, stop): - source_col = self.icf.columns["FORMAT/GT"] - gt = core.BufferedArray(self.get_array("call_genotype"), start) - gt_mask = core.BufferedArray(self.get_array("call_genotype_mask"), start) - gt_phased = core.BufferedArray(self.get_array("call_genotype_phased"), start) + def encode_genotypes_partition(self, partition_index): + gt_array = self.init_partition_array(partition_index, "call_genotype") + gt_mask_array = self.init_partition_array(partition_index, "call_genotype_mask") + gt_phased_array = self.init_partition_array( + partition_index, "call_genotype_phased" + ) + + partition = self.metadata.partitions[partition_index] + gt = core.BufferedArray(gt_array, partition.start_index) + gt_mask = core.BufferedArray(gt_mask_array, partition.start_index) + gt_phased = core.BufferedArray(gt_phased_array, partition.start_index) - for value in source_col.iter_values(start, stop): + source_col = self.icf.columns["FORMAT/GT"] + for value in source_col.iter_values( + partition.start_index, partition.stop_index + ): j = gt.next_buffer_row() sanitise_value_int_2d(gt.buff, j, value[:, :-1]) j = gt_phased.next_buffer_row() @@ -1622,29 +1851,40 @@ def encode_genotypes_slice(self, start, stop): gt.flush() gt_phased.flush() gt_mask.flush() - logger.debug(f"Encoded GT slice {start}:{stop}") - def encode_alleles_slice(self, start, stop): + self.finalise_partition_array(partition_index, "call_genotype") + self.finalise_partition_array(partition_index, "call_genotype_mask") + self.finalise_partition_array(partition_index, "call_genotype_phased") + + def encode_alleles_partition(self, partition_index): + array_name = "variant_allele" + alleles_array = self.init_partition_array(partition_index, array_name) + partition = self.metadata.partitions[partition_index] + alleles = core.BufferedArray(alleles_array, partition.start_index) ref_col = self.icf.columns["REF"] alt_col = self.icf.columns["ALT"] - alleles = core.BufferedArray(self.get_array("variant_allele"), start) for ref, alt in zip( - ref_col.iter_values(start, stop), alt_col.iter_values(start, stop) + ref_col.iter_values(partition.start_index, partition.stop_index), + alt_col.iter_values(partition.start_index, partition.stop_index), ): j = alleles.next_buffer_row() alleles.buff[j, :] = STR_FILL alleles.buff[j, 0] = ref[0] alleles.buff[j, 1 : 1 + len(alt)] = alt alleles.flush() - logger.debug(f"Encoded alleles slice {start}:{stop}") - def encode_id_slice(self, start, stop): + self.finalise_partition_array(partition_index, array_name) + + def encode_id_partition(self, partition_index): + vid_array = self.init_partition_array(partition_index, "variant_id") + vid_mask_array = self.init_partition_array(partition_index, "variant_id_mask") + partition = self.metadata.partitions[partition_index] + vid = core.BufferedArray(vid_array, partition.start_index) + vid_mask = core.BufferedArray(vid_mask_array, partition.start_index) col = self.icf.columns["ID"] - vid = core.BufferedArray(self.get_array("variant_id"), start) - vid_mask = core.BufferedArray(self.get_array("variant_id_mask"), start) - for value in col.iter_values(start, stop): + for value in col.iter_values(partition.start_index, partition.stop_index): j = vid.next_buffer_row() k = vid_mask.next_buffer_row() assert j == k @@ -1656,13 +1896,19 @@ def encode_id_slice(self, start, stop): vid_mask.buff[j] = True vid.flush() vid_mask.flush() - logger.debug(f"Encoded ID slice {start}:{stop}") - def encode_filters_slice(self, lookup, start, stop): - col = self.icf.columns["FILTERS"] - var_filter = core.BufferedArray(self.get_array("variant_filter"), start) + self.finalise_partition_array(partition_index, "variant_id") + self.finalise_partition_array(partition_index, "variant_id_mask") - for value in col.iter_values(start, stop): + def encode_filters_partition(self, partition_index): + lookup = {filt: index for index, filt in enumerate(self.schema.filter_id)} + array_name = "variant_filter" + array = self.init_partition_array(partition_index, array_name) + partition = self.metadata.partitions[partition_index] + var_filter = core.BufferedArray(array, partition.start_index) + + col = self.icf.columns["FILTERS"] + for value in col.iter_values(partition.start_index, partition.stop_index): j = var_filter.next_buffer_row() var_filter.buff[j] = False for f in value: @@ -1670,16 +1916,21 @@ def encode_filters_slice(self, lookup, start, stop): var_filter.buff[j, lookup[f]] = True except KeyError: raise ValueError( - f"Filter '{f}' was not defined " f"in the header." + f"Filter '{f}' was not defined in the header." ) from None var_filter.flush() - logger.debug(f"Encoded FILTERS slice {start}:{stop}") - def encode_contig_slice(self, lookup, start, stop): + self.finalise_partition_array(partition_index, array_name) + + def encode_contig_partition(self, partition_index): + lookup = {contig: index for index, contig in enumerate(self.schema.contig_id)} + array_name = "variant_contig" + array = self.init_partition_array(partition_index, array_name) + partition = self.metadata.partitions[partition_index] + contig = core.BufferedArray(array, partition.start_index) col = self.icf.columns["CHROM"] - contig = core.BufferedArray(self.get_array("variant_contig"), start) - for value in col.iter_values(start, stop): + for value in col.iter_values(partition.start_index, partition.stop_index): j = contig.next_buffer_row() # Note: because we are using the indexes to define the lookups # and we always have an index, it seems that we the contig lookup @@ -1687,161 +1938,120 @@ def encode_contig_slice(self, lookup, start, stop): # here, please do open an issue with a reproducible example! contig.buff[j] = lookup[value[0]] contig.flush() - logger.debug(f"Encoded CHROM slice {start}:{stop}") - - def encode_samples(self): - if not np.array_equal(self.schema.sample_id, self.icf.metadata.samples): - raise ValueError( - "Subsetting or reordering samples not supported currently" - ) # NEEDS TEST - array = self.root.array( - "sample_id", - self.schema.sample_id, - dtype="str", - compressor=DEFAULT_ZARR_COMPRESSOR, - chunks=(self.schema.samples_chunk_size,), - ) - array.attrs["_ARRAY_DIMENSIONS"] = ["samples"] - logger.debug("Samples done") - def encode_contig_id(self): - array = self.root.array( - "contig_id", - self.schema.contig_id, - dtype="str", - compressor=DEFAULT_ZARR_COMPRESSOR, - ) - array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"] - if self.schema.contig_length is not None: - array = self.root.array( - "contig_length", - self.schema.contig_length, - dtype=np.int64, - compressor=DEFAULT_ZARR_COMPRESSOR, + self.finalise_partition_array(partition_index, array_name) + + ####################### + # finalise + ####################### + + def finalise_array(self, name): + logger.info(f"Finalising {name}") + final_path = self.path / name + if final_path.exists(): + # NEEDS TEST + raise ValueError(f"Array {name} already exists") + for partition in range(len(self.metadata.partitions)): + # Move all the files in partition dir to dest dir + src = self.partition_array_path(partition, name) + if not src.exists(): + # Needs test + raise ValueError(f"Partition {partition} of {name} does not exist") + dest = self.arrays_path / name + # This is Zarr v2 specific. Chunks in v3 with start with "c" prefix. + chunk_files = [ + path for path in src.iterdir() if not path.name.startswith(".") + ] + # TODO check for a count of then number of files. If we require a + # dimension_separator of "/" then we could make stronger assertions + # here, as we'd always have num_variant_chunks + logger.debug( + f"Moving {len(chunk_files)} chunks for {name} partition {partition}" ) - array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"] - return {v: j for j, v in enumerate(self.schema.contig_id)} + for chunk_file in chunk_files: + os.rename(chunk_file, dest / chunk_file.name) + # Finally, once all the chunks have moved into the arrays dir, + # we move it out of wip + os.rename(self.arrays_path / name, self.path / name) + core.update_progress(1) - def encode_filter_id(self): - array = self.root.array( - "filter_id", - self.schema.filter_id, - dtype="str", - compressor=DEFAULT_ZARR_COMPRESSOR, + def finalise(self, show_progress=False): + self.load_metadata() + + progress_config = core.ProgressConfig( + total=len(self.schema.columns), + title="Finalise", + units="array", + show=show_progress, ) - array.attrs["_ARRAY_DIMENSIONS"] = ["filters"] - return {v: j for j, v in enumerate(self.schema.filter_id)} + # NOTE: it's not clear that adding more workers will make this quicker, + # as it's just going to be causing contention on the file system. + # Something to check empirically in some deployments. + # FIXME we're just using worker_processes=0 here to hook into the + # SynchronousExecutor which is intended for testing purposes so + # that we get test coverage. Should fix this either by allowing + # for multiple workers, or making a standard wrapper for tqdm + # that allows us to have a consistent look and feel. + with core.ParallelWorkManager(0, progress_config) as pwm: + for name in self.schema.columns: + pwm.submit(self.finalise_array, name) + zarr.consolidate_metadata(self.path) - def init(self): - self.root.attrs["vcf_zarr_version"] = "0.2" - self.root.attrs["vcf_header"] = self.icf.vcf_header - self.root.attrs["source"] = f"bio2zarr-{provenance.__version__}" - for column in self.schema.columns.values(): - self.init_array(column) + ###################### + # encode_all_partitions + ###################### - def finalise(self): - zarr.consolidate_metadata(self.path) + def get_max_encoding_memory(self): + """ + Return the approximate maximum memory used to encode a variant chunk. + """ + max_encoding_mem = max( + col.variant_chunk_nbytes for col in self.schema.columns.values() + ) + gt_mem = 0 + if "call_genotype" in self.schema.columns: + encoded_together = [ + "call_genotype", + "call_genotype_phased", + "call_genotype_mask", + ] + gt_mem = sum( + self.schema.columns[col].variant_chunk_nbytes + for col in encoded_together + ) + return max(max_encoding_mem, gt_mem) - def encode( - self, - worker_processes=1, - max_v_chunks=None, - show_progress=False, - max_memory=None, + def encode_all_partitions( + self, *, worker_processes=1, show_progress=False, max_memory=None ): max_memory = parse_max_memory(max_memory) - - # TODO this will move into the setup logic later when we're making it possible - # to split the work by slice - num_slices = max(1, worker_processes * 4) - # Using POS arbitrarily to get the array slices - slices = core.chunk_aligned_slices( - self.get_array("variant_position"), num_slices, max_chunks=max_v_chunks + self.load_metadata() + num_partitions = self.num_partitions + per_worker_memory = self.get_max_encoding_memory() + logger.info( + f"Encoding Zarr over {num_partitions} partitions with " + f"{worker_processes} workers and {display_size(per_worker_memory)} " + "per worker" ) - truncated = slices[-1][-1] - for array in self.root.values(): - if array.attrs["_ARRAY_DIMENSIONS"][0] == "variants": - shape = list(array.shape) - shape[0] = truncated - array.resize(shape) - - total_bytes = 0 - encoding_memory_requirements = {} - for col in self.schema.columns.values(): - array = self.get_array(col.name) - # NOTE!! this is bad, we're potentially creating quite a large - # numpy array for basically nothing. We can compute this. - variant_chunk_size = array.blocks[0].nbytes - encoding_memory_requirements[col.name] = variant_chunk_size - logger.debug( - f"{col.name} requires at least {display_size(variant_chunk_size)} " - f"per worker" - ) - total_bytes += array.nbytes - - filter_id_map = self.encode_filter_id() - contig_id_map = self.encode_contig_id() - - work = [] - for start, stop in slices: - for col in self.schema.columns.values(): - if col.vcf_field is not None: - f = functools.partial(self.encode_array_slice, col) - work.append( - EncodingWork( - f, - start, - stop, - [col.name], - encoding_memory_requirements[col.name], - ) - ) - work.append( - EncodingWork(self.encode_alleles_slice, start, stop, ["variant_allele"]) + # Each partition requires per_worker_memory bytes, so to prevent more that + # max_memory being used, we clamp the number of workers + max_num_workers = max_memory // per_worker_memory + if max_num_workers < worker_processes: + logger.warning( + f"Limiting number of workers to {max_num_workers} to " + f"keep within specified memory budget of {display_size(max_memory)}" ) - work.append( - EncodingWork( - self.encode_id_slice, start, stop, ["variant_id", "variant_id_mask"] - ) - ) - work.append( - EncodingWork( - functools.partial(self.encode_filters_slice, filter_id_map), - start, - stop, - ["variant_filter"], - ) - ) - work.append( - EncodingWork( - functools.partial(self.encode_contig_slice, contig_id_map), - start, - stop, - ["variant_contig"], - ) + if max_num_workers <= 0: + raise ValueError( + f"Insufficient memory to encode a partition:" + f"{display_size(per_worker_memory)} > {display_size(max_memory)}" ) - if "call_genotype" in self.schema.columns: - variables = [ - "call_genotype", - "call_genotype_phased", - "call_genotype_mask", - ] - gt_memory = sum( - encoding_memory_requirements[name] for name in variables - ) - work.append( - EncodingWork( - self.encode_genotypes_slice, start, stop, variables, gt_memory - ) - ) + num_workers = min(max_num_workers, worker_processes) - # Fail early if we can't fit a particular column into memory - for wp in work: - if wp.memory > max_memory: - raise ValueError( - f"Insufficient memory for {wp.columns}: " - f"{display_size(wp.memory)} > {display_size(max_memory)}" - ) + total_bytes = 0 + for col in self.schema.columns.values(): + # Open the array definition to get the total size + total_bytes += zarr.open(self.arrays_path / col.name).nbytes progress_config = core.ProgressConfig( total=total_bytes, @@ -1849,54 +2059,9 @@ def encode( units="B", show=show_progress, ) - - used_memory = 0 - # We need to keep some bounds on the queue size or the memory bounds algorithm - # below doesn't really work. - max_queued = 4 * max(1, worker_processes) - encoded_slices = collections.Counter() - - with core.ParallelWorkManager(worker_processes, progress_config) as pwm: - future = pwm.submit(self.encode_samples) - future_to_work = {future: EncodingWork(None, 0, 0, [])} - - def service_completed_futures(): - nonlocal used_memory - - completed = pwm.wait_for_completed() - for future in completed: - wp_done = future_to_work.pop(future) - used_memory -= wp_done.memory - logger.debug( - f"Complete {wp_done}: used mem={display_size(used_memory)}" - ) - for column in wp_done.columns: - encoded_slices[column] += 1 - if encoded_slices[column] == len(slices): - # Do this syncronously for simplicity. Should be - # fine as the workers will probably be busy with - # large encode tasks most of the time. - self.finalise_array(column) - - for wp in work: - while ( - used_memory + wp.memory > max_memory - or len(future_to_work) > max_queued - ): - logger.debug( - f"Wait: mem_required={used_memory + wp.memory} " - f"max_mem={max_memory} queued={len(future_to_work)} " - f"max_queued={max_queued}" - ) - service_completed_futures() - future = pwm.submit(wp.func, wp.start, wp.stop) - used_memory += wp.memory - logger.debug(f"Submit {wp}: used mem={display_size(used_memory)}") - future_to_work[future] = wp - - logger.debug("All work submitted") - while len(future_to_work) > 0: - service_completed_futures() + with core.ParallelWorkManager(num_workers, progress_config) as pwm: + for partition_index in range(num_partitions): + pwm.submit(self.encode_partition, partition_index) def mkschema(if_path, out): @@ -1911,13 +2076,48 @@ def encode( schema_path=None, variants_chunk_size=None, samples_chunk_size=None, - max_v_chunks=None, + max_variant_chunks=None, dimension_separator=None, max_memory=None, worker_processes=1, show_progress=False, ): - icf = IntermediateColumnarFormat(if_path) + # Rough heuristic to split work up enough to keep utilisation high + target_num_partitions = max(1, worker_processes * 4) + encode_init( + if_path, + zarr_path, + target_num_partitions, + schema_path=schema_path, + variants_chunk_size=variants_chunk_size, + samples_chunk_size=samples_chunk_size, + max_variant_chunks=max_variant_chunks, + dimension_separator=dimension_separator, + ) + vzw = VcfZarrWriter(zarr_path) + vzw.encode_all_partitions( + worker_processes=worker_processes, + show_progress=show_progress, + max_memory=max_memory, + ) + vzw.finalise(show_progress) + + +def encode_init( + icf_path, + zarr_path, + target_num_partitions, + *, + schema_path=None, + variants_chunk_size=None, + samples_chunk_size=None, + max_variant_chunks=None, + dimension_separator=None, + max_memory=None, + worker_processes=1, + show_progress=False, +): + icf = IntermediateColumnarFormat(icf_path) if schema_path is None: schema = VcfZarrSchema.generate( icf, @@ -1933,18 +2133,25 @@ def encode( with open(schema_path) as f: schema = VcfZarrSchema.fromjson(f.read()) zarr_path = pathlib.Path(zarr_path) - if zarr_path.exists(): - logger.warning(f"Deleting existing {zarr_path}") - shutil.rmtree(zarr_path) - vzw = VcfZarrWriter(zarr_path, icf, schema, dimension_separator=dimension_separator) - vzw.init() - vzw.encode( - max_v_chunks=max_v_chunks, - worker_processes=worker_processes, - max_memory=max_memory, - show_progress=show_progress, + vzw = VcfZarrWriter(zarr_path) + vzw.init( + icf, + target_num_partitions=target_num_partitions, + schema=schema, + dimension_separator=dimension_separator, + max_variant_chunks=max_variant_chunks, ) - vzw.finalise() + return vzw.num_partitions, vzw.get_max_encoding_memory() + + +def encode_partition(zarr_path, partition): + writer = VcfZarrWriter(zarr_path) + writer.encode_partition(partition) + + +def encode_finalise(zarr_path, show_progress=False): + writer = VcfZarrWriter(zarr_path) + writer.finalise(show_progress=show_progress) def convert( @@ -2154,7 +2361,7 @@ def validate(vcf_path, zarr_path, show_progress=False): assert pos[start_index] == first_pos vcf = cyvcf2.VCF(vcf_path) if show_progress: - iterator = tqdm.tqdm(vcf, desc=" Verify", total=vcf.num_records) # NEEDS TEST + iterator = tqdm.tqdm(vcf, desc=" Verify", total=vcf.num_records) # NEEDS TEST else: iterator = vcf for j, row in enumerate(iterator, start_index): diff --git a/tests/test_cli.py b/tests/test_cli.py index 3500978..8c22a38 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -22,16 +22,29 @@ compressor=None, show_progress=True, ) + DEFAULT_ENCODE_ARGS = dict( schema_path=None, variants_chunk_size=None, samples_chunk_size=None, - max_v_chunks=None, + max_variant_chunks=None, worker_processes=1, max_memory=None, show_progress=True, ) +DEFAULT_DENCODE_INIT_ARGS = dict( + schema_path=None, + variants_chunk_size=None, + samples_chunk_size=None, + max_variant_chunks=None, + show_progress=True, +) + +DEFAULT_DENCODE_PARTITION_ARGS = dict() + +DEFAULT_DENCODE_FINALISE_ARGS = dict(show_progress=True) + class TestWithMocks: vcf_path = "tests/data/vcf/sample.vcf.gz" @@ -385,6 +398,55 @@ def test_encode(self, mocked, tmp_path): **DEFAULT_ENCODE_ARGS, ) + @mock.patch("bio2zarr.vcf.encode_init", return_value=(10, 1024)) + def test_dencode_init(self, mocked, tmp_path): + icf_path = tmp_path / "icf" + icf_path.mkdir() + zarr_path = tmp_path / "zarr" + runner = ct.CliRunner(mix_stderr=False) + result = runner.invoke( + cli.vcf2zarr, + f"dencode-init {icf_path} {zarr_path} 10", + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert result.stdout == "10\t1 KiB\n" + assert len(result.stderr) == 0 + mocked.assert_called_once_with( + str(icf_path), + str(zarr_path), + target_num_partitions=10, + **DEFAULT_DENCODE_INIT_ARGS, + ) + + @mock.patch("bio2zarr.vcf.encode_partition") + def test_vcf_dencode_partition(self, mocked, tmp_path): + runner = ct.CliRunner(mix_stderr=False) + zarr_path = tmp_path / "zarr" + zarr_path.mkdir() + result = runner.invoke( + cli.vcf2zarr, + f"dencode-partition {zarr_path} 1", + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert len(result.stdout) == 0 + assert len(result.stderr) == 0 + mocked.assert_called_once_with( + str(zarr_path), 1, **DEFAULT_DENCODE_PARTITION_ARGS + ) + + @mock.patch("bio2zarr.vcf.encode_finalise") + def test_vcf_dencode_finalise(self, mocked, tmp_path): + runner = ct.CliRunner(mix_stderr=False) + result = runner.invoke( + cli.vcf2zarr, f"dencode-finalise {tmp_path}", catch_exceptions=False + ) + assert result.exit_code == 0 + assert len(result.stdout) == 0 + assert len(result.stderr) == 0 + mocked.assert_called_once_with(str(tmp_path), **DEFAULT_DENCODE_FINALISE_ARGS) + @mock.patch("bio2zarr.vcf.convert") def test_convert_vcf(self, mocked): runner = ct.CliRunner(mix_stderr=False) @@ -490,6 +552,42 @@ def test_encode(self, tmp_path): # Arbitrary check assert "variant_position" in result.stdout + def test_dencode(self, tmp_path): + icf_path = tmp_path / "icf" + zarr_path = tmp_path / "zarr" + runner = ct.CliRunner(mix_stderr=False) + result = runner.invoke( + cli.vcf2zarr, f"explode {self.vcf_path} {icf_path}", catch_exceptions=False + ) + assert result.exit_code == 0 + result = runner.invoke( + cli.vcf2zarr, + f"dencode-init {icf_path} {zarr_path} 5 --variants-chunk-size=3", + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert result.stdout.split()[0] == "3" + + for j in range(3): + result = runner.invoke( + cli.vcf2zarr, + f"dencode-partition {zarr_path} {j}", + catch_exceptions=False, + ) + assert result.exit_code == 0 + + result = runner.invoke( + cli.vcf2zarr, f"dencode-finalise {zarr_path}", catch_exceptions=False + ) + assert result.exit_code == 0 + + result = runner.invoke( + cli.vcf2zarr, f"inspect {zarr_path}", catch_exceptions=False + ) + assert result.exit_code == 0 + # Arbitrary check + assert "variant_position" in result.stdout + def test_convert(self, tmp_path): zarr_path = tmp_path / "zarr" runner = ct.CliRunner(mix_stderr=False) diff --git a/tests/test_vcf.py b/tests/test_vcf.py index ebd17b9..dfd690e 100644 --- a/tests/test_vcf.py +++ b/tests/test_vcf.py @@ -64,17 +64,16 @@ def test_not_enough_memory(self, tmp_path, icf_path, max_memory): with pytest.raises(ValueError, match="Insufficient memory"): vcf.encode(icf_path, zarr_path, max_memory=max_memory) - @pytest.mark.parametrize("max_memory", [135, 269]) + @pytest.mark.parametrize("max_memory", ["150KiB", "200KiB"]) def test_not_enough_memory_for_two( self, tmp_path, icf_path, zarr_path, caplog, max_memory ): other_zarr_path = tmp_path / "zarr" - with caplog.at_level("DEBUG"): + with caplog.at_level("WARNING"): vcf.encode( icf_path, other_zarr_path, max_memory=max_memory, worker_processes=2 ) - # This isn't a particularly strong test, but oh well. - assert "Wait: mem_required" in caplog.text + assert "Limiting number of workers to 1 to keep within" in caplog.text ds1 = sg.load_dataset(zarr_path) ds2 = sg.load_dataset(other_zarr_path) xt.assert_equal(ds1, ds2) diff --git a/tests/test_vcf_examples.py b/tests/test_vcf_examples.py index 3667287..ad06ea0 100644 --- a/tests/test_vcf_examples.py +++ b/tests/test_vcf_examples.py @@ -317,9 +317,11 @@ def test_full_pipeline(self, ds, tmp_path, worker_processes): ds2 = sg.load_dataset(out) xt.assert_equal(ds, ds2) - @pytest.mark.parametrize("max_v_chunks", [1, 2, 3]) + @pytest.mark.parametrize("max_variant_chunks", [1, 2, 3]) @pytest.mark.parametrize("variants_chunk_size", [1, 2, 3]) - def test_max_v_chunks(self, ds, tmp_path, max_v_chunks, variants_chunk_size): + def test_max_variant_chunks( + self, ds, tmp_path, max_variant_chunks, variants_chunk_size + ): exploded = tmp_path / "example.exploded" vcf.explode(exploded, [self.data_path]) out = tmp_path / "example.zarr" @@ -327,11 +329,11 @@ def test_max_v_chunks(self, ds, tmp_path, max_v_chunks, variants_chunk_size): exploded, out, variants_chunk_size=variants_chunk_size, - max_v_chunks=max_v_chunks, + max_variant_chunks=max_variant_chunks, ) ds2 = sg.load_dataset(out) xt.assert_equal( - ds.isel(variants=slice(None, variants_chunk_size * max_v_chunks)), ds2 + ds.isel(variants=slice(None, variants_chunk_size * max_variant_chunks)), ds2 ) @pytest.mark.parametrize("worker_processes", [0, 1, 2]) @@ -892,5 +894,6 @@ def test_split_explode(tmp_path): def test_missing_filter(tmp_path): path = "tests/data/vcf/sample_missing_filter.vcf.gz" + zarr_path = tmp_path / "zarr" with pytest.raises(ValueError, match="Filter 'q10' was not defined in the header"): - vcf.convert([path], tmp_path) + vcf.convert([path], zarr_path)