Source code for tca.cli

from __future__ import annotations

import argparse
from dataclasses import asdict
from pathlib import Path

import numpy as np

from tca.config import TurboQuantConfig
from tca.io import load_embeddings, load_metadata, save_json
from tca.pipeline import SearchIndex


[docs] def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(prog="tca", description="TurboCell Atlas CLI") subparsers = parser.add_subparsers(dest="command", required=True) search = subparsers.add_parser("search", help="Run TurboQuant candidate search + exact rerank") search.add_argument("--embeddings", required=True, help="Reference embeddings (.npy or .npz)") search.add_argument("--metadata", required=True, help="Metadata (.jsonl, .json, .csv)") search.add_argument("--query", required=True, help="Query embedding (.npy)") search.add_argument("--output", required=True, help="Output JSON path") search.add_argument("--bit-width", type=int, default=3) search.add_argument("--candidate-k", type=int, default=128) search.add_argument("--rerank-k", type=int, default=20) search.add_argument("--oversample", type=int, default=2) search.add_argument("--seed", type=int, default=0) search.add_argument("--quantizer-kind", choices=["mse", "prod"], default="prod") search.add_argument("--filter", action="append", default=[], help="Metadata filter in key=value form") return parser
def _parse_filters(entries: list[str]) -> dict[str, str]: parsed: dict[str, str] = {} for entry in entries: if "=" not in entry: raise ValueError(f"Invalid filter: {entry}. Expected key=value") key, value = entry.split("=", 1) parsed[key] = value return parsed
[docs] def main() -> None: parser = build_parser() args = parser.parse_args() if args.command == "search": config = TurboQuantConfig( bit_width=args.bit_width, candidate_k=args.candidate_k, rerank_k=args.rerank_k, oversample=args.oversample, seed=args.seed, quantizer_kind=args.quantizer_kind, ).validate() embeddings = load_embeddings(args.embeddings) metadata = load_metadata(args.metadata) query = load_embeddings(args.query) if query.ndim == 2: query = query[0] index = SearchIndex.from_embeddings(embeddings=embeddings, metadata=metadata, config=config) results = index.search(query=query, filters=_parse_filters(args.filter)) payload = { "manifest": index.to_manifest(), "results": [asdict(item) for item in results.items], "backend": results.backend, "candidate_count": results.candidate_count, "filtered_count": results.filtered_count, } save_json(args.output, payload)
if __name__ == "__main__": main()