Fix CUDA Graph Check under Deepep with DP FFN (#7451)
This commit is contained in:
@@ -48,6 +48,7 @@ from sglang.srt.utils import (
|
||||
rank0_log,
|
||||
require_attn_tp_gather,
|
||||
require_gathered_buffer,
|
||||
require_mlp_sync,
|
||||
require_mlp_tp_gather,
|
||||
)
|
||||
|
||||
@@ -212,6 +213,7 @@ class CudaGraphRunner:
|
||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
||||
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||
self.enable_two_batch_overlap = (
|
||||
model_runner.server_args.enable_two_batch_overlap
|
||||
@@ -337,22 +339,22 @@ class CudaGraphRunner:
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.require_mlp_tp_gather:
|
||||
total_batch_size = (
|
||||
cuda_graph_bs = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
else sum(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
||||
total_batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else total_batch_size <= self.max_bs
|
||||
)
|
||||
else:
|
||||
is_bs_supported = (
|
||||
forward_batch.batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else forward_batch.batch_size <= self.max_bs
|
||||
)
|
||||
cuda_graph_bs = forward_batch.batch_size
|
||||
|
||||
is_bs_supported = (
|
||||
cuda_graph_bs in self.graphs
|
||||
if self.disable_padding
|
||||
else cuda_graph_bs <= self.max_bs
|
||||
)
|
||||
|
||||
if self.require_mlp_sync:
|
||||
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
||||
|
||||
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
||||
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
||||
|
||||
@@ -23,6 +23,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
from sglang.srt.utils import (
|
||||
require_attn_tp_gather,
|
||||
require_gathered_buffer,
|
||||
require_mlp_sync,
|
||||
require_mlp_tp_gather,
|
||||
)
|
||||
|
||||
@@ -46,6 +47,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
||||
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||
self.dp_size = self.model_runner.dp_size
|
||||
self.tp_size = self.model_runner.tp_size
|
||||
@@ -127,24 +129,23 @@ class EAGLEDraftCudaGraphRunner:
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.require_mlp_tp_gather:
|
||||
if not forward_batch.can_run_dp_cuda_graph:
|
||||
return False
|
||||
total_batch_size = (
|
||||
cuda_graph_bs = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
else sum(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
is_bs_supported = (
|
||||
total_batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else total_batch_size <= self.max_bs
|
||||
)
|
||||
else:
|
||||
is_bs_supported = (
|
||||
forward_batch.batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else forward_batch.batch_size <= self.max_bs
|
||||
)
|
||||
cuda_graph_bs = forward_batch.batch_size
|
||||
|
||||
is_bs_supported = (
|
||||
cuda_graph_bs in self.graphs
|
||||
if self.disable_padding
|
||||
else cuda_graph_bs <= self.max_bs
|
||||
)
|
||||
|
||||
if self.require_mlp_sync:
|
||||
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
||||
|
||||
return is_bs_supported
|
||||
|
||||
def capture(self):
|
||||
|
||||
@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
|
||||
from sglang.srt.utils import (
|
||||
require_attn_tp_gather,
|
||||
require_gathered_buffer,
|
||||
require_mlp_sync,
|
||||
require_mlp_tp_gather,
|
||||
)
|
||||
|
||||
@@ -42,6 +43,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
||||
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||
self.tp_size = self.model_runner.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
@@ -130,29 +132,24 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.require_mlp_tp_gather:
|
||||
if not forward_batch.can_run_dp_cuda_graph:
|
||||
return False
|
||||
total_batch_size = (
|
||||
cuda_graph_bs = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
else sum(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
is_bs_supported = (
|
||||
total_batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else total_batch_size <= self.max_bs
|
||||
)
|
||||
return is_bs_supported
|
||||
else:
|
||||
batch_size = forward_batch.seq_lens.numel()
|
||||
cuda_graph_bs = forward_batch.seq_lens.numel()
|
||||
|
||||
is_bs_supported = (
|
||||
batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else batch_size <= self.max_bs
|
||||
)
|
||||
is_bs_supported = (
|
||||
cuda_graph_bs in self.graphs
|
||||
if self.disable_padding
|
||||
else cuda_graph_bs <= self.max_bs
|
||||
)
|
||||
|
||||
return is_bs_supported
|
||||
if self.require_mlp_sync:
|
||||
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
||||
|
||||
return is_bs_supported
|
||||
|
||||
def capture(self):
|
||||
CudaGraphRunner.capture(self)
|
||||
|
||||
Reference in New Issue
Block a user