Skip to content

[Feat] Adds LongCat-AudioDiT pipeline #13390

Merged
dg845 merged 16 commits intohuggingface:mainfrom
RuixiangMa:longcataudiodit
Apr 15, 2026
Merged

[Feat] Adds LongCat-AudioDiT pipeline #13390
dg845 merged 16 commits intohuggingface:mainfrom
RuixiangMa:longcataudiodit

Conversation

@RuixiangMa
Copy link
Copy Markdown
Contributor

@RuixiangMa RuixiangMa commented Apr 2, 2026

What does this PR do?

Adds LongCat-AudioDiT model support to diffusers.

Although LongCat-AudioDiT can be used for TTS-like generation, it is fundamentally a diffusion-based audio generation model (text conditioning + iterative latent denoising + VAE decoding) rather than a conventional autoregressive TTS model, so i think it fits naturally into diffusers.

Test

import soundfile as sf
import torch
from diffusers import LongCatAudioDiTPipeline

pipeline = LongCatAudioDiTPipeline.from_pretrained(
    "meituan-longcat/LongCat-AudioDiT-1B",
    torch_dtype=torch.float16,
)
pipeline = pipeline.to("cuda")

audio = pipeline(
    prompt="A calm ocean wave ambience with soft wind in the background.",
    audio_end_in_s=5.0,
    num_inference_steps=16,
    guidance_scale=4.0,
    output_type="pt",
).audios

output = audio[0, 0].float().cpu().numpy()
sf.write("longcat.wav", output, pipeline.sample_rate)

Result

longcat.wav

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@RuixiangMa RuixiangMa changed the title Longcataudiodit [Feat] Adds LongCat-AudioDiT support Apr 2, 2026
@RuixiangMa RuixiangMa changed the title [Feat] Adds LongCat-AudioDiT support [Feat] Adds LongCat-AudioDiT pipeline Apr 2, 2026
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@dg845 dg845 requested review from dg845 and yiyixuxu April 4, 2026 00:31
Comment thread src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py Outdated
Comment thread src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py Outdated
)


def _pixel_shuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, I think we should inline _pixel_shuffle_1d in UpsampleShortcut following #13390 (comment).

Comment thread src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py Outdated
Comment thread src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py Outdated
Comment thread src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py Outdated
Comment thread src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py Outdated
Comment thread src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py Outdated
Comment thread src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Comment thread src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py Outdated
Comment thread src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py Outdated
Comment thread src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py Outdated
Comment thread src/diffusers/models/transformers/transformer_longcat_audio_dit.py Outdated
Comment thread src/diffusers/models/transformers/transformer_longcat_audio_dit.py Outdated
Comment thread src/diffusers/models/transformers/transformer_longcat_audio_dit.py Outdated
Comment thread src/diffusers/models/transformers/transformer_longcat_audio_dit.py Outdated
Comment thread src/diffusers/models/transformers/transformer_longcat_audio_dit.py Outdated
Comment thread src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Comment on lines +515 to +519
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(

See #13390 (comment).

Comment thread src/diffusers/models/transformers/transformer_longcat_audio_dit.py Outdated
Comment on lines +584 to +589
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)

Can you also refactor forward here so that it is better organized, following #13390 (comment)? See for example the QwenImageTransformer2DModel.forward method:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reorganized parts of forward incrementally; kept the current structure otherwise to avoid unnecessary behavioral churn.

Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your continued work on this! Left some suggestions that should help LongCatAudioDiTPipeline support model offloading, layerwise casting, etc.

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 11, 2026
@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 14, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 14, 2026

Style bot fixed some files and pushed the changes.

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 14, 2026

@classmethod
@validate_hf_hub_args
def from_pretrained(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a conversion script?
our pipeline should not define from_pretrained method

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a conversion script? our pipeline should not define from_pretrained method

Added it and tested.

@yiyixuxu
Copy link
Copy Markdown
Collaborator

@claude can you help with a review here?

@github-actions
Copy link
Copy Markdown
Contributor

Claude Code is working…

I'll analyze this and get back to you.

View job run

Comment thread src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py Outdated
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 14, 2026
@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 15, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 15, 2026

Style bot fixed some files and pushed the changes.

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 15, 2026
Comment thread tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py Outdated
Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this PR! Does a HF Hub repo with the diffusers-format checkpoint currently exist? If not, would you be willing to create one?

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 15, 2026
@RuixiangMa
Copy link
Copy Markdown
Contributor Author

RuixiangMa commented Apr 15, 2026

Thanks for working on this PR! Does a HF Hub repo with the diffusers-format checkpoint currently exist? If not, would you be willing to create one?

There isn't a diffusers-format checkpoint yet. I'll try to create one.

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 15, 2026

Merging as the CI failures are unrelated.

@dg845 dg845 merged commit c41a3c3 into huggingface:main Apr 15, 2026
12 of 14 checks passed
@RuixiangMa RuixiangMa deleted the longcataudiodit branch April 15, 2026 13:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation models pipelines size/L PR with diff > 200 LOC tests utils

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants