Skip to content

mmap 行为优化#17

Draft
sfiisf wants to merge 4 commits into
refine_offloadfrom
feat/mmap-optimization
Draft

mmap 行为优化#17
sfiisf wants to merge 4 commits into
refine_offloadfrom
feat/mmap-optimization

Conversation

@sfiisf
Copy link
Copy Markdown
Collaborator

@sfiisf sfiisf commented Apr 23, 2026

使用方法与 #13 保持一致,设置环境变量 MMAP_MEM_THRESHOLD_GB=x 表示若 cpu mem 小于 xG 时,遇到 offload 会 offload 到 mmap


mmap 当前存在一些问题

  1. OOM 发生时 offload 行为异常:
    OOM 时将 memory_to_free 置为 1e30, 若模型 model_loaded_size > available_memory - mmap_mem_threshold, 则会始终走 partially_unload; 应该是希望 memory_to_free 够大时 partially_unload 也可以将模型完整 offload, 但实际仅能卸载较小一部分内容, 导致显存一直在被占用, 再次请求仍旧发生 OOM 然后重复上述逻辑.
        if min(memory_to_free, model_loaded_size) > available_memory - mmap_mem_threshold or memory_to_free < model_loaded_size:
            partially_unload = True
        else:
            partially_unload = False

https://github.com/siliconflow/cce/issues/176#issuecomment-4249756325

  1. 仍旧会发生暴 CPU 内存的情况
    目前是注意到 to_mmap 过程中有个 .cpu 会存在较明显的内存峰值
def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
...
    cpu_tensor = t.cpu()
    torch.save(cpu_tensor, temp_file)
...
    mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)

https://github.com/siliconflow/ComfyGridRuntime/issues/181#issuecomment-4220391298

@sfiisf
Copy link
Copy Markdown
Collaborator Author

sfiisf commented Apr 23, 2026

torch2.7 后新增了 torch.cuda.gds 对 NVIDIA cuFile / GPUDirect Storage (类似于 DMA, GPU<->DISK 不经过 CPU)进行了薄封装 https://docs.pytorch.org/docs/stable/generated/torch.cuda.gds.GdsFile.html
但需注意 save_storage 存储内容为裸的二进制数据, 无法通过 torch.load 等方式直接读取. 读回 GPU 配套使用 load_storage, 但失去了 offload 的语义, 因此尝试使用 mmap + frombuffer 读回 CPU

用形如下的样例进行测试,

SHAPE = (1024, 1024, 1024)
DEVICE = "cuda:2"

# gds
src = torch.randn(*SHAPE, device=DEVICE)
file = torch.cuda.gds.GdsFile("temp1.pt", os.O_CREAT | os.O_RDWR)
file.save_storage(src.untyped_storage(), offset=0)
num_bytes = src.numel() * src.element_size()
src_dtype = src.dtype
src_shape = src.shape

fo = open("temp1.pt", "rb")
mm = mmap.mmap(fo.fileno(), length=num_bytes, access=mmap.ACCESS_READ)
dest = torch.frombuffer(mm, dtype=src_dtype).reshape(src_shape)

# old
src = torch.randn(*SHAPE, device=DEVICE)
src_cpu = src.cpu()
torch.save(src_cpu, "temp2.pt")
dest = torch.load("temp2.pt", map_location="cpu", mmap=True, weights_only=False)

可得到结果, 发现提升还是较大的

gds mmap time cost: 3614.706 ms
gds mmap cpu rss start: 0.498 GB
gds mmap cpu rss end:   0.641 GB
gds mmap cpu rss peak:  0.641 GB
gds mmap cpu rss current delta: 0.143 GB
gds mmap cpu rss peak delta:    0.142 GB
cpu save mmap time cost: 9936.152 ms
cpu save mmap cpu rss start: 0.641 GB
cpu save mmap cpu rss end:   4.642 GB
cpu save mmap cpu rss peak:  4.642 GB

@sfiisf
Copy link
Copy Markdown
Collaborator Author

sfiisf commented Apr 23, 2026

当前 gds 实现: 00c0028
测试工作流: https://github.com/anveshane/Comfyui_turbodiffusion/blob/main/turbowan_workflow.json
对比 MMAP_MEM_THRESHOLD_GB=0 即不开启 mmap; MMAP_MEM_THRESHOLD_GB=512 即全部通过 mmap 且使用 .cpu() .save; 全部通过 mmap 且使用 gds 的测试结果如下

可以看到耗时有比较明显的提高,但在降低内存峰值上同样也有较明显的提高

进一步降低峰值可能考虑使用 save_storage 与 load_storage 配套使用完全代替 offload 2 cpu, 即 disk_offload 完全绕开 CPU, 但需要改动的内容可能较大, 进一步测试看需求再做考虑.

MMAP_MEM_THRESHOLD_GB=0:
Prompt af7337e5-7a82-4cc8-b48d-d3a13bfbd108 memory stats (109.56s):
  RSS start/end/peak: 1397.63 MiB / 37158.70 MiB / 37158.70 MiB
  USS start/end/peak: 1322.99 MiB / 37093.86 MiB / 37093.86 MiB
Prompt 0e10e1e5-ee24-4360-bd1f-30af953bf119 memory stats (32.95s):
  RSS start/end/peak: 37158.70 MiB / 37527.13 MiB / 37527.13 MiB
  USS start/end/peak: 37093.91 MiB / 37454.35 MiB / 37454.44 MiB
Prompt 0445ec03-6a6a-42d8-8eea-8c7ea07d4f16 memory stats (29.47s):
  RSS start/end/peak: 37527.13 MiB / 37719.08 MiB / 37719.08 MiB
  USS start/end/peak: 37454.45 MiB / 37654.56 MiB / 37654.56 MiB



MMAP_MEM_THRESHOLD_GB=512 all_2cpu_save:
Prompt 71859d6b-ebae-4586-b75b-71bd0b01600f memory stats (699.34s):
  RSS start/end/peak: 1396.25 MiB / 24430.46 MiB / 30129.74 MiB
  USS start/end/peak: 955.58 MiB / 23880.52 MiB / 29583.76 MiB
Prompt a4550d38-d314-4ff1-a241-1adde9bec94f memory stats (693.21s):
  RSS start/end/peak: 24430.52 MiB / 24423.28 MiB / 24639.08 MiB
  USS start/end/peak: 23880.59 MiB / 23880.72 MiB / 24093.73 MiB
Prompt fa46e5f0-bf95-43ee-8137-c478d13da96f memory stats (725.49s):
  RSS start/end/peak: 24423.35 MiB / 22552.54 MiB / 24611.89 MiB
  USS start/end/peak: 23880.73 MiB / 22018.51 MiB / 24068.34 MiB
Prompt 51cf9b12-1820-4042-998d-10c20f6729e3 memory stats (694.32s):
  RSS start/end/peak: 21853.77 MiB / 21121.22 MiB / 21853.77 MiB
  USS start/end/peak: 21319.45 MiB / 20590.23 MiB / 21319.45 MiB
Prompt 263d6b45-8009-4f73-9ac4-d04362bce335 memory stats (711.84s):
  RSS start/end/peak: 21121.36 MiB / 21091.86 MiB / 21335.42 MiB
  USS start/end/peak: 20590.22 MiB / 20559.59 MiB / 20803.21 MiB


MMAP_MEM_THRESHOLD_GB=512 all_gds:
Prompt 7f0618c0-a778-49f3-bb4b-9844724a637c memory stats (739.28s):
  RSS start/end/peak: 1400.79 MiB / 17361.43 MiB / 23111.11 MiB
  USS start/end/peak: 957.23 MiB / 16823.97 MiB / 22572.88 MiB
Prompt 24dfc9db-9bb9-4ad6-9cbd-d89ab6a0909b memory stats (762.83s):
  RSS start/end/peak: 17361.59 MiB / 10944.14 MiB / 17478.92 MiB
  USS start/end/peak: 16823.98 MiB / 10398.46 MiB / 16988.87 MiB
Prompt ede03cc6-874b-4df1-9d22-975f568f7887 memory stats (749.50s):
  RSS start/end/peak: 10944.35 MiB / 10934.63 MiB / 11148.28 MiB
  USS start/end/peak: 10398.20 MiB / 10397.01 MiB / 10565.04 MiB
Prompt 90f6be16-2f29-476a-b214-db6b340e7f8d memory stats (760.65s):
  RSS start/end/peak: 10934.68 MiB / 10927.91 MiB / 11121.48 MiB
  USS start/end/peak: 10397.09 MiB / 10396.98 MiB / 10591.96 MiB
Prompt f256e8ba-a0d8-4c30-9457-708c2f728022 memory stats (764.62s):
  RSS start/end/peak: 10928.00 MiB / 10929.80 MiB / 11117.88 MiB
  USS start/end/peak: 10396.98 MiB / 10397.15 MiB / 10588.71 MiB

Comment thread comfy/model_patcher.py Outdated

# Load with mmap - this doesn't load all data into RAM
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)
file = torch.cuda.gds.GdsFile(temp_file, os.O_CREAT | os.O_RDWR)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可能需要做一个前置依赖检查,看是否 cuda 可用。
因为我们未来这个要合并到主分支的话,就会影响 musa,npu 等其它芯片。

另外,我觉得目前先不用做硬件抽象层(如果你看到这条有考虑的话),那种抽象,等真的适配了几家芯片后再做比较好。

Comment thread comfy/model_patcher.py Outdated
del t
gc.collect()

fo = open(temp_file, "rb")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该会有资源(文件描述符)泄漏吧,需要被动等 gc 去回收。

函数返回后:
├── fd1 (来自 fo=open)        → 泄漏!无人关闭
├── fd2 (mmap 内部 dup)       → mmap 对象存活期间有效
│   └── mmap 对象被 Tensor 间接持有
│       └── Tensor GC → mmap GC → munmap + close(fd2)
│
└── 文件 temp_file            → _cleanup 在 Tensor GC 时删除
    └── 但 fd1 还开着!Windows 上可能无法删除文件

问 AI 给了个建议等版本:

def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
    # ... 创建文件、save_storage 等 ...
    
    with open(temp_file, "rb") as fo:
        mm = mmap.mmap(fo.fileno(), length=num, access=mmap.ACCESS_READ)
        mmap_tensor = torch.frombuffer(mm, dtype=t_type).reshape(t_shape).cpu()
    
    # ✅ fo 关闭,mm 存活(内部 dup 的 fd)
    # 需要确保 mm 在 Tensor 销毁时被清理
    
    # 包装 cleanup,同时处理 mm 和文件
    def _cleanup():
        try:
            # 先确保 Tensor 不再引用 mm 内存
            # 实际上需要 mm 被 GC,这里只能尽力
            if os.path.exists(temp_file):
                os.remove(temp_file)
        except Exception:
            pass
    
    # 更好的做法:用弱引用跟踪 mm
    weakref.finalize(mmap_tensor, _cleanup)
    
    # 额外:保存 mm 引用,防止过早 GC(如果需要)
    # 但通常 Tensor 内部持有足够信息
    
    return mmap_tensor

@ccndcn
Copy link
Copy Markdown
Collaborator

ccndcn commented May 7, 2026

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants