Wan训练框架总览

发布时间:2026/5/29 6:50:06

Wan训练框架总览 WanVideo 模型架构深度分析1. 总览1.1 核心模型是什么WanVideo 使用Flow Matching Diffusion Transformer(流匹配扩散变换器)作为核心生成模型,代号DiT(Diffusion Transformer)。整个系统由4 个独立子模型组成,并通过WanVideoPipeline统一编排:子模型类名职责DiTWanModel/VaceWanModel/MotWanModel核心去噪网络,预测速度场 vVAEWanVideoVAE视频帧 ↔ 潜变量空间(压缩比 8×8 空间,4× 时序)T5WanTextEncoder(UMT5-XXL)文本 → 语义向量(dim=4096)CLIPWanImageEncoder图像 → 视觉特征(I2V 任务)1.2 模型解决的任务以本文档关注的配置为例(Wan2.1-Fun-V1.1-1.3B-Control-Camera):输入:文本 prompt + 首帧参考图像 + 相机控制轨迹(方向 + 速度)输出:满足相机控制约束的 480×832 视频(N 帧)任务类型:Camera-Controlled Image-to-Video(相机控制图生视频)1.3 模型主执行路径训练时(单阶段 sft):data dict → get_pipeline_inputs() # 构建三元组 → transfer_data_to_device() # 设备/精度迁移 → for unit in pipe.units: # 顺序执行各 Unit VAE Encoder → input_latents T5 Encoder → context CLIP Encoder → clip_feature Camera Unit → camera_embeds → FlowMatchSFTLoss() # 采样时间步 t,加噪,DiT 前向,MSE → loss.backward()推理时:WanVideoPipeline.__call__() → 执行所有 units(预处理) → for t in scheduler.timesteps: # 去噪循环(默认约 50 步) model_fn_wan_video(dit, latents, context, ...) CFG 加权(如开启) scheduler.step() 更新 latents → VAE.decode(latents) # 潜变量 → 视频帧 → 返回 PIL.Image 列表2. 模型主干结构分解2.1 顶层结构树WanVideoPipeline diffsynth/pipelines/wan_video.py ├── scheduler: FlowMatchScheduler 扩散噪声调度器 ├── tokenizer: HuggingfaceTokenizer 分词器(T5 前置) ├── text_encoder: WanTextEncoder 文本编码器(UMT5-XXL) ├── image_encoder: WanImageEncoder 图像编码器(CLIP) ├── dit: WanModel 主扩散变换器 ├── dit2: WanModel 次扩散变换器(混合模型中的低时间步分支) ├── vae: WanVideoVAE 视频 VAE ├── vace: VaceWanModel VACE 编辑分支 DiT ├── vap: MotWanModel 视频动作预测分支 ├── audio_encoder: WanS2VAudioEncoder 音频编码器(S2V 任务) ├── animate_adapter: WanAnimateAdapter 动画适配器 ├── units: List[PipelineUnit] 25 个执行单元 ├── post_units: List[PipelineUnit] 1 个后处理单元 └── model_fn: model_fn_wan_video 去噪网络调用包装函数2.2 Pipeline Units 列表(完整 25 个)#Unit 名称职责激活条件1ShapeChecker验证/规范化分辨率和帧数始终2NoiseInitializer初始化高斯噪声 latent始终3PromptEmbedderT5 编码 prompt → context始终4S2V音频处理(Wav2Vec 编码)audio_processor != None5InputVideoEmbedderVAE 编码输入视频input_video != None6ImageEmbedderVAEVAE 编码首帧图像 → yinput_image != None7ImageEmbedderCLIPCLIP 编码首帧 → clip_featureinput_image != None8ImageEmbedderFused融合 VAE/CLIP 特征特定模型9FunControl处理控制信号(ControlNet 风格)fun_control != None10FunReference处理参考帧reference_image != None11FunCameraControl相机控制轨迹编码camera_control_* != None12SpeedControl速度控制信号处理speed_control != None13VACEVACE 视频编辑条件编码vace_video != None14AnimateVideoSplit动画视频分割animate 任务15AnimatePoseLatents姿态潜变量处理animate 任务16AnimateFacePixelValues人脸像素值处理animate_face 任务17AnimateInpaint动画修复animate 任务18VAP视频动作预测处理vap 任务19UnifiedSequenceParallel序列并行处理多卡并行20TeaCacheTeaCache 加速(推理用)推理时21CfgMerger合并正负向 context(CFG)cfg_scale != 122LongCatVideo长视频拼接处理LongCat 模型23WanToDance_ProcessInputsWanToDance 输入处理WanToDance 任务24WanToDance_RefImageEmbedderWanToDance 参考图编码WanToDance 任务25WanToDance_ImageKeyframesEmbedder关键帧编码WanToDance 任务3. 主流程分析3.1 训练 step 完整调用链train.py:__main__ └── launcher_map["sft"](accelerator, dataset, model, model_logger, args=args) └── launch_training_task() [runner.py] ├── optimizer = AdamW(model.trainable_modules(), lr=1e-5) ├── dataloader = DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0]) ├── model, optimizer, dataloader = accelerator.prepare(...) └── for epoch in range(num_epochs): for data in dataloader: with accelerator.accumulate(model): loss = model(data) [training_module.py → WanTrainingModule.forward()] ├── inputs = get_pipeline_inputs(data) │ ├── inputs_posi = {"prompt": data["prompt"]} │ ├── inputs_nega = {} │ └── inputs_shared = {input_video, height, width, │ num_frames, cfg_scale=1, ...} │ └── parse_extra_inputs() 注入 input_image / │ camera_control_direction / │ camera_control_speed ├── inputs = transfer_data_to_device(inputs, cuda, bfloat16) ├── for unit in pipe.units: │ inputs = pipe.unit_runner(unit, pipe, *inputs) │ Unit_5: VAE.encode(input_video) → input_latents [B,16,T/4,H/16,W/16] │ Unit_6: VAE.encode(input_image) → y [B,16,1,H/16,W/16] │ Unit_7: CLIP.encode(input_image) → clip_feature [B,1,1152] │ Unit_3: T5.encode(prompt) → context [B,512,4096] │ Unit_11: camera_embed(direction,speed) → camera_embeds └── loss = task_to_loss["sft"](pipe, *inputs) └── FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi) ├── t = randin

相关新闻