H-RDT(H-RDT: Human Manipulation Enhanced Bimanual Robotic Manipulation) 代码及模型阅读(二)

根据前一篇博客的介绍H-RDT(H-RDT: Human Manipulation Enhanced Bimanual Robotic Manipulation) 代码及模型阅读 我们了解了RDT模型的结构及原理, 这篇文章我将带大家自己阅读其内部对应的代码。本篇博客将分成三大部分介绍该模型分别是: 数据、模型结构及训练,模型推理及评估。

一、 数据

根据train.py我们看到数据集分为训练数据集以及验证数据集

    # Create dataset and dataloader
    train_dataset = VLAConsumerDataset(
        config=config,
        image_transform=image_transform,
        num_cameras=config["common"]["num_cameras"],
        image_aug=args.image_aug,
        image_corrupt_severity=args.image_corrupt_severity if hasattr(args, 'image_corrupt_severity') else None,
        upsample_rate=args.upsample_rate, # 3
        val=False,
        use_precomp_lang_embed=args.precomp_lang_embed, # True
        task_name=args.task_name,
        dataset_name=args.dataset_name,
    )
    
    val_dataset = VLAConsumerDataset(
        config=config,
        image_transform=image_transform,
        num_cameras=config["common"]["num_cameras"],
        image_aug=False,
        image_corrupt_severity=None,
        upsample_rate=args.upsample_rate if hasattr(args, 'upsample_rate') else None,
        val=True,
        use_precomp_lang_embed=args.precomp_lang_embed,
        task_name=args.task_name,
        dataset_name=args.dataset_name,
    )

下面我们看下数据集VLAConsumerDataset这个类, 代码如下所示:

class VLAConsumerDataset(Dataset):
    """用于监督训练的视觉-语言-动作数据集
    该数据集将从缓冲目录加载数据。
    """

    def __init__(
        self,
        config,
        image_transform,
        num_cameras,
        image_size=None,
        auto_adjust_image_brightness=True,
        image_aug=False,
        image_corrupt_severity=None,
        state_noise_snr=None,
        use_precomp_lang_embed=True,
        upsample_rate=None,
        val=False,
        task_name="open_laptop",
        dataset_name="egodex",  # 添加数据集名称参数
    ):
        super(VLAConsumerDataset, self).__init__()
        self.dataset_name = dataset_name
        DATASET_NAMES = {self.dataset_name}
        
        # 创建数据集名称和ID之间的映射关系
        self.dataset_name2id = {name: i for i, name in enumerate(DATASET_NAMES)}
        self.dataset_id2name = {i: name for i, name in enumerate(DATASET_NAMES)}

        self.state_noise_snr = state_noise_snr
        self.num_cameras = num_cameras
        self.img_history_size = config["common"]["img_history_size"]  # 图像历史大小
        self.image_transform = image_transform   # 图像变换函数

        # 根据数据集名称初始化相应的数据集
        if self.dataset_name == "egodex":
            self.hdf5_dataset = EgoDexDataset(
                config=config,
                upsample_rate=upsample_rate,
                val=val,
                use_precomp_lang_embed=use_precomp_lang_embed,
                # 注意:如有需要可覆盖默认路径
                data_root="/data/datasets/rdt_data/EgoDex",
                stat_path="/data/datasets/rdt_data/EgoDex/egodex_stat.json",
            )
        elif self.dataset_name == "robotwin_agilex":
            self.hdf5_dataset = RobotwinAgilexDataset(
                mode="multi_task",
                config=config,
                # 注意:覆盖默认路径
                multi_task_root_dir="/data/datasets/rdt_data/RoboTwin2.0/dataset",
            )
            '''
            self.hdf5_dataset = RobotwinAgilexDataset(
                mode="single_task",
                task_name=task_name,
                hdf5_folder="Aloha-AgileX/data",
                max_episodes=50,
                config=config
                # 注意:覆盖默认路径
                # single_task_root_dir="/path/to/your/robotwin2/single",
            )
            '''
        elif self.dataset_name == "R1-6-kinova":
            self.hdf5_dataset = RobotwinAgilexDataset(
                mode="multi_task",
                config=config,
                multi_task_root_dir="/data/datasets/custom_data/R1-6-kinova",
            )
        else:
            raise ValueError(f"未知的数据集名称: {self.dataset_name}")
            
        print(f"初始化数据集: {self.dataset_name}")

        self.use_precomp_lang_embed = use_precomp_lang_embed # 是否使用预计算的语言嵌入

        self.image_size = image_size # 图像尺寸
        self.auto_adjust_image_brightness = auto_adjust_image_brightness # 是否自动调整图像亮度
        # self.image_aug_transform = get_image_augmentation()
        self.image_aug = image_aug # 是否进行图像增强

    def get_dataset_name2id(self):
        return self.dataset_name2id

    def get_dataset_id2name(self):
        return self.dataset_id2name

    @staticmethod
    def pairwise(iterable):
        # 将可迭代对象两两配对
        a = iter(iterable)
        return zip(a, a)

    def __len__(self) -> int:
        return len(self.hdf5_dataset)

    def __getitem__(self, index):
        # 从后端数据集中获取数据
        try:
            res = self.hdf5_dataset.get_item(index)
        except Exception as e:
            print(f"加载episode {index}时出错: {e}")
            return None
            
        # 添加对res为None的检查,如果为None则重试几次
        retry_count = 0
        max_retries = 5
        while res is None and retry_count < max_retries:
            retry_count += 1
            print(f"获取到None数据项,正在第{retry_count}次重试...")
            try:
                res = self.hdf5_dataset.get_item(index)
            except Exception as e:
                print(f"重试数据加载时出错: {e}")
                
        # 如果多次重试后仍然为None,返回默认值以防止训练中断
        if res is None:
            print(f"警告: 多次重试后仍无法获取有效数据,返回默认值")

        data_dict = {}
        data_dict['dataset_name'] = res['dataset_name']
        data_dict['data_idx'] = self.dataset_name2id[data_dict['dataset_name']]

        # 处理状态和动作数据
        data_dict["states"] = res['states']
        data_dict["actions"] = res['actions']
        data_dict["action_norm"] = res['action_norm']

        # 处理图像
        if self.dataset_name in ['egodex']:
            # 单摄像头/拼接图像处理
            image_metas = []
            images = res['current_images'][0]
            valid_mask = res.get('current_images_mask', [np.ones(self.img_history_size, dtype=bool)])[0] # [[True]]
            image_metas.append((images, valid_mask))
            
            rearranged_images = []
            for hist_idx in range(self.img_history_size):
                images, valid_mask = image_metas[0]
                if valid_mask[hist_idx]:
                    rearranged_images.append((images[hist_idx], True))
                else:
                    rearranged_images.append((None, False))
        else:
            # 多视角处理(原始逻辑)
            image_metas = []
            for cam_idx in range(self.num_cameras):
                images = res['current_images'][cam_idx]
                valid_mask = res.get('current_images_mask', np.ones((self.num_cameras, self.img_history_size), dtype=bool))[cam_idx]
                image_metas.append((images, valid_mask))

            rearranged_images = []
            for hist_idx in range(self.img_history_size):
                for cam_idx in range(self.num_cameras):
                    images, valid_mask = image_metas[cam_idx]
                    if valid_mask[hist_idx]:
                        rearranged_images.append((images[hist_idx], True))
                    else:
                        rearranged_images.append((None, False))

        all_pixel_values = []
        for image, valid in rearranged_images:
            image = Image.fromarray(image) if image is not None else None
            #
            # if valid and self.auto_adjust_image_brightness:
            #     pixel_values = list(image.getdata())
            #     average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
            #     if average_brightness <= 0.15:
            #         image = transforms.ColorJitter(brightness=(1.75,1.75))(image)

            # 仅对50%的图像应用图像增强
            if valid and self.image_aug and (random.random() > 0.5):
                aug_type = random.choice([
                    "corrput_only", "color_only", "both"])
                if aug_type != "corrput_only":
                    image = transforms.ColorJitter(
                        brightness=0.3, contrast=0.4, saturation=0.5, hue=0.03)(image)
                if aug_type != "color_only":
                    image = image_corrupt(image)
                # image = self.image_aug_transform(image)

            pixel_values = self.image_transform(image) # {"dino": (3, 384, 384)), "siglip": (3, 384, 384)}
            all_pixel_values.append(pixel_values)

        # 处理dino-siglip格式的图像
        pv_example = all_pixel_values[0]
        merged_pixel_values = {
            k: torch.stack(
                [pv[k] for pv in all_pixel_values]
            )
            for k in pv_example
        }
        data_dict["images"] = merged_pixel_values

        if self.use_precomp_lang_embed:
            # 所有数据集都应提供lang_embeds作为张量
            if "lang_embeds" in res:
                data_dict["lang_embeds"] = res["lang_embeds"]
            elif torch.is_tensor(res["instruction"]):
                data_dict["lang_embeds"] = res["instruction"]
            else:
                # 遗留方式:从文件路径加载
                data_dict["lang_embeds"] = torch.load(res["instruction"])["embeddings"].squeeze(0)

        # 将所有numpy数组转换为torch张量
        for k, v in data_dict.items():
            if isinstance(v, np.ndarray):
                data_dict[k] = torch.from_numpy(v)

        # 验证所有数据都是张量
        for k, v in data_dict.items():
            assert not isinstance(v, np.ndarray), f"键: {k}, 值: {v}"

        """
        {
            'action_norm': tensor((16, 48), dtype=torch.float64),
            'actions': tensor(16, 48),dtype=torch.float64), 
            'data_idx': 0,
            'dataset_name': 'egodex', 
            'images': {
                'dino': tensor((1, 3, 384. 384)), 
                'siglip': tensor((1, 3, 384, 384))
            }, 
            'lang_embeds': tensor(19, 4096), dtype=torch.bfloat16), 
            'states': tensor((1, 48), dtype=torch.float64)
        }
        """
        return data_dict

接着对数据进行后处理

class DataCollatorForVLAConsumerDataset(object):
    """用于监督训练的样本整理器."""

    def __init__(self, use_precomp_lang_embed=True) -> None:
        self.use_precomp_lang_embed = use_precomp_lang_embed
        
    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        # 使用通用字段初始化批次
        batch = {
            "states": [],
            "actions": [],
            "action_norm": [],
            "images": [],
            "data_indices": [],
        }
        
        if self.use_precomp_lang_embed:
            lang_embeds = []
            lang_embed_lens = []

        # 处理批次中的每个实例
        for instance in instances:
            # 处理数值数据
            keys_to_check = [
                'states', 'actions',
                'action_norm',
            ]
            for key in keys_to_check:
                if isinstance(instance[key], torch.Tensor):
                    item = instance[key]
                else:
                    item = torch.from_numpy(instance[key])
                batch[key].append(item)

            # 处理图像
            batch["images"].append(instance["images"])
            batch["data_indices"].append(instance["data_idx"])

            if self.use_precomp_lang_embed and "lang_embeds" in instance:
                lang_embeds.append(instance["lang_embeds"])
                lang_embed_lens.append(instance["lang_embeds"].shape[0])

        # 堆叠数值数据的张量
        keys_to_stack = [
            'states', 'actions',
            'action_norm',
        ]
        for key in keys_to_stack:
            batch[key] = torch.stack(batch[key], dim=0)

        # 处理dino-siglip格式的图像
        pv_example = batch["images"][0]
        merged_pixel_values = {
            k: torch.stack(
                [pv[k] for pv in batch["images"]]
            )
            for k in pv_example
        }
        batch["images"] = merged_pixel_values

        if self.use_precomp_lang_embed:
            lang_embeds = torch.nn.utils.rnn.pad_sequence(
                lang_embeds,
                batch_first=True,
                padding_value=0)
            input_lang_attn_mask = torch.zeros(
                lang_embeds.shape[0], lang_embeds.shape[1], dtype=torch.bool)
            for i, l in enumerate(lang_embed_lens):
                input_lang_attn_mask[i, :l] = True
            batch["lang_embeds"] = lang_embeds
            batch["lang_attn_mask"] = input_lang_attn_mask
        """
        微调时48改为14
        {
            "states": Tensor(32, 1, 48),
            "actions": Tensor(32, 16, 48),
            "action_norm": Tensor(32, 16, 48), 
            "images": {
                "dino": Tensor(32, 1, 3, 384, 384),
                "siglip": Tensor(32, 1, 3, 384, 384)
            }
            "data_indices": list(32) -> 都是0第一个数据索引,
            "lang_embeds": Tensor(32, 44, 4096),
            "lang_attn_mask": Tensor(32, 44)
        }
        """
        return batch

从上述代码可以看到该代码目前包含了两个数据集分别是pretrain数据集-> EGODEX 以及finetune数据集->robotwin_agilex. 我们先看下pretrain数据集, 该数据集代码如下所示, 并加上详细的注释:

#!/usr/bin/env python3
"""
EgoDex 数据集加载器
实现 48 维手部动作表示、单视角图像和语言嵌入数据的加载
"""

import h5py
import numpy as np
import torch
import accelerate
import os
import cv2
from pathlib import Path
import warnings
import random
import json
warnings.filterwarnings("ignore")


class EgoDexDataset:
    """EgoDex 数据集加载器类"""
    
    def __init__(self, 
                 data_root=None, 
                 config=None,
                 upsample_rate=3,
                 val=False,
                 use_precomp_lang_embed=True,
                 stat_path=None):
        """
        初始化 EgoDex 数据集加载器

        Args:
            data_root: 数据根目录路径 (例如: "/share/hongzhe/datasets/egodex")
            config: 配置字典,包含各种参数设置
            upsample_rate: 时间数据上采样率 (帧采样间隔)
            val: 是否为验证集 (True 表示测试集, False 表示训练集)
            use_precomp_lang_embed: 是否使用预计算的语言嵌入
            stat_path: 统计文件路径 (默认: datasets/pretrain/egodex_stat.json)
        """
        # 数据集名称
        self.DATASET_NAME = "egodex"
        # 数据根目录路径
        self.data_root = Path(data_root)
        # 配置参数
        self.config = config
        # 上采样率,控制帧采样间隔
        self.upsample_rate = upsample_rate
        # 是否为验证/测试集
        self.val = val
        # 是否使用预计算的语言嵌入
        self.use_precomp_lang_embed = use_precomp_lang_embed
        
        # 根据配置设置参数
        if config:
            # 动作块大小,即未来动作序列的长度
            self.chunk_size = config['common']['action_chunk_size'] # 16
            # 状态维度,即动作向量的维度
            self.state_dim = config['common']['action_dim'] # 48
            # 图像历史帧数量
            self.img_history_size = config['common']['img_history_size'] # 1
        else:
            # 默认参数
            self.chunk_size = 16
            self.state_dim = 48
            self.img_history_size = 1
        
        # 如果未提供统计文件路径,则设置默认路径
        if stat_path is None:
            current_dir = os.path.dirname(os.path.abspath(__file__))
            stat_path = os.path.join(current_dir, 'egodex_stat.json')
        
        # 加载数据文件列表
        self.data_files = self._load_file_list()
        split_name = "test" if self.val else "train"
        print(f"Loaded {len(self.data_files)} {split_name} data files")
        
        # 加载用于归一化的统计数据
        self.action_min = None
        self.action_max = None
        if os.path.exists(stat_path):
            with open(stat_path, 'r') as f:
                stat = json.load(f)
            if 'egodex' in stat:
                # 加载动作数据的最小值和最大值,用于归一化
                self.action_min = np.array(stat['egodex']['min'])
                self.action_max = np.array(stat['egodex']['max'])
    
    def get_dataset_name(self):
        """返回数据集名称"""
        return self.DATASET_NAME
    
    def _load_file_list(self):
        """
        加载数据文件列表
        
        Returns:
            data_files: 包含所有数据文件信息的列表
        """
        data_files = []
        print("scan train/val file ..." )
        if not self.val:
            # 训练集: part1-part5 + extra
            for part in ['part1', 'part2', 'part3', 'part4', 'part5', 'extra']:
                part_dir = self.data_root / part
                if part_dir.exists():
                    data_files.extend(self._scan_directory(part_dir))
        else:
            # 测试集: test
            test_dir = self.data_root / 'test'
            if test_dir.exists():
                data_files.extend(self._scan_directory(test_dir))
        
        return data_files
    
    def _scan_directory(self, directory):
        """
        扫描目录中的文件
        
        Args:
            directory: 要扫描的目录路径
            
        Returns:
            files: 包含文件信息的列表
        """
        files = []
        # 遍历目录中的任务子目录
        for task_dir in directory.iterdir():
            if task_dir.is_dir():
                # 收集所有 hdf5 文件
                hdf5_files = list(task_dir.glob('*.hdf5'))
                for hdf5_file in hdf5_files:
                    # 获取文件名(不含扩展名)
                    file_index = hdf5_file.stem
                    # 构造对应的 mp4 和 pt 文件路径
                    mp4_file = task_dir / f"{file_index}.mp4"
                    pt_file = task_dir / f"{file_index}.pt"
                    
                    # 确保所有必需的文件都存在
                    if (hdf5_file.exists() and mp4_file.exists() and 
                        pt_file.exists()):
                        files.append({
                            'hdf5': hdf5_file,
                            'mp4': mp4_file,
                            'pt': pt_file,
                            'task': task_dir.name,
                            'file_index': file_index
                        })
        return files
    
    def construct_48d_action(self, hdf5_file, frame_indices):
        """
        直接提取预计算的 48 维手部动作表示
        
        Args:
            hdf5_file: HDF5 文件对象
            frame_indices: 要提取的帧索引列表
            
        Returns:
            actions: (T, 48) 的动作序列
        """
        # 检查是否存在预计算的动作数据
        if 'actions_48d' not in hdf5_file:
            raise ValueError("Missing precomputed actions_48d data in HDF5 file, please run precompute_48d_actions.py first")
        
        # 直接读取预计算的 48 维动作数据
        precomputed_actions = hdf5_file['actions_48d'][:]
        
        # 提取指定帧的动作
        selected_actions = precomputed_actions[frame_indices]
        
        return selected_actions.astype(np.float32)
    
    def parse_img_data(self, mp4_path, idx):
        """
        按照 cvpr_real_dataset.py 的采样逻辑加载图像帧
        
        Args:
            mp4_path: MP4 文件路径
            idx: 当前帧索引
            
        Returns:
            frames: (img_history_size, H, W, 3) 的图像帧
        """
        # 打开视频文件
        cap = cv2.VideoCapture(str(mp4_path))
        # 获取视频总帧数
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # 按照 cvpr_real_dataset.py 的逻辑计算采样范围
        start_i = max(idx - self.img_history_size * self.upsample_rate + 1, 0)
        num_frames = (idx - start_i) // self.upsample_rate + 1
        
        frames = []
        
        try:
            # 按照指定间隔采样帧
            for i, frame_idx in enumerate(range(start_i, idx + 1, self.upsample_rate)):
                if frame_idx < total_frames:
                    # 定位到指定帧并读取
                    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
                    ret, frame = cap.read()
                    if ret:
                        # BGR 转 RGB
                        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                        frames.append(frame)
                    else:
                        print(f"Warning: Not enough frames in {mp4_path}")
                        break
                else:
                    # 如果帧索引超出总帧数,使用最后一帧
                    print(f"Warning: Frame index exceeds total frames in {mp4_path}")
                    break
        except Exception as e:
            print(f"Error loading image frames: {e}")
        
        # 释放视频文件
        cap.release()
        
        # 转换为 numpy 数组
        if frames:
            frames = np.array(frames)
        else:
            frames = np.zeros((1, 1080, 1920, 3), dtype=np.uint8)
        
        # 如有必要,进行填充 (按照 cvpr_real_dataset.py 的逻辑)
        if frames.shape[0] < self.img_history_size:
            pad_frames = np.repeat(frames[:1], self.img_history_size - frames.shape[0], axis=0)
            frames = np.concatenate([pad_frames, frames], axis=0)
        
        return frames
    
    def __len__(self):
        """返回数据集大小"""
        return len(self.data_files)
    
    def get_item(self, idx=None):
        """
        获取一个数据样本
        
        Args:
            idx: 数据样本索引,如果为 None 则随机选择
            
        Returns:
            包含所有必需字段的数据字典
        """
        # 如果未指定索引,则随机选择
        if idx is None:
            idx = random.randint(0, len(self.data_files) - 1)
        
        # 获取文件信息
        file_info = self.data_files[idx % len(self.data_files)]
        
        try:
            # 加载 HDF5 数据
            with h5py.File(file_info['hdf5'], 'r') as f:
                # 获取总帧数
                transforms_group = f['transforms']
                total_frames = list(transforms_group.values())[0].shape[0]
                
                # 按照 cvpr_real_dataset.py 的逻辑计算随机索引
                max_index = total_frames - 2
                if max_index < 0:
                    print(f"Warning: Not enough frames in {file_info['hdf5']}")
                    return None
                
                # 随机采样索引
                index = random.randint(0, max_index)
                
                # 使用当前索引构建 48 维动作
                current_action = self.construct_48d_action(f, [index])
                
                # 构建未来动作序列
                action_end = min(index + self.chunk_size * self.upsample_rate, max_index + 1)
                action_indices = list(range(index + 1, action_end + 1, self.upsample_rate))
                
                # 如果动作帧不足,则重复最后一个动作帧
                while len(action_indices) < self.chunk_size:
                    action_indices.append(action_indices[-1] if action_indices else index + 1)
                
                # 提取动作序列
                actions = self.construct_48d_action(f, action_indices[:self.chunk_size])
                
                # 如果动作形状仍不正确,则用最后一个动作进行填充
                if actions.shape[0] < self.chunk_size:
                    last_action = actions[-1:] if len(actions) > 0 else current_action
                    padding = np.repeat(last_action, self.chunk_size - actions.shape[0], axis=0)
                    actions = np.concatenate([actions, padding], axis=0)
            
            # 归一化动作数据
            if self.action_min is not None and self.action_max is not None:
                current_action = (current_action - self.action_min) / (self.action_max - self.action_min) * 2 - 1
                current_action = np.clip(current_action, -1, 1)
                actions = (actions - self.action_min) / (self.action_max - self.action_min) * 2 - 1
                actions = np.clip(actions, -1, 1)
            
            # 使用新的采样逻辑加载单视角图像帧
            image_frames = self.parse_img_data(file_info['mp4'], index)
            
            # 只取所需的历史帧数量
            image_frames = image_frames[-self.img_history_size:]
            
            # 加载语言嵌入
            lang_embed_path = file_info['pt']
            
            # 返回数据字典
            return {
                'states': current_action,  # (1, 48) 当前状态
                'actions': actions,  # (chunk_size, 48) 未来动作序列
                'action_norm': np.ones_like(actions),  # 动作指示器
                'current_images': [image_frames],  # [(img_history_size, H, W, 3)] 单视角图像
                'current_images_mask': [np.ones(self.img_history_size, dtype=bool)],  # 图像掩码
                'instruction': str(lang_embed_path),  # 语言嵌入文件路径
                'dataset_name': self.DATASET_NAME,
                'task': file_info['task'],
                'file_info': {
                    'hdf5_path': str(file_info['hdf5']),
                    'mp4_path': str(file_info['mp4']),
                    'pt_path': str(file_info['pt']),
                    'total_frames': total_frames,
                    'selected_index': index,
                    'action_indices': action_indices
                }
            }
            
        except Exception as e:
            print(f"Error loading data {file_info['hdf5']}: {e}")
            return None

    def __getitem__(self, idx):
        """PyTorch Dataset 接口"""
        return self.get_item(idx)


if __name__ == "__main__":
    # 测试数据集
    dataset = EgoDexDataset(
        data_root="/share/hongzhe/datasets/egodex",
        val=False,
        upsample_rate=3
    )
    
    print(f"Dataset size: {len(dataset)}")
    
    # 测试加载样本
    sample = dataset.get_item(0)
    print("Sample data structure:")
    for key, value in sample.items():
        if isinstance(value, np.ndarray):
            print(f"  {key}: {value.shape}")
        else:
            print(f"  {key}: {type(value)}")

接着我们看下finetune数据集代码, 这里使用的是roboTwin2 huggingface的数据集, 代码如下:

import os
import random
import h5py
import numpy as np
import cv2
import json
import pandas as pd
import time
import glob
from typing import List, Dict, Optional
import sys
import warnings
import traceback
import torch

warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    message=".*multichannel.*"
)

# Import imgaug libraries
import imgaug as ia
import imgaug.augmenters as iaa

if not hasattr(np, 'bool'):
    np.bool = bool

class RobotwinAgilexDataset:
    """
    Dataset for loading RobotWin Agilex robot data with joint actions.
    Supports both single-task and multi-task modes with balanced sampling.
    """
    def __init__(
        self, 
        mode="multi_task",  # "single_task" or "multi_task"
        single_task_root_dir=None,  # For single task mode
        multi_task_root_dir=None,  # For multi task mode
        task_name=None,  # Required for single_task mode
        hdf5_folder=None,  # Required for single_task mode, e.g., "demo_clean/data"
        max_episodes=None,  # Maximum number of episodes to load for single task
        config=None,
        stat_path=None,
        upsample_rate=3,
        val=False,
        image_corrupt_severity=5
    ):
        """
        Initialize RobotWin dataset with single/multi-task support.
        
        Args:
            mode: "single_task" or "multi_task"
            single_task_root_dir: Root directory for single task mode
            multi_task_root_dir: Root directory for multi task mode
            task_name: Task name for single task mode (e.g., "beat_block_hammer")
            hdf5_folder: HDF5 folder path for single task mode (e.g., "demo_clean/data")
            max_episodes: Maximum episodes to load for single task (None = all)
            config: Configuration dictionary
            stat_path: Path to normalization statistics file
            upsample_rate: Temporal data upsampling rate
            val: Whether this is validation set
            image_corrupt_severity: Image corruption severity level
        """
        # 设置数据集基本信息
        self.DATASET_NAME = "robotwin_agilex"
        self.mode = mode
        self.single_task_root_dir = single_task_root_dir
        self.multi_task_root_dir = multi_task_root_dir
        self.task_name = task_name
        self.hdf5_folder = hdf5_folder
        self.max_episodes = max_episodes
        self.upsample_rate = upsample_rate
        self.val = val
        self.image_corrupt_severity = image_corrupt_severity
        
        # 验证模式参数
        if mode == "single_task":
            # 单任务模式需要提供任务名和HDF5文件夹路径
            if not task_name or not hdf5_folder:
                raise ValueError("single_task mode requires task_name and hdf5_folder parameters")
        elif mode == "multi_task":
            pass  # 多任务模式不需要额外验证
        else:
            raise ValueError(f"Invalid mode: {mode}. Must be 'single_task' or 'multi_task'")
        
        # 从配置中获取基本参数
        self.chunk_size = config['common']['action_chunk_size']  # 动作块大小
        self.state_dim = config['common']['action_dim']          # 状态维度
        self.img_history_size = config['common']['img_history_size']  # 图像历史大小
        
        # 设置相机参数(仅支持3个相机)
        self.num_cameras = config['common']['num_cameras']
        if self.num_cameras == 3:
            # 定义三个相机:高处相机、右手腕相机、左手腕相机
            self.cameras = ["cam_high", "cam_right_wrist", "cam_left_wrist"]
        else:
            raise ValueError(f"Unsupported num_cameras={self.num_cameras}, only 3 cameras supported.")

        # 相机映射关系
        self.camera_mapping = {
            "cam_high": "head_camera", 
            "cam_left_wrist": "left_camera",
            "cam_right_wrist": "right_camera"
        }
        
        # 如果未提供统计文件路径,则使用默认路径
        if stat_path is None:
            current_dir = os.path.dirname(os.path.abspath(__file__))
            stat_path = os.path.join(current_dir, 'stats.json')
        
        # 加载动作数据的统计信息(用于归一化)
        with open(stat_path, 'r') as file:
            stat = json.load(file)
        self.action_min = np.array(stat['robotwin_agilex']['min'])
        self.action_max = np.array(stat['robotwin_agilex']['max'])
        
        # 初始化数据结构
        if mode == "single_task":
            self.episode_files = []  # 单任务模式下的episode文件列表
        else:
            self.task_to_episodes = {}  # 多任务模式下任务到episode文件的映射
            self.task_weights = {}      # 任务采样权重
        
        self.total_episodes = 0     # 总episode数
        
        # 初始化数据集
        self._initialize_dataset()
    
    def _scan_single_task_folder(self):
        """
        扫描单任务文件夹获取所有HDF5文件
        """
        # 构建任务数据目录路径
        task_dir = os.path.join(self.single_task_root_dir, self.task_name, self.hdf5_folder)
        if not os.path.exists(task_dir):
            print(f"Warning: Task folder {task_dir} does not exist")
            return []
        
        # 查找目录中的所有HDF5文件
        hdf5_files = []
        for f in os.listdir(task_dir):
            if f.endswith(".hdf5"):
                hdf5_path = os.path.join(task_dir, f)
                hdf5_files.append(hdf5_path)
        
        # 排序以确保一致的顺序
        hdf5_files.sort()
        
        # 如果指定了最大episode数,则限制文件数量
        if self.max_episodes is not None:
            hdf5_files = hdf5_files[:self.max_episodes]
        
        print(f"Single task {self.task_name}: Found {len(hdf5_files)} HDF5 files")
        return hdf5_files
    
    def _scan_multi_task_folders(self):
        """
        扫描多任务文件夹获取所有HDF5文件
        """
        if not os.path.exists(self.multi_task_root_dir):
            print(f"Warning: Multi-task root directory {self.multi_task_root_dir} does not exist")
            return {}
        
        # 获取所有任务文件夹
        task_folders = [d for d in os.listdir(self.multi_task_root_dir) 
                       if os.path.isdir(os.path.join(self.multi_task_root_dir, d))]
        
        task_to_episodes = {}
        
        # 遍历每个任务文件夹
        for task_folder in task_folders:
            task_dir = os.path.join(self.multi_task_root_dir, task_folder)
            hdf5_files = []
            
            # 递归查找任务文件夹中的所有HDF5文件
            for root, dirs, files in os.walk(task_dir):
                for file in files:
                    if file.endswith(".hdf5"):
                        hdf5_path = os.path.join(root, file)
                        hdf5_files.append(hdf5_path)
            
            # 随机打乱文件顺序
            random.shuffle(hdf5_files)
            task_to_episodes[task_folder] = hdf5_files
            print(f"Multi-task {task_folder}: Found {len(hdf5_files)} HDF5 files")
        
        return task_to_episodes
    
    def _initialize_dataset(self):
        """
        初始化数据集,重新扫描文件夹并更新采样权重
        """
        print("Initializing dataset...")
        
        if self.mode == "single_task":
            # 单任务模式:扫描单个任务文件夹
            self.episode_files = self._scan_single_task_folder()
            self.total_episodes = len(self.episode_files)
            
            if self.total_episodes == 0:
                raise ValueError("Error: No HDF5 files found, please check data path")
            
            print(f"Single task dataset initialized. Total {self.total_episodes} episodes")
        
        else:
            # 多任务模式:扫描所有任务文件夹
            self.task_to_episodes = self._scan_multi_task_folders()
            
            # 计算总episode数
            all_task_count = sum(len(episodes) for episodes in self.task_to_episodes.values())
            
            if all_task_count == 0:
                raise ValueError("Error: No HDF5 files found, please check data path")
            
            # 计算采样权重 - 所有任务权重相等(1:1:1:1:1:1采样)
            task_count = len(self.task_to_episodes)
            for task_name in self.task_to_episodes.keys():
                self.task_weights[task_name] = 1.0 / task_count
            
            self.total_episodes = all_task_count
            
            print(f"Multi-task dataset initialized. Total {all_task_count} episodes across {task_count} tasks")
            print(f"Task weights: {self.task_weights}")
    
    def __len__(self):
        """
        返回数据集的近似长度
        """
        return self.total_episodes * 200  # 假设每个HDF5文件有200个样本
    
    def get_dataset_name(self):
        """
        返回数据集名称
        """
        return self.DATASET_NAME

    def parse_img_data(self, dataset, idx):
        """
        处理单个相机的图像数据
        
        Args:
            dataset: HDF5数据集,包含单个相机的图像
            idx: 当前帧索引
            
        Returns:
            处理后的图像序列,形状为 [history_size, H, W, 3]
        """
        # 计算起始索引
        start_i = max(idx - self.img_history_size * self.upsample_rate + 1, 0)
        num_frames = (idx - start_i) // self.upsample_rate + 1

        # 使用标准分辨率存储图像(320x240)
        frames = np.zeros((num_frames, 240, 320, 3), dtype=np.uint8)
        
        try:
            # 按照上采样率提取图像帧
            for i, frame_idx in enumerate(range(start_i, idx + 1, self.upsample_rate)):
                if frame_idx < len(dataset):
                    img_data = dataset[frame_idx]
                    
                    # 解码图像
                    decoded_img = self.decode_image_with_opencv(img_data)
                    if decoded_img is None:
                        raise Exception(f"[DEBUG] decode error")

                    if decoded_img is not None:
                        frames[i] = decoded_img

        except Exception as e:
            print(f"[DEBUG] decode_image_with_opencv error: {e}")

        # 如果帧数不足历史大小,用第一帧进行填充
        if num_frames < self.img_history_size:
            pad_frames = np.repeat(frames[:1], self.img_history_size - num_frames, axis=0)
            frames = np.concatenate([pad_frames, frames])
        
        return frames

    def decode_image_with_opencv(self, img_data):
        """
        使用OpenCV解码图像数据,保持RGB格式
        
        Args:
            img_data (bytes): 二进制图像数据
        
        Returns:
            np.ndarray: 解码后的图像数组,形状=(240, 320, 3),RGB格式
        """
        try:
            # 使用OpenCV解码
            nparr = np.frombuffer(img_data, np.uint8)
            bgr_img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
            
            # 确保解码成功
            if bgr_img is None:
                raise Exception("OpenCV decoding failed")
                
            # OpenCV默认使用BGR格式,转换为RGB
            rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
            
            # 如果尺寸不匹配,调整到标准尺寸(320x240)
            if rgb_img.shape[:2] != (240, 320):
                rgb_img = cv2.resize(rgb_img, (320, 240))
                
            return rgb_img
            
        except Exception as e:
            print(f"[DEBUG] Image decoding failed: {e}")
            return None

    def load_language_embedding(self, hdf5_file_path):
        """
        从集中式的lang_embeddings文件夹加载预编码的语言指令嵌入
        
        Args:
            hdf5_file_path (str): HDF5 episode文件路径
            
        Returns:
            torch.Tensor: 任务指令嵌入,如果未找到则返回None
        """
        try:
            # 从HDF5文件路径提取任务名
            task_name = None
            
            if self.mode == "single_task":
                # 单任务模式使用提供的任务名
                task_name = self.task_name
            else:
                # 多任务模式从文件路径提取任务名
                if self.multi_task_root_dir and self.multi_task_root_dir in hdf5_file_path:
                    # 获取相对于multi_task_root_dir的相对路径
                    relative_path = os.path.relpath(hdf5_file_path, self.multi_task_root_dir)
                    # 任务名是相对路径中的第一个目录
                    task_name = relative_path.split(os.sep)[0]
            
            if task_name is None:
                print(f"Warning: Could not extract task name from path {hdf5_file_path}")
                return None
            
            # 构建嵌入文件路径
            current_dir = os.path.dirname(os.path.abspath(__file__))
            embedding_path = os.path.join(current_dir, 'lang_embeddings', f"{task_name}.pt")
            
            if not os.path.exists(embedding_path):
                print(f"Warning: Task embedding file not found: {embedding_path}")
                return None
            
            # 加载嵌入数据
            embedding_data = torch.load(embedding_path, map_location='cpu')
            
            # 提取嵌入张量
            embeddings = embedding_data.get('embeddings', None)
            
            if embeddings is None:
                print(f"Warning: No embeddings found in {embedding_path}")
                return None
            
            # 如果存在批处理维度则移除(从3D转换为2D)
            if embeddings.dim() == 3:
                embeddings = embeddings.squeeze(0)
            
            return embeddings
            
        except Exception as e:
            print(f"Error loading language embedding from {hdf5_file_path}: {e}")
            return None

    def extract_episode_item(self, hdf5_file):
        """
        从HDF5文件中提取单个样本
        
        Args:
            hdf5_file: HDF5文件路径
            
        Returns:
            包含提取数据的字典,如果提取失败则返回None
        """
        try:
            # 以只读模式打开HDF5文件
            with h5py.File(hdf5_file, 'r', swmr=True, libver='latest') as f:
                # 从新的HDF5结构加载关节动作数据
                try:
                    left_arm = f["joint_action/left_arm"][:]
                    left_gripper = f["joint_action/left_gripper"][:]
                    right_arm = f["joint_action/right_arm"][:]
                    right_gripper = f["joint_action/right_gripper"][:]
                    
                    # 处理维度不匹配问题
                    # 如果左臂是2D但左夹爪是1D,扩展夹爪维度
                    if len(left_arm.shape) == 2 and len(left_gripper.shape) == 1:
                        left_gripper = left_gripper.reshape(-1, 1)
                    
                    # 如果右臂是2D但右夹爪是1D,扩展夹爪维度
                    if len(right_arm.shape) == 2 and len(right_gripper.shape) == 1:
                        right_gripper = right_gripper.reshape(-1, 1)
                    
                    # 连接所有部分形成完整的动作向量
                    action_data = np.concatenate([left_arm, left_gripper, right_arm, right_gripper], axis=1)
                    
                except Exception as e:
                    print(f"Error loading joint action data: {e}")
                    return None
                
                # 调整数据索引
                max_index = len(action_data) - 2
                index = random.randint(0, max_index)
                
                # 当前状态(使用关节动作数据)
                action_current = action_data[index]
                
                # 未来动作序列
                action_end = min(index + self.chunk_size * self.upsample_rate, max_index + 1)
                action_chunk = action_data[index+1:action_end+1:self.upsample_rate]
                
                # 如果动作序列不足,重复最后一帧进行填充
                if action_chunk.shape[0] < self.chunk_size:
                    last_part = np.repeat(action_chunk[-1:], self.chunk_size - action_chunk.shape[0], axis=0)
                    action_chunk = np.concatenate([action_chunk, last_part], axis=0)
                
                # 获取多视角相机数据
                try:
                    current_images = []
                    
                    # 定义3视角设置的相机路径(无前相机)
                    camera_paths = {
                        "cam_high": "observation/head_camera/rgb",
                        "cam_left_wrist": "observation/left_camera/rgb",
                        "cam_right_wrist": "observation/right_camera/rgb"
                    }
                    
                    # 从每个配置的相机加载图像
                    for cam_idx, cam_name in enumerate(self.cameras):
                        cam_path = camera_paths.get(cam_name)
                        if cam_path and cam_path in f:
                            camera_data = f[cam_path]
                            img_frames = self.parse_img_data(camera_data, index)
                            current_images.append(img_frames)
                        else:
                            print(f"Warning: Camera {cam_name} not found in {hdf5_file}")
                            return None
                    
                    # 确保相机数量正确
                    if len(current_images) != self.num_cameras:
                        print(f"Error: Expected {self.num_cameras} cameras, but got {len(current_images)}")
                        return None
                    
                    # 转换为numpy数组,形状为 [num_cameras, history_size, H, W, 3]
                    img_frames_np = np.array(current_images)
                    
                    # 为每个相机创建图像掩码
                    mask_length = self.img_history_size
                    current_images_mask = [
                        np.array([True]*mask_length, dtype=bool) for _ in range(self.num_cameras)
                    ]
                    
                except Exception as e:
                    print(f"Error accessing camera data in {hdf5_file}: {e}")
                    traceback.print_exc()
                    return None
                
                # 加载预编码的语言指令
                language_embedding = self.load_language_embedding(hdf5_file)
                if language_embedding is None:
                    print(f"Warning: Failed to load language embedding for {hdf5_file}")
                    return None
                
                # 创建状态指示器和动作归一化数组
                state_indicator = np.ones_like(action_current)
                action_norm = np.ones_like(action_chunk)
                
                # 创建结果字典
                result = {
                    "current_images": img_frames_np,  # 当前帧图像(可能经过增强)
                    "current_images_mask": current_images_mask,  # 图像掩码
                    "actions": action_chunk,  # 动作序列
                    "states": np.expand_dims(action_current, axis=0),  # 状态
                    "state_indicator": state_indicator,
                    "action_norm": action_norm,
                    "instruction": language_embedding,  # 预编码的语言指令
                    "dataset_name": self.DATASET_NAME,  # 数据集名称
                }
                
                return result

        except Exception as e:
            print(f"Error processing {hdf5_file}: {e}")
            traceback.print_exc()
            return None

    def get_item(self, index=None):
        """
        获取数据项,如果index为None则随机选择一个
        
        Args:
            index: 可选,指定索引。如果为None,则随机选择
            
        Returns:
            处理后的数据字典,如果失败则返回None
        """
        if self.mode == "single_task":
            # 单任务模式:从episode文件中随机选择
            if not self.episode_files:
                self._initialize_dataset()
            
            if not self.episode_files:
                print("Error: No available episodes")
                return None
            
            # 随机选择一个HDF5文件
            episode_file = random.choice(self.episode_files)
            
        else:
            # 多任务模式:平衡任务采样
            if not self.task_to_episodes:
                self._initialize_dataset()
            
            # 根据任务权重随机选择一个任务
            task_name = random.choices(
                list(self.task_weights.keys()),
                weights=list(self.task_weights.values()),
                k=1
            )[0]
            
            # 从选中的任务中随机选择一个样本
            task_episodes = self.task_to_episodes.get(task_name, [])
            if not task_episodes:
                print(f"Warning: Task {task_name} has no available samples")
                # 从其他任务中选择
                alternative_tasks = [t for t in self.task_to_episodes.keys() if t != task_name and self.task_to_episodes.get(t, [])]
                if not alternative_tasks:
                    print("Error: No available samples")
                    return None
                task_name = random.choice(alternative_tasks)
                task_episodes = self.task_to_episodes.get(task_name, [])
            
            # 随机选择一个HDF5文件
            episode_file = random.choice(task_episodes)
        
        # 尝试提取样本数据(最多3次尝试)
        for _ in range(3):
            item = self.extract_episode_item(episode_file)
            if item is not None:
                return item
            # 如果当前样本提取失败,随机选择另一个
            if self.mode == "single_task":
                episode_file = random.choice(self.episode_files)
            else:
                task_episodes = self.task_to_episodes.get(task_name, [])
                if task_episodes:
                    episode_file = random.choice(task_episodes)
        
        print(f"Warning: Failed to extract sample, returning None")
        return None

二、 模型训练

import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
from torch.distributions.uniform import Uniform
import math
from typing import Dict, Optional, Tuple, Any
import logging

# 导入自定义模块
from utils.hub_mixin import CompatiblePyTorchModelHubMixin
from models.hrdt.model import HRDT


class SigmoidTimestepSampler:
    """
    LogitNormal采样器
    采样过程: u ~ N(mean, std), t = sigmoid(u)
    用于扩散模型的时间步采样
    """
    def __init__(self, timestep_max=1.0, mean=0.0, std=1.0):
        """
        初始化采样器参数
        Args:
            timestep_max: 时间步最大值
            mean: 正态分布均值
            std: 正态分布标准差
        """
        self.timestep_max = timestep_max
        self.mean = mean  # Normal distribution mean
        self.std = std    # Normal distribution standard deviation
    
    def sample(self, shape):
        """
        LogitNormal采样, 即 sigmoid(randn(m,s))
        
        1. u ~ N(mean, std)
        2. t = sigmoid(u) 
        """
        # 生成正态分布随机数 u ~ N(mean, std)
        u = torch.normal(mean=self.mean, std=self.std, size=shape)
        # 应用sigmoid变换得到(0,1)范围的时间步
        t = torch.sigmoid(u)
        # 缩放到[0, timestep_max]范围
        return t * self.timestep_max
    
    def visualize_distribution(self, num_samples=10000):
        """
        可视化采样分布
        """
        samples = self.sample((num_samples,))
        return {
            'samples': samples,
            'mean': samples.mean().item(),
            'std': samples.std().item(),
            'min': samples.min().item(),
            'max': samples.max().item(),
            'config': f'LogitNormal(mean={self.mean}, std={self.std})'
        }


class ActionEncoder(nn.Module):
    """动作编码器,组合状态和动作适配器"""
    
    def __init__(self, state_dim, action_dim, hidden_size, config):
        """
        初始化动作编码器
        Args:
            state_dim: 状态维度
            action_dim: 动作维度
            hidden_size: 隐藏层大小
            config: 配置字典
        """
        super().__init__()
        # 构建状态适配器
        self.state_adaptor = self.build_condition_adapter(
            config['st_adaptor'],
            in_features=state_dim,
            out_features=hidden_size
        )
        # 构建动作适配器
        self.action_adaptor = self.build_condition_adapter(
            config['act_adaptor'],
            in_features=action_dim,
            out_features=hidden_size
        )
    
    def build_condition_adapter(self, projector_type, in_features, out_features):
        """
        构建条件适配器
        Args:
            projector_type: 投影器类型 ('linear' 或 'mlpNx_silu')
            in_features: 输入特征维度
            out_features: 输出特征维度
        Returns:
            构建的投影器模块
        """
        projector = None
        if projector_type == 'linear':
            projector = nn.Linear(in_features, out_features)
        else:
            # 匹配mlpNx_silu格式,如mlp2x_silu表示2层MLP
            mlp_silu_match = re.match(r'^mlp(\d+)x_silu$', projector_type)
            if mlp_silu_match:
                mlp_depth = int(mlp_silu_match.group(1))
                modules = [nn.Linear(in_features, out_features)]
                for _ in range(1, mlp_depth):
                    modules.append(nn.SiLU())
                    modules.append(nn.Linear(out_features, out_features))
                projector = nn.Sequential(*modules)

        if projector is None:
            raise ValueError(f'Unknown projector type: {projector_type}')

        return projector
    
    def encode_state(self, state_tokens):
        """编码状态"""
        return self.state_adaptor(state_tokens)
    
    def encode_action(self, action_tokens):
        """编码动作"""
        return self.action_adaptor(action_tokens)


class HRDTRunner(
        nn.Module,
        CompatiblePyTorchModelHubMixin,
        repo_url="https://huggingface.co/hongzhe2002/H-RDT/"
    ):
    """
    H-RDT模型运行器,负责模型训练和推理
    """
    def __init__(self, *, state_dim, action_dim,
                 pred_horizon, config, act_pos_emb_config=None, img_pos_emb_config=None, lang_pos_emb_config=None,
                 max_img_len=None, max_lang_len=None,
                 training_mode='lang',
                 mode='pretrain',
                 pretrained_backbone_path=None,
                 dtype=torch.bfloat16):
        """
        初始化HRDTRunner
        Args:
            state_dim: 状态维度
            action_dim: 动作维度
            pred_horizon: 预测时间范围
            config: 配置字典
            act_pos_emb_config: 动作位置编码配置
            img_pos_emb_config: 图像位置编码配置
            lang_pos_emb_config: 语言位置编码配置
            max_img_len: 最大图像长度
            max_lang_len: 最大语言长度
            training_mode: 训练模式
            mode: 运行模式 ('pretrain' 或 'finetune')
            pretrained_backbone_path: 预训练骨干网络路径
            dtype: 数据类型
        """
        super(HRDTRunner, self).__init__()
        # 创建扩散模型
        hidden_size = config['hrdt']['hidden_size']
        self.gradient_checkpointing = False
        self.hidden_size = hidden_size
        self.training_mode = training_mode
        self.mode = mode  # 'pretrain' or 'finetune'
        
        # 验证模式
        if mode not in ['pretrain', 'finetune']:
            raise ValueError(f"mode must be 'pretrain' or 'finetune', got {mode}")

        # 创建H-RDT模型
        self.model = HRDT(
            horizon=pred_horizon,
            config=config['hrdt'],
            x_pos_emb_config=act_pos_emb_config,
            img_pos_emb_config=img_pos_emb_config,
            lang_pos_emb_config=lang_pos_emb_config,
            max_img_len=max_img_len,
            max_lang_len=max_lang_len,
            training_mode=training_mode,
            dtype=dtype,
        )

        # 图像特征适配器 - 使用配置中的维度
        self.img_adapter = self.build_condition_adapter(
            config.get('img_adapter', 'mlp2x_silu'),
            in_features=config.get('vision', {}).get('feature_dim', 2048),  # 默认ResNet50维度
            out_features=hidden_size
        )
        
        # 动作编码器(状态和动作适配器)
        self.action_encoder = ActionEncoder(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_size=hidden_size,
            config=config
        )

        # 语言特征适配器 - 使用配置中的维度
        self.lang_adapter = self.build_condition_adapter(
            config.get('lang_adapter', 'mlp2x_silu'),
            in_features=config.get('text', {}).get('feature_dim', 768),  # 默认DistilBERT维度
            out_features=hidden_size
        )

        # 创建噪声调度器
        noise_scheduler_config = config['noise_scheduler']
        self.num_inference_timesteps = noise_scheduler_config['num_inference_timesteps']
        self.timestep_max = noise_scheduler_config['timestep_max']
        
        sampler_type = noise_scheduler_config.get('sampler_type', 'sigmoid')
        if sampler_type == 'uniform':
            self.timestep_sampler = Uniform(0, self.timestep_max)
        elif sampler_type == 'sigmoid':
            mean = noise_scheduler_config.get('sigmoid_mean', 0.0)
            std = noise_scheduler_config.get('sigmoid_std', 1.0)
            self.timestep_sampler = SigmoidTimestepSampler(self.timestep_max, mean, std)
        else:
            raise ValueError(f"Unknown sampler type: {sampler_type}")

        self.pred_horizon = pred_horizon
        self.action_dim = action_dim

        # TimeNoise配置
        self.time_noise_a = config["time_noise"]["a"]
        self.time_noise_beta_m = config["time_noise"]["beta_m"]
        
        self.img_pos_emb_config = img_pos_emb_config

        # 如果在微调模式下,加载预训练骨干网络权重
        if mode == 'finetune' and pretrained_backbone_path is not None:
            self.load_pretrained_backbone(pretrained_backbone_path)

        # 打印模型大小
        print("Model params: %e" % sum(p.numel() for p in self.parameters()))

    def load_pretrained_backbone(self, pretrained_path):
        """
        加载预训练骨干网络权重,保持动作编码器和解码器为新初始化状态
        Args:
            pretrained_path: 预训练模型路径
        """
        logging.info(f"Loading pretrained backbone from {pretrained_path}")
        
        # 加载检查点
        if pretrained_path.endswith('.safetensors'):
            import safetensors.torch
            checkpoint = safetensors.torch.load_file(pretrained_path)
        else:
            checkpoint = torch.load(pretrained_path, map_location='cpu')
        
        # 提取状态字典
        if 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        elif 'model' in checkpoint:
            state_dict = checkpoint['model']
        else:
            state_dict = checkpoint
        
        # 过滤掉动作编码器和动作解码器权重
        backbone_state_dict = {}
        action_related_keys = []
        
        for key, value in state_dict.items():
            # 跳过动作编码器(state_adaptor, act_adaptor)和动作解码器权重
            if any(pattern in key for pattern in [
                'st_adaptor', 'act_adaptor', 'action_encoder', 
                'final_layer', 'action_decoder'
            ]):
                action_related_keys.append(key)
                continue

            # 如果形状不匹配,跳过img_pos_emb
            if key == 'model.img_pos_emb':
                current_shape = self.model.img_pos_emb.shape
                if value.shape != current_shape:
                    logging.info(f"Skipping img_pos_emb due to shape mismatch: {value.shape} vs {current_shape}")
                    continue

            backbone_state_dict[key] = value
        
        # 加载骨干网络权重
        missing_keys, unexpected_keys = self.load_state_dict(backbone_state_dict, strict=False)
        
        logging.info(f"Loaded backbone with {len(backbone_state_dict)} parameters")
        logging.info(f"Skipped action-related keys: {action_related_keys}")
        logging.info(f"Missing keys: {missing_keys}")
        logging.info(f"Unexpected keys: {unexpected_keys}")
        logging.info("Action encoder and decoder initialized from scratch for finetune mode")

    @classmethod
    def from_pretrained_for_finetune(cls, pretrained_path, state_dim, action_dim, pred_horizon, config, **kwargs):
        """
        创建微调模式下的模型,带有预训练骨干网络
        """
        return cls(
            state_dim=state_dim,
            action_dim=action_dim,
            pred_horizon=pred_horizon,
            config=config,
            mode='finetune',
            pretrained_backbone_path=pretrained_path,
            **kwargs
        )

    def build_condition_adapter(
        self, projector_type, in_features, out_features):
        """
        构建条件适配器(与ActionEncoder中的方法重复,可能是为了代码复用)
        """
        projector = None
        if projector_type == 'linear':
            projector = nn.Linear(in_features, out_features)
        else:
            mlp_silu_match = re.match(r'^mlp(\d+)x_silu$', projector_type)
            if mlp_silu_match:
                mlp_depth = int(mlp_silu_match.group(1))
                modules = [nn.Linear(in_features, out_features)]
                for _ in range(1, mlp_depth):
                    modules.append(nn.SiLU())
                    modules.append(nn.Linear(out_features, out_features))
                projector = nn.Sequential(*modules)

        if projector is None:
            raise ValueError(f'Unknown projector type: {projector_type}')

        return projector
    
    def gradient_checkpointing_enable(self, value=True):
        """
        启用梯度检查点以提高内存效率
        """
        self.gradient_checkpointing = value
        if hasattr(self.model, "gradient_checkpointing_enable"):
            self.model.gradient_checkpointing_enable(value)

    def compute_loss(self, state_tokens=None, action_gt=None, image_tokens=None, lang_tokens=None, lang_attn_mask=None):
        """
        计算损失函数
        Args:
            img_tokens: (batch_size, img_len, img_token_dim)
            state_tokens: (batch_size, chunk_size, action_dim), 
            action_gt: (batch_size, chunk_size, action_dim), 监督的真值动作
            lang_tokens: (batch_size, L, hidden_size), 语言特征(未池化)
            lang_attn_mask: (batch_size, L), 语言token的注意力掩码
        Returns:
            包含损失值的字典
        """
        batch_size = image_tokens.shape[0] # 32
        device = image_tokens.device
        dtype = image_tokens.dtype

        # 生成噪声并添加到真值动作上
        noise = torch.randn(action_gt.shape, dtype=dtype, device=device)
        timesteps = self.timestep_sampler.sample((batch_size,)).to(device) # 32
        
        # 扩展时间步张量
        broadcasted = timesteps.view(-1, 1, 1) # 32, 1, 1
        # 创建带噪声的动作
        noisy_action = (action_gt * broadcasted + noise * (1 - broadcasted)).to(dtype=dtype)  # 32, 16, 48

        # 处理图像特征
        img_c = self.img_adapter(image_tokens)

        # 处理语言特征 - 处理None情况
        lang_c = None
        if lang_tokens is not None:
            lang_c = self.lang_adapter(lang_tokens)  # [B, L, D] - 保持未池化用于交叉注意力

        # 使用动作编码器处理状态/动作
        state_traj = self.action_encoder.encode_state(state_tokens) # 32,1, 2176
        action_traj = self.action_encoder.encode_action(noisy_action) # 32, 16, 2176
        state_action_traj = torch.cat([state_traj, action_traj], dim=1) # 32, 17, 2176
        
        # 模型前向传播
        pred = self.model(state_action_traj, timesteps, img_c=img_c, lang_c=lang_c, lang_attn_mask=lang_attn_mask) # 32, 16,48
        # 目标是噪声
        target = action_gt - noise
        
        # 计算均方误差损失
        diff_loss = F.mse_loss(pred, target)
        
        return {"diff_loss": diff_loss, "loss": diff_loss}

    @torch.no_grad()
    def predict_action(self, state_tokens=None, image_tokens=None, lang_tokens=None, lang_attn_mask=None):
        '''
        预测动作序列
        Args:
            state_tokens: (batch_size, chunk_size, action_dim)
            image_tokens: (batch_size, img_len, in_feat_dim)
            lang_tokens: 语言特征 [B, L, hidden_size] (未池化)
            lang_attn_mask: (batch_size, L), 语言token的注意力掩码
            
        Returns: 
            (batch_size, chunk_size, action_dim), 预测的动作序列
        '''
        batch_size = image_tokens.shape[0]
        device = image_tokens.device
        dtype = image_tokens.dtype

        # 处理图像特征
        img_c = self.img_adapter(image_tokens)

        # 处理语言特征 - 处理None情况
        lang_c = None
        if lang_tokens is not None:
            lang_c = self.lang_adapter(lang_tokens)  # [B, L, D] - 保持未池化用于交叉注意力

        # 编码状态
        state_traj = self.action_encoder.encode_state(state_tokens)
        # 初始化噪声动作
        noisy_action = torch.randn((batch_size, self.pred_horizon, self.action_dim), dtype=dtype, device=device)
        timestep = torch.tensor([0.0], dtype=dtype, device=device)
        step_size = 1.0 / self.num_inference_timesteps

        # 逐步去噪生成动作序列
        for _ in range(self.num_inference_timesteps):
            action_traj = self.action_encoder.encode_action(noisy_action)
            state_action_traj = torch.cat([state_traj, action_traj], dim=1)
            pred = self.model(state_action_traj, timestep, img_c=img_c, lang_c=lang_c, lang_attn_mask=lang_attn_mask)
            # 更新动作估计
            noisy_action = pred * step_size + noisy_action
            # 更新时间步
            timestep = timestep + step_size

        return noisy_action

    def forward(self, *args, **kwargs) -> torch.Tensor:
        """前向传播,调用计算损失函数"""
        return self.compute_loss(*args, **kwargs)

模型核心代码代码如下:

from collections import OrderedDict
from typing import List, Tuple, Optional

import re
import torch
import torch.nn as nn
import torch.nn.functional as F

from models.hrdt.blocks import ActionDecoder, HRDTBlock, TimestepEmbedder
from models.hrdt.pos_emb import get_multimodal_pos_embed


class HRDT(nn.Module):
    """
    机器人扩散变换器模型 (Robotics Diffusion Transformer)
    
    修改内容:
    1. 状态和噪声动作块一起作为输入处理
    2. AdaLN现在仅使用时间步长(不使用句子标记)
    3. 图像特征直接输入到块中的交叉注意力机制
    4. 训练模式控制使用哪些交叉注意力层
    """
    def __init__(
        self,
        horizon: int,                    # 预测动作序列的长度
        config: dict,                    # 模型配置字典
        x_pos_emb_config: List[Tuple],   # 输入序列位置嵌入配置 [(模态名, 长度), ...]
        img_pos_emb_config: List[Tuple] = None,   # 图像位置嵌入配置
        lang_pos_emb_config: List[Tuple] = None,  # 语言位置嵌入配置
        max_img_len: int = None,         # 图像序列最大长度
        max_lang_len: int = None,        # 语言序列最大长度
        training_mode: str = 'lang',     # 训练模式 ('lang')
        dtype=torch.bfloat16,            # 模型参数数据类型
    ):
        super().__init__()
        # 保存模型参数
        self.horizon = horizon           # 动作序列预测长度
        self.hidden_size = config["hidden_size"]  # 隐藏层维度
        self.n_heads = config["num_heads"]        # 注意力头数
        self.dtype = dtype               # 数据类型
        self.gradient_checkpointing = False       # 是否启用梯度检查点
        self.training_mode = training_mode        # 训练模式

        # 验证训练模式参数
        if training_mode not in ['lang']:
            raise ValueError(f"training_mode必须是'lang',得到的是{training_mode}")

        # 移除AdaLN适配器 - 时间步嵌入直接传递给块

        # 时间步嵌入器:将扩散过程的时间步转换为嵌入向量
        self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype)

        # 创建H-RDT层并根据训练模式配置
        self.depth = config["depth"]     # Transformer块的层数
        self.blocks = nn.ModuleList([
            HRDTBlock(layer_idx, config=config, training_mode=training_mode)
            for layer_idx in range(self.depth)
        ])
        
        # 动作解码器:最终将隐藏状态转换为动作预测
        self.action_decoder = ActionDecoder(config=config)

        # 位置嵌入配置
        self.x_pos_emb_config = x_pos_emb_config      # 输入序列位置嵌入配置
        self.lang_pos_emb_config = lang_pos_emb_config # 语言位置嵌入配置
        self.img_pos_emb_config = img_pos_emb_config   # 图像位置嵌入配置
        
        # 创建可学习的位置嵌入参数
        # x_pos_emb: 状态+动作序列的位置嵌入 (batch_size=1, seq_len=1+horizon, hidden_size)
        self.x_pos_emb = nn.Parameter(torch.zeros(
            1, 1 + self.horizon, self.hidden_size))
        
        # 语言位置嵌入 (batch_size=1, seq_len=max_lang_len, hidden_size)
        self.lang_pos_emb = nn.Parameter(torch.zeros(
            1, max_lang_len, self.hidden_size))
            
        # 图像位置嵌入 (batch_size=1, seq_len=max_img_len, hidden_size)
        self.img_pos_emb = nn.Parameter(torch.zeros(
            1, max_img_len, self.hidden_size))

        # 初始化模型权重
        self.initialize_weights()

    def build_condition_adapter(
        self, 
        projector_type,    # 投影器类型 ('linear' 或 'mlp{N}x_silu')
        in_features,       # 输入特征维度
        out_features       # 输出特征维度
    ):
        """
        构建条件适配器,用于将不同模态的特征投影到统一的隐藏空间维度
        """
        projector = None
        
        # 线性投影器
        if projector_type == 'linear':
            projector = nn.Linear(in_features, out_features)
        else:
            # 匹配MLP+SILU类型的投影器 (如 mlp2x_silu)
            mlp_silu_match = re.match(r'^mlp(\d+)x_silu$', projector_type)
            if mlp_silu_match:
                mlp_depth = int(mlp_silu_match.group(1))  # 获取MLP层数
                modules = [nn.Linear(in_features, out_features)]
                
                # 构建多层MLP,每层后跟SiLU激活函数
                for _ in range(1, mlp_depth):
                    modules.append(nn.SiLU())
                    modules.append(nn.Linear(out_features, out_features))
                projector = nn.Sequential(*modules)

        # 如果没有匹配到有效的投影器类型,抛出异常
        if projector is None:
            raise ValueError(f'未知的投影器类型: {projector_type}')

        return projector

    def initialize_weights(self):
        """
        初始化模型的各种权重参数
        """
        # 初始化线性层权重
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)  # Xavier均匀初始化
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)         # 偏置初始化为0
        self.apply(_basic_init)

        # 使用正余弦函数初始化位置嵌入
        # 输入序列位置嵌入
        x_pos_emb = get_multimodal_pos_embed(
            embed_dim=self.hidden_size,                    # 嵌入维度
            mm_lens=OrderedDict(self.x_pos_emb_config)     # 多模态长度配置
        )
        # 将计算得到的位置嵌入复制到模型参数中
        self.x_pos_emb.data.copy_(
            torch.from_numpy(x_pos_emb).float().unsqueeze(0))

        # 语言位置嵌入
        lang_pos_emb = get_multimodal_pos_embed(
            embed_dim=self.hidden_size,
            mm_lens=OrderedDict(self.lang_pos_emb_config)
        )
        self.lang_pos_emb.data.copy_(
            torch.from_numpy(lang_pos_emb).float().unsqueeze(0))

        # 图像位置嵌入
        img_pos_embed = get_multimodal_pos_embed(
            embed_dim=self.hidden_size,
            mm_lens=OrderedDict(self.img_pos_emb_config)
        )
        self.img_pos_emb.data.copy_(
            torch.from_numpy(img_pos_embed).float().unsqueeze(0))

        # 初始化时间步嵌入MLP层权重
        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)  # 第一层权重正态分布初始化
        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)  # 第三层权重正态分布初始化

        # 将AdaLN调制层的权重初始化为0,确保初始状态不会对输入产生影响
        for block in self.blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

        # 将输出层的权重初始化为0
        nn.init.constant_(self.action_decoder.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.action_decoder.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.action_decoder.ffn.fc2.weight, 0)
        nn.init.constant_(self.action_decoder.ffn.fc2.bias, 0)

        # 将所有参数转换为指定的数据类型
        self.to(self.dtype)

    def gradient_checkpointing_enable(self, value: bool = True):
        """
        启用梯度检查点以节省内存
        
        Args:
            value (bool): 是否启用梯度检查点,默认为True
        """
        self.gradient_checkpointing = value

    def forward(
        self, 
        x,                    # 输入张量 (B, 1 + T, D) - 状态和动作序列
        t,                    # 扩散时间步 (B,) 或 (1,)
        img_c=None,           # 图像特征用于交叉注意力 (B, S_img, D),可选
        lang_c=None,          # 语言标记用于交叉注意力 (B, S_lang, D),可选
        sentence_c=None,      # 句子标记(为向后兼容而保留,实际被忽略)
        task_c=None,          # 任务标记(未使用)
        lang_attn_mask=None   # 语言标记的注意力掩码 (B, S_lang),可选
    ):
        """
        H-RDT模型的前向传播过程

        Args:
            x: (B, 1 + T, D), 状态和动作标记序列, T = horizon
            t: (B,) or (1,), 扩散时间步
            img_c: (B, S_img, D), 用于交叉注意力的图像特征, 可选
            lang_c: (B, S_lang, D), 用于交叉注意力的语言标记, 可选
            sentence_c: 忽略 (为向后兼容)
            lang_attn_mask: (B, S_lang), 语言标记的注意力掩码, 可选
        Returns:
            x: (B, T, D_out), 预测的去噪动作标记
        """
        # 使用正弦嵌入对时间步进行编码
        # t: (batch_size,) -> t_emb: (batch_size, hidden_size)
        t_emb = self.t_embedder(t)  # (B, D) or (1, D)
        
        # 如果时间嵌入是单个向量,则扩展为批量大小
        if t_emb.shape[0] == 1:
            t_emb = t_emb.expand(x.shape[0], -1)  # (B, D)

        # 添加位置嵌入到输入序列
        x = x + self.x_pos_emb
        
        # 如果提供了图像特征,则添加图像位置嵌入
        if img_c is not None:
            img_c = img_c + self.img_pos_emb[:, :img_c.shape[1]] # batch_size, 729, 2176
        
        # 如果提供了语言特征,则添加语言位置嵌入
        if lang_c is not None:
            lang_c = lang_c + self.lang_pos_emb[:, :lang_c.shape[1]] # batch_size, 44, 2176
        
        # 将时间嵌入直接传递给各个块(不使用句子标记)
        for i, block in enumerate(self.blocks):
            # 构建交叉注意力上下文字典
            cross_contexts = {
                'img_c': img_c,              # 图像上下文
                'lang_c': lang_c,            # 语言上下文
                'lang_attn_mask': lang_attn_mask  # 语言注意力掩码
            }
            
            # 根据是否启用梯度检查点和是否处于训练模式决定执行方式
            # 梯度检查点可以节省内存但会增加计算时间
            if self.gradient_checkpointing and self.training:
                # 使用梯度检查点执行块计算
                x = torch.utils.checkpoint.checkpoint(block, x, t_emb, cross_contexts, use_reentrant=False)
            else:
                # 直接执行块计算
                x = block(x, t_emb, cross_contexts)

        # 最终层仅使用时间步(不使用交叉注意力)
        x = self.action_decoder(x, t_emb)

        # 提取动作预测(去掉状态部分,只保留动作序列部分)
        x = x[:, -self.horizon:]

        return x

下面是模块代码部分:

import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.vision_transformer import Mlp

from models.hrdt.norm import RMSNorm
from models.hrdt.attention import Attention, CrossAttention


def modulate(x, shift, scale):
    """
    对输入张量进行调制(仿射变换)
    
    Args:
        x: 输入张量
        shift: 偏移量(beta参数)
        scale: 缩放因子(gamma参数)
    
    Returns:
        调制后的张量: x * (1 + scale) + shift
    """
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


class TimestepEmbedder(nn.Module):
    """
    时间步嵌入器:将标量时间步转换为向量表示
    
    来源:
    https://github.com/facebookresearch/DiT/blob/main/models.py
    """
    def __init__(self, hidden_size, frequency_embedding_size=256, dtype=torch.bfloat16):
        """
        初始化时间步嵌入器
        
        Args:
            hidden_size: 隐藏层维度
            frequency_embedding_size: 频率嵌入维度
            dtype: 数据类型
        """
        super().__init__()
        # 多层感知机用于处理频率嵌入
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),  # 使用SiLU激活函数
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size
        self.dtype = dtype

    def timestep_embedding(self, t, dim, max_period=10000):
        """
        创建正弦时间步嵌入
        
        Args:
            t: 形状为(N,)的1维张量,包含每个批次元素的索引,可能是小数
            dim: 输出维度
            max_period: 控制嵌入的最小频率
            
        Returns:
            形状为(N, D)的位置嵌入张量
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        # 计算频率参数
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(
                start=0, end=half, dtype=torch.float32, device=t.device) / half
        )
        # 计算角度参数
        args = t[:, None].float() * freqs[None]
        # 拼接cos和sin嵌入
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        # 如果维度是奇数,补充一个零维度
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding.to(self.dtype)

    def forward(self, t):
        """
        前向传播
        
        Args:
            t: 时间步张量
            
        Returns:
            时间步嵌入向量
        """
        # 生成时间步频率嵌入
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        # 通过MLP处理得到最终嵌入
        t_emb = self.mlp(t_freq)
        return t_emb


class FeedForward(nn.Module):
    """
    带SiLU激活的前馈网络
    
    参考:
    https://github.com/meta-llama/llama3/blob/main/llama/model.py
    """
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        """
        初始化前馈网络
        
        Args:
            dim: 输入维度
            hidden_dim: 隐藏层维度
            multiple_of: 隐藏层维度倍数
            ffn_dim_multiplier: FFN维度乘数
        """
        super().__init__()
        # 计算隐藏层维度
        hidden_dim = int(2 * hidden_dim / 3)
        # 应用自定义维度乘数
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        # 调整隐藏层维度为multiple_of的倍数
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        # 三个线性变换层
        self.w1 = nn.Linear(
            dim, hidden_dim, bias=False
        )
        self.w2 = nn.Linear(
            hidden_dim, dim, bias=False
        )
        self.w3 = nn.Linear(
            dim, hidden_dim, bias=False
        )

    def forward(self, x):
        """
        前向传播
        
        Args:
            x: 输入张量
            
        Returns:
            处理后的张量
        """
        # SwiGLU激活函数: silu(w1(x)) * w3(x) -> w2
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class HRDTBlock(nn.Module):
    """
    H-RDT模块块,包含自注意力、两个交叉注意力层和前馈网络
    训练模式控制使用哪个交叉注意力层:
    - 'lang': 图像 + 语言交叉注意力
    """
    def __init__(self, layer_idx: int, config: dict, training_mode: str = 'lang'):
        """
        初始化H-RDT块
        
        Args:
            layer_idx: 层索引
            config: 配置字典
            training_mode: 训练模式 ('lang')
        """
        super().__init__()
        self.layer_idx = layer_idx
        self.hidden_size = config["hidden_size"]
        self.norm_eps = config["norm_eps"]
        self.training_mode = training_mode
        
        # 验证训练模式
        if training_mode not in ['lang']:
            raise ValueError(f"training_mode must be 'lang', got {training_mode}")
        
        # 自注意力层
        self.attn_norm = nn.LayerNorm(
            self.hidden_size, eps=self.norm_eps)
        self.attn = Attention(config)
        
        # 图像交叉注意力层 (始终存在)
        self.img_cross_norm = nn.LayerNorm(
            self.hidden_size, eps=self.norm_eps)
        self.img_cond_norm = nn.LayerNorm(
            self.hidden_size, eps=self.norm_eps)
        self.img_cross_attn = CrossAttention(config)
        
        # 语言交叉注意力层
        self.lang_cross_norm = nn.LayerNorm(
            self.hidden_size, eps=self.norm_eps)
        self.lang_cond_norm = nn.LayerNorm(
            self.hidden_size, eps=self.norm_eps)
        self.lang_cross_attn = CrossAttention(config)
        
        # 前馈网络
        self.ffn_norm = nn.LayerNorm(
            self.hidden_size, eps=self.norm_eps)
        self.ffn = FeedForward(
            dim=self.hidden_size,
            hidden_dim=4*self.hidden_size,
            multiple_of=config["multiple_of"],
            ffn_dim_multiplier=config["ffn_dim_multiplier"],
        )
        
        # AdaLN调制 - 保持原有的9个参数结构
        # self_attn(3) + cross_attn(3) + mlp(3) = 9 总计
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(self.hidden_size, 9*self.hidden_size, bias=True)
        )
        
    def forward(
            self,
            x: torch.Tensor,
            t: torch.Tensor,
            cross_contexts: dict = None,
        ):
        """
        前向传播,基于训练模式使用两个交叉注意力层
        
        https://blog.csdn.net/A1233242/article/details/149250857
        
        Args:
            x: 输入状态-动作序列
            t: 时间步嵌入 (不再使用句子token)
            cross_contexts: 包含交叉注意力上下文的字典
                - 'img_c': 用于交叉注意力的图像特征 (始终使用)
                - 'lang_c': 用于交叉注意力的语言token (如果training_mode='lang')
                - 'lang_attn_mask': 语言的注意力掩码
                
        Returns:
            处理后的张量
        """
        if cross_contexts is None:
            cross_contexts = {}
            
        # 自适应层归一化 - 分解为偏移、缩放和门控参数
        # shift -> beta; scale -> gamma; gate -> alpha
        shift_attn, scale_attn, gate_attn, \
        shift_cross, scale_cross, gate_cross, \
        shift_mlp, scale_mlp, gate_mlp \
            = self.adaLN_modulation(t).chunk(9, dim=1)
            
        # 自注意力处理
        h = x + gate_attn.unsqueeze(1) * self.attn(
            modulate(self.attn_norm(x), shift_attn, scale_attn))
        
        # 图像交叉注意力 (始终存在)
        img_c = cross_contexts.get('img_c')
        if img_c is not None:
            h = h + gate_cross.unsqueeze(1) * self.img_cross_attn(
                modulate(self.img_cross_norm(h), shift_cross, scale_cross),
                self.img_cond_norm(img_c), None)
        
        # 语言交叉注意力
        lang_c = cross_contexts.get('lang_c')
        lang_attn_mask = cross_contexts.get('lang_attn_mask')
        if lang_c is not None:
            # 使用相同的调制参数对语言应用额外的交叉注意力
            h = h + self.lang_cross_attn(
                self.lang_cross_norm(h),
                self.lang_cond_norm(lang_c), lang_attn_mask)
        
        # 前馈网络处理
        out = h + gate_mlp.unsqueeze(1) * self.ffn(
            modulate(self.ffn_norm(h), shift_mlp, scale_mlp))
        
        return out


class ActionDecoder(nn.Module):
    """
    H-RDT的动作解码器层 (之前称为FinalLayer).
    """
    def __init__(self, config):
        """
        初始化动作解码器
        
        Args:
            config: 配置字典
        """
        super().__init__()
        self.hidden_size = config["hidden_size"]
        self.norm_eps = config["norm_eps"]
        self.output_size = config["output_size"]

        # 前馈网络归一化层
        self.ffn_norm = nn.LayerNorm(
            self.hidden_size, eps=self.norm_eps)
        # 多层感知机
        self.ffn = Mlp(
            in_features=self.hidden_size,
            hidden_features=self.hidden_size*4,
            out_features=self.output_size,
            act_layer=nn.SiLU, drop=0.0
        )

        # AdaLN调制
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(self.hidden_size, 2*self.hidden_size, bias=True)
        )

    def forward(
            self,
            x: torch.Tensor,
            t: torch.Tensor
        ):
        """
        前向传播
        
        Args:
            x: 输入张量
            t: 时间步嵌入
            
        Returns:
            解码后的动作输出
        """
        # 分解调制参数
        shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
        # 应用调制和归一化
        x = modulate(self.ffn_norm(x), shift, scale)
        # 通过前馈网络
        x = self.ffn(x)
        return x


# 为向后兼容保留FinalLayer
FinalLayer = ActionDecoder

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容