四、数据处理模块
该模块需要看两处代码分为别hdf5_vla_dataset.py、dataset.py
4.1 hdf5_vla_dataset.py
import numpy as np
import h5py
def parse_hdf5_file(self, file_path):
with h5py.File(file_path, 'r') as f:
qpos = f['observations']['qpos'][:]
num_steps = qpos.shape[0]
if num_steps < 128:
return False, None
EPS = 1e-2 # 为了证明机械臂是有移动的
qpos_delta = np.abs(qpos - qpos[0:1]) # 其他qps值与第一个值之间的差距
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
if len(indices) > 0:
first_idx = indices[0] # 代表机器人开始移动的时间索引
else:
raise ValueError("Found no qpos that exceeds the threshold.")# 为了证明机械臂是有移动的
step_id = np.random.randint(first_idx-1, num_steps)
dir_path = os.path.dirname(file_path)
with open(os.path.join(dir_path, 'expanded_instruction_gpt-4-turbo.json'), 'r') as f_instr:
instruction_dict = json.load(f_instr)
instruction_type = np.random.choice([ # 随机选择一种文本模态
'instruction', 'simplified_instruction', 'expanded_instruction'])
instruction = instruction_dict[instruction_type] # 对应的文本描述
if isinstance(instruction, list):
instruction = np.random.choice(instruction)
# Assemble the meta
meta = {
"dataset_name": self.DATASET_NAME,
"#steps": num_steps, # 该数据有多少帧 T
"step_id": step_id,
"instruction": instruction
}
single_side_norm_scale_vec_len = qpos.shape[1] // 2 if qpos.shape[1] % 2 == 0 and qpos.shape[-1] % 2 <=10 else None
assert single_side_norm_scale_vec_len is not None, "qpos cannot divede by 2 and lager than 10"
_norm_vec = [1 for i in range(qpos.shape[1])]
qpos_norm_vec = _norm_vec.copy()
qpos_norm_vec[single_side_norm_scale_vec_len - 1] = 4.7908
qpos_norm_vec[-1] = 4.7888
qpos = qpos / np.array([
qpos_norm_vec
])
action_norm_vec = _norm_vec.copy()
action_norm_vec[single_side_norm_scale_vec_len - 1] = 11.8997
action_norm_vec[-1] = 13.9231
f_action = f['action']
target_qpos = f_action[step_id:step_id + self.CHUNK_SIZE] / np.array([ # CHUNK_SIZE 序列段长度, 这里为64
action_norm_vec
])
state = qpos[step_id:step_id+1]
state_std = np.std(qpos, axis=0)
state_mean = np.mean(qpos, axis=0)
state_norm = np.sqrt(np.mean(qpos**2, axis=0))
actions = target_qpos
if actions.shape[0] < self.CHUNK_SIZE: # ⭐️小于段长度, 对最后的阶段进行重复操作
# Pad the actions using the last action
actions = np.concatenate([
actions,
np.tile(actions[-1:], (self.CHUNK_SIZE-actions.shape[0], 1))
], axis=0)
# Fill the state/action into the unified vector
def fill_in_state(values):
values_len = values.shape[-1] // 2 - 1 if values.shape[-1] % 2 == 0 and values.shape[-1] % 2 <=10 else None
assert values is not None, "values cannot divede by 2 and lager than 10"
# Target indices corresponding to your state space
# In this example: 6 joints + 1 gripper for each arm
UNI_STATE_INDICES = [
STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] for i in range(values_len)
] + [
STATE_VEC_IDX_MAPPING["left_gripper_open"]
] + [
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(values_len)
] + [
STATE_VEC_IDX_MAPPING["right_gripper_open"]
]
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM,))
uni_vec[..., UNI_STATE_INDICES] = values
return uni_vec
state = fill_in_state(state) # 将1,14 转位1,128
state_indicator = fill_in_state(np.ones_like(state_std))
state_std = fill_in_state(state_std)
state_mean = fill_in_state(state_mean)
state_norm = fill_in_state(state_norm)
actions = fill_in_state(actions)
def parse_img(key):
imgs = []
for i in range(max(step_id-self.IMG_HISORY_SIZE+1, 0), step_id+1): # [step-1, step+1), 取两个图片, 即向前多取一张
if key not in f['observations']['images']:
key = key.replace("_wrist", "")
img = f['observations']['images'][key][i]
if not isinstance(img, np.ndarray):
img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
imgs.append(img)
imgs = np.stack(imgs)
if imgs.shape[0] < self.IMG_HISORY_SIZE:
# Pad the images using the first image
imgs = np.concatenate([
np.tile(imgs[:1], (self.IMG_HISORY_SIZE-imgs.shape[0], 1, 1, 1)),
imgs
], axis=0)
return imgs
cam_high = parse_img('cam_high') # 这里包括cam_high, cam_left_wrist, cam_right_wrist, shape: (2, 480, 640, 3)
valid_len = min(step_id - (first_idx - 1) + 1, self.IMG_HISORY_SIZE)
cam_high_mask = np.array(
[False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len) # (True, True)
cam_left_wrist = parse_img('cam_left_wrist') # shape: (2, 480, 640, 3)
cam_left_wrist_mask = cam_high_mask.copy()
cam_right_wrist = parse_img('cam_right_wrist') # (True, True)
cam_right_wrist_mask = cam_high_mask.copy() # (True, True)
return True, {
"meta": meta,
"state": state, # shape(1, 128)
"state_std": state_std, # shape (128)
"state_mean": state_mean, # shape (128), 对相应的某类别(14中一个)所有数值进行评价
"state_norm": state_norm, # shape (128)
"actions": actions, # shape (64, 128), 64个维度, 128 对应14个位置有数值, 索引和state_indicator一致
"state_indicator": state_indicator, # shape 128, 对应14个位置为1, 其余地方为0
"cam_high": cam_high, # shape: (2, 480, 640, 3)
"cam_high_mask": cam_high_mask, # (True, True)
"cam_left_wrist": cam_left_wrist, # shape: (2, 480, 640, 3)
"cam_left_wrist_mask": cam_left_wrist_mask, # (True, True)
"cam_right_wrist": cam_right_wrist, # shape: (2, 480, 640, 3)
"cam_right_wrist_mask": cam_right_wrist_mask # (True, True)
}
4.1.1 f
(.hd5py)文件内容, 下面的T=300
, 代表的是时间, 是一个完整动作的周期
{
"observations":{
"images":{
"cam_high": np.array(300, 480, 640, 3),
"cam_low": np.array(300, 480, 640, 3),
"cam_left": np.array(300, 480, 640, 3),
"cam_right": np.array(300, 480, 640, 3)
},
"effort": np.array(300, 16), # 关节力矩(Joint Efforts)
"qpos": np.array(300, 16), # 末端执行器的位姿
"qvel": np.array(300, 16), # 关节速度(Joint Velocities)
},
"action": np.array(300, 16)
}
⭐️ 这里为什么要求出action是一段的呢? 是因为之后会根据一段的action做顺滑, 通过差值一步步取做
4.1.2 针对机械臂没有太大浮动数据进行过滤,
可参考RDT-1B: a Diffusion Foundation Model for Bimanual Manipulation 论文及代码总结(一) 5.2部分中的其他finetune细节
* 对于数据动作周期小于128, 进行过滤
* 计算之后的每一个动作与第一个动作进行浮动计算, 当前一个动作与当前动作的差距大于阈值, 则认为有动作发生, 记作first_idx
num_steps = qpos.shape[0]
if num_steps < 128:
return False, None
EPS = 1e-2 # 为了证明机械臂是有移动的
# Get the idx of the first qpos whose delta exceeds the threshold
qpos_delta = np.abs(qpos - qpos[0:1]) # 其他qps值与第一个值之间的差距
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
if len(indices) > 0:
first_idx = indices[0] # 代表机器人开始移动的时间索引
else:
raise ValueError("Found no qpos that exceeds the threshold.")# 为了证明机械臂是有移动的
4.1.3 生成meta
文件
{
"dataset_name": "custom",
"#steps": num_stemp, # 该数据有多少帧时长, 这里使用的是300
"step_id": step_id, # 从上述得到的first_index 开始, 直到num_stemp结束随机选一个值
"instruction": instruction # 从instrction随机选一个值'instruction', 'simplified_instruction', 'expanded_instruction']
}
4.1.4 对原始数据进行归一化, 并对action(target_qpos)随机截取self.CHUNK_SIZE长度大小的段, 并根据step_id得到state, 当前机械臂qpos⭐️
# 1️⃣对qpos进行归一化
qpos_norm_vec = [1, 1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 1, 4.7888]
qpos = qpos / np.array([qpos_norm_vec])
# 2️⃣对action进行归一化, ⭐️小于段长度, 对最后的阶段进行重复操作
action_norm_vec = [1, 1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 1, 13.9231]
actions = f_action[step_id:step_id + self.CHUNK_SIZE] / np.array([action_norm_vec]) # self.CHUNK_SIZE=64
if actions.shape[0] < self.CHUNK_SIZE:
actions = np.concatenate([
actions,
np.tile(actions[-1:], (self.CHUNK_SIZE-actions.shape[0], 1))
], axis=0)
# 3️⃣并根据step_id得到state, 当前机械臂qpos
state = qpos[step_id:step_id+1]
state_std = np.std(qpos, axis=0)
state_mean = np.mean(qpos, axis=0)
state_norm = np.sqrt(np.mean(qpos**2, axis=0))
4.1.5 对数据将16的维度映射到128的维度上
def fill_in_state(values):
values_len = values.shape[-1] // 2 - 1 if values.shape[-1] % 2 == 0 and values.shape[-1] % 2 <=10 else None
assert values is not None, "values cannot divede by 2 and lager than 10"
# Target indices corresponding to your state space
# In this example: 6 joints + 1 gripper for each arm
UNI_STATE_INDICES = [
STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] for i in range(values_len)
] + [
STATE_VEC_IDX_MAPPING["left_gripper_open"]
] + [
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(values_len)
] + [
STATE_VEC_IDX_MAPPING["right_gripper_open"]
]
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM,))
uni_vec[..., UNI_STATE_INDICES] = values
return uni_vec
state = fill_in_state(state) # 将1,16 转位1,128, 把16维度上的值填到128维度的位置上, 不再对应位置上的值默认是0
state_indicator = fill_in_state(np.ones_like(state_std))
state_std = fill_in_state(state_std)
state_mean = fill_in_state(state_mean)
state_norm = fill_in_state(state_norm)
actions = fill_in_state(actions)
4.1.6 对图像数据进行处理,这里parse_img处理中解析图像, 如果图像时间轴上的数量少于预设的self.IMG_HISORY_SIZE, 则重复原图补充
cam_high = parse_img('cam_high') # 这里包括cam_high, cam_left_wrist, cam_right_wrist, shape: (2, 480, 640, 3); cam_low没有选取
valid_len = min(step_id - (first_idx - 1) + 1, self.IMG_HISORY_SIZE)
cam_high_mask = np.array([False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len) # 之前一直浮动小的图像设置mask为False
cam_left_wrist = parse_img('cam_left_wrist') # shape: (2, 480, 640, 3)
cam_left_wrist_mask = cam_high_mask.copy()
cam_right_wrist = parse_img('cam_right_wrist') # (True, True)
cam_right_wrist_mask = cam_high_mask.copy() # (True, True)
4.2 dataset.py
def __getitem__(self, index):
# For robustness, we will try to load the data until we succeed
while True:
data_dict = None
try:
if self.use_hdf5:
res = self.hdf5_dataset.get_item()
content = res['meta']
states = res['state'] # (1, 128)
actions = res['actions'] # (64, 128)
state_elem_mask = res['state_indicator']
image_metas = [
res['cam_high'], res['cam_high_mask'],
res['cam_right_wrist'], res['cam_right_wrist_mask'],
res['cam_left_wrist'], res['cam_left_wrist_mask'],
]
state_std = res['state_std']
state_mean = res['state_mean']
state_norm = res['state_norm']
else:
(content, _, states, _, actions, _,
state_elem_mask, *image_metas,
state_std, state_mean, state_norm) = self._safe_load(index)
data_dict = {}
data_dict['dataset_name'] = content['dataset_name']
data_dict['data_idx'] = self.dataset_name2id[data_dict['dataset_name']]
data_dict['ctrl_freq'] = self.control_freq[data_dict['dataset_name']] \
if random.random() > self.cond_mask_prob else 0 # 有一定概率ctrl_freq为0
if self.state_noise_snr is not None:
states += np.random.normal(
0.0, state_std / np.sqrt(10 ** (self.state_noise_snr / 10)),
states.shape)
ds_state_mean = np.array(self.dataset_stat[data_dict['dataset_name']]['state_mean'])
ds_state_mean = np.tile(ds_state_mean[None], (states.shape[0], 1))
# Randomly mask the states by the mean state
data_dict["states"] = states \
if random.random() > self.cond_mask_prob else ds_state_mean
data_dict["actions"] = actions
data_dict["state_elem_mask"] = state_elem_mask \
if random.random() > self.cond_mask_prob else np.zeros_like(state_elem_mask)
# Stat for the episode that the step belongs to
data_dict["state_norm"] = state_norm
# We replace the invalid images with the background image
# and also randomly mask images by the background image
background_color = np.array([
int(x*255) for x in self.image_processor.image_mean
], dtype=np.uint8).reshape(1, 1, 3)
background_image = np.ones((
self.image_processor.size["height"],
self.image_processor.size["width"], 3), dtype=np.uint8
) * background_color # 基于预训练的图像均值构建背景图片
image_metas = list(self.pairwise(image_metas))
mask_probs = [self.cond_mask_prob] * self.num_cameras # 一定概率对图像加入mask, 提高模型泛化性
if self.cam_ext_mask_prob >= 0.0:
mask_probs[0] = self.cam_ext_mask_prob
rearranged_images = []
for i in range(self.img_history_size):
for j in range(self.num_cameras):
images, image_mask = image_metas[j]
image, valid = images[i], image_mask[i]
if valid and (math.prod(image.shape) > 0) and \
(random.random() > mask_probs[j]):
rearranged_images.append((image, True))
else:
rearranged_images.append((background_image.copy(), False)) # 直接将背景噪音加入进去
preprocessed_images = []
processor = self.image_processor
for image, valid in rearranged_images:
image = Image.fromarray(image)
if self.image_size is not None:
image = transforms.Resize(self.image_size)(image) # (1008, 336)
# assert image.height == 336, "We haven't prepare for training with images of different resolutions."
if valid and self.auto_adjust_image_brightness: # False
pixel_values = list(image.getdata())
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
if average_brightness <= 0.15:
image = transforms.ColorJitter(brightness=(1.75,1.75))(image)
# Only apply image augmentation to 50% of the images
if valid and self.image_aug and (random.random() > 0.5):
aug_type = random.choice([
"corrput_only", "color_only", "both"])
if aug_type != "corrput_only":
image = transforms.ColorJitter(
brightness=0.3, contrast=0.4, saturation=0.5, hue=0.03)(image)
if aug_type != "color_only":
image = image_corrupt(image)
if self.image_aspect_ratio == 'pad': # True
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
preprocessed_images.append(image)
data_dict["images"] = preprocessed_images
if self.use_precomp_lang_embed:
if content["instruction"][-1] == ".":
content["instruction"] = content["instruction"][:-1]
data_dict["lang_embed"] = torch.load(content["instruction"]) \
if random.random() > self.cond_mask_prob else self.empty_lang_embed
else:
instruction = content["instruction"] \
if random.random() > self.cond_mask_prob else "" # 语言有的时候不输入
data_dict["input_ids"] = self.tokenizer(
instruction,
return_tensors="pt",
padding="longest",
truncation=False,
).input_ids[0] # 得到分词的token编码, (1, 35)
assert len(data_dict["input_ids"]) <= self.tokenizer_max_length, \
f"Instruction length {len(data_dict['input_ids'])} exceeds the maximum length {self.tokenizer_max_length}."
for k, v in data_dict.items():
if isinstance(v, np.ndarray):
data_dict[k] = torch.from_numpy(v)
for k, v in data_dict.items():
assert not isinstance(v, np.ndarray), f"key: {k}, value: {v}"
# data_dict[k] = torch.from_numpy(v)
return data_dict
except BaseException as e:
# Print the error info
if data_dict is not None:
print(f"Error catched when processing sample from {data_dict.get('dataset_name')}:", e)
else:
print(f"Error catched when processing sample:", e)
traceback.print_exc()
# Try incresing the index
index = (index + 1) % len(self)
上述hdf5_vla_dataset.py
主要代码为dataset.py
代码中self.hdf5_dataset.get_item()
, 最终返回data_dict
4.2.1 数据重新整理
res = self.hdf5_dataset.get_item()
content = res['meta']
states = res['state'] # (1, 128)
actions = res['actions'] # (64, 128)
state_elem_mask = res['state_indicator']
image_metas = [
res['cam_high'], res['cam_high_mask'],
res['cam_right_wrist'], res['cam_right_wrist_mask'],
res['cam_left_wrist'], res['cam_left_wrist_mask'],
]
state_std = res['state_std']
state_mean = res['state_mean']
state_norm = res['state_norm']
data_dict['dataset_name'] = content['dataset_name']
data_dict['data_idx'] = self.dataset_name2id[data_dict['dataset_name']]
data_dict['ctrl_freq'] = self.control_freq[data_dict['dataset_name']] hu mao cif random.random() > self.cond_mask_prob else 0 # "configs/custom_configs/custom_dataset_control_freq.json", 有一定概率ctrl_freq为0
4.2.2 数据进行后处理
# 1️⃣对state进行加噪声, 类似关键点检测生成高斯图一样
if self.state_noise_snr is not None:
states += np.random.normal(
0.0, state_std / np.sqrt(10 ** (self.state_noise_snr / 10)),
states.shape)
# 2️⃣对state进行进一步的随机drop, drop的值用预先处理的均值和标准差进行填充, state_elem_mask也是如此, 一开始全是1, 但是这里会随机补充0
ds_state_mean = np.array(self.dataset_stat[data_dict['dataset_name']]['state_mean'])# "configs/custom_configs/custom_dataset_stat.json"
ds_state_mean = np.tile(ds_state_mean[None], (states.shape[0], 1))
data_dict["states"] = states if random.random() > self.cond_mask_prob else ds_state_mean
data_dict["actions"] = actions
data_dict["state_elem_mask"] = state_elem_mask if random.random() > self.cond_mask_prob else np.zeros_like(state_elem_mask)
data_dict["state_norm"] = state_norm
4.2.3 对图像进行后处理
有一定概率基于预训练的图像均值构建背景图片, 随机加入图片构建图像数据, 处理后的图片保存在data_dict["images"]
background_color = np.array([
int(x*255) for x in self.image_processor.image_mean
], dtype=np.uint8).reshape(1, 1, 3)
background_image = np.ones((
self.image_processor.size["height"],
self.image_processor.size["width"], 3), dtype=np.uint8
) * background_color # 基于预训练的图像均值构建背景图片
image_metas = list(self.pairwise(image_metas))
mask_probs = [self.cond_mask_prob] * self.num_cameras # 一定概率对图像加入mask, 提高模型泛化性
if self.cam_ext_mask_prob >= 0.0:
mask_probs[0] = self.cam_ext_mask_prob
rearranged_images = []
for i in range(self.img_history_size):
for j in range(self.num_cameras):
images, image_mask = image_metas[j]
image, valid = images[i], image_mask[i]
if valid and (math.prod(image.shape) > 0) and \
(random.random() > mask_probs[j]):
rearranged_images.append((image, True))
else:
rearranged_images.append((background_image.copy(), False)) # 直接将背景噪音加入进去
preprocessed_images = []
processor = self.image_processor
for image, valid in rearranged_images:
image = Image.fromarray(image)
if self.image_size is not None:
image = transforms.Resize(self.image_size)(image) # (1008, 336)
# assert image.height == 336, "We haven't prepare for training with images of different resolutions."
if valid and self.auto_adjust_image_brightness: # False
pixel_values = list(image.getdata())
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
if average_brightness <= 0.15:
image = transforms.ColorJitter(brightness=(1.75,1.75))(image)
# Only apply image augmentation to 50% of the images
if valid and self.image_aug and (random.random() > 0.5):
aug_type = random.choice([
"corrput_only", "color_only", "both"])
if aug_type != "corrput_only":
image = transforms.ColorJitter(
brightness=0.3, contrast=0.4, saturation=0.5, hue=0.03)(image)
if aug_type != "color_only":
image = image_corrupt(image)
if self.image_aspect_ratio == 'pad': # True
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
preprocessed_images.append(image)
data_dict["images"] = preprocessed_images
4.2.4 对文本进行后处理, 获得短语token分词
instruction = content["instruction"] \
if random.random() > self.cond_mask_prob else "" # 语言有的时候不输入
data_dict["input_ids"] = self.tokenizer(
instruction,
return_tensors="pt",
padding="longest",
truncation=False,
).input_ids[0] # 得到分词的token编码, (1, 35)
assert len(data_dict["input_ids"]) <= self.tokenizer_max_length, \
f"Instruction length {len(data_dict['input_ids'])} exceeds the maximum length {self.tokenizer_max_length}."
4.2.5 对数据进行后处理, 获得数据字典, data_dict
{
"data_name": "custom",
"data_idx": 0, # 文件名对应的索引
"ctrl_freq": 0,
"states": tensor[np.array(1, 128)],
"actions": tensor[np.array(64, 128)],
"state_elem_mask": tensor[np.array(128,)],
"state_norm": tensor[np.array(128,)],
"images": [6, 3, 384, 384], # 这里的6是3(三个摄像头)*2(历史图片时间序列上是2)
"input_ids": tensor[np.array(9,)],
}
参考
五、模型训练
脚本文件train.py、rdt_runner.py、model.py、 blocks.py
5.1 train.py
images = batch["images"].to(dtype=weight_dtype) # shape: (B, 6, 3, 384, 384)
states = batch["states"].to(dtype=weight_dtype) # (B, 1, D_a), 2, 1, 128, # We only use the last state as input
states = states[:, -1:, :] # (B, 1, 128)
actions = batch["actions"].to(dtype=weight_dtype) # shape (B, 64, 128)
state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype) # (B, 128)shape
ctrl_freqs = batch["ctrl_freqs"] # shape (B, 128), 有些是0有些是100
with torch.no_grad():
batch_size, _, C, H, W = images.shape
image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach() # shape (B*6, 729, 1152)
image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size)) # shape (B, 4374, 1152)
lang_attn_mask = batch["lang_attn_mask"] # mask 地方不做embedding
text_embeds = batch["lang_embeds"].to(dtype=weight_dtype) \
if args.precomp_lang_embed \
else text_encoder(
input_ids=batch["input_ids"],
attention_mask=lang_attn_mask
)["last_hidden_state"].detach()
# shape (B, 46[token num], 4096)
state_elem_mask = state_elem_mask.unsqueeze(1)
loss = rdt(
lang_tokens=text_embeds,
lang_attn_mask=lang_attn_mask,
img_tokens=image_embeds,
state_tokens=states,
action_gt=actions,
action_mask=state_elem_mask,
ctrl_freqs=ctrl_freqs
)
-
image_embeds
通过视觉编码得到特征为shape=(BatchSize*6, 729, 1152)
并转换成shape=(BatchSize, 4374, 1152)
-
text_embeds
通过语言编码得到特征为shape=(BatchSize, 46, 4096)将图文编码输入到
rdt`模型中
5.2 rdt_runner.py
1️⃣ 整体pipeline
def compute_loss(self, lang_tokens, lang_attn_mask, img_tokens,
state_tokens, action_gt, action_mask, ctrl_freqs
) -> torch.Tensor:
'''
lang_tokens: (batch_size, lang_len, lang_token_dim)
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
which should be True-False bool tensor.
img_tokens: (batch_size, img_len, img_token_dim)
state_tokens: (batch_size, 1, state_token_dim), states
action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervision
action_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor.
ctrl_freqs: (batch_size,), control frequency for each sample.
return: loss_value, a scalar tensor
'''
batch_size = lang_tokens.shape[0]
device = lang_tokens.device
# Sample noise that we'll add to the actions
noise = torch.randn(
action_gt.shape, dtype=action_gt.dtype, device=device
) # shape(batch_size, 64 .128)
# Sample random diffusion timesteps
timesteps = torch.randint(
0, self.num_train_timesteps,
(batch_size,), device=device
).long() # shape(batch_size,), 例如(226, 57); self.num_train_timesteps = 1000
# Add noise to the clean actions according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_action = self.noise_scheduler.add_noise( # 在原图上加上噪音
action_gt, noise, timesteps) # shape (B 64, 128)
# Concatenate the state and action tokens to form the input sequence
state_action_traj = torch.cat([state_tokens, noisy_action], dim=1) # shape (B, 65, 128) 将状态token(128)以及加噪音的action(label)进行合并作为state_action_traj
# Append the action mask to the input sequence
action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1) # shape (2, 1, 128)
state_action_traj = torch.cat([state_action_traj, action_mask], dim=2) # 将当前状态token以及加噪音的action(label)以及对应mask进行合并作为state_action_traj (B, 65, 256)
# Align the dimension with the hidden size
lang_cond, img_cond, state_action_traj = self.adapt_conditions(
lang_tokens, img_tokens, state_action_traj) # state_action_traj shape (2, 64, 128)
# Predict the denoised result
pred = self.model(state_action_traj, ctrl_freqs,
timesteps, lang_cond, img_cond,
lang_mask=lang_attn_mask)
pred_type = self.prediction_type
if pred_type == 'epsilon':
target = noise
elif pred_type == 'sample':
target = action_gt
else:
raise ValueError(f"Unsupported prediction type {pred_type}")
loss = F.mse_loss(pred, target)
return loss
2️⃣数据构成
noise = torch.randn(action_gt.shape, dtype=action_gt.dtype, device=device) # shape(batch_size, 64 .128)
timesteps = torch.randint(0, self.num_train_timesteps, (batch_size,), device=device).long() # shape(batch_size,), 例如(226, 57); self.num_train_timesteps = 1000
noisy_action = self.noise_scheduler.add_noise(action_gt, noise, timesteps) # shape (B 64, 128), 根据每个时间步的噪声幅度,在清洁动作中添加噪声
state_action_traj = torch.cat([state_tokens, noisy_action], dim=1) # shape (B, 65, 128) 将状态state token(128)以及加噪音的action(label)进行合并作为state_action_traj
action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1) # shape (B, 1, 128)
state_action_traj = torch.cat([state_action_traj, action_mask], dim=2) # 将当前状态token以及加噪音的action(label)以及对应mask进行合并作为state_action_traj (B, 65, 256)
lang_cond, img_cond, state_action_traj = self.adapt_conditions(lang_tokens, img_tokens, state_action_traj) # 分别加入adapter mlp的推理
5.3 model.py
def forward(self, x, freq, t, lang_c, img_c, lang_mask=None, img_mask=None):
"""
Forward pass of RDT.
x: (B, T, D), state + action token sequence, T = horizon + 1,
dimension D is assumed to be the same as the hidden size.状态token以及加噪音的action(label)进行合并作为
freq: (B,), a scalar indicating control frequency.
t: (B,) or (1,), diffusion timesteps.
lang_c: (B, L_lang, D) or None, language condition tokens (variable length),
dimension D is assumed to be the same as the hidden size.
img_c: (B, L_img, D) or None, image condition tokens (fixed length),
dimension D is assumed to be the same as the hidden size.
lang_mask: (B, L_lang) or None, language condition mask (True for valid).
img_mask: (B, L_img) or None, image condition mask (True for valid).
"""
# 这里的D就是2048
t = self.t_embedder(t).unsqueeze(1) # (B, 1, D) or (1, 1, D)
freq = self.freq_embedder(freq).unsqueeze(1) # (B, 1, D)
# Append timestep to the input tokens
if t.shape[0] == 1:
t = t.expand(x.shape[0], -1, -1)
x = torch.cat([t, freq, x], dim=1) # (B, T+1, D), 这里的x为(B, 65, 2048)
# Add multimodal position embeddings
x = x + self.x_pos_embed
# Note the lang is of variable length
lang_c = lang_c + self.lang_cond_pos_embed[:, :lang_c.shape[1]]
img_c = img_c + self.img_cond_pos_embed
# Forward pass
conds = [lang_c, img_c] # shape: [(B, T, 2048), (B, T, 2048), ...]
masks = [lang_mask, img_mask] # img_mask为None
for i, block in enumerate(self.blocks):
c, mask = conds[i%2], masks[i%2]
x = block(x, c, mask) # (B, T+1, D)
# Inject the language condition at the final layer
x = self.final_layer(x) # (B, T+1, out_channels), out_channels = 128, 这里的x shape(B, 67, 128)
# Only preserve the action tokens
x = x[:, -self.horizon:] # x shape (B, 64, 128)
return x
这里需要注明下-self.horizon:默认为action的长度=64
, 可以参考下图Dit, 其实就是输出的noise
DIT
5.4 blocks.py
上述的self.blocks
代码如下所示
#################################################################################
# Cross Attention Layers #
#################################################################################
class CrossAttention(nn.Module):
"""
A cross-attention layer with flash attention.
"""
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0,
proj_drop: float = 0,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, c: torch.Tensor,
mask: torch.Tensor | None = None) -> torch.Tensor:
B, N, C = x.shape
_, L, _ = c.shape
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
kv = self.kv(c).reshape(B, L, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
# Prepare attn mask (B, L) to mask the conditioion
if mask is not None:
mask = mask.reshape(B, 1, 1, L)
mask = mask.expand(-1, -1, N, -1)
if self.fused_attn:
x = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
dropout_p=self.attn_drop.p if self.training else 0.,
attn_mask=mask
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if mask is not None:
attn = attn.masked_fill_(mask.logical_not(), float('-inf'))
attn = attn.softmax(dim=-1)
if self.attn_drop.p > 0:
attn = self.attn_drop(attn)
x = attn @ v
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
x = self.proj(x)
if self.proj_drop.p > 0:
x = self.proj_drop(x)
return x
#################################################################################
# RDT Block #
#################################################################################
class RDTBlock(nn.Module):
"""
A RDT block with cross-attention conditioning.
"""
def __init__(self, hidden_size, num_heads, **block_kwargs):
super().__init__()
self.norm1 = RmsNorm(hidden_size, eps=1e-6)
self.attn = Attention(
dim=hidden_size, num_heads=num_heads,
qkv_bias=True, qk_norm=True,
norm_layer=RmsNorm,**block_kwargs)
self.cross_attn = CrossAttention(
hidden_size, num_heads=num_heads,
qkv_bias=True, qk_norm=True,
norm_layer=RmsNorm,**block_kwargs)
self.norm2 = RmsNorm(hidden_size, eps=1e-6)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.ffn = Mlp(in_features=hidden_size,
hidden_features=hidden_size,
act_layer=approx_gelu, drop=0)
self.norm3 = RmsNorm(hidden_size, eps=1e-6)
def forward(self, x, c, mask=None):
origin_x = x
x = self.norm1(x)
x = self.attn(x)
x = x + origin_x
origin_x = x
x = self.norm2(x)
x = self.cross_attn(x, c, mask)
x = x + origin_x
origin_x = x
x = self.norm3(x)
x = self.ffn(x)
x = x + origin_x
return x
class FinalLayer(nn.Module):
"""
The final layer of RDT.
"""
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = RmsNorm(hidden_size, eps=1e-6)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.ffn_final = Mlp(in_features=hidden_size,
hidden_features=hidden_size,
out_features=out_channels,
act_layer=approx_gelu, drop=0)
def forward(self, x):
x = self.norm_final(x)
x = self.ffn_final(x)
return x
六、 模型推理
#!/home/lin/software/miniconda3/envs/aloha/bin/python
# -- coding: UTF-8
"""
#!/usr/bin/python3
"""
import argparse
import sys
import threading
import time
import yaml
from collections import deque
import numpy as np
import rospy
import torch
from cv_bridge import CvBridge
from geometry_msgs.msg import Twist
from nav_msgs.msg import Odometry
from PIL import Image as PImage
from sensor_msgs.msg import Image, JointState
from std_msgs.msg import Header
import cv2
from agilex_model import create_model
CAMERA_NAMES = ['cam_high', 'cam_right_wrist', 'cam_left_wrist']
observation_window = None
lang_embeddings = None
# debug
preload_images = None
# Initialize the model
def make_policy(args):
with open(args.config_path, "r") as fp:
config = yaml.safe_load(fp)
args.config = config
# pretrained_text_encoder_name_or_path = "google/t5-v1_1-xxl"
pretrained_vision_encoder_name_or_path = "google/siglip-so400m-patch14-384"
model = create_model(
args=args.config,
dtype=torch.bfloat16,
pretrained=args.pretrained_model_name_or_path,
# pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
control_frequency=args.ctrl_freq,
)
return model
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
# Interpolate the actions to make the robot move smoothly
def interpolate_action(args, prev_action, cur_action):
steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
diff = np.abs(cur_action - prev_action)
step = np.ceil(diff / steps).astype(int)
step = np.max(step)
if step <= 1:
return cur_action[np.newaxis, :]
new_actions = np.linspace(prev_action, cur_action, step + 1)
return new_actions[1:]
def get_config(args):
config = {
'episode_len': args.max_publish_step,
'state_dim': 14,
'chunk_size': args.chunk_size,
'camera_names': CAMERA_NAMES,
}
return config
# Get the observation from the ROS topic
def get_ros_observation(args,ros_operator):
rate = rospy.Rate(args.publish_rate)
print_flag = True
while True and not rospy.is_shutdown():
result = ros_operator.get_frame()
if not result:
if print_flag:
print("syn fail when get_ros_observation")
print_flag = False
rate.sleep()
continue
print_flag = True
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, robot_base) = result
# print(f"sync success when get_ros_observation")
return (img_front, img_left, img_right,
puppet_arm_left, puppet_arm_right)
# Update the observation window buffer
def update_observation_window(args, config, ros_operator):
# JPEG transformation
# Align with training
def jpeg_mapping(img):
img = cv2.imencode('.jpg', img)[1].tobytes()
img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
return img
global observation_window
if observation_window is None:
observation_window = deque(maxlen=2)
# Append the first dummy image
observation_window.append(
{
'qpos': None,
'images':
{
config["camera_names"][0]: None,
config["camera_names"][1]: None,
config["camera_names"][2]: None,
},
}
)
img_front, img_left, img_right, puppet_arm_left, puppet_arm_right = get_ros_observation(args,ros_operator)
img_front = jpeg_mapping(img_front)
img_left = jpeg_mapping(img_left)
img_right = jpeg_mapping(img_right)
qpos = np.concatenate(
(np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0)
qpos = torch.from_numpy(qpos).float().cuda()
observation_window.append(
{
'qpos': qpos,
'images':
{
config["camera_names"][0]: img_front,
config["camera_names"][1]: img_right,
config["camera_names"][2]: img_left,
},
}
)
# RDT inference
def inference_fn(args, config, policy, t):
global observation_window
global lang_embeddings
# print(f"Start inference_thread_fn: t={t}")
while True and not rospy.is_shutdown():
time1 = time.time()
# fetch images in sequence [front, right, left]
image_arrs = [
observation_window[-2]['images'][config['camera_names'][0]],
observation_window[-2]['images'][config['camera_names'][1]],
observation_window[-2]['images'][config['camera_names'][2]],
observation_window[-1]['images'][config['camera_names'][0]],
observation_window[-1]['images'][config['camera_names'][1]],
observation_window[-1]['images'][config['camera_names'][2]]
]
# fetch debug images in sequence [front, right, left]
# image_arrs = [
# preload_images[config['camera_names'][0]][max(t - 1, 0)],
# preload_images[config['camera_names'][2]][max(t - 1, 0)],
# preload_images[config['camera_names'][1]][max(t - 1, 0)],
# preload_images[config['camera_names'][0]][t],
# preload_images[config['camera_names'][2]][t],
# preload_images[config['camera_names'][1]][t]
# ]
# # encode the images
# for i in range(len(image_arrs)):
# image_arrs[i] = cv2.imdecode(np.frombuffer(image_arrs[i], np.uint8), cv2.IMREAD_COLOR)
# proprio = torch.from_numpy(preload_images['qpos'][t]).float().cuda()
images = [PImage.fromarray(arr) if arr is not None else None
for arr in image_arrs]
# for i, pos in enumerate(['f', 'r', 'l'] * 2):
# images[i].save(f'{t}-{i}-{pos}.png')
# get last qpos in shape [14, ]
proprio = observation_window[-1]['qpos']
# unsqueeze to [1, 14]
proprio = proprio.unsqueeze(0)
# actions shaped as [1, 64, 14] in format [left, right]
actions = policy.step(
proprio=proprio,
images=images,
text_embeds=lang_embeddings
).squeeze(0).cpu().numpy()
# print(f"inference_actions: {actions.squeeze()}")
print(f"Model inference time: {time.time() - time1} s")
# print(f"Finish inference_thread_fn: t={t}")
return actions
# Main loop for the manipulation task
def model_inference(args, config, ros_operator):
global lang_embeddings
# Load rdt model
policy = make_policy(args)
lang_dict = torch.load(args.lang_embeddings_path)
print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"")
lang_embeddings = lang_dict["embeddings"]
max_publish_step = config['episode_len']
chunk_size = config['chunk_size']
# Initialize position of the puppet arm
# left0 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, 3.557830810546875]
# right0 = [-0.00133514404296875, 0.00438690185546875, 0.034523963928222656, -0.053597450256347656, -0.00476837158203125, -0.00209808349609375, 3.557830810546875]
# left1 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258]
# right1 = [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883]
# ⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️改⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️
left0 = [0, 0, 0, 0, 0, 0, 0, 1]
right0 = [0, 0, 0, 0, 0, 0, 0, 1]
left1 = [0, 0, 0, 0, 0, 0, 0, 1]
right1 = [0, 0, 0, 0, 0, 0, 0, 1]
ros_operator.puppet_arm_publish_continuous(left0, right0)
input("Press enter to continue")
ros_operator.puppet_arm_publish_continuous(left1, right1)
# Initialize the previous action to be the initial robot state
pre_action = np.zeros(config['state_dim'])
# ⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️改⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️
pre_action[:16] = np.array(
[0, 0, 0, 0 ,0, 0, 0] +
[0, 0, 0, 0, 0, 0, 0]
)
# pre_action[:14] = np.array(
# [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258] +
# [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883]
# )
action = None
# Inference loop
with torch.inference_mode():
while True and not rospy.is_shutdown():
# The current time step
t = 0
rate = rospy.Rate(args.publish_rate)
action_buffer = np.zeros([chunk_size, config['state_dim']]) # state_dim 14
while t < max_publish_step and not rospy.is_shutdown():
# Update observation window
update_observation_window(args, config, ros_operator)
# When coming to the end of the action chunk
if t % chunk_size == 0:
# Start inference
action_buffer = inference_fn(args, config, policy, t).copy()
raw_action = action_buffer[t % chunk_size]
action = raw_action
# Interpolate the original action sequence
if args.use_actions_interpolation:
# print(f"Time {t}, pre {pre_action}, act {action}")
interp_actions = interpolate_action(args, pre_action, action)
else:
interp_actions = action[np.newaxis, :]
# Execute the interpolated actions one by one
for act in interp_actions:
left_action = act[:7]
right_action = act[7:14]
if not args.disable_puppet_arm:
ros_operator.puppet_arm_publish(left_action, right_action) # puppet_arm_publish_continuous_thread
if args.use_robot_base:
vel_action = act[14:16]
ros_operator.robot_base_publish(vel_action)
rate.sleep()
# print(f"doing action: {act}")
t += 1
print("Published Step", t)
pre_action = action.copy()
# ROS operator class
class RosOperator:
def __init__(self, args):
self.robot_base_deque = None
self.puppet_arm_right_deque = None
self.puppet_arm_left_deque = None
self.img_front_deque = None
self.img_right_deque = None
self.img_left_deque = None
self.img_front_depth_deque = None
self.img_right_depth_deque = None
self.img_left_depth_deque = None
self.bridge = None
self.puppet_arm_left_publisher = None
self.puppet_arm_right_publisher = None
self.robot_base_publisher = None
self.puppet_arm_publish_thread = None
self.puppet_arm_publish_lock = None
self.args = args
self.init()
self.init_ros()
def init(self):
# 初始化CvBridge对象,用于OpenCV图像格式和其他格式之间的转换
self.bridge = CvBridge()
# 初始化图像缓存队列,用于存储来自不同摄像头的图像数据
self.img_left_deque = deque()
self.img_right_deque = deque()
self.img_front_deque = deque()
# 初始化深度图像缓存队列,用于存储来自不同摄像头的深度图像数据
self.img_left_depth_deque = deque()
self.img_right_depth_deque = deque()
self.img_front_depth_deque = deque()
# 初始化机械臂操作指令缓存队列,用于存储机械臂的操作指令
self.puppet_arm_left_deque = deque()
self.puppet_arm_right_deque = deque()
# 初始化机器人底盘操作指令缓存队列,用于存储机器人底盘的操作指令
self.robot_base_deque = deque()
# 初始化机械臂指令发布锁,用于同步机械臂指令的发布,避免竞态条件
self.puppet_arm_publish_lock = threading.Lock()
# 获取锁,以确保在初始化阶段之后,机械臂指令的发布是线程安全的
self.puppet_arm_publish_lock.acquire()
def puppet_arm_publish(self, left, right):
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # Set timestep
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = left
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = right
self.puppet_arm_right_publisher.publish(joint_state_msg)
def robot_base_publish(self, vel):
vel_msg = Twist()
vel_msg.linear.x = vel[0]
vel_msg.linear.y = 0
vel_msg.linear.z = 0
vel_msg.angular.x = 0
vel_msg.angular.y = 0
vel_msg.angular.z = vel[1]
self.robot_base_publisher.publish(vel_msg)
def puppet_arm_publish_continuous(self, left, right):
# 初始化发布器的频率
rate = rospy.Rate(self.args.publish_rate)
# 初始化左右手臂的位置数据
left_arm = None
right_arm = None
"""
在代码中,left 和 left_arm 的含义如下:
left:表示目标位置(即期望的机械臂左臂关节角度)。它是通过推理模型或用户指定的目标状态生成的动作序列的一部分。
left_arm:表示当前机械臂左臂的实际关节角度状态。它从 ROS 话题订阅的数据中获取,反映了机械臂当前的真实位置。
两者的区别在于:
left 是目标状态,代表机械臂需要移动到的位置。
left_arm 是当前状态,代表机械臂当前所在的位置。
在插值或连续发布动作的过程中,代码会逐步调整 left_arm 的值,使其逐渐接近目标状态 left。
"""
# 在ROS节点关闭前持续执行循环
while True and not rospy.is_shutdown():
# 检查并更新左臂位置数据
if len(self.puppet_arm_left_deque) != 0:
left_arm = list(self.puppet_arm_left_deque[-1].position)
# 检查并更新右臂位置数据
if len(self.puppet_arm_right_deque) != 0:
right_arm = list(self.puppet_arm_right_deque[-1].position)
# 如果任一手臂的位置数据未更新,则等待一段时间后继续尝试
if left_arm is None or right_arm is None:
rate.sleep()
continue
else:
# 当左右手臂位置数据均成功更新后,退出循环
break
left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
flag = True
step = 0
while flag and not rospy.is_shutdown():
if self.puppet_arm_publish_lock.acquire(False):
return
left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
flag = False
for i in range(len(left)):
if left_diff[i] < self.args.arm_steps_length[i]: # 每次移动不能超过太多
left_arm[i] = left[i]
else:
left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i] # 移动太多给她很小的距离进行移动
flag = True
for i in range(len(right)):
if right_diff[i] < self.args.arm_steps_length[i]:
right_arm[i] = right[i]
else:
right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
flag = True
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # Set the timestep
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = left_arm
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = right_arm
self.puppet_arm_right_publisher.publish(joint_state_msg)
step += 1
print("puppet_arm_publish_continuous:", step)
rate.sleep()
def puppet_arm_publish_linear(self, left, right):
num_step = 100
rate = rospy.Rate(200)
left_arm = None
right_arm = None
while True and not rospy.is_shutdown():
if len(self.puppet_arm_left_deque) != 0:
left_arm = list(self.puppet_arm_left_deque[-1].position)
if len(self.puppet_arm_right_deque) != 0:
right_arm = list(self.puppet_arm_right_deque[-1].position)
if left_arm is None or right_arm is None:
rate.sleep()
continue
else:
break
traj_left_list = np.linspace(left_arm, left, num_step)
traj_right_list = np.linspace(right_arm, right, num_step)
for i in range(len(traj_left_list)):
traj_left = traj_left_list[i]
traj_right = traj_right_list[i]
traj_left[-1] = left[-1]
traj_right[-1] = right[-1]
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = traj_left
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = traj_right
self.puppet_arm_right_publisher.publish(joint_state_msg)
rate.sleep()
def puppet_arm_publish_continuous_thread(self, left, right):
if self.puppet_arm_publish_thread is not None:
self.puppet_arm_publish_lock.release()
self.puppet_arm_publish_thread.join()
self.puppet_arm_publish_lock.acquire(False)
self.puppet_arm_publish_thread = None
self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
self.puppet_arm_publish_thread.start()
def get_frame(self):
if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
return False
if self.args.use_depth_image:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
# 如果不是else之前的条件,计算三幅图像中时间戳最小的值
# 从三个图像队列中取出最后一幅图像,并获取它们的时间戳
# 使用时间戳的秒数来比较,找到三个时间戳中最小的一个
else:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec()])
if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
return False
# 当队列中最前面的左图像的时间戳小于目标帧时间时,持续移除队列头部的图像
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_deque.popleft()
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_deque.popleft()
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_deque.popleft()
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_left_deque.popleft()
puppet_arm_left = self.puppet_arm_left_deque.popleft()
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_right_deque.popleft()
puppet_arm_right = self.puppet_arm_right_deque.popleft()
img_left_depth = None
if self.args.use_depth_image:
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_depth_deque.popleft()
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
img_right_depth = None
if self.args.use_depth_image:
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_depth_deque.popleft()
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
img_front_depth = None
if self.args.use_depth_image:
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_depth_deque.popleft()
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
robot_base = None
if self.args.use_robot_base:
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
self.robot_base_deque.popleft()
robot_base = self.robot_base_deque.popleft()
return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, robot_base)
def img_left_callback(self, msg):
if len(self.img_left_deque) >= 2000:
self.img_left_deque.popleft()
self.img_left_deque.append(msg)
def img_right_callback(self, msg):
if len(self.img_right_deque) >= 2000:
self.img_right_deque.popleft()
self.img_right_deque.append(msg)
def img_front_callback(self, msg):
if len(self.img_front_deque) >= 2000:
self.img_front_deque.popleft()
self.img_front_deque.append(msg)
def img_left_depth_callback(self, msg):
if len(self.img_left_depth_deque) >= 2000:
self.img_left_depth_deque.popleft()
self.img_left_depth_deque.append(msg)
def img_right_depth_callback(self, msg):
if len(self.img_right_depth_deque) >= 2000:
self.img_right_depth_deque.popleft()
self.img_right_depth_deque.append(msg)
def img_front_depth_callback(self, msg):
if len(self.img_front_depth_deque) >= 2000:
self.img_front_depth_deque.popleft()
self.img_front_depth_deque.append(msg)
def puppet_arm_left_callback(self, msg):
if len(self.puppet_arm_left_deque) >= 2000:
self.puppet_arm_left_deque.popleft()
self.puppet_arm_left_deque.append(msg)
def puppet_arm_right_callback(self, msg):
if len(self.puppet_arm_right_deque) >= 2000:
self.puppet_arm_right_deque.popleft()
self.puppet_arm_right_deque.append(msg)
def robot_base_callback(self, msg):
if len(self.robot_base_deque) >= 2000:
self.robot_base_deque.popleft()
self.robot_base_deque.append(msg)
def init_ros(self):
"""
初始化ROS节点和订阅者。
该方法初始化了ROS节点,并根据参数订阅了相应的ROS话题。
它还创建了用于发布关节状态的发布者。
"""
# 初始化ROS节点,设置节点名为'joint_state_publisher',并允许匿名节点
rospy.init_node('joint_state_publisher', anonymous=True)
# 订阅左侧、右侧和前端相机的图像话题
rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
# 如果使用深度图像,则订阅对应的深度图像话题
if self.args.use_depth_image:
rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
# 订阅机械臂关节状态和机器人底盘里程计话题
rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
# 创建用于发布机械臂关节状态和机器人底盘控制指令的发布者
self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10)
self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
def get_arguments():
parser = argparse.ArgumentParser()
# 添加最大发布步骤数量的参数,用于限制发布动作的步骤数
parser.add_argument('--max_publish_step', action='store', type=int,
help='Maximum number of action publishing steps', default=10000, required=False)
# 添加随机种子参数,用于确保结果的可重复性
parser.add_argument('--seed', action='store', type=int,
help='Random seed', default=None, required=False)
# 添加前置摄像头图像主题参数,指定前置摄像头的图像来源
parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
default='/camera_f/color/image_raw', required=False)
# 添加左置摄像头图像主题参数,指定左置摄像头的图像来源
parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
default='/camera_l/color/image_raw', required=False)
# 添加右置摄像头图像主题参数,指定右置摄像头的图像来源
parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
default='/camera_r/color/image_raw', required=False)
# 添加前置摄像头深度图像主题参数,指定前置摄像头的深度图像来源
parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic',
default='/camera_f/depth/image_raw', required=False)
# 添加左置摄像头深度图像主题参数,指定左置摄像头的深度图像来源
parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic',
default='/camera_l/depth/image_raw', required=False)
# 添加右置摄像头深度图像主题参数,指定右置摄像头的深度图像来源
parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic',
default='/camera_r/depth/image_raw', required=False)
# 添加命令行参数解析,用于配置左臂指令主题
parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic',
default='/master/joint_left', required=False)
# 添加命令行参数解析,用于配置右臂指令主题
parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic',
default='/master/joint_right', required=False)
# 添加命令行参数解析,用于配置左臂状态主题
parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
default='/puppet/joint_left', required=False)
# 添加命令行参数解析,用于配置右臂状态主题
parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
default='/puppet/joint_right', required=False)
# 添加命令行参数解析,用于配置机器人底盘状态主题
parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
default='/odom_raw', required=False)
# 添加命令行参数解析,用于配置机器人底盘指令主题
parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic',
default='/cmd_vel', required=False)
# 添加命令行参数解析,用于配置是否使用机器人底盘移动
parser.add_argument('--use_robot_base', action='store_true',
help='Whether to use the robot base to move around',
default=False, required=False)
# 添加命令行参数解析,用于配置动作发布频率
parser.add_argument('--publish_rate', action='store', type=int,
help='The rate at which to publish the actions',
default=30, required=False)
# 添加命令行参数解析,用于配置机器人控制频率
parser.add_argument('--ctrl_freq', action='store', type=int,
help='The control frequency of the robot',
default=25, required=False)
# 添加命令行参数解析,用于配置动作块大小
parser.add_argument('--chunk_size', action='store', type=int,
help='Action chunk size',
default=64, required=False)
# 添加命令行参数解析,用于配置每个关节每步的最大变化量
parser.add_argument('--arm_steps_length', action='store', type=float,
help='The maximum change allowed for each joint per timestep',
default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], required=False)
# 添加命令行参数解析,用于配置是否在动作差异过大时进行插值
parser.add_argument('--use_actions_interpolation', action='store_true',
help='Whether to interpolate the actions if the difference is too large',
default=False, required=False)
# 添加命令行参数解析,用于配置是否使用深度图像
parser.add_argument('--use_depth_image', action='store_true',
help='Whether to use depth images',
default=False, required=False)
# 添加命令行参数解析,用于配置是否禁用puppet臂,以便安全调试
parser.add_argument('--disable_puppet_arm', action='store_true',
help='Whether to disable the puppet arm. This is useful for safely debugging',default=False)
# 添加命令行参数解析,用于配置配置文件路径
parser.add_argument('--config_path', type=str, default="configs/base.yaml",
help='Path to the config file')
# 以下命令行参数解析被注释掉,可能是因为不再使用或者尚未实现
# parser.add_argument('--cfg_scale', type=float, default=2.0,
# help='the scaling factor used to modify the magnitude of the control features during denoising')
# 添加命令行参数解析,用于配置预训练模型的名称或路径
parser.add_argument('--pretrained_model_name_or_path', type=str, required=True, help='Name or path to the pretrained model')
# 添加命令行参数解析,用于配置语言嵌入路径
parser.add_argument('--lang_embeddings_path', type=str, required=True,
help='Path to the pre-encoded language instruction embeddings')
args = parser.parse_args()
return args
def main():
args = get_arguments()
ros_operator = RosOperator(args)
if args.seed is not None:
set_seed(args.seed)
config = get_config(args)
model_inference(args, config, ros_operator)
if __name__ == '__main__':
main()