Manually flip deepep_mode for cuda_graph (#11666)
This commit is contained in:
@@ -235,6 +235,15 @@ class DeepEPBuffer:
|
|||||||
cls.clean_buffer()
|
cls.clean_buffer()
|
||||||
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_dispatch_mode(cls, mode: DeepEPMode):
|
||||||
|
if mode.is_low_latency():
|
||||||
|
cls.set_dispatch_mode_as_low_latency()
|
||||||
|
elif mode.is_normal():
|
||||||
|
cls.set_dispatch_mode_as_normal()
|
||||||
|
else:
|
||||||
|
raise Exception("unsupported mode")
|
||||||
|
|
||||||
|
|
||||||
class DeepEPConfig(BaseDispatcherConfig):
|
class DeepEPConfig(BaseDispatcherConfig):
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
set_dp_buffer_len,
|
set_dp_buffer_len,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer
|
||||||
|
from sglang.srt.layers.moe.utils import get_deepep_mode, get_moe_a2a_backend
|
||||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
CaptureHiddenMode,
|
CaptureHiddenMode,
|
||||||
@@ -240,6 +242,8 @@ class CudaGraphRunner:
|
|||||||
self.attn_tp_size = get_attention_tp_size()
|
self.attn_tp_size = get_attention_tp_size()
|
||||||
self.attn_tp_rank = get_attention_tp_rank()
|
self.attn_tp_rank = get_attention_tp_rank()
|
||||||
|
|
||||||
|
self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
|
||||||
|
|
||||||
# Batch sizes to capture
|
# Batch sizes to capture
|
||||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||||
log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
|
log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
|
||||||
@@ -653,6 +657,8 @@ class CudaGraphRunner:
|
|||||||
)
|
)
|
||||||
return logits_output_or_pp_proxy_tensors
|
return logits_output_or_pp_proxy_tensors
|
||||||
|
|
||||||
|
self.deepep_adapter.capture(is_extend_in_batch=False)
|
||||||
|
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
self.device_module.synchronize()
|
self.device_module.synchronize()
|
||||||
self.model_runner.tp_group.barrier()
|
self.model_runner.tp_group.barrier()
|
||||||
@@ -796,6 +802,8 @@ class CudaGraphRunner:
|
|||||||
skip_attn_backend_init: bool = False,
|
skip_attn_backend_init: bool = False,
|
||||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
||||||
|
self.deepep_adapter.replay()
|
||||||
|
|
||||||
if not skip_attn_backend_init:
|
if not skip_attn_backend_init:
|
||||||
self.replay_prepare(forward_batch, pp_proxy_tensors)
|
self.replay_prepare(forward_batch, pp_proxy_tensors)
|
||||||
else:
|
else:
|
||||||
@@ -872,3 +880,23 @@ CUDA_GRAPH_CAPTURE_FAILED_MSG = (
|
|||||||
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
|
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
|
||||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepEPCudaGraphRunnerAdapter:
|
||||||
|
def __init__(self):
|
||||||
|
# Record DeepEP mode used during capture to ensure replay consistency
|
||||||
|
self._captured_deepep_mode = None
|
||||||
|
|
||||||
|
def capture(self, is_extend_in_batch: bool):
|
||||||
|
if not get_moe_a2a_backend().is_deepep():
|
||||||
|
return
|
||||||
|
self._captured_deepep_mode = get_deepep_mode().resolve(
|
||||||
|
is_extend_in_batch=is_extend_in_batch
|
||||||
|
)
|
||||||
|
DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode)
|
||||||
|
|
||||||
|
def replay(self):
|
||||||
|
if not get_moe_a2a_backend().is_deepep():
|
||||||
|
return
|
||||||
|
assert self._captured_deepep_mode is not None
|
||||||
|
DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
|||||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||||
CudaGraphRunner,
|
CudaGraphRunner,
|
||||||
|
DeepEPCudaGraphRunnerAdapter,
|
||||||
get_batch_sizes_to_capture,
|
get_batch_sizes_to_capture,
|
||||||
get_global_graph_memory_pool,
|
get_global_graph_memory_pool,
|
||||||
model_capture_mode,
|
model_capture_mode,
|
||||||
@@ -61,6 +62,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
self.enable_profile_cuda_graph = (
|
self.enable_profile_cuda_graph = (
|
||||||
model_runner.server_args.enable_profile_cuda_graph
|
model_runner.server_args.enable_profile_cuda_graph
|
||||||
)
|
)
|
||||||
|
self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
|
||||||
server_args = model_runner.server_args
|
server_args = model_runner.server_args
|
||||||
|
|
||||||
# Batch sizes to capture
|
# Batch sizes to capture
|
||||||
@@ -264,6 +266,8 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
forward_batch.spec_info.hidden_states = hidden_states_backup
|
forward_batch.spec_info.hidden_states = hidden_states_backup
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
self.deepep_adapter.capture(is_extend_in_batch=False)
|
||||||
|
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
self.model_runner.tp_group.barrier()
|
self.model_runner.tp_group.barrier()
|
||||||
@@ -285,6 +289,8 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
|
|
||||||
def replay(self, forward_batch: ForwardBatch):
|
def replay(self, forward_batch: ForwardBatch):
|
||||||
assert forward_batch.out_cache_loc is not None
|
assert forward_batch.out_cache_loc is not None
|
||||||
|
self.deepep_adapter.replay()
|
||||||
|
|
||||||
raw_bs = forward_batch.batch_size
|
raw_bs = forward_batch.batch_size
|
||||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
|||||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||||
CudaGraphRunner,
|
CudaGraphRunner,
|
||||||
|
DeepEPCudaGraphRunnerAdapter,
|
||||||
LogitsProcessorOutput,
|
LogitsProcessorOutput,
|
||||||
get_batch_sizes_to_capture,
|
get_batch_sizes_to_capture,
|
||||||
get_global_graph_memory_pool,
|
get_global_graph_memory_pool,
|
||||||
@@ -61,6 +62,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
)
|
)
|
||||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||||
self.padded_static_len = -1
|
self.padded_static_len = -1
|
||||||
|
self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.num_tokens_per_bs = self.speculative_num_steps + 1
|
self.num_tokens_per_bs = self.speculative_num_steps + 1
|
||||||
@@ -243,6 +245,8 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
)
|
)
|
||||||
spec_info.positions = None
|
spec_info.positions = None
|
||||||
|
|
||||||
|
self.deepep_adapter.capture(is_extend_in_batch=True)
|
||||||
|
|
||||||
# Forward batch
|
# Forward batch
|
||||||
forward_batch = ForwardBatch(
|
forward_batch = ForwardBatch(
|
||||||
forward_mode=ForwardMode.DRAFT_EXTEND,
|
forward_mode=ForwardMode.DRAFT_EXTEND,
|
||||||
@@ -318,6 +322,8 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
|
|
||||||
def replay(self, forward_batch: ForwardBatch):
|
def replay(self, forward_batch: ForwardBatch):
|
||||||
assert forward_batch.out_cache_loc is not None
|
assert forward_batch.out_cache_loc is not None
|
||||||
|
self.deepep_adapter.replay()
|
||||||
|
|
||||||
# batch_size and num_seqs can be different in case there are finished examples
|
# batch_size and num_seqs can be different in case there are finished examples
|
||||||
# in the batch, which will not be counted as num_seqs
|
# in the batch, which will not be counted as num_seqs
|
||||||
raw_bs = forward_batch.batch_size
|
raw_bs = forward_batch.batch_size
|
||||||
|
|||||||
Reference in New Issue
Block a user