GitHub topics: jax
astro-informatics/s2fft
S2FFT: Differentiable and accelerated spherical transforms
Language: Python - Size: 58.1 MB - Last synced at: 1 day ago - Pushed at: 1 day ago - Stars: 150 - Forks: 9

MaxMSun/lqrax
JAX-enabled continuous-time LQR solver
Language: Python - Size: 2.57 MB - Last synced at: 16 days ago - Pushed at: 16 days ago - Stars: 1 - Forks: 0

debangshu-mukherjee/rheedium
a JAX based package for differentiable RHEED simulations and reconstructions.
Language: Python - Size: 255 KB - Last synced at: 4 days ago - Pushed at: 16 days ago - Stars: 0 - Forks: 0

google-deepmind/dks
Multi-framework implementation of Deep Kernel Shaping and Tailored Activation Transformations, which are methods that modify neural network models (and their initializations) to make them easier to train.
Language: Python - Size: 1.23 MB - Last synced at: 16 days ago - Pushed at: 16 days ago - Stars: 70 - Forks: 5

google-research/sofima
Scalable Optical Flow-based Image Montaging and Alignment
Language: Jupyter Notebook - Size: 4.73 MB - Last synced at: 9 days ago - Pushed at: 3 months ago - Stars: 69 - Forks: 16

mmarcinmichal/trax Fork of google/trax
Trax — Deep Learning with Clear Code and Speed
Language: Python - Size: 162 MB - Last synced at: 17 days ago - Pushed at: 17 days ago - Stars: 0 - Forks: 0

evanatyourservice/psgd_jax
Implementation of PSGD optimizer in JAX
Language: Python - Size: 329 KB - Last synced at: 4 days ago - Pushed at: 4 months ago - Stars: 33 - Forks: 2

arpastrana/jax_fdm
Auto-differentiable and hardware-accelerated force density method
Language: Python - Size: 114 MB - Last synced at: 8 days ago - Pushed at: 3 months ago - Stars: 88 - Forks: 6

samuela/git-re-basin
Code release for "Git Re-Basin: Merging Models modulo Permutation Symmetries"
Language: Python - Size: 1.7 MB - Last synced at: 3 days ago - Pushed at: about 2 years ago - Stars: 479 - Forks: 41

LouisDesdoigts/dLux
Differentiable optical models as parameterised neural networks in Jax using Zodiax
Language: Python - Size: 740 MB - Last synced at: 17 days ago - Pushed at: 17 days ago - Stars: 56 - Forks: 8

google-deepmind/PGMax
Loopy belief propagation for factor graphs on discrete variables in JAX
Language: Jupyter Notebook - Size: 14.6 MB - Last synced at: 10 days ago - Pushed at: 7 months ago - Stars: 150 - Forks: 11

pyrddlgym-project/pyRDDLGym-jax
JAX compilation of RDDL description files, and a differentiable planner in JAX.
Language: Python - Size: 12.1 MB - Last synced at: 17 days ago - Pushed at: 17 days ago - Stars: 5 - Forks: 1

FLAIROx/JaxGL
Simple JAX Graphics Library.
Language: Python - Size: 60.5 KB - Last synced at: 8 days ago - Pushed at: 6 months ago - Stars: 36 - Forks: 0

JoeyTeng/jaxrenderer
Differentiable Rasteriser implemented in JAX. Reference: https://github.com/erwincoumans/tinyrenderer, https://github.com/ssloy/tinyrenderer/wiki; PR: https://github.com/google/brax/pull/367
Language: Jupyter Notebook - Size: 47.4 MB - Last synced at: 4 days ago - Pushed at: over 1 year ago - Stars: 72 - Forks: 7

frankroeder/goal_conditioned_rl
Goal-conditioned reinforcement learning like 🔥
Language: Python - Size: 29.3 KB - Last synced at: 2 days ago - Pushed at: over 1 year ago - Stars: 12 - Forks: 0

jcmgray/autoray
Abstract your array operations.
Language: Python - Size: 1.74 MB - Last synced at: about 8 hours ago - Pushed at: 2 months ago - Stars: 148 - Forks: 11

mpi4jax/mpi4jax
Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python :zap:
Language: Python - Size: 5.06 MB - Last synced at: 1 day ago - Pushed at: about 2 months ago - Stars: 478 - Forks: 31

calvinikchen/expanding-ejecta
Pipeline for measuring supernova morphology and the Hubble constant through the expanding ejecta method using intensity interferometer.
Language: Jupyter Notebook - Size: 7.07 MB - Last synced at: 19 days ago - Pushed at: 19 days ago - Stars: 1 - Forks: 1

instadeepai/matrax
A collection of matrix games in JAX
Language: Python - Size: 326 KB - Last synced at: 18 days ago - Pushed at: 6 months ago - Stars: 11 - Forks: 3

EmptyJackson/unifloral
Unified Implementations of Offline Reinforcement Learning Algorithms
Language: Python - Size: 47.9 KB - Last synced at: 19 days ago - Pushed at: 19 days ago - Stars: 58 - Forks: 3

TolgaOk/jaxdp
A Dynamic Programming package for discrete MDPs implemented in JAX
Language: Python - Size: 549 KB - Last synced at: 19 days ago - Pushed at: 19 days ago - Stars: 5 - Forks: 1

PennyLaneAI/pennylane-cirq
The PennyLane-Cirq plugin integrates Google's Cirq software library with with PennyLane's quantum machine learning capabilities.
Language: Python - Size: 2.07 MB - Last synced at: 9 days ago - Pushed at: 9 days ago - Stars: 56 - Forks: 18

narendasan/VibeRL
VibeRL is a toolkit for reinforcement learning, designed to facilitate the use of standalone RL implementations such as CleanRL and ReJAX in experiments primarily in JAX
Language: Python - Size: 922 KB - Last synced at: 12 days ago - Pushed at: 12 days ago - Stars: 0 - Forks: 0

JeyRunner/flaxfit
Fitting jax flax models made simple.
Language: Python - Size: 138 KB - Last synced at: 20 days ago - Pushed at: 20 days ago - Stars: 0 - Forks: 0

ZhengYinan-AIR/FISOR
[ICLR 2024] The official implementation of "Safe Offline Reinforcement Learning with Feasibility-Guided Diffusion Model"
Language: Python - Size: 13.1 MB - Last synced at: 18 days ago - Pushed at: 3 months ago - Stars: 97 - Forks: 7

AaltoML/BayesNewton
Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's method.
Language: Python - Size: 1.34 MB - Last synced at: 6 days ago - Pushed at: over 1 year ago - Stars: 231 - Forks: 29

DarshanDeshpande/jax-models
Unofficial JAX implementations of deep learning research papers
Language: Python - Size: 201 KB - Last synced at: 1 day ago - Pushed at: almost 3 years ago - Stars: 156 - Forks: 9

cbg-ethz/Jnotype
Probabilistic modeling of high-dimensional binary data in JAX
Language: Python - Size: 276 KB - Last synced at: 8 days ago - Pushed at: 8 days ago - Stars: 3 - Forks: 0

NeuralQXLab/nqxpack
Save/Load files from NetKet, flax and other scientific ML libraries
Language: Python - Size: 52.7 KB - Last synced at: 20 days ago - Pushed at: 20 days ago - Stars: 1 - Forks: 0

tianjuxue/jax-am
Additive manufacturing simulation with JAX.
Language: Jupyter Notebook - Size: 57.7 MB - Last synced at: 19 days ago - Pushed at: 9 months ago - Stars: 296 - Forks: 58

JPGoodale/hippox
High-order Polynomial Projection Operators for JAX
Language: Python - Size: 113 KB - Last synced at: 9 days ago - Pushed at: about 2 years ago - Stars: 7 - Forks: 0

brentyi/jax_dataclasses
Pytrees + dataclasses ❤️
Language: Python - Size: 71.3 KB - Last synced at: 6 days ago - Pushed at: 20 days ago - Stars: 62 - Forks: 6

brentyi/jaxlie
Rigid transforms + Lie groups for JAX
Language: Python - Size: 13.3 MB - Last synced at: 21 days ago - Pushed at: 21 days ago - Stars: 253 - Forks: 16

gerdm/bayes
Neat Bayesian machine learning examples
Language: Jupyter Notebook - Size: 45.9 MB - Last synced at: 8 days ago - Pushed at: 4 months ago - Stars: 56 - Forks: 8

google-deepmind/dm_pix
PIX is an image processing library in JAX, for JAX.
Language: Python - Size: 761 KB - Last synced at: 14 days ago - Pushed at: 2 months ago - Stars: 415 - Forks: 29

rezaakb/pinns-jax
PINNs-JAX, Physics-informed Neural Networks (PINNs) implemented in JAX.
Language: Python - Size: 137 KB - Last synced at: 20 days ago - Pushed at: 8 months ago - Stars: 47 - Forks: 6

mia-jinns/jinns
Physics Informed Neural Networks (PINNs) + SPINNs + HyperPINNs with JAX 📓 Check out our various notebooks to get started ⚠️ Mirror repository of jinns (development happens on Gitlab)
Language: Jupyter Notebook - Size: 110 MB - Last synced at: 10 days ago - Pushed at: 10 days ago - Stars: 29 - Forks: 7

alonfnt/bayex
Minimal Implementation of Bayesian Optimization in JAX
Language: Python - Size: 370 KB - Last synced at: 21 days ago - Pushed at: 21 days ago - Stars: 94 - Forks: 3

zombie-einstein/jaxpr-viz
Jaxpr Visualisation Tool
Language: Python - Size: 423 KB - Last synced at: 21 days ago - Pushed at: 5 months ago - Stars: 24 - Forks: 1

ergodicio/tsadar-app
Streamlit application for TSADAR - AD-based Thomson Scattering Analysis software. It is hosted on AWS for simple browser-based access and runs on GPUs for rapid analysis
Language: Python - Size: 411 KB - Last synced at: 21 days ago - Pushed at: 21 days ago - Stars: 1 - Forks: 0

awslabs/fortuna 📦
A Library for Uncertainty Quantification.
Language: Python - Size: 4.56 MB - Last synced at: 21 days ago - Pushed at: 21 days ago - Stars: 913 - Forks: 48

BirkhoffG/jax-dataloader
Pytorch-like dataloaders for JAX.
Language: Jupyter Notebook - Size: 697 KB - Last synced at: 22 days ago - Pushed at: 22 days ago - Stars: 80 - Forks: 3

DBraun/audiotree
Audio data loading and augmentations in JAX
Language: Python - Size: 265 KB - Last synced at: 22 days ago - Pushed at: 22 days ago - Stars: 3 - Forks: 0

patrick-kidger/sympy2jax
Turn SymPy expressions into trainable JAX expressions.
Language: Python - Size: 36.1 KB - Last synced at: 2 days ago - Pushed at: 22 days ago - Stars: 337 - Forks: 14

VishwamAI/ProtienFlex
ProteinFlex is a comprehensive platform for protein structure analysis and drug discovery, leveraging advanced AI and machine learning techniques. The platform combines state-of-the-art protein structure prediction with interactive visualization and sophisticated drug discovery tools.
Language: Python - Size: 739 KB - Last synced at: 6 days ago - Pushed at: 22 days ago - Stars: 0 - Forks: 0

ziatdinovmax/NeuroBayes
Fully and Partially Bayesian Neural Nets
Language: Python - Size: 116 MB - Last synced at: 17 days ago - Pushed at: about 1 month ago - Stars: 68 - Forks: 8

mjendrusch/salad
protein structure generation with sparse all-atom denoising models
Language: Python - Size: 28.9 MB - Last synced at: 22 days ago - Pushed at: 22 days ago - Stars: 30 - Forks: 5

bsc-quantic/tn4ml
Tensor Networks for Machine Learning
Language: Python - Size: 38 MB - Last synced at: 22 days ago - Pushed at: 22 days ago - Stars: 16 - Forks: 3

carrycooldude/JAX-Dataloader
A lightweight DataLoader for JAX to load data from various file formats, including CSV, JSON, and more. The goal of this project is to port TensorFlow Dataset (TFDS) functionality into JAX while supporting multiple data sources and preprocessing.
Language: Python - Size: 23.3 MB - Last synced at: 4 days ago - Pushed at: 29 days ago - Stars: 6 - Forks: 0

sabrinastronomy/Grad2Dens
@jax optimized code to convert from a localized 21cm emission measurement to a density field at high-z.
Language: Python - Size: 117 MB - Last synced at: 22 days ago - Pushed at: 22 days ago - Stars: 0 - Forks: 0

i-a-morozov/sympint
JAX composable symplectic integrators: multi-map, implicit midpoint and Tao
Language: Python - Size: 6.01 MB - Last synced at: 22 days ago - Pushed at: 22 days ago - Stars: 0 - Forks: 0

tomasplsek/CADET
Machine learning pipeline for detection of X-ray cavities on Chandra images of early-type galaxies and galaxy clusters.
Language: Jupyter Notebook - Size: 185 MB - Last synced at: 22 days ago - Pushed at: 22 days ago - Stars: 8 - Forks: 3

tanyuqian/redco
NAACL '24 (Best Demo Paper RunnerUp) / MlSys @ NeurIPS '23 - RedCoast: A Lightweight Tool to Automate Distributed Training and Inference
Language: Python - Size: 11.5 MB - Last synced at: 4 days ago - Pushed at: 5 months ago - Stars: 65 - Forks: 7

MichalBortkiewicz/JaxGCRL
Goal-Conditioned Reinforcement Learning with JAX
Language: Python - Size: 32.5 MB - Last synced at: 22 days ago - Pushed at: about 2 months ago - Stars: 146 - Forks: 19

mdda/getting-to-aha-with-tpus
Reasoning-from-Zero using gemma.JAX.nnx on TPUs
Language: Python - Size: 292 KB - Last synced at: 23 days ago - Pushed at: 23 days ago - Stars: 9 - Forks: 0

n2cholas/jax-resnet
Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
Language: Python - Size: 49.8 KB - Last synced at: 4 days ago - Pushed at: almost 3 years ago - Stars: 109 - Forks: 8

lockwo/awesome-jax
Curated list of JAX Resources and Packages
Size: 84 KB - Last synced at: 4 days ago - Pushed at: 24 days ago - Stars: 4 - Forks: 0

tjhunter/levy-stable-jax
Lévy's alpha-stable distribution for the Jax numerical framework
Language: Jupyter Notebook - Size: 11.8 MB - Last synced at: 2 days ago - Pushed at: 11 months ago - Stars: 2 - Forks: 0

i-a-morozov/tohubohu
JAX composable differentiable chaos indicators
Language: Python - Size: 28.9 MB - Last synced at: 24 days ago - Pushed at: 24 days ago - Stars: 0 - Forks: 0

google-parfait/dataset_grouper
Libraries for efficient and scalable group-structured dataset pipelines.
Language: Python - Size: 59.6 KB - Last synced at: 10 days ago - Pushed at: 5 months ago - Stars: 25 - Forks: 3

kyegomez/VO-ROPE
An implementation of the all-new rope from jianlin
Language: Python - Size: 34.2 KB - Last synced at: 21 days ago - Pushed at: 24 days ago - Stars: 4 - Forks: 0

PapayaResearch/synthax
A Fast Modular Synthesizer in JAX
Language: Python - Size: 5.18 MB - Last synced at: 16 days ago - Pushed at: about 1 year ago - Stars: 49 - Forks: 1

veb-101/keras-vision
Porting vision models to Keras 3 for easily accessibility. Contains MobileViT v1, MobileViT v2, fastvit
Language: Jupyter Notebook - Size: 4.45 MB - Last synced at: 25 days ago - Pushed at: 25 days ago - Stars: 11 - Forks: 2

pwolle/seli
Fast NN research in Jax
Language: Python - Size: 3.4 MB - Last synced at: 25 days ago - Pushed at: 25 days ago - Stars: 1 - Forks: 0

poets-ai/elegy
A High Level API for Deep Learning in JAX
Language: Python - Size: 33.9 MB - Last synced at: 3 days ago - Pushed at: over 2 years ago - Stars: 475 - Forks: 32

google/jaxopt
Hardware accelerated, batchable and differentiable optimizers in JAX.
Language: Python - Size: 3.35 MB - Last synced at: 25 days ago - Pushed at: about 1 month ago - Stars: 960 - Forks: 70

SymJAX/SymJAX
Documentation:
Language: Python - Size: 27.3 MB - Last synced at: 15 days ago - Pushed at: almost 2 years ago - Stars: 119 - Forks: 5

HFooladi/molax
Molecular active learning with JAX
Language: Python - Size: 47.9 KB - Last synced at: 26 days ago - Pushed at: 26 days ago - Stars: 0 - Forks: 0

google/fedjax
FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research.
Language: Python - Size: 841 KB - Last synced at: 25 days ago - Pushed at: about 1 month ago - Stars: 261 - Forks: 42

alonfnt/pcax
Differentiable Principal Component Analysis (PCA) implementation in JAX
Language: Python - Size: 33.2 KB - Last synced at: 26 days ago - Pushed at: 26 days ago - Stars: 26 - Forks: 1

alonfnt/tsnex
Minimal t-distributed stochastic neighbor embedding (t-SNE) implementation in JAX.
Language: Python - Size: 30.3 KB - Last synced at: 26 days ago - Pushed at: 26 days ago - Stars: 5 - Forks: 0

EthanSchmitt7/TurbaNet
TurbaNet is a lightweight and user-friendly API wrapper for the JAX library, designed to simplify and accelerate the setup of swarm-based training, evaluation, and simulation of small neural networks.
Language: Python - Size: 21.5 KB - Last synced at: 21 days ago - Pushed at: 21 days ago - Stars: 1 - Forks: 1

homerjed/sbgm
Score-based Diffusion models in JAX.
Language: Python - Size: 14.6 MB - Last synced at: 8 days ago - Pushed at: 4 months ago - Stars: 8 - Forks: 1

homerjed/sbiax
Fast, lightweight and parallelised simulation-based inference in JAX.
Language: Python - Size: 19.2 MB - Last synced at: 1 day ago - Pushed at: 1 day ago - Stars: 18 - Forks: 3

Sea-Snell/JAXSeq
Train very large language models in Jax.
Language: Python - Size: 252 KB - Last synced at: 8 days ago - Pushed at: over 1 year ago - Stars: 204 - Forks: 18

USCbiostats/PM520
PM520 Advanced Statistical Computing
Language: Jupyter Notebook - Size: 2.81 MB - Last synced at: 17 days ago - Pushed at: 27 days ago - Stars: 14 - Forks: 2

vdutor/SphericalHarmonics
Zonal Spherical Harmonics in d Dimensions in TensorFlow, PyTorch and Jax
Language: Python - Size: 3.43 MB - Last synced at: 21 days ago - Pushed at: 12 months ago - Stars: 29 - Forks: 5

UQatKIT/Eikonax
A Fully Differentiable Solver for the Anisotropic Eikonal Equation
Language: Python - Size: 4 MB - Last synced at: 1 day ago - Pushed at: 3 months ago - Stars: 3 - Forks: 0

google/jax-cfd
Computational Fluid Dynamics in JAX
Language: Jupyter Notebook - Size: 6.41 MB - Last synced at: 25 days ago - Pushed at: about 1 month ago - Stars: 821 - Forks: 119

jla524/fromthetensor
From the Tensor to Stable Diffusion, a rough outline for a 1 week course.
Size: 2.8 MB - Last synced at: 27 days ago - Pushed at: 28 days ago - Stars: 1,057 - Forks: 44

camml-lab/reax
REAX — Scalable, flexible training for JAX, inspired by the simplicity of PyTorch Lightning.
Language: Python - Size: 354 KB - Last synced at: 27 days ago - Pushed at: 28 days ago - Stars: 0 - Forks: 0

bahremsd/tmmax
A fast transfer matrix method written in jax for modelling optical multilayer thin films
Language: Jupyter Notebook - Size: 3.98 MB - Last synced at: 11 days ago - Pushed at: about 1 month ago - Stars: 9 - Forks: 3

kirkegaardlab/recloc
Official code for the paper Local Clustering and Global Spreading of Receptors for Optimal Spatial Gradient Sensing (PRL 2025). Includes simulations and visualizations for optimizing receptor placement on cell surfaces.
Language: Python - Size: 28.3 KB - Last synced at: 27 days ago - Pushed at: 28 days ago - Stars: 2 - Forks: 0

jeremiecoullon/SGMCMCJax
Lightweight library of stochastic gradient MCMC algorithms written in JAX.
Language: Python - Size: 1.37 MB - Last synced at: 1 day ago - Pushed at: over 1 year ago - Stars: 103 - Forks: 8

wcxve/inferax
Statistical Inference with JAX.
Language: Python - Size: 26.4 KB - Last synced at: 2 days ago - Pushed at: 3 days ago - Stars: 1 - Forks: 0

google/CommonLoopUtils
CLU lets you write beautiful training loops in JAX.
Language: Jupyter Notebook - Size: 1.48 MB - Last synced at: 25 days ago - Pushed at: about 1 month ago - Stars: 337 - Forks: 34

CoastEgo/microlux
Implementation of automatic differentiation in VBozza's BinaryLensing
Language: Jupyter Notebook - Size: 4.49 MB - Last synced at: 29 days ago - Pushed at: 29 days ago - Stars: 10 - Forks: 2

joseph-nagel/jax-quickstart
Getting to know the JAX framework
Language: Jupyter Notebook - Size: 11.7 KB - Last synced at: 29 days ago - Pushed at: 29 days ago - Stars: 0 - Forks: 0

NVIDIA-Merlin/dataloader
The merlin dataloader lets you rapidly load tabular data for training deep leaning models with TensorFlow, PyTorch or JAX
Language: Python - Size: 28.7 MB - Last synced at: 6 days ago - Pushed at: about 1 year ago - Stars: 418 - Forks: 26

v0lta/Jax-Wavelet-Toolbox
Differentiable and gpu enabled fast wavelet transforms in JAX.
Language: Python - Size: 1.15 MB - Last synced at: 3 days ago - Pushed at: 10 months ago - Stars: 42 - Forks: 2

EMI-Group/tensorneat
GPU-accelerated NeuroEvolution of Augmenting Topologies (NEAT)
Language: Python - Size: 35.2 MB - Last synced at: 29 days ago - Pushed at: 29 days ago - Stars: 140 - Forks: 18

luyug/GradCache
Run Effective Large Batch Contrastive Learning Beyond GPU/TPU Memory Constraint
Language: Python - Size: 43.9 KB - Last synced at: 26 days ago - Pushed at: about 1 year ago - Stars: 386 - Forks: 24

ThomasMullen/diffilqrax
Diffilqrax is an open-source Python library for differentiable iLQR and LQR in JAX
Language: Python - Size: 9.39 MB - Last synced at: 29 days ago - Pushed at: 3 months ago - Stars: 2 - Forks: 0

rystrauss/dopamax
Reinforcement learning in pure JAX.
Language: Python - Size: 262 KB - Last synced at: 22 days ago - Pushed at: 3 months ago - Stars: 12 - Forks: 1

valence-labs/mess
MESS: Modern Electronic Structure Simulations
Language: Python - Size: 2.17 MB - Last synced at: 10 days ago - Pushed at: about 1 month ago - Stars: 28 - Forks: 2

micdoh/XLRON
X-elerated Learning and Resource Allocation for Optical Networks
Language: Jupyter Notebook - Size: 8.24 MB - Last synced at: 30 days ago - Pushed at: 30 days ago - Stars: 17 - Forks: 3

uwplasma/ESSOS
e-Stellarator Simulation and Optimization Suite
Language: Python - Size: 5.34 MB - Last synced at: 3 days ago - Pushed at: 3 days ago - Stars: 10 - Forks: 3

NobuoTsukamoto/jax_examples
Jax, Flax, examples (ImageClassification, SemanticSegmentation, and more...)
Language: Python - Size: 4.35 MB - Last synced at: 8 days ago - Pushed at: about 1 month ago - Stars: 10 - Forks: 0

renecotyfanboy/jaxspec
jaxspec is an X-ray spectra Bayesian analysis package, relying on JAX to enable just in time compilation
Language: Python - Size: 317 MB - Last synced at: 19 days ago - Pushed at: 24 days ago - Stars: 30 - Forks: 3

abess-team/skscope
skscope: Sparse-Constrained OPtimization via itErative-solvers
Language: Python - Size: 23.2 MB - Last synced at: 34 minutes ago - Pushed at: 7 months ago - Stars: 332 - Forks: 15
