【AI Infra/MLsys】CMU 10-714 Needle:训练 PTB 爆显存?深扒 Python GC 与 GPU 显存的“跨服交流障碍”

关键词:Needle, AI Infra, Python GC, Reference Cycle, CUDA OOM, Memory Leak

0. 背景

前两个月听学长建议,下定决心转 AI Infra 方向。为了快速补齐系统知识并准备实习面试,我花了半个月时间肝完了 CMU 10-714 (Deep Learning Systems) 的课程作业。这门课要求手写一个名为 Needle 的深度学习框架,并用它来复现 PyTorch 的功能。

hw4_extra 中,任务是实现 Transformer 并训练 Language Model (PTB 数据集)。当时因为赶进度,虽然模型跑通了,但留了一个“显存缓慢泄漏”的 Bug。最近两天拿到了实习 offer,闲下来正好把这个遗留问题彻底解决,顺便作为毕设的预热。

机器配置

  • OS: Ubuntu 20.04
  • GPU: 2x NVIDIA Tesla V100 (16GB)
  • Environment: Jupyter Notebook + Needle (CUDA Backend)

1. 案发现场:显存的“灵异”增长

在实现 epoch_general_ptb 训练循环时,我的代码大致如下

def epoch_general_ptb(data, model, seq_len=40, loss_fn=nn.SoftmaxLoss(), opt=None,
        clip=None, device=None, dtype="float32"):
    # ... (初始化代码略) ...
    for i in range(0, nbatch - 1, seq_len):
        X, y = ndl.data.get_batch(data, i, seq_len, device=device, dtype=dtype)
        if h is not None:
             if isinstance(h, tuple):
                 h = tuple(x.detach() for x in h)
             else:
                 h = h.detach()
        logits, h = model(X, h)
        loss = loss_fn(logits, y)
        if opt:
            opt.reset_grad()
            loss.backward()
            if clip is not None:
                opt.clip_grad_norm(clip)
            opt.step()
        avg_loss += loss.numpy() * y.shape[0]
        avg_acc += np.sum(logits.numpy().argmax(axis=1) == y.numpy())
    return avg_acc/total_samples, avg_loss/total_samples

代码逻辑看起来已经非常严谨。然而,当你盯着 watch -n 1 nvidia-smi 时,恐怖的事情发生了:

显存占用并没有在第一个 Batch 后稳定下来,而是从起步的 3GB 开始,以每秒约 30MB 的速度匀速爬升。

3.0GB -> 4.5GB -> 8.2GB -> ... -> 16.1GB (CUDA Out Of Memory)

不到总 Epoch 数的一半,程序直接崩盘。

2. 并非“常规嫌疑犯”

作为具备一定经验的炼丹师,我首先排查了两个最容易导致 Transformer 显存爆炸的经典错误,但发现我的代码并没有这两个问题。为了防止误判,这里简单列举一下这两个“常规嫌疑犯”:

2.1 嫌疑人 A:无限生长的 BPTT

如果在新的 batch 循环中没有切断上一轮 batch 中 Hidden State 的计算图,h 会记录从 Epoch 开始到现在的完整历史,导致显存爆炸。
我的代码(已做正确处理)

# 每个 batch 开始前显式 detach
if h is not None:
    if isinstance(h, tuple):
        h = tuple(x.detach() for x in h) # 切断与上一个 batch 的连接
    else:
        h = h.detach() # 切断与上一个 batch 的连接

2.2 嫌疑人 B:囤积 Tensor 的 Loss

如果在统计 Loss 时直接累加 Tensor 对象,会将整个计算图挂在 avg_loss 变量上。
我的代码(已做正确处理)

# 使用 .numpy() 获取纯数值
avg_loss += loss.numpy() * batch_cnt 

既然常规逻辑都对,为什么显存还是炸了?

3. 真凶:Python GC 与 GPU 的“信息不对称”

经过深度调试,我发现问题的根源在于 Python 的垃圾回收机制(GC)跟不上 GPU 显存的消耗速度,且二者存在巨大的信息差。

3.1 核心矛盾:小对象 vs 大显存

Python 的自动 GC 主要是基于引用计数(Reference Counting)分代回收(Generational GC)

  • Python 的视角:Needle 中的一个 Tensor 对象,在 Python 解释器里只是一个轻量级的 Wrapper,可能只占几百字节。Python 觉得内存很空闲,完全没必要触发 GC。
  • GPU 的视角:这个轻量级 Wrapper 背后,可能管理着几百 MB 的显存数据。

结果:Python 觉得“我还行,再攒攒垃圾再扔”,而 GPU 已经在喊“救命,我装不下了”。

3.2 循环引用(Reference Cycle)

如果是简单的变量,引用计数归零会立即释放。但在深度学习框架(无论是 Needle 还是 PyTorch)的计算图中,循环引用无处不在:

Output Tensor -> Op Node -> Input Tensor -> ... -> Output Tensor

这种引用闭环导致引用计数永远不为 0。这些“僵尸对象”必须等待 Python 的 Cyclic GC(循环垃圾收集器) 介入扫描才能被回收。由于 Python 并不感知 GPU 显存压力,这个扫描往往来得太晚。

3.3 解决方案:手动挡内存管理

为了强迫 Python 释放显存,我在每个 Batch 结束时加入了极其激进的内存管理代码:

# 1. 手动斩断引用
del logits, loss, X, y 
# 2. 强制 GC 工作
import gc
gc.collect()

加入这两行后,显存占用瞬间被死死按在了 3GB - 9GB之间,直到训练结束

4. 深度复盘:只有 gc.collect() 够吗?del 到底起到了什么作用?

我在调试过程中对比了“手动 GC”与“手动 GC + del”的两种方案。结果非常有意思:两者都不会爆显存,但显存的“水位线”表现不同。

这里有一个极易被忽略的细节:Hidden State h 是计算图的“锚点”。

4.1 方案 A:只加 gc.collect(),不加 del -> 显存水位最高

  • 状态:在 Loop 结尾调用 gc.collect() 时:
    1. 上一轮的尸体:由于 h 在本轮开头已经被 detach()(且被重新赋值),上一轮的计算图变成了无主垃圾。GC 会成功回收它们。
    2. 这一轮的树干h 依然指向当前轮次的 Transformer 输出,它像一个锚点,死死拉住了当前轮次的主干计算图,GC 无法回收
    3. 这一轮的枝叶losslogits 变量依然活着,它们拉住了计算图的后半段(Softmax 等),GC 也无法回收
  • 结果:显存里同时躺着【模型权重 + 这一轮完整的计算图】。且因为 loss 等变量要等到下一轮赋值时才释放,显存峰值会更高。

4.2 方案 B:先 delgc.collect() -> 显存水位降低(剪除枝叶)

    # ... Loop 结尾 ...
    del logits, loss, X, y  # [动作 1]:手动斩断“枝叶”
    gc.collect()            # [动作 2]:立刻回收垃圾
  • 差异在哪里?
    虽然 h 依然拉着计算图的“主干”(Transformer Layers),但是通过 del,我们显式地切断了 losslogits 对计算图“枝叶”(Projection, Softmax, CrossEntropy)的引用。
  • GC 做了什么?
    1. 回收上一轮的全部残留(由开头的 detach 触发)。
    2. 回收这一轮X, y, logits, loss 以及它们对应的计算图节点(由 del 触发)。
  • 结果
    虽然当前轮次的 Transformer 主干显存要等到下一轮开头 h.detach() 时才能释放,但我们通过手动 del 提前释放了部分显存(取决于 Vocab 大小)

    更重要的是,这避免了变量重绑定带来的短暂双倍占用(即在下一轮 loss = ... 执行前的一瞬间,新旧 loss 可能共存)。

4.3 结论

  • 这一轮的 GC:负责回收 上一轮的计算图(树干) + 这一轮的枝叶(Loss/Logits)
  • 这一轮的 h:负责暂时“撑住”这一轮的树干,直到下一轮 detach 接力。

总结del + gc 的组合拳,本质上是将显存释放的时间点尽可能提前,消除变量重绑定期间的显存波峰,这在显存极限操作(如 V100 跑大模型)中是防止 OOM 的关键一根稻草。

5. 最终代码

这是修改后完美运行的代码,显存稳如老狗:

def epoch_general_ptb(data, model, seq_len=40, loss_fn=nn.SoftmaxLoss(), opt=None,
        clip=None, device=None, dtype="float32"):

    # ... (初始化代码略) ...
    import gc # 引入垃圾回收模块

    for i in range(0, nbatch - 1, seq_len):
        X, y = ndl.data.get_batch(data, i, seq_len, device=device, dtype=dtype)

        # [常规操作] Detach Hidden State
        if h is not None:
             if isinstance(h, tuple):
                 h = tuple(x.detach() for x in h)
             else:
                 h = h.detach()

        logits, h = model(X, h)
        loss = loss_fn(logits, y)

        if opt:
            opt.reset_grad()
            loss.backward()
            if clip is not None:
                opt.clip_grad_norm(clip)
            opt.step()

        # [常规操作] 累加 Loss 时只取数值
        avg_loss += loss.numpy() * y.shape[0]
        avg_acc += np.sum(logits.numpy().argmax(axis=1) == y.numpy())

        # [关键操作] 手动内存管理防爆
        # 必须先 del 移除引用,再 gc 回收循环引用
        del logits, loss, X, y
        gc.collect()

    # 循环结束后清理最后的残留
    del h
    gc.collect()

    return avg_acc/total_samples, avg_loss/total_samples

6. 额外的小插曲:消失的 1.6GB 与“删不掉”的 300MB

evaluate_ptb 运行结束回到 Jupyter 命令行后,我注意到显存依然停留在 1952MiB。此时训练已经结束,为什么显存没有释放?

6.1 1952MiB 里装的是什么?

Jupyter Notebook 的特性是变量持久化。虽然代码单元格跑完了,但定义的全局变量依然“活”在内存里:

  1. Model Parameters: model 对象依然存在,占用显存。

  2. Optimizer States: 如果优化器对象还在引用参数(或者框架缓存了动量信息),这部分通常是模型参数量的 2-3 倍(Adam 需要保存 $$m$$ 和 $$v$$ 两个动量矩阵)。

  3. GPU Dataset: train_datacorpus 是被搬运到 device 上的,这部分数据依然驻留。

  4. Framework Caching: 深度学习框架通常不会立即归还申请过的显存,而是作为 Cache 保留以备下次使用。

6.2 为什么剩下 314MiB 删不掉?

当我执行了 del model 并配合 gc.collect() 后,显存瞬间从 1952MiB 掉到了 314MiB。这证明模型和中间变量的显存已经被成功回收。

但无论我如何 gc.collect(),这最后的 314MiB 始终雷打不动。这不是内存泄漏,而是 CUDA Context(上下文) 的硬性开销。

  • 什么是 CUDA Context:只要 Python 进程(PID)调用了 ndl.cuda(),GPU 驱动就会创建一个上下文环境,用于加载动态链接库(如 cuBLAS 用于矩阵乘法)、管理设备状态和映射内存。
  • 入场费:这 ~300MB 可以理解为使用 GPU 的“入场费”。
  • 生命周期:它与 Python 进程绑定。只要 Jupyter Kernel 不重启(PID 不变),这部分显存就永远不会释放。

结论:显存管理非常健康。1.6GB 的模型相关数据已成功回收,剩下的 300MB 是驱动层面的正常驻留。如果想归零,唯有点击 Kernel -> Restart Kernel

7. 总结 (Takeaways)

  1. AI Infra 的视角:不要只盯着 Python 代码的逻辑正确性。在深度学习系统中,Python 只是指挥官,真正的资源在底层 C++/CUDA。当两者对资源紧张程度的感知不一致时,必须人工介入。
  2. GC 的局限性:在显存敏感的场景(大模型、长序列 RNN),依赖 Python 自动 GC 是危险的。
  3. del 的重要性gc.collect() 只能回收“垃圾”。如果不 del 掉变量,它们就不是垃圾,GC 想收也收不走。清空显存 = 移除引用 (del) + 回收孤岛 (gc)

8. 附录

相关代码仓库(也是我的CMU10-714实现)
https://github.com/Crzax/CMU10-714-DeepLearningSystems

© 版权声明
THE END
努力变成更好的自己
点赞6 分享
评论 抢沙发
头像
欢迎您留下宝贵的见解!
提交
头像

昵称

取消
昵称表情代码图片

    暂无评论内容