[PD] Optimize custom mem pool usage and bump mooncake version (#7393)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -1,47 +0,0 @@
|
|||||||
import os
|
|
||||||
import threading
|
|
||||||
from importlib import resources
|
|
||||||
from typing import Dict, Final, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.cuda.memory import CUDAPluggableAllocator
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(shangming): move this class into mooncake's package for more general use cases
|
|
||||||
class MooncakeNVLinkAllocator:
|
|
||||||
_instances: Dict[torch.device, CUDAPluggableAllocator] = {}
|
|
||||||
_lock: Final = threading.Lock()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_so_path(cls) -> str:
|
|
||||||
"""Dynamically locate hook.so in the mooncake package installation"""
|
|
||||||
try:
|
|
||||||
# Attempt to locate package resource
|
|
||||||
with resources.path("mooncake", "hook.so") as so_path:
|
|
||||||
if so_path.exists():
|
|
||||||
return str(so_path)
|
|
||||||
except (ImportError, FileNotFoundError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Fallback strategy: check in package location via import metadata
|
|
||||||
try:
|
|
||||||
import mooncake
|
|
||||||
|
|
||||||
base_path = os.path.dirname(os.path.abspath(mooncake.__file__))
|
|
||||||
so_path = os.path.join(base_path, "hook.so")
|
|
||||||
if os.path.exists(so_path):
|
|
||||||
return so_path
|
|
||||||
except (ImportError, FileNotFoundError, TypeError):
|
|
||||||
raise ImportError(
|
|
||||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL require mooncake-transfer-engine >= 0.3.3.post2."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_allocator(cls, device: torch.device) -> CUDAPluggableAllocator:
|
|
||||||
with cls._lock:
|
|
||||||
if device not in cls._instances:
|
|
||||||
so_path = cls._get_so_path()
|
|
||||||
cls._instances[device] = CUDAPluggableAllocator(
|
|
||||||
so_path, "mc_nvlink_malloc", "mc_nvlink_free"
|
|
||||||
)
|
|
||||||
return cls._instances[device]
|
|
||||||
@@ -270,12 +270,10 @@ class MHATokenToKVPool(KVCache):
|
|||||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||||
)
|
)
|
||||||
if self.enable_custom_mem_pool:
|
if self.enable_custom_mem_pool:
|
||||||
from sglang.srt.disaggregation.mooncake.memory_pool import (
|
|
||||||
MooncakeNVLinkAllocator,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(shangming): abstract custom allocator class for more backends
|
# TODO(shangming): abstract custom allocator class for more backends
|
||||||
allocator = MooncakeNVLinkAllocator.get_allocator(self.device)
|
from mooncake.allocator import NVLinkAllocator
|
||||||
|
|
||||||
|
allocator = NVLinkAllocator.get_allocator(self.device)
|
||||||
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
||||||
else:
|
else:
|
||||||
self.custom_mem_pool = None
|
self.custom_mem_pool = None
|
||||||
@@ -602,12 +600,10 @@ class MLATokenToKVPool(KVCache):
|
|||||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||||
)
|
)
|
||||||
if self.enable_custom_mem_pool:
|
if self.enable_custom_mem_pool:
|
||||||
from sglang.srt.disaggregation.mooncake.memory_pool import (
|
|
||||||
MooncakeNVLinkAllocator,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(shangming): abstract custom allocator class for more backends
|
# TODO(shangming): abstract custom allocator class for more backends
|
||||||
allocator = MooncakeNVLinkAllocator.get_allocator(self.device)
|
from mooncake.allocator import NVLinkAllocator
|
||||||
|
|
||||||
|
allocator = NVLinkAllocator.get_allocator(self.device)
|
||||||
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
||||||
else:
|
else:
|
||||||
self.custom_mem_pool = None
|
self.custom_mem_pool = None
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ pip install -e "python[dev]"
|
|||||||
pip list
|
pip list
|
||||||
|
|
||||||
# Install additional dependencies
|
# Install additional dependencies
|
||||||
pip install mooncake-transfer-engine==0.3.2.post1 nvidia-cuda-nvrtc-cu12
|
pip install mooncake-transfer-engine==0.3.4 nvidia-cuda-nvrtc-cu12
|
||||||
|
|
||||||
# For lmms_evals evaluating MMMU
|
# For lmms_evals evaluating MMMU
|
||||||
git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
|
git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
|
||||||
|
|||||||
Reference in New Issue
Block a user