From 4689714b8861ab483705a8b9fc92c09f762750bb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 15 Apr 2026 08:58:15 +0530 Subject: [PATCH 1/3] add an example of spmd for flux on v5e-8 --- .../pytorch_xla/inference/flux/README.md | 37 +++- .../inference/flux/flux_inference_spmd.py | 190 ++++++++++++++++++ 2 files changed, 226 insertions(+), 1 deletion(-) create mode 100644 examples/research_projects/pytorch_xla/inference/flux/flux_inference_spmd.py diff --git a/examples/research_projects/pytorch_xla/inference/flux/README.md b/examples/research_projects/pytorch_xla/inference/flux/README.md index 0bbd650bb6b7..2c5a2800f4de 100644 --- a/examples/research_projects/pytorch_xla/inference/flux/README.md +++ b/examples/research_projects/pytorch_xla/inference/flux/README.md @@ -51,7 +51,42 @@ python flux_inference.py The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest. -On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel): +On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel). + +> **Note:** `flux_inference.py` uses `xmp.spawn` (one process per chip) and requires the full model to fit on a single chip. If you run into OOM errors (e.g., on v5e with 16GB HBM per chip), use the SPMD version instead — see below. + +### SPMD version (for v5e-8 and similar) + +On TPU configurations where a single chip cannot hold the full FLUX transformer (~16GB in bf16), use `flux_inference_spmd.py`. This script uses PyTorch/XLA SPMD to shard the transformer across multiple chips using a `(data, model)` mesh — 4-way model parallel so each chip holds ~4GB of weights, with the remaining chips for data parallelism. + +```bash +python flux_inference_spmd.py --schnell +``` + +Key differences from `flux_inference.py`: +- **Single-process SPMD** instead of multi-process `xmp.spawn` — the XLA compiler handles all collective communication transparently. +- **Transformer weights are sharded** across the `"model"` mesh axis using `xs.mark_sharding`. +- **VAE lives on CPU**, moved to XLA only for decode (then moved back), since the transformer stays on device throughout. +- **Text encoding** runs on CPU before loading the transformer. + +On a v5litepod-8 (v5e, 8 chips, 16GB HBM each) with FLUX.1-schnell, expect ~1.76 sec/image at steady state (after compilation): + +``` +2026-04-15 02:24:30 [info ] SPMD mesh: (2, 4), axes: ('data', 'model'), devices: 8 +2026-04-15 02:24:30 [info ] encoding prompt on CPU... +2026-04-15 02:26:20 [info ] loading VAE on CPU... +2026-04-15 02:26:20 [info ] loading flux transformer from black-forest-labs/FLUX.1-schnell +2026-04-15 02:27:22 [info ] starting compilation run... +2026-04-15 02:52:55 [info ] compilation took 1533.4575625509997 sec. +2026-04-15 02:52:56 [info ] starting inference run... +2026-04-15 02:56:11 [info ] inference time: 195.74092420299985 +2026-04-15 02:56:13 [info ] inference time: 1.7625778899996476 +2026-04-15 02:56:13 [info ] avg. inference over 2 iterations took 98.75175104649975 sec. +``` + +The first inference iteration includes VAE compilation (~195s). The second iteration shows the true steady-state speed (~1.76s). + +### v6e-4 results (original `flux_inference.py`) ```bash WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. diff --git a/examples/research_projects/pytorch_xla/inference/flux/flux_inference_spmd.py b/examples/research_projects/pytorch_xla/inference/flux/flux_inference_spmd.py new file mode 100644 index 000000000000..700dc6ab69a1 --- /dev/null +++ b/examples/research_projects/pytorch_xla/inference/flux/flux_inference_spmd.py @@ -0,0 +1,190 @@ +"""FLUX inference on TPU using PyTorch/XLA SPMD. + +Uses SPMD to shard the transformer across multiple TPU chips, enabling +inference on devices where the model doesn't fit on a single chip (e.g., v5e). +The VAE is loaded on CPU at startup, moved to XLA for decode, then moved back. +""" + +from argparse import ArgumentParser +from pathlib import Path +from time import perf_counter + +import numpy as np +import structlog +import torch +import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met +import torch_xla.debug.profiler as xp +import torch_xla.distributed.spmd as xs +import torch_xla.runtime as xr +from torch_xla.experimental.custom_kernel import FlashAttention + +from diffusers import AutoencoderKL, FluxPipeline + + +cache_path = Path("/tmp/data/compiler_cache_tRiLlium_eXp") +cache_path.mkdir(parents=True, exist_ok=True) +xr.initialize_cache(str(cache_path), readonly=False) +xr.use_spmd() + +logger = structlog.get_logger() +metrics_filepath = "/tmp/metrics_report.txt" +VAE_SCALE_FACTOR = 8 + + +def _vae_decode(latents, vae, height, width, device): + """Move VAE to XLA, decode latents, move VAE back to CPU.""" + vae.to(device) + latents = FluxPipeline._unpack_latents(latents, height, width, VAE_SCALE_FACTOR) + latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor + with torch.no_grad(): + image = vae.decode(latents, return_dict=False)[0] + vae.to("cpu") + return image + + +def main(args): + # --- SPMD mesh: 4-way model parallel to fit transformer + VAE on v5e chips --- + num_devices = xr.global_runtime_device_count() + mesh = xs.Mesh(np.arange(num_devices), (num_devices // 4, 4), ("data", "model")) + xs.set_global_mesh(mesh) + logger.info(f"SPMD mesh: {mesh.mesh_shape}, axes: {mesh.axis_names}, devices: {num_devices}") + + # --- Profiler --- + profile_path = Path("/tmp/data/profiler_out_tRiLlium_eXp") + profile_path.mkdir(parents=True, exist_ok=True) + profiler_port = 9012 + profile_duration = args.profile_duration + if args.profile: + logger.info(f"starting profiler on port {profiler_port}") + _ = xp.start_server(profiler_port) + + device = xm.xla_device() + + # --- Checkpoint --- + if args.schnell: + ckpt_id = "black-forest-labs/FLUX.1-schnell" + else: + ckpt_id = "black-forest-labs/FLUX.1-dev" + + # --- Text encoding (CPU) --- + prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side" + logger.info("encoding prompt on CPU...") + text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to("cpu") + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, _ = text_pipe.encode_prompt( + prompt=prompt, prompt_2=None, max_sequence_length=512 + ) + image_processor = text_pipe.image_processor + del text_pipe + + # --- Load VAE on CPU (moved to XLA only for decode) --- + logger.info("loading VAE on CPU...") + vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16) + + # --- Load transformer and shard --- + logger.info(f"loading flux transformer from {ckpt_id}") + flux_pipe = FluxPipeline.from_pretrained( + ckpt_id, + text_encoder=None, + tokenizer=None, + text_encoder_2=None, + tokenizer_2=None, + vae=None, + torch_dtype=torch.bfloat16, + ).to(device) + + for name, param in flux_pipe.transformer.named_parameters(): + if param.dim() >= 2: + spec = [None] * param.dim() + largest_dim = max(range(param.dim()), key=lambda d: param.shape[d]) + spec[largest_dim] = "model" + xs.mark_sharding(param, mesh, tuple(spec)) + + flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True) + FlashAttention.DEFAULT_BLOCK_SIZES = { + "block_q": 1536, + "block_k_major": 1536, + "block_k": 1536, + "block_b": 1536, + "block_q_major_dkv": 1536, + "block_k_major_dkv": 1536, + "block_q_dkv": 1536, + "block_k_dkv": 1536, + "block_q_dq": 1536, + "block_k_dq": 1536, + "block_k_major_dq": 1536, + } + + width = args.width + height = args.height + guidance = args.guidance + n_steps = 4 if args.schnell else 28 + + prompt_embeds = prompt_embeds.to(device) + pooled_prompt_embeds = pooled_prompt_embeds.to(device) + xs.mark_sharding(prompt_embeds, mesh, ("data", None, None)) + xs.mark_sharding(pooled_prompt_embeds, mesh, ("data", None)) + + # --- Compilation run --- + logger.info("starting compilation run...") + ts = perf_counter() + latents = flux_pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_inference_steps=28, + guidance_scale=guidance, + height=height, + width=width, + output_type="latent", + ).images + image = _vae_decode(latents, vae, height, width, device) + image = image_processor.postprocess(image)[0] + logger.info(f"compilation took {perf_counter() - ts} sec.") + image.save("/tmp/compile_out.png") + + # --- Inference loop --- + seed = 4096 if args.seed is None else args.seed + xm.set_rng_state(seed=seed, device=device) + times = [] + logger.info("starting inference run...") + for _ in range(args.itters): + ts = perf_counter() + + if args.profile: + xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration) + latents = flux_pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_inference_steps=n_steps, + guidance_scale=guidance, + height=height, + width=width, + output_type="latent", + ).images + image = _vae_decode(latents, vae, height, width, device) + image = image_processor.postprocess(image)[0] + inference_time = perf_counter() - ts + logger.info(f"inference time: {inference_time}") + times.append(inference_time) + + logger.info(f"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.") + image.save("/tmp/inference_out.png") + metrics_report = met.metrics_report() + with open(metrics_filepath, "w+") as fout: + fout.write(metrics_report) + logger.info(f"saved metric information as {metrics_filepath}") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--schnell", action="store_true", help="run flux schnell instead of dev") + parser.add_argument("--width", type=int, default=1024, help="width of the image to generate") + parser.add_argument("--height", type=int, default=1024, help="height of the image to generate") + parser.add_argument("--guidance", type=float, default=3.5, help="guidance strength for dev") + parser.add_argument("--seed", type=int, default=None, help="seed for inference") + parser.add_argument("--profile", action="store_true", help="enable profiling") + parser.add_argument("--profile-duration", type=int, default=10000, help="duration for profiling in msec.") + parser.add_argument("--itters", type=int, default=15, help="items to run inference and get avg time in sec.") + args = parser.parse_args() + main(args) From b99078d22729464ada10fd03cf1ec76ae90c15b7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Apr 2026 14:55:56 +0530 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Sayak Paul --- .../pytorch_xla/inference/flux/flux_inference_spmd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/pytorch_xla/inference/flux/flux_inference_spmd.py b/examples/research_projects/pytorch_xla/inference/flux/flux_inference_spmd.py index 700dc6ab69a1..ebb4e7e3085c 100644 --- a/examples/research_projects/pytorch_xla/inference/flux/flux_inference_spmd.py +++ b/examples/research_projects/pytorch_xla/inference/flux/flux_inference_spmd.py @@ -22,7 +22,7 @@ from diffusers import AutoencoderKL, FluxPipeline -cache_path = Path("/tmp/data/compiler_cache_tRiLlium_eXp") +cache_path = Path("/tmp/data/compiler_cache_eXp") cache_path.mkdir(parents=True, exist_ok=True) xr.initialize_cache(str(cache_path), readonly=False) xr.use_spmd() @@ -51,7 +51,7 @@ def main(args): logger.info(f"SPMD mesh: {mesh.mesh_shape}, axes: {mesh.axis_names}, devices: {num_devices}") # --- Profiler --- - profile_path = Path("/tmp/data/profiler_out_tRiLlium_eXp") + profile_path = Path("/tmp/data/profiler_out_eXp") profile_path.mkdir(parents=True, exist_ok=True) profiler_port = 9012 profile_duration = args.profile_duration From 294fd1adeccd190a368672b2481739100d11dbac Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 15 Apr 2026 15:52:34 +0530 Subject: [PATCH 3/3] add check --- .../pytorch_xla/inference/flux/flux_inference_spmd.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/research_projects/pytorch_xla/inference/flux/flux_inference_spmd.py b/examples/research_projects/pytorch_xla/inference/flux/flux_inference_spmd.py index ebb4e7e3085c..9d1eeeae1b0d 100644 --- a/examples/research_projects/pytorch_xla/inference/flux/flux_inference_spmd.py +++ b/examples/research_projects/pytorch_xla/inference/flux/flux_inference_spmd.py @@ -46,7 +46,10 @@ def _vae_decode(latents, vae, height, width, device): def main(args): # --- SPMD mesh: 4-way model parallel to fit transformer + VAE on v5e chips --- num_devices = xr.global_runtime_device_count() - mesh = xs.Mesh(np.arange(num_devices), (num_devices // 4, 4), ("data", "model")) + if num_devices >= 4: + mesh = xs.Mesh(np.arange(num_devices), (num_devices // 4, 4), ("data", "model")) + else: + NotImplementedError xs.set_global_mesh(mesh) logger.info(f"SPMD mesh: {mesh.mesh_shape}, axes: {mesh.axis_names}, devices: {num_devices}")