gaussian_toolbox.utils package#

Submodules#

gaussian_toolbox.utils.dataclass module#

JAX/dm-tree friendly dataclass implementation reusing Python dataclasses.

mappable_dataclass(cls)#

Exposes dataclass as collections.abc.Mapping descendent.

Allows to traverse dataclasses in methods from dm-tree library.

NOTE: changes dataclasses constructor to dict-type (i.e. positional args aren’t supported; however can use generators/iterables).

Parameters:

cls – A dataclass to mutate.

Returns:

Mutated dataclass implementing collections.abc.Mapping interface.

dataclass(cls=None, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, mappable_dataclass=True, kw_only=False)#

JAX-friendly wrapper for dataclasses.dataclass().

This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).

Parameters:
  • cls – A class to decorate.

  • init – See dataclasses.dataclass().

  • repr – See dataclasses.dataclass().

  • eq – See dataclasses.dataclass().

  • order – See dataclasses.dataclass().

  • unsafe_hash – See dataclasses.dataclass().

  • frozen – See dataclasses.dataclass().

  • mappable_dataclass – If True (the default), methods to make the class implement the collections.abc.Mapping interface will be generated and the class will include collections.abc.Mapping in its base classes. True is the default, because being an instance of Mapping makes chex.dataclass compatible with e.g. jax.tree_util.tree_* methods, the tree library, or methods related to tensorflow/python/utils/nest.py. As a side-effect, e.g. np.testing.assert_array_equal will only check the field names are equal and not the content. Use chex.assert_tree_* instead.

Returns:

A JAX-friendly dataclass.

register_dataclass_type_with_jax_tree_util(data_class)#

Register an existing dataclass so JAX knows how to handle it.

This means that functions in jax.tree_util operate over the fields of the dataclass. See https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees for further information.

Parameters:

data_class – A class created using dataclasses.dataclass. It must be constructable from keyword arguments corresponding to the members exposed in instance.__dict__.

gaussian_toolbox.utils.jax_minimize_wrapper module#

A collection of helper functions for optimization with JAX.

UPDATE: This is obsolete now that jax.scipy.optimize.minimize is exists!

minimize(fun, x0, method=None, args=(), bounds=None, constraints=(), tol=None, callback=None, options=None)#

gaussian_toolbox.utils.linalg module#

invert_matrix(A)#
Return type:

Tuple[Array, Array]

invert_diagonal(A)#
Return type:

Tuple[Array, Array]

Module contents#