RDT (Robotics Diffusion Transformer)
RDT (Robotics Diffusion Transformer) 是一种结合扩散模型和Transformer的大模型方法,支持多任务统一建模,具有强大的泛化能力。
📊 数据格式转换
转换命令
将仿真采集的原始数据格式转换为RDT算法中用到的HDF5格式:
python3 policies/act/data_process/raw_to_hdf5.py -md mujoco -dir data -tn ${task_name} -vn ${video_names}
转换示例
# 示例转换
python3 policies/act/data_process/raw_to_hdf5.py -md mujoco -dir data -tn block_place -vn cam_0 cam_1
数据移动
将hdf5文件移动到RDT需要的地址:
mv data/hdf5/${task_name} policies/RDT/training_data
示例
mv data/hdf5/block_place policies/RDT/training_data
数据目录结构
将多个任务的数据都放在training_data,RDT会在一个模型中训练多个任务,目录结构为:
training_data
├── instructions
│ ├── ${task_1}.json
│ ├── ${task_2}.json
│ ├── ...
├── ${task_1}
│ ├── instructions
│ │ ├── lang_embed_0.pt
│ │ ├── ...
│ ├── episode_0.hdf5
│ ├── episode_1.hdf5
│ ├── ...
├── ${task_2}
│ ├── instructions
│ │ ├── lang_embed_0.pt
│ │ ├── ...
│ ├── episode_0.hdf5
│ ├── episode_1.hdf5
│ ├── ...
├── ...
🎓 模型训练
GPU要求
训练至少需要25G内存(batch size = 4),推理需要0.5G内存
环境安装
conda create -n rdt python=3.10.0
conda activate rdt
cd DISCOVERSE
pip install -r requirements.txt
pip install -e .
pip install torch==2.1.0 torchvision==0.16.0 packaging==24.0 ninja
pip install flash-attn==2.7.2.post1 --no-build-isolation
如果安装flash-attn失败,可以从官方下载对应的.whl安装: https://github.com/Dao-AILab/flash-attention/releases
# 安装flash_attn-*.whl:
pip install flash_attn-*.whl
cd DISCOVERSE/policies/RDT
pip install -r requirements.txt
pip install huggingface_hub==0.25.2
下载模型
cd DISCOVERSE/policies/RDT
mkdir -p weights/RDT && cd weights/RDT
huggingface-cli download google/t5-v1_1-xxl --local-dir t5-v1_1-xxl
huggingface-cli download google/siglip-so400m-patch14-384 --local-dir siglip-so400m-patch14-384
huggingface-cli download robotics-diffusion-transformer/rdt-1b --local-dir rdt-1b
生成language embedding
cd DISCOVERSE
python3 policies/RDT/scripts/encode_lang_batch_once.py ${task_name} ${gpu_id}
示例
python3 policies/RDT/scripts/encode_lang_batch_once.py block_place 0
配置文件
复制policies/RDT/model_config/model_name.yml,并重命名model_name
训练微调
cd DISCOVERSE/policies/RDT
python3 scripts/encode_lang_batch_once.py {task_name} {gpu_id}
示例
python3 scripts/encode_lang_batch_once.py block_place 0
🚀 策略推理
推理命令
cd DISCOVERSE/policies/RDT
bash eval.sh {robot} {task_name} {model_name} {ckpt_id}
推理示例
bash eval.sh airbot block_place model_name 20000