BPE编码器原理与实现
01.BPE原理
1.1 一句话说明
BPE 就是不断地把出现频率最高的相邻字节对h合并成一个新 token,直到词表达到目标大小。
1.2 为什么要做BPE
BPE缩短了序列长度:
单个 token 的计算开销与它包含的字节数无关(1 字节 token 和 10 字节 token,进入模型后都是相同维度的向量,计算量相同)。
BPE 加速模型的真正原因:用多字节 token 代替多个单字节 token → 序列长度变短 → 自注意力的复杂度O(L^2)降低 → 整体计算量下降。
例如: 原本用字节级编码(每个字节一个 token),一个 10 字符的单词就需要 10 个 token。BPE 将常见多字节组合(如
"Hello")合并成一个 token,从而大幅缩短序列长度。
1.3 预分词
什么是预分词?:在 BPE 统计相邻对频率 之前,先把文本切成一个个更粗粒度的块(比如单词、标点、数字)。
为什么做预分词?:
防止跨类型合并:例如
dog末尾的g与后面的.不应合并,否则会破坏语义边界。保护空格:通常将单词前的空格与单词本身绑定为一个整体,使得空格可以像普通 token 一样参与合并,同时保留显式的0词边界信息。
提升合理性:语言模型通常以自然词为单位,预分词让 BPE 的合并行为更符合人类语言的边界。
1.4 特殊字符(Token)
什么是特殊Token:特殊 token 是人为添加到词表中的、不参与 BPE 合并的原子标记,例如
<|endoftext|>、<|sep|>、[PAD]、[CLS]等。作用:
分隔不同的文本片段(如文档结束符)
传递控制信号(如分类任务的
[CLS])填充序列至相同长度(
[PAD])
训练阶段的处理:
切分语料:使用带捕获组的正则
re.split(f"({special_regex})", text)将文本按特殊 token 切开。隔离统计:只保留普通文本片段(
p not in special_tokens)送入 BPE 统计频率。特殊 token 本身不参与任何相邻对统计。最终加入:训练结束后,将特殊 token 的 UTF‑8 字节表示追加到词表中,分配新的 ID。
编码阶段的处理:
优先匹配:构造正则时按长度从长到短排序,确保
"<|ab|>"不会被拆成"<|a|>" + "b|>"。切分与路由:遍历文本,匹配到的特殊 token 直接转为 ID(通过
byte_to_id);特殊 token 之间的普通文本片段调用_encode_text_segment进行正常 BPE 编码。不参与合并:特殊 token 在编码时始终保持原子性,不会被 BPE 规则拆分或与其他 token 合并。
1.5 BPE训练步骤/举例:
输入:原始文本语料
输出:词表(vocab: dict[int, bytes]) + 合并顺序(merges: list[tuple[bytes, bytes]])
初始化
初始词表(vocab)为0-255,每个字节是一个独立token。
预分词(Pre‑tokenization)
输入文本:
"Hello world!"GPT‑2 正则匹配结果:
["Hello", " world", "!"]
统计原始单词频率
对预分词后的每个块,将其编码为 UTF‑8 字节序列,再拆成单字节元组。
"Hello"→(b'H', b'e', b'l', b'l', b'o')" world"→(b' ', b'w', b'o', b'r', b'l', b'd')"!"→(b'!',)统计整个语料中每个元组出现的次数(即“单词”频率)。
假设语料中
"Hello"出现 3 次," world"出现 5 次,"!"出现 2 次。
初始相邻对统计
遍历每个单词的字节列表,枚举所有相邻对,累加频率(单词频率乘该对在单词内出现次数)。
单词
(b'H', b'e', b'l', b'l', b'o')频率 3产生对:
(b'H',b'e')、(b'e',b'l')、(b'l',b'l')、(b'l',b'o'),每个对增加 3 次。单词
(b' ', b'w', b'o', b'r', b'l', b'd')频率 5产生
(b' ',b'w')、(b'w',b'o')等。
合并循环(假设要合并 3 次)
第 1 轮:找出全局频率最高的相邻对。
假设
(b' ', b'w')频率最高(例如 5 次),选定它。创建新 token
b' w'(空格+w),分配新 ID(256)。更新:所有单词中凡出现
(b' ', b'w')的地方替换为b' w'。受影响单词:
" world"→(b' w', b'o', b'r', b'l', b'd')。重新统计相邻对(仅受影响的部分增量更新,这里简化描述)。
第 2 轮:重新计算后,假设
(b' w', b'o')频率最高(例如 5 次)。合并 →
b' wo'(ID 257)。更新:
" world"→(b' wo', b'r', b'l', b'd')。
第 3 轮:假设
(b'l', b'o')在"Hello"中频繁出现,频率最高。合并 →
b'lo'(ID 258)。更新:
"Hello"→(b'H', b'e', b'l', b'lo')(注意b'l'和b'lo'相邻,后续可能继续合并)。
最终词表构成
ID 0–255:原始单字节
ID 256:
b' w'ID 257:
b' wo'ID 258:
b'lo'…(继续合并直到词表大小)
特殊 token(如
<|endoftext|>)在训练结束后追加到词表末尾。
1.6 倒排索引优化性能
原理:只需遍历一次,剩下的在倒排索引里面进行。
倒排索引让 BPE 从“每轮全量扫描”变成“每轮只更新局部”,复杂度从 O(合并次数 × 语料大小) 降为近似 O(语料大小 + 合并次数 × 平均受影响单词数)。
什么是倒排索引:倒排索引是一种:从内容(值)快速定位到位置(键)的数据结构。
与传统“文档 → 词”的正向索引相反,它存储“词 → 包含该词的文档列表”没有倒排索引时:每轮合并需要遍历所有单词,逐个扫描其字节列表,查找是否包含
best_pair→ O(总单词数 × 平均单词长度)。有倒排索引时:
初始化时只遍历一次所有单词,填好
indices。每轮合并:直接通过
indices[best_pair]瞬间拿到所有受影响的单词下标 → O(1)只更新这些单词,同时维护
indices中受影响的条目(删除旧 pair、添加新 pair)。
1.7 BPE编码步骤/举例
输入:原始字符串(可能含特殊 token,如 <|endoftext|>)
输出:整数 ID 列表
特殊 token 切分(若有)
用正则
(special1|special2|...)扫描文本,按捕获组分割。结果列表交替出现:普通文本片段、特殊 token、普通文本片段…
特殊 token 直接通过
byte_to_id映射为 ID(不经过 BPE)。
对每个普通文本片段:
a. 预分词:用 GPT‑2 正则将片段切成单词/标点块。
b. 对每个块:将块编码为 UTF‑8 字节序列,再拆成单字节列表(如
b"Hello"→[b'H',b'e',b'l',b'l',b'o'])。反复合并:在当前字节列表中,找到
merges字典里 rank 最小(即最早训练) 的相邻对,合并所有出现位置。直到无法合并(没有相邻对在
merges中)。将最终的每个字节块通过
byte_to_id转为 ID。
c. 将各块的 ID 顺序拼接。
按原顺序 合并所有片段的 ID 序列(普通片段 ID + 特殊 token ID + 普通片段 ID…)。
1.8 BPE解码步骤/举例
每个 ID 查
id_to_byte得到对应的字节序列(如b'Hello',b'<|endoftext|>'等)。将所有字节序列按顺序拼接成一个完整的
bytes对象。用
decode("utf-8", errors="replace")转换为字符串。
02. 训练阶段的关键步骤
初始化词表为 0-255 字节
读取语料,遇到特殊 token 时 用带捕获组的正则
re.split(f"({special_regex})", text)切分,特殊 token 单独拎出来,不参与后续 BPE 统计,保证它们独立。用 GPT‑2 正则对普通文本做 【预分词】,目的是 防止跨越类型合并(例如字母和标点符号不会被合并),同时保护空格(将空格与后面的单词绑定为一个整体)
统计每个“单词”的出现频率,存储为
raw_counts。构建倒排索引
indices和频率表stats,加速合并。循环合并
num_merges次:从
stats中选出 【频率最高、同频字典序最大】 的 pair。更新所有包含该 pair 的单词:替换 pair、调整相邻对统计。
利用
indices快速定位受影响的单词。
将合并产生的 token 和特殊 token 加入词表。
03. 编码(推理)阶段的核心逻辑
如果文本中包含特殊 token:用 【正则 + 捕获组】 切分,特殊 token 直接转 ID,普通片段走 BPE。
对普通片段:先用 【GPT‑2 正则】 切分成单词/标点块。
对每个块:初始为 【字节列表】,然后反复:
找到当前序列中 【合并优先级最高】 的 pair(即
merges中 rank 最小的)。如果找不到,退出。
否则,合并所有出现的位置,更新列表。
最后查
byte_to_id得到 ID 序列。
04. 两个最容易出错的细节
训练时合并比较要用 【元组】 而不是列表(否则永远 false:元组是可哈希的,可以用作字典的键;列表是不可哈希的,不能作为字典的键)。
特殊 token 正则必须 【按长度从长到短排序】,避免
"<|a|>"错误匹配"<|ab|>"的前半部分。
05. 我踩过的一个实际坑(环境/运行)
WSL 内存不足导致断开 → 用 【流式读取 +
.wslconfig限制内存】 解决。
06.实际代码
BPE训练
# 基于[CS336-1讲义【这就是小C】.pdf](https://docs.qq.com/pdf/DYXNHZnhzdWpsb0t0?fromtype=pdf&nlc=1&errorpage_redirect_count=1)
# 添加了分段读取文档训练,解决了wsl内存不够用,断连无法跑通的问题。
import os
from collections import defaultdict, Counter
import regex as re
import json
def read_text_in_chunks(file_path, chunk_size=1024*1024):
"""生成器:分块读取文本文件,按行切分,确保每行完整。"""
with open(file_path, "r", encoding="utf-8") as f:
buffer = ""
while True:
chunk = f.read(chunk_size)
if not chunk:
break
buffer += chunk
lines = buffer.split("\n")
# 除了可能不完整的最后一行,其余完整的行都 yield
for line in lines[:-1]:
yield line + "\n"
buffer = lines[-1] # 保留不完整的行
if buffer:
yield buffer
def train_bpe(
input_path: str | os.PathLike,
vocab_size: int,
special_tokens: list[str],
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
"""
训练字节级 BPE 分词器(流式读取,避免内存爆炸)。
"""
# --- 1. 初始化基础词表 ---
vocab = {i: bytes([i]) for i in range(256)}
num_merges = vocab_size - 256 - len(special_tokens)
if num_merges <= 0:
# 若目标词表太小,直接返回基础词表+特殊token
for s_tok in special_tokens:
vocab[len(vocab)] = s_tok.encode("utf-8")
return vocab, []
# --- 2. 预编译正则(特殊 token 拆分 + GPT‑2 预分词)---
if special_tokens:
special_regex = "|".join(re.escape(t) for t in special_tokens)
else:
special_regex = None
# GPT‑2 官方预分词正则(已验证)
gpt2_pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
# --- 3. 流式统计 raw_counts(单词频率)---
raw_counts = Counter()
for chunk in read_text_in_chunks(input_path):
# 3a. 根据特殊 token 切分当前 chunk
if special_regex:
parts = re.split(f"({special_regex})", chunk)
train_segments = [p for p in parts if p not in special_tokens]
else:
train_segments = [chunk]
# 3b. 对每个普通片段做预分词,统计单词频率
for segment in train_segments:
words = gpt2_pat.findall(segment)
for word in words:
# 将单词转为字节元组 (b'H', b'i')
token_tuple = tuple(bytes([b]) for b in word.encode("utf-8"))
raw_counts[token_tuple] += 1
# 如果没有统计到任何单词(空文件),直接返回基础词表
if not raw_counts:
for s_tok in special_tokens:
vocab[len(vocab)] = s_tok.encode("utf-8")
return vocab, []
# --- 4. 构建高效数据结构(words_list, counts_list, stats, indices)---
words_list = []
counts_list = []
for word_tuple, freq in raw_counts.items():
words_list.append(list(word_tuple)) # 转为可修改的 list
counts_list.append(freq)
stats = defaultdict(int)
indices = defaultdict(set) # pair -> set of word indices
for idx, word in enumerate(words_list):
freq = counts_list[idx]
for i in range(len(word) - 1):
pair = (word[i], word[i+1])
stats[pair] += freq
indices[pair].add(idx)
merges = [] # 记录合并顺序
# --- 5. 迭代合并 ---
for _ in range(num_merges):
if not stats:
break
# 5a. 选择最佳 pair(频率最高,同频字典序最大)
best_pair = max(stats.items(), key=lambda x: (x[1], x[0]))[0]
if stats[best_pair] <= 0:
break
merges.append(best_pair)
new_token = best_pair[0] + best_pair[1]
# 5b. 获取所有包含该 pair 的单词索引(拷贝,因为循环中会修改 indices)
relevant_indices = list(indices[best_pair])
# 5c. 逐一更新受影响的单词
for idx in relevant_indices:
word = words_list[idx]
freq = counts_list[idx]
i = 0
while i < len(word) - 1:
if word[i] == best_pair[0] and word[i+1] == best_pair[1]:
# 1) 减少旧的相邻 pair 频率(左邻、右邻)
if i > 0:
prev_pair = (word[i-1], word[i])
stats[prev_pair] -= freq
if stats[prev_pair] == 0:
del stats[prev_pair]
if i < len(word) - 2:
next_pair = (word[i+1], word[i+2])
stats[next_pair] -= freq
if stats[next_pair] == 0:
del stats[next_pair]
# 2) 合并:替换第一个字节,删除第二个
word[i] = new_token
del word[i+1]
# 3) 添加新产生的相邻 pair
if i > 0:
new_prev = (word[i-1], word[i])
stats[new_prev] += freq
indices[new_prev].add(idx)
if i < len(word) - 1:
new_next = (word[i], word[i+1])
stats[new_next] += freq
indices[new_next].add(idx)
# 注意:合并后 i 不移动,因为当前位置已经是 new_token,
# 下一轮会检查 (new_token, word[i+1])
else:
i += 1
# 5d. 清理已完全合并的 best_pair(从 stats 和 indices 中删除)
if best_pair in stats:
del stats[best_pair]
if best_pair in indices:
del indices[best_pair]
# --- 6. 构建最终词表 ---
for pair in merges:
new_id = len(vocab)
vocab[new_id] = pair[0] + pair[1]
for s_tok in special_tokens:
vocab[len(vocab)] = s_tok.encode("utf-8")
return vocab, merges
def bytes_to_unicode():
"""
创建一个映射,将 0-255 字节映射为一组可见的 Unicode 字符。
这是 GPT-2 源码的标准做法。
"""
bs = list(range(ord("!"), ord("~")+1)) + list(range(ord("¡"), ord("¬")+1)) + list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def save_tokenizer_files(vocab, merges, out_dir):
os.makedirs(out_dir, exist_ok=True)
# 初始化映射表
byte_encoder = bytes_to_unicode()
# 词表保存
# 使用 byte_encoder 将 bytes 转换为可见字符串
json_vocab = {
k: "".join(byte_encoder[b] for b in v)
for k, v in vocab.items()
}
with open(os.path.join(out_dir, "vocab.json"), "w", encoding="utf-8") as f:
json.dump(json_vocab, f, indent=4)
# 合并规则保存
with open(os.path.join(out_dir, "merges.txt"), "w", encoding="utf-8") as f:
for p1, p2 in merges:
# 同样转换 p1 和 p2
s1 = "".join(byte_encoder[b] for b in p1)
s2 = "".join(byte_encoder[b] for b in p2)
f.write(f"{s1} {s2}\n")
def main():
input_path = "data/TinyStoriesV2-GPT4-train.txt"
vocab_size = 10000 # 作业要求的词表大小
# input_path = ""
# input_path = ""
# vocab_size = 1000 # 作业要求的词表大小
special_tokens = ["<endoftext>"]
output_dir = "data/TinyStoriesV2-GPT4-train"
print(f"开始训练 BPE 分词器 (目标词表大小:{vocab_size})...")
print("这可能需要几分钟,具体取决于你的 CPU 速度和倒排索引的效率")
# 调用你之前写好的逻辑
vocab, merges = train_bpe(input_path, vocab_size, special_tokens)
# 保存结果
save_tokenizer_files(vocab, merges, output_dir)
if __name__ == "__main__":
main()
BPETokenizer
# 代码来源于教程:
# [CS336-1讲义【这就是小C】.pdf](https://docs.qq.com/pdf/DYXNHZnhzdWpsb0t0?fromtype=pdf&nlc=1&errorpage_redirect_count=1)
import regex as re # 使用 regex 而非内置 re,因为它支持 Unicode 类别 (如 \p{L})
from collections.abc import Iterable
"""
For special_tokens:
推理/编码阶段 (Tokenizer.encode)
在模型使用分词器将文本转为 ID 时, 必须优先匹配特殊 Token。
代码逻辑:
正则匹配:构建一个包含所有特殊 Token 的正则表达式。
优先级:先扫描文本, 一旦发现特殊 Token, 直接将其转为对应的 ID。
普通处理:特殊 Token 之间的文本,再走正常的 GPT-预分词和 BPE 合并流程。
"""
class BPETokenizer:
"""
字节级 BPE (Byte-Pair Encoding) 分词器实现。
该分词器将任意字符串编码为整数 ID 序列, 并能将 ID 序列还原。
它采用字节级处理, 确保不会出现未知词 (OOV) 错误。
"""
def __init__(self, vocab: dict[int,bytes], merges: list[tuple[bytes, bytes]], special_tokens):
"""
初始化分词器。
参数:
vocab: 词汇表,建立整数 ID 到 字节块(bytes) 的映射。
merges: 合并规则列表。列表中的每一项是一个二元组 (bytes_a, bytes_b),
表示在训练过程中 bytes_a 和 bytes_b 被合并的顺序。
special_tokens: 特殊标记列表 (如 <|endoftext|>), 这些标记不会被 BPE 规则拆分。
"""
# 1. 建立双向映射, 方便查表
self.vocab = vocab # ID -> 字节块
self.id_to_byte = vocab
self.byte_to_id = {v: k for k, v in vocab.items()} # 字节块 -> ID
# 2. 将合并规则转换为Rank字典。
# BPE 编码时, 必须优先应用在训练阶段较早出现的合并规则。
# 字典结构为:{(byte_a, byte_b): 顺序索引}
self.merges = {pair: i for i,pair in enumerate(merges)}
self.special_tokens = special_tokens or []
# 3. 构建特殊 Token 的正则表达式
if self.special_tokens:
# 关键:必须按照长度从长到短排序 (reverse=True)。
# 这样正则引擎会优先匹配最长的特殊标记, 防止重叠标记 (如 <|a|><|b|>) 被错误拆分
sorted_special = sorted(self.special_tokens, key=len, reverse=True)
# 使用 re.escape 确保标记中的特殊字符 (如 | 或 [ ) 被当作普通字符处理
special_pattern = "|".join(re.escape(t) for t in sorted_special)
self.special_regex = re.compile(special_pattern)
else:
self.special_regex = None
# 4. GPT-2 官方预分词正则表达式。
# 它的作用是在应用 BPE 合并前, 先将文本切分成单词、标点、数字等逻辑块。
# 这样做是为了防止 BPE 规则跨越单词或标点 (例如: 防止将 "dog" 的末尾和 "." 合并) 。
self.gpt2_pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
def encode(self, text: str) -> list[int]:
"""
将输入的原始字符串编码为整数 ID 列表。
该方法的核心逻辑是:
1. 作为一个“协调者”,它辅助处理文本中的“特殊标记 (Special Tokens) ”和“普通文本”。
2. 特殊标记 (如 <|endoftext|>) 被视为原子,直接映射为 ID,不参与 BPE 的拆分和合并。
3. 普通文本片段则被交给底层逻辑执行预分词和 BPE 算法。
参数:
text: 需要编码的原始字符串 (例如 "Hello<|end|>World) 。
返回:
list[int]: 编码后的整数 ID 序列。
"""
# --- 步骤 1:边界情况检查 ---
# 如果输入是空字符或 None,直接返回空列表。
# 这是为了防止后续逻辑在处理空文本时产生错误。
if not text:
return []
# --- 步骤 2:情况 A - 快速路径 (Fast Path) ---
# 如果我们在初始化时没有定义任何特殊标记 (或者特殊标记列表为空) ,
# 那么整个文本都可以被视为一段连续的“普通文本”
# 我们直接调用内部方法 _encode_text_segment 进行 BPE 处理并返回结果/
if not self.special_regex:
return self._encode_text_segment(text)
# --- 步骤3:情况 B - 处理含有特殊标记的复杂文本 ---
# 此时文本中可能混有普通文字和特殊标记,我们需要像“剪刀”一样把它们切开。
tokens = []
# last_pos 用于记录上一次匹配结束的位置,帮助我们定位“特殊标记”之间的“缝隙”。
last_pos = 0
# 使用 finditer 遍历文本中所有符号特殊标记模式的匹配项。
# finditer 的好处是它提供了 match.start() 和 match.end(),
# 这让我们能够精确地知道特殊标记在哪里开始,在哪里结束。
for match in self.special_regex.finditer(text):
# 3.1 提取并处理“前置普通文本”
# 这里的区间是 [last_pos, match.start()]。
# " hello <|endoftext|> world"
# 这段文本是夹在两个特殊标记之间 (或者开头到第一个特殊标记之间) 的普通文字。
pre_text = text[last_pos:match.start()]
# 如果这两个标记之间确实有文字 (长度 > 0)
if pre_text:
# 调用核心 BPE 逻辑。_encode_text_segment 会执行:
# 1. GPT-2 预分词正则切分。
# 2. 字节化。
# 3. 按照 merges 规则进行贪婪合并。
tokens.extend(self._encode_text_segment(pre_text))
# pre_tokens : [1,2,3,...] self._encode_text_segment: [4,5,6] tokens.extend -> [1,2,3,...,4,5,6]
# token.append() : [1,2,3,...[4,5,6]]
# 3.2 处理“当前特殊标记”
# match.group() 拿到的就是被识别出来的特殊标记字符串 (如 "<|endoftext|>") 。
special_tok = match.group()
# 核心原则:特殊标记不参与 BPE 合并!
# 我们直接将其编码为 UTF-8 字节,然后在词表中查找其 ID。
# 注意:这些标记在 train_bpe 阶段必须已经被手动加入到了词表中。
tokens.append(self.byte_to_id[special_tok.encode("utf-8")])
# 3.3 更新游标
# 将游标移动到当前匹配项的末尾,为寻找下一个片段做准备。
last_pos = match.end()
# --- 步骤 4:处理“收尾文本” ---
# 如果最后一个特殊标记后面还有文字 (例如 "Hello<|end|>World" 中的 "World") ,
# 或者整个文本根本没有特殊标记匹配 (虽然逻辑上 Case A 已处理,但这里是双重保险) ,
# 我们需要处理从 last_pos 到字符串末尾的所有剩余字符。
remaining_text = text[last_pos:]
if remaining_text:
# 剩余部分同样作为普通文本片段进行 BPE 编码。
tokens.extend(self._encode_text_segment(remaining_text))
# 返回拼接好的所有 ID 列表
return tokens
def _encode_text_segment(self, text: str) -> list[int]:
"""
内部核心函数:对不含特殊 Token 的纯文本片段应用 BPE 合并逻辑
"""
ids = []
# 使用 GPT-2 正则进行预分词,将文本拆成单词/标点符号块
# 例如:"Hello world!" -> ["Hello", " world", "!"]
pre_tokens = self.gpt2_pat.findall(text)
for p_tok in pre_tokens:
# 第一步: 将当前片段转为字节序列,并将每个字节看作一个独立的“部分 (Part) ”
# 例如: "Hello" -> [b'H', b'e', b'l', b'l', b'o']
byte_parts = [bytes([b]) for b in p_tok.encode("utf-8")]
# 第二步:反复执行合并,直到没有符号条件的合并规则为止
while len(byte_parts) >= 2:
# 在当前序列的所有相邻对中,寻找合并优先级最高 (Rank 最小) 的一对,即按照构造merge时添加pair的顺序进行合并。
best_pair = None
min_rank = float('inf')
for i in range(len(byte_parts) - 1):
pair = (byte_parts[i], byte_parts[i+1])
if pair in self.merges:
rank = self.merges[pair]
if rank < min_rank:
min_rank = rank
best_pair = pair
# 如果找不到任何可以合并的规则,退出当前片段
if best_pair is None:
break
# 第三步:执行合并操作。
# 遍历当前序列,将所有出现的 best_pair 替换成合并后的长字节块。
new_byte_parts = []
i = 0
# [b'H', b'e', b'l', b'l', b'o', b'H', b'e'] -> [b'He', b'l', b'l', b'o', b'He']
while i < len(byte_parts):
# 如果当前两个部分匹配最高优规则
if i < len(byte_parts) - 1 and (byte_parts[i], byte_parts[i+1]) == best_pair:
new_byte_parts.append(best_pair[0] + best_pair[1])
i += 2 # 跳过下一项,因为已经合并了
else:
new_byte_parts.append(byte_parts[i])
i += 1
byte_parts = new_byte_parts # 更新序列,进入下一轮 while 循环
# 第四步:将合并到极限后的所有字节块转换为词表中的 ID
for part in byte_parts:
ids.append(self.byte_to_id[part])
return ids
def decode(self, ids: list[int]) -> str:
"""
将 ID 列表解码为原始字符串。
"""
# 1. 根据 ID 查表找回字节块
byte_segments = [self.id_to_byte[i] for i in ids]
# 2. 将所有字节块按顺序拼接成一个完整的字节六
full_bytes = b"".join(byte_segments)
# 3. 将字节流解码为 UTF-8 字符串。
# 使用 errors="replace" 非常关键: 因为 BPE 可能会生成不完整的字节序列
# (例如 3 字节的中文字符只产生了一部分) , 此时不报错而是插入替换符 ()。
return full_bytes.decode("utf-8", errors="replace")
def encode_iterable(self, iterable: Iterable[str]) -> Iterable[int]:
"""
内存高效的迭代编码器。
参数:
iterable: 一个可迭代的字符串对象 (例如文件句柄) 。
返回:
一个生成器, 逐个产出编码后的 ID。 用于处理无法一次性读入内存的大文件。
"""
for chunk in iterable:
# 对每一块文本进行编码, 并通过 yield 吐出结果
yield from self.encode(chunk)