CS336 任务一 实验
一、学习率
#!/bin/bash
# --- 1. 定义要跑试的学习率列表 ---
LR_LIST=(1e-4 3e-4 6e-4 1e-3 3e-3 6e-3 1e-2 3e-2)
# --- 2. 固定参数配置 ---
MAX_ITERS=7000
WARMUP_ITERS=700
MAX_NORM=1.0
BATCH_SIZE=32
CONTEXT_LEN=256
VOCAB_SIZE=10000
# 记录基础路径
OUTPUT_ROOT="model_result/sweep_lr"
WANDB_PROJECT="cs336-pretraining-TinyStories-LR"
# --- 3. 开始循环实验 ---
for LR in "${LR_LIST[@]}"; do
MIN_LR=$(awk "BEGIN {print $LR * 0.1}")
RUN_NAME="lr${LR}_min${MIN_LR}_step${MAX_ITERS}"
OUT_DIR="${OUTPUT_ROOT}/${RUN_NAME}"
echo "--------------------------------------------------------------"
echo "📊 实验启动: max_lr=$LR, min_lr=$MIN_LR"
echo "📛 监控名称: $RUN_NAME"
echo "--------------------------------------------------------------"
python cs336_basics/train.py \
--train_data_path data/TinyStoriesV2-GPT4-train.bin \
--valid_data_path data/TinyStoriesV2-GPT4-valid.bin \
--run_name "$RUN_NAME" \
--vocab_size "$VOCAB_SIZE" \
--num_layers 4 --num_heads 16 --d_model 512 --d_ff 1344 \
--max_iters "$MAX_ITERS" \
--batch_size "$BATCH_SIZE" \
--context_length "$CONTEXT_LEN" \
--lr "$LR" \
--min_lr "$MIN_LR" \
--warmup_iters "$WARMUP_ITERS" \
--max_norm "$MAX_NORM" \
--out_dir "$OUT_DIR" \
--device cuda \
--wandb_project "$WANDB_PROJECT"
if [ $? -ne 0 ]; then
echo "❌ 警告: 学习率 $LR 导致训练中断, 跳过。"
fi
done
echo "🎉 所有学习率扫参实验跑完! 请打开 WandB 观察曲线。"
二、BatchSize
#!/bin/bash
# --- 1. 核心参数设置 ---
# 基于你之前跑出 3e-4 时的实验结果
BASE_BS=32 # 之前跑出 3e-4 时的基准 Batch Size
BASE_LR=0.0003 # 也就是 3e-4
MAX_ITERS=7000 # 保持总迭代步数一致
WARMUP_ITERS=700 # 10% 的预热
MAX_NORM=1.0
# --- 2. 待测试的 Batch Size 列表 ---
# 从极小(1, 8)到典型(64, 128)再到极限值(256, 512)
# 注意: 如果 512 导致显存溢出(OOM), 脚本会自动跳过并继续
BS_LIST=(1 8 32 64) # 128 256 512 卡的显存不够,就线跑这些了
# BS_LIST=(1)
# --- 3. 路径配置 ---
OUTPUT_ROOT="model_result/sweep_bs"
WANDB_PROJECT="cs336-pretraining-TinyStories-BS"
# --- 4. 开始循环实验 ---
for BS in "${BS_LIST[@]}"; do
# 【核心数学逻辑】线性缩放学习率
# 公式: 当前LR = 最佳基准LR * (当前BS / 基准BS)
LR=$(awk "BEGIN {print $BASE_LR * ($BS / $BASE_BS)}")
# 保持 min_lr 为当前 max_lr 的 10%
MIN_LR=$(awk "BEGIN {print $LR * 0.1}")
RUN_NAME="bs${BS}_lr${LR}_step${MAX_ITERS}"
OUT_DIR="${OUTPUT_ROOT}/${RUN_NAME}"
echo "============================================================"
echo "🚀 启动实验: Batch Size = $BS"
echo "📈 线性缩放学习率: $LR (Min: $MIN_LR)"
echo "============================================================"
python cs336_basics/train.py \
--train_data_path data/TinyStoriesV2-GPT4-train.bin \
--valid_data_path data/TinyStoriesV2-GPT4-valid.bin \
--run_name "$RUN_NAME" \
--vocab_size 10000 \
--num_layers 4 --num_heads 16 --d_model 512 --d_ff 1344 \
--max_iters "$MAX_ITERS" \
--batch_size "$BS" \
--context_length 256 \
--lr "$LR" \
--min_lr "$MIN_LR" \
--warmup_iters "$WARMUP_ITERS" \
--max_norm "$MAX_NORM" \
--out_dir "$OUT_DIR" \
--device cuda \
--wandb_project "$WANDB_PROJECT"
# 错误处理: 如果目前运溢出或其他原因导致崩溃, 记录并尝试下一个
if [ $? -ne 0 ]; then
echo "⚠️ 警告: Batch Size $BS 运行失败 (可能是 OOM), 正在尝试下一组..."
continue
fi
done
echo "🎉 所有 Batch Size 消融实验已跑完!"
三、消融实验
baseline
python cs336_basics/train.py \
--train_data_path data/TinyStoriesV2-GPT4-train.bin \
--valid_data_path data/TinyStoriesV2-GPT4-valid.bin \
--run_name "baseline_bs32_lr6e4" \
--vocab_size 10000 \
--num_layers 4 --num_heads 16 --d_model 512 --d_ff 1344 \
--max_iters 7000 \
--batch_size 32 \
--context_length 256 \
--lr 6e-4 \
--min_lr 6e-5 \
--warmup_iters 700 \
--out_dir model_result/TinyStories_baseline \
--wandb_project "cs336-pretraining-TinyStories-Ablations" \
--device cuda
NoRms
python cs336_basics/train.py \
--train_data_path data/TinyStoriesV2-GPT4-train.bin \
--valid_data_path data/TinyStoriesV2-GPT4-valid.bin \
--no_rms_norm \
--run_name "ablation_no_rms_norm" \
--vocab_size 10000 \
--num_layers 4 --num_heads 16 --d_model 512 --d_ff 1344 \
--max_iters 7000 \
--batch_size 32 \
--context_length 256 \
--lr 6e-4 \
--min_lr 6e-5 \
--warmup_iters 700 \
--out_dir model_result/TinyStories_ablation_no_rms_norm \
--wandb_project "cs336-pretraining-TinyStories-Ablations" \
--device cuda
PostNorm
python cs336_basics/train.py \
--train_data_path data/TinyStoriesV2-GPT4-train.bin \
--valid_data_path data/TinyStoriesV2-GPT4-valid.bin \
--norm_mode "post" \
--run_name "ablation_post_norm" \
--vocab_size 10000 \
--num_layers 4 --num_heads 16 --d_model 512 --d_ff 1344 \
--max_iters 7000 \
--batch_size 32 \
--context_length 256 \
--lr 6e-4 \
--min_lr 6e-5 \
--warmup_iters 700 \
--out_dir model_result/TinyStories_ablation_post_norm \
--wandb_project "cs336-pretraining-TinyStories-Ablations" \
--device cuda
NoRope
python cs336_basics/train.py \
--train_data_path data/TinyStoriesV2-GPT4-train.bin \
--valid_data_path data/TinyStoriesV2-GPT4-valid.bin \
--no_rope \
--run_name "ablation_no_rope" \
--vocab_size 10000 \
--num_layers 4 --num_heads 16 --d_model 512 --d_ff 1344 \
--max_iters 7000 \
--batch_size 32 \
--context_length 256 \
--lr 6e-4 \
--min_lr 6e-5 \
--warmup_iters 700 \
--out_dir model_result/TinyStories_ablation_no_rope \
--wandb_project "cs336-pretraining-TinyStories-Ablations" \
--device cuda
Silu
python cs336_basics/train.py\
--train_data_path data/TinyStoriesV2-GPT4-train.bin \
--valid_data_path data/TinyStoriesV2-GPT4-valid.bin \
--ffn_type "silu" \
--run_name "ablation_silu" \
--vocab_size 10000 \
--num_layers 4 --num_heads 16 --d_model 512 --d_ff 1344 \
--max_iters 7000 \
--batch_size 32 \
--context_length 256 \
--lr 6e-4 \
--min_lr 6e-5 \
--warmup_iters 700 \
--out_dir model_result/TinyStories_ablation_silu \
--wandb_project "cs336-pretraining-TinyStories-Ablations" \
--device cuda
四、模型使用
import argparse
import json
import torch
from cs336_basics.Layers import TransformerLM
from cs336_basics.criterion import load_checkpoint
from cs336_basics.BPETokenizer import BPETokenizer # 按你实际文件名改
# ===== 配置区 =====
CKPT_PATH = "model_result/TinyStories_baseline/ckpt_final.pt"
VOCAB_PATH = "data/TinyStoriesV2-GPT4-train/vocab.json"
MERGES_PATH = "data/TinyStoriesV2-GPT4-train/merges.txt"
SPECIAL_TOKENS = ["<endoftext>"] # 注意:和训练时一致,无竖线
# 必须和训练时一致
CFG = dict(vocab_size=10000, max_seq_len=256, d_model=512,
num_layers=4, num_heads=16, d_ff=1344, rope_theta=10000.0)
CONTEXT_LENGTH = 256
TEMPERATURE = 0.8
TOP_P = 0.95
MAX_NEW_TOKENS = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ==================
def bytes_to_unicode():
"""和训练保存时用的同一个映射(GPT-2 标准做法)。"""
bs = list(range(ord("!"), ord("~") + 1)) + \
list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(c) for c in cs]
return dict(zip(bs, cs))
def load_tokenizer():
# 反向映射:可见字符 -> 原始字节
byte_decoder = {v: k for k, v in bytes_to_unicode().items()}
# 还原 vocab: {id_str: 可见字符串} -> {int: bytes}
with open(VOCAB_PATH, encoding="utf-8") as f:
json_vocab = json.load(f)
vocab = {
int(idx): bytes(byte_decoder[c] for c in token_str)
for idx, token_str in json_vocab.items()
}
# 还原 merges: 每行 "s1 s2" -> (bytes, bytes)
merges = []
with open(MERGES_PATH, encoding="utf-8") as f:
for line in f:
line = line.rstrip("\n")
if not line:
continue
s1, s2 = line.split(" ")
b1 = bytes(byte_decoder[c] for c in s1)
b2 = bytes(byte_decoder[c] for c in s2)
merges.append((b1, b2))
return BPETokenizer(vocab, merges, SPECIAL_TOKENS)
@torch.no_grad()
def generate(model, tokenizer, prompt):
model.eval()
eos_id = tokenizer.byte_to_id.get("<endoftext>".encode("utf-8"))
ids = tokenizer.encode(prompt)
x = torch.tensor(ids, dtype=torch.long, device=DEVICE).unsqueeze(0)
for _ in range(MAX_NEW_TOKENS):
logits = model(x[:, -CONTEXT_LENGTH:])[:, -1, :]
logits = logits / (TEMPERATURE + 1e-8)
probs = torch.softmax(logits, dim=-1)
sp, si = torch.sort(probs, descending=True)
mask = torch.cumsum(sp, dim=-1) - sp > TOP_P
sp[mask] = 0.0
sp = sp / sp.sum(dim=-1, keepdim=True)
next_id = si.gather(-1, torch.multinomial(sp, 1))
x = torch.cat([x, next_id], dim=1)
if eos_id is not None and next_id.item() == eos_id:
break
return tokenizer.decode(x[0].tolist())
def main():
p = argparse.ArgumentParser()
p.add_argument("--prompt", type=str, default="Once upon a time")
args = p.parse_args()
tokenizer = load_tokenizer()
model = TransformerLM(device=DEVICE, **CFG).to(DEVICE)
load_checkpoint(CKPT_PATH, model, optimizer=None)
print("=" * 60)
print(generate(model, tokenizer, args.prompt))
print("=" * 60)
if __name__ == "__main__":
main()
本文是原创文章,采用 CC BY-NC-ND 4.0 协议,完整转载请注明来自 程序员Orion
评论
匿名评论
隐私政策
你无需删除空行,直接评论以获取最佳展示效果