Add support for NCCL symmetric memory for TP allreduces (#8238)
This commit is contained in:
@@ -29,6 +29,9 @@ from torch.profiler import ProfilerActivity, profile
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||
from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
@@ -643,11 +646,15 @@ class CudaGraphRunner:
|
||||
|
||||
run_once()
|
||||
|
||||
global global_graph_memory_pool
|
||||
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
|
||||
if get_global_graph_memory_pool() is None:
|
||||
set_global_graph_memory_pool(torch.cuda.graph_pool_handle())
|
||||
# Set graph pool id globally to be able to use symmetric memory
|
||||
set_graph_pool_id(get_global_graph_memory_pool())
|
||||
with torch.cuda.graph(
|
||||
graph, pool=get_global_graph_memory_pool(), stream=stream
|
||||
):
|
||||
out = run_once()
|
||||
|
||||
global_graph_memory_pool = graph.pool()
|
||||
return graph, out
|
||||
|
||||
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
||||
|
||||
Reference in New Issue
Block a user