BlackJAX: Composable Bayesian inference in JAX
Authors:
Alberto Cabezas,
Adrien Corenflos,
Junpeng Lao,
Rémi Louf,
Antoine Carnec,
Kaustubh Chaudhari,
Reuben Cohn-Gordon,
Jeremie Coullon,
Wei Deng,
Sam Duffield,
Gerardo Durán-Martín,
Marcin Elantkowski,
Dan Foreman-Mackey,
Michele Gregori,
Carlos Iguaran,
Ravin Kumar,
Martin Lysy,
Kevin Murphy,
Juan Camilo Orduz,
Karm Patel,
Xi Wang,
Rob Zinkov
Abstract:
BlackJAX is a library implementing sampling and variational inference algorithms commonly used in Bayesian computation. It is designed for ease of use, speed, and modularity by taking a functional approach to the algorithms' implementation. BlackJAX is written in Python, using JAX to compile and run NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs. The library integrates well w…
▽ More
BlackJAX is a library implementing sampling and variational inference algorithms commonly used in Bayesian computation. It is designed for ease of use, speed, and modularity by taking a functional approach to the algorithms' implementation. BlackJAX is written in Python, using JAX to compile and run NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs. The library integrates well with probabilistic programming languages by working directly with the (un-normalized) target log density function. BlackJAX is intended as a collection of low-level, composable implementations of basic statistical 'atoms' that can be combined to perform well-defined Bayesian inference, but also provides high-level routines for ease of use. It is designed for users who need cutting-edge methods, researchers who want to create complex sampling methods, and people who want to learn how these work.
△ Less
Submitted 22 February, 2024; v1 submitted 16 February, 2024;
originally announced February 2024.