diff --git a/nucliadb/tests/conftest.py b/nucliadb/tests/conftest.py
index 467765b7ec..d61dd23fb8 100644
--- a/nucliadb/tests/conftest.py
+++ b/nucliadb/tests/conftest.py
@@ -36,6 +36,7 @@
"tests.ndbfixtures.standalone",
"tests.ndbfixtures.reader",
"tests.ndbfixtures.writer",
+ "tests.ndbfixtures.search",
"tests.ndbfixtures.train",
"tests.ndbfixtures.ingest",
# subcomponents
diff --git a/nucliadb/tests/fixtures.py b/nucliadb/tests/fixtures.py
index 0a95702f97..cb3ca6ab04 100644
--- a/nucliadb/tests/fixtures.py
+++ b/nucliadb/tests/fixtures.py
@@ -352,6 +352,8 @@ async def knowledge_graph(nucliadb_writer: AsyncClient, nucliadb_grpc: WriterStu
return (nodes, edges)
+# TODO: remove after migrating tests/nucliadb/ to ndbfixtures. fixture already
+# moved to ndbfixtures.common
@pytest.fixture(scope="function")
async def stream_audit(natsd: str, mocker):
from nucliadb_utils.audit.stream import StreamAuditStorage
diff --git a/nucliadb/tests/ingest/fixtures.py b/nucliadb/tests/ingest/fixtures.py
index e4f784fdb1..1d1f66e0e2 100644
--- a/nucliadb/tests/ingest/fixtures.py
+++ b/nucliadb/tests/ingest/fixtures.py
@@ -23,7 +23,7 @@
from dataclasses import dataclass
from datetime import datetime
from os.path import dirname, getsize
-from typing import AsyncIterator, Iterable, Iterator, Optional
+from typing import AsyncIterator, Iterable, Optional
from unittest.mock import AsyncMock, patch
import nats
@@ -49,14 +49,11 @@
from nucliadb_protos.knowledgebox_pb2 import SemanticModelMetadata
from nucliadb_protos.writer_pb2 import BrokerMessage
from nucliadb_utils import const
-from nucliadb_utils.audit.audit import AuditStorage
-from nucliadb_utils.audit.basic import BasicAuditStorage
-from nucliadb_utils.audit.stream import StreamAuditStorage
from nucliadb_utils.cache.nats import NatsPubsub
from nucliadb_utils.cache.pubsub import PubSubDriver
from nucliadb_utils.indexing import IndexingUtility
from nucliadb_utils.nats import NatsConnectionManager
-from nucliadb_utils.settings import audit_settings, indexing_settings, transaction_settings
+from nucliadb_utils.settings import indexing_settings, transaction_settings
from nucliadb_utils.storages.settings import settings as storage_settings
from nucliadb_utils.storages.storage import Storage
from nucliadb_utils.transaction import TransactionUtility
@@ -225,33 +222,6 @@ async def knowledgebox_with_vectorsets(
await KnowledgeBox.delete(maindb_driver, kbid)
-@pytest.fixture(scope="function")
-def audit(basic_audit: BasicAuditStorage) -> Iterator[AuditStorage]:
- # XXX: why aren't we settings the utility?
- yield basic_audit
-
-
-@pytest.fixture(scope="function")
-async def basic_audit() -> AsyncIterator[BasicAuditStorage]:
- audit = BasicAuditStorage()
- await audit.initialize()
- yield audit
- await audit.finalize()
-
-
-@pytest.fixture(scope="function")
-async def stream_audit(nats_server: str) -> AsyncIterator[StreamAuditStorage]:
- audit = StreamAuditStorage(
- [nats_server],
- audit_settings.audit_jetstream_target, # type: ignore
- audit_settings.audit_partitions,
- audit_settings.audit_hash_seed,
- )
- await audit.initialize()
- yield audit
- await audit.finalize()
-
-
@pytest.fixture(scope="function")
async def indexing_utility(
dummy_indexing_utility: IndexingUtility,
diff --git a/nucliadb/tests/ndbfixtures/common.py b/nucliadb/tests/ndbfixtures/common.py
index 7995eb7fc5..13e110d567 100644
--- a/nucliadb/tests/ndbfixtures/common.py
+++ b/nucliadb/tests/ndbfixtures/common.py
@@ -18,31 +18,76 @@
# along with this program. If not, see .
#
from os.path import dirname
-from typing import AsyncIterator
+from typing import AsyncIterator, Iterator
import pytest
+from pytest_mock import MockerFixture
from nucliadb.common.cluster.manager import KBShardManager
from nucliadb.common.maindb.driver import Driver
+from nucliadb_utils.audit.audit import AuditStorage
+from nucliadb_utils.audit.basic import BasicAuditStorage
+from nucliadb_utils.audit.stream import StreamAuditStorage
+from nucliadb_utils.settings import audit_settings
from nucliadb_utils.storages.settings import settings as storage_settings
from nucliadb_utils.storages.storage import Storage
-from nucliadb_utils.utilities import (
- Utility,
- clean_utility,
- set_utility,
-)
+from nucliadb_utils.utilities import Utility, clean_utility, set_utility
+from tests.ndbfixtures.utils import global_utility
+
+# Audit
@pytest.fixture(scope="function")
-async def shard_manager(storage: Storage, maindb_driver: Driver) -> AsyncIterator[KBShardManager]:
- sm = KBShardManager()
- set_utility(Utility.SHARD_MANAGER, sm)
+def audit(basic_audit: BasicAuditStorage) -> Iterator[AuditStorage]:
+ yield basic_audit
- yield sm
- clean_utility(Utility.SHARD_MANAGER)
+@pytest.fixture(scope="function")
+async def basic_audit() -> AsyncIterator[BasicAuditStorage]:
+ audit = BasicAuditStorage()
+ await audit.initialize()
+ with global_utility(Utility.AUDIT, audit):
+ yield audit
+ await audit.finalize()
+
+
+@pytest.fixture(scope="function")
+async def stream_audit(nats_server: str, mocker: MockerFixture) -> AsyncIterator[StreamAuditStorage]:
+ audit = StreamAuditStorage(
+ [nats_server],
+ audit_settings.audit_jetstream_target, # type: ignore
+ audit_settings.audit_partitions,
+ audit_settings.audit_hash_seed,
+ )
+ await audit.initialize()
+
+ mocker.spy(audit, "send")
+ mocker.spy(audit.js, "publish")
+ mocker.spy(audit, "search")
+ mocker.spy(audit, "chat")
+
+ with global_utility(Utility.AUDIT, audit):
+ yield audit
+
+ await audit.finalize()
+
+
+# Local files
@pytest.fixture(scope="function")
async def local_files():
storage_settings.local_testing_files = f"{dirname(__file__)}"
+
+
+# Shard manager
+
+
+@pytest.fixture(scope="function")
+async def shard_manager(storage: Storage, maindb_driver: Driver) -> AsyncIterator[KBShardManager]:
+ sm = KBShardManager()
+ set_utility(Utility.SHARD_MANAGER, sm)
+
+ yield sm
+
+ clean_utility(Utility.SHARD_MANAGER)
diff --git a/nucliadb/tests/ndbfixtures/node.py b/nucliadb/tests/ndbfixtures/node.py
index 203b280228..8b95242205 100644
--- a/nucliadb/tests/ndbfixtures/node.py
+++ b/nucliadb/tests/ndbfixtures/node.py
@@ -17,13 +17,126 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
#
+import dataclasses
+import logging
+import os
+import time
import uuid
+from typing import Union
from unittest.mock import patch
+import backoff
+import docker # type: ignore
+import nats
import pytest
+from grpc import insecure_channel
+from grpc_health.v1 import health_pb2_grpc
+from grpc_health.v1.health_pb2 import HealthCheckRequest
+from nats.js.api import ConsumerConfig
+from pytest_docker_fixtures import images # type: ignore
+from pytest_docker_fixtures.containers._base import BaseImage # type: ignore
from nucliadb.common.cluster import manager
from nucliadb.common.cluster.settings import settings as cluster_settings
+from nucliadb.common.nidx import NIDX_ENABLED
+from nucliadb_protos.nodewriter_pb2 import EmptyQuery, ShardId
+from nucliadb_protos.nodewriter_pb2_grpc import NodeWriterStub
+from nucliadb_utils.tests.fixtures import get_testing_storage_backend
+
+logger = logging.getLogger(__name__)
+
+images.settings["nucliadb_node_reader"] = {
+ "image": "europe-west4-docker.pkg.dev/nuclia-internal/nuclia/node",
+ "version": "latest",
+ "env": {
+ "FILE_BACKEND": "unset",
+ "HOST_KEY_PATH": "/data/node.key",
+ "DATA_PATH": "/data",
+ "READER_LISTEN_ADDRESS": "0.0.0.0:4445",
+ "NUCLIADB_DISABLE_ANALYTICS": "True",
+ "RUST_BACKTRACE": "full",
+ "RUST_LOG": "nucliadb_*=DEBUG",
+ },
+ "options": {
+ "command": [
+ "/usr/local/bin/node_reader",
+ ],
+ "ports": {"4445": ("0.0.0.0", 0)},
+ "publish_all_ports": False,
+ "mem_limit": "3g", # default is 1g, need to override
+ "platform": "linux/amd64",
+ },
+}
+
+images.settings["nucliadb_node_writer"] = {
+ "image": "europe-west4-docker.pkg.dev/nuclia-internal/nuclia/node",
+ "version": "latest",
+ "env": {
+ "FILE_BACKEND": "unset",
+ "HOST_KEY_PATH": "/data/node.key",
+ "DATA_PATH": "/data",
+ "WRITER_LISTEN_ADDRESS": "0.0.0.0:4446",
+ "NUCLIADB_DISABLE_ANALYTICS": "True",
+ "RUST_BACKTRACE": "full",
+ "RUST_LOG": "nucliadb_*=DEBUG",
+ },
+ "options": {
+ "command": [
+ "/usr/local/bin/node_writer",
+ ],
+ "ports": {"4446": ("0.0.0.0", 0)},
+ "publish_all_ports": False,
+ "mem_limit": "3g", # default is 1g, need to override
+ "platform": "linux/amd64",
+ },
+}
+
+images.settings["nucliadb_node_sidecar"] = {
+ "image": "europe-west4-docker.pkg.dev/nuclia-internal/nuclia/node_sidecar",
+ "version": "latest",
+ "env": {
+ "INDEX_JETSTREAM_SERVERS": "[]",
+ "CACHE_PUBSUB_NATS_URL": "",
+ "HOST_KEY_PATH": "/data/node.key",
+ "DATA_PATH": "/data",
+ "SIDECAR_LISTEN_ADDRESS": "0.0.0.0:4447",
+ "READER_LISTEN_ADDRESS": "0.0.0.0:4445",
+ "WRITER_LISTEN_ADDRESS": "0.0.0.0:4446",
+ "PYTHONUNBUFFERED": "1",
+ "LOG_LEVEL": "DEBUG",
+ },
+ "options": {
+ "command": [
+ "node_sidecar",
+ ],
+ "ports": {"4447": ("0.0.0.0", 0)},
+ "publish_all_ports": False,
+ "platform": "linux/amd64",
+ },
+}
+
+images.settings["nidx"] = {
+ "image": "nidx",
+ "version": "latest",
+ "env": {},
+ "options": {
+ # A few indexers on purpose for faster indexing
+ "command": [
+ "nidx",
+ "api",
+ "searcher",
+ "indexer",
+ "indexer",
+ "indexer",
+ "indexer",
+ "scheduler",
+ "worker",
+ ],
+ "ports": {"10000": ("0.0.0.0", 0), "10001": ("0.0.0.0", 0)},
+ "publish_all_ports": False,
+ "platform": "linux/amd64",
+ },
+}
@pytest.fixture(scope="function")
@@ -50,3 +163,422 @@ async def dummy_index_node_cluster(
dummy=True,
)
yield
+
+
+def get_container_host(container_obj):
+ return container_obj.attrs["NetworkSettings"]["IPAddress"]
+
+
+class nucliadbNodeReader(BaseImage):
+ name = "nucliadb_node_reader"
+ port = 4445
+
+ def run(self, volume):
+ self._volume = volume
+ self._mount = "/data"
+ return super(nucliadbNodeReader, self).run()
+
+ def get_image_options(self):
+ options = super(nucliadbNodeReader, self).get_image_options()
+ options["volumes"] = {self._volume.name: {"bind": "/data"}}
+ return options
+
+ def check(self):
+ channel = insecure_channel(f"{self.host}:{self.get_port()}")
+ stub = health_pb2_grpc.HealthStub(channel)
+ pb = HealthCheckRequest(service="nodereader.NodeReader")
+ try:
+ result = stub.Check(pb)
+ return result.status == 1
+ except: # noqa
+ return False
+
+
+class nucliadbNodeWriter(BaseImage):
+ name = "nucliadb_node_writer"
+ port = 4446
+
+ def run(self, volume):
+ self._volume = volume
+ self._mount = "/data"
+ return super(nucliadbNodeWriter, self).run()
+
+ def get_image_options(self):
+ options = super(nucliadbNodeWriter, self).get_image_options()
+ options["volumes"] = {self._volume.name: {"bind": "/data"}}
+ return options
+
+ def check(self):
+ channel = insecure_channel(f"{self.host}:{self.get_port()}")
+ stub = health_pb2_grpc.HealthStub(channel)
+ pb = HealthCheckRequest(service="nodewriter.NodeWriter")
+ try:
+ result = stub.Check(pb)
+ return result.status == 1
+ except: # noqa
+ return False
+
+
+class nucliadbNodeSidecar(BaseImage):
+ name = "nucliadb_node_sidecar"
+ port = 4447
+
+ def run(self, volume):
+ self._volume = volume
+ self._mount = "/data"
+ return super(nucliadbNodeSidecar, self).run()
+
+ def get_image_options(self):
+ options = super(nucliadbNodeSidecar, self).get_image_options()
+ options["volumes"] = {self._volume.name: {"bind": "/data"}}
+ return options
+
+ def check(self):
+ channel = insecure_channel(f"{self.host}:{self.get_port()}")
+ stub = health_pb2_grpc.HealthStub(channel)
+ pb = HealthCheckRequest(service="")
+ try:
+ result = stub.Check(pb)
+ return result.status == 1
+ except: # noqa
+ return False
+
+
+class NidxImage(BaseImage):
+ name = "nidx"
+
+
+nucliadb_node_1_reader = nucliadbNodeReader()
+nucliadb_node_1_writer = nucliadbNodeWriter()
+nucliadb_node_1_sidecar = nucliadbNodeSidecar()
+
+nucliadb_node_2_reader = nucliadbNodeReader()
+nucliadb_node_2_writer = nucliadbNodeWriter()
+nucliadb_node_2_sidecar = nucliadbNodeSidecar()
+
+
+@dataclasses.dataclass
+class NodeS3Storage:
+ server: str
+
+ def envs(self):
+ return {
+ "FILE_BACKEND": "s3",
+ "S3_CLIENT_ID": "fake",
+ "S3_CLIENT_SECRET": "fake",
+ "S3_BUCKET": "test",
+ "S3_INDEXING_BUCKET": "indexing",
+ "S3_DEADLETTER_BUCKET": "deadletter",
+ "S3_ENDPOINT": self.server,
+ }
+
+
+@dataclasses.dataclass
+class NodeGCSStorage:
+ server: str
+
+ def envs(self):
+ return {
+ "FILE_BACKEND": "gcs",
+ "GCS_BUCKET": "test",
+ "GCS_INDEXING_BUCKET": "indexing",
+ "GCS_DEADLETTER_BUCKET": "deadletter",
+ "GCS_ENDPOINT_URL": self.server,
+ }
+
+
+NodeStorage = Union[NodeGCSStorage, NodeS3Storage]
+
+
+class _NodeRunner:
+ def __init__(self, natsd, storage: NodeStorage):
+ self.docker_client = docker.from_env(version=BaseImage.docker_version)
+ self.natsd = natsd
+ self.storage = storage
+ self.data = {} # type: ignore
+
+ def start(self):
+ docker_platform_name = self.docker_client.api.version()["Platform"]["Name"].upper()
+ if "GITHUB_ACTION" in os.environ:
+ # Valid when using github actions
+ docker_internal_host = "172.17.0.1"
+ elif docker_platform_name == "DOCKER ENGINE - COMMUNITY":
+ # for linux users
+ docker_internal_host = "172.17.0.1"
+ elif "DOCKER DESKTOP" in docker_platform_name:
+ # Valid when using Docker desktop
+ docker_internal_host = "host.docker.internal"
+ else:
+ docker_internal_host = "172.17.0.1"
+
+ self.volume_node_1 = self.docker_client.volumes.create(driver="local")
+ self.volume_node_2 = self.docker_client.volumes.create(driver="local")
+
+ self.storage.server = self.storage.server.replace("localhost", docker_internal_host)
+ images.settings["nucliadb_node_writer"]["env"].update(self.storage.envs())
+ writer1_host, writer1_port = nucliadb_node_1_writer.run(self.volume_node_1)
+ writer2_host, writer2_port = nucliadb_node_2_writer.run(self.volume_node_2)
+
+ reader1_host, reader1_port = nucliadb_node_1_reader.run(self.volume_node_1)
+ reader2_host, reader2_port = nucliadb_node_2_reader.run(self.volume_node_2)
+
+ natsd_server = self.natsd.replace("localhost", docker_internal_host)
+ images.settings["nucliadb_node_sidecar"]["env"].update(
+ {
+ "INDEX_JETSTREAM_SERVERS": f'["{natsd_server}"]',
+ "CACHE_PUBSUB_NATS_URL": f'["{natsd_server}"]',
+ "READER_LISTEN_ADDRESS": f"{docker_internal_host}:{reader1_port}",
+ "WRITER_LISTEN_ADDRESS": f"{docker_internal_host}:{writer1_port}",
+ }
+ )
+ images.settings["nucliadb_node_sidecar"]["env"].update(self.storage.envs())
+
+ sidecar1_host, sidecar1_port = nucliadb_node_1_sidecar.run(self.volume_node_1)
+
+ images.settings["nucliadb_node_sidecar"]["env"]["READER_LISTEN_ADDRESS"] = (
+ f"{docker_internal_host}:{reader2_port}"
+ )
+ images.settings["nucliadb_node_sidecar"]["env"]["WRITER_LISTEN_ADDRESS"] = (
+ f"{docker_internal_host}:{writer2_port}"
+ )
+
+ sidecar2_host, sidecar2_port = nucliadb_node_2_sidecar.run(self.volume_node_2)
+
+ writer1_internal_host = get_container_host(nucliadb_node_1_writer.container_obj)
+ writer2_internal_host = get_container_host(nucliadb_node_2_writer.container_obj)
+
+ self.data.update(
+ {
+ "writer1_internal_host": writer1_internal_host,
+ "writer2_internal_host": writer2_internal_host,
+ "writer1": {
+ "host": writer1_host,
+ "port": writer1_port,
+ },
+ "writer2": {
+ "host": writer2_host,
+ "port": writer2_port,
+ },
+ "reader1": {
+ "host": reader1_host,
+ "port": reader1_port,
+ },
+ "reader2": {
+ "host": reader2_host,
+ "port": reader2_port,
+ },
+ "sidecar1": {
+ "host": sidecar1_host,
+ "port": sidecar1_port,
+ },
+ "sidecar2": {
+ "host": sidecar2_host,
+ "port": sidecar2_port,
+ },
+ }
+ )
+ return self.data
+
+ def stop(self):
+ container_ids = []
+ for component in [
+ nucliadb_node_1_reader,
+ nucliadb_node_1_writer,
+ nucliadb_node_1_sidecar,
+ nucliadb_node_2_writer,
+ nucliadb_node_2_reader,
+ nucliadb_node_2_sidecar,
+ ]:
+ container_obj = getattr(component, "container_obj", None)
+ if container_obj:
+ container_ids.append(container_obj.id)
+ component.stop()
+
+ for container_id in container_ids:
+ for _ in range(5):
+ try:
+ self.docker_client.containers.get(container_id)
+ except docker.errors.NotFound:
+ break
+ time.sleep(2)
+
+ self.volume_node_1.remove()
+ self.volume_node_2.remove()
+
+ def setup_env(self):
+ # reset on every test run in case something touches it
+ cluster_settings.writer_port_map = {
+ self.data["writer1_internal_host"]: self.data["writer1"]["port"],
+ self.data["writer2_internal_host"]: self.data["writer2"]["port"],
+ }
+ cluster_settings.reader_port_map = {
+ self.data["writer1_internal_host"]: self.data["reader1"]["port"],
+ self.data["writer2_internal_host"]: self.data["reader2"]["port"],
+ }
+
+ cluster_settings.node_writer_port = None # type: ignore
+ cluster_settings.node_reader_port = None # type: ignore
+
+ cluster_settings.cluster_discovery_mode = "manual"
+ cluster_settings.cluster_discovery_manual_addresses = [
+ self.data["writer1_internal_host"],
+ self.data["writer2_internal_host"],
+ ]
+
+
+@pytest.fixture(scope="session")
+def gcs_node_storage(gcs):
+ return NodeGCSStorage(server=gcs)
+
+
+@pytest.fixture(scope="session")
+def s3_node_storage(s3):
+ return NodeS3Storage(server=s3)
+
+
+@pytest.fixture(scope="session")
+def node_storage(request):
+ backend = get_testing_storage_backend()
+ if backend == "gcs":
+ return request.getfixturevalue("gcs_node_storage")
+ elif backend == "s3":
+ return request.getfixturevalue("s3_node_storage")
+ else:
+ print(f"Unknown storage backend {backend}, using gcs")
+ return request.getfixturevalue("gcs_node_storage")
+
+
+@pytest.fixture(scope="session")
+def gcs_nidx_storage(gcs):
+ return {
+ "INDEXER__OBJECT_STORE": "gcs",
+ "INDEXER__BUCKET": "indexing",
+ "INDEXER__ENDPOINT": gcs,
+ "STORAGE__OBJECT_STORE": "gcs",
+ "STORAGE__ENDPOINT": gcs,
+ "STORAGE__BUCKET": "nidx",
+ }
+
+
+@pytest.fixture(scope="session")
+def s3_nidx_storage(s3):
+ return {
+ "INDEXER__OBJECT_STORE": "s3",
+ "INDEXER__BUCKET": "indexing",
+ "INDEXER__ENDPOINT": s3,
+ "STORAGE__OBJECT_STORE": "s3",
+ "STORAGE__ENDPOINT": s3,
+ "STORAGE__BUCKET": "nidx",
+ }
+
+
+@pytest.fixture(scope="session")
+def nidx_storage(request):
+ backend = get_testing_storage_backend()
+ if backend == "gcs":
+ return request.getfixturevalue("gcs_nidx_storage")
+ elif backend == "s3":
+ return request.getfixturevalue("s3_nidx_storage")
+
+
+@pytest.fixture(scope="session", autouse=False)
+def _node(natsd: str, node_storage):
+ nr = _NodeRunner(natsd, node_storage)
+ try:
+ cluster_info = nr.start()
+ except Exception:
+ nr.stop()
+ raise
+ nr.setup_env()
+ yield cluster_info
+ nr.stop()
+
+
+@pytest.fixture(scope="session")
+async def _nidx(natsd, nidx_storage, pg):
+ if not NIDX_ENABLED:
+ yield
+ return
+
+ # Create needed NATS stream/consumer
+ nc = await nats.connect(servers=[natsd])
+ js = nc.jetstream()
+ await js.add_stream(name="nidx", subjects=["nidx"])
+ await js.add_consumer(stream="nidx", config=ConsumerConfig(name="nidx"))
+ await nc.drain()
+ await nc.close()
+
+ # Run nidx
+ images.settings["nidx"]["env"] = {
+ "RUST_LOG": "info",
+ "METADATA__DATABASE_URL": f"postgresql://postgres:postgres@172.17.0.1:{pg[1]}/postgres",
+ "INDEXER__NATS_SERVER": natsd.replace("localhost", "172.17.0.1"),
+ **nidx_storage,
+ }
+ image = NidxImage()
+ image.run()
+
+ api_port = image.get_port(10000)
+ searcher_port = image.get_port(10001)
+
+ # Configure settings
+ from nucliadb_utils.settings import indexing_settings
+
+ cluster_settings.nidx_api_address = f"localhost:{api_port}"
+ cluster_settings.nidx_searcher_address = f"localhost:{searcher_port}"
+ indexing_settings.index_nidx_subject = "nidx"
+
+ yield
+
+ image.stop()
+
+
+@pytest.fixture(scope="function")
+def node(_nidx, _node, request):
+ # clean up all shard data before each test
+ channel1 = insecure_channel(f"{_node['writer1']['host']}:{_node['writer1']['port']}")
+ channel2 = insecure_channel(f"{_node['writer2']['host']}:{_node['writer2']['port']}")
+ writer1 = NodeWriterStub(channel1)
+ writer2 = NodeWriterStub(channel2)
+
+ logger.debug("cleaning up shards data")
+ try:
+ cleanup_node(writer1)
+ cleanup_node(writer2)
+ except Exception:
+ logger.error(
+ "Error cleaning up shards data. Maybe the node fixture could not start properly?",
+ exc_info=True,
+ )
+
+ client = docker.client.from_env()
+ containers_by_port = {}
+ for container in client.containers.list():
+ name = container.name
+ command = container.attrs["Config"]["Cmd"]
+ ports = container.ports
+ print(f"container {name} executing {command} is using ports: {ports}")
+
+ for internal_port in container.ports:
+ for host in container.ports[internal_port]:
+ port = host["HostPort"]
+ port_containers = containers_by_port.setdefault(port, [])
+ if container not in port_containers:
+ port_containers.append(container)
+
+ for port, containers in containers_by_port.items():
+ if len(containers) > 1:
+ names = ", ".join([container.name for container in containers])
+ print(f"ATENTION! Containers {names} share port {port}!")
+ raise
+ finally:
+ channel1.close()
+ channel2.close()
+
+ yield _node
+
+
+@backoff.on_exception(backoff.expo, Exception, jitter=backoff.random_jitter, max_tries=5)
+def cleanup_node(writer: NodeWriterStub):
+ for shard in writer.ListShards(EmptyQuery()).ids:
+ writer.DeleteShard(ShardId(id=shard.id))
diff --git a/nucliadb/tests/search/fixtures.py b/nucliadb/tests/ndbfixtures/search.py
similarity index 58%
rename from nucliadb/tests/search/fixtures.py
rename to nucliadb/tests/ndbfixtures/search.py
index d79d623186..15ff2bdb96 100644
--- a/nucliadb/tests/search/fixtures.py
+++ b/nucliadb/tests/ndbfixtures/search.py
@@ -18,134 +18,100 @@
# along with this program. If not, see .
import asyncio
-from enum import Enum
-from typing import AsyncIterable, Optional
+import datetime
+from typing import AsyncIterable
+from unittest.mock import patch
import pytest
-from httpx import AsyncClient
-from redis import asyncio as aioredis
+from nucliadb.common.cluster import manager
from nucliadb.common.cluster.manager import KBShardManager, get_index_node
+from nucliadb.common.maindb.driver import Driver
from nucliadb.common.maindb.utils import get_driver
from nucliadb.common.nidx import get_nidx_api_client
from nucliadb.ingest.cache import clear_ingest_cache
-from nucliadb.search import API_PREFIX
+from nucliadb.ingest.settings import settings as ingest_settings
+from nucliadb.search.app import application
from nucliadb.search.predict import DummyPredictEngine
+from nucliadb_models.resource import NucliaDBRoles
from nucliadb_protos.nodereader_pb2 import GetShardRequest
from nucliadb_protos.noderesources_pb2 import Shard
-from nucliadb_utils.settings import nuclia_settings
+from nucliadb_utils.cache.settings import settings as cache_settings
+from nucliadb_utils.settings import (
+ nuclia_settings,
+ nucliadb_settings,
+ running_settings,
+)
+from nucliadb_utils.storages.storage import Storage
from nucliadb_utils.tests import free_port
+from nucliadb_utils.transaction import TransactionUtility
from nucliadb_utils.utilities import (
Utility,
- clean_utility,
clear_global_cache,
- get_utility,
- set_utility,
)
from tests.ingest.fixtures import broker_resource
+from tests.ndbfixtures.utils import create_api_client_factory, global_utility
-
-@pytest.fixture(scope="function")
-def test_settings_search(storage, natsd, node, maindb_driver): # type: ignore
- from nucliadb.ingest.settings import settings as ingest_settings
- from nucliadb_utils.cache.settings import settings as cache_settings
- from nucliadb_utils.settings import (
- nuclia_settings,
- nucliadb_settings,
- running_settings,
- )
-
- cache_settings.cache_pubsub_nats_url = [natsd]
-
- running_settings.debug = False
-
- ingest_settings.disable_pull_worker = True
-
- ingest_settings.nuclia_partitions = 1
-
- nuclia_settings.dummy_processing = True
- nuclia_settings.dummy_predict = True
- nuclia_settings.dummy_learning_services = True
-
- ingest_settings.grpc_port = free_port()
-
- nucliadb_settings.nucliadb_ingest = f"localhost:{ingest_settings.grpc_port}"
+# Main fixtures
@pytest.fixture(scope="function")
-async def dummy_predict() -> AsyncIterable[DummyPredictEngine]:
- original_setting = nuclia_settings.dummy_predict
- nuclia_settings.dummy_predict = True
-
- predict_util = DummyPredictEngine()
- await predict_util.initialize()
- original_predict = get_utility(Utility.PREDICT)
- set_utility(Utility.PREDICT, predict_util)
-
- yield predict_util
-
- nuclia_settings.dummy_predict = original_setting
-
- if original_predict is None:
- clean_utility(Utility.PREDICT)
- else:
- set_utility(Utility.PREDICT, original_predict)
+async def cluster_nucliadb_search(
+ storage: Storage,
+ nats_server: str,
+ node,
+ maindb_driver: Driver,
+ transaction_utility: TransactionUtility,
+):
+ with (
+ patch.object(cache_settings, "cache_pubsub_nats_url", [nats_server]),
+ patch.object(running_settings, "debug", False),
+ patch.object(ingest_settings, "disable_pull_worker", True),
+ patch.object(ingest_settings, "nuclia_partitions", 1),
+ patch.object(nuclia_settings, "dummy_processing", True),
+ patch.object(nuclia_settings, "dummy_predict", True),
+ patch.object(nuclia_settings, "dummy_learning_services", True),
+ patch.object(ingest_settings, "grpc_port", free_port()),
+ patch.object(nucliadb_settings, "nucliadb_ingest", f"localhost:{ingest_settings.grpc_port}"),
+ patch.dict(manager.INDEX_NODES, clear=True),
+ ):
+ async with application.router.lifespan_context(application):
+ # Make sure is clean
+ delay = 0.1
+ timeout = datetime.timedelta(seconds=30)
+ start = datetime.datetime.now()
+
+ await asyncio.sleep(delay)
+ while len(manager.INDEX_NODES) < 2:
+ print("awaiting cluster nodes - search fixtures.py")
+ await asyncio.sleep(delay)
+ if (datetime.datetime.now() - start) > timeout:
+ raise Exception("No cluster")
+
+ client_factory = create_api_client_factory(application)
+ async with client_factory(roles=[NucliaDBRoles.READER]) as client:
+ yield client
+
+ # Make sure nodes can sync
+ await asyncio.sleep(delay)
+ # TODO: fix this awful global state manipulation
+ clear_ingest_cache()
+ clear_global_cache()
+
+
+# Rest, TODO keep cleaning
@pytest.fixture(scope="function")
-async def search_api(test_settings_search, transaction_utility, redis): # type: ignore
- from nucliadb.common.cluster import manager
- from nucliadb.search.app import application
-
- def make_client_fixture(
- roles: Optional[list[Enum]] = None,
- user: str = "",
- version: str = "1",
- root: bool = False,
- extra_headers: Optional[dict[str, str]] = None,
- ) -> AsyncClient:
- roles = roles or []
- client_base_url = "http://test"
-
- if root is False:
- client_base_url = f"{client_base_url}/{API_PREFIX}/v{version}"
-
- client = AsyncClient(app=application, base_url=client_base_url)
- client.headers["X-NUCLIADB-ROLES"] = ";".join([role.value for role in roles])
- client.headers["X-NUCLIADB-USER"] = user
-
- extra_headers = extra_headers or {}
- if len(extra_headers) == 0:
- return client
-
- for header, value in extra_headers.items():
- client.headers[f"{header}"] = value
-
- return client
-
- driver = aioredis.from_url(f"redis://{redis[0]}:{redis[1]}")
- await driver.flushall()
-
- async with application.router.lifespan_context(application):
- # Make sure is clean
- await asyncio.sleep(1)
- count = 0
- while len(manager.INDEX_NODES) < 2:
- print("awaiting cluster nodes - search fixtures.py")
- await asyncio.sleep(1)
- if count == 40:
- raise Exception("No cluster")
- count += 1
-
- yield make_client_fixture
+async def dummy_predict() -> AsyncIterable[DummyPredictEngine]:
+ with (
+ patch.object(nuclia_settings, "dummy_predict", True),
+ ):
+ predict_util = DummyPredictEngine()
+ await predict_util.initialize()
- # Make sure nodes can sync
- await asyncio.sleep(1)
- await driver.flushall()
- await driver.close(close_connection_pool=True)
- clear_ingest_cache()
- clear_global_cache()
- manager.INDEX_NODES.clear()
+ with global_utility(Utility.PREDICT, predict_util):
+ yield predict_util
@pytest.fixture(scope="function")
@@ -238,3 +204,13 @@ async def wait_for_shard(knowledgebox_ingest: str, count: int) -> str:
# Wait an extra couple of seconds for reader/searcher to catch up
await asyncio.sleep(2)
return knowledgebox_ingest
+
+
+# Dependencies from tests/fixtures.py
+
+
+@pytest.fixture(scope="function")
+async def txn(maindb_driver):
+ async with maindb_driver.transaction() as txn:
+ yield txn
+ await txn.abort()
diff --git a/nucliadb/tests/ndbfixtures/utils.py b/nucliadb/tests/ndbfixtures/utils.py
index f360344d36..5c66c9fadb 100644
--- a/nucliadb/tests/ndbfixtures/utils.py
+++ b/nucliadb/tests/ndbfixtures/utils.py
@@ -18,13 +18,19 @@
# along with this program. If not, see .
#
+import logging
+from contextlib import contextmanager
from enum import Enum
-from typing import Callable, Optional
+from typing import Any, Callable, Optional
+from unittest.mock import patch
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from nucliadb.search import API_PREFIX
+from nucliadb_utils.utilities import MAIN
+
+logger = logging.getLogger("fixtures.utils")
def create_api_client_factory(application: FastAPI) -> Callable[..., AsyncClient]:
@@ -56,3 +62,21 @@ def _make_client_fixture(
return client
return _make_client_fixture
+
+
+@contextmanager
+def global_utility(name: str, util: Any):
+ """Hacky set_utility used in tests to provide proper setup/cleanup of utilities.
+
+ Tests can sometimes mess with global state. While fixtures add/remove global
+ utilities, component lifecycles do the same. Sometimes, we can left
+ utilities unclean or overwrite utilities. This context manager allows tests
+ to remove utilities letting the previously set one.
+
+ """
+
+ if name in MAIN:
+ logger.warning(f"Overwriting previously set utility {name}: {MAIN[name]} with {util}")
+
+ with patch.dict(MAIN, values={name: util}, clear=False):
+ yield util
diff --git a/nucliadb/tests/search/conftest.py b/nucliadb/tests/search/conftest.py
index 59c6a57c1f..b4d446ba47 100644
--- a/nucliadb/tests/search/conftest.py
+++ b/nucliadb/tests/search/conftest.py
@@ -19,13 +19,12 @@
#
pytest_plugins = [
"pytest_docker_fixtures",
- "tests.fixtures",
"tests.ndbfixtures.maindb",
"tests.ndbfixtures.processing",
"tests.ndbfixtures.standalone",
"tests.ingest.fixtures", # should be refactored out
- "tests.search.node",
- "tests.search.fixtures",
+ "tests.ndbfixtures.node",
+ "tests.ndbfixtures.search",
"nucliadb_utils.tests.fixtures",
"nucliadb_utils.tests.gcs",
"nucliadb_utils.tests.s3",
diff --git a/nucliadb/tests/search/integration/api/v1/test_ask_audit.py b/nucliadb/tests/search/integration/api/v1/test_ask_audit.py
index 49745d4f11..3122d7e5d7 100644
--- a/nucliadb/tests/search/integration/api/v1/test_ask_audit.py
+++ b/nucliadb/tests/search/integration/api/v1/test_ask_audit.py
@@ -17,7 +17,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
#
-from typing import Callable
import nats
from httpx import AsyncClient
@@ -25,7 +24,6 @@
from nats.js import JetStreamContext
from nucliadb.search.api.v1.router import KB_PREFIX
-from nucliadb_models.resource import NucliaDBRoles
from nucliadb_protos.audit_pb2 import AuditRequest
@@ -37,11 +35,11 @@ async def get_audit_messages(sub):
async def test_ask_sends_only_one_audit(
- search_api: Callable[..., AsyncClient], multiple_search_resource: str, stream_audit
+ cluster_nucliadb_search: AsyncClient, test_search_resource: str, stream_audit
) -> None:
from nucliadb_utils.settings import audit_settings
- kbid = multiple_search_resource
+ kbid = test_search_resource
# Prepare a test audit stream to receive our messages
partition = stream_audit.get_partition(kbid)
@@ -61,12 +59,11 @@ async def test_ask_sends_only_one_audit(
psub = await jetstream.pull_subscribe(subject, "psub")
- async with search_api(roles=[NucliaDBRoles.READER]) as client:
- resp = await client.post(
- f"/{KB_PREFIX}/{kbid}/ask",
- json={"query": "title"},
- )
- assert resp.status_code == 200
+ resp = await cluster_nucliadb_search.post(
+ f"/{KB_PREFIX}/{kbid}/ask",
+ json={"query": "title"},
+ )
+ assert resp.status_code == 200
# Testing the middleware integration where it collects audit calls and sends a single message
# at requests ends. In this case we expect one seach and one chat sent once
diff --git a/nucliadb/tests/search/integration/api/v1/test_find.py b/nucliadb/tests/search/integration/api/v1/test_find.py
index 7cd41895fc..f327910c7d 100644
--- a/nucliadb/tests/search/integration/api/v1/test_find.py
+++ b/nucliadb/tests/search/integration/api/v1/test_find.py
@@ -17,35 +17,32 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
#
-from typing import Callable
from httpx import AsyncClient
from nucliadb.search.api.v1.router import KB_PREFIX
-from nucliadb_models.resource import NucliaDBRoles
-async def test_find(search_api: Callable[..., AsyncClient], multiple_search_resource: str) -> None:
+async def test_find(cluster_nucliadb_search: AsyncClient, multiple_search_resource: str) -> None:
kbid = multiple_search_resource
- async with search_api(roles=[NucliaDBRoles.READER]) as client:
- resp = await client.get(
- f"/{KB_PREFIX}/{kbid}/find?query=own+text",
- )
- assert resp.status_code == 200
-
- data = resp.json()
-
- # TODO: uncomment when we have the total stable in tests
- # assert data["total"] == 65
-
- res = next(iter(data["resources"].values()))
- para = next(iter(res["fields"]["/f/file"]["paragraphs"].values()))
- assert para["position"] == {
- "page_number": 0,
- "index": 0,
- "start": 0,
- "end": 45,
- "start_seconds": [0],
- "end_seconds": [10],
- }
+ resp = await cluster_nucliadb_search.get(
+ f"/{KB_PREFIX}/{kbid}/find?query=own+text",
+ )
+ assert resp.status_code == 200
+
+ data = resp.json()
+
+ # TODO: uncomment when we have the total stable in tests
+ # assert data["total"] == 65
+
+ res = next(iter(data["resources"].values()))
+ para = next(iter(res["fields"]["/f/file"]["paragraphs"].values()))
+ assert para["position"] == {
+ "page_number": 0,
+ "index": 0,
+ "start": 0,
+ "end": 45,
+ "start_seconds": [0],
+ "end_seconds": [10],
+ }
diff --git a/nucliadb/tests/search/integration/api/v1/test_search.py b/nucliadb/tests/search/integration/api/v1/test_search.py
index 3f5b1f6045..f47e2fd8e8 100644
--- a/nucliadb/tests/search/integration/api/v1/test_search.py
+++ b/nucliadb/tests/search/integration/api/v1/test_search.py
@@ -19,7 +19,6 @@
#
import asyncio
import os
-from typing import Callable
import pytest
from httpx import AsyncClient
@@ -29,7 +28,6 @@
from nucliadb.common.maindb.utils import get_driver
from nucliadb.search.api.v1.router import KB_PREFIX
from nucliadb.tests.vectors import Q
-from nucliadb_models.resource import NucliaDBRoles
from nucliadb_protos.nodereader_pb2 import (
SearchRequest,
)
@@ -40,50 +38,48 @@
@pytest.mark.flaky(reruns=5)
async def test_multiple_fuzzy_search_resource_all(
- search_api: Callable[..., AsyncClient], multiple_search_resource: str
+ cluster_nucliadb_search: AsyncClient, multiple_search_resource: str
) -> None:
kbid = multiple_search_resource
- async with search_api(roles=[NucliaDBRoles.READER]) as client:
- resp = await client.get(
- f'/{KB_PREFIX}/{kbid}/search?query=own+test+"This is great"&highlight=true&top_k=20',
- )
+ resp = await cluster_nucliadb_search.get(
+ f'/{KB_PREFIX}/{kbid}/search?query=own+test+"This is great"&highlight=true&top_k=20',
+ )
- assert resp.status_code == 200, resp.content
- assert len(resp.json()["paragraphs"]["results"]) == 20
+ assert resp.status_code == 200, resp.content
+ assert len(resp.json()["paragraphs"]["results"]) == 20
- # Expected results:
- # - 'text' should not be highlighted as we are searching by 'test' in the query
- # - 'This is great' should be highlighted because it is an exact query search
- # - 'own' should not be highlighted because it is considered as a stop-word
- assert (
- resp.json()["paragraphs"]["results"][0]["text"]
- == "My own text Ramon. This is great to be here. "
- )
+ # Expected results:
+ # - 'text' should not be highlighted as we are searching by 'test' in the query
+ # - 'This is great' should be highlighted because it is an exact query search
+ # - 'own' should not be highlighted because it is considered as a stop-word
+ assert (
+ resp.json()["paragraphs"]["results"][0]["text"]
+ == "My own text Ramon. This is great to be here. "
+ )
@pytest.mark.flaky(reruns=3)
async def test_search_resource_all(
- search_api: Callable[..., AsyncClient],
+ cluster_nucliadb_search: AsyncClient,
test_search_resource: str,
) -> None:
kbid = test_search_resource
- async with search_api(roles=[NucliaDBRoles.READER]) as client:
- await asyncio.sleep(1)
- resp = await client.get(
- f"/{KB_PREFIX}/{kbid}/search?query=own+text&split=true&highlight=true&text_resource=true",
- )
- assert resp.status_code == 200
- assert resp.json()["fulltext"]["query"] == "own text"
- assert resp.json()["paragraphs"]["query"] == "own text"
- assert resp.json()["paragraphs"]["results"][0]["start_seconds"] == [0]
- assert resp.json()["paragraphs"]["results"][0]["end_seconds"] == [10]
- assert (
- resp.json()["paragraphs"]["results"][0]["text"]
- == "My own text Ramon. This is great to be here. "
- )
- assert len(resp.json()["resources"]) == 1
- assert len(resp.json()["sentences"]["results"]) == 1
+ await asyncio.sleep(1)
+ resp = await cluster_nucliadb_search.get(
+ f"/{KB_PREFIX}/{kbid}/search?query=own+text&split=true&highlight=true&text_resource=true",
+ )
+ assert resp.status_code == 200
+ assert resp.json()["fulltext"]["query"] == "own text"
+ assert resp.json()["paragraphs"]["query"] == "own text"
+ assert resp.json()["paragraphs"]["results"][0]["start_seconds"] == [0]
+ assert resp.json()["paragraphs"]["results"][0]["end_seconds"] == [10]
+ assert (
+ resp.json()["paragraphs"]["results"][0]["text"]
+ == "My own text Ramon. This is great to be here. "
+ )
+ assert len(resp.json()["resources"]) == 1
+ assert len(resp.json()["sentences"]["results"]) == 1
# get shards ids
@@ -135,33 +131,24 @@ async def test_search_resource_all(
async def test_search_with_facets(
- search_api: Callable[..., AsyncClient], multiple_search_resource: str
+ cluster_nucliadb_search: AsyncClient, multiple_search_resource: str
) -> None:
kbid = multiple_search_resource
- async with search_api(roles=[NucliaDBRoles.READER]) as client:
- url = f"/{KB_PREFIX}/{kbid}/search?query=own+text&faceted=/classification.labels"
-
- resp = await client.get(url)
- data = resp.json()
- assert (
- data["fulltext"]["facets"]["/classification.labels"]["/classification.labels/labelset1"]
- == 25
- )
- assert (
- data["paragraphs"]["facets"]["/classification.labels"]["/classification.labels/labelset1"]
- == 25
- )
-
- # also just test short hand filter
- url = f"/{KB_PREFIX}/{kbid}/search?query=own+text&faceted=/l"
- resp = await client.get(url)
- data = resp.json()
- assert (
- data["fulltext"]["facets"]["/classification.labels"]["/classification.labels/labelset1"]
- == 25
- )
- assert (
- data["paragraphs"]["facets"]["/classification.labels"]["/classification.labels/labelset1"]
- == 25
- )
+ url = f"/{KB_PREFIX}/{kbid}/search?query=own+text&faceted=/classification.labels"
+
+ resp = await cluster_nucliadb_search.get(url)
+ data = resp.json()
+ assert data["fulltext"]["facets"]["/classification.labels"]["/classification.labels/labelset1"] == 25
+ assert (
+ data["paragraphs"]["facets"]["/classification.labels"]["/classification.labels/labelset1"] == 25
+ )
+
+ # also just test short hand filter
+ url = f"/{KB_PREFIX}/{kbid}/search?query=own+text&faceted=/l"
+ resp = await cluster_nucliadb_search.get(url)
+ data = resp.json()
+ assert data["fulltext"]["facets"]["/classification.labels"]["/classification.labels/labelset1"] == 25
+ assert (
+ data["paragraphs"]["facets"]["/classification.labels"]["/classification.labels/labelset1"] == 25
+ )
diff --git a/nucliadb/tests/search/integration/api/v1/test_suggest.py b/nucliadb/tests/search/integration/api/v1/test_suggest.py
index 10574f6e99..31bc0b637f 100644
--- a/nucliadb/tests/search/integration/api/v1/test_suggest.py
+++ b/nucliadb/tests/search/integration/api/v1/test_suggest.py
@@ -17,7 +17,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
#
-from typing import Callable
import pytest
from httpx import AsyncClient
@@ -26,24 +25,22 @@
from nucliadb.common.datamanagers.cluster import KB_SHARDS
from nucliadb.common.maindb.utils import get_driver
from nucliadb.search.api.v1.router import KB_PREFIX
-from nucliadb_models.resource import NucliaDBRoles
from nucliadb_protos.nodereader_pb2 import SuggestFeatures, SuggestRequest
from nucliadb_protos.writer_pb2 import Shards as PBShards
@pytest.mark.flaky(reruns=5)
async def test_suggest_resource_all(
- search_api: Callable[..., AsyncClient], test_search_resource: str
+ cluster_nucliadb_search: AsyncClient, test_search_resource: str
) -> None:
kbid = test_search_resource
- async with search_api(roles=[NucliaDBRoles.READER]) as client:
- resp = await client.get(
- f"/{KB_PREFIX}/{kbid}/suggest?query=own+text",
- )
- assert resp.status_code == 200
- paragraph_results = resp.json()["paragraphs"]["results"]
- assert len(paragraph_results) == 1
+ resp = await cluster_nucliadb_search.get(
+ f"/{KB_PREFIX}/{kbid}/suggest?query=own+text",
+ )
+ assert resp.status_code == 200
+ paragraph_results = resp.json()["paragraphs"]["results"]
+ assert len(paragraph_results) == 1
# get shards ids
diff --git a/nucliadb/tests/search/integration/requesters/test_utils.py b/nucliadb/tests/search/integration/requesters/test_utils.py
index 2009e6f9ba..147ee8efdb 100644
--- a/nucliadb/tests/search/integration/requesters/test_utils.py
+++ b/nucliadb/tests/search/integration/requesters/test_utils.py
@@ -17,7 +17,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Callable
import pytest
from httpx import AsyncClient
@@ -35,9 +34,9 @@
@pytest.mark.xfail # pulling start/end position for vectors results needs to be fixed
async def test_vector_result_metadata(
- search_api: Callable[..., AsyncClient], multiple_search_resource: str
+ cluster_nucliadb_search: AsyncClient, test_search_resource: str
) -> None:
- kbid = multiple_search_resource
+ kbid = test_search_resource
pb_query, _, _ = await QueryParser(
kbid=kbid,
diff --git a/nucliadb/tests/search/node.py b/nucliadb/tests/search/node.py
deleted file mode 100644
index 99dfe42d24..0000000000
--- a/nucliadb/tests/search/node.py
+++ /dev/null
@@ -1,556 +0,0 @@
-# Copyright (C) 2021 Bosutech XXI S.L.
-#
-# nucliadb is offered under the AGPL v3.0 and as commercial software.
-# For commercial licensing, contact us at info@nuclia.com.
-#
-# AGPL:
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-#
-
-import dataclasses
-import logging
-import os
-import time
-from typing import Union
-
-import backoff
-import docker # type: ignore
-import nats
-import pytest
-from grpc import insecure_channel
-from grpc_health.v1 import health_pb2_grpc
-from grpc_health.v1.health_pb2 import HealthCheckRequest
-from nats.js.api import ConsumerConfig
-from pytest_docker_fixtures import images # type: ignore
-from pytest_docker_fixtures.containers._base import BaseImage # type: ignore
-
-from nucliadb.common.cluster.settings import settings as cluster_settings
-from nucliadb.common.nidx import NIDX_ENABLED
-from nucliadb_protos.nodewriter_pb2 import EmptyQuery, ShardId
-from nucliadb_protos.nodewriter_pb2_grpc import NodeWriterStub
-from nucliadb_utils.tests.fixtures import get_testing_storage_backend
-
-logger = logging.getLogger(__name__)
-
-images.settings["nucliadb_node_reader"] = {
- "image": "europe-west4-docker.pkg.dev/nuclia-internal/nuclia/node",
- "version": "latest",
- "env": {
- "FILE_BACKEND": "unset",
- "HOST_KEY_PATH": "/data/node.key",
- "DATA_PATH": "/data",
- "READER_LISTEN_ADDRESS": "0.0.0.0:4445",
- "NUCLIADB_DISABLE_ANALYTICS": "True",
- "RUST_BACKTRACE": "full",
- "RUST_LOG": "nucliadb_*=DEBUG",
- },
- "options": {
- "command": [
- "/usr/local/bin/node_reader",
- ],
- "ports": {"4445": ("0.0.0.0", 0)},
- "publish_all_ports": False,
- "mem_limit": "3g", # default is 1g, need to override
- "platform": "linux/amd64",
- },
-}
-
-images.settings["nucliadb_node_writer"] = {
- "image": "europe-west4-docker.pkg.dev/nuclia-internal/nuclia/node",
- "version": "latest",
- "env": {
- "FILE_BACKEND": "unset",
- "HOST_KEY_PATH": "/data/node.key",
- "DATA_PATH": "/data",
- "WRITER_LISTEN_ADDRESS": "0.0.0.0:4446",
- "NUCLIADB_DISABLE_ANALYTICS": "True",
- "RUST_BACKTRACE": "full",
- "RUST_LOG": "nucliadb_*=DEBUG",
- },
- "options": {
- "command": [
- "/usr/local/bin/node_writer",
- ],
- "ports": {"4446": ("0.0.0.0", 0)},
- "publish_all_ports": False,
- "mem_limit": "3g", # default is 1g, need to override
- "platform": "linux/amd64",
- },
-}
-
-images.settings["nucliadb_node_sidecar"] = {
- "image": "europe-west4-docker.pkg.dev/nuclia-internal/nuclia/node_sidecar",
- "version": "latest",
- "env": {
- "INDEX_JETSTREAM_SERVERS": "[]",
- "CACHE_PUBSUB_NATS_URL": "",
- "HOST_KEY_PATH": "/data/node.key",
- "DATA_PATH": "/data",
- "SIDECAR_LISTEN_ADDRESS": "0.0.0.0:4447",
- "READER_LISTEN_ADDRESS": "0.0.0.0:4445",
- "WRITER_LISTEN_ADDRESS": "0.0.0.0:4446",
- "PYTHONUNBUFFERED": "1",
- "LOG_LEVEL": "DEBUG",
- },
- "options": {
- "command": [
- "node_sidecar",
- ],
- "ports": {"4447": ("0.0.0.0", 0)},
- "publish_all_ports": False,
- "platform": "linux/amd64",
- },
-}
-
-images.settings["nidx"] = {
- "image": "nidx",
- "version": "latest",
- "env": {},
- "options": {
- # A few indexers on purpose for faster indexing
- "command": [
- "nidx",
- "api",
- "searcher",
- "indexer",
- "indexer",
- "indexer",
- "indexer",
- "scheduler",
- "worker",
- ],
- "ports": {"10000": ("0.0.0.0", 0), "10001": ("0.0.0.0", 0)},
- "publish_all_ports": False,
- "platform": "linux/amd64",
- },
-}
-
-
-def get_container_host(container_obj):
- return container_obj.attrs["NetworkSettings"]["IPAddress"]
-
-
-class nucliadbNodeReader(BaseImage):
- name = "nucliadb_node_reader"
- port = 4445
-
- def run(self, volume):
- self._volume = volume
- self._mount = "/data"
- return super(nucliadbNodeReader, self).run()
-
- def get_image_options(self):
- options = super(nucliadbNodeReader, self).get_image_options()
- options["volumes"] = {self._volume.name: {"bind": "/data"}}
- return options
-
- def check(self):
- channel = insecure_channel(f"{self.host}:{self.get_port()}")
- stub = health_pb2_grpc.HealthStub(channel)
- pb = HealthCheckRequest(service="nodereader.NodeReader")
- try:
- result = stub.Check(pb)
- return result.status == 1
- except: # noqa
- return False
-
-
-class nucliadbNodeWriter(BaseImage):
- name = "nucliadb_node_writer"
- port = 4446
-
- def run(self, volume):
- self._volume = volume
- self._mount = "/data"
- return super(nucliadbNodeWriter, self).run()
-
- def get_image_options(self):
- options = super(nucliadbNodeWriter, self).get_image_options()
- options["volumes"] = {self._volume.name: {"bind": "/data"}}
- return options
-
- def check(self):
- channel = insecure_channel(f"{self.host}:{self.get_port()}")
- stub = health_pb2_grpc.HealthStub(channel)
- pb = HealthCheckRequest(service="nodewriter.NodeWriter")
- try:
- result = stub.Check(pb)
- return result.status == 1
- except: # noqa
- return False
-
-
-class nucliadbNodeSidecar(BaseImage):
- name = "nucliadb_node_sidecar"
- port = 4447
-
- def run(self, volume):
- self._volume = volume
- self._mount = "/data"
- return super(nucliadbNodeSidecar, self).run()
-
- def get_image_options(self):
- options = super(nucliadbNodeSidecar, self).get_image_options()
- options["volumes"] = {self._volume.name: {"bind": "/data"}}
- return options
-
- def check(self):
- channel = insecure_channel(f"{self.host}:{self.get_port()}")
- stub = health_pb2_grpc.HealthStub(channel)
- pb = HealthCheckRequest(service="")
- try:
- result = stub.Check(pb)
- return result.status == 1
- except: # noqa
- return False
-
-
-class NidxImage(BaseImage):
- name = "nidx"
-
-
-nucliadb_node_1_reader = nucliadbNodeReader()
-nucliadb_node_1_writer = nucliadbNodeWriter()
-nucliadb_node_1_sidecar = nucliadbNodeSidecar()
-
-nucliadb_node_2_reader = nucliadbNodeReader()
-nucliadb_node_2_writer = nucliadbNodeWriter()
-nucliadb_node_2_sidecar = nucliadbNodeSidecar()
-
-
-@dataclasses.dataclass
-class NodeS3Storage:
- server: str
-
- def envs(self):
- return {
- "FILE_BACKEND": "s3",
- "S3_CLIENT_ID": "fake",
- "S3_CLIENT_SECRET": "fake",
- "S3_BUCKET": "test",
- "S3_INDEXING_BUCKET": "indexing",
- "S3_DEADLETTER_BUCKET": "deadletter",
- "S3_ENDPOINT": self.server,
- }
-
-
-@dataclasses.dataclass
-class NodeGCSStorage:
- server: str
-
- def envs(self):
- return {
- "FILE_BACKEND": "gcs",
- "GCS_BUCKET": "test",
- "GCS_INDEXING_BUCKET": "indexing",
- "GCS_DEADLETTER_BUCKET": "deadletter",
- "GCS_ENDPOINT_URL": self.server,
- }
-
-
-NodeStorage = Union[NodeGCSStorage, NodeS3Storage]
-
-
-class _NodeRunner:
- def __init__(self, natsd, storage: NodeStorage):
- self.docker_client = docker.from_env(version=BaseImage.docker_version)
- self.natsd = natsd
- self.storage = storage
- self.data = {} # type: ignore
-
- def start(self):
- docker_platform_name = self.docker_client.api.version()["Platform"]["Name"].upper()
- if "GITHUB_ACTION" in os.environ:
- # Valid when using github actions
- docker_internal_host = "172.17.0.1"
- elif docker_platform_name == "DOCKER ENGINE - COMMUNITY":
- # for linux users
- docker_internal_host = "172.17.0.1"
- elif "DOCKER DESKTOP" in docker_platform_name:
- # Valid when using Docker desktop
- docker_internal_host = "host.docker.internal"
- else:
- docker_internal_host = "172.17.0.1"
-
- self.volume_node_1 = self.docker_client.volumes.create(driver="local")
- self.volume_node_2 = self.docker_client.volumes.create(driver="local")
-
- self.storage.server = self.storage.server.replace("localhost", docker_internal_host)
- images.settings["nucliadb_node_writer"]["env"].update(self.storage.envs())
- writer1_host, writer1_port = nucliadb_node_1_writer.run(self.volume_node_1)
- writer2_host, writer2_port = nucliadb_node_2_writer.run(self.volume_node_2)
-
- reader1_host, reader1_port = nucliadb_node_1_reader.run(self.volume_node_1)
- reader2_host, reader2_port = nucliadb_node_2_reader.run(self.volume_node_2)
-
- natsd_server = self.natsd.replace("localhost", docker_internal_host)
- images.settings["nucliadb_node_sidecar"]["env"].update(
- {
- "INDEX_JETSTREAM_SERVERS": f'["{natsd_server}"]',
- "CACHE_PUBSUB_NATS_URL": f'["{natsd_server}"]',
- "READER_LISTEN_ADDRESS": f"{docker_internal_host}:{reader1_port}",
- "WRITER_LISTEN_ADDRESS": f"{docker_internal_host}:{writer1_port}",
- }
- )
- images.settings["nucliadb_node_sidecar"]["env"].update(self.storage.envs())
-
- sidecar1_host, sidecar1_port = nucliadb_node_1_sidecar.run(self.volume_node_1)
-
- images.settings["nucliadb_node_sidecar"]["env"]["READER_LISTEN_ADDRESS"] = (
- f"{docker_internal_host}:{reader2_port}"
- )
- images.settings["nucliadb_node_sidecar"]["env"]["WRITER_LISTEN_ADDRESS"] = (
- f"{docker_internal_host}:{writer2_port}"
- )
-
- sidecar2_host, sidecar2_port = nucliadb_node_2_sidecar.run(self.volume_node_2)
-
- writer1_internal_host = get_container_host(nucliadb_node_1_writer.container_obj)
- writer2_internal_host = get_container_host(nucliadb_node_2_writer.container_obj)
-
- self.data.update(
- {
- "writer1_internal_host": writer1_internal_host,
- "writer2_internal_host": writer2_internal_host,
- "writer1": {
- "host": writer1_host,
- "port": writer1_port,
- },
- "writer2": {
- "host": writer2_host,
- "port": writer2_port,
- },
- "reader1": {
- "host": reader1_host,
- "port": reader1_port,
- },
- "reader2": {
- "host": reader2_host,
- "port": reader2_port,
- },
- "sidecar1": {
- "host": sidecar1_host,
- "port": sidecar1_port,
- },
- "sidecar2": {
- "host": sidecar2_host,
- "port": sidecar2_port,
- },
- }
- )
- return self.data
-
- def stop(self):
- container_ids = []
- for component in [
- nucliadb_node_1_reader,
- nucliadb_node_1_writer,
- nucliadb_node_1_sidecar,
- nucliadb_node_2_writer,
- nucliadb_node_2_reader,
- nucliadb_node_2_sidecar,
- ]:
- container_obj = getattr(component, "container_obj", None)
- if container_obj:
- container_ids.append(container_obj.id)
- component.stop()
-
- for container_id in container_ids:
- for _ in range(5):
- try:
- self.docker_client.containers.get(container_id)
- except docker.errors.NotFound:
- break
- time.sleep(2)
-
- self.volume_node_1.remove()
- self.volume_node_2.remove()
-
- def setup_env(self):
- # reset on every test run in case something touches it
- cluster_settings.writer_port_map = {
- self.data["writer1_internal_host"]: self.data["writer1"]["port"],
- self.data["writer2_internal_host"]: self.data["writer2"]["port"],
- }
- cluster_settings.reader_port_map = {
- self.data["writer1_internal_host"]: self.data["reader1"]["port"],
- self.data["writer2_internal_host"]: self.data["reader2"]["port"],
- }
-
- cluster_settings.node_writer_port = None # type: ignore
- cluster_settings.node_reader_port = None # type: ignore
-
- cluster_settings.cluster_discovery_mode = "manual"
- cluster_settings.cluster_discovery_manual_addresses = [
- self.data["writer1_internal_host"],
- self.data["writer2_internal_host"],
- ]
-
-
-@pytest.fixture(scope="session")
-def gcs_node_storage(gcs):
- return NodeGCSStorage(server=gcs)
-
-
-@pytest.fixture(scope="session")
-def s3_node_storage(s3):
- return NodeS3Storage(server=s3)
-
-
-@pytest.fixture(scope="session")
-def node_storage(request):
- backend = get_testing_storage_backend()
- if backend == "gcs":
- return request.getfixturevalue("gcs_node_storage")
- elif backend == "s3":
- return request.getfixturevalue("s3_node_storage")
- else:
- print(f"Unknown storage backend {backend}, using gcs")
- return request.getfixturevalue("gcs_node_storage")
-
-
-@pytest.fixture(scope="session")
-def gcs_nidx_storage(gcs):
- return {
- "INDEXER__OBJECT_STORE": "gcs",
- "INDEXER__BUCKET": "indexing",
- "INDEXER__ENDPOINT": gcs,
- "STORAGE__OBJECT_STORE": "gcs",
- "STORAGE__ENDPOINT": gcs,
- "STORAGE__BUCKET": "nidx",
- }
-
-
-@pytest.fixture(scope="session")
-def s3_nidx_storage(s3):
- return {
- "INDEXER__OBJECT_STORE": "s3",
- "INDEXER__BUCKET": "indexing",
- "INDEXER__ENDPOINT": s3,
- "STORAGE__OBJECT_STORE": "s3",
- "STORAGE__ENDPOINT": s3,
- "STORAGE__BUCKET": "nidx",
- }
-
-
-@pytest.fixture(scope="session")
-def nidx_storage(request):
- backend = get_testing_storage_backend()
- if backend == "gcs":
- return request.getfixturevalue("gcs_nidx_storage")
- elif backend == "s3":
- return request.getfixturevalue("s3_nidx_storage")
-
-
-@pytest.fixture(scope="session", autouse=False)
-def _node(natsd: str, node_storage):
- nr = _NodeRunner(natsd, node_storage)
- try:
- cluster_info = nr.start()
- except Exception:
- nr.stop()
- raise
- nr.setup_env()
- yield cluster_info
- nr.stop()
-
-
-@pytest.fixture(scope="session")
-async def _nidx(natsd, nidx_storage, pg):
- if not NIDX_ENABLED:
- yield
- return
-
- # Create needed NATS stream/consumer
- nc = await nats.connect(servers=[natsd])
- js = nc.jetstream()
- await js.add_stream(name="nidx", subjects=["nidx"])
- await js.add_consumer(stream="nidx", config=ConsumerConfig(name="nidx"))
- await nc.drain()
- await nc.close()
-
- # Run nidx
- images.settings["nidx"]["env"] = {
- "RUST_LOG": "info",
- "METADATA__DATABASE_URL": f"postgresql://postgres:postgres@172.17.0.1:{pg[1]}/postgres",
- "INDEXER__NATS_SERVER": natsd.replace("localhost", "172.17.0.1"),
- **nidx_storage,
- }
- image = NidxImage()
- image.run()
-
- api_port = image.get_port(10000)
- searcher_port = image.get_port(10001)
-
- # Configure settings
- from nucliadb_utils.settings import indexing_settings
-
- cluster_settings.nidx_api_address = f"localhost:{api_port}"
- cluster_settings.nidx_searcher_address = f"localhost:{searcher_port}"
- indexing_settings.index_nidx_subject = "nidx"
-
- yield
-
- image.stop()
-
-
-@pytest.fixture(scope="function")
-def node(_nidx, _node, request):
- # clean up all shard data before each test
- channel1 = insecure_channel(f"{_node['writer1']['host']}:{_node['writer1']['port']}")
- channel2 = insecure_channel(f"{_node['writer2']['host']}:{_node['writer2']['port']}")
- writer1 = NodeWriterStub(channel1)
- writer2 = NodeWriterStub(channel2)
-
- logger.debug("cleaning up shards data")
- try:
- cleanup_node(writer1)
- cleanup_node(writer2)
- except Exception:
- logger.error(
- "Error cleaning up shards data. Maybe the node fixture could not start properly?",
- exc_info=True,
- )
-
- client = docker.client.from_env()
- containers_by_port = {}
- for container in client.containers.list():
- name = container.name
- command = container.attrs["Config"]["Cmd"]
- ports = container.ports
- print(f"container {name} executing {command} is using ports: {ports}")
-
- for internal_port in container.ports:
- for host in container.ports[internal_port]:
- port = host["HostPort"]
- port_containers = containers_by_port.setdefault(port, [])
- if container not in port_containers:
- port_containers.append(container)
-
- for port, containers in containers_by_port.items():
- if len(containers) > 1:
- names = ", ".join([container.name for container in containers])
- print(f"ATENTION! Containers {names} share port {port}!")
- raise
- finally:
- channel1.close()
- channel2.close()
-
- yield _node
-
-
-@backoff.on_exception(backoff.expo, Exception, jitter=backoff.random_jitter, max_tries=5)
-def cleanup_node(writer: NodeWriterStub):
- for shard in writer.ListShards(EmptyQuery()).ids:
- writer.DeleteShard(ShardId(id=shard.id))