[PD] Add custom memory pool option to support Mooncake PD with NVLink (#7264)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc
2025-06-18 08:21:37 +08:00
committed by GitHub
parent ceaa85c9e6
commit c26d7349d3
4 changed files with 163 additions and 47 deletions

View File

@@ -0,0 +1,47 @@
import os
import threading
from importlib import resources
from typing import Dict, Final, Optional
import torch
from torch.cuda.memory import CUDAPluggableAllocator
# TODO(shangming): move this class into mooncake's package for more general use cases
class MooncakeNVLinkAllocator:
_instances: Dict[torch.device, CUDAPluggableAllocator] = {}
_lock: Final = threading.Lock()
@classmethod
def _get_so_path(cls) -> str:
"""Dynamically locate hook.so in the mooncake package installation"""
try:
# Attempt to locate package resource
with resources.path("mooncake", "hook.so") as so_path:
if so_path.exists():
return str(so_path)
except (ImportError, FileNotFoundError, TypeError):
pass
# Fallback strategy: check in package location via import metadata
try:
import mooncake
base_path = os.path.dirname(os.path.abspath(mooncake.__file__))
so_path = os.path.join(base_path, "hook.so")
if os.path.exists(so_path):
return so_path
except (ImportError, FileNotFoundError, TypeError):
raise ImportError(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL require mooncake-transfer-engine >= 0.3.3.post2."
)
@classmethod
def get_allocator(cls, device: torch.device) -> CUDAPluggableAllocator:
with cls._lock:
if device not in cls._instances:
so_path = cls._get_so_path()
cls._instances[device] = CUDAPluggableAllocator(
so_path, "mc_nvlink_malloc", "mc_nvlink_free"
)
return cls._instances[device]

View File

@@ -6,6 +6,7 @@ import random
import threading
import warnings
from collections import deque
from contextlib import nullcontext
from enum import Enum
from typing import TYPE_CHECKING, List, Optional
@@ -84,24 +85,37 @@ class ReqToMetadataIdxAllocator:
class MetadataBuffers:
def __init__(self, size: int, max_top_logprobs_num: int = 128):
# TODO: abort top_logprobs_num > 128 in PD
def __init__(
self,
size: int,
max_top_logprobs_num: int = 128,
custom_mem_pool: torch.cuda.MemPool = None,
):
self.custom_mem_pool = custom_mem_pool
device = "cuda" if self.custom_mem_pool else "cpu"
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device="cpu"
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device="cpu"
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
)
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
else nullcontext()
):
# TODO: abort top_logprobs_num > 128 in PD
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device=device
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device=device
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device=device
)
def get_buf_infos(self):
ptrs = [