From f911cc7135ba0b5b606b0cd652aa57c760615c68 Mon Sep 17 00:00:00 2001 From: Xin Huang <42597328+huan233usc@users.noreply.github.com> Date: Sat, 17 Feb 2024 20:05:32 -0800 Subject: [PATCH] Initial impl for branch (#90) usage ds.add_branch("a") ds.set_current_branch("a") then all write/read will go to this branch. --------- Co-authored-by: coufon --- python/src/space/core/datasets.py | 16 ++- python/src/space/core/proto/metadata_pb2.pyi | 3 + python/src/space/core/storage.py | 133 ++++++++++++++---- python/tests/core/loaders/test_parquet.py | 2 +- python/tests/core/ops/test_delete.py | 4 +- python/tests/core/ops/test_read.py | 4 +- python/tests/core/test_runners.py | 75 ++++++++++ python/tests/core/test_storage.py | 72 +++++++++- python/tests/ray/test_runners.py | 138 ++++++++++--------- 9 files changed, 341 insertions(+), 106 deletions(-) diff --git a/python/src/space/core/datasets.py b/python/src/space/core/datasets.py index ff51fc4..29c7649 100644 --- a/python/src/space/core/datasets.py +++ b/python/src/space/core/datasets.py @@ -73,13 +73,25 @@ def record_fields(self) -> List[str]: return self._storage.record_fields def add_tag(self, tag: str, snapshot_id: Optional[int] = None): - """Add yag to a snapshot.""" + """Add tag to a dataset.""" self._storage.add_tag(tag, snapshot_id) def remove_tag(self, tag: str): - """Remove tag from a snapshot.""" + """Remove tag from a dataset.""" self._storage.remove_tag(tag) + def add_branch(self, branch: str): + """Add branch to a dataset.""" + self._storage.add_branch(branch) + + def remove_branch(self, branch: str): + """Remove branch for a dataset.""" + self._storage.remove_branch(branch) + + def set_current_branch(self, branch: str): + """Set current branch for the dataset.""" + self._storage.set_current_branch(branch) + def local(self, file_options: Optional[FileOptions] = None) -> LocalRunner: """Get a runner that runs operations locally.""" return LocalRunner(self._storage, file_options) diff --git a/python/src/space/core/proto/metadata_pb2.pyi b/python/src/space/core/proto/metadata_pb2.pyi index 0581e6c..9ad0be6 100644 --- a/python/src/space/core/proto/metadata_pb2.pyi +++ b/python/src/space/core/proto/metadata_pb2.pyi @@ -366,6 +366,8 @@ global___StorageStatistics = StorageStatistics class ChangeLog(google.protobuf.message.Message): """Change log stores changes made by a snapshot. NEXT_ID: 3 + TODO: to replace RowBitmap list by runtime.FileSet (not backward + compatible). """ DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -391,6 +393,7 @@ global___ChangeLog = ChangeLog @typing_extensions.final class RowBitmap(google.protobuf.message.Message): """Mark rows in a file by bitmap. + TODO: to replace it by runtime.DataFile (not backward compatible). NEXT_ID: 5 """ diff --git a/python/src/space/core/storage.py b/python/src/space/core/storage.py index 14e4470..7eb32fa 100644 --- a/python/src/space/core/storage.py +++ b/python/src/space/core/storage.py @@ -50,6 +50,10 @@ # Initial snapshot ID. _INIT_SNAPSHOT_ID = 0 +# Name for the main branch, by default the read write are using this branch. +_MAIN_BRANCH = "main" +# Sets of reference that could not be added as branches or tags by user. +_RESERVED_REFERENCE = [_MAIN_BRANCH] # pylint: disable=too-many-public-methods @@ -59,8 +63,11 @@ class Storage(paths.StoragePathsMixin): Not thread safe. """ - def __init__(self, location: str, metadata_file: str, - metadata: meta.StorageMetadata): + def __init__(self, + location: str, + metadata_file: str, + metadata: meta.StorageMetadata, + current_branch: Optional[str] = None): super().__init__(location) self._fs = create_fs(location) self._metadata = metadata @@ -77,12 +84,21 @@ def __init__(self, location: str, metadata_file: str, self._physical_schema) self._primary_keys = set(self._metadata.schema.primary_keys) + self._current_branch = current_branch or _MAIN_BRANCH + self._max_snapshot_id = max( + [ref.snapshot_id for ref in self._metadata.refs.values()] + + [self._metadata.current_snapshot_id]) @property def metadata(self) -> meta.StorageMetadata: """Return the storage metadata.""" return self._metadata + @property + def current_branch(self) -> str: + """Return the current branch.""" + return self._current_branch + @property def primary_keys(self) -> List[str]: """Return the storage primary keys.""" @@ -103,6 +119,13 @@ def physical_schema(self) -> pa.Schema: """Return the physcal schema that uses reference for record fields.""" return self._physical_schema + def current_snapshot_id(self, branch: str) -> int: + """Returns the snapshot id for the current branch.""" + if branch != _MAIN_BRANCH: + return self.lookup_reference(branch).snapshot_id + + return self.metadata.current_snapshot_id + def serializer(self) -> DictSerializer: """Return a serializer (deserializer) for the dataset.""" return DictSerializer.create(self.logical_schema) @@ -112,7 +135,10 @@ def snapshot(self, snapshot_id: Optional[int] = None) -> meta.Snapshot: if not specified. """ if snapshot_id is None: - snapshot_id = self._metadata.current_snapshot_id + if self.current_branch == _MAIN_BRANCH: + snapshot_id = self._metadata.current_snapshot_id + else: + snapshot_id = self.version_to_snapshot_id(self.current_branch) if snapshot_id in self._metadata.snapshots: return self._metadata.snapshots[snapshot_id] @@ -185,7 +211,8 @@ def reload(self) -> bool: return False metadata = _read_metadata(self._fs, self._location, entry_point) - self.__init__(self.location, entry_point.metadata_file, metadata) # type: ignore[misc] # pylint: disable=unnecessary-dunder-call + self.__init__( # type: ignore[misc] # pylint: disable=unnecessary-dunder-call + self.location, entry_point.metadata_file, metadata, self.current_branch) logging.info( f"Storage reloaded to snapshot: {self._metadata.current_snapshot_id}") return True @@ -199,9 +226,9 @@ def version_to_snapshot_id(self, version: Version) -> int: if isinstance(version, int): return version - return self._lookup_reference(version).snapshot_id + return self.lookup_reference(version).snapshot_id - def _lookup_reference(self, tag_or_branch: str) -> meta.SnapshotReference: + def lookup_reference(self, tag_or_branch: str) -> meta.SnapshotReference: """Lookup a snapshot reference.""" if tag_or_branch in self._metadata.refs: return self._metadata.refs[tag_or_branch] @@ -210,24 +237,48 @@ def _lookup_reference(self, tag_or_branch: str) -> meta.SnapshotReference: def add_tag(self, tag: str, snapshot_id: Optional[int] = None) -> None: """Add tag to a snapshot""" + self._add_reference(tag, meta.SnapshotReference.TAG, snapshot_id) + + def add_branch(self, branch: str) -> None: + """Add branch to a snapshot""" + self._add_reference(branch, meta.SnapshotReference.BRANCH, None) + + def set_current_branch(self, branch: str) -> None: + """Set current branch for the snapshot.""" + if branch != _MAIN_BRANCH: + snapshot_ref = self.lookup_reference(branch) + if snapshot_ref.type != meta.SnapshotReference.BRANCH: + raise errors.UserInputError("{branch} is not a branch.") + + self._current_branch = branch + + def _add_reference(self, + ref_name: str, + ref_type: meta.SnapshotReference.ReferenceType.ValueType, + snapshot_id: Optional[int] = None) -> None: + """Add reference to a snapshot""" if snapshot_id is None: snapshot_id = self._metadata.current_snapshot_id if snapshot_id not in self._metadata.snapshots: raise errors.VersionNotFoundError(f"Snapshot {snapshot_id} is not found") - if len(tag) == 0: - raise errors.UserInputError("Tag cannot be empty") + if not ref_name: + raise errors.UserInputError("Reference name cannot be empty.") + + if ref_name in _RESERVED_REFERENCE: + raise errors.UserInputError("{ref_name} is reserved.") - if tag in self._metadata.refs: - raise errors.VersionAlreadyExistError(f"Reference {tag} already exist") + if ref_name in self._metadata.refs: + raise errors.VersionAlreadyExistError( + f"Reference {ref_name} already exist") new_metadata = meta.StorageMetadata() new_metadata.CopyFrom(self._metadata) - tag_ref = meta.SnapshotReference(reference_name=tag, - snapshot_id=snapshot_id, - type=meta.SnapshotReference.TAG) - new_metadata.refs[tag].CopyFrom(tag_ref) + ref = meta.SnapshotReference(reference_name=ref_name, + snapshot_id=snapshot_id, + type=ref_type) + new_metadata.refs[ref_name].CopyFrom(ref) new_metadata_path = self.new_metadata_path() self._write_metadata(new_metadata_path, new_metadata) self._metadata = new_metadata @@ -235,19 +286,33 @@ def add_tag(self, tag: str, snapshot_id: Optional[int] = None) -> None: def remove_tag(self, tag: str) -> None: """Remove tag from metadata""" - if (tag not in self._metadata.refs or - self._metadata.refs[tag].type != meta.SnapshotReference.TAG): - raise errors.VersionNotFoundError(f"Tag {tag} is not found") + self._remove_reference(tag, meta.SnapshotReference.TAG) + + def remove_branch(self, branch: str) -> None: + """Remove branch from metadata""" + if branch == self._current_branch: + raise errors.UserInputError("Cannot remove the current branch.") + + self._remove_reference(branch, meta.SnapshotReference.BRANCH) + + def _remove_reference( + self, ref_name: str, + ref_type: meta.SnapshotReference.ReferenceType.ValueType) -> None: + if (ref_name not in self._metadata.refs or + self._metadata.refs[ref_name].type != ref_type): + raise errors.VersionNotFoundError( + f"Reference {ref_name} is not found or has a wrong type " + "(tag vs branch)") new_metadata = meta.StorageMetadata() new_metadata.CopyFrom(self._metadata) - del new_metadata.refs[tag] + del new_metadata.refs[ref_name] new_metadata_path = self.new_metadata_path() self._write_metadata(new_metadata_path, new_metadata) self._metadata = new_metadata self._metadata_file = new_metadata_path - def commit(self, patch: rt.Patch) -> None: + def commit(self, patch: rt.Patch, branch: str) -> None: """Commit changes to the storage. TODO: only support a single writer; to ensure atomicity in commit by @@ -255,13 +320,24 @@ def commit(self, patch: rt.Patch) -> None: Args: patch: a patch describing changes made to the storage. + branch: the branch this commit is writing to. """ - current_snapshot = self.snapshot() - new_metadata = meta.StorageMetadata() new_metadata.CopyFrom(self._metadata) new_snapshot_id = self._next_snapshot_id() - new_metadata.current_snapshot_id = new_snapshot_id + if branch != _MAIN_BRANCH: + branch_snapshot = self.lookup_reference(branch) + # To block the case delete branch and add a tag during commit + # TODO: move this check out of commit() + if branch_snapshot.type != meta.SnapshotReference.BRANCH: + raise errors.UserInputError("Branch {branch} is no longer exists.") + + new_metadata.refs[branch].snapshot_id = new_snapshot_id + current_snapshot = self.snapshot(branch_snapshot.snapshot_id) + else: + new_metadata.current_snapshot_id = new_snapshot_id + current_snapshot = self.snapshot(self._metadata.current_snapshot_id) + new_metadata.last_update_time.CopyFrom(proto_now()) new_metadata_path = self.new_metadata_path() @@ -417,7 +493,8 @@ def _initialize_files(self, metadata_path: str) -> None: raise errors.StorageExistError(str(e)) from None def _next_snapshot_id(self) -> int: - return self._metadata.current_snapshot_id + 1 + self._max_snapshot_id = self._max_snapshot_id + 1 + return self._max_snapshot_id def _write_metadata( self, @@ -473,7 +550,7 @@ def __init__(self, storage: Storage): self._txn_id = uuid_() # The storage snapshot ID when the transaction starts. self._snapshot_id: Optional[int] = None - + self._branch = storage.current_branch self._result: Optional[JobResult] = None def commit(self, patch: Optional[rt.Patch]) -> None: @@ -483,7 +560,9 @@ def commit(self, patch: Optional[rt.Patch]) -> None: # Check that no other commit has taken place. assert self._snapshot_id is not None self._storage.reload() - if self._snapshot_id != self._storage.metadata.current_snapshot_id: + current_snapshot_id = self._storage.current_snapshot_id(self._branch) + + if self._snapshot_id != current_snapshot_id: self._result = JobResult( JobResult.State.FAILED, None, "Abort commit because the storage has been modified.") @@ -493,7 +572,7 @@ def commit(self, patch: Optional[rt.Patch]) -> None: self._result = JobResult(JobResult.State.SKIPPED) return - self._storage.commit(patch) + self._storage.commit(patch, self._branch) self._result = JobResult(JobResult.State.SUCCEEDED, patch.storage_statistics_update) @@ -509,7 +588,7 @@ def __enter__(self) -> Transaction: # All mutations start with a transaction, so storage is always reloaded for # mutations. self._storage.reload() - self._snapshot_id = self._storage.metadata.current_snapshot_id + self._snapshot_id = self._storage.current_snapshot_id(self._branch) logging.info(f"Start transaction {self._txn_id}") return self diff --git a/python/tests/core/loaders/test_parquet.py b/python/tests/core/loaders/test_parquet.py index 5366493..7c42f8b 100644 --- a/python/tests/core/loaders/test_parquet.py +++ b/python/tests/core/loaders/test_parquet.py @@ -71,4 +71,4 @@ def test_append_parquet(self, tmp_path): ]).combine_chunks().sort_by("int64") assert not ds.index_files(version="empty") - assert ds.index_files(version="after_append") == [file0, file1] + assert sorted(ds.index_files(version="after_append")) == [file0, file1] diff --git a/python/tests/core/ops/test_delete.py b/python/tests/core/ops/test_delete.py index 44c6a1d..0be3469 100644 --- a/python/tests/core/ops/test_delete.py +++ b/python/tests/core/ops/test_delete.py @@ -42,7 +42,7 @@ def test_delete_all_types(self, tmp_path, all_types_schema, for batch in input_data: append_op.write(batch) - storage.commit(append_op.finish()) + storage.commit(append_op.finish(), "main") old_data_files = storage.data_files() delete_op = FileSetDeleteOp( @@ -54,7 +54,7 @@ def test_delete_all_types(self, tmp_path, all_types_schema, _default_file_options) patch = delete_op.delete() assert patch is not None - storage.commit(patch) + storage.commit(patch, "main") # Verify storage metadata after patch. new_data_files = storage.data_files() diff --git a/python/tests/core/ops/test_read.py b/python/tests/core/ops/test_read.py index c9fb71f..ddb5eb5 100644 --- a/python/tests/core/ops/test_read.py +++ b/python/tests/core/ops/test_read.py @@ -41,7 +41,7 @@ def test_read_all_types(self, tmp_path, all_types_schema, for batch in input_data: append_op.write(batch) - storage.commit(append_op.finish()) + storage.commit(append_op.finish(), "main") read_op = FileSetReadOp(str(location), storage.metadata, storage.data_files()) @@ -79,7 +79,7 @@ def test_read_with_record_filters(self, tmp_path, record_fields_schema, for batch in input_data: append_op.write(batch) - storage.commit(append_op.finish()) + storage.commit(append_op.finish(), "main") data_files = storage.data_files() read_op = FileSetReadOp(str(location), storage.metadata, data_files) diff --git a/python/tests/core/test_runners.py b/python/tests/core/test_runners.py index c094bb0..8e5e573 100644 --- a/python/tests/core/test_runners.py +++ b/python/tests/core/test_runners.py @@ -200,6 +200,81 @@ def test_add_read_remove_tag(self, sample_dataset): assert "Version insert1 is not found" in str(excinfo.value) + def test_concurrent_write_to_different_branch(self, sample_dataset): + ds = sample_dataset + ds.add_branch("exp1") + local_runner = ds.local() + lock1 = threading.Lock() + lock2 = threading.Lock() + lock1.acquire() + lock2.acquire() + + sample_data = _generate_data([1, 2]) + + def make_iter(): + yield sample_data + lock2.release() + lock1.acquire() + yield sample_data + lock1.release() + + job_result = [None] + + def append_data(): + job_result[0] = local_runner.append_from(make_iter) + + t = threading.Thread(target=append_data) + t.start() + lock2.acquire() + ds.set_current_branch("exp1") + local_runner.append(sample_data) + lock2.release() + lock1.release() + t.join() + + ds.set_current_branch("main") + assert local_runner.read_all() == pa.concat_tables( + [sample_data, sample_data]) + assert local_runner.read_all(version="exp1") == pa.concat_tables( + [sample_data]) + + def test_add_read_with_branch(self, sample_dataset): + ds = sample_dataset + local_runner = ds.local() + + sample_data1 = _generate_data([1, 2]) + local_runner.append(sample_data1) + + ds.add_branch("exp1") + + assert local_runner.read_all() == sample_data1 + + create_time0 = datetime.utcfromtimestamp( + ds.storage.metadata.snapshots[0].create_time.seconds).replace( + tzinfo=pytz.utc) + create_time1 = datetime.utcfromtimestamp( + ds.storage.metadata.snapshots[1].create_time.seconds).replace( + tzinfo=pytz.utc) + assert ds.versions().to_pydict() == { + "snapshot_id": [1, 0], + "tag_or_branch": ["exp1", None], + "create_time": [create_time1, create_time0] + } + + sample_data2 = _generate_data([3, 4]) + local_runner.append(sample_data2) + + ds.set_current_branch("exp1") + + sample_data3 = _generate_data([5, 6]) + local_runner.append(sample_data3) + + ds.set_current_branch("main") + assert local_runner.read_all() == pa.concat_tables( + [sample_data1, sample_data2]) + assert local_runner.read_all(version="exp1") == pa.concat_tables( + [sample_data1, sample_data3]) + def test_dataset_with_file_type(self, tmp_path): schema = pa.schema([("id", pa.int64()), ("name", pa.string()), ("file", File(directory="test_folder"))]) diff --git a/python/tests/core/test_storage.py b/python/tests/core/test_storage.py index cc7efc0..de4d96b 100644 --- a/python/tests/core/test_storage.py +++ b/python/tests/core/test_storage.py @@ -137,7 +137,7 @@ def test_commit(self, tmp_path): record_uncompressed_bytes=30) patch = rt.Patch(addition=added_manifest_files, storage_statistics_update=added_storage_statistics) - storage.commit(patch) + storage.commit(patch, "main") assert storage.snapshot(0) is not None new_snapshot = storage.snapshot(1) @@ -155,7 +155,7 @@ def test_commit(self, tmp_path): index_manifest_files=["data/index_manifest1"], record_manifest_files=["data/record_manifest1"]), storage_statistics_update=added_storage_statistics2) - storage.commit(patch) + storage.commit(patch, "main") new_snapshot = storage.snapshot(2) assert new_snapshot.manifest_files == meta.ManifestFiles( @@ -177,7 +177,7 @@ def test_commit(self, tmp_path): index_compressed_bytes=-10, index_uncompressed_bytes=-20, record_uncompressed_bytes=-30)) - storage.commit(patch) + storage.commit(patch, "main") new_snapshot = storage.snapshot(3) assert new_snapshot.manifest_files.index_manifest_files == [ "data/index_manifest1" @@ -202,7 +202,7 @@ def create_index_manifest_writer(): def commit_add_index_manifest(manifest_path: str): patch = rt.Patch(addition=meta.ManifestFiles( index_manifest_files=[storage.short_path(manifest_path)])) - storage.commit(patch) + storage.commit(patch, "main") manifest_writer = create_index_manifest_writer() manifest_writer.write( @@ -375,3 +375,67 @@ def test_tags(self, tmp_path): "tag_or_branch": ["tag2"], "create_time": [create_time1] } + + def test_branches(self, tmp_path): + location = tmp_path / "dataset" + storage = Storage.create(location=str(location), + schema=_SCHEMA, + primary_keys=["int64"], + record_fields=[]) + + create_time1 = datetime.utcfromtimestamp( + storage.metadata.snapshots[0].create_time.seconds).replace( + tzinfo=pytz.utc) + assert storage.versions().to_pydict() == { + "snapshot_id": [0], + "tag_or_branch": [None], + "create_time": [create_time1] + } + + storage.add_branch("branch1") + + with pytest.raises(errors.UserInputError, match=r".*already exist.*"): + storage.add_branch("branch1") + + storage.add_branch("branch2") + + snapshot_id1 = storage.version_to_snapshot_id("branch1") + snapshot_id2 = storage.version_to_snapshot_id("branch2") + + metadata = storage.metadata + assert len(metadata.refs) == 2 + assert snapshot_id1 == snapshot_id2 == metadata.current_snapshot_id + + versions = storage.versions().to_pydict() + versions["tag_or_branch"].sort() + assert versions == { + "snapshot_id": [0, 0], + "tag_or_branch": ["branch1", "branch2"], + "create_time": [create_time1, create_time1] + } + + storage.remove_branch("branch1") + + with pytest.raises(errors.UserInputError, match=r".*not found.*"): + storage.remove_branch("branch1") + assert len(storage.metadata.refs) == 1 + + patch = rt.Patch(addition=meta.ManifestFiles( + index_manifest_files=["data/index_manifest1"], + record_manifest_files=["data/record_manifest1"]), + storage_statistics_update=meta.StorageStatistics( + num_rows=100, + index_compressed_bytes=100, + index_uncompressed_bytes=200, + record_uncompressed_bytes=300)) + storage.commit(patch, "branch2") + + create_time2 = datetime.utcfromtimestamp( + storage.metadata.snapshots[1].create_time.seconds).replace( + tzinfo=pytz.utc) + + assert storage.versions().to_pydict() == { + "snapshot_id": [1, 0], + "tag_or_branch": ["branch2", None], + "create_time": [create_time2, create_time1] + } diff --git a/python/tests/ray/test_runners.py b/python/tests/ray/test_runners.py index b109380..c7b61ec 100644 --- a/python/tests/ray/test_runners.py +++ b/python/tests/ray/test_runners.py @@ -83,81 +83,83 @@ class TestRayReadWriteRunner: ]) def test_write_read_dataset(self, sample_dataset, enable_row_range_block, batch_size): + sample_dataset.add_branch("branch1") runner = sample_dataset.ray(ray_options=RayOptions( max_parallelism=4, enable_row_range_block=enable_row_range_block)) - - # Test append. input_data0 = generate_data([1, 2, 3]) - runner.append(input_data0) - - assert_equal( - runner.read_all(batch_size=batch_size).sort_by("int64"), - input_data0.sort_by("int64")) - input_data1 = generate_data([4, 5]) input_data2 = generate_data([6, 7]) input_data3 = generate_data([8]) input_data4 = generate_data([9, 10, 11]) - - runner.append_from([ - lambda: iter([input_data1, input_data2]), lambda: iter([input_data3]), - lambda: iter([input_data4]) - ]) - - assert_equal( - runner.read_all(batch_size=batch_size).sort_by("int64"), - pa.concat_tables( - [input_data0, input_data1, input_data2, input_data3, - input_data4]).sort_by("int64")) - - # Test insert. - result = runner.insert(generate_data([7, 12])) - assert result.state == JobResult.State.FAILED - assert "Primary key to insert already exist" in result.error_message - - runner.upsert(generate_data([7, 12])) - assert_equal( - runner.read_all(batch_size=batch_size).sort_by("int64"), - pa.concat_tables([ - input_data0, input_data1, input_data2, input_data3, input_data4, - generate_data([12]) - ]).sort_by("int64")) - - # Test delete. - runner.delete(pc.field("int64") < 10) - assert_equal( - runner.read_all(batch_size=batch_size).sort_by("int64"), - pa.concat_tables([generate_data([10, 11, 12])]).sort_by("int64")) - - # Test reading views. - view = sample_dataset.map_batches(fn=_sample_map_udf, - output_schema=sample_dataset.schema, - output_record_fields=["binary"]) - assert_equal( - view.ray(DEFAULT_RAY_OPTIONS).read_all( - batch_size=batch_size).sort_by("int64"), - pa.concat_tables([ - pa.Table.from_pydict({ - "int64": [10, 11, 12], - "float64": [v / 10 + 1 for v in [10, 11, 12]], - "binary": [f"b{v}".encode("utf-8") for v in [10, 11, 12]] - }) - ]).sort_by("int64")) - - # Test a transform on a view. - transform_on_view = view.map_batches(fn=_sample_map_udf, - output_schema=view.schema, - output_record_fields=["binary"]) - assert_equal( - transform_on_view.ray(DEFAULT_RAY_OPTIONS).read_all( - batch_size=batch_size).sort_by("int64"), - pa.concat_tables([ - pa.Table.from_pydict({ - "int64": [10, 11, 12], - "float64": [v / 10 + 2 for v in [10, 11, 12]], - "binary": [f"b{v}".encode("utf-8") for v in [10, 11, 12]] - }) - ]).sort_by("int64")) + for branch in ["branch1", "main"]: + sample_dataset.set_current_branch(branch) + # Test append. + runner.append(input_data0) + + assert_equal( + runner.read_all(batch_size=batch_size).sort_by("int64"), + input_data0.sort_by("int64")) + + runner.append_from([ + lambda: iter([input_data1, input_data2]), lambda: iter([input_data3]), + lambda: iter([input_data4]) + ]) + + assert_equal( + runner.read_all(batch_size=batch_size).sort_by("int64"), + pa.concat_tables( + [input_data0, input_data1, input_data2, input_data3, + input_data4]).sort_by("int64")) + + # Test insert. + result = runner.insert(generate_data([7, 12])) + assert result.state == JobResult.State.FAILED + assert "Primary key to insert already exist" in result.error_message + + runner.upsert(generate_data([7, 12])) + assert_equal( + runner.read_all(batch_size=batch_size).sort_by("int64"), + pa.concat_tables([ + input_data0, input_data1, input_data2, input_data3, input_data4, + generate_data([12]) + ]).sort_by("int64")) + + # Test delete. + runner.delete(pc.field("int64") < 10) + assert_equal( + runner.read_all(batch_size=batch_size).sort_by("int64"), + pa.concat_tables([generate_data([10, 11, 12])]).sort_by("int64")) + + # Test reading views. + view = sample_dataset.map_batches(fn=_sample_map_udf, + output_schema=sample_dataset.schema, + output_record_fields=["binary"]) + + assert_equal( + view.ray(DEFAULT_RAY_OPTIONS).read_all( + batch_size=batch_size).sort_by("int64"), + pa.concat_tables([ + pa.Table.from_pydict({ + "int64": [10, 11, 12], + "float64": [v / 10 + 1 for v in [10, 11, 12]], + "binary": [f"b{v}".encode("utf-8") for v in [10, 11, 12]] + }) + ]).sort_by("int64")) + + # Test a transform on a view. + transform_on_view = view.map_batches(fn=_sample_map_udf, + output_schema=view.schema, + output_record_fields=["binary"]) + assert_equal( + transform_on_view.ray(DEFAULT_RAY_OPTIONS).read_all( + batch_size=batch_size).sort_by("int64"), + pa.concat_tables([ + pa.Table.from_pydict({ + "int64": [10, 11, 12], + "float64": [v / 10 + 2 for v in [10, 11, 12]], + "binary": [f"b{v}".encode("utf-8") for v in [10, 11, 12]] + }) + ]).sort_by("int64")) @pytest.mark.parametrize("enable_row_range_block", [(True,), (False,)]) def test_read_batch_size(self, tmp_path, sample_schema,