From 3611c13ec66864675b2e03d93753acfca4d5ed86 Mon Sep 17 00:00:00 2001 From: Joan Antoni RE Date: Fri, 27 Dec 2024 10:23:09 +0100 Subject: [PATCH 1/9] Move search fixtures to ndbfixtures --- nucliadb/tests/conftest.py | 1 + nucliadb/tests/ndbfixtures/node.py | 532 +++++++++++++++++ .../fixtures.py => ndbfixtures/search.py} | 0 nucliadb/tests/search/conftest.py | 4 +- nucliadb/tests/search/node.py | 556 ------------------ 5 files changed, 535 insertions(+), 558 deletions(-) rename nucliadb/tests/{search/fixtures.py => ndbfixtures/search.py} (100%) delete mode 100644 nucliadb/tests/search/node.py diff --git a/nucliadb/tests/conftest.py b/nucliadb/tests/conftest.py index 467765b7ec..d61dd23fb8 100644 --- a/nucliadb/tests/conftest.py +++ b/nucliadb/tests/conftest.py @@ -36,6 +36,7 @@ "tests.ndbfixtures.standalone", "tests.ndbfixtures.reader", "tests.ndbfixtures.writer", + "tests.ndbfixtures.search", "tests.ndbfixtures.train", "tests.ndbfixtures.ingest", # subcomponents diff --git a/nucliadb/tests/ndbfixtures/node.py b/nucliadb/tests/ndbfixtures/node.py index 203b280228..8b95242205 100644 --- a/nucliadb/tests/ndbfixtures/node.py +++ b/nucliadb/tests/ndbfixtures/node.py @@ -17,13 +17,126 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # +import dataclasses +import logging +import os +import time import uuid +from typing import Union from unittest.mock import patch +import backoff +import docker # type: ignore +import nats import pytest +from grpc import insecure_channel +from grpc_health.v1 import health_pb2_grpc +from grpc_health.v1.health_pb2 import HealthCheckRequest +from nats.js.api import ConsumerConfig +from pytest_docker_fixtures import images # type: ignore +from pytest_docker_fixtures.containers._base import BaseImage # type: ignore from nucliadb.common.cluster import manager from nucliadb.common.cluster.settings import settings as cluster_settings +from nucliadb.common.nidx import NIDX_ENABLED +from nucliadb_protos.nodewriter_pb2 import EmptyQuery, ShardId +from nucliadb_protos.nodewriter_pb2_grpc import NodeWriterStub +from nucliadb_utils.tests.fixtures import get_testing_storage_backend + +logger = logging.getLogger(__name__) + +images.settings["nucliadb_node_reader"] = { + "image": "europe-west4-docker.pkg.dev/nuclia-internal/nuclia/node", + "version": "latest", + "env": { + "FILE_BACKEND": "unset", + "HOST_KEY_PATH": "/data/node.key", + "DATA_PATH": "/data", + "READER_LISTEN_ADDRESS": "0.0.0.0:4445", + "NUCLIADB_DISABLE_ANALYTICS": "True", + "RUST_BACKTRACE": "full", + "RUST_LOG": "nucliadb_*=DEBUG", + }, + "options": { + "command": [ + "/usr/local/bin/node_reader", + ], + "ports": {"4445": ("0.0.0.0", 0)}, + "publish_all_ports": False, + "mem_limit": "3g", # default is 1g, need to override + "platform": "linux/amd64", + }, +} + +images.settings["nucliadb_node_writer"] = { + "image": "europe-west4-docker.pkg.dev/nuclia-internal/nuclia/node", + "version": "latest", + "env": { + "FILE_BACKEND": "unset", + "HOST_KEY_PATH": "/data/node.key", + "DATA_PATH": "/data", + "WRITER_LISTEN_ADDRESS": "0.0.0.0:4446", + "NUCLIADB_DISABLE_ANALYTICS": "True", + "RUST_BACKTRACE": "full", + "RUST_LOG": "nucliadb_*=DEBUG", + }, + "options": { + "command": [ + "/usr/local/bin/node_writer", + ], + "ports": {"4446": ("0.0.0.0", 0)}, + "publish_all_ports": False, + "mem_limit": "3g", # default is 1g, need to override + "platform": "linux/amd64", + }, +} + +images.settings["nucliadb_node_sidecar"] = { + "image": "europe-west4-docker.pkg.dev/nuclia-internal/nuclia/node_sidecar", + "version": "latest", + "env": { + "INDEX_JETSTREAM_SERVERS": "[]", + "CACHE_PUBSUB_NATS_URL": "", + "HOST_KEY_PATH": "/data/node.key", + "DATA_PATH": "/data", + "SIDECAR_LISTEN_ADDRESS": "0.0.0.0:4447", + "READER_LISTEN_ADDRESS": "0.0.0.0:4445", + "WRITER_LISTEN_ADDRESS": "0.0.0.0:4446", + "PYTHONUNBUFFERED": "1", + "LOG_LEVEL": "DEBUG", + }, + "options": { + "command": [ + "node_sidecar", + ], + "ports": {"4447": ("0.0.0.0", 0)}, + "publish_all_ports": False, + "platform": "linux/amd64", + }, +} + +images.settings["nidx"] = { + "image": "nidx", + "version": "latest", + "env": {}, + "options": { + # A few indexers on purpose for faster indexing + "command": [ + "nidx", + "api", + "searcher", + "indexer", + "indexer", + "indexer", + "indexer", + "scheduler", + "worker", + ], + "ports": {"10000": ("0.0.0.0", 0), "10001": ("0.0.0.0", 0)}, + "publish_all_ports": False, + "platform": "linux/amd64", + }, +} @pytest.fixture(scope="function") @@ -50,3 +163,422 @@ async def dummy_index_node_cluster( dummy=True, ) yield + + +def get_container_host(container_obj): + return container_obj.attrs["NetworkSettings"]["IPAddress"] + + +class nucliadbNodeReader(BaseImage): + name = "nucliadb_node_reader" + port = 4445 + + def run(self, volume): + self._volume = volume + self._mount = "/data" + return super(nucliadbNodeReader, self).run() + + def get_image_options(self): + options = super(nucliadbNodeReader, self).get_image_options() + options["volumes"] = {self._volume.name: {"bind": "/data"}} + return options + + def check(self): + channel = insecure_channel(f"{self.host}:{self.get_port()}") + stub = health_pb2_grpc.HealthStub(channel) + pb = HealthCheckRequest(service="nodereader.NodeReader") + try: + result = stub.Check(pb) + return result.status == 1 + except: # noqa + return False + + +class nucliadbNodeWriter(BaseImage): + name = "nucliadb_node_writer" + port = 4446 + + def run(self, volume): + self._volume = volume + self._mount = "/data" + return super(nucliadbNodeWriter, self).run() + + def get_image_options(self): + options = super(nucliadbNodeWriter, self).get_image_options() + options["volumes"] = {self._volume.name: {"bind": "/data"}} + return options + + def check(self): + channel = insecure_channel(f"{self.host}:{self.get_port()}") + stub = health_pb2_grpc.HealthStub(channel) + pb = HealthCheckRequest(service="nodewriter.NodeWriter") + try: + result = stub.Check(pb) + return result.status == 1 + except: # noqa + return False + + +class nucliadbNodeSidecar(BaseImage): + name = "nucliadb_node_sidecar" + port = 4447 + + def run(self, volume): + self._volume = volume + self._mount = "/data" + return super(nucliadbNodeSidecar, self).run() + + def get_image_options(self): + options = super(nucliadbNodeSidecar, self).get_image_options() + options["volumes"] = {self._volume.name: {"bind": "/data"}} + return options + + def check(self): + channel = insecure_channel(f"{self.host}:{self.get_port()}") + stub = health_pb2_grpc.HealthStub(channel) + pb = HealthCheckRequest(service="") + try: + result = stub.Check(pb) + return result.status == 1 + except: # noqa + return False + + +class NidxImage(BaseImage): + name = "nidx" + + +nucliadb_node_1_reader = nucliadbNodeReader() +nucliadb_node_1_writer = nucliadbNodeWriter() +nucliadb_node_1_sidecar = nucliadbNodeSidecar() + +nucliadb_node_2_reader = nucliadbNodeReader() +nucliadb_node_2_writer = nucliadbNodeWriter() +nucliadb_node_2_sidecar = nucliadbNodeSidecar() + + +@dataclasses.dataclass +class NodeS3Storage: + server: str + + def envs(self): + return { + "FILE_BACKEND": "s3", + "S3_CLIENT_ID": "fake", + "S3_CLIENT_SECRET": "fake", + "S3_BUCKET": "test", + "S3_INDEXING_BUCKET": "indexing", + "S3_DEADLETTER_BUCKET": "deadletter", + "S3_ENDPOINT": self.server, + } + + +@dataclasses.dataclass +class NodeGCSStorage: + server: str + + def envs(self): + return { + "FILE_BACKEND": "gcs", + "GCS_BUCKET": "test", + "GCS_INDEXING_BUCKET": "indexing", + "GCS_DEADLETTER_BUCKET": "deadletter", + "GCS_ENDPOINT_URL": self.server, + } + + +NodeStorage = Union[NodeGCSStorage, NodeS3Storage] + + +class _NodeRunner: + def __init__(self, natsd, storage: NodeStorage): + self.docker_client = docker.from_env(version=BaseImage.docker_version) + self.natsd = natsd + self.storage = storage + self.data = {} # type: ignore + + def start(self): + docker_platform_name = self.docker_client.api.version()["Platform"]["Name"].upper() + if "GITHUB_ACTION" in os.environ: + # Valid when using github actions + docker_internal_host = "172.17.0.1" + elif docker_platform_name == "DOCKER ENGINE - COMMUNITY": + # for linux users + docker_internal_host = "172.17.0.1" + elif "DOCKER DESKTOP" in docker_platform_name: + # Valid when using Docker desktop + docker_internal_host = "host.docker.internal" + else: + docker_internal_host = "172.17.0.1" + + self.volume_node_1 = self.docker_client.volumes.create(driver="local") + self.volume_node_2 = self.docker_client.volumes.create(driver="local") + + self.storage.server = self.storage.server.replace("localhost", docker_internal_host) + images.settings["nucliadb_node_writer"]["env"].update(self.storage.envs()) + writer1_host, writer1_port = nucliadb_node_1_writer.run(self.volume_node_1) + writer2_host, writer2_port = nucliadb_node_2_writer.run(self.volume_node_2) + + reader1_host, reader1_port = nucliadb_node_1_reader.run(self.volume_node_1) + reader2_host, reader2_port = nucliadb_node_2_reader.run(self.volume_node_2) + + natsd_server = self.natsd.replace("localhost", docker_internal_host) + images.settings["nucliadb_node_sidecar"]["env"].update( + { + "INDEX_JETSTREAM_SERVERS": f'["{natsd_server}"]', + "CACHE_PUBSUB_NATS_URL": f'["{natsd_server}"]', + "READER_LISTEN_ADDRESS": f"{docker_internal_host}:{reader1_port}", + "WRITER_LISTEN_ADDRESS": f"{docker_internal_host}:{writer1_port}", + } + ) + images.settings["nucliadb_node_sidecar"]["env"].update(self.storage.envs()) + + sidecar1_host, sidecar1_port = nucliadb_node_1_sidecar.run(self.volume_node_1) + + images.settings["nucliadb_node_sidecar"]["env"]["READER_LISTEN_ADDRESS"] = ( + f"{docker_internal_host}:{reader2_port}" + ) + images.settings["nucliadb_node_sidecar"]["env"]["WRITER_LISTEN_ADDRESS"] = ( + f"{docker_internal_host}:{writer2_port}" + ) + + sidecar2_host, sidecar2_port = nucliadb_node_2_sidecar.run(self.volume_node_2) + + writer1_internal_host = get_container_host(nucliadb_node_1_writer.container_obj) + writer2_internal_host = get_container_host(nucliadb_node_2_writer.container_obj) + + self.data.update( + { + "writer1_internal_host": writer1_internal_host, + "writer2_internal_host": writer2_internal_host, + "writer1": { + "host": writer1_host, + "port": writer1_port, + }, + "writer2": { + "host": writer2_host, + "port": writer2_port, + }, + "reader1": { + "host": reader1_host, + "port": reader1_port, + }, + "reader2": { + "host": reader2_host, + "port": reader2_port, + }, + "sidecar1": { + "host": sidecar1_host, + "port": sidecar1_port, + }, + "sidecar2": { + "host": sidecar2_host, + "port": sidecar2_port, + }, + } + ) + return self.data + + def stop(self): + container_ids = [] + for component in [ + nucliadb_node_1_reader, + nucliadb_node_1_writer, + nucliadb_node_1_sidecar, + nucliadb_node_2_writer, + nucliadb_node_2_reader, + nucliadb_node_2_sidecar, + ]: + container_obj = getattr(component, "container_obj", None) + if container_obj: + container_ids.append(container_obj.id) + component.stop() + + for container_id in container_ids: + for _ in range(5): + try: + self.docker_client.containers.get(container_id) + except docker.errors.NotFound: + break + time.sleep(2) + + self.volume_node_1.remove() + self.volume_node_2.remove() + + def setup_env(self): + # reset on every test run in case something touches it + cluster_settings.writer_port_map = { + self.data["writer1_internal_host"]: self.data["writer1"]["port"], + self.data["writer2_internal_host"]: self.data["writer2"]["port"], + } + cluster_settings.reader_port_map = { + self.data["writer1_internal_host"]: self.data["reader1"]["port"], + self.data["writer2_internal_host"]: self.data["reader2"]["port"], + } + + cluster_settings.node_writer_port = None # type: ignore + cluster_settings.node_reader_port = None # type: ignore + + cluster_settings.cluster_discovery_mode = "manual" + cluster_settings.cluster_discovery_manual_addresses = [ + self.data["writer1_internal_host"], + self.data["writer2_internal_host"], + ] + + +@pytest.fixture(scope="session") +def gcs_node_storage(gcs): + return NodeGCSStorage(server=gcs) + + +@pytest.fixture(scope="session") +def s3_node_storage(s3): + return NodeS3Storage(server=s3) + + +@pytest.fixture(scope="session") +def node_storage(request): + backend = get_testing_storage_backend() + if backend == "gcs": + return request.getfixturevalue("gcs_node_storage") + elif backend == "s3": + return request.getfixturevalue("s3_node_storage") + else: + print(f"Unknown storage backend {backend}, using gcs") + return request.getfixturevalue("gcs_node_storage") + + +@pytest.fixture(scope="session") +def gcs_nidx_storage(gcs): + return { + "INDEXER__OBJECT_STORE": "gcs", + "INDEXER__BUCKET": "indexing", + "INDEXER__ENDPOINT": gcs, + "STORAGE__OBJECT_STORE": "gcs", + "STORAGE__ENDPOINT": gcs, + "STORAGE__BUCKET": "nidx", + } + + +@pytest.fixture(scope="session") +def s3_nidx_storage(s3): + return { + "INDEXER__OBJECT_STORE": "s3", + "INDEXER__BUCKET": "indexing", + "INDEXER__ENDPOINT": s3, + "STORAGE__OBJECT_STORE": "s3", + "STORAGE__ENDPOINT": s3, + "STORAGE__BUCKET": "nidx", + } + + +@pytest.fixture(scope="session") +def nidx_storage(request): + backend = get_testing_storage_backend() + if backend == "gcs": + return request.getfixturevalue("gcs_nidx_storage") + elif backend == "s3": + return request.getfixturevalue("s3_nidx_storage") + + +@pytest.fixture(scope="session", autouse=False) +def _node(natsd: str, node_storage): + nr = _NodeRunner(natsd, node_storage) + try: + cluster_info = nr.start() + except Exception: + nr.stop() + raise + nr.setup_env() + yield cluster_info + nr.stop() + + +@pytest.fixture(scope="session") +async def _nidx(natsd, nidx_storage, pg): + if not NIDX_ENABLED: + yield + return + + # Create needed NATS stream/consumer + nc = await nats.connect(servers=[natsd]) + js = nc.jetstream() + await js.add_stream(name="nidx", subjects=["nidx"]) + await js.add_consumer(stream="nidx", config=ConsumerConfig(name="nidx")) + await nc.drain() + await nc.close() + + # Run nidx + images.settings["nidx"]["env"] = { + "RUST_LOG": "info", + "METADATA__DATABASE_URL": f"postgresql://postgres:postgres@172.17.0.1:{pg[1]}/postgres", + "INDEXER__NATS_SERVER": natsd.replace("localhost", "172.17.0.1"), + **nidx_storage, + } + image = NidxImage() + image.run() + + api_port = image.get_port(10000) + searcher_port = image.get_port(10001) + + # Configure settings + from nucliadb_utils.settings import indexing_settings + + cluster_settings.nidx_api_address = f"localhost:{api_port}" + cluster_settings.nidx_searcher_address = f"localhost:{searcher_port}" + indexing_settings.index_nidx_subject = "nidx" + + yield + + image.stop() + + +@pytest.fixture(scope="function") +def node(_nidx, _node, request): + # clean up all shard data before each test + channel1 = insecure_channel(f"{_node['writer1']['host']}:{_node['writer1']['port']}") + channel2 = insecure_channel(f"{_node['writer2']['host']}:{_node['writer2']['port']}") + writer1 = NodeWriterStub(channel1) + writer2 = NodeWriterStub(channel2) + + logger.debug("cleaning up shards data") + try: + cleanup_node(writer1) + cleanup_node(writer2) + except Exception: + logger.error( + "Error cleaning up shards data. Maybe the node fixture could not start properly?", + exc_info=True, + ) + + client = docker.client.from_env() + containers_by_port = {} + for container in client.containers.list(): + name = container.name + command = container.attrs["Config"]["Cmd"] + ports = container.ports + print(f"container {name} executing {command} is using ports: {ports}") + + for internal_port in container.ports: + for host in container.ports[internal_port]: + port = host["HostPort"] + port_containers = containers_by_port.setdefault(port, []) + if container not in port_containers: + port_containers.append(container) + + for port, containers in containers_by_port.items(): + if len(containers) > 1: + names = ", ".join([container.name for container in containers]) + print(f"ATENTION! Containers {names} share port {port}!") + raise + finally: + channel1.close() + channel2.close() + + yield _node + + +@backoff.on_exception(backoff.expo, Exception, jitter=backoff.random_jitter, max_tries=5) +def cleanup_node(writer: NodeWriterStub): + for shard in writer.ListShards(EmptyQuery()).ids: + writer.DeleteShard(ShardId(id=shard.id)) diff --git a/nucliadb/tests/search/fixtures.py b/nucliadb/tests/ndbfixtures/search.py similarity index 100% rename from nucliadb/tests/search/fixtures.py rename to nucliadb/tests/ndbfixtures/search.py diff --git a/nucliadb/tests/search/conftest.py b/nucliadb/tests/search/conftest.py index 59c6a57c1f..9b3b67a554 100644 --- a/nucliadb/tests/search/conftest.py +++ b/nucliadb/tests/search/conftest.py @@ -24,8 +24,8 @@ "tests.ndbfixtures.processing", "tests.ndbfixtures.standalone", "tests.ingest.fixtures", # should be refactored out - "tests.search.node", - "tests.search.fixtures", + "tests.ndbfixtures.node", + "tests.ndbfixtures.search", "nucliadb_utils.tests.fixtures", "nucliadb_utils.tests.gcs", "nucliadb_utils.tests.s3", diff --git a/nucliadb/tests/search/node.py b/nucliadb/tests/search/node.py deleted file mode 100644 index 99dfe42d24..0000000000 --- a/nucliadb/tests/search/node.py +++ /dev/null @@ -1,556 +0,0 @@ -# Copyright (C) 2021 Bosutech XXI S.L. -# -# nucliadb is offered under the AGPL v3.0 and as commercial software. -# For commercial licensing, contact us at info@nuclia.com. -# -# AGPL: -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -# - -import dataclasses -import logging -import os -import time -from typing import Union - -import backoff -import docker # type: ignore -import nats -import pytest -from grpc import insecure_channel -from grpc_health.v1 import health_pb2_grpc -from grpc_health.v1.health_pb2 import HealthCheckRequest -from nats.js.api import ConsumerConfig -from pytest_docker_fixtures import images # type: ignore -from pytest_docker_fixtures.containers._base import BaseImage # type: ignore - -from nucliadb.common.cluster.settings import settings as cluster_settings -from nucliadb.common.nidx import NIDX_ENABLED -from nucliadb_protos.nodewriter_pb2 import EmptyQuery, ShardId -from nucliadb_protos.nodewriter_pb2_grpc import NodeWriterStub -from nucliadb_utils.tests.fixtures import get_testing_storage_backend - -logger = logging.getLogger(__name__) - -images.settings["nucliadb_node_reader"] = { - "image": "europe-west4-docker.pkg.dev/nuclia-internal/nuclia/node", - "version": "latest", - "env": { - "FILE_BACKEND": "unset", - "HOST_KEY_PATH": "/data/node.key", - "DATA_PATH": "/data", - "READER_LISTEN_ADDRESS": "0.0.0.0:4445", - "NUCLIADB_DISABLE_ANALYTICS": "True", - "RUST_BACKTRACE": "full", - "RUST_LOG": "nucliadb_*=DEBUG", - }, - "options": { - "command": [ - "/usr/local/bin/node_reader", - ], - "ports": {"4445": ("0.0.0.0", 0)}, - "publish_all_ports": False, - "mem_limit": "3g", # default is 1g, need to override - "platform": "linux/amd64", - }, -} - -images.settings["nucliadb_node_writer"] = { - "image": "europe-west4-docker.pkg.dev/nuclia-internal/nuclia/node", - "version": "latest", - "env": { - "FILE_BACKEND": "unset", - "HOST_KEY_PATH": "/data/node.key", - "DATA_PATH": "/data", - "WRITER_LISTEN_ADDRESS": "0.0.0.0:4446", - "NUCLIADB_DISABLE_ANALYTICS": "True", - "RUST_BACKTRACE": "full", - "RUST_LOG": "nucliadb_*=DEBUG", - }, - "options": { - "command": [ - "/usr/local/bin/node_writer", - ], - "ports": {"4446": ("0.0.0.0", 0)}, - "publish_all_ports": False, - "mem_limit": "3g", # default is 1g, need to override - "platform": "linux/amd64", - }, -} - -images.settings["nucliadb_node_sidecar"] = { - "image": "europe-west4-docker.pkg.dev/nuclia-internal/nuclia/node_sidecar", - "version": "latest", - "env": { - "INDEX_JETSTREAM_SERVERS": "[]", - "CACHE_PUBSUB_NATS_URL": "", - "HOST_KEY_PATH": "/data/node.key", - "DATA_PATH": "/data", - "SIDECAR_LISTEN_ADDRESS": "0.0.0.0:4447", - "READER_LISTEN_ADDRESS": "0.0.0.0:4445", - "WRITER_LISTEN_ADDRESS": "0.0.0.0:4446", - "PYTHONUNBUFFERED": "1", - "LOG_LEVEL": "DEBUG", - }, - "options": { - "command": [ - "node_sidecar", - ], - "ports": {"4447": ("0.0.0.0", 0)}, - "publish_all_ports": False, - "platform": "linux/amd64", - }, -} - -images.settings["nidx"] = { - "image": "nidx", - "version": "latest", - "env": {}, - "options": { - # A few indexers on purpose for faster indexing - "command": [ - "nidx", - "api", - "searcher", - "indexer", - "indexer", - "indexer", - "indexer", - "scheduler", - "worker", - ], - "ports": {"10000": ("0.0.0.0", 0), "10001": ("0.0.0.0", 0)}, - "publish_all_ports": False, - "platform": "linux/amd64", - }, -} - - -def get_container_host(container_obj): - return container_obj.attrs["NetworkSettings"]["IPAddress"] - - -class nucliadbNodeReader(BaseImage): - name = "nucliadb_node_reader" - port = 4445 - - def run(self, volume): - self._volume = volume - self._mount = "/data" - return super(nucliadbNodeReader, self).run() - - def get_image_options(self): - options = super(nucliadbNodeReader, self).get_image_options() - options["volumes"] = {self._volume.name: {"bind": "/data"}} - return options - - def check(self): - channel = insecure_channel(f"{self.host}:{self.get_port()}") - stub = health_pb2_grpc.HealthStub(channel) - pb = HealthCheckRequest(service="nodereader.NodeReader") - try: - result = stub.Check(pb) - return result.status == 1 - except: # noqa - return False - - -class nucliadbNodeWriter(BaseImage): - name = "nucliadb_node_writer" - port = 4446 - - def run(self, volume): - self._volume = volume - self._mount = "/data" - return super(nucliadbNodeWriter, self).run() - - def get_image_options(self): - options = super(nucliadbNodeWriter, self).get_image_options() - options["volumes"] = {self._volume.name: {"bind": "/data"}} - return options - - def check(self): - channel = insecure_channel(f"{self.host}:{self.get_port()}") - stub = health_pb2_grpc.HealthStub(channel) - pb = HealthCheckRequest(service="nodewriter.NodeWriter") - try: - result = stub.Check(pb) - return result.status == 1 - except: # noqa - return False - - -class nucliadbNodeSidecar(BaseImage): - name = "nucliadb_node_sidecar" - port = 4447 - - def run(self, volume): - self._volume = volume - self._mount = "/data" - return super(nucliadbNodeSidecar, self).run() - - def get_image_options(self): - options = super(nucliadbNodeSidecar, self).get_image_options() - options["volumes"] = {self._volume.name: {"bind": "/data"}} - return options - - def check(self): - channel = insecure_channel(f"{self.host}:{self.get_port()}") - stub = health_pb2_grpc.HealthStub(channel) - pb = HealthCheckRequest(service="") - try: - result = stub.Check(pb) - return result.status == 1 - except: # noqa - return False - - -class NidxImage(BaseImage): - name = "nidx" - - -nucliadb_node_1_reader = nucliadbNodeReader() -nucliadb_node_1_writer = nucliadbNodeWriter() -nucliadb_node_1_sidecar = nucliadbNodeSidecar() - -nucliadb_node_2_reader = nucliadbNodeReader() -nucliadb_node_2_writer = nucliadbNodeWriter() -nucliadb_node_2_sidecar = nucliadbNodeSidecar() - - -@dataclasses.dataclass -class NodeS3Storage: - server: str - - def envs(self): - return { - "FILE_BACKEND": "s3", - "S3_CLIENT_ID": "fake", - "S3_CLIENT_SECRET": "fake", - "S3_BUCKET": "test", - "S3_INDEXING_BUCKET": "indexing", - "S3_DEADLETTER_BUCKET": "deadletter", - "S3_ENDPOINT": self.server, - } - - -@dataclasses.dataclass -class NodeGCSStorage: - server: str - - def envs(self): - return { - "FILE_BACKEND": "gcs", - "GCS_BUCKET": "test", - "GCS_INDEXING_BUCKET": "indexing", - "GCS_DEADLETTER_BUCKET": "deadletter", - "GCS_ENDPOINT_URL": self.server, - } - - -NodeStorage = Union[NodeGCSStorage, NodeS3Storage] - - -class _NodeRunner: - def __init__(self, natsd, storage: NodeStorage): - self.docker_client = docker.from_env(version=BaseImage.docker_version) - self.natsd = natsd - self.storage = storage - self.data = {} # type: ignore - - def start(self): - docker_platform_name = self.docker_client.api.version()["Platform"]["Name"].upper() - if "GITHUB_ACTION" in os.environ: - # Valid when using github actions - docker_internal_host = "172.17.0.1" - elif docker_platform_name == "DOCKER ENGINE - COMMUNITY": - # for linux users - docker_internal_host = "172.17.0.1" - elif "DOCKER DESKTOP" in docker_platform_name: - # Valid when using Docker desktop - docker_internal_host = "host.docker.internal" - else: - docker_internal_host = "172.17.0.1" - - self.volume_node_1 = self.docker_client.volumes.create(driver="local") - self.volume_node_2 = self.docker_client.volumes.create(driver="local") - - self.storage.server = self.storage.server.replace("localhost", docker_internal_host) - images.settings["nucliadb_node_writer"]["env"].update(self.storage.envs()) - writer1_host, writer1_port = nucliadb_node_1_writer.run(self.volume_node_1) - writer2_host, writer2_port = nucliadb_node_2_writer.run(self.volume_node_2) - - reader1_host, reader1_port = nucliadb_node_1_reader.run(self.volume_node_1) - reader2_host, reader2_port = nucliadb_node_2_reader.run(self.volume_node_2) - - natsd_server = self.natsd.replace("localhost", docker_internal_host) - images.settings["nucliadb_node_sidecar"]["env"].update( - { - "INDEX_JETSTREAM_SERVERS": f'["{natsd_server}"]', - "CACHE_PUBSUB_NATS_URL": f'["{natsd_server}"]', - "READER_LISTEN_ADDRESS": f"{docker_internal_host}:{reader1_port}", - "WRITER_LISTEN_ADDRESS": f"{docker_internal_host}:{writer1_port}", - } - ) - images.settings["nucliadb_node_sidecar"]["env"].update(self.storage.envs()) - - sidecar1_host, sidecar1_port = nucliadb_node_1_sidecar.run(self.volume_node_1) - - images.settings["nucliadb_node_sidecar"]["env"]["READER_LISTEN_ADDRESS"] = ( - f"{docker_internal_host}:{reader2_port}" - ) - images.settings["nucliadb_node_sidecar"]["env"]["WRITER_LISTEN_ADDRESS"] = ( - f"{docker_internal_host}:{writer2_port}" - ) - - sidecar2_host, sidecar2_port = nucliadb_node_2_sidecar.run(self.volume_node_2) - - writer1_internal_host = get_container_host(nucliadb_node_1_writer.container_obj) - writer2_internal_host = get_container_host(nucliadb_node_2_writer.container_obj) - - self.data.update( - { - "writer1_internal_host": writer1_internal_host, - "writer2_internal_host": writer2_internal_host, - "writer1": { - "host": writer1_host, - "port": writer1_port, - }, - "writer2": { - "host": writer2_host, - "port": writer2_port, - }, - "reader1": { - "host": reader1_host, - "port": reader1_port, - }, - "reader2": { - "host": reader2_host, - "port": reader2_port, - }, - "sidecar1": { - "host": sidecar1_host, - "port": sidecar1_port, - }, - "sidecar2": { - "host": sidecar2_host, - "port": sidecar2_port, - }, - } - ) - return self.data - - def stop(self): - container_ids = [] - for component in [ - nucliadb_node_1_reader, - nucliadb_node_1_writer, - nucliadb_node_1_sidecar, - nucliadb_node_2_writer, - nucliadb_node_2_reader, - nucliadb_node_2_sidecar, - ]: - container_obj = getattr(component, "container_obj", None) - if container_obj: - container_ids.append(container_obj.id) - component.stop() - - for container_id in container_ids: - for _ in range(5): - try: - self.docker_client.containers.get(container_id) - except docker.errors.NotFound: - break - time.sleep(2) - - self.volume_node_1.remove() - self.volume_node_2.remove() - - def setup_env(self): - # reset on every test run in case something touches it - cluster_settings.writer_port_map = { - self.data["writer1_internal_host"]: self.data["writer1"]["port"], - self.data["writer2_internal_host"]: self.data["writer2"]["port"], - } - cluster_settings.reader_port_map = { - self.data["writer1_internal_host"]: self.data["reader1"]["port"], - self.data["writer2_internal_host"]: self.data["reader2"]["port"], - } - - cluster_settings.node_writer_port = None # type: ignore - cluster_settings.node_reader_port = None # type: ignore - - cluster_settings.cluster_discovery_mode = "manual" - cluster_settings.cluster_discovery_manual_addresses = [ - self.data["writer1_internal_host"], - self.data["writer2_internal_host"], - ] - - -@pytest.fixture(scope="session") -def gcs_node_storage(gcs): - return NodeGCSStorage(server=gcs) - - -@pytest.fixture(scope="session") -def s3_node_storage(s3): - return NodeS3Storage(server=s3) - - -@pytest.fixture(scope="session") -def node_storage(request): - backend = get_testing_storage_backend() - if backend == "gcs": - return request.getfixturevalue("gcs_node_storage") - elif backend == "s3": - return request.getfixturevalue("s3_node_storage") - else: - print(f"Unknown storage backend {backend}, using gcs") - return request.getfixturevalue("gcs_node_storage") - - -@pytest.fixture(scope="session") -def gcs_nidx_storage(gcs): - return { - "INDEXER__OBJECT_STORE": "gcs", - "INDEXER__BUCKET": "indexing", - "INDEXER__ENDPOINT": gcs, - "STORAGE__OBJECT_STORE": "gcs", - "STORAGE__ENDPOINT": gcs, - "STORAGE__BUCKET": "nidx", - } - - -@pytest.fixture(scope="session") -def s3_nidx_storage(s3): - return { - "INDEXER__OBJECT_STORE": "s3", - "INDEXER__BUCKET": "indexing", - "INDEXER__ENDPOINT": s3, - "STORAGE__OBJECT_STORE": "s3", - "STORAGE__ENDPOINT": s3, - "STORAGE__BUCKET": "nidx", - } - - -@pytest.fixture(scope="session") -def nidx_storage(request): - backend = get_testing_storage_backend() - if backend == "gcs": - return request.getfixturevalue("gcs_nidx_storage") - elif backend == "s3": - return request.getfixturevalue("s3_nidx_storage") - - -@pytest.fixture(scope="session", autouse=False) -def _node(natsd: str, node_storage): - nr = _NodeRunner(natsd, node_storage) - try: - cluster_info = nr.start() - except Exception: - nr.stop() - raise - nr.setup_env() - yield cluster_info - nr.stop() - - -@pytest.fixture(scope="session") -async def _nidx(natsd, nidx_storage, pg): - if not NIDX_ENABLED: - yield - return - - # Create needed NATS stream/consumer - nc = await nats.connect(servers=[natsd]) - js = nc.jetstream() - await js.add_stream(name="nidx", subjects=["nidx"]) - await js.add_consumer(stream="nidx", config=ConsumerConfig(name="nidx")) - await nc.drain() - await nc.close() - - # Run nidx - images.settings["nidx"]["env"] = { - "RUST_LOG": "info", - "METADATA__DATABASE_URL": f"postgresql://postgres:postgres@172.17.0.1:{pg[1]}/postgres", - "INDEXER__NATS_SERVER": natsd.replace("localhost", "172.17.0.1"), - **nidx_storage, - } - image = NidxImage() - image.run() - - api_port = image.get_port(10000) - searcher_port = image.get_port(10001) - - # Configure settings - from nucliadb_utils.settings import indexing_settings - - cluster_settings.nidx_api_address = f"localhost:{api_port}" - cluster_settings.nidx_searcher_address = f"localhost:{searcher_port}" - indexing_settings.index_nidx_subject = "nidx" - - yield - - image.stop() - - -@pytest.fixture(scope="function") -def node(_nidx, _node, request): - # clean up all shard data before each test - channel1 = insecure_channel(f"{_node['writer1']['host']}:{_node['writer1']['port']}") - channel2 = insecure_channel(f"{_node['writer2']['host']}:{_node['writer2']['port']}") - writer1 = NodeWriterStub(channel1) - writer2 = NodeWriterStub(channel2) - - logger.debug("cleaning up shards data") - try: - cleanup_node(writer1) - cleanup_node(writer2) - except Exception: - logger.error( - "Error cleaning up shards data. Maybe the node fixture could not start properly?", - exc_info=True, - ) - - client = docker.client.from_env() - containers_by_port = {} - for container in client.containers.list(): - name = container.name - command = container.attrs["Config"]["Cmd"] - ports = container.ports - print(f"container {name} executing {command} is using ports: {ports}") - - for internal_port in container.ports: - for host in container.ports[internal_port]: - port = host["HostPort"] - port_containers = containers_by_port.setdefault(port, []) - if container not in port_containers: - port_containers.append(container) - - for port, containers in containers_by_port.items(): - if len(containers) > 1: - names = ", ".join([container.name for container in containers]) - print(f"ATENTION! Containers {names} share port {port}!") - raise - finally: - channel1.close() - channel2.close() - - yield _node - - -@backoff.on_exception(backoff.expo, Exception, jitter=backoff.random_jitter, max_tries=5) -def cleanup_node(writer: NodeWriterStub): - for shard in writer.ListShards(EmptyQuery()).ids: - writer.DeleteShard(ShardId(id=shard.id)) From aac99f5254c7676e93e49778aafbb4a6ac05c7b4 Mon Sep 17 00:00:00 2001 From: Joan Antoni RE Date: Fri, 27 Dec 2024 17:04:10 +0100 Subject: [PATCH 2/9] Better utility set/clear for ndbfixtures --- nucliadb/tests/ndbfixtures/utils.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/nucliadb/tests/ndbfixtures/utils.py b/nucliadb/tests/ndbfixtures/utils.py index f360344d36..5c66c9fadb 100644 --- a/nucliadb/tests/ndbfixtures/utils.py +++ b/nucliadb/tests/ndbfixtures/utils.py @@ -18,13 +18,19 @@ # along with this program. If not, see . # +import logging +from contextlib import contextmanager from enum import Enum -from typing import Callable, Optional +from typing import Any, Callable, Optional +from unittest.mock import patch from fastapi import FastAPI from httpx import ASGITransport, AsyncClient from nucliadb.search import API_PREFIX +from nucliadb_utils.utilities import MAIN + +logger = logging.getLogger("fixtures.utils") def create_api_client_factory(application: FastAPI) -> Callable[..., AsyncClient]: @@ -56,3 +62,21 @@ def _make_client_fixture( return client return _make_client_fixture + + +@contextmanager +def global_utility(name: str, util: Any): + """Hacky set_utility used in tests to provide proper setup/cleanup of utilities. + + Tests can sometimes mess with global state. While fixtures add/remove global + utilities, component lifecycles do the same. Sometimes, we can left + utilities unclean or overwrite utilities. This context manager allows tests + to remove utilities letting the previously set one. + + """ + + if name in MAIN: + logger.warning(f"Overwriting previously set utility {name}: {MAIN[name]} with {util}") + + with patch.dict(MAIN, values={name: util}, clear=False): + yield util From fbfe513f3640c45f77de73fe3e5798fa46ede82e Mon Sep 17 00:00:00 2001 From: Joan Antoni RE Date: Fri, 27 Dec 2024 17:09:25 +0100 Subject: [PATCH 3/9] Decouple search from tests/fixtures --- nucliadb/tests/fixtures.py | 2 + nucliadb/tests/ingest/fixtures.py | 34 +------------- nucliadb/tests/ndbfixtures/common.py | 67 +++++++++++++++++++++++----- nucliadb/tests/ndbfixtures/search.py | 10 +++++ nucliadb/tests/search/conftest.py | 1 - 5 files changed, 70 insertions(+), 44 deletions(-) diff --git a/nucliadb/tests/fixtures.py b/nucliadb/tests/fixtures.py index 0a95702f97..cb3ca6ab04 100644 --- a/nucliadb/tests/fixtures.py +++ b/nucliadb/tests/fixtures.py @@ -352,6 +352,8 @@ async def knowledge_graph(nucliadb_writer: AsyncClient, nucliadb_grpc: WriterStu return (nodes, edges) +# TODO: remove after migrating tests/nucliadb/ to ndbfixtures. fixture already +# moved to ndbfixtures.common @pytest.fixture(scope="function") async def stream_audit(natsd: str, mocker): from nucliadb_utils.audit.stream import StreamAuditStorage diff --git a/nucliadb/tests/ingest/fixtures.py b/nucliadb/tests/ingest/fixtures.py index e4f784fdb1..1d1f66e0e2 100644 --- a/nucliadb/tests/ingest/fixtures.py +++ b/nucliadb/tests/ingest/fixtures.py @@ -23,7 +23,7 @@ from dataclasses import dataclass from datetime import datetime from os.path import dirname, getsize -from typing import AsyncIterator, Iterable, Iterator, Optional +from typing import AsyncIterator, Iterable, Optional from unittest.mock import AsyncMock, patch import nats @@ -49,14 +49,11 @@ from nucliadb_protos.knowledgebox_pb2 import SemanticModelMetadata from nucliadb_protos.writer_pb2 import BrokerMessage from nucliadb_utils import const -from nucliadb_utils.audit.audit import AuditStorage -from nucliadb_utils.audit.basic import BasicAuditStorage -from nucliadb_utils.audit.stream import StreamAuditStorage from nucliadb_utils.cache.nats import NatsPubsub from nucliadb_utils.cache.pubsub import PubSubDriver from nucliadb_utils.indexing import IndexingUtility from nucliadb_utils.nats import NatsConnectionManager -from nucliadb_utils.settings import audit_settings, indexing_settings, transaction_settings +from nucliadb_utils.settings import indexing_settings, transaction_settings from nucliadb_utils.storages.settings import settings as storage_settings from nucliadb_utils.storages.storage import Storage from nucliadb_utils.transaction import TransactionUtility @@ -225,33 +222,6 @@ async def knowledgebox_with_vectorsets( await KnowledgeBox.delete(maindb_driver, kbid) -@pytest.fixture(scope="function") -def audit(basic_audit: BasicAuditStorage) -> Iterator[AuditStorage]: - # XXX: why aren't we settings the utility? - yield basic_audit - - -@pytest.fixture(scope="function") -async def basic_audit() -> AsyncIterator[BasicAuditStorage]: - audit = BasicAuditStorage() - await audit.initialize() - yield audit - await audit.finalize() - - -@pytest.fixture(scope="function") -async def stream_audit(nats_server: str) -> AsyncIterator[StreamAuditStorage]: - audit = StreamAuditStorage( - [nats_server], - audit_settings.audit_jetstream_target, # type: ignore - audit_settings.audit_partitions, - audit_settings.audit_hash_seed, - ) - await audit.initialize() - yield audit - await audit.finalize() - - @pytest.fixture(scope="function") async def indexing_utility( dummy_indexing_utility: IndexingUtility, diff --git a/nucliadb/tests/ndbfixtures/common.py b/nucliadb/tests/ndbfixtures/common.py index 7995eb7fc5..13e110d567 100644 --- a/nucliadb/tests/ndbfixtures/common.py +++ b/nucliadb/tests/ndbfixtures/common.py @@ -18,31 +18,76 @@ # along with this program. If not, see . # from os.path import dirname -from typing import AsyncIterator +from typing import AsyncIterator, Iterator import pytest +from pytest_mock import MockerFixture from nucliadb.common.cluster.manager import KBShardManager from nucliadb.common.maindb.driver import Driver +from nucliadb_utils.audit.audit import AuditStorage +from nucliadb_utils.audit.basic import BasicAuditStorage +from nucliadb_utils.audit.stream import StreamAuditStorage +from nucliadb_utils.settings import audit_settings from nucliadb_utils.storages.settings import settings as storage_settings from nucliadb_utils.storages.storage import Storage -from nucliadb_utils.utilities import ( - Utility, - clean_utility, - set_utility, -) +from nucliadb_utils.utilities import Utility, clean_utility, set_utility +from tests.ndbfixtures.utils import global_utility + +# Audit @pytest.fixture(scope="function") -async def shard_manager(storage: Storage, maindb_driver: Driver) -> AsyncIterator[KBShardManager]: - sm = KBShardManager() - set_utility(Utility.SHARD_MANAGER, sm) +def audit(basic_audit: BasicAuditStorage) -> Iterator[AuditStorage]: + yield basic_audit - yield sm - clean_utility(Utility.SHARD_MANAGER) +@pytest.fixture(scope="function") +async def basic_audit() -> AsyncIterator[BasicAuditStorage]: + audit = BasicAuditStorage() + await audit.initialize() + with global_utility(Utility.AUDIT, audit): + yield audit + await audit.finalize() + + +@pytest.fixture(scope="function") +async def stream_audit(nats_server: str, mocker: MockerFixture) -> AsyncIterator[StreamAuditStorage]: + audit = StreamAuditStorage( + [nats_server], + audit_settings.audit_jetstream_target, # type: ignore + audit_settings.audit_partitions, + audit_settings.audit_hash_seed, + ) + await audit.initialize() + + mocker.spy(audit, "send") + mocker.spy(audit.js, "publish") + mocker.spy(audit, "search") + mocker.spy(audit, "chat") + + with global_utility(Utility.AUDIT, audit): + yield audit + + await audit.finalize() + + +# Local files @pytest.fixture(scope="function") async def local_files(): storage_settings.local_testing_files = f"{dirname(__file__)}" + + +# Shard manager + + +@pytest.fixture(scope="function") +async def shard_manager(storage: Storage, maindb_driver: Driver) -> AsyncIterator[KBShardManager]: + sm = KBShardManager() + set_utility(Utility.SHARD_MANAGER, sm) + + yield sm + + clean_utility(Utility.SHARD_MANAGER) diff --git a/nucliadb/tests/ndbfixtures/search.py b/nucliadb/tests/ndbfixtures/search.py index d79d623186..1263cb3b26 100644 --- a/nucliadb/tests/ndbfixtures/search.py +++ b/nucliadb/tests/ndbfixtures/search.py @@ -238,3 +238,13 @@ async def wait_for_shard(knowledgebox_ingest: str, count: int) -> str: # Wait an extra couple of seconds for reader/searcher to catch up await asyncio.sleep(2) return knowledgebox_ingest + + +# Dependencies from tests/fixtures.py + + +@pytest.fixture(scope="function") +async def txn(maindb_driver): + async with maindb_driver.transaction() as txn: + yield txn + await txn.abort() diff --git a/nucliadb/tests/search/conftest.py b/nucliadb/tests/search/conftest.py index 9b3b67a554..b4d446ba47 100644 --- a/nucliadb/tests/search/conftest.py +++ b/nucliadb/tests/search/conftest.py @@ -19,7 +19,6 @@ # pytest_plugins = [ "pytest_docker_fixtures", - "tests.fixtures", "tests.ndbfixtures.maindb", "tests.ndbfixtures.processing", "tests.ndbfixtures.standalone", From 659bd554e9c02cc49261181b933b46c02531401a Mon Sep 17 00:00:00 2001 From: Joan Antoni RE Date: Fri, 27 Dec 2024 17:17:10 +0100 Subject: [PATCH 4/9] Use patch.object instead of changing global settings --- nucliadb/tests/ndbfixtures/search.py | 72 +++++++++++----------------- 1 file changed, 29 insertions(+), 43 deletions(-) diff --git a/nucliadb/tests/ndbfixtures/search.py b/nucliadb/tests/ndbfixtures/search.py index 1263cb3b26..25c6d7719f 100644 --- a/nucliadb/tests/ndbfixtures/search.py +++ b/nucliadb/tests/ndbfixtures/search.py @@ -20,6 +20,7 @@ import asyncio from enum import Enum from typing import AsyncIterable, Optional +from unittest.mock import patch import pytest from httpx import AsyncClient @@ -29,67 +30,52 @@ from nucliadb.common.maindb.utils import get_driver from nucliadb.common.nidx import get_nidx_api_client from nucliadb.ingest.cache import clear_ingest_cache +from nucliadb.ingest.settings import settings as ingest_settings from nucliadb.search import API_PREFIX from nucliadb.search.predict import DummyPredictEngine from nucliadb_protos.nodereader_pb2 import GetShardRequest from nucliadb_protos.noderesources_pb2 import Shard -from nucliadb_utils.settings import nuclia_settings +from nucliadb_utils.cache.settings import settings as cache_settings +from nucliadb_utils.settings import ( + nuclia_settings, + nucliadb_settings, + running_settings, +) from nucliadb_utils.tests import free_port from nucliadb_utils.utilities import ( Utility, - clean_utility, clear_global_cache, - get_utility, - set_utility, ) from tests.ingest.fixtures import broker_resource +from tests.ndbfixtures.utils import global_utility @pytest.fixture(scope="function") def test_settings_search(storage, natsd, node, maindb_driver): # type: ignore - from nucliadb.ingest.settings import settings as ingest_settings - from nucliadb_utils.cache.settings import settings as cache_settings - from nucliadb_utils.settings import ( - nuclia_settings, - nucliadb_settings, - running_settings, - ) - - cache_settings.cache_pubsub_nats_url = [natsd] - - running_settings.debug = False - - ingest_settings.disable_pull_worker = True - - ingest_settings.nuclia_partitions = 1 - - nuclia_settings.dummy_processing = True - nuclia_settings.dummy_predict = True - nuclia_settings.dummy_learning_services = True - - ingest_settings.grpc_port = free_port() - - nucliadb_settings.nucliadb_ingest = f"localhost:{ingest_settings.grpc_port}" + with ( + patch.object(cache_settings, "cache_pubsub_nats_url", [natsd]), + patch.object(running_settings, "debug", False), + patch.object(ingest_settings, "disable_pull_worker", True), + patch.object(ingest_settings, "nuclia_partitions", 1), + patch.object(nuclia_settings, "dummy_processing", True), + patch.object(nuclia_settings, "dummy_predict", True), + patch.object(nuclia_settings, "dummy_learning_services", True), + patch.object(ingest_settings, "grpc_port", free_port()), + patch.object(nucliadb_settings, "nucliadb_ingest", f"localhost:{ingest_settings.grpc_port}"), + ): + yield @pytest.fixture(scope="function") async def dummy_predict() -> AsyncIterable[DummyPredictEngine]: - original_setting = nuclia_settings.dummy_predict - nuclia_settings.dummy_predict = True - - predict_util = DummyPredictEngine() - await predict_util.initialize() - original_predict = get_utility(Utility.PREDICT) - set_utility(Utility.PREDICT, predict_util) - - yield predict_util - - nuclia_settings.dummy_predict = original_setting - - if original_predict is None: - clean_utility(Utility.PREDICT) - else: - set_utility(Utility.PREDICT, original_predict) + with ( + patch.object(nuclia_settings, "dummy_predict", True), + ): + predict_util = DummyPredictEngine() + await predict_util.initialize() + + with global_utility(Utility.PREDICT, predict_util): + yield predict_util @pytest.fixture(scope="function") From 75fac4abe0c1dd977120c6dcb836fdf658c4ff63 Mon Sep 17 00:00:00 2001 From: Joan Antoni RE Date: Fri, 27 Dec 2024 17:25:10 +0100 Subject: [PATCH 5/9] Use create_api_client_factory in search fixture --- nucliadb/tests/ndbfixtures/search.py | 35 +++------------------------- 1 file changed, 3 insertions(+), 32 deletions(-) diff --git a/nucliadb/tests/ndbfixtures/search.py b/nucliadb/tests/ndbfixtures/search.py index 25c6d7719f..b55566f45b 100644 --- a/nucliadb/tests/ndbfixtures/search.py +++ b/nucliadb/tests/ndbfixtures/search.py @@ -18,12 +18,10 @@ # along with this program. If not, see . import asyncio -from enum import Enum -from typing import AsyncIterable, Optional +from typing import AsyncIterable from unittest.mock import patch import pytest -from httpx import AsyncClient from redis import asyncio as aioredis from nucliadb.common.cluster.manager import KBShardManager, get_index_node @@ -31,7 +29,6 @@ from nucliadb.common.nidx import get_nidx_api_client from nucliadb.ingest.cache import clear_ingest_cache from nucliadb.ingest.settings import settings as ingest_settings -from nucliadb.search import API_PREFIX from nucliadb.search.predict import DummyPredictEngine from nucliadb_protos.nodereader_pb2 import GetShardRequest from nucliadb_protos.noderesources_pb2 import Shard @@ -47,7 +44,7 @@ clear_global_cache, ) from tests.ingest.fixtures import broker_resource -from tests.ndbfixtures.utils import global_utility +from tests.ndbfixtures.utils import create_api_client_factory, global_utility @pytest.fixture(scope="function") @@ -83,32 +80,6 @@ async def search_api(test_settings_search, transaction_utility, redis): # type: from nucliadb.common.cluster import manager from nucliadb.search.app import application - def make_client_fixture( - roles: Optional[list[Enum]] = None, - user: str = "", - version: str = "1", - root: bool = False, - extra_headers: Optional[dict[str, str]] = None, - ) -> AsyncClient: - roles = roles or [] - client_base_url = "http://test" - - if root is False: - client_base_url = f"{client_base_url}/{API_PREFIX}/v{version}" - - client = AsyncClient(app=application, base_url=client_base_url) - client.headers["X-NUCLIADB-ROLES"] = ";".join([role.value for role in roles]) - client.headers["X-NUCLIADB-USER"] = user - - extra_headers = extra_headers or {} - if len(extra_headers) == 0: - return client - - for header, value in extra_headers.items(): - client.headers[f"{header}"] = value - - return client - driver = aioredis.from_url(f"redis://{redis[0]}:{redis[1]}") await driver.flushall() @@ -123,7 +94,7 @@ def make_client_fixture( raise Exception("No cluster") count += 1 - yield make_client_fixture + yield create_api_client_factory(application) # Make sure nodes can sync await asyncio.sleep(1) From f6fd8fcc5d6aa3c72bf011aef22757ad02705538 Mon Sep 17 00:00:00 2001 From: Joan Antoni RE Date: Fri, 27 Dec 2024 17:26:05 +0100 Subject: [PATCH 6/9] Search was not even using redis anymore... --- nucliadb/tests/ndbfixtures/search.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/nucliadb/tests/ndbfixtures/search.py b/nucliadb/tests/ndbfixtures/search.py index b55566f45b..416c1279ca 100644 --- a/nucliadb/tests/ndbfixtures/search.py +++ b/nucliadb/tests/ndbfixtures/search.py @@ -22,7 +22,6 @@ from unittest.mock import patch import pytest -from redis import asyncio as aioredis from nucliadb.common.cluster.manager import KBShardManager, get_index_node from nucliadb.common.maindb.utils import get_driver @@ -76,13 +75,10 @@ async def dummy_predict() -> AsyncIterable[DummyPredictEngine]: @pytest.fixture(scope="function") -async def search_api(test_settings_search, transaction_utility, redis): # type: ignore +async def search_api(test_settings_search, transaction_utility): # type: ignore from nucliadb.common.cluster import manager from nucliadb.search.app import application - driver = aioredis.from_url(f"redis://{redis[0]}:{redis[1]}") - await driver.flushall() - async with application.router.lifespan_context(application): # Make sure is clean await asyncio.sleep(1) @@ -98,8 +94,6 @@ async def search_api(test_settings_search, transaction_utility, redis): # type: # Make sure nodes can sync await asyncio.sleep(1) - await driver.flushall() - await driver.close(close_connection_pool=True) clear_ingest_cache() clear_global_cache() manager.INDEX_NODES.clear() From 5e2f9a55c0e24aa8bd4a38652fac103352df11ae Mon Sep 17 00:00:00 2001 From: Joan Antoni RE Date: Fri, 27 Dec 2024 17:33:11 +0100 Subject: [PATCH 7/9] Don't wait search cluster for nothing --- nucliadb/tests/ndbfixtures/search.py | 44 +++++++++++++++------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/nucliadb/tests/ndbfixtures/search.py b/nucliadb/tests/ndbfixtures/search.py index 416c1279ca..94c0ec6de4 100644 --- a/nucliadb/tests/ndbfixtures/search.py +++ b/nucliadb/tests/ndbfixtures/search.py @@ -18,16 +18,19 @@ # along with this program. If not, see . import asyncio +import datetime from typing import AsyncIterable from unittest.mock import patch import pytest +from nucliadb.common.cluster import manager from nucliadb.common.cluster.manager import KBShardManager, get_index_node from nucliadb.common.maindb.utils import get_driver from nucliadb.common.nidx import get_nidx_api_client from nucliadb.ingest.cache import clear_ingest_cache from nucliadb.ingest.settings import settings as ingest_settings +from nucliadb.search.app import application from nucliadb.search.predict import DummyPredictEngine from nucliadb_protos.nodereader_pb2 import GetShardRequest from nucliadb_protos.noderesources_pb2 import Shard @@ -76,27 +79,26 @@ async def dummy_predict() -> AsyncIterable[DummyPredictEngine]: @pytest.fixture(scope="function") async def search_api(test_settings_search, transaction_utility): # type: ignore - from nucliadb.common.cluster import manager - from nucliadb.search.app import application - - async with application.router.lifespan_context(application): - # Make sure is clean - await asyncio.sleep(1) - count = 0 - while len(manager.INDEX_NODES) < 2: - print("awaiting cluster nodes - search fixtures.py") - await asyncio.sleep(1) - if count == 40: - raise Exception("No cluster") - count += 1 - - yield create_api_client_factory(application) - - # Make sure nodes can sync - await asyncio.sleep(1) - clear_ingest_cache() - clear_global_cache() - manager.INDEX_NODES.clear() + with patch.dict(manager.INDEX_NODES, clear=True): + async with application.router.lifespan_context(application): + # Make sure is clean + delay = 0.1 + timeout = datetime.timedelta(seconds=30) + start = datetime.datetime.now() + + await asyncio.sleep(delay) + while len(manager.INDEX_NODES) < 2: + print("awaiting cluster nodes - search fixtures.py") + await asyncio.sleep(delay) + if (datetime.datetime.now() - start) > timeout: + raise Exception("No cluster") + + yield create_api_client_factory(application) + + # Make sure nodes can sync + await asyncio.sleep(delay) + clear_ingest_cache() + clear_global_cache() @pytest.fixture(scope="function") From 2860654f13fefb6dcb1a9e545595ac2ef4bbf315 Mon Sep 17 00:00:00 2001 From: Joan Antoni RE Date: Fri, 27 Dec 2024 17:38:20 +0100 Subject: [PATCH 8/9] Don't need 25 massive resources for almost unit tests --- nucliadb/tests/search/integration/api/v1/test_ask_audit.py | 4 ++-- nucliadb/tests/search/integration/requesters/test_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nucliadb/tests/search/integration/api/v1/test_ask_audit.py b/nucliadb/tests/search/integration/api/v1/test_ask_audit.py index 49745d4f11..1ad6c337b4 100644 --- a/nucliadb/tests/search/integration/api/v1/test_ask_audit.py +++ b/nucliadb/tests/search/integration/api/v1/test_ask_audit.py @@ -37,11 +37,11 @@ async def get_audit_messages(sub): async def test_ask_sends_only_one_audit( - search_api: Callable[..., AsyncClient], multiple_search_resource: str, stream_audit + search_api: Callable[..., AsyncClient], test_search_resource: str, stream_audit ) -> None: from nucliadb_utils.settings import audit_settings - kbid = multiple_search_resource + kbid = test_search_resource # Prepare a test audit stream to receive our messages partition = stream_audit.get_partition(kbid) diff --git a/nucliadb/tests/search/integration/requesters/test_utils.py b/nucliadb/tests/search/integration/requesters/test_utils.py index 2009e6f9ba..c2d0cca00f 100644 --- a/nucliadb/tests/search/integration/requesters/test_utils.py +++ b/nucliadb/tests/search/integration/requesters/test_utils.py @@ -35,9 +35,9 @@ @pytest.mark.xfail # pulling start/end position for vectors results needs to be fixed async def test_vector_result_metadata( - search_api: Callable[..., AsyncClient], multiple_search_resource: str + search_api: Callable[..., AsyncClient], test_search_resource: str ) -> None: - kbid = multiple_search_resource + kbid = test_search_resource pb_query, _, _ = await QueryParser( kbid=kbid, From a79e67b17304e280aeffa8a3bea55e14428555fc Mon Sep 17 00:00:00 2001 From: Joan Antoni RE Date: Fri, 27 Dec 2024 18:17:41 +0100 Subject: [PATCH 9/9] New cluster_nucliadb_search fixture (wip for deploy_modes) --- nucliadb/tests/ndbfixtures/search.py | 55 +++++---- .../integration/api/v1/test_ask_audit.py | 15 +-- .../search/integration/api/v1/test_find.py | 45 ++++---- .../search/integration/api/v1/test_search.py | 109 ++++++++---------- .../search/integration/api/v1/test_suggest.py | 17 ++- .../integration/requesters/test_utils.py | 3 +- 6 files changed, 117 insertions(+), 127 deletions(-) diff --git a/nucliadb/tests/ndbfixtures/search.py b/nucliadb/tests/ndbfixtures/search.py index 94c0ec6de4..15ff2bdb96 100644 --- a/nucliadb/tests/ndbfixtures/search.py +++ b/nucliadb/tests/ndbfixtures/search.py @@ -26,12 +26,14 @@ from nucliadb.common.cluster import manager from nucliadb.common.cluster.manager import KBShardManager, get_index_node +from nucliadb.common.maindb.driver import Driver from nucliadb.common.maindb.utils import get_driver from nucliadb.common.nidx import get_nidx_api_client from nucliadb.ingest.cache import clear_ingest_cache from nucliadb.ingest.settings import settings as ingest_settings from nucliadb.search.app import application from nucliadb.search.predict import DummyPredictEngine +from nucliadb_models.resource import NucliaDBRoles from nucliadb_protos.nodereader_pb2 import GetShardRequest from nucliadb_protos.noderesources_pb2 import Shard from nucliadb_utils.cache.settings import settings as cache_settings @@ -40,7 +42,9 @@ nucliadb_settings, running_settings, ) +from nucliadb_utils.storages.storage import Storage from nucliadb_utils.tests import free_port +from nucliadb_utils.transaction import TransactionUtility from nucliadb_utils.utilities import ( Utility, clear_global_cache, @@ -48,11 +52,19 @@ from tests.ingest.fixtures import broker_resource from tests.ndbfixtures.utils import create_api_client_factory, global_utility +# Main fixtures + @pytest.fixture(scope="function") -def test_settings_search(storage, natsd, node, maindb_driver): # type: ignore +async def cluster_nucliadb_search( + storage: Storage, + nats_server: str, + node, + maindb_driver: Driver, + transaction_utility: TransactionUtility, +): with ( - patch.object(cache_settings, "cache_pubsub_nats_url", [natsd]), + patch.object(cache_settings, "cache_pubsub_nats_url", [nats_server]), patch.object(running_settings, "debug", False), patch.object(ingest_settings, "disable_pull_worker", True), patch.object(ingest_settings, "nuclia_partitions", 1), @@ -61,25 +73,8 @@ def test_settings_search(storage, natsd, node, maindb_driver): # type: ignore patch.object(nuclia_settings, "dummy_learning_services", True), patch.object(ingest_settings, "grpc_port", free_port()), patch.object(nucliadb_settings, "nucliadb_ingest", f"localhost:{ingest_settings.grpc_port}"), + patch.dict(manager.INDEX_NODES, clear=True), ): - yield - - -@pytest.fixture(scope="function") -async def dummy_predict() -> AsyncIterable[DummyPredictEngine]: - with ( - patch.object(nuclia_settings, "dummy_predict", True), - ): - predict_util = DummyPredictEngine() - await predict_util.initialize() - - with global_utility(Utility.PREDICT, predict_util): - yield predict_util - - -@pytest.fixture(scope="function") -async def search_api(test_settings_search, transaction_utility): # type: ignore - with patch.dict(manager.INDEX_NODES, clear=True): async with application.router.lifespan_context(application): # Make sure is clean delay = 0.1 @@ -93,14 +88,32 @@ async def search_api(test_settings_search, transaction_utility): # type: ignore if (datetime.datetime.now() - start) > timeout: raise Exception("No cluster") - yield create_api_client_factory(application) + client_factory = create_api_client_factory(application) + async with client_factory(roles=[NucliaDBRoles.READER]) as client: + yield client # Make sure nodes can sync await asyncio.sleep(delay) + # TODO: fix this awful global state manipulation clear_ingest_cache() clear_global_cache() +# Rest, TODO keep cleaning + + +@pytest.fixture(scope="function") +async def dummy_predict() -> AsyncIterable[DummyPredictEngine]: + with ( + patch.object(nuclia_settings, "dummy_predict", True), + ): + predict_util = DummyPredictEngine() + await predict_util.initialize() + + with global_utility(Utility.PREDICT, predict_util): + yield predict_util + + @pytest.fixture(scope="function") async def test_search_resource( indexing_utility_registered, diff --git a/nucliadb/tests/search/integration/api/v1/test_ask_audit.py b/nucliadb/tests/search/integration/api/v1/test_ask_audit.py index 1ad6c337b4..3122d7e5d7 100644 --- a/nucliadb/tests/search/integration/api/v1/test_ask_audit.py +++ b/nucliadb/tests/search/integration/api/v1/test_ask_audit.py @@ -17,7 +17,6 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # -from typing import Callable import nats from httpx import AsyncClient @@ -25,7 +24,6 @@ from nats.js import JetStreamContext from nucliadb.search.api.v1.router import KB_PREFIX -from nucliadb_models.resource import NucliaDBRoles from nucliadb_protos.audit_pb2 import AuditRequest @@ -37,7 +35,7 @@ async def get_audit_messages(sub): async def test_ask_sends_only_one_audit( - search_api: Callable[..., AsyncClient], test_search_resource: str, stream_audit + cluster_nucliadb_search: AsyncClient, test_search_resource: str, stream_audit ) -> None: from nucliadb_utils.settings import audit_settings @@ -61,12 +59,11 @@ async def test_ask_sends_only_one_audit( psub = await jetstream.pull_subscribe(subject, "psub") - async with search_api(roles=[NucliaDBRoles.READER]) as client: - resp = await client.post( - f"/{KB_PREFIX}/{kbid}/ask", - json={"query": "title"}, - ) - assert resp.status_code == 200 + resp = await cluster_nucliadb_search.post( + f"/{KB_PREFIX}/{kbid}/ask", + json={"query": "title"}, + ) + assert resp.status_code == 200 # Testing the middleware integration where it collects audit calls and sends a single message # at requests ends. In this case we expect one seach and one chat sent once diff --git a/nucliadb/tests/search/integration/api/v1/test_find.py b/nucliadb/tests/search/integration/api/v1/test_find.py index 7cd41895fc..f327910c7d 100644 --- a/nucliadb/tests/search/integration/api/v1/test_find.py +++ b/nucliadb/tests/search/integration/api/v1/test_find.py @@ -17,35 +17,32 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # -from typing import Callable from httpx import AsyncClient from nucliadb.search.api.v1.router import KB_PREFIX -from nucliadb_models.resource import NucliaDBRoles -async def test_find(search_api: Callable[..., AsyncClient], multiple_search_resource: str) -> None: +async def test_find(cluster_nucliadb_search: AsyncClient, multiple_search_resource: str) -> None: kbid = multiple_search_resource - async with search_api(roles=[NucliaDBRoles.READER]) as client: - resp = await client.get( - f"/{KB_PREFIX}/{kbid}/find?query=own+text", - ) - assert resp.status_code == 200 - - data = resp.json() - - # TODO: uncomment when we have the total stable in tests - # assert data["total"] == 65 - - res = next(iter(data["resources"].values())) - para = next(iter(res["fields"]["/f/file"]["paragraphs"].values())) - assert para["position"] == { - "page_number": 0, - "index": 0, - "start": 0, - "end": 45, - "start_seconds": [0], - "end_seconds": [10], - } + resp = await cluster_nucliadb_search.get( + f"/{KB_PREFIX}/{kbid}/find?query=own+text", + ) + assert resp.status_code == 200 + + data = resp.json() + + # TODO: uncomment when we have the total stable in tests + # assert data["total"] == 65 + + res = next(iter(data["resources"].values())) + para = next(iter(res["fields"]["/f/file"]["paragraphs"].values())) + assert para["position"] == { + "page_number": 0, + "index": 0, + "start": 0, + "end": 45, + "start_seconds": [0], + "end_seconds": [10], + } diff --git a/nucliadb/tests/search/integration/api/v1/test_search.py b/nucliadb/tests/search/integration/api/v1/test_search.py index 3f5b1f6045..f47e2fd8e8 100644 --- a/nucliadb/tests/search/integration/api/v1/test_search.py +++ b/nucliadb/tests/search/integration/api/v1/test_search.py @@ -19,7 +19,6 @@ # import asyncio import os -from typing import Callable import pytest from httpx import AsyncClient @@ -29,7 +28,6 @@ from nucliadb.common.maindb.utils import get_driver from nucliadb.search.api.v1.router import KB_PREFIX from nucliadb.tests.vectors import Q -from nucliadb_models.resource import NucliaDBRoles from nucliadb_protos.nodereader_pb2 import ( SearchRequest, ) @@ -40,50 +38,48 @@ @pytest.mark.flaky(reruns=5) async def test_multiple_fuzzy_search_resource_all( - search_api: Callable[..., AsyncClient], multiple_search_resource: str + cluster_nucliadb_search: AsyncClient, multiple_search_resource: str ) -> None: kbid = multiple_search_resource - async with search_api(roles=[NucliaDBRoles.READER]) as client: - resp = await client.get( - f'/{KB_PREFIX}/{kbid}/search?query=own+test+"This is great"&highlight=true&top_k=20', - ) + resp = await cluster_nucliadb_search.get( + f'/{KB_PREFIX}/{kbid}/search?query=own+test+"This is great"&highlight=true&top_k=20', + ) - assert resp.status_code == 200, resp.content - assert len(resp.json()["paragraphs"]["results"]) == 20 + assert resp.status_code == 200, resp.content + assert len(resp.json()["paragraphs"]["results"]) == 20 - # Expected results: - # - 'text' should not be highlighted as we are searching by 'test' in the query - # - 'This is great' should be highlighted because it is an exact query search - # - 'own' should not be highlighted because it is considered as a stop-word - assert ( - resp.json()["paragraphs"]["results"][0]["text"] - == "My own text Ramon. This is great to be here. " - ) + # Expected results: + # - 'text' should not be highlighted as we are searching by 'test' in the query + # - 'This is great' should be highlighted because it is an exact query search + # - 'own' should not be highlighted because it is considered as a stop-word + assert ( + resp.json()["paragraphs"]["results"][0]["text"] + == "My own text Ramon. This is great to be here. " + ) @pytest.mark.flaky(reruns=3) async def test_search_resource_all( - search_api: Callable[..., AsyncClient], + cluster_nucliadb_search: AsyncClient, test_search_resource: str, ) -> None: kbid = test_search_resource - async with search_api(roles=[NucliaDBRoles.READER]) as client: - await asyncio.sleep(1) - resp = await client.get( - f"/{KB_PREFIX}/{kbid}/search?query=own+text&split=true&highlight=true&text_resource=true", - ) - assert resp.status_code == 200 - assert resp.json()["fulltext"]["query"] == "own text" - assert resp.json()["paragraphs"]["query"] == "own text" - assert resp.json()["paragraphs"]["results"][0]["start_seconds"] == [0] - assert resp.json()["paragraphs"]["results"][0]["end_seconds"] == [10] - assert ( - resp.json()["paragraphs"]["results"][0]["text"] - == "My own text Ramon. This is great to be here. " - ) - assert len(resp.json()["resources"]) == 1 - assert len(resp.json()["sentences"]["results"]) == 1 + await asyncio.sleep(1) + resp = await cluster_nucliadb_search.get( + f"/{KB_PREFIX}/{kbid}/search?query=own+text&split=true&highlight=true&text_resource=true", + ) + assert resp.status_code == 200 + assert resp.json()["fulltext"]["query"] == "own text" + assert resp.json()["paragraphs"]["query"] == "own text" + assert resp.json()["paragraphs"]["results"][0]["start_seconds"] == [0] + assert resp.json()["paragraphs"]["results"][0]["end_seconds"] == [10] + assert ( + resp.json()["paragraphs"]["results"][0]["text"] + == "My own text Ramon. This is great to be here. " + ) + assert len(resp.json()["resources"]) == 1 + assert len(resp.json()["sentences"]["results"]) == 1 # get shards ids @@ -135,33 +131,24 @@ async def test_search_resource_all( async def test_search_with_facets( - search_api: Callable[..., AsyncClient], multiple_search_resource: str + cluster_nucliadb_search: AsyncClient, multiple_search_resource: str ) -> None: kbid = multiple_search_resource - async with search_api(roles=[NucliaDBRoles.READER]) as client: - url = f"/{KB_PREFIX}/{kbid}/search?query=own+text&faceted=/classification.labels" - - resp = await client.get(url) - data = resp.json() - assert ( - data["fulltext"]["facets"]["/classification.labels"]["/classification.labels/labelset1"] - == 25 - ) - assert ( - data["paragraphs"]["facets"]["/classification.labels"]["/classification.labels/labelset1"] - == 25 - ) - - # also just test short hand filter - url = f"/{KB_PREFIX}/{kbid}/search?query=own+text&faceted=/l" - resp = await client.get(url) - data = resp.json() - assert ( - data["fulltext"]["facets"]["/classification.labels"]["/classification.labels/labelset1"] - == 25 - ) - assert ( - data["paragraphs"]["facets"]["/classification.labels"]["/classification.labels/labelset1"] - == 25 - ) + url = f"/{KB_PREFIX}/{kbid}/search?query=own+text&faceted=/classification.labels" + + resp = await cluster_nucliadb_search.get(url) + data = resp.json() + assert data["fulltext"]["facets"]["/classification.labels"]["/classification.labels/labelset1"] == 25 + assert ( + data["paragraphs"]["facets"]["/classification.labels"]["/classification.labels/labelset1"] == 25 + ) + + # also just test short hand filter + url = f"/{KB_PREFIX}/{kbid}/search?query=own+text&faceted=/l" + resp = await cluster_nucliadb_search.get(url) + data = resp.json() + assert data["fulltext"]["facets"]["/classification.labels"]["/classification.labels/labelset1"] == 25 + assert ( + data["paragraphs"]["facets"]["/classification.labels"]["/classification.labels/labelset1"] == 25 + ) diff --git a/nucliadb/tests/search/integration/api/v1/test_suggest.py b/nucliadb/tests/search/integration/api/v1/test_suggest.py index 10574f6e99..31bc0b637f 100644 --- a/nucliadb/tests/search/integration/api/v1/test_suggest.py +++ b/nucliadb/tests/search/integration/api/v1/test_suggest.py @@ -17,7 +17,6 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # -from typing import Callable import pytest from httpx import AsyncClient @@ -26,24 +25,22 @@ from nucliadb.common.datamanagers.cluster import KB_SHARDS from nucliadb.common.maindb.utils import get_driver from nucliadb.search.api.v1.router import KB_PREFIX -from nucliadb_models.resource import NucliaDBRoles from nucliadb_protos.nodereader_pb2 import SuggestFeatures, SuggestRequest from nucliadb_protos.writer_pb2 import Shards as PBShards @pytest.mark.flaky(reruns=5) async def test_suggest_resource_all( - search_api: Callable[..., AsyncClient], test_search_resource: str + cluster_nucliadb_search: AsyncClient, test_search_resource: str ) -> None: kbid = test_search_resource - async with search_api(roles=[NucliaDBRoles.READER]) as client: - resp = await client.get( - f"/{KB_PREFIX}/{kbid}/suggest?query=own+text", - ) - assert resp.status_code == 200 - paragraph_results = resp.json()["paragraphs"]["results"] - assert len(paragraph_results) == 1 + resp = await cluster_nucliadb_search.get( + f"/{KB_PREFIX}/{kbid}/suggest?query=own+text", + ) + assert resp.status_code == 200 + paragraph_results = resp.json()["paragraphs"]["results"] + assert len(paragraph_results) == 1 # get shards ids diff --git a/nucliadb/tests/search/integration/requesters/test_utils.py b/nucliadb/tests/search/integration/requesters/test_utils.py index c2d0cca00f..147ee8efdb 100644 --- a/nucliadb/tests/search/integration/requesters/test_utils.py +++ b/nucliadb/tests/search/integration/requesters/test_utils.py @@ -17,7 +17,6 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Callable import pytest from httpx import AsyncClient @@ -35,7 +34,7 @@ @pytest.mark.xfail # pulling start/end position for vectors results needs to be fixed async def test_vector_result_metadata( - search_api: Callable[..., AsyncClient], test_search_resource: str + cluster_nucliadb_search: AsyncClient, test_search_resource: str ) -> None: kbid = test_search_resource