Transformer推理
1. 自回归生成机制
语言模型的生成过程是一个离散的时间序列预测问题。
1.1 核心逻辑
模型根据已有的 Token 序列 ,预测下一个 的概率分布。一旦采样得到 ,将其拼接到序列末尾,作为下一轮预测的输入。
1.2 上下文窗口限制
Transformer 模型具有固定的上下文窗口长度(context_length)。在推理过程中,生成的序列长度可能会超过模型的处理上限。
代码实现:
idx_cond = generated[:, -self.context_length:]
逻辑:始终只截取序列末尾的
context_length个 Token 输入模型。这是一个滑动窗口操作,确保位置编码(RoPE)的索引始终在 的合法范围内。
2. 采样策略 (Sampling Strategies)
直接选择概率最大的词(Greedy Search)往往会导致生成的文本重复且枯燥。为了生成多样化的文本,我们需要对 Logits(未归一化的预测分值)进行干预。
2.1 Temperature (温度缩放)
温度参数 用于调整概率分布的熵(Entropy)。
:缩小 Logits 之间的差异,概率分布趋向均匀(High Entropy)。生成的文本随机性增加。
:放大 Logits 之间的差异,高分值 Token 的概率显著增加(Low Entropy)。生成的文本更加确定和保守。
代码实现:
if temperature != 1.0:
logits = logits / (temperature + 1e-8)
2.2 Top-p (Nucleus) Sampling
Top-p 采样动态截断概率分布的尾部,仅在累积概率达到 p 的最小集合中进行采样。相比 Top-k 采样,Top-p 能更好地适应不同确定性的上下文。
算法步骤:
排序:将词表 Logits 按降序排列。
累积:计算 Softmax 后的累积概率分布(CDF)。
截断:找到累积概率超过阈值 p 的位置,将该位置之后的 Token 概率置为 0(Logit 置为 )。
重归一化:对剩余的 Logits 重新计算 Softmax。
3. 代码实现
class TransformerLM(nn.Module):
def __init__(self, vocab_size: int, max_seq_len: int, d_model: int,
num_layers: int, num_heads: int, d_ff: int, rope_theta: float,
device=None, dtype=None,
# 新增实验参数
use_rms_norm: bool = True,
norm_mode: str = "pre",
ffn_type: str = "swiglu"):
super().__init__()
self.max_seq_len = max_seq_len
self.context_length = max_seq_len
# 1. Token Embedding 层
self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
# 2. 堆叠 Transformer Blocks
# 将实验参数透传给每一个 Block
self.layers = nn.ModuleList([
TransformerBlock(
d_model, num_heads, d_ff, max_seq_len, rope_theta,
device=device,dtype=dtype,
# use_rms_norm=use_rms_norm,
# norm_mode=norm_mode,
# ffn_type=ffn_type
)
for _ in range(num_layers)
])
# 3. 最终的输出层
# 如果全局禁用了 Norm, 这里的 Final Norm 也要变成 Identity
if use_rms_norm:
self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
else:
"""
forward(input):
return input
"""
self.ln_final = nn.Identity()
# 最后是一个 Linear 层映射回词表大小 (LM Head)
self.lm_head = Linear(d_model, vocab_size, device=device, dtype=dtype)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
b, s = token_ids.shape
# 准备位置信息用于 RoPE, shape: [S] -> [1, S] -> [B, S]
token_positions = torch.arange(s, device=token_ids.device).unsqueeze(0).expand(b, s)
# 1. Embedding
x = self.token_embeddings(token_ids)
# 2. 逐层通过 Transformer Blocks
for layer in self.layers:
x = layer(x, token_positions=token_positions)
# 3. 最终归一化 (如果 use_rms_norm=False, 这里就是直通)
x = self.ln_final(x)
# 4. 投影到词表空间得到 logits
return self.lm_head(x)
@torch.no_grad()
def generate(
self,
prompt_ids: torch.Tensor,
max_new_tokens: int,
eos_token_id: int = None,
temperature: float = 1.0,
top_p: float = 1.0
) -> torch.Tensor:
"""
从模型生成文本 ID 序列。
参数:
prompt_ids: 提示词 ID (Batch, Seq_len)
max_new_tokens: 最多生成的词数
eos_token_id: 停止生成的 Token ID (如 <|endoftext|>)
temperature: 温度系数 (越高越随机, 越低越稳定)
top_p: 核采样阈值
"""
# 设置为评估模式
self.eval()
# 将输入拷贝一份, 避免修改原始数据
generated = prompt_ids.clone()
for _ in range(max_new_tokens):
# 1. 裁剪输入: 模型只能处理 context_length 长度的内容
# 如果生成的序列过长, 只取最后的 context_length 个词
idx_cond = generated[:,-self.context_length:]
# 2. 前向传播得到 Logits
# 我们只关心最后一个时间步的预测
logits = self.forward(idx_cond) # (Batch, T, Vocab)
logits = logits[:, -1, :] # (Batch, Vocab)
# 3. 应用温度 (Temperature)
if temperature != 1.0:
logits = logits / (temperature + 1e-8) # 加个 epsilon 防止除以 0
# 4. 应用 Top-P (Nucleus Sampling) 过滤
if top_p < 1.0:
logits = self._top_p_filter(logits, top_p)
# 5. 归一化并采样
probs = softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # (Batch, 1)
# 6. 拼接新词
generated = torch.cat((generated, next_token), dim=1)
# 7. 如果遇到了 EOS, 提前结束生成
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return generated
def _top_p_filter(self, logits: torch.Tensor, p: float) -> torch.Tensor:
"""内部工具函数: 执行 Top-P 截断"""
# 对词表分值进行降序排序
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
# 计算累计概率分布
cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1)
# 创建掩码: 我们要去掉累计概率超过 p 的 Token
# 逻辑: 保留最小的集合 V(p), 使其概率之和 >= p
# 我们把所有超过 p 的位置标记为 True (需要移除)
sorted_indices_to_remove = cumulative_probs > p
# 关键修正: 确保至少保留第一个词 (最高概率词),
# 并且我们要保留第一个"使概率超过 p" 的那个词。
# 做法是把标记位向右移动一格。
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
# 将被移除的 Token 分数设为负无穷
# 这里需要利用 scatter 将排序后的掩码映射回原始词表索引位置
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, float('-inf'))
return logits
本文是原创文章,采用 CC BY-NC-ND 4.0 协议,完整转载请注明来自 程序员Orion
评论
匿名评论
隐私政策
你无需删除空行,直接评论以获取最佳展示效果