Skip to content

add an example of spmd for flux on v5e-8#13474

Open
sayakpaul wants to merge 6 commits intomainfrom
flux-spmd-tpu
Open

add an example of spmd for flux on v5e-8#13474
sayakpaul wants to merge 6 commits intomainfrom
flux-spmd-tpu

Conversation

@sayakpaul
Copy link
Copy Markdown
Member

What does this PR do?

Add an example of model parallelism for Flux using PyTorch XLA. Tested on v5e-8.

Cc: @entrpn if you could review.

@sayakpaul sayakpaul requested a review from tengomucho April 15, 2026 03:29
@github-actions github-actions bot added examples size/L PR with diff > 200 LOC labels Apr 15, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 15, 2026
ckpt_id = "black-forest-labs/FLUX.1-dev"

# --- Text encoding (CPU) ---
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: again, for clarity I would avoid the "Trillium" word if we test on v5.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably fine. It's quite separate.

xs.mark_sharding(param, mesh, tuple(spec))

flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
FlashAttention.DEFAULT_BLOCK_SIZES = {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like black magic, consider adding a comment explaining where these come from

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @entrpn for those as it's copied from flux_inference.py.


def _vae_decode(latents, vae, height, width, device):
"""Move VAE to XLA, decode latents, move VAE back to CPU."""
vae.to(device)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know much about this, but isn't moving VAE back and forth between xla device and cpu quite expensive in time? Woudn't it be better just to keep it in XLA?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would barely fit, otherwise. Plus we have to free some stuff anyway to do the actual computation. Once it's compiled it doesn't take much of a hit barring some displacement overhead which is likely justifiable given the cheap pricing of v5es. Does it make sense?

2026-04-15 02:56:13 [info ] avg. inference over 2 iterations took 98.75175104649975 sec.
```

The first inference iteration includes VAE compilation (~195s). The second iteration shows the true steady-state speed (~1.76s).
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps you can include a dummy inference in the compilation part, so that VAE is compiled and timings look more regular.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get this part. Elaborate more? The block under "logger.info("starting compilation run...")" has the VAE compilation included.


def _vae_decode(latents, vae, height, width, device):
"""Move VAE to XLA, decode latents, move VAE back to CPU."""
vae.to(device)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would barely fit, otherwise. Plus we have to free some stuff anyway to do the actual computation. Once it's compiled it doesn't take much of a hit barring some displacement overhead which is likely justifiable given the cheap pricing of v5es. Does it make sense?

ckpt_id = "black-forest-labs/FLUX.1-dev"

# --- Text encoding (CPU) ---
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably fine. It's quite separate.

xs.mark_sharding(param, mesh, tuple(spec))

flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
FlashAttention.DEFAULT_BLOCK_SIZES = {
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @entrpn for those as it's copied from flux_inference.py.

2026-04-15 02:56:13 [info ] avg. inference over 2 iterations took 98.75175104649975 sec.
```

The first inference iteration includes VAE compilation (~195s). The second iteration shows the true steady-state speed (~1.76s).
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get this part. Elaborate more? The block under "logger.info("starting compilation run...")" has the VAE compilation included.

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 15, 2026
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 15, 2026
@sayakpaul sayakpaul requested a review from tengomucho April 15, 2026 10:22
@github-actions github-actions bot added the size/L PR with diff > 200 LOC label Apr 15, 2026
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 15, 2026
@sayakpaul
Copy link
Copy Markdown
Member Author

Additionally, @entrpn I am seeing recompilations with the following

  W torch_xla/csrc/runtime/pjrt_computation_client.cpp:682] Failed to deserialize executable: UNIMPLEMENTED: Deserializing serialized  executable not supported.                       

Is that expected?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples size/L PR with diff > 200 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants