JAX
Differential Programming with JAX course is nice. Meta Optimal Transport is nice JAX repo to run/study.
Robert Lange has nice JAX repos.
Notesβ
Linksβ
- audax - Home for audio ML in JAX. Has common features, learnable frontends, pretrained supervised and self-supervised models.
- tinygp - Extremely lightweight library for building Gaussian Process models in Python, built on top of jax.
- GPJax - Didactic Gaussian process package for researchers in Jax.
- Mctx - Monte Carlo tree search in JAX.
- Pipelined Swarm Training - Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes.
- JAX MuZero - JAX implementation of the MuZero agent.
- Jax Influence - Scalable implementation of Influence Functions in JaX.
- BlackJAX - Library of samplers for JAX that works on CPU as well as GPU. (Twitter)
- GPax - Jax/Flax codebase for Gaussian processes including meta and multi-task Gaussian processes.
- jax-fenics-adjoint - Differentiable interface to FEniCS/Firedrake for JAX using dolfin-adjoint/pyadjoint.
- jax-ekf - Generic EKF, with support for non-Euclidean manifolds.
- PaLM - Jax - Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax.
- Pre-trained image classification models for Jax/Haiku
- Flaxformer: transformer architectures in JAX/Flax
- KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX
- flowjax - Normalizing flow implementations in jax.
- Jax3D - Library for neural rendering in Jax and aims to be a nimble NeRF ecosystem.
- DALLΒ·E 2 in JAX
- JAXNS - Nested sampling in JAX.
- AUX - Audio processing library in JAX, for JAX.
- Nice DeepMind Jax libraries
- Machine Learning with JAX - From Zero to Hero (2021)
- Flax - Neural network library for JAX designed for flexibility. (Docs)
- JAX talks by HuggingFace
- Homomorphic Encryption in JAX
- JAX implementation of Learning to learn by gradient descent by gradient descent
- Normalizing Flows in JAX
- Big Vision - Designed for training large-scale vision models on Cloud TPU VMs. Based on Jax/Flax libraries.
- Jax vs. Julia (Vs PyTorch) (2022) (HN)
- minGPT in JAX
- flaxvision - Selection of neural network models ported from torchvision for JAX & Flax.
- JAX version of clip guided diffusion scripts
- Functorch - Jax-like composable function transforms for PyTorch. (HN)
- Ninjax - Module system for JAX that offers full state access and allows to easily combine modules from other libraries.
- Functional Transformer - Pure-functional implementation of a machine learning transformer model in Python/JAX.
- JAX + Units - Provides and interface between JAX and Pint to allow JAX to support operations with units.
- Infinite Recommendation Networks (β-AE) in JAX
- Differential Programming with JAX course (Code)
- Algorithms for Privacy-Preserving Machine Learning in JAX
- Connex - Small JAX library built on Equinox whose aim is to incorporate artificial analogues of biological neural network attributes into deep learning research and architecture design.
- Rax - Composable Learning to Rank using JAX.
- JaX is faster than PyTorch but harder to debug
- JAX Meta Learning - Collection of meta-learning algorithms in JAX.
- Gymnax - RL Environments in JAX.
- Pax - Framework to configure and run machine learning experiments on top of Jax.
- SymPy2Jax - Turn SymPy expressions into trainable JAX expressions.
- JAX Typing - Type annotations and runtime checking for shape and dtype of JAX arrays, and PyTrees.
- CoDeX - Data compression in JAX.
- DiBS - Python JAX implementation for DiBS, fully differentiable method for joint Bayesian inference of the DAG and parameters of general, causal Bayesian networks.
- Generative Adversarial Networks in JAX
- Neural implicit queries - Perform geometric queries on neural implicit surfaces like ray casting, intersection testing, fast mesh extraction, closest points, and more.
- CLIP-JAX - Train CLIP models using JAX and transformers.
- BLOOM Inference in JAX