Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion examples/research_projects/pytorch_xla/inference/flux/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,42 @@ python flux_inference.py

The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.

On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel):
On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel).

> **Note:** `flux_inference.py` uses `xmp.spawn` (one process per chip) and requires the full model to fit on a single chip. If you run into OOM errors (e.g., on v5e with 16GB HBM per chip), use the SPMD version instead — see below.

### SPMD version (for v5e-8 and similar)

On TPU configurations where a single chip cannot hold the full FLUX transformer (~16GB in bf16), use `flux_inference_spmd.py`. This script uses PyTorch/XLA SPMD to shard the transformer across multiple chips using a `(data, model)` mesh — 4-way model parallel so each chip holds ~4GB of weights, with the remaining chips for data parallelism.

```bash
python flux_inference_spmd.py --schnell
```

Key differences from `flux_inference.py`:
- **Single-process SPMD** instead of multi-process `xmp.spawn` — the XLA compiler handles all collective communication transparently.
- **Transformer weights are sharded** across the `"model"` mesh axis using `xs.mark_sharding`.
- **VAE lives on CPU**, moved to XLA only for decode (then moved back), since the transformer stays on device throughout.
- **Text encoding** runs on CPU before loading the transformer.

On a v5litepod-8 (v5e, 8 chips, 16GB HBM each) with FLUX.1-schnell, expect ~1.76 sec/image at steady state (after compilation):

```
2026-04-15 02:24:30 [info ] SPMD mesh: (2, 4), axes: ('data', 'model'), devices: 8
2026-04-15 02:24:30 [info ] encoding prompt on CPU...
2026-04-15 02:26:20 [info ] loading VAE on CPU...
2026-04-15 02:26:20 [info ] loading flux transformer from black-forest-labs/FLUX.1-schnell
2026-04-15 02:27:22 [info ] starting compilation run...
2026-04-15 02:52:55 [info ] compilation took 1533.4575625509997 sec.
2026-04-15 02:52:56 [info ] starting inference run...
2026-04-15 02:56:11 [info ] inference time: 195.74092420299985
2026-04-15 02:56:13 [info ] inference time: 1.7625778899996476
2026-04-15 02:56:13 [info ] avg. inference over 2 iterations took 98.75175104649975 sec.
```

The first inference iteration includes VAE compilation (~195s). The second iteration shows the true steady-state speed (~1.76s).
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps you can include a dummy inference in the compilation part, so that VAE is compiled and timings look more regular.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get this part. Elaborate more? The block under "logger.info("starting compilation run...")" has the VAE compilation included.


### v6e-4 results (original `flux_inference.py`)

```bash
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""FLUX inference on TPU using PyTorch/XLA SPMD.

Uses SPMD to shard the transformer across multiple TPU chips, enabling
inference on devices where the model doesn't fit on a single chip (e.g., v5e).
The VAE is loaded on CPU at startup, moved to XLA for decode, then moved back.
"""

from argparse import ArgumentParser
from pathlib import Path
from time import perf_counter

import numpy as np
import structlog
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.debug.profiler as xp
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from torch_xla.experimental.custom_kernel import FlashAttention

from diffusers import AutoencoderKL, FluxPipeline


cache_path = Path("/tmp/data/compiler_cache_eXp")
cache_path.mkdir(parents=True, exist_ok=True)
xr.initialize_cache(str(cache_path), readonly=False)
xr.use_spmd()

logger = structlog.get_logger()
metrics_filepath = "/tmp/metrics_report.txt"
VAE_SCALE_FACTOR = 8


def _vae_decode(latents, vae, height, width, device):
"""Move VAE to XLA, decode latents, move VAE back to CPU."""
vae.to(device)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know much about this, but isn't moving VAE back and forth between xla device and cpu quite expensive in time? Woudn't it be better just to keep it in XLA?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would barely fit, otherwise. Plus we have to free some stuff anyway to do the actual computation. Once it's compiled it doesn't take much of a hit barring some displacement overhead which is likely justifiable given the cheap pricing of v5es. Does it make sense?

latents = FluxPipeline._unpack_latents(latents, height, width, VAE_SCALE_FACTOR)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
with torch.no_grad():
image = vae.decode(latents, return_dict=False)[0]
vae.to("cpu")
return image


def main(args):
# --- SPMD mesh: 4-way model parallel to fit transformer + VAE on v5e chips ---
num_devices = xr.global_runtime_device_count()
if num_devices >= 4:
mesh = xs.Mesh(np.arange(num_devices), (num_devices // 4, 4), ("data", "model"))
else:
NotImplementedError
xs.set_global_mesh(mesh)
logger.info(f"SPMD mesh: {mesh.mesh_shape}, axes: {mesh.axis_names}, devices: {num_devices}")

# --- Profiler ---
profile_path = Path("/tmp/data/profiler_out_eXp")
profile_path.mkdir(parents=True, exist_ok=True)
profiler_port = 9012
profile_duration = args.profile_duration
if args.profile:
logger.info(f"starting profiler on port {profiler_port}")
_ = xp.start_server(profiler_port)

device = xm.xla_device()

# --- Checkpoint ---
if args.schnell:
ckpt_id = "black-forest-labs/FLUX.1-schnell"
else:
ckpt_id = "black-forest-labs/FLUX.1-dev"

# --- Text encoding (CPU) ---
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: again, for clarity I would avoid the "Trillium" word if we test on v5.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably fine. It's quite separate.

logger.info("encoding prompt on CPU...")
text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to("cpu")
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, _ = text_pipe.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)
image_processor = text_pipe.image_processor
del text_pipe

# --- Load VAE on CPU (moved to XLA only for decode) ---
logger.info("loading VAE on CPU...")
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16)

# --- Load transformer and shard ---
logger.info(f"loading flux transformer from {ckpt_id}")
flux_pipe = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=None,
tokenizer=None,
text_encoder_2=None,
tokenizer_2=None,
vae=None,
torch_dtype=torch.bfloat16,
).to(device)

for name, param in flux_pipe.transformer.named_parameters():
if param.dim() >= 2:
spec = [None] * param.dim()
largest_dim = max(range(param.dim()), key=lambda d: param.shape[d])
spec[largest_dim] = "model"
xs.mark_sharding(param, mesh, tuple(spec))

flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
FlashAttention.DEFAULT_BLOCK_SIZES = {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like black magic, consider adding a comment explaining where these come from

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @entrpn for those as it's copied from flux_inference.py.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, these block sizes have been optimized for Trillium through some tests we ran internally. They can be kept as is as long the v5e's vmem can handle it. These could be optimized in the future specifically for v5e.

"block_q": 1536,
"block_k_major": 1536,
"block_k": 1536,
"block_b": 1536,
"block_q_major_dkv": 1536,
"block_k_major_dkv": 1536,
"block_q_dkv": 1536,
"block_k_dkv": 1536,
"block_q_dq": 1536,
"block_k_dq": 1536,
"block_k_major_dq": 1536,
}

width = args.width
height = args.height
guidance = args.guidance
n_steps = 4 if args.schnell else 28

prompt_embeds = prompt_embeds.to(device)
pooled_prompt_embeds = pooled_prompt_embeds.to(device)
xs.mark_sharding(prompt_embeds, mesh, ("data", None, None))
xs.mark_sharding(pooled_prompt_embeds, mesh, ("data", None))

# --- Compilation run ---
logger.info("starting compilation run...")
ts = perf_counter()
latents = flux_pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=28,
guidance_scale=guidance,
height=height,
width=width,
output_type="latent",
).images
image = _vae_decode(latents, vae, height, width, device)
image = image_processor.postprocess(image)[0]
logger.info(f"compilation took {perf_counter() - ts} sec.")
image.save("/tmp/compile_out.png")

# --- Inference loop ---
seed = 4096 if args.seed is None else args.seed
xm.set_rng_state(seed=seed, device=device)
times = []
logger.info("starting inference run...")
for _ in range(args.itters):
ts = perf_counter()

if args.profile:
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
latents = flux_pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=n_steps,
guidance_scale=guidance,
height=height,
width=width,
output_type="latent",
).images
image = _vae_decode(latents, vae, height, width, device)
image = image_processor.postprocess(image)[0]
inference_time = perf_counter() - ts
logger.info(f"inference time: {inference_time}")
times.append(inference_time)

logger.info(f"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.")
image.save("/tmp/inference_out.png")
metrics_report = met.metrics_report()
with open(metrics_filepath, "w+") as fout:
fout.write(metrics_report)
logger.info(f"saved metric information as {metrics_filepath}")


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--schnell", action="store_true", help="run flux schnell instead of dev")
parser.add_argument("--width", type=int, default=1024, help="width of the image to generate")
parser.add_argument("--height", type=int, default=1024, help="height of the image to generate")
parser.add_argument("--guidance", type=float, default=3.5, help="guidance strength for dev")
parser.add_argument("--seed", type=int, default=None, help="seed for inference")
parser.add_argument("--profile", action="store_true", help="enable profiling")
parser.add_argument("--profile-duration", type=int, default=10000, help="duration for profiling in msec.")
parser.add_argument("--itters", type=int, default=15, help="items to run inference and get avg time in sec.")
args = parser.parse_args()
main(args)
Loading