[PD] Add custom memory pool option to support Mooncake PD with NVLink (#7264)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
47
python/sglang/srt/disaggregation/mooncake/memory_pool.py
Normal file
47
python/sglang/srt/disaggregation/mooncake/memory_pool.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
import threading
|
||||
from importlib import resources
|
||||
from typing import Dict, Final, Optional
|
||||
|
||||
import torch
|
||||
from torch.cuda.memory import CUDAPluggableAllocator
|
||||
|
||||
|
||||
# TODO(shangming): move this class into mooncake's package for more general use cases
|
||||
class MooncakeNVLinkAllocator:
|
||||
_instances: Dict[torch.device, CUDAPluggableAllocator] = {}
|
||||
_lock: Final = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def _get_so_path(cls) -> str:
|
||||
"""Dynamically locate hook.so in the mooncake package installation"""
|
||||
try:
|
||||
# Attempt to locate package resource
|
||||
with resources.path("mooncake", "hook.so") as so_path:
|
||||
if so_path.exists():
|
||||
return str(so_path)
|
||||
except (ImportError, FileNotFoundError, TypeError):
|
||||
pass
|
||||
|
||||
# Fallback strategy: check in package location via import metadata
|
||||
try:
|
||||
import mooncake
|
||||
|
||||
base_path = os.path.dirname(os.path.abspath(mooncake.__file__))
|
||||
so_path = os.path.join(base_path, "hook.so")
|
||||
if os.path.exists(so_path):
|
||||
return so_path
|
||||
except (ImportError, FileNotFoundError, TypeError):
|
||||
raise ImportError(
|
||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL require mooncake-transfer-engine >= 0.3.3.post2."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_allocator(cls, device: torch.device) -> CUDAPluggableAllocator:
|
||||
with cls._lock:
|
||||
if device not in cls._instances:
|
||||
so_path = cls._get_so_path()
|
||||
cls._instances[device] = CUDAPluggableAllocator(
|
||||
so_path, "mc_nvlink_malloc", "mc_nvlink_free"
|
||||
)
|
||||
return cls._instances[device]
|
||||
@@ -6,6 +6,7 @@ import random
|
||||
import threading
|
||||
import warnings
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
@@ -84,24 +85,37 @@ class ReqToMetadataIdxAllocator:
|
||||
|
||||
|
||||
class MetadataBuffers:
|
||||
def __init__(self, size: int, max_top_logprobs_num: int = 128):
|
||||
# TODO: abort top_logprobs_num > 128 in PD
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
max_top_logprobs_num: int = 128,
|
||||
custom_mem_pool: torch.cuda.MemPool = None,
|
||||
):
|
||||
self.custom_mem_pool = custom_mem_pool
|
||||
device = "cuda" if self.custom_mem_pool else "cpu"
|
||||
|
||||
# We transfer the metadata of first output token to decode
|
||||
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
||||
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
|
||||
self.output_token_logprobs_val = torch.zeros(
|
||||
(size, 16), dtype=torch.float32, device="cpu"
|
||||
)
|
||||
self.output_token_logprobs_idx = torch.zeros(
|
||||
(size, 16), dtype=torch.int32, device="cpu"
|
||||
)
|
||||
self.output_top_logprobs_val = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
|
||||
)
|
||||
self.output_top_logprobs_idx = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
|
||||
)
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.custom_mem_pool
|
||||
else nullcontext()
|
||||
):
|
||||
# TODO: abort top_logprobs_num > 128 in PD
|
||||
|
||||
# We transfer the metadata of first output token to decode
|
||||
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
||||
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
||||
self.output_token_logprobs_val = torch.zeros(
|
||||
(size, 16), dtype=torch.float32, device=device
|
||||
)
|
||||
self.output_token_logprobs_idx = torch.zeros(
|
||||
(size, 16), dtype=torch.int32, device=device
|
||||
)
|
||||
self.output_top_logprobs_val = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.float32, device=device
|
||||
)
|
||||
self.output_top_logprobs_idx = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
def get_buf_infos(self):
|
||||
ptrs = [
|
||||
|
||||
Reference in New Issue
Block a user