Source code for tca.quantization

from __future__ import annotations

from dataclasses import dataclass
import math
from typing import Iterable

import numpy as np
from numpy.typing import ArrayLike, NDArray


FloatArray = NDArray[np.float32]
IntArray = NDArray[np.int32]
SignArray = NDArray[np.int8]


def _l2_normalize(x: FloatArray, eps: float = 1e-12) -> FloatArray:
    norms = np.linalg.norm(x, axis=-1, keepdims=True)
    norms = np.maximum(norms, eps)
    return (x / norms).astype(np.float32, copy=False)


def _orthogonal_matrix(dimension: int, seed: int) -> FloatArray:
    rng = np.random.default_rng(seed)
    base = rng.normal(size=(dimension, dimension)).astype(np.float32)
    q, r = np.linalg.qr(base)
    signs = np.sign(np.diag(r))
    signs[signs == 0] = 1
    q = q * signs
    return q.astype(np.float32, copy=False)


def _sample_rotated_coordinate_distribution(
    dimension: int,
    n_samples: int,
    seed: int,
) -> FloatArray:
    rng = np.random.default_rng(seed)
    samples = rng.normal(size=(n_samples, dimension)).astype(np.float32)
    samples = _l2_normalize(samples)
    return samples[:, 0].astype(np.float32, copy=False)


[docs] def fit_scalar_codebook( bit_width: int, dimension: int, n_samples: int = 20000, seed: int = 0, max_iter: int = 100, tol: float = 1e-6, ) -> FloatArray: if bit_width < 1: raise ValueError("bit_width must be >= 1") n_centroids = 2**bit_width data = np.sort(_sample_rotated_coordinate_distribution(dimension, n_samples, seed)) quantiles = np.linspace(0.0, 1.0, n_centroids + 2, dtype=np.float32)[1:-1] centroids = np.quantile(data, quantiles).astype(np.float32) for _ in range(max_iter): boundaries = (centroids[:-1] + centroids[1:]) / 2.0 labels = np.searchsorted(boundaries, data, side="left") updated = centroids.copy() for idx in range(n_centroids): mask = labels == idx if np.any(mask): updated[idx] = np.mean(data[mask], dtype=np.float32) shift = float(np.max(np.abs(updated - centroids))) centroids = updated if shift <= tol: break return centroids.astype(np.float32, copy=False)
[docs] @dataclass(slots=True) class EncodedMSE: indices: IntArray
[docs] class TurboQuantMSE: def __init__( self, dimension: int, bit_width: int, *, seed: int = 0, monte_carlo_samples: int = 20000, lloyd_max_iter: int = 100, lloyd_tol: float = 1e-6, ) -> None: self.dimension = dimension self.bit_width = bit_width self.rotation = _orthogonal_matrix(dimension, seed) self._cached_query_bytes: bytes | None = None self._cached_query_rot: FloatArray | None = None self.codebook = fit_scalar_codebook( bit_width=bit_width, dimension=dimension, n_samples=monte_carlo_samples, seed=seed, max_iter=lloyd_max_iter, tol=lloyd_tol, )
[docs] def encode(self, x: ArrayLike) -> EncodedMSE: array = np.asarray(x, dtype=np.float32) was_1d = array.ndim == 1 if was_1d: array = array[None, :] rotated = array @ self.rotation.T boundaries = (self.codebook[:-1] + self.codebook[1:]) / 2.0 indices = np.searchsorted(boundaries, rotated, side="left").astype(np.int32) return EncodedMSE(indices=indices[0] if was_1d else indices)
[docs] def decode(self, encoded: EncodedMSE | ArrayLike) -> FloatArray: indices = encoded.indices if isinstance(encoded, EncodedMSE) else np.asarray(encoded, dtype=np.int32) was_1d = indices.ndim == 1 if was_1d: indices = indices[None, :] quantized = self.codebook[indices] decoded = quantized @ self.rotation return decoded[0].astype(np.float32, copy=False) if was_1d else decoded.astype(np.float32, copy=False)
[docs] def prepare_query(self, query: ArrayLike) -> FloatArray: query_arr = np.asarray(query, dtype=np.float32) key = query_arr.tobytes() if self._cached_query_bytes == key and self._cached_query_rot is not None: return self._cached_query_rot self._cached_query_bytes = key self._cached_query_rot = (query_arr @ self.rotation.T).astype(np.float32, copy=False) return self._cached_query_rot
[docs] def approximate_inner_products(self, query: ArrayLike, encoded: EncodedMSE | ArrayLike) -> FloatArray: indices = encoded.indices if isinstance(encoded, EncodedMSE) else np.asarray(encoded, dtype=np.int32) query_rot = self.prepare_query(query) was_1d = indices.ndim == 1 if was_1d: indices = indices[None, :] quantized = self.codebook[indices] scores = np.sum(quantized * query_rot[None, :], axis=1, dtype=np.float32) return np.asarray(float(scores[0]), dtype=np.float32) if was_1d else scores.astype(np.float32, copy=False)
[docs] @dataclass(slots=True) class EncodedProd: indices: IntArray signs: SignArray residual_norms: FloatArray
[docs] class TurboQuantProd: def __init__( self, dimension: int, bit_width: int, *, seed: int = 0, monte_carlo_samples: int = 20000, lloyd_max_iter: int = 100, lloyd_tol: float = 1e-6, ) -> None: if bit_width < 2: raise ValueError("TurboQuantProd requires bit_width >= 2") self.dimension = dimension self.bit_width = bit_width self.mse_quantizer = TurboQuantMSE( dimension=dimension, bit_width=bit_width - 1, seed=seed, monte_carlo_samples=monte_carlo_samples, lloyd_max_iter=lloyd_max_iter, lloyd_tol=lloyd_tol, ) rng = np.random.default_rng(seed + 1) self.projection = rng.normal(size=(dimension, dimension)).astype(np.float32) self.qjl_scale = np.float32(math.sqrt(math.pi / 2.0) / dimension) self._cached_query_bytes: bytes | None = None self._cached_query_rot: FloatArray | None = None self._cached_proj_query: FloatArray | None = None
[docs] def encode(self, x: ArrayLike) -> EncodedProd: array = np.asarray(x, dtype=np.float32) was_1d = array.ndim == 1 if was_1d: array = array[None, :] mse_encoded = self.mse_quantizer.encode(array) mse_decoded = self.mse_quantizer.decode(mse_encoded) residual = array - mse_decoded signs = np.sign(residual @ self.projection.T).astype(np.int8) signs[signs == 0] = 1 residual_norms = np.linalg.norm(residual, axis=1).astype(np.float32) if was_1d: return EncodedProd( indices=mse_encoded.indices, signs=signs[0], residual_norms=np.asarray([residual_norms[0]], dtype=np.float32), ) return EncodedProd( indices=mse_encoded.indices, signs=signs, residual_norms=residual_norms, )
[docs] def decode(self, encoded: EncodedProd) -> FloatArray: indices = encoded.indices signs = encoded.signs norms = encoded.residual_norms was_1d = np.asarray(indices).ndim == 1 if was_1d: indices = np.asarray(indices, dtype=np.int32)[None, :] signs = np.asarray(signs, dtype=np.int32)[None, :] norms = np.asarray(norms, dtype=np.float32) mse_part = self.mse_quantizer.decode(EncodedMSE(indices=indices)) qjl_part = self.qjl_scale * norms[:, None] * (signs @ self.projection) decoded = mse_part + qjl_part return decoded[0].astype(np.float32, copy=False) if was_1d else decoded.astype(np.float32, copy=False)
[docs] def prepare_query(self, query: ArrayLike) -> tuple[FloatArray, FloatArray]: query_arr = np.asarray(query, dtype=np.float32) key = query_arr.tobytes() if self._cached_query_bytes == key and self._cached_query_rot is not None and self._cached_proj_query is not None: return self._cached_query_rot, self._cached_proj_query self._cached_query_bytes = key self._cached_query_rot = self.mse_quantizer.prepare_query(query_arr) self._cached_proj_query = (self.projection @ query_arr).astype(np.float32, copy=False) return self._cached_query_rot, self._cached_proj_query
[docs] def approximate_inner_products(self, query: ArrayLike, encoded: EncodedProd) -> FloatArray: query_rot, proj_query = self.prepare_query(query) indices = np.asarray(encoded.indices, dtype=np.int32) signs = np.asarray(encoded.signs, dtype=np.int8) norms = np.asarray(encoded.residual_norms, dtype=np.float32) was_1d = indices.ndim == 1 if was_1d: indices = indices[None, :] signs = signs[None, :] norms = norms.reshape(1) mse_term = np.sum(self.mse_quantizer.codebook[indices] * query_rot[None, :], axis=1, dtype=np.float32) residual_term = self.qjl_scale * norms * (signs @ proj_query) scores = mse_term + residual_term.astype(np.float32, copy=False) return np.asarray(float(scores[0]), dtype=np.float32) if was_1d else scores.astype(np.float32, copy=False)
[docs] def exact_topk( query: ArrayLike, bank: ArrayLike, top_k: int, ids: Iterable[int] | None = None, ) -> tuple[np.ndarray, np.ndarray]: bank_array = np.asarray(bank, dtype=np.float32) query_array = np.asarray(query, dtype=np.float32) scores = bank_array @ query_array top_k = min(top_k, bank_array.shape[0]) order = np.argpartition(-scores, kth=top_k - 1)[:top_k] order = order[np.argsort(-scores[order])] if ids is None: return order.astype(np.int32), scores[order].astype(np.float32) ids_array = np.asarray(list(ids), dtype=np.int32) return ids_array[order], scores[order].astype(np.float32)