Advanced Datashare worker¶
In this section we'll augment the worker template app (translation and classification) with vector store to allow us to perform semantic similarity searches between queries and Datashare docs.
Make sure you've followed the basic worker example to understand the basics !
Clone the template repository¶
Start over and clone the template repository once again:
Install extra dependencies¶
We'll use LanceDB to implement our vector store, we need to add it as well as the sentence-transformers to our dependencies:
Note
In a production setup, since elasticsearch implements its own vector database it might have been convenient to use it. For this examples, we're using LanceDB as it's embedded and doesn't require any deployment update.
Embedding Datashare documents¶
For the demo purpose, we'll split the task of embedding docs into two tasks:
- the
create_vectorization_tasks
which scans the index, get IDs of Datashare docs and batch them and createvectorize_docs
tasks - the
vectorize_docs
tasks (triggered by thecreate_vectorization_tasks
task) receives docs IDs, fetch the doc contents from the index and add them to vector database
Note
We could have performed vectorization in a single task, having first task splitting a large tasks into batches/chunks is a commonly used pattern to distribute heavy workloads across workers (learn more in the task workflow guide).
The create_vectorization_tasks
task¶
The create_vectorization_tasks
is defined in the tasks/vectorize.py
file as following:
import asyncio
import logging
from collections.abc import AsyncIterable
from typing import AsyncIterator
import numpy as np
from icij_common.es import (
DOC_CONTENT,
ESClient,
HITS,
ID_,
QUERY,
SOURCE,
ids_query,
make_document_query,
match_all,
)
from icij_worker.ds_task_client import DatashareTaskClient
from lancedb import AsyncConnection as LanceDBConnection, AsyncTable
from lancedb.embeddings import get_registry
from lancedb.index import FTS, IvfPq
from lancedb.pydantic import LanceModel, Vector
from sentence_transformers import SentenceTransformer
from datashare_python.constants import PYTHON_TASK_GROUP
from datashare_python.tasks.dependencies import (
lifespan_es_client,
lifespan_task_client,
lifespan_vector_db,
)
from datashare_python.utils import async_batches
logger = logging.getLogger(__name__)
async def recreate_vector_table(
vector_db: LanceDBConnection, schema: type[LanceModel]
) -> AsyncTable:
table_name = "ds_docs"
existing_tables = await vector_db.table_names()
if table_name in existing_tables:
logging.info("deleting previous vector db...")
await vector_db.drop_table(table_name)
table = await vector_db.create_table(table_name, schema=schema)
return table
def make_record_schema(model: str) -> type[LanceModel]:
model = get_registry().get("huggingface").create(name=model)
class RecordSchema(LanceModel):
doc_id: str
content: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
return RecordSchema
async def create_vectorization_tasks(
project: str,
*,
model: str = "BAAI/bge-small-en-v1.5",
es_client: ESClient | None = None,
task_client: DatashareTaskClient | None = None,
vector_db: LanceDBConnection | None = None,
batch_size: int = 16,
) -> list[str]:
if es_client is None:
es_client = lifespan_es_client()
if task_client is None:
task_client = lifespan_task_client()
if vector_db is None:
vector_db = lifespan_vector_db()
schema = make_record_schema(model)
await recreate_vector_table(vector_db, schema)
query = make_document_query(match_all())
docs_pages = es_client.poll_search_pages(
index=project, body=query, sort="_doc:asc", _source=False
)
doc_ids = (doc[ID_] async for doc in _flatten_search_pages(docs_pages))
batches = async_batches(doc_ids, batch_size=batch_size)
logging.info("spawning vectorization tasks...")
args = {"project": project}
task_ids = []
async for batch in batches:
args["docs"] = list(batch)
task_id = await task_client.create_task(
"vectorize_docs", args, group=PYTHON_TASK_GROUP.name
)
task_ids.append(task_id)
logging.info("created %s vectorization tasks !", len(task_ids))
return task_ids
async def _flatten_search_pages(pages: AsyncIterable[dict]) -> AsyncIterator[dict]:
async for page in pages:
for doc in page[HITS][HITS]:
yield doc
The function starts by creating a schema for our vector DB table using the convenient
LanceDB embedding function feature,
which will automatically create the record vector field from the provided source field (
content` in our case) using
our HuggingFace embedding model:
def make_record_schema(model: str) -> type[LanceModel]:
model = get_registry().get("huggingface").create(name=model)
class RecordSchema(LanceModel):
doc_id: str
content: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
return RecordSchema
We then (re)-create a vector table using the DB connection provided by dependency injection (see the next section to learn more):
async def recreate_vector_table(
vector_db: LanceDBConnection, schema: type[LanceModel]
) -> AsyncTable:
table_name = "ds_docs"
existing_tables = await vector_db.table_names()
if table_name in existing_tables:
logging.info("deleting previous vector db...")
await vector_db.drop_table(table_name)
table = await vector_db.create_table(table_name, schema=schema)
return table
Next create_vectorization_tasks
queries the index matching all documents:
batch_size
:
docs_pages = es_client.poll_search_pages(
index=project, body=query, sort="_doc:asc", _source=False
)
doc_ids = (doc[ID_] async for doc in _flatten_search_pages(docs_pages))
batches = async_batches(doc_ids, batch_size=batch_size)
Finally, for each batch, it spawns a vectorization task using the datashare task client and returns the list of created tasks:
args = {"project": project}
task_ids = []
async for batch in batches:
args["docs"] = list(batch)
task_id = await task_client.create_task(
"vectorize_docs", args, group=PYTHON_TASK_GROUP.name
)
task_ids.append(task_id)
logging.info("created %s vectorization tasks !", len(task_ids))
return task_ids
The lifespan_vector_db
dependency injection¶
In order to avoid to re-create a DB connection each time the worker processes a task, we leverage dependency injection in order to create the connection at start up and retrieve it inside our function.
This pattern is already used for the elasticsearch client and the datashare task client, to use it for the vector DB connection, we'll need to update the dependencies.py file.
First we need to implement the dependency setup function:
from lancedb import AsyncConnection as LanceDBConnection, connect_async
from datashare_python.constants import DATA_DIR
_VECTOR_DB_CONNECTION: LanceDBConnection | None = None
_DB_PATH = DATA_DIR / "vector.db"
async def vector_db_setup(**_):
global _VECTOR_DB_CONNECTION
_VECTOR_DB_CONNECTION = await connect_async(_DB_PATH)
The function creates a connection to the vector DB located on the filesystem and stores the connection to a global variable.
We then have to implement a function to make this global available to the rest of the codebase:
def lifespan_vector_db() -> LanceDBConnection:
if _VECTOR_DB_CONNECTION is None:
raise DependencyInjectionError("vector db connection")
return _VECTOR_DB_CONNECTION
AsyncConnection.__aexit__
methode:
async def vector_db_teardown(exc_type, exc_val, exc_tb):
await lifespan_vector_db().__aexit__(exc_type, exc_val, exc_tb)
global _VECTOR_DB_CONNECTION
_VECTOR_DB_CONNECTION = None
Read the dependency injection guide to learn more !
The vectorize_docs
task¶
Next we implement the vectorize_docs
as following:
async def vectorize_docs(
docs: list[str],
project: str,
*,
es_client: ESClient | None = None,
vector_db: LanceDBConnection | None = None,
) -> int:
if es_client is None:
es_client = lifespan_es_client()
if vector_db is None:
vector_db = lifespan_vector_db()
n_docs = len(docs)
logging.info("vectorizing %s docs...", n_docs)
query = {QUERY: ids_query(docs)}
docs_pages = es_client.poll_search_pages(
index=project, body=query, sort="_doc:asc", _source_includes=[DOC_CONTENT]
)
es_docs = _flatten_search_pages(docs_pages)
table = await vector_db.open_table("ds_docs")
records = [
{"doc_id": d[ID_], "content": d[SOURCE][DOC_CONTENT]} async for d in es_docs
]
await table.add(records)
logging.info("vectorized %s docs !", n_docs)
return n_docs
The task function starts by retriving the batch document contents, querying the index by doc IDs:
query = {QUERY: ids_query(docs)}
docs_pages = es_client.poll_search_pages(
index=project, body=query, sort="_doc:asc", _source_includes=[DOC_CONTENT]
)
es_docs = _flatten_search_pages(docs_pages)
Finally, we add each doc content to the vector DB table, because we created table using a schema and the
embedding function feature, the embedding vector
will be automatically created from the content
source field:
table = await vector_db.open_table("ds_docs")
records = [
{"doc_id": d[ID_], "content": d[SOURCE][DOC_CONTENT]} async for d in es_docs
]
await table.add(records)
Semantic similarity search¶
Now that we've built a vector store from Datashare's docs, we need to query it. Let's create a find_most_similar
task which find the most similar docs for a provided set of queries.
The task function starts by loading the embedding model and vectorizes the input queries:
async def find_most_similar(
queries: list[str],
model: str,
*,
vector_db: LanceDBConnection | None = None,
n_similar: int = 2,
) -> list[list[dict]]:
if vector_db is None:
vector_db = lifespan_vector_db()
n_queries = len(queries)
logging.info("performing similarity search for %s queries...", n_queries)
table = await vector_db.open_table("ds_docs")
# Create indexes for hybrid search
try:
await table.create_index(
"vector", config=IvfPq(distance_type="cosine"), replace=False
)
await table.create_index("content", config=FTS(), replace=False)
except RuntimeError:
logging.debug("skipping index creation as they already exist")
vectorizer = SentenceTransformer(model)
vectors = vectorizer.encode(queries)
futures = (
_find_most_similar(table, q, v, n_similar) for q, v in zip(queries, vectors)
)
results = await asyncio.gather(*futures)
results = sum(results, start=[])
logging.info("completed similarity search for %s queries !", n_queries)
return results
async def _find_most_similar(
table: AsyncTable, query: str, vector: np.ndarray, n_similar: int
) -> list[dict]:
# pylint: disable=unused-argument
most_similar = (
await table.query()
# The async client seems to be bugged and does really support hybrid queries
# .nearest_to_text(query, columns=["content"])
.nearest_to(vector)
.limit(n_similar)
.select(["doc_id"])
.to_list()
)
most_similar = [
{"doc_id": s["doc_id"], "distance": s["_distance"]} for s in most_similar
]
return most_similar
it then performs an hybrid search, using both the input query vector and its text:
async def _find_most_similar(
table: AsyncTable, query: str, vector: np.ndarray, n_similar: int
) -> list[dict]:
# pylint: disable=unused-argument
most_similar = (
await table.query()
# The async client seems to be bugged and does really support hybrid queries
# .nearest_to_text(query, columns=["content"])
.nearest_to(vector)
.limit(n_similar)
.select(["doc_id"])
.to_list()
)
most_similar = [
{"doc_id": s["doc_id"], "distance": s["_distance"]} for s in most_similar
]
return most_similar
Registering the new tasks¶
In order to turn our function into a Datashare task, we have to register it into the
app
async app variable of the
app.py file, using the @task
decorator:
from typing import Optional
from icij_worker import AsyncApp
from icij_worker.typing_ import PercentProgress
from pydantic import parse_obj_as
from datashare_python.constants import PYTHON_TASK_GROUP
from datashare_python.objects import ClassificationConfig, TranslationConfig
from datashare_python.tasks import (
classify_docs as classify_docs_,
create_classification_tasks as create_classification_tasks_,
create_translation_tasks as create_translation_tasks_,
translate_docs as translate_docs_,
)
from datashare_python.tasks.dependencies import APP_LIFESPAN_DEPS
from datashare_python.tasks.vectorize import (
create_vectorization_tasks as create_vectorization_tasks_,
find_most_similar as find_most_similar_,
vectorize_docs as vectorization_docs_,
)
app = AsyncApp("ml", dependencies=APP_LIFESPAN_DEPS)
@app.task(group=PYTHON_TASK_GROUP)
async def create_vectorization_tasks(
project: str, model: str = "BAAI/bge-small-en-v1.5"
) -> list[str]:
return await create_vectorization_tasks_(project, model=model)
@app.task(group=PYTHON_TASK_GROUP)
async def vectorization_docs(docs: list[str], project: str) -> int:
return await vectorization_docs_(docs, project)
@app.task(group=PYTHON_TASK_GROUP)
async def find_most_similar(
queries: list[str], model: str, n_similar: int = 2
) -> list[list[dict]]:
return await find_most_similar_(queries, model, n_similar=n_similar)
Testing¶
Finally, we implement some tests in the tests/tasks/test_vectorize.py
file:
from pathlib import Path
from typing import List
import pytest
from icij_common.es import ESClient
from lancedb import AsyncConnection as LanceDBConnection, connect_async
from datashare_python.objects import Document
from datashare_python.tasks.vectorize import (
create_vectorization_tasks,
find_most_similar,
make_record_schema,
recreate_vector_table,
vectorize_docs,
)
from datashare_python.tests.conftest import TEST_PROJECT
from datashare_python.utils import DSTaskClient
@pytest.fixture
async def test_vector_db(tmpdir) -> LanceDBConnection:
db = await connect_async(Path(tmpdir) / "test_vectors.db")
return db
@pytest.mark.integration
async def test_create_vectorization_tasks(
populate_es: List[Document], # pylint: disable=unused-argument
test_es_client: ESClient,
test_task_client: DSTaskClient,
test_vector_db: LanceDBConnection,
):
# When
task_ids = await create_vectorization_tasks(
project=TEST_PROJECT,
es_client=test_es_client,
task_client=test_task_client,
vector_db=test_vector_db,
batch_size=2,
)
# Then
assert len(task_ids) == 2
@pytest.mark.integration
async def test_vectorize_docs(
populate_es: List[Document], # pylint: disable=unused-argument
test_es_client: ESClient,
test_vector_db: LanceDBConnection,
):
# Given
model = "BAAI/bge-small-en-v1.5"
docs = ["doc-0", "doc-3"]
schema = make_record_schema(model)
await recreate_vector_table(test_vector_db, schema)
# When
n_vectorized = await vectorize_docs(
docs,
TEST_PROJECT,
es_client=test_es_client,
vector_db=test_vector_db,
)
# Then
assert n_vectorized == 2
table = await test_vector_db.open_table("ds_docs")
records = await table.query().to_list()
assert len(records) == 2
doc_ids = sorted(d["doc_id"] for d in records)
assert doc_ids == ["doc-0", "doc-3"]
assert all("vector" in r for r in records)
@pytest.mark.integration
async def test_find_most_similar(test_vector_db: LanceDBConnection):
# Given
model = "BAAI/bge-small-en-v1.5"
schema = make_record_schema(model)
table = await recreate_vector_table(test_vector_db, schema)
docs = [
{"doc_id": "novel", "content": "I'm a doc about novels"},
{"doc_id": "monkey", "content": "I'm speaking about monkeys"},
]
await table.add(docs)
queries = ["doc about books", "doc speaking about animal"]
# When
n_similar = 1
most_similar = await find_most_similar(
queries, model, vector_db=test_vector_db, n_similar=n_similar
)
# Then
assert len(most_similar) == 2
similar_ids = [s["doc_id"] for s in most_similar]
assert similar_ids == ["novel", "monkey"]
assert all("distance" in s for s in most_similar)
We can then run the tests after starting test services using the datashare-python
Docker Compose wrapper:
collected 3 items
datashare_python/tests/tasks/test_vectorize.py ... [100%]
====== 3 passed in 6.87s ======
....
Summary¶
We've successfully added a vector store to Datashare !
Rather than copy-pasting the above code blocks, you can replace/update your codebase with the following files: