gaussian_toolbox.utils package#
Submodules#
gaussian_toolbox.utils.dataclass module#
JAX/dm-tree friendly dataclass implementation reusing Python dataclasses.
- mappable_dataclass(cls)#
Exposes dataclass as
collections.abc.Mapping
descendent.Allows to traverse dataclasses in methods from dm-tree library.
NOTE: changes dataclasses constructor to dict-type (i.e. positional args aren’t supported; however can use generators/iterables).
- Parameters:
cls – A dataclass to mutate.
- Returns:
Mutated dataclass implementing
collections.abc.Mapping
interface.
- dataclass(cls=None, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, mappable_dataclass=True, kw_only=False)#
JAX-friendly wrapper for
dataclasses.dataclass()
.This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).
- Parameters:
cls – A class to decorate.
init – See
dataclasses.dataclass()
.repr – See
dataclasses.dataclass()
.eq – See
dataclasses.dataclass()
.order – See
dataclasses.dataclass()
.unsafe_hash – See
dataclasses.dataclass()
.frozen – See
dataclasses.dataclass()
.mappable_dataclass – If True (the default), methods to make the class implement the
collections.abc.Mapping
interface will be generated and the class will includecollections.abc.Mapping
in its base classes. True is the default, because being an instance of Mapping makes chex.dataclass compatible with e.g. jax.tree_util.tree_* methods, the tree library, or methods related to tensorflow/python/utils/nest.py. As a side-effect, e.g. np.testing.assert_array_equal will only check the field names are equal and not the content. Use chex.assert_tree_* instead.
- Returns:
A JAX-friendly dataclass.
- register_dataclass_type_with_jax_tree_util(data_class)#
Register an existing dataclass so JAX knows how to handle it.
This means that functions in jax.tree_util operate over the fields of the dataclass. See https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees for further information.
- Parameters:
data_class – A class created using dataclasses.dataclass. It must be constructable from keyword arguments corresponding to the members exposed in instance.__dict__.
gaussian_toolbox.utils.jax_minimize_wrapper module#
A collection of helper functions for optimization with JAX.
UPDATE: This is obsolete now that jax.scipy.optimize.minimize is exists!
- minimize(fun, x0, method=None, args=(), bounds=None, constraints=(), tol=None, callback=None, options=None)#
gaussian_toolbox.utils.linalg module#
- invert_matrix(A)#
- Return type:
Tuple
[Array
,Array
]
- invert_diagonal(A)#
- Return type:
Tuple
[Array
,Array
]