Source code for tca.config

from __future__ import annotations

from dataclasses import asdict, dataclass
import json
from pathlib import Path
from typing import Any


[docs] @dataclass(slots=True) class TurboQuantConfig: bit_width: int = 3 candidate_k: int = 128 rerank_k: int = 20 oversample: int = 2 seed: int = 0 quantizer_kind: str = "prod" lloyd_max_iter: int = 100 lloyd_tol: float = 1e-6 monte_carlo_samples: int = 20000 store_original_embeddings: bool = True auto_score_gap_threshold: float = 0.06 auto_score_spread_threshold: float = 0.015 max_candidate_k: int = 2048 max_oversample: int = 8
[docs] def validate(self) -> "TurboQuantConfig": if self.bit_width < 1: raise ValueError("bit_width must be >= 1") if self.candidate_k < 1: raise ValueError("candidate_k must be >= 1") if self.rerank_k < 1: raise ValueError("rerank_k must be >= 1") if self.oversample < 1: raise ValueError("oversample must be >= 1") if self.max_candidate_k < self.candidate_k: raise ValueError("max_candidate_k must be >= candidate_k") if self.max_oversample < self.oversample: raise ValueError("max_oversample must be >= oversample") if self.quantizer_kind not in {"mse", "prod"}: raise ValueError("quantizer_kind must be either 'mse' or 'prod'") if self.quantizer_kind == "prod" and self.bit_width < 2: raise ValueError("TurboQuant 'prod' requires bit_width >= 2") return self
[docs] def to_dict(self) -> dict[str, Any]: return asdict(self)
[docs] def to_json(self, path: str | Path) -> None: Path(path).write_text(json.dumps(self.to_dict(), indent=2), encoding="utf-8")
[docs] @classmethod def from_json(cls, path: str | Path) -> "TurboQuantConfig": payload = json.loads(Path(path).read_text(encoding="utf-8")) return cls(**payload).validate()