Manually flip deepep_mode for cuda_graph (#11666)
This commit is contained in:
@@ -235,6 +235,15 @@ class DeepEPBuffer:
|
||||
cls.clean_buffer()
|
||||
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
||||
|
||||
@classmethod
|
||||
def set_dispatch_mode(cls, mode: DeepEPMode):
|
||||
if mode.is_low_latency():
|
||||
cls.set_dispatch_mode_as_low_latency()
|
||||
elif mode.is_normal():
|
||||
cls.set_dispatch_mode_as_normal()
|
||||
else:
|
||||
raise Exception("unsupported mode")
|
||||
|
||||
|
||||
class DeepEPConfig(BaseDispatcherConfig):
|
||||
_instance = None
|
||||
|
||||
@@ -40,6 +40,8 @@ from sglang.srt.layers.dp_attention import (
|
||||
set_dp_buffer_len,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer
|
||||
from sglang.srt.layers.moe.utils import get_deepep_mode, get_moe_a2a_backend
|
||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
@@ -240,6 +242,8 @@ class CudaGraphRunner:
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
|
||||
self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
|
||||
|
||||
# Batch sizes to capture
|
||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||
log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
|
||||
@@ -653,6 +657,8 @@ class CudaGraphRunner:
|
||||
)
|
||||
return logits_output_or_pp_proxy_tensors
|
||||
|
||||
self.deepep_adapter.capture(is_extend_in_batch=False)
|
||||
|
||||
for _ in range(2):
|
||||
self.device_module.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
@@ -796,6 +802,8 @@ class CudaGraphRunner:
|
||||
skip_attn_backend_init: bool = False,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
||||
self.deepep_adapter.replay()
|
||||
|
||||
if not skip_attn_backend_init:
|
||||
self.replay_prepare(forward_batch, pp_proxy_tensors)
|
||||
else:
|
||||
@@ -872,3 +880,23 @@ CUDA_GRAPH_CAPTURE_FAILED_MSG = (
|
||||
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
|
||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||
)
|
||||
|
||||
|
||||
class DeepEPCudaGraphRunnerAdapter:
|
||||
def __init__(self):
|
||||
# Record DeepEP mode used during capture to ensure replay consistency
|
||||
self._captured_deepep_mode = None
|
||||
|
||||
def capture(self, is_extend_in_batch: bool):
|
||||
if not get_moe_a2a_backend().is_deepep():
|
||||
return
|
||||
self._captured_deepep_mode = get_deepep_mode().resolve(
|
||||
is_extend_in_batch=is_extend_in_batch
|
||||
)
|
||||
DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode)
|
||||
|
||||
def replay(self):
|
||||
if not get_moe_a2a_backend().is_deepep():
|
||||
return
|
||||
assert self._captured_deepep_mode is not None
|
||||
DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode)
|
||||
|
||||
@@ -9,6 +9,7 @@ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||
CudaGraphRunner,
|
||||
DeepEPCudaGraphRunnerAdapter,
|
||||
get_batch_sizes_to_capture,
|
||||
get_global_graph_memory_pool,
|
||||
model_capture_mode,
|
||||
@@ -61,6 +62,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.enable_profile_cuda_graph = (
|
||||
model_runner.server_args.enable_profile_cuda_graph
|
||||
)
|
||||
self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
|
||||
server_args = model_runner.server_args
|
||||
|
||||
# Batch sizes to capture
|
||||
@@ -264,6 +266,8 @@ class EAGLEDraftCudaGraphRunner:
|
||||
forward_batch.spec_info.hidden_states = hidden_states_backup
|
||||
return ret
|
||||
|
||||
self.deepep_adapter.capture(is_extend_in_batch=False)
|
||||
|
||||
for _ in range(2):
|
||||
torch.cuda.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
@@ -285,6 +289,8 @@ class EAGLEDraftCudaGraphRunner:
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch):
|
||||
assert forward_batch.out_cache_loc is not None
|
||||
self.deepep_adapter.replay()
|
||||
|
||||
raw_bs = forward_batch.batch_size
|
||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||
CudaGraphRunner,
|
||||
DeepEPCudaGraphRunnerAdapter,
|
||||
LogitsProcessorOutput,
|
||||
get_batch_sizes_to_capture,
|
||||
get_global_graph_memory_pool,
|
||||
@@ -61,6 +62,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
)
|
||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||
self.padded_static_len = -1
|
||||
self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
|
||||
|
||||
# Attention backend
|
||||
self.num_tokens_per_bs = self.speculative_num_steps + 1
|
||||
@@ -243,6 +245,8 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
)
|
||||
spec_info.positions = None
|
||||
|
||||
self.deepep_adapter.capture(is_extend_in_batch=True)
|
||||
|
||||
# Forward batch
|
||||
forward_batch = ForwardBatch(
|
||||
forward_mode=ForwardMode.DRAFT_EXTEND,
|
||||
@@ -318,6 +322,8 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch):
|
||||
assert forward_batch.out_cache_loc is not None
|
||||
self.deepep_adapter.replay()
|
||||
|
||||
# batch_size and num_seqs can be different in case there are finished examples
|
||||
# in the batch, which will not be counted as num_seqs
|
||||
raw_bs = forward_batch.batch_size
|
||||
|
||||
Reference in New Issue
Block a user