大模型的工作原理:分布式训练入门

我有个同事的签名是"大模型是真的大",大模型(如GPT、LLaMA)之所以被称为“大”,不仅因为它们的参数量高达数十亿甚至上万亿,更因为它们需要强大的计算资源来完成训练。这些资源远超单机能承载的范围,因此,大模型的训练离不开分布式训练技术。

分布式训练概述

分布式训练通过将训练任务分摊到多个计算设备(如GPU、NPU或CPU)上,以加速训练过程。其主要目标包括:

  1. 提高计算效率:减少训练所需时间。
  2. 扩展模型规模:支持更大的模型和更复杂的数据集。(单个设备的显存不足以支撑大模型)
  3. 高效利用资源:通过并行计算,充分利用硬件能力。

在实际操作中,分布式训练需要解决以下关键问题:

  • 如何将数据分配到多个设备上?
  • 如何在多个设备之间共享和同步模型参数?
  • 如何保证训练的准确性和效率?

分布式训练的关键概念

数据并行(Data Parallelism)

数据并行是最常见的分布式训练方法。其核心思想是将数据切分成多个小批次(mini-batches),并将这些小批次分发到不同的设备上进行计算。

流程

  1. 每个设备(如GPU)获取不同的小批次数据。
  2. 每个设备独立计算其对应数据的小批次的梯度。
  3. 汇总所有设备的梯度,并更新全局模型参数。
  4. 广播更新后的参数给所有设备。

优点:实现简单,对数据量较大的场景效果显著。
缺点:对模型规模的扩展性有限,设备间通信开销较大。

以下是一个示意图:

image.png

模型并行(Model Parallelism)

当模型参数过大,单个设备无法容纳时,可以采用模型并行。模型并行将模型拆分成多个部分,并分配到不同设备上。

流程

  1. 将模型的不同层(或块)分配到不同的设备。
  2. 每个设备只计算属于自己的那部分模型。
  3. 设备间通过通信共享中间结果。

优点:适用于超大模型。
缺点:实现复杂,设备间的同步通信成本较高。

以下是模型并行的示意图:

image.png

混合并行(Hybrid Parallelism)

混合并行结合了数据并行和模型并行的优点。在这种方法中,既对数据进行切分,又将模型分割到多个设备上。

优点:充分利用硬件资源,适用于超大规模的训练任务。
缺点:实现和调试更复杂。

以下是混合并行的示意图:

image.png

Pipeline 并行(Pipeline Parallelism)

Pipeline 并行是一种特殊的模型并行,它将模型的不同层分配到不同设备上,但同时允许多个小批次的数据在流水线中流动。

优点:提高了设备利用率,减少了空闲时间。
缺点:需要处理流水线中的梯度同步问题。

以下是Pipeline并行的示意图:

image.png

开源框架支持的分布式训练方法

目前主流的深度学习框架都支持分布式训练,比如 PyTorch、TensorFlow、DeepSpeed 和 Hugging Face Transformers。以下是一些常用的工具和方法。

1. PyTorch 的分布式训练

PyTorch 提供了torch.distributed模块,用于实现分布式训练。

以下是一个简单的数据并行示例:

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# 初始化分布式环境
dist.init_process_group("nccl", rank=0, world_size=1)

# 模型和数据
model = torch.nn.Linear(10, 1).to("cuda:0")
ddp_model = DDP(model, device_ids=[0])
data = torch.randn(20, 10).to("cuda:0")
target = torch.randn(20, 1).to("cuda:0")

# 损失函数和优化器
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01)

# 训练
outputs = ddp_model(data)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()

2. DeepSpeed 的零冗余优化

DeepSpeed 是一种高效的分布式训练框架,特别适用于超大规模模型。其核心特性包括 ZeRO(Zero Redundancy Optimizer)。

ZeRO 的关键在于分布式地存储优化器状态、梯度和参数,从而显著降低每个设备的内存需求。

使用 DeepSpeed 训练模型的示例:

import deepspeed

# 定义模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

model = MyModel()

# 配置 DeepSpeed
ds_config = {
    "train_batch_size": 8,
    "fp16": {
        "enabled": True
    },
    "zero_optimization": {
        "stage": 2
    }
}

# 初始化
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    config_params=ds_config
)

# 训练
data = torch.randn(8, 10).to(model_engine.local_rank)
target = torch.randn(8, 1).to(model_engine.local_rank)
loss = torch.nn.MSELoss()(model_engine(data), target)
model_engine.backward(loss)
model_engine.step()

3. Hugging Face Transformers 的 Trainer

Hugging Face Transformers 提供了一个开箱即用的Trainer类,支持分布式训练。以下是一个训练 GPT 模型的示例:

from transformers import Trainer, TrainingArguments, GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2")
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    save_steps=10,
    save_total_limit=2,
    fp16=True,
    deepspeed="./ds_config.json",  # 支持 DeepSpeed
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=my_dataset,
)

trainer.train()

Ray:分布式训练中的“全能工具”

在深度学习的大规模分布式训练中,Ray 是一个不可忽视的工具。它不仅是一个分布式计算框架,还通过高层封装,提供了许多强大的工具和库,如 Ray Train、Ray Tune 和 Ray Serve,帮助开发者快速构建和管理分布式应用。所以单独写一章。

什么是 Ray?

Ray 是一个通用的分布式计算框架,核心目标是让开发者可以轻松实现分布式程序。它支持各种场景,从机器学习训练、超参数调优到大规模数据处理和在线推理。

Ray 的核心特点:

  1. 简单易用:开发者只需用 Python 编写代码,Ray 会自动帮你处理分布式调度。
  2. 扩展性强:可以在单机上调试,部署到多节点集群时只需简单调整。
  3. 高效的资源管理:支持动态资源分配和任务调度。
  4. 组件丰富:Ray 包含多个高层库,如 Ray Train、Ray Tune 和 Ray Serve,分别对应训练、超参数调优和在线推理。

Ray 的核心概念

在使用 Ray 时,需要理解以下几个核心概念:

  1. Task(任务)
    一个 Ray Task 是一个可以异步运行的函数。它会自动分配到集群中的空闲计算资源。

    示例代码:

    import ray
    
    ray.init()  # 初始化 Ray
    
    @ray.remote
    def slow_function(x):
        import time
        time.sleep(1)
        return x ** 2
    
    futures = [slow_function.remote(i) for i in range(10)]
    results = ray.get(futures)
    print(results)
    

    解释:上述代码中,slow_function 被声明为远程任务(@ray.remote),会并行执行在集群中的不同节点上。

  2. Actor(角色)
    Actor 是 Ray 中的有状态任务,可以用来保存中间状态。例如,深度学习的模型实例可以作为一个 Actor 存在。

    示例代码:

    @ray.remote
    class Counter:
        def __init__(self):
            self.count = 0
    
        def increment(self):
            self.count += 1
            return self.count
    
    counter = Counter.remote()
    print(ray.get(counter.increment.remote()))  # 输出 1
    print(ray.get(counter.increment.remote()))  # 输出 2
    
  3. Cluster(集群)
    Ray 的集群管理非常灵活,你可以在本地运行单节点,也可以扩展到上千节点的分布式集群。

Ray Train:用于分布式训练

Ray Train 是 Ray 为分布式训练任务设计的高层库。它支持各种深度学习框架(如 PyTorch 和 TensorFlow),并通过高效的资源管理和分布式调度简化训练过程。

核心功能

  • 自动分布式支持:数据并行训练。
  • 易于集成:与现有的 PyTorch 或 TensorFlow 代码无缝对接。
  • 灵活扩展:支持 CPU/GPU 混合环境。

以下是一个使用 Ray Train 进行分布式训练的 PyTorch 示例:

import ray
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
import torch
import torch.nn as nn
import torch.optim as optim

# 定义简单的模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# 定义训练函数
def train_loop_per_worker(config):
    model = MyModel()
    optimizer = optim.SGD(model.parameters(), lr=config["lr"])
    loss_fn = nn.MSELoss()

    # 模拟数据
    data = torch.randn(100, 10)
    target = torch.randn(100, 1)

    for _ in range(5):  # 模拟训练
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        print(f"Loss: {loss.item()}")

# 使用 Ray Train 进行分布式训练
ray.init()
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=4),  # 使用 4 个 worker
    train_loop_config={"lr": 0.01},  # 传递训练超参数
)
trainer.fit()

Ray Tune:超参数调优

大模型训练中,找到最佳超参数(如学习率、batch size)非常重要。Ray 提供了 Ray Tune,这是一个分布式超参数调优框架,支持多种搜索算法和调度策略。

示例代码:

from ray import tune
from ray.tune.schedulers import ASHAScheduler

def train(config):
    import torch
    model = torch.nn.Linear(10, 1)
    optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"])
    data = torch.randn(100, 10)
    target = torch.randn(100, 1)
    loss_fn = torch.nn.MSELoss()
    for epoch in range(10):
        optimizer.zero_grad()
        loss = loss_fn(model(data), target)
        loss.backward()
        optimizer.step()
        tune.report(loss=loss.item())  # 上报结果给 Ray Tune

search_space = {"lr": tune.grid_search([0.01, 0.1, 1.0])}
scheduler = ASHAScheduler()
tune.run(
    train,
    config=search_space,
    scheduler=scheduler,
    num_samples=3
)

Ray Serve:分布式推理

训练完成后,大模型的推理服务同样需要分布式支持。Ray 的 Serve 模块提供了高效的分布式推理能力。

以下是一个简单的 Ray Serve 示例:

from ray import serve
import ray

ray.init()
serve.start()

@serve.deployment
def predict(request):
    return {"message": "Hello from Ray Serve!"}

predict.deploy()

import requests
response = requests.get("http://127.0.0.1:8000/predict")
print(response.json())

为什么选择 Ray?

Ray 是分布式训练和部署的一站式解决方案:

  • 如果你想高效地训练大模型,Ray Train 提供了数据并行和资源调度能力。
  • 如果你需要优化超参数,Ray Tune 可以让你轻松实现大规模调优。
  • 如果你需要部署分布式推理服务,Ray Serve 是理想选择。

相比其他工具(如 PyTorch DDP、DeepSpeed),Ray 的优势在于更广的应用场景和更高的灵活性。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,463评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,868评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,213评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,666评论 1 290
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,759评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,725评论 1 294
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,716评论 3 415
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,484评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,928评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,233评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,393评论 1 345
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,073评论 5 340
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,718评论 3 324
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,308评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,538评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,338评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,260评论 2 352