Skip to content

Advanced Datashare worker

In this section we'll augment the worker template app (translation and classification) with vector store to allow us to perform semantic similarity searches between queries and Datashare docs.

Make sure you've followed the basic worker example to understand the basics !

Clone the template repository

Start over and clone the template repository once again:

git clone git@github.com:ICIJ/datashare-python.git

Install extra dependencies

We'll use LanceDB to implement our vector store, we need to add it as well as the sentence-transformers to our dependencies:

uv add lancedb sentence-transformers

Note

In a production setup, since elasticsearch implements its own vector database it might have been convenient to use it. For this examples, we're using LanceDB as it's embedded and doesn't require any deployment update.

Embedding Datashare documents

For the demo purpose, we'll split the task of embedding docs into two tasks:

  • the create_vectorization_tasks which scans the index, get IDs of Datashare docs and batch them and create vectorize_docs tasks
  • the vectorize_docs tasks (triggered by the create_vectorization_tasks task) receives docs IDs, fetch the doc contents from the index and add them to vector database

Note

We could have performed vectorization in a single task, having first task splitting a large tasks into batches/chunks is a commonly used pattern to distribute heavy workloads across workers (learn more in the task workflow guide).

The create_vectorization_tasks task

The create_vectorization_tasks is defined in the tasks/vectorize.py file as following:

tasks/vectorize.py
import asyncio
import logging
from collections.abc import AsyncIterable
from typing import AsyncIterator

import numpy as np
from icij_common.es import (
    DOC_CONTENT,
    ESClient,
    HITS,
    ID_,
    QUERY,
    SOURCE,
    ids_query,
    make_document_query,
    match_all,
)
from icij_worker.ds_task_client import DatashareTaskClient
from lancedb import AsyncConnection as LanceDBConnection, AsyncTable
from lancedb.embeddings import get_registry
from lancedb.index import FTS, IvfPq
from lancedb.pydantic import LanceModel, Vector
from sentence_transformers import SentenceTransformer

from datashare_python.constants import PYTHON_TASK_GROUP
from datashare_python.tasks.dependencies import (
    lifespan_es_client,
    lifespan_task_client,
    lifespan_vector_db,
)
from datashare_python.utils import async_batches

logger = logging.getLogger(__name__)


async def recreate_vector_table(
    vector_db: LanceDBConnection, schema: type[LanceModel]
) -> AsyncTable:
    table_name = "ds_docs"
    existing_tables = await vector_db.table_names()
    if table_name in existing_tables:
        logging.info("deleting previous vector db...")
        await vector_db.drop_table(table_name)
    table = await vector_db.create_table(table_name, schema=schema)
    return table


def make_record_schema(model: str) -> type[LanceModel]:
    model = get_registry().get("huggingface").create(name=model)

    class RecordSchema(LanceModel):
        doc_id: str
        content: str = model.SourceField()
        vector: Vector(model.ndims()) = model.VectorField()

    return RecordSchema


async def create_vectorization_tasks(
    project: str,
    *,
    model: str = "BAAI/bge-small-en-v1.5",
    es_client: ESClient | None = None,
    task_client: DatashareTaskClient | None = None,
    vector_db: LanceDBConnection | None = None,
    batch_size: int = 16,
) -> list[str]:
    if es_client is None:
        es_client = lifespan_es_client()
    if task_client is None:
        task_client = lifespan_task_client()
    if vector_db is None:
        vector_db = lifespan_vector_db()
    schema = make_record_schema(model)
    await recreate_vector_table(vector_db, schema)
    query = make_document_query(match_all())
    docs_pages = es_client.poll_search_pages(
        index=project, body=query, sort="_doc:asc", _source=False
    )
    doc_ids = (doc[ID_] async for doc in _flatten_search_pages(docs_pages))
    batches = async_batches(doc_ids, batch_size=batch_size)
    logging.info("spawning vectorization tasks...")
    args = {"project": project}
    task_ids = []
    async for batch in batches:
        args["docs"] = list(batch)
        task_id = await task_client.create_task(
            "vectorize_docs", args, group=PYTHON_TASK_GROUP.name
        )
        task_ids.append(task_id)
    logging.info("created %s vectorization tasks !", len(task_ids))
    return task_ids


async def _flatten_search_pages(pages: AsyncIterable[dict]) -> AsyncIterator[dict]:
    async for page in pages:
        for doc in page[HITS][HITS]:
            yield doc

The function starts by creating a schema for our vector DB table using the convenient LanceDB embedding function feature, which will automatically create the record vector field from the provided source field (content` in our case) using our HuggingFace embedding model:

tasks/vectorize.py
def make_record_schema(model: str) -> type[LanceModel]:
    model = get_registry().get("huggingface").create(name=model)

    class RecordSchema(LanceModel):
        doc_id: str
        content: str = model.SourceField()
        vector: Vector(model.ndims()) = model.VectorField()

    return RecordSchema

We then (re)-create a vector table using the DB connection provided by dependency injection (see the next section to learn more):

tasks/vectorize.py
async def recreate_vector_table(
    vector_db: LanceDBConnection, schema: type[LanceModel]
) -> AsyncTable:
    table_name = "ds_docs"
    existing_tables = await vector_db.table_names()
    if table_name in existing_tables:
        logging.info("deleting previous vector db...")
        await vector_db.drop_table(table_name)
    table = await vector_db.create_table(table_name, schema=schema)
    return table

Next create_vectorization_tasks queries the index matching all documents:

tasks/vectorize.py
    query = make_document_query(match_all())
and scroll through results pages creating batches of batch_size:
tasks/vectorize.py
    docs_pages = es_client.poll_search_pages(
        index=project, body=query, sort="_doc:asc", _source=False
    )
    doc_ids = (doc[ID_] async for doc in _flatten_search_pages(docs_pages))
    batches = async_batches(doc_ids, batch_size=batch_size)

Finally, for each batch, it spawns a vectorization task using the datashare task client and returns the list of created tasks:

tasks/vectorize.py
    args = {"project": project}
    task_ids = []
    async for batch in batches:
        args["docs"] = list(batch)
        task_id = await task_client.create_task(
            "vectorize_docs", args, group=PYTHON_TASK_GROUP.name
        )
        task_ids.append(task_id)
    logging.info("created %s vectorization tasks !", len(task_ids))
    return task_ids

The lifespan_vector_db dependency injection

In order to avoid to re-create a DB connection each time the worker processes a task, we leverage dependency injection in order to create the connection at start up and retrieve it inside our function.

This pattern is already used for the elasticsearch client and the datashare task client, to use it for the vector DB connection, we'll need to update the dependencies.py file.

First we need to implement the dependency setup function:

dependencies.py
from lancedb import AsyncConnection as LanceDBConnection, connect_async

from datashare_python.constants import DATA_DIR

_VECTOR_DB_CONNECTION: LanceDBConnection | None = None
_DB_PATH = DATA_DIR / "vector.db"


async def vector_db_setup(**_):
    global _VECTOR_DB_CONNECTION
    _VECTOR_DB_CONNECTION = await connect_async(_DB_PATH)

The function creates a connection to the vector DB located on the filesystem and stores the connection to a global variable.

We then have to implement a function to make this global available to the rest of the codebase:

dependencies.py
def lifespan_vector_db() -> LanceDBConnection:
    if _VECTOR_DB_CONNECTION is None:
        raise DependencyInjectionError("vector db connection")
    return _VECTOR_DB_CONNECTION
We also need to make sure the connection is properly exited when the worker stops by implementing the dependency tear down. We just call the AsyncConnection.__aexit__ methode:
dependencies.py
async def vector_db_teardown(exc_type, exc_val, exc_tb):
    await lifespan_vector_db().__aexit__(exc_type, exc_val, exc_tb)
    global _VECTOR_DB_CONNECTION
    _VECTOR_DB_CONNECTION = None

Read the dependency injection guide to learn more !

The vectorize_docs task

Next we implement the vectorize_docs as following:

tasks/vectorize.py
async def vectorize_docs(
    docs: list[str],
    project: str,
    *,
    es_client: ESClient | None = None,
    vector_db: LanceDBConnection | None = None,
) -> int:
    if es_client is None:
        es_client = lifespan_es_client()
    if vector_db is None:
        vector_db = lifespan_vector_db()
    n_docs = len(docs)
    logging.info("vectorizing %s docs...", n_docs)
    query = {QUERY: ids_query(docs)}
    docs_pages = es_client.poll_search_pages(
        index=project, body=query, sort="_doc:asc", _source_includes=[DOC_CONTENT]
    )
    es_docs = _flatten_search_pages(docs_pages)
    table = await vector_db.open_table("ds_docs")
    records = [
        {"doc_id": d[ID_], "content": d[SOURCE][DOC_CONTENT]} async for d in es_docs
    ]
    await table.add(records)
    logging.info("vectorized %s docs !", n_docs)
    return n_docs

The task function starts by retriving the batch document contents, querying the index by doc IDs:

tasks/vectorize.py
    query = {QUERY: ids_query(docs)}
    docs_pages = es_client.poll_search_pages(
        index=project, body=query, sort="_doc:asc", _source_includes=[DOC_CONTENT]
    )
    es_docs = _flatten_search_pages(docs_pages)

Finally, we add each doc content to the vector DB table, because we created table using a schema and the embedding function feature, the embedding vector will be automatically created from the content source field:

tasks/vectorize.py
    table = await vector_db.open_table("ds_docs")
    records = [
        {"doc_id": d[ID_], "content": d[SOURCE][DOC_CONTENT]} async for d in es_docs
    ]
    await table.add(records)

Now that we've built a vector store from Datashare's docs, we need to query it. Let's create a find_most_similar task which find the most similar docs for a provided set of queries.

The task function starts by loading the embedding model and vectorizes the input queries:

tasks/vectorize.py
async def find_most_similar(
    queries: list[str],
    model: str,
    *,
    vector_db: LanceDBConnection | None = None,
    n_similar: int = 2,
) -> list[list[dict]]:
    if vector_db is None:
        vector_db = lifespan_vector_db()
    n_queries = len(queries)
    logging.info("performing similarity search for %s queries...", n_queries)
    table = await vector_db.open_table("ds_docs")
    # Create indexes for hybrid search
    try:
        await table.create_index(
            "vector", config=IvfPq(distance_type="cosine"), replace=False
        )
        await table.create_index("content", config=FTS(), replace=False)
    except RuntimeError:
        logging.debug("skipping index creation as they already exist")
    vectorizer = SentenceTransformer(model)
    vectors = vectorizer.encode(queries)
    futures = (
        _find_most_similar(table, q, v, n_similar) for q, v in zip(queries, vectors)
    )
    results = await asyncio.gather(*futures)
    results = sum(results, start=[])
    logging.info("completed similarity search for %s queries !", n_queries)
    return results


async def _find_most_similar(
    table: AsyncTable, query: str, vector: np.ndarray, n_similar: int
) -> list[dict]:
    # pylint: disable=unused-argument
    most_similar = (
        await table.query()
        # The async client seems to be bugged and does really support hybrid queries
        # .nearest_to_text(query, columns=["content"])
        .nearest_to(vector)
        .limit(n_similar)
        .select(["doc_id"])
        .to_list()
    )
    most_similar = [
        {"doc_id": s["doc_id"], "distance": s["_distance"]} for s in most_similar
    ]
    return most_similar

it then performs an hybrid search, using both the input query vector and its text:

tasks/vectorize.py
async def _find_most_similar(
    table: AsyncTable, query: str, vector: np.ndarray, n_similar: int
) -> list[dict]:
    # pylint: disable=unused-argument
    most_similar = (
        await table.query()
        # The async client seems to be bugged and does really support hybrid queries
        # .nearest_to_text(query, columns=["content"])
        .nearest_to(vector)
        .limit(n_similar)
        .select(["doc_id"])
        .to_list()
    )
    most_similar = [
        {"doc_id": s["doc_id"], "distance": s["_distance"]} for s in most_similar
    ]
    return most_similar

Registering the new tasks

In order to turn our function into a Datashare task, we have to register it into the app async app variable of the app.py file, using the @task decorator:

app.py
from typing import Optional

from icij_worker import AsyncApp
from icij_worker.typing_ import PercentProgress
from pydantic import parse_obj_as

from datashare_python.constants import PYTHON_TASK_GROUP
from datashare_python.objects import ClassificationConfig, TranslationConfig
from datashare_python.tasks import (
    classify_docs as classify_docs_,
    create_classification_tasks as create_classification_tasks_,
    create_translation_tasks as create_translation_tasks_,
    translate_docs as translate_docs_,
)
from datashare_python.tasks.dependencies import APP_LIFESPAN_DEPS
from datashare_python.tasks.vectorize import (
    create_vectorization_tasks as create_vectorization_tasks_,
    find_most_similar as find_most_similar_,
    vectorize_docs as vectorization_docs_,
)

app = AsyncApp("ml", dependencies=APP_LIFESPAN_DEPS)


@app.task(group=PYTHON_TASK_GROUP)
async def create_vectorization_tasks(
    project: str, model: str = "BAAI/bge-small-en-v1.5"
) -> list[str]:
    return await create_vectorization_tasks_(project, model=model)


@app.task(group=PYTHON_TASK_GROUP)
async def vectorization_docs(docs: list[str], project: str) -> int:
    return await vectorization_docs_(docs, project)


@app.task(group=PYTHON_TASK_GROUP)
async def find_most_similar(
    queries: list[str], model: str, n_similar: int = 2
) -> list[list[dict]]:
    return await find_most_similar_(queries, model, n_similar=n_similar)

Testing

Finally, we implement some tests in the tests/tasks/test_vectorize.py file:

tests/tasks/test_vectorize.py
from pathlib import Path
from typing import List

import pytest
from icij_common.es import ESClient
from lancedb import AsyncConnection as LanceDBConnection, connect_async

from datashare_python.objects import Document
from datashare_python.tasks.vectorize import (
    create_vectorization_tasks,
    find_most_similar,
    make_record_schema,
    recreate_vector_table,
    vectorize_docs,
)
from datashare_python.tests.conftest import TEST_PROJECT
from datashare_python.utils import DSTaskClient


@pytest.fixture
async def test_vector_db(tmpdir) -> LanceDBConnection:
    db = await connect_async(Path(tmpdir) / "test_vectors.db")
    return db


@pytest.mark.integration
async def test_create_vectorization_tasks(
    populate_es: List[Document],  # pylint: disable=unused-argument
    test_es_client: ESClient,
    test_task_client: DSTaskClient,
    test_vector_db: LanceDBConnection,
):
    # When
    task_ids = await create_vectorization_tasks(
        project=TEST_PROJECT,
        es_client=test_es_client,
        task_client=test_task_client,
        vector_db=test_vector_db,
        batch_size=2,
    )
    # Then
    assert len(task_ids) == 2


@pytest.mark.integration
async def test_vectorize_docs(
    populate_es: List[Document],  # pylint: disable=unused-argument
    test_es_client: ESClient,
    test_vector_db: LanceDBConnection,
):
    # Given
    model = "BAAI/bge-small-en-v1.5"
    docs = ["doc-0", "doc-3"]
    schema = make_record_schema(model)
    await recreate_vector_table(test_vector_db, schema)

    # When
    n_vectorized = await vectorize_docs(
        docs,
        TEST_PROJECT,
        es_client=test_es_client,
        vector_db=test_vector_db,
    )
    # Then
    assert n_vectorized == 2
    table = await test_vector_db.open_table("ds_docs")
    records = await table.query().to_list()
    assert len(records) == 2
    doc_ids = sorted(d["doc_id"] for d in records)
    assert doc_ids == ["doc-0", "doc-3"]
    assert all("vector" in r for r in records)


@pytest.mark.integration
async def test_find_most_similar(test_vector_db: LanceDBConnection):
    # Given
    model = "BAAI/bge-small-en-v1.5"
    schema = make_record_schema(model)
    table = await recreate_vector_table(test_vector_db, schema)
    docs = [
        {"doc_id": "novel", "content": "I'm a doc about novels"},
        {"doc_id": "monkey", "content": "I'm speaking about monkeys"},
    ]
    await table.add(docs)
    queries = ["doc about books", "doc speaking about animal"]

    # When
    n_similar = 1
    most_similar = await find_most_similar(
        queries, model, vector_db=test_vector_db, n_similar=n_similar
    )
    # Then
    assert len(most_similar) == 2
    similar_ids = [s["doc_id"] for s in most_similar]
    assert similar_ids == ["novel", "monkey"]
    assert all("distance" in s for s in most_similar)

We can then run the tests after starting test services using the datashare-python Docker Compose wrapper:

./datashare-python up -d postgresql redis elasticsearch rabbitmq datashare_webuv run --frozen pytest datashare_python/tests/tasks/test_vectorize.py===== test session starts =====
collected 3 items

datashare_python/tests/tasks/test_vectorize.py ... [100%]

====== 3 passed in 6.87s ======
....

Summary

We've successfully added a vector store to Datashare !

Rather than copy-pasting the above code blocks, you can replace/update your codebase with the following files: