Source code for geomstats.varifold.keops.genred

"""Kernels and kernel pairings using KeOps genred."""

import geomstats.backend as gs
from geomstats.varifold.base import Pairing

if gs.__name__.endswith("pytorch"):
    from pykeops.torch import Genred
else:
    from pykeops.numpy import Genred


[docs] def GaussianKernel(sigma): r"""Gaussian kernel. .. math:: K(x, y)=e^{-\|x-y\|^2 / \sigma^2} Parameters ---------- sigma : float Kernel parameter. """ expr = Genred( "Exp(-SqDist(x,y)*a)", [ "a=Pm(1)", "x=Vi(3)", "y=Vj(3)", ], reduction_op="Sum", axis=1, ) a_param = 1 / gs.array([sigma]) ** 2 def kernel_eval(point_a, point_b): return expr(a_param, point_a, point_b) return kernel_eval
[docs] def CauchyKernel(sigma): r"""Cauchy kernel. .. math:: K(x, y)=\frac{1}{1+\|x-y\|^2 / \sigma^2} Parameters ---------- sigma : float Kernel parameter. """ expr = Genred( "IntCst(1)/(IntCst(1)+SqDist(x,y)*a)", [ "a=Pm(1)", "x=Vi(3)", "y=Vj(3)", ], reduction_op="Sum", axis=1, ) a_param = 1 / gs.array([sigma]) ** 2 def kernel_eval(point_a, point_b): return expr(a_param, point_a, point_b) return kernel_eval
[docs] def LinearKernel(): r"""Linear kernel. .. math:: K(u, v) = \langle u, v \rangle """ expr = Genred( "(u|v)", [ "u=Vi(3)", "v=Vj(3)", ], reduction_op="Sum", axis=1, ) def kernel_eval(point_a, point_b): return expr(point_a, point_b) return kernel_eval
[docs] def BinetKernel(): r"""Binet kernel. .. math:: K(u, v) = \langle u, v \rangle^2 """ expr = Genred( "Square((u|v))", [ "u=Vi(3)", "v=Vj(3)", ], reduction_op="Sum", axis=1, ) def kernel_eval(point_a, point_b): return expr(point_a, point_b) return kernel_eval
[docs] def OrientedGaussianKernel(sigma=1.0): r"""Gaussian kernel restricted to the hypersphere. .. math:: K(u, v)=e^{2 (\langle u, v \rangle / - 1) / \sigma^2} Parameters ---------- sigma : float Kernel parameter. """ expr = Genred( "Exp(IntCst(2)*b*((u|v)-IntCst(1)))", [ "b=Pm(1)", "u=Vi(3)", "v=Vj(3)", ], reduction_op="Sum", axis=1, ) a_param = 1 / gs.array([sigma]) ** 2 def kernel_eval(point_a, point_b): return expr(a_param, point_a, point_b) return kernel_eval
[docs] def UnorientedGaussianKernel(sigma=1.0): r"""Gaussian kernel restricted to the hypersphere. .. math:: K(u, v)=e^{2 (\langle u, v \rangle ^2 - 1) / \sigma^2 } Parameters ---------- sigma : float Kernel parameter. """ expr = Genred( "Exp(IntCst(2)*b*(Square((u|v))-IntCst(1)))", [ "b=Pm(1)", "u=Vi(3)", "v=Vj(3)", ], reduction_op="Sum", axis=1, ) a_param = 1 / gs.array([sigma]) ** 2 def kernel_eval(point_a, point_b): return expr(a_param, point_a, point_b) return kernel_eval
[docs] class GaussianBinetPairing(Pairing): r"""Instantiate a Gaussian–Binet kernel pairing. This pairing is defined by .. math:: K(x, y, u, v) = exp(-||x - y||^2 / sigma^2) <u, v>^2 Parameters ---------- sigma : float Positive bandwidth parameter of the Gaussian kernel. """ def __init__(self, sigma): super().__init__() self._expr = Genred( "Exp(-SqDist(x,y)*a)*Square((u|v))*b", [ "a=Pm(1)", "x=Vi(3)", "y=Vj(3)", "u=Vi(3)", "v=Vj(3)", "b=Vj(1)", ], reduction_op="Sum", axis=1, ) self._a_param = 1 / gs.array([sigma]) ** 2
[docs] def kernel_prod(self, *kernel_args): """Apply the kernel pairing to a vector.""" return self._expr(self._a_param, *kernel_args)