01.BPE原理

1.1 一句话说明

BPE 就是不断地把出现频率最高的相邻字节对h合并成一个新 token,直到词表达到目标大小。

1.2 为什么要做BPE

BPE缩短了序列长度:

  • 单个 token 的计算开销与它包含的字节数无关(1 字节 token 和 10 字节 token,进入模型后都是相同维度的向量,计算量相同)。

  • BPE 加速模型的真正原因:用多字节 token 代替多个单字节 token → 序列长度变短 → 自注意力的复杂度O(L^2)降低 → 整体计算量下降。

  • 例如: 原本用字节级编码(每个字节一个 token),一个 10 字符的单词就需要 10 个 token。BPE 将常见多字节组合(如 "Hello")合并成一个 token,从而大幅缩短序列长度。

1.3 预分词

  • 什么是预分词?:在 BPE 统计相邻对频率 之前,先把文本切成一个个更粗粒度的块(比如单词、标点、数字)。

  • 为什么做预分词?

    • 防止跨类型合并:例如 dog 末尾的 g 与后面的 . 不应合并,否则会破坏语义边界。

    • 保护空格:通常将单词前的空格与单词本身绑定为一个整体,使得空格可以像普通 token 一样参与合并,同时保留显式的0词边界信息。

    • 提升合理性:语言模型通常以自然词为单位,预分词让 BPE 的合并行为更符合人类语言的边界。

1.4 特殊字符(Token)

  • 什么是特殊Token:特殊 token 是人为添加到词表中的、不参与 BPE 合并的原子标记,例如 <|endoftext|><|sep|>[PAD][CLS] 等。

  • 作用

    • 分隔不同的文本片段(如文档结束符)

    • 传递控制信号(如分类任务的 [CLS]

    • 填充序列至相同长度([PAD]

  • 训练阶段的处理

    • 切分语料:使用带捕获组的正则 re.split(f"({special_regex})", text) 将文本按特殊 token 切开。

    • 隔离统计:只保留普通文本片段(p not in special_tokens)送入 BPE 统计频率。特殊 token 本身不参与任何相邻对统计。

    • 最终加入:训练结束后,将特殊 token 的 UTF‑8 字节表示追加到词表中,分配新的 ID。

  • 编码阶段的处理

    • 优先匹配:构造正则时按长度从长到短排序,确保 "<|ab|>" 不会被拆成 "<|a|>" + "b|>"

    • 切分与路由:遍历文本,匹配到的特殊 token 直接转为 ID(通过 byte_to_id);特殊 token 之间的普通文本片段调用 _encode_text_segment 进行正常 BPE 编码。

    • 不参与合并:特殊 token 在编码时始终保持原子性,不会被 BPE 规则拆分或与其他 token 合并。

1.5 BPE训练步骤/举例:

输入:原始文本语料
输出:词表(vocab: dict[int, bytes]) + 合并顺序(merges: list[tuple[bytes, bytes]]

  1. 初始化

    • 初始词表(vocab)为0-255,每个字节是一个独立token。

  2. 预分词(Pre‑tokenization)

    • 输入文本"Hello world!"

    • GPT‑2 正则匹配结果:["Hello", " world", "!"]

  3. 统计原始单词频率

    • 对预分词后的每个块,将其编码为 UTF‑8 字节序列,再拆成单字节元组。

    • "Hello" → (b'H', b'e', b'l', b'l', b'o')

    • " world" → (b' ', b'w', b'o', b'r', b'l', b'd')

    • "!" → (b'!',)

    • 统计整个语料中每个元组出现的次数(即“单词”频率)。

    • 假设语料中 "Hello" 出现 3 次," world" 出现 5 次,"!" 出现 2 次。

  4. 初始相邻对统计

    • 遍历每个单词的字节列表,枚举所有相邻对,累加频率(单词频率乘该对在单词内出现次数)。

    • 单词 (b'H', b'e', b'l', b'l', b'o') 频率 3

    • 产生对:(b'H',b'e')(b'e',b'l')(b'l',b'l')(b'l',b'o'),每个对增加 3 次。

    • 单词 (b' ', b'w', b'o', b'r', b'l', b'd') 频率 5

    • 产生 (b' ',b'w')(b'w',b'o') 等。

  5. 合并循环(假设要合并 3 次)

    • 第 1 轮:找出全局频率最高的相邻对。

      • 假设 (b' ', b'w') 频率最高(例如 5 次),选定它。

      • 创建新 token b' w'(空格+w),分配新 ID(256)。

      • 更新:所有单词中凡出现 (b' ', b'w') 的地方替换为 b' w'

      • 受影响单词:" world" → (b' w', b'o', b'r', b'l', b'd')

      • 重新统计相邻对(仅受影响的部分增量更新,这里简化描述)。

    • 第 2 轮:重新计算后,假设 (b' w', b'o') 频率最高(例如 5 次)。

      • 合并 → b' wo'(ID 257)。

      • 更新:" world" → (b' wo', b'r', b'l', b'd')

    • 第 3 轮:假设 (b'l', b'o') 在 "Hello" 中频繁出现,频率最高。

      • 合并 → b'lo'(ID 258)。

      • 更新:"Hello" → (b'H', b'e', b'l', b'lo')(注意 b'l' 和 b'lo' 相邻,后续可能继续合并)。

  6. 最终词表构成

    • ID 0–255:原始单字节

    • ID 256:b' w'

    • ID 257:b' wo'

    • ID 258:b'lo'

    • …(继续合并直到词表大小)

  7. 特殊 token(如 <|endoftext|>)在训练结束后追加到词表末尾。

1.6 倒排索引优化性能

  • 原理:只需遍历一次,剩下的在倒排索引里面进行。

    • 倒排索引让 BPE 从“每轮全量扫描”变成“每轮只更新局部”,复杂度从 O(合并次数 × 语料大小) 降为近似 O(语料大小 + 合并次数 × 平均受影响单词数)。

  • 什么是倒排索引:倒排索引是一种:从内容(值)快速定位到位置(键)的数据结构。
    与传统“文档 → 词”的正向索引相反,它存储“词 → 包含该词的文档列表”

  • 没有倒排索引时:每轮合并需要遍历所有单词,逐个扫描其字节列表,查找是否包含 best_pair → O(总单词数 × 平均单词长度)。

  • 有倒排索引时

    • 初始化时只遍历一次所有单词,填好 indices

    • 每轮合并:直接通过 indices[best_pair] 瞬间拿到所有受影响的单词下标 → O(1)

    • 只更新这些单词,同时维护 indices 中受影响的条目(删除旧 pair、添加新 pair)。

1.7 BPE编码步骤/举例

输入:原始字符串(可能含特殊 token,如 <|endoftext|>
输出:整数 ID 列表

  1. 特殊 token 切分(若有)

    • 用正则 (special1|special2|...) 扫描文本,按捕获组分割。

    • 结果列表交替出现:普通文本片段、特殊 token、普通文本片段…

    • 特殊 token 直接通过 byte_to_id 映射为 ID(不经过 BPE)。

  2. 对每个普通文本片段
    a. 预分词:用 GPT‑2 正则将片段切成单词/标点块。
    b. 对每个块

    • 将块编码为 UTF‑8 字节序列,再拆成单字节列表(如 b"Hello" → [b'H',b'e',b'l',b'l',b'o'])。

    • 反复合并:在当前字节列表中,找到 merges 字典里 rank 最小(即最早训练) 的相邻对,合并所有出现位置。

    • 直到无法合并(没有相邻对在 merges 中)。

    • 将最终的每个字节块通过 byte_to_id 转为 ID。
      c. 将各块的 ID 顺序拼接

  3. 按原顺序 合并所有片段的 ID 序列(普通片段 ID + 特殊 token ID + 普通片段 ID…)。

1.8 BPE解码步骤/举例

  1. 每个 ID 查 id_to_byte 得到对应的字节序列(如 b'Hello'b'<|endoftext|>' 等)。

  2. 将所有字节序列按顺序拼接成一个完整的 bytes 对象。

  3. 用 decode("utf-8", errors="replace") 转换为字符串。

02. 训练阶段的关键步骤

  1. 初始化词表为 0-255 字节

  2. 读取语料,遇到特殊 token 时 用带捕获组的正则 re.split(f"({special_regex})", text) 切分,特殊 token 单独拎出来,不参与后续 BPE 统计,保证它们独立。

  3. 用 GPT‑2 正则对普通文本做 【预分词】,目的是 防止跨越类型合并(例如字母和标点符号不会被合并),同时保护空格(将空格与后面的单词绑定为一个整体)

  4. 统计每个“单词”的出现频率,存储为 raw_counts

  5. 构建倒排索引 indices 和频率表 stats,加速合并。

  6. 循环合并 num_merges 次:

    1. stats 中选出 【频率最高、同频字典序最大】 的 pair。

    2. 更新所有包含该 pair 的单词:替换 pair、调整相邻对统计。

    3. 利用 indices 快速定位受影响的单词。

  7. 将合并产生的 token 和特殊 token 加入词表。

03. 编码(推理)阶段的核心逻辑

  • 如果文本中包含特殊 token:用 【正则 + 捕获组】 切分,特殊 token 直接转 ID,普通片段走 BPE。

  • 对普通片段:先用 【GPT‑2 正则】 切分成单词/标点块。

  • 对每个块:初始为 【字节列表】,然后反复:

    • 找到当前序列中 【合并优先级最高】 的 pair(即 merges 中 rank 最小的)。

    • 如果找不到,退出。

    • 否则,合并所有出现的位置,更新列表。

  • 最后查 byte_to_id 得到 ID 序列。

04. 两个最容易出错的细节

  • 训练时合并比较要用 【元组】 而不是列表(否则永远 false:元组是可哈希的,可以用作字典的键;列表是不可哈希的,不能作为字典的键)。

  • 特殊 token 正则必须 【按长度从长到短排序】,避免 "<|a|>" 错误匹配 "<|ab|>" 的前半部分。

05. 我踩过的一个实际坑(环境/运行)

  • WSL 内存不足导致断开 → 用 【流式读取 + .wslconfig 限制内存】 解决。

06.实际代码

BPE训练

# 基于[CS336-1讲义【这就是小C】.pdf](https://docs.qq.com/pdf/DYXNHZnhzdWpsb0t0?fromtype=pdf&nlc=1&errorpage_redirect_count=1)
# 添加了分段读取文档训练,解决了wsl内存不够用,断连无法跑通的问题。
import os
from collections import defaultdict, Counter
import regex as re
import json

def read_text_in_chunks(file_path, chunk_size=1024*1024):
    """生成器:分块读取文本文件,按行切分,确保每行完整。"""
    with open(file_path, "r", encoding="utf-8") as f:
        buffer = ""
        while True:
            chunk = f.read(chunk_size)
            if not chunk:
                break
            buffer += chunk
            lines = buffer.split("\n")
            # 除了可能不完整的最后一行,其余完整的行都 yield
            for line in lines[:-1]:
                yield line + "\n"
            buffer = lines[-1]   # 保留不完整的行
        if buffer:
            yield buffer

def train_bpe(
    input_path: str | os.PathLike,
    vocab_size: int,
    special_tokens: list[str],
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """
    训练字节级 BPE 分词器(流式读取,避免内存爆炸)。
    """
    # --- 1. 初始化基础词表 ---
    vocab = {i: bytes([i]) for i in range(256)}
    num_merges = vocab_size - 256 - len(special_tokens)
    if num_merges <= 0:
        # 若目标词表太小,直接返回基础词表+特殊token
        for s_tok in special_tokens:
            vocab[len(vocab)] = s_tok.encode("utf-8")
        return vocab, []

    # --- 2. 预编译正则(特殊 token 拆分 + GPT‑2 预分词)---
    if special_tokens:
        special_regex = "|".join(re.escape(t) for t in special_tokens)
    else:
        special_regex = None

    # GPT‑2 官方预分词正则(已验证)
    gpt2_pat = re.compile(
        r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    )

    # --- 3. 流式统计 raw_counts(单词频率)---
    raw_counts = Counter()

    for chunk in read_text_in_chunks(input_path):
        # 3a. 根据特殊 token 切分当前 chunk
        if special_regex:
            parts = re.split(f"({special_regex})", chunk)
            train_segments = [p for p in parts if p not in special_tokens]
        else:
            train_segments = [chunk]

        # 3b. 对每个普通片段做预分词,统计单词频率
        for segment in train_segments:
            words = gpt2_pat.findall(segment)
            for word in words:
                # 将单词转为字节元组 (b'H', b'i')
                token_tuple = tuple(bytes([b]) for b in word.encode("utf-8"))
                raw_counts[token_tuple] += 1

    # 如果没有统计到任何单词(空文件),直接返回基础词表
    if not raw_counts:
        for s_tok in special_tokens:
            vocab[len(vocab)] = s_tok.encode("utf-8")
        return vocab, []

    # --- 4. 构建高效数据结构(words_list, counts_list, stats, indices)---
    words_list = []
    counts_list = []
    for word_tuple, freq in raw_counts.items():
        words_list.append(list(word_tuple))   # 转为可修改的 list
        counts_list.append(freq)

    stats = defaultdict(int)
    indices = defaultdict(set)          # pair -> set of word indices

    for idx, word in enumerate(words_list):
        freq = counts_list[idx]
        for i in range(len(word) - 1):
            pair = (word[i], word[i+1])
            stats[pair] += freq
            indices[pair].add(idx)

    merges = []   # 记录合并顺序

    # --- 5. 迭代合并 ---
    for _ in range(num_merges):
        if not stats:
            break

        # 5a. 选择最佳 pair(频率最高,同频字典序最大)
        best_pair = max(stats.items(), key=lambda x: (x[1], x[0]))[0]
        if stats[best_pair] <= 0:
            break

        merges.append(best_pair)
        new_token = best_pair[0] + best_pair[1]

        # 5b. 获取所有包含该 pair 的单词索引(拷贝,因为循环中会修改 indices)
        relevant_indices = list(indices[best_pair])

        # 5c. 逐一更新受影响的单词
        for idx in relevant_indices:
            word = words_list[idx]
            freq = counts_list[idx]

            i = 0
            while i < len(word) - 1:
                if word[i] == best_pair[0] and word[i+1] == best_pair[1]:
                    # 1) 减少旧的相邻 pair 频率(左邻、右邻)
                    if i > 0:
                        prev_pair = (word[i-1], word[i])
                        stats[prev_pair] -= freq
                        if stats[prev_pair] == 0:
                            del stats[prev_pair]
                    if i < len(word) - 2:
                        next_pair = (word[i+1], word[i+2])
                        stats[next_pair] -= freq
                        if stats[next_pair] == 0:
                            del stats[next_pair]

                    # 2) 合并:替换第一个字节,删除第二个
                    word[i] = new_token
                    del word[i+1]

                    # 3) 添加新产生的相邻 pair
                    if i > 0:
                        new_prev = (word[i-1], word[i])
                        stats[new_prev] += freq
                        indices[new_prev].add(idx)
                    if i < len(word) - 1:
                        new_next = (word[i], word[i+1])
                        stats[new_next] += freq
                        indices[new_next].add(idx)

                    # 注意:合并后 i 不移动,因为当前位置已经是 new_token,
                    # 下一轮会检查 (new_token, word[i+1])
                else:
                    i += 1

        # 5d. 清理已完全合并的 best_pair(从 stats 和 indices 中删除)
        if best_pair in stats:
            del stats[best_pair]
        if best_pair in indices:
            del indices[best_pair]

    # --- 6. 构建最终词表 ---
    for pair in merges:
        new_id = len(vocab)
        vocab[new_id] = pair[0] + pair[1]

    for s_tok in special_tokens:
        vocab[len(vocab)] = s_tok.encode("utf-8")

    return vocab, merges
def bytes_to_unicode():
    """
    创建一个映射,将 0-255 字节映射为一组可见的 Unicode 字符。
    这是 GPT-2 源码的标准做法。
    """
    bs = list(range(ord("!"), ord("~")+1)) + list(range(ord("¡"), ord("¬")+1)) + list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

def save_tokenizer_files(vocab, merges, out_dir):
    os.makedirs(out_dir, exist_ok=True)

    # 初始化映射表
    byte_encoder = bytes_to_unicode()

    # 词表保存
    # 使用 byte_encoder 将 bytes 转换为可见字符串
    json_vocab = {
        k: "".join(byte_encoder[b] for b in v)
        for k, v in vocab.items()
    }
    with open(os.path.join(out_dir, "vocab.json"), "w", encoding="utf-8") as f:
        json.dump(json_vocab, f, indent=4)

    # 合并规则保存
    with open(os.path.join(out_dir, "merges.txt"), "w", encoding="utf-8") as f:
        for p1, p2 in merges:
            # 同样转换 p1 和 p2
            s1 = "".join(byte_encoder[b] for b in p1)
            s2 = "".join(byte_encoder[b] for b in p2)
            f.write(f"{s1} {s2}\n")

def main():
    input_path = "data/TinyStoriesV2-GPT4-train.txt"
    vocab_size = 10000 # 作业要求的词表大小
    # input_path = ""
    # input_path = ""
    # vocab_size = 1000 # 作业要求的词表大小

    special_tokens = ["<endoftext>"]
    output_dir = "data/TinyStoriesV2-GPT4-train"

    print(f"开始训练 BPE 分词器 (目标词表大小:{vocab_size})...")
    print("这可能需要几分钟,具体取决于你的 CPU 速度和倒排索引的效率")

    # 调用你之前写好的逻辑
    vocab, merges = train_bpe(input_path, vocab_size, special_tokens)

    # 保存结果
    save_tokenizer_files(vocab, merges, output_dir)

if __name__ == "__main__":
    main()

BPETokenizer

# 代码来源于教程:
# [CS336-1讲义【这就是小C】.pdf](https://docs.qq.com/pdf/DYXNHZnhzdWpsb0t0?fromtype=pdf&nlc=1&errorpage_redirect_count=1)
import regex as re  # 使用 regex 而非内置 re,因为它支持 Unicode 类别 (如 \p{L})
from collections.abc import Iterable

"""
For special_tokens:
    推理/编码阶段 (Tokenizer.encode)
        在模型使用分词器将文本转为 ID 时, 必须优先匹配特殊 Token。
    代码逻辑:
        正则匹配:构建一个包含所有特殊 Token 的正则表达式。
        优先级:先扫描文本, 一旦发现特殊 Token, 直接将其转为对应的 ID。
        普通处理:特殊 Token 之间的文本,再走正常的 GPT-预分词和 BPE 合并流程。
"""

class BPETokenizer:
    """
    字节级 BPE (Byte-Pair Encoding) 分词器实现。

    该分词器将任意字符串编码为整数 ID 序列, 并能将 ID 序列还原。
    它采用字节级处理, 确保不会出现未知词 (OOV) 错误。
    """

    def __init__(self, vocab: dict[int,bytes], merges: list[tuple[bytes, bytes]], special_tokens):
        """
        初始化分词器。
        
        参数:
            vocab: 词汇表,建立整数 ID 到 字节块(bytes) 的映射。
            merges: 合并规则列表。列表中的每一项是一个二元组 (bytes_a, bytes_b),
                    表示在训练过程中 bytes_a 和 bytes_b 被合并的顺序。
            special_tokens: 特殊标记列表 (如 <|endoftext|>), 这些标记不会被 BPE 规则拆分。
        """
        # 1. 建立双向映射, 方便查表
        self.vocab = vocab # ID -> 字节块
        self.id_to_byte = vocab
        self.byte_to_id = {v: k for k, v in vocab.items()} # 字节块 -> ID
        
        # 2. 将合并规则转换为Rank字典。
        # BPE 编码时, 必须优先应用在训练阶段较早出现的合并规则。
        # 字典结构为:{(byte_a, byte_b): 顺序索引}
        self.merges = {pair: i for i,pair in enumerate(merges)}
         
        self.special_tokens = special_tokens or []

        # 3. 构建特殊 Token 的正则表达式
        if self.special_tokens:
            # 关键:必须按照长度从长到短排序 (reverse=True)。
            # 这样正则引擎会优先匹配最长的特殊标记, 防止重叠标记 (如 <|a|><|b|>) 被错误拆分
            sorted_special = sorted(self.special_tokens, key=len, reverse=True)
            # 使用 re.escape 确保标记中的特殊字符 (如 | 或 [ ) 被当作普通字符处理
            special_pattern = "|".join(re.escape(t) for t in sorted_special)
            self.special_regex = re.compile(special_pattern)
        else:
            self.special_regex = None
        # 4. GPT-2 官方预分词正则表达式。
        # 它的作用是在应用 BPE 合并前, 先将文本切分成单词、标点、数字等逻辑块。
        # 这样做是为了防止 BPE 规则跨越单词或标点 (例如: 防止将 "dog" 的末尾和 "." 合并) 。
        self.gpt2_pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
    
    def encode(self, text: str) -> list[int]:
        """
        将输入的原始字符串编码为整数 ID 列表。

        该方法的核心逻辑是:
        1. 作为一个“协调者”,它辅助处理文本中的“特殊标记 (Special Tokens) ”和“普通文本”。
        2. 特殊标记 (如 <|endoftext|>) 被视为原子,直接映射为 ID,不参与 BPE 的拆分和合并。
        3. 普通文本片段则被交给底层逻辑执行预分词和 BPE 算法。

        参数:
            text: 需要编码的原始字符串 (例如 "Hello<|end|>World) 。
        
        返回:
            list[int]: 编码后的整数 ID 序列。
        """
        # --- 步骤 1:边界情况检查 ---
        # 如果输入是空字符或 None,直接返回空列表。
        # 这是为了防止后续逻辑在处理空文本时产生错误。
        if not text:
            return []
        
        # --- 步骤 2:情况 A - 快速路径 (Fast Path) ---
        # 如果我们在初始化时没有定义任何特殊标记 (或者特殊标记列表为空) ,
        # 那么整个文本都可以被视为一段连续的“普通文本”
        # 我们直接调用内部方法 _encode_text_segment 进行 BPE 处理并返回结果/
        if not self.special_regex:
            return self._encode_text_segment(text)
        
        # --- 步骤3:情况 B - 处理含有特殊标记的复杂文本 ---
        # 此时文本中可能混有普通文字和特殊标记,我们需要像“剪刀”一样把它们切开。
        tokens = []

        # last_pos 用于记录上一次匹配结束的位置,帮助我们定位“特殊标记”之间的“缝隙”。
        last_pos = 0

        # 使用 finditer 遍历文本中所有符号特殊标记模式的匹配项。
        # finditer 的好处是它提供了 match.start() 和 match.end(),
        # 这让我们能够精确地知道特殊标记在哪里开始,在哪里结束。
        for match in self.special_regex.finditer(text):

            # 3.1 提取并处理“前置普通文本”
            # 这里的区间是 [last_pos, match.start()]。
            # " hello <|endoftext|> world"
            # 这段文本是夹在两个特殊标记之间 (或者开头到第一个特殊标记之间) 的普通文字。
            pre_text = text[last_pos:match.start()]

            # 如果这两个标记之间确实有文字 (长度 > 0)
            if pre_text:
                # 调用核心 BPE 逻辑。_encode_text_segment 会执行:
                # 1. GPT-2 预分词正则切分。
                # 2. 字节化。
                # 3. 按照 merges 规则进行贪婪合并。
                tokens.extend(self._encode_text_segment(pre_text))
                # pre_tokens : [1,2,3,...] self._encode_text_segment: [4,5,6] tokens.extend -> [1,2,3,...,4,5,6]
                # token.append() : [1,2,3,...[4,5,6]]
            
            # 3.2 处理“当前特殊标记”
            # match.group() 拿到的就是被识别出来的特殊标记字符串 (如 "<|endoftext|>") 。
            special_tok = match.group()

            # 核心原则:特殊标记不参与 BPE 合并!
            # 我们直接将其编码为 UTF-8 字节,然后在词表中查找其 ID。
            # 注意:这些标记在 train_bpe 阶段必须已经被手动加入到了词表中。
            tokens.append(self.byte_to_id[special_tok.encode("utf-8")])

            # 3.3 更新游标
            # 将游标移动到当前匹配项的末尾,为寻找下一个片段做准备。
            last_pos = match.end()
        
        # --- 步骤 4:处理“收尾文本” ---
        # 如果最后一个特殊标记后面还有文字 (例如 "Hello<|end|>World" 中的 "World") ,
        # 或者整个文本根本没有特殊标记匹配 (虽然逻辑上 Case A 已处理,但这里是双重保险) ,
        # 我们需要处理从 last_pos 到字符串末尾的所有剩余字符。
        remaining_text = text[last_pos:]
        if remaining_text:
            # 剩余部分同样作为普通文本片段进行 BPE 编码。
            tokens.extend(self._encode_text_segment(remaining_text))
        
        # 返回拼接好的所有 ID 列表
        return tokens
    def _encode_text_segment(self, text: str) -> list[int]:
        """
        内部核心函数:对不含特殊 Token 的纯文本片段应用 BPE 合并逻辑
        """
        ids = []
        # 使用 GPT-2 正则进行预分词,将文本拆成单词/标点符号块
        # 例如:"Hello world!" -> ["Hello", " world", "!"]
        pre_tokens = self.gpt2_pat.findall(text)

        for p_tok in pre_tokens:
            # 第一步: 将当前片段转为字节序列,并将每个字节看作一个独立的“部分 (Part) ”
            # 例如: "Hello" -> [b'H', b'e', b'l', b'l', b'o']
            byte_parts = [bytes([b]) for b in p_tok.encode("utf-8")]

            # 第二步:反复执行合并,直到没有符号条件的合并规则为止
            while len(byte_parts) >= 2:
                # 在当前序列的所有相邻对中,寻找合并优先级最高 (Rank 最小) 的一对,即按照构造merge时添加pair的顺序进行合并。
                best_pair = None
                min_rank = float('inf')

                for i in range(len(byte_parts) - 1):
                    pair =  (byte_parts[i], byte_parts[i+1])
                    if pair in self.merges:
                        rank = self.merges[pair]
                        if rank < min_rank:
                            min_rank = rank
                            best_pair = pair

                # 如果找不到任何可以合并的规则,退出当前片段
                if best_pair is None:
                    break

                # 第三步:执行合并操作。
                # 遍历当前序列,将所有出现的 best_pair 替换成合并后的长字节块。
                new_byte_parts = []
                i = 0
                # [b'H', b'e', b'l', b'l', b'o', b'H', b'e'] -> [b'He', b'l', b'l', b'o', b'He']
                while i < len(byte_parts):
                    # 如果当前两个部分匹配最高优规则
                    if i < len(byte_parts) - 1 and (byte_parts[i], byte_parts[i+1]) == best_pair:
                        new_byte_parts.append(best_pair[0] + best_pair[1])
                        i += 2 # 跳过下一项,因为已经合并了
                    else:
                        new_byte_parts.append(byte_parts[i])
                        i += 1
                byte_parts = new_byte_parts # 更新序列,进入下一轮 while 循环
            
            # 第四步:将合并到极限后的所有字节块转换为词表中的 ID
            for part in byte_parts:
                ids.append(self.byte_to_id[part])
            
        return ids
    
    def decode(self, ids: list[int]) -> str:
        """
        将 ID 列表解码为原始字符串。
        """
        # 1. 根据 ID 查表找回字节块
        byte_segments = [self.id_to_byte[i] for i in ids]

        # 2. 将所有字节块按顺序拼接成一个完整的字节六
        full_bytes = b"".join(byte_segments)

        # 3. 将字节流解码为 UTF-8 字符串。
        # 使用 errors="replace" 非常关键: 因为 BPE 可能会生成不完整的字节序列
        # (例如 3 字节的中文字符只产生了一部分) , 此时不报错而是插入替换符 ()。
        return full_bytes.decode("utf-8", errors="replace")
    
    def encode_iterable(self, iterable: Iterable[str]) -> Iterable[int]:
        """
        内存高效的迭代编码器。

        参数:
            iterable: 一个可迭代的字符串对象 (例如文件句柄) 。
        返回:
            一个生成器, 逐个产出编码后的 ID。 用于处理无法一次性读入内存的大文件。
        """
        for chunk in iterable:
            # 对每一块文本进行编码, 并通过 yield 吐出结果
            yield from self.encode(chunk)