
openpi使用流匹配方式来训练专家模型流匹配原理https://blog.csdn.net/qq_37795208/article/details/159049034https://blog.csdn.net/qq_37795208/article/details/159049034openpi论文解读https://blog.csdn.net/qq_37795208/article/details/159049034https://blog.csdn.net/qq_37795208/article/details/159049034借用上述博客的相关内容PI0.5输入图像、任务文本、机器人当前状态再加上一段“当前还带噪声的动作序列”模型学习预测“应该朝哪个方向把这段动作去噪”。训练时学这个方向推理时从纯噪声开始反复更新 10 步左右最后得到可执行的动作 chunk1.流匹配大致原理2.对应的流匹配计算代码对应openpi/sr/openpi/models/pi0.py中的损失函数计算定义如下#flow matching 的核心 override def compute_loss( self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool False ) - at.Float[at.Array, *b ah]: ## 1. 数据增强与预处理 preprocess_rng, noise_rng, time_rng jax.random.split(rng, 3) observation _model.preprocess_observation(preprocess_rng, observation, traintrain) # 2. 采样噪声与时间步 t batch_shape actions.shape[:-2] noise jax.random.normal(noise_rng, actions.shape)# 纯噪声,可以理解为一个随机的目标 time jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 0.001#采样时间Beta(1.5, 1)再裁剪到 [0.001, 0.999]训练时会随机抽取不同的 t让模型见过不同去噪阶段 time_expanded time[..., None, None] #noise从高斯分布采样的一段噪声动作t:flow matching 的时间步从 0 到 1; #x_t t * 噪声 (1-t) * 真实动作真实和噪声连线间的某一点 x_t time_expanded * noise (1 - time_expanded) * actions #x_t真实动作和噪声在时间步 t 下的线性混合状态 u_t noise - actions #理论上对应的目标速度场 # 4. 前向传播图像文本噪声动作 → 模型预测速度场 v_t # one big forward pass of prefix suffix at once #prefix只有环境的编码suffix只有动作和时间编码 prefix_tokens, prefix_mask, prefix_ar_mask self.embed_prefix(observation)#prefix 固定不变的环境信息图像 文本 机器人状态prefix_tokens图像 文本 编码后的 tokens suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond self.embed_suffix(observation, x_t, time)#suffix 每次都变的动作信息带噪动作 x_t 时间 tsuffix_tokens带噪动作 时间 编码后的 tokens input_mask jnp.concatenate([prefix_mask, suffix_mask], axis1) ar_mask jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis0) attn_mask make_attn_mask(input_mask, ar_mask) positions jnp.cumsum(input_mask, axis1) - 1 (prefix_out, suffix_out), _ self.PaliGemma.llm(#图像文本编码动作编码全部拼起来 → 送入大模型Gemma / PaliGemma [prefix_tokens, suffix_tokens], maskattn_mask, positionspositions, adarms_cond[None, adarms_cond] ) v_t self.action_out_proj(suffix_out[:, -self.action_horizon :]) #模型预测出来的速度场最后一层线性层 → 输出 v_t return jnp.mean(jnp.square(v_t - u_t), axis-1) #最终损失 MSE,Loss MSE( 模型预测速度场 v_t , 真实最优速度场 u_t )2.1首先进行随机采样噪声noise根据不同的时间步随机构造插值点x_t并计算真实的方向向量u_t2.2输入数据的准备这里会在输入前准备两组数据分别为代表环境的图像文本数据和代表动作的带噪声动作和时间数据。2.2.1)图像和文本的token准备这里prefix_tokens是将obs中的图像和语言通过预训练的PaliGemma VLM模型将图像和字符串变成token。这里由于设置了trainfalse图像编码器是冻结的不可训练文本embedding 层可训练prefix_tokens, prefix_mask, prefix_ar_mask self.embed_prefix(observation)#prefix 固定不变的环境信息图像 文本 机器人状态prefix_tokens图像 文本 编码后的 tokens suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond self.embed_suffix(observation, x_t, time)#suffix 每次都变的动作信息带噪动作 x_t 时间 tsuffix_tokens带噪动作 时间 编码后的 tokens其中这里的prefix_tokens中的文本token是包含了机械臂的状态state位姿关节在如下可以看出def tokenize(self, prompt: str, state: np.ndarray | None None) - tuple[np.ndarray, np.ndarray]: cleaned_text prompt.strip().replace(_, ).replace(\n, ) if state is not None: # This is the Pi05 format, where the state is part of the discrete language input. discretized_state np.digitize(state, binsnp.linspace(-1, 1, 256 1)[:-1]) - 1 state_str .join(map(str, discretized_state)) full_prompt fTask: {cleaned_text}, State: {state_str};\nAction: tokens self._tokenizer.encode(full_prompt, add_bosTrue)2.2.2带噪声的动作和时间的映射suffix_tokens是带噪声动作的tokenadarms_cond是时间的adaRMSsuffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond self.embed_suffix(observation, x_t, time)#suffix 每次都变的动作信息带噪动作 x_t 时间 tsuffix_tokens带噪动作 时间 编码后的 tokens2.2.3完整的数据输入过程其中self.PaliGemma.llm不是用于预测token而是将输入信息经过预训练大模型的 Transformer 来融合。# prefix_out: (B, P, 2048) ← prefix 对应的输出通常丢弃 # suffix_out: (B, H, 2048) ← suffix 对应的输出用于预测动作 # 输入环境信息prefix_tokens 动作信息suffix_tokens # 输出经过 27 层 Transformer 处理后融合了环境上下文的动作表示suffix_out # 作用利用预训练大模型的 Transformer 来深度理解和融合环境信息与动作信息 (prefix_out, suffix_out), _ self.PaliGemma.llm(#图像文本编码动作编码全部拼起来 → 送入大模型Gemma / PaliGemma [prefix_tokens, suffix_tokens],# 两个独立的输入 maskattn_mask, # 注意力掩码 (B, PH, PH) positionspositions, # 位置索引 (B, PH) adarms_cond[None, adarms_cond]# 时间条件只给 suffix )2.3训练目标和损失