1. 自回归生成机制

语言模型的生成过程是一个离散的时间序列预测问题。

1.1 核心逻辑

模型根据已有的 Token 序列 x1:tx_{1:t},预测下一个 Tokenxt+1Token x_{t+1} 的概率分布。一旦采样得到 xt+1x_{t+1},将其拼接到序列末尾,作为下一轮预测的输入。
xt+1P(xx1:t)x_{t+1} \sim P(x|x_{1:t})

1.2 上下文窗口限制

Transformer 模型具有固定的上下文窗口长度(context_length)。在推理过程中,生成的序列长度可能会超过模型的处理上限。

  • 代码实现

idx_cond = generated[:, -self.context_length:]
  • 逻辑:始终只截取序列末尾的 context_length 个 Token 输入模型。这是一个滑动窗口操作,确保位置编码(RoPE)的索引始终在 [0,contextlength1][0, \text{contextlength} - 1] 的合法范围内。

2. 采样策略 (Sampling Strategies)

直接选择概率最大的词(Greedy Search)往往会导致生成的文本重复且枯燥。为了生成多样化的文本,我们需要对 Logits(未归一化的预测分值)进行干预。

2.1 Temperature (温度缩放)

温度参数 τ\tau 用于调整概率分布的熵(Entropy)。
Pi=exp(zi/τ)jexp(zj/τ)P_i = \frac{\exp(z_i / \tau)}{\sum_j \exp(z_j / \tau)}

  • τ>1\tau > 1:缩小 Logits 之间的差异,概率分布趋向均匀(High Entropy)。生成的文本随机性增加。

  • τ<1\tau < 1:放大 Logits 之间的差异,高分值 Token 的概率显著增加(Low Entropy)。生成的文本更加确定和保守。

  • 代码实现

if temperature != 1.0:
    logits = logits / (temperature + 1e-8)

2.2 Top-p (Nucleus) Sampling

Top-p 采样动态截断概率分布的尾部,仅在累积概率达到 p 的最小集合中进行采样。相比 Top-k 采样,Top-p 能更好地适应不同确定性的上下文。
算法步骤:

  1. 排序:将词表 Logits 按降序排列。

  2. 累积:计算 Softmax 后的累积概率分布(CDF)。

  3. 截断:找到累积概率超过阈值 p 的位置,将该位置之后的 Token 概率置为 0(Logit 置为 -\infty)。

  4. 重归一化:对剩余的 Logits 重新计算 Softmax。

3. 代码实现

class TransformerLM(nn.Module):
    def __init__(self, vocab_size: int, max_seq_len: int, d_model: int,
                 num_layers: int, num_heads: int, d_ff: int, rope_theta: float,
                 device=None, dtype=None,
                 # 新增实验参数
                 use_rms_norm: bool = True,
                 norm_mode: str = "pre",
                 ffn_type: str = "swiglu"):
        super().__init__()
        self.max_seq_len = max_seq_len
        self.context_length = max_seq_len
        # 1. Token Embedding 层
        self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)

        # 2. 堆叠 Transformer Blocks
        # 将实验参数透传给每一个 Block
        self.layers = nn.ModuleList([
            TransformerBlock(
                d_model, num_heads, d_ff, max_seq_len, rope_theta,
                device=device,dtype=dtype,
                # use_rms_norm=use_rms_norm,
                # norm_mode=norm_mode,
                # ffn_type=ffn_type
            )
            for _ in range(num_layers)
        ])

        # 3. 最终的输出层
        # 如果全局禁用了 Norm, 这里的 Final Norm 也要变成 Identity
        if use_rms_norm:
            self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
        else:
            """
            forward(input):
                return input
            """
            self.ln_final = nn.Identity()
        
        # 最后是一个 Linear 层映射回词表大小 (LM Head)
        self.lm_head = Linear(d_model, vocab_size, device=device, dtype=dtype)
    
    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:

        b, s = token_ids.shape

        # 准备位置信息用于 RoPE, shape: [S] -> [1, S] -> [B, S]
        token_positions = torch.arange(s, device=token_ids.device).unsqueeze(0).expand(b, s)

        # 1. Embedding
        x = self.token_embeddings(token_ids)

        # 2. 逐层通过 Transformer Blocks
        for layer in self.layers:
            x = layer(x, token_positions=token_positions)

        # 3. 最终归一化 (如果 use_rms_norm=False, 这里就是直通)
        x = self.ln_final(x)

        # 4. 投影到词表空间得到 logits
        return self.lm_head(x)
    
    @torch.no_grad()
    def generate(
        self,
        prompt_ids: torch.Tensor,
        max_new_tokens: int,
        eos_token_id: int = None,
        temperature: float = 1.0,
        top_p: float = 1.0
    ) -> torch.Tensor:
        """
        从模型生成文本 ID 序列。

        参数: 
            prompt_ids: 提示词 ID (Batch, Seq_len)
            max_new_tokens: 最多生成的词数
            eos_token_id: 停止生成的 Token ID (如 <|endoftext|>)
            temperature: 温度系数 (越高越随机, 越低越稳定)
            top_p: 核采样阈值
        """
        # 设置为评估模式
        self.eval()

        # 将输入拷贝一份, 避免修改原始数据
        generated = prompt_ids.clone()

        for _ in range(max_new_tokens):
            # 1. 裁剪输入: 模型只能处理 context_length 长度的内容
            # 如果生成的序列过长, 只取最后的 context_length 个词
            idx_cond = generated[:,-self.context_length:]

            # 2. 前向传播得到 Logits
            # 我们只关心最后一个时间步的预测
            logits = self.forward(idx_cond) # (Batch, T, Vocab)
            logits = logits[:, -1, :] # (Batch, Vocab)

            # 3. 应用温度 (Temperature)
            if temperature != 1.0:
                logits = logits / (temperature + 1e-8) # 加个 epsilon 防止除以 0
            
            # 4. 应用 Top-P (Nucleus Sampling) 过滤
            if top_p < 1.0:
                logits = self._top_p_filter(logits, top_p)
            
            # 5. 归一化并采样
            probs = softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1) # (Batch, 1)
            
            # 6. 拼接新词
            generated = torch.cat((generated, next_token), dim=1)

            # 7. 如果遇到了 EOS, 提前结束生成
            if eos_token_id is not None and (next_token == eos_token_id).all():
                break
        
        return generated 


    def _top_p_filter(self, logits: torch.Tensor, p: float) -> torch.Tensor:
        """内部工具函数: 执行 Top-P 截断"""
        # 对词表分值进行降序排序
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)

        # 计算累计概率分布
        cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1)

        # 创建掩码: 我们要去掉累计概率超过 p 的 Token
        # 逻辑: 保留最小的集合 V(p), 使其概率之和 >= p
        # 我们把所有超过 p 的位置标记为 True (需要移除)
        sorted_indices_to_remove = cumulative_probs > p

        # 关键修正: 确保至少保留第一个词 (最高概率词),
        # 并且我们要保留第一个"使概率超过 p" 的那个词。
        # 做法是把标记位向右移动一格。
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = False

        # 将被移除的 Token 分数设为负无穷
        # 这里需要利用 scatter 将排序后的掩码映射回原始词表索引位置
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits = logits.masked_fill(indices_to_remove, float('-inf'))

        return logits