The James Webb Space Telescope (JWST) has transformed extragalactic astronomy. In particular, it has uncovering a puzzling population of compact, red objects in the distant universe known as "Little Red Dots" (LRDs). These sources exhibit distinctive features not seen in typical galaxies and therefore their nature remains a subject of intense debate. Are they supermassive black holes hiding behind screens of dust? Massive dead galaxies appearing too early in the universe? Or something entirely new? To find out, we need to perform computationally heavy statistical analysis on the astronomical data. However traditional tools have been too slow or made many assumptions to reduce the complexity of the JWST data that can lead to inaccurate results.
In this talk, I will introduce the modern Python stack that now makes this possible: JAX and NumPyro. JAX allows you to write standard Python code that runs on GPUs and automatically computes derivatives, while NumPyro leverages that power for incredibly fast statistical modeling. We will start with the basics, using simple examples to demonstrate how JAX can speed up existing workflows and how NumPyro makes Bayesian inference accessible.
Then, we will look at the "Little Red Dot" mystery as a case study. I will show how we built a custom inference engine (unite) to process thousands of JWST observations. By leveraging JAX's speed and NumPyro's flexibility, we were able to efficiently and accurately test complex physical models against the data, uncovering evidence that these unique may in fact be supermassive black holes embedded in dense gas clouds: essentially, stars powered not by fusion but by black holes.
This talk is for anyone interested in high-performance Python and especially for (data) scientists interested in modern scientific methods in designing scalable inference pipelines. You will leave with a solid introduction to JAX and NumPyro and an appreciation for how these tools are already helping solve the Universe's greatest mysteries.