Source code for tca.pipeline

from __future__ import annotations

from dataclasses import asdict
from pathlib import Path
from typing import Any

import numpy as np

from tca.config import TurboQuantConfig
from tca.planning import plan_query
from tca.quantization import EncodedProd, TurboQuantMSE, TurboQuantProd, exact_topk
from tca.types import MetadataRecord, SearchItem, SearchResults


[docs] class SearchIndex: def __init__( self, embeddings: np.ndarray, metadata: list[MetadataRecord], config: TurboQuantConfig, ) -> None: config.validate() if embeddings.ndim != 2: raise ValueError("embeddings must be a 2D array") if len(metadata) != embeddings.shape[0]: raise ValueError("metadata length must match number of embeddings") self.embeddings = np.asarray(embeddings, dtype=np.float32) self.metadata = metadata self.config = config self.dimension = embeddings.shape[1] self.ids = np.array( [str(record.get("cell_id", idx)) for idx, record in enumerate(metadata)], dtype=object, ) if config.quantizer_kind == "mse": self.quantizer = TurboQuantMSE( dimension=self.dimension, bit_width=config.bit_width, seed=config.seed, monte_carlo_samples=config.monte_carlo_samples, lloyd_max_iter=config.lloyd_max_iter, lloyd_tol=config.lloyd_tol, ) self.encoded = self.quantizer.encode(self.embeddings) else: self.quantizer = TurboQuantProd( dimension=self.dimension, bit_width=config.bit_width, seed=config.seed, monte_carlo_samples=config.monte_carlo_samples, lloyd_max_iter=config.lloyd_max_iter, lloyd_tol=config.lloyd_tol, ) self.encoded = self.quantizer.encode(self.embeddings) self.embedding_norms = np.linalg.norm(self.embeddings, axis=1).astype(np.float32) self.metadata_columns: dict[str, np.ndarray] = {} if metadata: keys = set().union(*(record.keys() for record in metadata)) for key in keys: self.metadata_columns[key] = np.asarray([record.get(key) for record in metadata], dtype=object)
[docs] @classmethod def from_embeddings( cls, embeddings: np.ndarray, metadata: list[MetadataRecord], config: TurboQuantConfig | None = None, ) -> "SearchIndex": return cls(embeddings=embeddings, metadata=metadata, config=config or TurboQuantConfig())
def _mask_for_filters(self, filters: dict[str, Any] | None) -> np.ndarray: mask = np.ones(len(self.metadata), dtype=bool) if not filters: return mask for key, value in filters.items(): allowed = value if isinstance(value, (list, tuple, set)) else [value] column = self.metadata_columns.get(key) if column is None: return np.zeros(len(self.metadata), dtype=bool) mask &= np.isin(column, list(allowed)) return mask def _candidate_scores(self, query: np.ndarray, active_indices: np.ndarray) -> np.ndarray: if isinstance(self.quantizer, TurboQuantProd): subset = EncodedProd( indices=self.encoded.indices[active_indices], signs=self.encoded.signs[active_indices], residual_norms=self.encoded.residual_norms[active_indices], ) return self.quantizer.approximate_inner_products(query, subset) return self.quantizer.approximate_inner_products(query, self.encoded.indices[active_indices])
[docs] def search( self, query: np.ndarray, filters: dict[str, Any] | None = None, *, query_mode: str = "auto", ) -> SearchResults: query_arr = np.asarray(query, dtype=np.float32) if query_arr.ndim != 1 or query_arr.shape[0] != self.dimension: raise ValueError("query must be a 1D vector with the same dimensionality as the index") mask = self._mask_for_filters(filters) active_indices = np.flatnonzero(mask) if active_indices.size == 0: return SearchResults(items=[], backend="turboquant", candidate_count=0, filtered_count=0, diagnostics={"query_mode": query_mode}) approx_scores = self._candidate_scores(query_arr, active_indices) query_plan = plan_query(approx_scores=approx_scores, config=self.config, query_mode=query_mode) candidate_count = min(active_indices.size, query_plan.candidate_k * query_plan.oversample) approx_order = np.argpartition(-approx_scores, kth=candidate_count - 1)[:candidate_count] candidate_indices = active_indices[approx_order] final_ids, final_scores = exact_topk( query=query_arr, bank=self.embeddings[candidate_indices], top_k=min(self.config.rerank_k, candidate_indices.size), ids=candidate_indices.tolist(), ) items = [ SearchItem( rank=rank + 1, item_id=str(self.ids[idx]), score=float(score), metadata=self.metadata[int(idx)], ) for rank, (idx, score) in enumerate(zip(final_ids.tolist(), final_scores.tolist(), strict=True)) ] return SearchResults( items=items, backend="turboquant", candidate_count=int(candidate_indices.size), filtered_count=int(active_indices.size), diagnostics={ "query_mode": query_plan.mode, "planned_candidate_k": query_plan.candidate_k, "planned_oversample": query_plan.oversample, "probe_count": query_plan.probe_count, "score_gap": query_plan.score_gap, "score_spread": query_plan.score_spread, }, )
[docs] def search_exact(self, query: np.ndarray, filters: dict[str, Any] | None = None) -> SearchResults: query_arr = np.asarray(query, dtype=np.float32) mask = self._mask_for_filters(filters) active_indices = np.flatnonzero(mask) if active_indices.size == 0: return SearchResults(items=[], backend="exact", candidate_count=0, filtered_count=0, diagnostics={}) final_ids, final_scores = exact_topk( query=query_arr, bank=self.embeddings[active_indices], top_k=min(self.config.rerank_k, active_indices.size), ids=active_indices.tolist(), ) items = [ SearchItem( rank=rank + 1, item_id=str(self.ids[idx]), score=float(score), metadata=self.metadata[int(idx)], ) for rank, (idx, score) in enumerate(zip(final_ids.tolist(), final_scores.tolist(), strict=True)) ] return SearchResults( items=items, backend="exact", candidate_count=int(active_indices.size), filtered_count=int(active_indices.size), diagnostics={}, )
[docs] def to_manifest(self) -> dict[str, Any]: return { "config": self.config.to_dict(), "n_cells": int(self.embeddings.shape[0]), "dimension": int(self.dimension), "quantizer_kind": self.config.quantizer_kind, }