Source code for geomstats.varifold.kernel

"""Kernel pairings."""

import importlib

import geomstats.backend as gs

from .base import Pairing

if gs.__name__.endswith("pytorch"):
    import torch

    _compile = torch.compile

else:

    def _compile(fn):
        return fn


[docs] def GaussianBinetPairing(sigma, backend="auto"): r"""Instantiate a Gaussian-Binet kernel pairing. This pairing is defined by .. math:: K(x, y, u, v) = exp(-||x - y||^2 / sigma^2) * <u, v>^2 Parameters ---------- sigma : float Positive bandwidth parameter of the Gaussian kernel. backend : {"auto", "torch", "keops", "keops_genred", "keops_lazy"} Implementation backend. - "auto": Select an implementation automatically (typically prefers a KeOps-based implementation when available, otherwise falls back to a Torch/NumPy implementation). - "backend": Dense implementation using the current geomstats backend. - "keops": Alias for "keops_genred". - "keops_genred": KeOps implementation using Genred reductions. - "keops_lazy": KeOps LazyTensor-based implementation. Returns ------- Pairing An object implementing the kernel pairing. Notes ----- The dense ("backend") implementation materializes pairwise matrices and is memory-bound for large inputs. KeOps-based implementations avoid forming the full kernel matrix and are more efficient for large-scale problems. The "auto" backend does not guarantee optimal performance in all cases, but provides a reasonable default based on available dependencies. """ if backend == "auto": has_keops = importlib.util.find_spec("pykeops") is not None backend = "keops_genred" if has_keops else "backend" if backend == "keops": backend = "keops_genred" if backend == "backend": return _GaussianBinetPairing(sigma=sigma) if backend == "keops_genred": import geomstats.varifold.keops.genred as gkeops return gkeops.GaussianBinetPairing(sigma) if backend == "keops_lazy": import geomstats.varifold.keops.lazy as lkeops return lkeops.SurfaceKernelPairing(lkeops.GaussianBinetKernel(sigma=sigma)) raise ValueError(f"Unknown backend: {backend}")
class _GaussianBinetPairing(Pairing): r"""Instantiate a Gaussian–Binet kernel pairing. This pairing is defined by .. math:: K(x, y, u, v) = exp(-||x - y||^2 / sigma^2) <u, v>^2 Parameters ---------- sigma : float Positive bandwidth parameter of the Gaussian kernel. Notes ----- It materializes pairwise matrices and is memory-bound for large inputs. """ def __init__(self, sigma): super().__init__() def _kernel(x, y, u, v): x_norm2 = gs.sum(x**2, axis=1)[:, None] y_norm2 = gs.sum(y**2, axis=1)[None, :] dist2 = x_norm2 + y_norm2 - 2 * (x @ y.T) K_xy = gs.exp(-dist2 / sigma**2) uv = u @ v.T return K_xy * uv**2 self._kernel = _compile(_kernel) def kernel_prod(self, x, y, u, v, b): """Apply the kernel pairing to a vector.""" return self._kernel(x, y, u, v) @ b