先让整套训练机制完整运转,模型的能力是后续几篇的事

本篇的目标:让训练循环首次运转

第一篇说明过,训练是一个循环:输入真实文字、判断模型预测的误差及其方向、据此微调参数,如此重复成千上万遍。前两篇我们完成了数据的准备工作。本篇的任务,是把这个循环首次完整地搭建起来,并让它运转。

为了把注意力集中在"循环"本身,本篇所用的模型会刻意做得非常简单——它只依据前一个字来预测下一个字。这种模型称为 Bigram(二元)模型。它的能力相当有限,学不到深层的语言规律,但这并不影响本篇的目标。本篇要确认的,是"输入数据、计算损失、调整参数"这一整套机制能够正确地运转起来。模型的能力,是从第五篇起逐步解决的问题。

每次输入多少:批次与上下文长度

训练时,我们不会把整份语料一次性输入模型,那样既无法容纳,也没有必要。每一步只输入一小段。这就引出两个数值。

其一是上下文长度context lengthcontext lengthThe context length is the longest stretch of text a model can look at when predicting the next token. Anything beyond it is simply invisible to the model. A bigram effectively has a context length of one; this mini GPT starts small here and grows it later. Real models push this to tens of thousands of tokens so they can follow long documents and conversations. block_size:模型一次最多查看多长的文字。我们先设为 8,表示模型每次依据至多 8 个字来做预测。

其二是批次大小batch sizebatch sizeThe batch size is how many independent text segments are fed through the model together in a single training step. Processing them in parallel is far more efficient than one at a time, and averaging the error over a whole batch gives a steadier signal for adjusting the parameters. It influences training speed and stability, not what the model ultimately learns. batch_size:出于效率考虑,每一步同时输入若干小段,让模型并行处理。我们设为 32,表示一次输入 32 段。

关键之处在于,每一小段都要配有一个对应的"正确答案"。我们从语料中随机截取一段长度为 block_size 的文字作为输入 x,再把这段文字整体向后移动一个字作为答案 y。这样一来,x 中的每一个位置,它所对应的"下一个字"恰好就是 y 中同一位置上的字。

import torch

torch.manual_seed(1337)        # 固定随机种子,使结果可复现

block_size = 8
batch_size = 32


def get_batch(split: str):
    """随机截取 batch_size 个小段,返回输入 x 与答案 y"""
    data = train_data if split == "train" else val_data
    # 随机选取 batch_size 个起始位置
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])   # 整体后移一位
    return x, y


xb, yb = get_batch("train")
print(f"输入 x 的形状:{xb.shape}")      # (32, 8)
print(f"答案 y 的形状:{yb.shape}")      # (32, 8)
print(f"第一小段 x:{xb[0].tolist()}")
print(f"第一小段 y:{yb[0].tolist()}")
get_batch 与 torch.stack 详解
### 变量说明 - `data` 一维长张量,整份语料的 token 序列,`shape=(1_000_000,)` - `block_size` 单条样本的上下文长度,如 `8` 或 `256` - `batch_size` 一个 batch 里几条样本,如 `4` - `ix` 长度为 `batch_size` 的随机起点张量,由 `torch.randint` 生成;减 `block_size` 确保切片不越界 ### 列表推导与 stack ```python x = torch.stack([data[i:i + block_size] for i in ix]) ``` 列表推导对每个起点切出长度为 `block_size` 的片段(`shape=(8,)`),`torch.stack` 沿新维度拼成 `shape=(batch_size, block_size)` 的二维张量。 ### 目标张量 y ```python y = torch.stack([data[i + 1:i + block_size + 1] for i in ix]) ``` 起点右移一位。`x` 的每个位置对应 `y` 的同位置,即下一个 token: ```text x: [18, 47, 29, 56, 9, 31, 75, 4] y: [47, 29, 56, 9, 31, 75, 4, 62] ``` 一条长度为 `block_size` 的样本包含 `block_size` 个独立预测任务,训练效率因此被拉满。 ### torch.stack vs torch.cat ```python torch.stack([a, b, c]).shape # (3, block_size) torch.cat([a, b, c]).shape # (3 * block_size,) ``` 这里需要 `(batch_size, block_size)` 的二维张量,所以用 `stack`。

结构最简单的模型:Bigram

接下来搭建模型。在 PyTorch 中,模型是一个继承自 nn.Module 的类。Bigram 模型只需要一样东西:一张表。

import torch.nn as nn
from torch.nn import functional as F


class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size: int):
        super().__init__()
        # 一张 vocab_size × vocab_size 的表,它就是这个模型的全部参数
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx 形状为 (B, T),每个元素是一个字的编号
        logits = self.token_embedding_table(idx)   # 查表,得到 (B, T, vocab_size)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

nn.Embedding(vocab_size, vocab_size) 是一张 vocab_size 行、vocab_size 列的表。每个字的编号对应表中的一行,这一行包含 vocab_size 个数值,代表"在看到这个字之后,下一个字是词表中各个字的可能性评分"。

这个模型的局限正在于此:它做预测时只查阅了当前这一个字所对应的那一行,完全没有参考更早的内容。对于"今天天气真"这五个字,它只取最后一个"真"去查表,前面四个字全部被忽略。这个模型只具备一个字的记忆。

代码中的 B、T、C 是三个维度的简称:B 是批次大小(batch,此处为 32),T 是序列长度,也称时间步(time,此处为 8),C 是每个位置上评分的个数(channel,等于 vocab_size)。这套 (B, T, C) 记法在后续每一篇都会用到,此处先建立印象。

模型如何打分:logits、softmax 与损失

forward 方法中出现了三个新的概念,需要解释清楚,它们是训练机制的核心。

第一个是 logitslogitslogits 是模型在最后一步直接输出的一组原始数值,词表里每个字各对应一个。它们没有经过任何规范化,可正可负、大小不限,本身并不是概率,只是相对的偏好评分——某个字的 logit 越大,模型越倾向于选它。要把这组评分变成真正相加为 1 的概率,还需要紧接着的 softmax。 。模型对"下一个字"的预测,不是直接给出一个字,而是为词表中的每个字给出一个评分。这一组原始评分就称为 logits。评分有高有低,可正可负。

第二个是 softmax。logits 是原始评分,并不直观。softmax 是一个函数,它把这一组评分换算成一组概率——每个值都介于 0 与 1 之间,且全部相加恰好等于 1。经过这一步,我们就能够说"下一个字是'好'的概率是 0.42"这样的话。

第三个是损失(loss)。我们需要一个数值来衡量"模型这一次预测得有多差"。这个数值就是损失。此处采用的损失函数称为交叉熵cross entropycross entropyCross entropy is the loss function used for next-token prediction. It looks only at the probability the model assigned to the correct token: the higher that probability, the lower the loss. Concretely it is the negative logarithm of that probability, so being confident yet wrong is punished very heavily. Driving cross entropy down is exactly the same as making the model assign high probability to the right answers. ,其逻辑相当朴素:考察模型为"正确答案那个字"所给出的概率有多高。概率给得越高,损失越低;概率给得越低,损失越高。如果模型把正确答案的概率压到接近 0,损失就会变得很大。

整个训练过程,目标就是设法让这个损失数值不断变小。损失小,意味着模型为正确答案给出的概率高,也就意味着它预测得准确。

这里有一个可以用来核对的细节。训练刚开始时,参数处于随机状态,模型对词表中每个字几乎一视同仁地随机预测,为正确答案给出的概率大约是 1/vocab_size。此时交叉熵损失约等于 ln(vocab_size)。假设你的 vocab_size 是 3000,那么训练初期的损失应当在 8 附近(因为 ln(3000) ≈ 8)。稍后运行程序时,如果初始损失与这个估算值吻合,就说明整套流程正常。

让模型生成文字

模型还需要具备生成文字的能力。为 Bigram 模型补充一个 generate 方法——它做的就是第一篇所说的连续预测:预测下一个字,接续到末尾,再以新的文字为输入继续预测,如此循环。

    def generate(self, idx, max_new_tokens: int):
        # idx 是当前已有的文字(编号形式),形状为 (B, T)
        for _ in range(max_new_tokens):
            logits, _ = self(idx)              # 前向计算,得到评分
            logits = logits[:, -1, :]          # 只取最后一个位置的评分
            probs = F.softmax(logits, dim=-1)  # 换算成概率
            idx_next = torch.multinomial(probs, num_samples=1)  # 依概率抽取一个字
            idx = torch.cat((idx, idx_next), dim=1)             # 接续到末尾
        return idx

需要留意 torch.multinomial 这一步:它并非固定地选取概率最高的字,而是依照概率分布随机抽取。概率为 0.42 的字,有 42% 的机会被抽中。这一处随机性使得模型每次生成的内容不会完全相同,也使生成结果更自然。

训练循环:核心的五行代码

准备工作就绪。训练循环本身,其核心其实只有五行:

model = BigramLanguageModel(vocab_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for step in range(5000):
    xb, yb = get_batch("train")              # 1. 取一批数据
    logits, loss = model(xb, yb)             # 2. 前向:计算预测与损失
    optimizer.zero_grad(set_to_none=True)    # 3. 清空上一轮的梯度
    loss.backward()                          # 4. 反向:计算每个参数的调整方向
    optimizer.step()                         # 5. 实际调整参数

    if step % 500 == 0:
        print(f"step {step}: loss {loss.item():.4f}")

逐行说明。第 1 行取出一批数据。第 2 行把数据输入模型,得到评分与损失,这一步称为前向传播。第 3 行先把上一轮残留的梯度清零,至于为什么需要清零,下一篇会解释。第 4 行 loss.backward() 是关键,它会计算出"每一个参数应当朝哪个方向、调整多少",这个过程称为反向传播。第 5 行 optimizer.step() 依照计算出的方向,把所有参数实际调整一次。

optimizer 是优化器,负责"调整"这个动作。lr=1e-3 是学习率,控制每次调整的幅度。loss.backward() 与优化器内部究竟如何工作,正是下一篇——反向传播——要完整讲解的内容。在本篇中,可以先把它视为"依据误差方向调整参数"的一个待解释环节,先让循环运转起来。

运行并观察结果

把前两篇的数据准备代码,与本篇的全部代码合并到一个文件中运行。可以看到损失被逐步打印出来,从初始的 8 左右,一步步下降,最终降到 3 到 4 之间趋于稳定。

损失持续下降,就证明这套机制确实运转起来了——模型正被一步步调整得预测更加准确。

训练结束后,让模型生成一段文字:

context = torch.zeros((1, 1), dtype=torch.long)   # 以编号 0 对应的字作为起始
generated = model.generate(context, max_new_tokens=300)[0].tolist()
print(decode(generated))

生成结果不必抱有过高期待。Bigram 模型生成的文字,整体上仍然难以读通。但若仔细观察,它与训练之前的随机字符已有明显不同:常见字出现的频率提高了,标点的位置不再完全失当,相邻的两个字偶尔还算搭配。这正是这个模型的能力上限:它只有一个字的记忆,所能学到的,也就是"某个字之后通常跟随哪个字"这类最表层的规律。

但这并不影响本篇的结论。重要的是:输入数据、计算损失、反向传播、调整参数——这一整套循环,已经完整且正确地运转起来了。后续几篇所要做的,全部是把模型从"只依据一个字"升级为"能够参考整段上下文",而这套训练循环的骨架,几乎不需要再做改动。

本篇要点

  • 训练时数据分小段输入,block_size 是一次查看的长度,batch_size 是一次输入的段数。
  • 输入 x 与答案 y 的关系是"整体错开一位",一小段中包含了多道"预测下一个字"的题目。
  • Bigram 模型只用一张 vocab_size × vocab_size 的表,只依据前一个字预测,能力有限。
  • logits 是模型对每个字的原始评分,softmax 把它换算成概率,交叉熵损失衡量预测的偏差。
  • 训练循环的核心是五行:取数据、前向计算损失、清空梯度、反向传播、调整参数。
  • Bigram 生成的文字仍难以读通,但训练循环已完整运转,这正是本篇的目标。

下一篇

本篇中,loss.backward()optimizer.step() 是作为待解释的环节使用的。而它们恰恰是"模型如何知道参数该朝哪个方向调整"这一问题的答案所在。下一篇将完整拆解这个环节,讲解反向传播——整个训练过程的核心。

参考资料

版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。

(采用 CC BY-NC-SA 4.0 许可协议进行授权)

本文标题:最简单的模型:先让训练循环运转起来

本文链接:https://www.sshipanoo.com/blog/ai/mini-gpt/03-最笨的模型先让循环转起来/

本文最后一次更新为 天前,文章中的某些内容可能已过时!