LLM-Era Compute for 21-cm Cosmology: Accelerating Bayesian Inference for the SKA Era
Driven by the rapid development of large language models, sustained investment in GPU/TPU infrastructure and associated software stacks is reshaping what is computationally feasible across data-intensive science. In radio cosmology, end-to-end simulation and inference workloads naturally map into the language of structured linear algebra, with intrinsic batch dimensions spanning time samples, frequency channels, and interferometric baselines. This alignment makes accelerator architectures not simply convenient, but essential for scaling these analyses to the data volumes and model complexity demanded by next-generation experiments such as the SKA. Using the global 21-cm experiment REACH as a case study, we demonstrate the advantages of implementing forward models and likelihood evaluations in JAX and coupling them to Bayesian sampling frameworks, such as BlackJAX. By expressing the computation as a differentiable array program that can be JIT-compiled, kernel-fused, and parallelised on GPU hardware, we achieve speedups of order 10^2 relative to traditional CPU-based workflows. Crucially, the accelerated likelihood exhibits near-constant wall clock time over a wide range of data volume and model complexity. We show this work has already made previously impractical regimes feasible on realistic compute budgets and timescales, including large time-resolved analyses, evidence-driven model optimisation, and large-scale simulation-based inference (SBI). Furthermore, JAX’s automatic differentiation provides exact gradients of the forward model and likelihood with minimal additional implementation effort, enabling practical gradient-informed samplers for high-dimensional analyses. We outline ongoing work to extend these ideas to interferometric 21-cm power-spectrum inference within the BayesEoR framework, where analyses are expected to involve order thousands of parameters and terabyte-scale datasets.