Source code for geomstats._backend.tensorflow

"""Tensorflow based computation backend."""

import math
from collections import Counter
from itertools import product

import numpy as _np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow import abs
from tensorflow import acos as arccos  # NOQA
from tensorflow import acosh as arccosh
from tensorflow import argmax, argmin
from tensorflow import asin as arcsin
from tensorflow import atan2 as arctan2
from tensorflow import broadcast_to
from tensorflow import clip_by_value as clip
from tensorflow import (
    complex64,
    complex128,
    concat,
    cos,
    cosh,
    divide,
    equal,
    exp,
    expand_dims,
    float32,
    float64,
    floor,
    gather,
    greater,
    int32,
    int64,
    less,
    less_equal,
    linspace,
    logical_and,
    logical_or,
    maximum,
    meshgrid,
    minimum,
    ones,
    ones_like,
    pad,
)
from tensorflow import range as arange
from tensorflow import reduce_max as amax
from tensorflow import reduce_mean as mean
from tensorflow import reduce_min as amin
from tensorflow import reduce_prod as prod
from tensorflow import (
    reshape,
    searchsorted,
    shape,
    sign,
    sin,
    sinh,
    sort,
    sqrt,
    squeeze,
    stack,
    tan,
    tanh,
    uint8,
    zeros,
    zeros_like,
)
from tensorflow.experimental.numpy import moveaxis

from ..constants import tf_atol, tf_rtol
from . import autodiff  # NOQA
from . import linalg  # NOQA
from . import random  # NOQA

DTYPES = {int32: 0, int64: 1, float32: 2, float64: 3, complex64: 4, complex128: 5}

angle = tf.math.angle
arctanh = tf.math.atanh
ceil = tf.math.ceil
conj = tf.math.conj
cross = tf.linalg.cross
erf = tf.math.erf
imag = tf.math.imag
isnan = tf.math.is_nan
log = tf.math.log
mod = tf.math.mod
polygamma = tf.math.polygamma
power = tf.math.pow
real = tf.math.real
set_diag = tf.linalg.set_diag
std = tf.math.reduce_std
atol = tf_atol
rtol = tf_rtol
trapz = tfp.math.trapz


def _raise_not_implemented_error(*args, **kwargs):
    raise NotImplementedError


def comb(n, k):
    return math.factorial(n) // math.factorial(k) // math.factorial(n - k)


def to_numpy(x):
    return x.numpy()


def from_numpy(x):
    return tf.convert_to_tensor(x)


def one_hot(labels, num_classes):
    return tf.one_hot(labels, num_classes, dtype=tf.uint8)


def concatenate(x, axis=0, out=None):
    return concat(x, axis=axis)


def convert_to_wider_dtype(tensor_list):
    dtype_list = [DTYPES[x.dtype] for x in tensor_list]
    wider_dtype_index = max(dtype_list)

    wider_dtype = list(DTYPES.keys())[wider_dtype_index]

    tensor_list = [cast(x, dtype=wider_dtype) for x in tensor_list]
    return tensor_list


def repeat(a, repeats, axis=None):
    return tf.repeat(input=a, repeats=repeats, axis=axis)


def array(x, dtype=None):
    return tf.convert_to_tensor(x, dtype=dtype)


def trace(x, axis1=0, axis2=1):
    min_axis = min(axis1, axis2)
    max_axis = max(axis1, axis2)
    if min_axis == 1 and max_axis == 2:
        return tf.einsum("...ii", x)
    if min_axis == -2 and max_axis == -1:
        return tf.einsum("...ii", x)
    if min_axis == 0 and max_axis == 1:
        return tf.einsum("ii...", x)
    if min_axis == 0 and max_axis == 2:
        return tf.einsum("i...i", x)
    raise NotImplementedError()


# TODO (nkoep): Handle the optional axis arguments.
def diagonal(a, axis1=0, axis2=1):
    return tf.linalg.diag_part(a)


def ndim(x):
    return tf.convert_to_tensor(x).ndim


def to_ndarray(x, to_ndim, axis=0):
    if ndim(x) == to_ndim - 1:
        x = tf.expand_dims(x, axis=axis)
    return x


def empty(shape, dtype=float64):
    if not isinstance(dtype, tf.DType):
        raise ValueError("dtype must be one of Tensorflow's types")
    np_dtype = dtype.as_numpy_dtype
    return tf.convert_to_tensor(_np.empty(shape, dtype=np_dtype))


def empty_like(prototype, dtype=None):
    initial_shape = tf.shape(prototype)
    if dtype is None:
        dtype = prototype.dtype
    return empty(initial_shape, dtype=dtype)


def flip(m, axis=None):
    if not isinstance(m, tf.Tensor):
        raise ValueError("m must be a Tensorflow tensor")
    if axis is None:
        axis = range(m.ndim)
    elif not hasattr(axis, "__iter__"):
        axis = (axis,)
    return tf.reverse(m, axis=axis)


def any(x, axis=None):
    return tf.math.reduce_any(tf.cast(x, bool), axis=axis)


def _is_boolean(x):
    if isinstance(x, bool):
        return True
    if isinstance(x, (tuple, list)):
        return isinstance(x[0], bool)
    if tf.is_tensor(x):
        return x.dtype == bool
    return False


[docs]def get_mask_i_float(i, n): """Create a 1D array of zeros with one element at one, with floating type. Parameters ---------- i : int Index of the non-zero element. n: n Length of the created array. Returns ------- mask_i_float : array-like, shape=[n,] 1D array of zeros except at index i, where it is one """ range_n = arange(n) i_float = cast(array([i]), int32)[0] mask_i = equal(range_n, i_float) mask_i_float = cast(mask_i, float32) return mask_i_float
def _mask_from_indices(indices, mask_shape, dtype=float32): """Create a binary mask. Parameters ---------- indices: tuple Single index or tuple of indices where ones will be. mask_shape: tuple Shape of the mask. dtype: dtype Type of the mask. Returns ------- tf_mask : array, shape=[mask_shape] """ np_mask = _np.zeros(mask_shape) for i_index, index in enumerate(indices): if not isinstance(index, tuple): if hasattr(index, "__iter__"): indices[i_index] = tuple(index) else: indices[i_index] = (index,) for index in indices: if len(index) != len(mask_shape): raise ValueError("Indices must have the same size as shape") for index in indices: np_mask[index] = 1 tf_mask = array(np_mask, dtype=dtype) return tf_mask def _duplicate_array(x, n_samples, axis=0): """Stack copies of an array along an additional dimension. Parameters ---------- x: array-like, shape=[dim] Initial array which will be copied. n_samples: int Number of copies of the array to create. axis: int, optional Dimension of the new array along which the copies of x are made. Returns ------- tiled_array: array, shape=[dim[:axis], n_samples, dim[axis:]] Copies of x stacked along dim axis """ multiples = _np.ones(ndim(x) + 1, dtype=_np.int32) multiples[axis] = n_samples return tile(to_ndarray(x, ndim(x) + 1, axis), multiples) def _vectorized_mask_from_indices( n_samples=1, indices=None, mask_shape=None, axis=0, dtype=float32 ): """Create a vectorized binary mask. Parameters ---------- n_samples: int Number of copies of the mask along the additional dimension. indices : {tuple, list(tuple)} Single tuple, or list of tuples of indices where ones will be. mask_shape : tuple Shape of the mask. axis: int Axis along which the mask is vectorized. dtype: dtype Type of the returned array. Returns ------- tf_mask : array, shape=[mask_shape[:axis], n_samples, mask_shape[axis:]] """ mask = _mask_from_indices(indices, mask_shape, dtype) return _duplicate_array(mask, n_samples, axis=axis) def _assignment_single_value(x, value, indices, mode="replace", axis=0): """Assign a value at given indices of an array. Parameters ---------- x : array-like, shape=[dim] Initial array. value : float Value to be added. indices : {int, tuple, list(int), list(tuple)} Single int or tuple, or list of ints or tuples of indices where value is assigned. If the length of the tuples is shorter than ndim(x), value is assigned to each copy along axis. mode : string, optional Whether the assignment is done by replacing the old value, or by adding to it. Possible values are 'replace' and 'sum' axis : int, optional Axis along which value is assigned, if vectorized. Returns ------- x_new : array-like, shape=[dim] Copy of x where value was assigned at all indices (and possibly along an axis). """ single_index = not isinstance(indices, list) if tf.is_tensor(indices): single_index = ndim(indices) <= 1 and sum(indices.shape) <= ndim(x) if single_index: indices = [indices] if isinstance(indices[0], tuple): use_vectorization = len(indices[0]) < ndim(x) elif tf.is_tensor(indices[0]) and ndim(indices[0]) >= 1: use_vectorization = len(indices[0]) < ndim(x) else: use_vectorization = ndim(x) > 1 if use_vectorization: full_shape = shape(x).numpy() n_samples = full_shape[axis] tile_shape = list(full_shape[:axis]) + list(full_shape[axis + 1 :]) mask = _vectorized_mask_from_indices( n_samples, indices, tile_shape, axis, x.dtype ) else: mask = _mask_from_indices(indices, shape(x), x.dtype) if mode == "replace": return x + -x * mask + value * mask if mode == "sum": return x + value * mask raise ValueError("mode must be one of 'replace' or 'sum'") def _assignment(x, values, indices, mode, axis): if _is_boolean(indices): if ndim(array(indices)) > 1: indices_tensor = tf.where(indices) indices = [tuple(ind) for ind in indices_tensor] else: indices_from_booleans = [index for index, val in enumerate(indices) if val] indices_along_dims = [range(dim) for dim in shape(x)] indices_along_dims[axis] = indices_from_booleans indices = list(product(*indices_along_dims)) if tf.rank(values) == 0: return _assignment_single_value(x, values, indices, mode, axis) values = cast(flatten(array(values)), x.dtype) single_index = not isinstance(indices, list) if tf.is_tensor(indices): single_index = ndim(indices) <= 1 and sum(indices.shape) <= ndim(x) if single_index: if len(values) > 1: indices = [ tuple(list(indices[:axis]) + [i] + list(indices[axis:])) for i in range(x.shape[axis]) ] else: indices = [indices] if len(values) != len(indices): raise ValueError("Either one value or as many values as indices") for i_index, index in enumerate(indices): x = _assignment_single_value(x, values[i_index], index, mode, axis) return x
[docs]def assignment(x, values, indices, axis=0): """Add values at given indices of an array. Parameters ---------- x: array-like, shape=[dim] Initial array. values: {float, list(float)} Value or list of values to be assigned. indices: { int, tuple(int), array-like({int, tuple, boolean}) Single index or array of indices where values are assigned. If the length of the tuples is shorter than ndim(x) by one, values are assigned to each copy along axis. If indices is a list of booleans and ndim(x) > 1, values are assigned across all dimensions. axis: int, optional Axis along which values are assigned, if vectorized. Returns ------- x_new : array-like, shape=[dim] Copy of x as the sum of x and the values at the given indices. Notes ----- If a single value is provided, it is assigned at all the indices. If a single index is provided, and len(indices) == ndim(x) - 1, then values are assigned along axis. Examples -------- Most examples translate as assignment(x, indices, values) <=> x[indices] = values Some special cases are given by vectorisation. (Beware that copies are always returned). if ndim(x) == 3, assignment(x, 1, (1, 0), 1) <=> x[1, :, 0] = 1 if ndim(x) == 2, assignment(x, [1, 2], [(0, 1), (2, 3)]) <=> x[((0, 2), (1, 3))] = [1, 2] """ return _assignment(x, values, indices, "replace", axis)
[docs]def assignment_by_sum(x, values, indices, axis=0): """Add values at given indices of an array. Parameters ---------- x: array-like, shape=[dim] Initial array. values: {float, list(float)} Value or list of values to be assigned. indices: { int, tuple(int), array-like({int, tuple, boolean}) Single index or array of indices where values are assigned. If the length of the tuples is shorter than ndim(x) by one, values are assigned to each copy along axis. If indices is a list of booleans and ndim(x) > 1, values are assigned across all dimensions. axis: int, optional Axis along which values are assigned, if vectorized. Returns ------- x_new : array-like, shape=[dim] Copy of x as the sum of x and the values at the given indices. Notes ----- If a single value is provided, it is assigned at all the indices. If a single index is provided, and len(indices) == ndim(x) - 1, then values are assigned along axis. Examples -------- Most examples translate as assignment_by_sum(x, indices, values) <=> x[indices] = x[indices] + values Some special cases are given by vectorisation. (Beware that copies are always returned). if ndim(x) == 3, assignment_by_sum(x, 1, (1, 0), 1) <=> x[1, :, 0] += 1 if ndim(x) == 2, assignment_by_sum(x, [1, 2], [(0, 1), (2, 3)]) <=> x[((0, 2), (1, 3))] += [1, 2] """ return _assignment(x, values, indices, "sum", axis)
[docs]def array_from_sparse(indices, data, target_shape): """Create an array of given shape, with values at specific indices. The rest of the array will be filled with zeros. Parameters ---------- indices : iterable(tuple(int)) Index of each element which will be assigned a specific value. data : iterable(scalar) Value associated at each index. target_shape : tuple(int) Shape of the output array. Returns ------- a : array, shape=target_shape Array of zeros with specified values assigned to specified indices. """ return tf.sparse.to_dense( tf.sparse.reorder(tf.SparseTensor(indices, data, target_shape)) )
[docs]def get_slice(x, indices): """Return a slice of an array, following Numpy's style. Parameters ---------- x : array-like, shape=[dim] Initial array. indices : iterable(iterable(int)) Indices which are kept along each axis, starting from 0. Returns ------- slice : array-like Slice of x given by indices. Notes ----- This follows Numpy's convention: indices are grouped by axis. Examples -------- >>> a = tf.reshape(tf.convert_to_tensor(range(30)), (3,10)) >>> get_slice(a, ((0, 2), (8, 9))) <tf.Tensor: id=41, shape=(2,), dtype=int32, numpy=array([ 8, 29])> """ if hasattr(indices, "shape"): if indices.shape.rank == 0: return x[indices] if tf.is_tensor(indices) and indices.shape[-1] == 1: return tf.gather_nd(x, indices) return tf.gather_nd(x, list(zip(*indices)))
def vectorize(x, pyfunc, multiple_args=False, dtype=None, **kwargs): if multiple_args: return tf.map_fn(lambda y: pyfunc(*y), elems=x, dtype=dtype) return tf.map_fn(pyfunc, elems=x, dtype=dtype) def split(x, indices_or_sections, axis=0): if isinstance(indices_or_sections, int): return tf.split(x, indices_or_sections, axis=axis) indices_or_sections = _np.array(indices_or_sections) intervals_length = indices_or_sections[1:] - indices_or_sections[:-1] last_interval_length = x.shape[axis] - indices_or_sections[-1] if last_interval_length > 0: intervals_length = _np.append(intervals_length, last_interval_length) intervals_length = _np.insert(intervals_length, 0, indices_or_sections[0]) return tf.split(x, num_or_size_splits=tuple(intervals_length), axis=axis) def hsplit(x, n_splits): return tf.split(x, num_or_size_splits=n_splits, axis=1)
[docs]def flatten(x): """Collapse the tensor into 1-D. Following https://www.tensorflow.org/api_docs/python/tf/reshape """ return tf.reshape(x, [-1])
[docs]def matmul(a, b): """Matrix-matrix or matrix-vector product of two tensors. This wraps both mathvec and matmul into a single function, to mimic the behavior of torch's and numpy's versions of matmul """ if ndim(b) < ndim(a) and (ndim(b) == 1 or b.shape[-2] != a.shape[-1]): return tf.linalg.matvec(a, b) return tf.linalg.matmul(a, b)
def outer(x, y): return tf.einsum("i,j->ij", x, y) def copy(x): return tf.Variable(x) def hstack(x): return tf.concat(x, axis=1) def vstack(x): new_x = [] for one_x in x: if one_x.ndim < 2: new_x.append(tf.expand_dims(one_x, axis=0)) else: new_x.append(one_x) return tf.concat(new_x, axis=0) def cast(x, dtype): return tf.cast(x, dtype) def broadcast_arrays(*args, **kwargs): tensors = [*args] shapes = [t.get_shape().as_list() for t in tensors] max_rank = max(len(s) for s in shapes) for index, value in enumerate(shapes): if len(value) == max_rank: continue tensor = tensors[index] for _ in range(max_rank - len(value)): value.insert(0, 1) tensor = tf.expand_dims(tensor, axis=0) tensors[index] = tensor broadcast_shape = [] for index in range(max_rank): dimensions = [s[index] for s in shapes] repeats = Counter(dimensions) if len(repeats) > 2 or (len(repeats) == 2 and 1 not in list(repeats.keys())): raise ValueError( "operands could not be broadcast together with shapes", shapes ) broadcast_shape.append(max(repeats.keys())) for axis, dimension in enumerate(broadcast_shape): tensors = [ tf.concat([t] * dimension, axis=axis) if t.get_shape()[axis] == 1 else t for t in tensors ] return tensors def dot(x, y): return tf.tensordot(x, y, axes=1) def isclose(x, y, rtol=rtol, atol=atol): if not tf.is_tensor(x): x = tf.constant(x) if not tf.is_tensor(y): y = tf.constant(y) x, y = convert_to_wider_dtype([x, y]) dtype = x.dtype rhs = tf.constant(atol, dtype=dtype) + tf.constant(rtol, dtype=dtype) * tf.abs(y) return tf.less_equal(tf.abs(tf.subtract(x, y)), rhs) def allclose(x, y, rtol=rtol, atol=atol): return tf.reduce_all(isclose(x, y, rtol=rtol, atol=atol)) def eye(n, m=None): if m is None: m = n return tf.eye(num_rows=n, num_columns=m) def sum(x, axis=None, keepdims=False, name=None): if not tf.is_tensor(x): x = tf.convert_to_tensor(x) if x.dtype == bool: x = cast(x, int32) return tf.reduce_sum(x, axis, keepdims, name) def einsum(equation, *inputs, **kwargs): einsum_str = equation input_tensors_list = inputs input_tensors_list = convert_to_wider_dtype(input_tensors_list) einsum_list = einsum_str.split("->") input_str = einsum_list[0] output_str = einsum_list[1] input_str_list = input_str.split(",") is_ellipsis = [input_str[:3] == "..." for input_str in input_str_list] all_ellipsis = bool(_np.prod(is_ellipsis)) if all_ellipsis: if len(input_str_list) > 2: raise NotImplementedError( "Ellipsis support not implemented for >2 input tensors" ) ndims = [len(input_str[3:]) for input_str in input_str_list] tensor_a = input_tensors_list[0] tensor_b = input_tensors_list[1] initial_ndim_a = tensor_a.ndim initial_ndim_b = tensor_b.ndim tensor_a = to_ndarray(tensor_a, to_ndim=ndims[0] + 1) tensor_b = to_ndarray(tensor_b, to_ndim=ndims[1] + 1) n_tensor_a = tensor_a.shape[0] n_tensor_b = tensor_b.shape[0] if n_tensor_a != n_tensor_b: if n_tensor_a == 1: tensor_a = squeeze(tensor_a, axis=0) input_prefix_list = ["", "r"] output_prefix = "r" elif n_tensor_b == 1: tensor_b = squeeze(tensor_b, axis=0) input_prefix_list = ["r", ""] output_prefix = "r" else: raise ValueError("Shape mismatch for einsum.") else: input_prefix_list = ["r", "r"] output_prefix = "r" input_str_list = [ input_str.replace("...", prefix) for input_str, prefix in zip(input_str_list, input_prefix_list) ] output_str = output_str.replace("...", output_prefix) input_str = input_str_list[0] + "," + input_str_list[1] einsum_str = input_str + "->" + output_str result = tf.einsum(einsum_str, tensor_a, tensor_b, **kwargs) cond = ( n_tensor_a == n_tensor_b == 1 and initial_ndim_a != tensor_a.ndim and initial_ndim_b != tensor_b.ndim ) if cond: result = squeeze(result, axis=0) return result return tf.einsum(equation, *input_tensors_list, **kwargs) def transpose(x, axes=None): return tf.transpose(x, perm=axes) def all(x, axis=None): return tf.math.reduce_all(tf.cast(x, bool), axis=axis) def cumsum(a, axis=None): if axis is None: return tf.math.cumsum(flatten(a), axis=0) return tf.math.cumsum(a, axis=axis) def cumprod(a, axis=None): if axis is None: return tf.math.cumprod(flatten(a), axis=0) return tf.math.cumprod(a, axis=axis) # (sait) there is tf.experimental.tril (we can use it once it moves to stable) def tril(mat, k=0): if k not in (0, -1): raise NotImplementedError("Only k=0 and k=-1 supported so far") tril = tf.linalg.band_part(mat, -1, 0) if k == 0: return tril zero_diag = tf.zeros(mat.shape[:-1]) return tf.linalg.set_diag(tril, zero_diag) # TODO(sait) use tf.experimental.triu once it becomes stable. def triu(mat, k=0): if k not in (0, 1): raise NotImplementedError("Only k=0 and k=1 supported so far") triu = tf.linalg.band_part(mat, 0, -1) if k == 0: return triu zero_diag = tf.zeros(mat.shape[:-1]) return tf.linalg.set_diag(triu, zero_diag) def diag_indices(*args, **kwargs): return tuple(map(tf.convert_to_tensor, _np.diag_indices(*args, **kwargs))) def tril_indices(*args, **kwargs): return tuple(map(tf.convert_to_tensor, _np.tril_indices(*args, **kwargs))) def triu_indices(*args, **kwargs): return tuple(map(tf.convert_to_tensor, _np.triu_indices(*args, **kwargs))) def unique(x): return tf.unique(x).y def where(condition, x=None, y=None): if x is None and y is None: return tf.where(condition) if not tf.is_tensor(x): x = tf.constant(x) if not tf.is_tensor(y): y = tf.constant(y) y = cast(y, x.dtype) return tf.where(condition, x, y) def tril_to_vec(x, k=0): n = x.shape[-1] axis = 1 if x.ndim == 3 else 0 mask = tf.ones((n, n)) mask_a = tf.linalg.band_part(mask, -1, 0) if k < 0: mask_b = tf.linalg.band_part(mask, -k - 1, 0) else: mask_b = tf.zeros_like(mask_a) mask = tf.cast(mask_a - mask_b, dtype=tf.bool) return tf.boolean_mask(x, mask, axis=axis) def triu_to_vec(x, k=0): n = x.shape[-1] axis = 1 if x.ndim == 3 else 0 mask = tf.ones((n, n)) mask_a = tf.linalg.band_part(mask, 0, -1) if k > 0: mask_b = tf.linalg.band_part(mask, 0, k - 1) else: mask_b = tf.zeros_like(mask_a) mask = tf.cast(mask_a - mask_b, dtype=tf.bool) return tf.boolean_mask(x, mask, axis=axis) def tile(x, multiples): t1 = tf.ones(len(multiples) - len(tf.shape(x))) t1 = tf.cast(t1, tf.int32) t2 = tf.shape(x) x_reshape = tf.reshape(x, tf.concat([t1, t2], axis=0)) return tf.tile(x_reshape, multiples) def vec_to_diag(vec): return tf.linalg.diag(vec)
[docs]def vec_to_triu(vec): """Take vec and forms strictly upper triangular matrix. Parameters ---------- vec : array_like, shape[..., n] Returns ------- tril : array_like, shape=[..., k, k] where k is (1 + sqrt(1 + 8 * n)) / 2 """ n = vec.shape[-1] triu_shape = vec.shape + (n,) _ones = tf.ones(triu_shape) vec = tf.reshape(vec, [-1]) mask_a = tf.linalg.band_part(_ones, 0, -1) mask_b = tf.linalg.band_part(_ones, 0, 0) mask = tf.subtract(mask_a, mask_b) non_zero = tf.not_equal(mask, tf.constant(0.0)) indices = tf.where(non_zero) sparse = tf.SparseTensor(indices, values=vec, dense_shape=triu_shape) return tf.sparse.to_dense(sparse)
[docs]def vec_to_tril(vec): """Take vec and forms strictly lower triangular matrix. Parameters ---------- vec : array_like, shape=[..., n] Returns ------- tril : array_like, shape=[..., k, k] where k is (1 + sqrt(1 + 8 * n)) / 2 """ n = vec.shape[-1] tril_shape = vec.shape + (n,) _ones = tf.ones(tril_shape) vec = tf.reshape(vec, [-1]) mask_a = tf.linalg.band_part(_ones, -1, 0) mask_b = tf.linalg.band_part(_ones, 0, 0) mask = tf.subtract(mask_a, mask_b) non_zero = tf.not_equal(mask, tf.constant(0.0)) indices = tf.where(non_zero) sparse = tf.SparseTensor(indices, values=vec, dense_shape=tril_shape) return tf.sparse.to_dense(sparse)
[docs]def mat_from_diag_triu_tril(diag, tri_upp, tri_low): """Build matrix from given components. Forms a matrix from diagonal, strictly upper triangular and strictly lower traingular parts. Parameters ---------- diag : array_like, shape=[..., n] tri_upp : array_like, shape=[..., (n * (n - 1)) / 2] tri_low : array_like, shape=[..., (n * (n - 1)) / 2] Returns ------- mat : array_like, shape=[..., n, n] """ triu_mat = vec_to_triu(tri_upp) tril_mat = vec_to_tril(tri_low) triu_tril_mat = triu_mat + tril_mat mat = tf.linalg.set_diag(triu_tril_mat, diag) return mat
def _ravel_multi_index(multi_index, shape): strides = tf.math.cumprod(shape, exclusive=True, reverse=True) return tf.reduce_sum(multi_index * tf.expand_dims(strides, 1), axis=0) def ravel_tril_indices(n, k=0, m=None): if m is None: size = (n, n) else: size = (n, m) idxs = tril_indices(n, k, m) return _ravel_multi_index(idxs, size) def kron(a, b): return tf.linalg.LinearOperatorKronecker([a, b]).to_dense() def take(a, indices, axis=0): return tf.gather(a, indices, axis=axis)