Add support for NCCL symmetric memory for TP allreduces (#8238)
This commit is contained in:
@@ -251,6 +251,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| `--disable-cuda-graph-padding` | Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed. | False |
|
| `--disable-cuda-graph-padding` | Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed. | False |
|
||||||
| `--enable-profile-cuda-graph` | Enable profiling of cuda graph capture. | False |
|
| `--enable-profile-cuda-graph` | Enable profiling of cuda graph capture. | False |
|
||||||
| `--enable-nccl-nvls` | Enable NCCL NVLS for prefill heavy requests when available. | False |
|
| `--enable-nccl-nvls` | Enable NCCL NVLS for prefill heavy requests when available. | False |
|
||||||
|
| `--enable-symm-mem` | Enable NCCL symmetric memory for fast collectives. | False |
|
||||||
| `--enable-tokenizer-batch-encode` | Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds. | False |
|
| `--enable-tokenizer-batch-encode` | Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds. | False |
|
||||||
| `--disable-outlines-disk-cache` | Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency. | False |
|
| `--disable-outlines-disk-cache` | Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency. | False |
|
||||||
| `--disable-custom-all-reduce` | Disable the custom all-reduce kernel and fall back to NCCL. | False |
|
| `--disable-custom-all-reduce` | Disable the custom all-reduce kernel and fall back to NCCL. | False |
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ class PyNcclCommunicator:
|
|||||||
self.available = True
|
self.available = True
|
||||||
self.disabled = False
|
self.disabled = False
|
||||||
|
|
||||||
|
self.nccl_version = self.nccl.ncclGetRawVersion()
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
|
logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
|
||||||
|
|
||||||
@@ -259,6 +260,12 @@ class PyNcclCommunicator:
|
|||||||
cudaStream_t(stream.cuda_stream),
|
cudaStream_t(stream.cuda_stream),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def register_comm_window_raw(self, ptr: int, size: int):
|
||||||
|
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
|
||||||
|
|
||||||
|
def deregister_comm_window(self, window):
|
||||||
|
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def change_state(
|
def change_state(
|
||||||
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
||||||
|
|||||||
@@ -0,0 +1,133 @@
|
|||||||
|
import tempfile
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
from torch.cuda.memory import CUDAPluggableAllocator
|
||||||
|
|
||||||
|
from sglang.srt.distributed.parallel_state import GroupCoordinator
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
|
||||||
|
nccl_allocator_source = """
|
||||||
|
#include <nccl.h>
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
void* nccl_alloc_plug(size_t size, int device, void* stream) {
|
||||||
|
void* ptr;
|
||||||
|
ncclResult_t err = ncclMemAlloc(&ptr, size);
|
||||||
|
return ptr;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
|
||||||
|
ncclResult_t err = ncclMemFree(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_allocator = None
|
||||||
|
_mem_pool = None
|
||||||
|
_registered_base_addrs = set()
|
||||||
|
_graph_pool_id = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_symmetric_memory_enabled():
|
||||||
|
return global_server_args_dict["enable_symm_mem"]
|
||||||
|
|
||||||
|
|
||||||
|
def set_graph_pool_id(graph_pool_id):
|
||||||
|
global _graph_pool_id
|
||||||
|
_graph_pool_id = graph_pool_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_nccl_mem_pool():
|
||||||
|
global _allocator, _mem_pool
|
||||||
|
if _mem_pool is None:
|
||||||
|
out_dir = tempfile.gettempdir()
|
||||||
|
nccl_allocator_libname = "nccl_allocator"
|
||||||
|
torch.utils.cpp_extension.load_inline(
|
||||||
|
name=nccl_allocator_libname,
|
||||||
|
cpp_sources=nccl_allocator_source,
|
||||||
|
with_cuda=True,
|
||||||
|
extra_ldflags=["-lnccl"],
|
||||||
|
verbose=True,
|
||||||
|
is_python_module=False,
|
||||||
|
build_directory=out_dir,
|
||||||
|
)
|
||||||
|
_allocator = CUDAPluggableAllocator(
|
||||||
|
f"{out_dir}/{nccl_allocator_libname}.so",
|
||||||
|
"nccl_alloc_plug",
|
||||||
|
"nccl_free_plug",
|
||||||
|
).allocator()
|
||||||
|
_mem_pool = torch.cuda.MemPool(_allocator)
|
||||||
|
return _mem_pool
|
||||||
|
|
||||||
|
|
||||||
|
class use_symmetric_memory:
|
||||||
|
def __init__(self, group_coordinator: GroupCoordinator):
|
||||||
|
if not is_symmetric_memory_enabled():
|
||||||
|
self.group_coordinator = None
|
||||||
|
self._mem_pool_ctx = None
|
||||||
|
self.is_graph_capture = None
|
||||||
|
self.device = None
|
||||||
|
self.pre_2_8_0 = None
|
||||||
|
else:
|
||||||
|
self.group_coordinator = group_coordinator
|
||||||
|
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
|
||||||
|
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
|
||||||
|
self.device = torch.cuda.current_device()
|
||||||
|
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if not is_symmetric_memory_enabled():
|
||||||
|
return self
|
||||||
|
assert (
|
||||||
|
self.group_coordinator.pynccl_comm is not None
|
||||||
|
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
|
||||||
|
assert (
|
||||||
|
self.group_coordinator.pynccl_comm.nccl_version >= 22703
|
||||||
|
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
|
||||||
|
if self.is_graph_capture:
|
||||||
|
assert (
|
||||||
|
_graph_pool_id is not None
|
||||||
|
), "graph_pool_id is not set under graph capture"
|
||||||
|
# Pause graph memory pool to use symmetric memory with cuda graph
|
||||||
|
if self.pre_2_8_0:
|
||||||
|
torch._C._cuda_endAllocateCurrentStreamToPool(
|
||||||
|
self.device, _graph_pool_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
|
||||||
|
self._mem_pool_ctx.__enter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def tag(self, tensor: torch.Tensor):
|
||||||
|
if not is_symmetric_memory_enabled():
|
||||||
|
return
|
||||||
|
tensor.symmetric_memory = True
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if not is_symmetric_memory_enabled():
|
||||||
|
return
|
||||||
|
global _registered_base_addrs
|
||||||
|
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
for segment in get_nccl_mem_pool().snapshot():
|
||||||
|
if segment["address"] not in _registered_base_addrs:
|
||||||
|
if segment["stream"] == 0 and self.pre_2_8_0:
|
||||||
|
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/152861
|
||||||
|
# Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b
|
||||||
|
# WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream
|
||||||
|
continue
|
||||||
|
self.group_coordinator.pynccl_comm.register_comm_window_raw(
|
||||||
|
segment["address"], segment["total_size"]
|
||||||
|
)
|
||||||
|
_registered_base_addrs.add(segment["address"])
|
||||||
|
|
||||||
|
if self.is_graph_capture:
|
||||||
|
if self.pre_2_8_0:
|
||||||
|
torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
|
||||||
|
else:
|
||||||
|
torch._C._cuda_beginAllocateCurrentThreadToPool(
|
||||||
|
self.device, _graph_pool_id
|
||||||
|
)
|
||||||
@@ -67,6 +67,7 @@ def find_nccl_library() -> str:
|
|||||||
|
|
||||||
ncclResult_t = ctypes.c_int
|
ncclResult_t = ctypes.c_int
|
||||||
ncclComm_t = ctypes.c_void_p
|
ncclComm_t = ctypes.c_void_p
|
||||||
|
ncclWindow_t = ctypes.c_void_p
|
||||||
|
|
||||||
|
|
||||||
class ncclUniqueId(ctypes.Structure):
|
class ncclUniqueId(ctypes.Structure):
|
||||||
@@ -279,6 +280,23 @@ class NCCLLibrary:
|
|||||||
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
exported_functions_symm_mem = [
|
||||||
|
# ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags);
|
||||||
|
Function(
|
||||||
|
"ncclCommWindowRegister",
|
||||||
|
ncclResult_t,
|
||||||
|
[
|
||||||
|
ncclComm_t,
|
||||||
|
buffer_type,
|
||||||
|
ctypes.c_size_t,
|
||||||
|
ctypes.POINTER(ncclWindow_t),
|
||||||
|
ctypes.c_int,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
# ncclResult_t ncclCommWindowDeregister(ncclComm_t comm, ncclWindow_t win);
|
||||||
|
Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]),
|
||||||
|
]
|
||||||
|
|
||||||
# class attribute to store the mapping from the path to the library
|
# class attribute to store the mapping from the path to the library
|
||||||
# to avoid loading the same library multiple times
|
# to avoid loading the same library multiple times
|
||||||
path_to_library_cache: Dict[str, Any] = {}
|
path_to_library_cache: Dict[str, Any] = {}
|
||||||
@@ -312,7 +330,10 @@ class NCCLLibrary:
|
|||||||
|
|
||||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||||
_funcs: Dict[str, Any] = {}
|
_funcs: Dict[str, Any] = {}
|
||||||
for func in NCCLLibrary.exported_functions:
|
exported_functions = NCCLLibrary.exported_functions
|
||||||
|
if hasattr(self.lib, "ncclCommWindowRegister"):
|
||||||
|
exported_functions.extend(NCCLLibrary.exported_functions_symm_mem)
|
||||||
|
for func in exported_functions:
|
||||||
f = getattr(self.lib, func.name)
|
f = getattr(self.lib, func.name)
|
||||||
f.restype = func.restype
|
f.restype = func.restype
|
||||||
f.argtypes = func.argtypes
|
f.argtypes = func.argtypes
|
||||||
@@ -328,10 +349,14 @@ class NCCLLibrary:
|
|||||||
error_str = self.ncclGetErrorString(result)
|
error_str = self.ncclGetErrorString(result)
|
||||||
raise RuntimeError(f"NCCL error: {error_str}")
|
raise RuntimeError(f"NCCL error: {error_str}")
|
||||||
|
|
||||||
def ncclGetVersion(self) -> str:
|
def ncclGetRawVersion(self) -> int:
|
||||||
version = ctypes.c_int()
|
version = ctypes.c_int()
|
||||||
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
||||||
version_str = str(version.value)
|
# something like 21903
|
||||||
|
return version.value
|
||||||
|
|
||||||
|
def ncclGetVersion(self) -> str:
|
||||||
|
version_str = str(self.ncclGetRawVersion())
|
||||||
# something like 21903 --> "2.19.3"
|
# something like 21903 --> "2.19.3"
|
||||||
major = version_str[0].lstrip("0")
|
major = version_str[0].lstrip("0")
|
||||||
minor = version_str[1:3].lstrip("0")
|
minor = version_str[1:3].lstrip("0")
|
||||||
@@ -460,6 +485,20 @@ class NCCLLibrary:
|
|||||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||||
|
|
||||||
|
def ncclCommWindowRegister(
|
||||||
|
self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int
|
||||||
|
) -> ncclWindow_t:
|
||||||
|
window = ncclWindow_t()
|
||||||
|
self.NCCL_CHECK(
|
||||||
|
self._funcs["ncclCommWindowRegister"](
|
||||||
|
comm, buff, size, ctypes.byref(window), win_flags
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return window
|
||||||
|
|
||||||
|
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"NCCLLibrary",
|
"NCCLLibrary",
|
||||||
|
|||||||
@@ -497,6 +497,17 @@ class GroupCoordinator:
|
|||||||
if self.npu_communicator is not None and not self.npu_communicator.disabled:
|
if self.npu_communicator is not None and not self.npu_communicator.disabled:
|
||||||
return self.npu_communicator.all_reduce(input_)
|
return self.npu_communicator.all_reduce(input_)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.pynccl_comm is not None
|
||||||
|
and hasattr(input_, "symmetric_memory")
|
||||||
|
and input_.symmetric_memory
|
||||||
|
):
|
||||||
|
with self.pynccl_comm.change_state(
|
||||||
|
enable=True, stream=torch.cuda.current_stream()
|
||||||
|
):
|
||||||
|
self.pynccl_comm.all_reduce(input_)
|
||||||
|
return input_
|
||||||
|
|
||||||
outplace_all_reduce_method = None
|
outplace_all_reduce_method = None
|
||||||
if (
|
if (
|
||||||
self.qr_comm is not None
|
self.qr_comm is not None
|
||||||
|
|||||||
@@ -623,8 +623,9 @@ class Engine(EngineBase):
|
|||||||
def _set_envs_and_config(server_args: ServerArgs):
|
def _set_envs_and_config(server_args: ServerArgs):
|
||||||
# Set global environments
|
# Set global environments
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
|
||||||
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
|
if not server_args.enable_symm_mem:
|
||||||
|
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
|
||||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
||||||
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
||||||
|
|||||||
@@ -13,10 +13,14 @@ from sglang.srt.distributed import (
|
|||||||
divide,
|
divide,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
|
parallel_state,
|
||||||
split_tensor_along_last_dim,
|
split_tensor_along_last_dim,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
use_symmetric_memory,
|
||||||
|
)
|
||||||
from sglang.srt.layers.parameter import (
|
from sglang.srt.layers.parameter import (
|
||||||
BasevLLMParameter,
|
BasevLLMParameter,
|
||||||
BlockQuantScaleParameter,
|
BlockQuantScaleParameter,
|
||||||
@@ -1292,7 +1296,9 @@ class RowParallelLinear(LinearBase):
|
|||||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||||
# bias will not get added more than once in TP>1 case)
|
# bias will not get added more than once in TP>1 case)
|
||||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||||
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||||
|
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||||
|
sm.tag(output_parallel)
|
||||||
if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -14,8 +14,12 @@ from sglang.srt.distributed import (
|
|||||||
get_moe_expert_parallel_world_size,
|
get_moe_expert_parallel_world_size,
|
||||||
get_moe_tensor_parallel_rank,
|
get_moe_tensor_parallel_rank,
|
||||||
get_moe_tensor_parallel_world_size,
|
get_moe_tensor_parallel_world_size,
|
||||||
|
get_tp_group,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
use_symmetric_memory,
|
||||||
|
)
|
||||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
@@ -626,24 +630,27 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
final_hidden_states = self.quant_method.apply(
|
with use_symmetric_memory(get_tp_group()) as sm:
|
||||||
layer=self,
|
final_hidden_states = self.quant_method.apply(
|
||||||
x=hidden_states,
|
layer=self,
|
||||||
topk_output=topk_output,
|
x=hidden_states,
|
||||||
activation=self.activation,
|
topk_output=topk_output,
|
||||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
activation=self.activation,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||||
**(
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
dict(
|
**(
|
||||||
tp_rank=self.moe_tp_rank,
|
dict(
|
||||||
tp_size=self.moe_tp_size,
|
tp_rank=self.moe_tp_rank,
|
||||||
ep_rank=self.moe_ep_rank,
|
tp_size=self.moe_tp_size,
|
||||||
ep_size=self.moe_ep_size,
|
ep_rank=self.moe_ep_rank,
|
||||||
)
|
ep_size=self.moe_ep_size,
|
||||||
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
|
)
|
||||||
else {}
|
if self.quant_method.__class__.__name__
|
||||||
),
|
== "ModelOptNvFp4FusedMoEMethod"
|
||||||
)
|
else {}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
sm.tag(final_hidden_states)
|
||||||
|
|
||||||
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|||||||
@@ -11,8 +11,12 @@ from sglang.srt.distributed import (
|
|||||||
divide,
|
divide,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
|
parallel_state,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
use_symmetric_memory,
|
||||||
|
)
|
||||||
from sglang.srt.layers.amx_utils import PackWeightMethod
|
from sglang.srt.layers.amx_utils import PackWeightMethod
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
||||||
from sglang.srt.layers.parameter import BasevLLMParameter
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
||||||
@@ -464,7 +468,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
masked_input = input_
|
masked_input = input_
|
||||||
# Get the embeddings.
|
# Get the embeddings.
|
||||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||||
|
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||||
|
sm.tag(output_parallel)
|
||||||
# Mask the output embedding.
|
# Mask the output embedding.
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"weight_loader_disable_mmap",
|
"weight_loader_disable_mmap",
|
||||||
"enable_triton_kernel_moe",
|
"enable_triton_kernel_moe",
|
||||||
"enable_multimodal",
|
"enable_multimodal",
|
||||||
|
"enable_symm_mem",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Put some global args for easy access
|
# Put some global args for easy access
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ from torch.profiler import ProfilerActivity, profile
|
|||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
set_graph_pool_id,
|
||||||
|
)
|
||||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||||
from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
|
from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
@@ -643,11 +646,15 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
run_once()
|
run_once()
|
||||||
|
|
||||||
global global_graph_memory_pool
|
if get_global_graph_memory_pool() is None:
|
||||||
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
|
set_global_graph_memory_pool(torch.cuda.graph_pool_handle())
|
||||||
|
# Set graph pool id globally to be able to use symmetric memory
|
||||||
|
set_graph_pool_id(get_global_graph_memory_pool())
|
||||||
|
with torch.cuda.graph(
|
||||||
|
graph, pool=get_global_graph_memory_pool(), stream=stream
|
||||||
|
):
|
||||||
out = run_once()
|
out = run_once()
|
||||||
|
|
||||||
global_graph_memory_pool = graph.pool()
|
|
||||||
return graph, out
|
return graph, out
|
||||||
|
|
||||||
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
||||||
|
|||||||
@@ -34,6 +34,9 @@ from sglang.srt.distributed import (
|
|||||||
parallel_state,
|
parallel_state,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
use_symmetric_memory,
|
||||||
|
)
|
||||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||||
@@ -481,7 +484,11 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
if not _is_cuda:
|
if not _is_cuda:
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
current_stream.wait_stream(self.alt_stream)
|
current_stream.wait_stream(self.alt_stream)
|
||||||
final_hidden_states += shared_output
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||||
|
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
||||||
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||||
|
final_hidden_states = final_hidden_states_out
|
||||||
|
sm.tag(final_hidden_states)
|
||||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
@@ -507,7 +514,11 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
# fused in biased_grouped_topk so we can skip here
|
# fused in biased_grouped_topk so we can skip here
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||||
|
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
||||||
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||||
|
final_hidden_states = final_hidden_states_out
|
||||||
|
sm.tag(final_hidden_states)
|
||||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|||||||
@@ -218,6 +218,7 @@ class ServerArgs:
|
|||||||
enable_profile_cuda_graph: bool = False
|
enable_profile_cuda_graph: bool = False
|
||||||
enable_cudagraph_gc: bool = False
|
enable_cudagraph_gc: bool = False
|
||||||
enable_nccl_nvls: bool = False
|
enable_nccl_nvls: bool = False
|
||||||
|
enable_symm_mem: bool = False
|
||||||
enable_tokenizer_batch_encode: bool = False
|
enable_tokenizer_batch_encode: bool = False
|
||||||
disable_outlines_disk_cache: bool = False
|
disable_outlines_disk_cache: bool = False
|
||||||
disable_custom_all_reduce: bool = False
|
disable_custom_all_reduce: bool = False
|
||||||
@@ -1599,6 +1600,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable NCCL NVLS for prefill heavy requests when available.",
|
help="Enable NCCL NVLS for prefill heavy requests when available.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-symm-mem",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable NCCL symmetric memory for fast collectives.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-tokenizer-batch-encode",
|
"--enable-tokenizer-batch-encode",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user