Fix CUDA Graph Check under Deepep with DP FFN (#7451)
This commit is contained in:
@@ -48,6 +48,7 @@ from sglang.srt.utils import (
|
|||||||
rank0_log,
|
rank0_log,
|
||||||
require_attn_tp_gather,
|
require_attn_tp_gather,
|
||||||
require_gathered_buffer,
|
require_gathered_buffer,
|
||||||
|
require_mlp_sync,
|
||||||
require_mlp_tp_gather,
|
require_mlp_tp_gather,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -212,6 +213,7 @@ class CudaGraphRunner:
|
|||||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||||
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||||
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||||
|
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
||||||
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||||
self.enable_two_batch_overlap = (
|
self.enable_two_batch_overlap = (
|
||||||
model_runner.server_args.enable_two_batch_overlap
|
model_runner.server_args.enable_two_batch_overlap
|
||||||
@@ -337,23 +339,23 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
if self.require_mlp_tp_gather:
|
if self.require_mlp_tp_gather:
|
||||||
total_batch_size = (
|
cuda_graph_bs = (
|
||||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||||
if self.model_runner.spec_algorithm.is_eagle()
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
else sum(forward_batch.global_num_tokens_cpu)
|
else sum(forward_batch.global_num_tokens_cpu)
|
||||||
)
|
)
|
||||||
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
|
||||||
total_batch_size in self.graphs
|
|
||||||
if self.disable_padding
|
|
||||||
else total_batch_size <= self.max_bs
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
cuda_graph_bs = forward_batch.batch_size
|
||||||
|
|
||||||
is_bs_supported = (
|
is_bs_supported = (
|
||||||
forward_batch.batch_size in self.graphs
|
cuda_graph_bs in self.graphs
|
||||||
if self.disable_padding
|
if self.disable_padding
|
||||||
else forward_batch.batch_size <= self.max_bs
|
else cuda_graph_bs <= self.max_bs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.require_mlp_sync:
|
||||||
|
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
||||||
|
|
||||||
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
||||||
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
||||||
# because the full_text_row_masked_out_mask tensor will always be ones
|
# because the full_text_row_masked_out_mask tensor will always be ones
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
require_attn_tp_gather,
|
require_attn_tp_gather,
|
||||||
require_gathered_buffer,
|
require_gathered_buffer,
|
||||||
|
require_mlp_sync,
|
||||||
require_mlp_tp_gather,
|
require_mlp_tp_gather,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -46,6 +47,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||||
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||||
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||||
|
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
||||||
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||||
self.dp_size = self.model_runner.dp_size
|
self.dp_size = self.model_runner.dp_size
|
||||||
self.tp_size = self.model_runner.tp_size
|
self.tp_size = self.model_runner.tp_size
|
||||||
@@ -127,24 +129,23 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
if self.require_mlp_tp_gather:
|
if self.require_mlp_tp_gather:
|
||||||
if not forward_batch.can_run_dp_cuda_graph:
|
cuda_graph_bs = (
|
||||||
return False
|
|
||||||
total_batch_size = (
|
|
||||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||||
if self.model_runner.spec_algorithm.is_eagle()
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
else sum(forward_batch.global_num_tokens_cpu)
|
else sum(forward_batch.global_num_tokens_cpu)
|
||||||
)
|
)
|
||||||
is_bs_supported = (
|
|
||||||
total_batch_size in self.graphs
|
|
||||||
if self.disable_padding
|
|
||||||
else total_batch_size <= self.max_bs
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
cuda_graph_bs = forward_batch.batch_size
|
||||||
|
|
||||||
is_bs_supported = (
|
is_bs_supported = (
|
||||||
forward_batch.batch_size in self.graphs
|
cuda_graph_bs in self.graphs
|
||||||
if self.disable_padding
|
if self.disable_padding
|
||||||
else forward_batch.batch_size <= self.max_bs
|
else cuda_graph_bs <= self.max_bs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.require_mlp_sync:
|
||||||
|
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
||||||
|
|
||||||
return is_bs_supported
|
return is_bs_supported
|
||||||
|
|
||||||
def capture(self):
|
def capture(self):
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
require_attn_tp_gather,
|
require_attn_tp_gather,
|
||||||
require_gathered_buffer,
|
require_gathered_buffer,
|
||||||
|
require_mlp_sync,
|
||||||
require_mlp_tp_gather,
|
require_mlp_tp_gather,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -42,6 +43,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||||
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||||
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||||
|
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
||||||
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||||
self.tp_size = self.model_runner.tp_size
|
self.tp_size = self.model_runner.tp_size
|
||||||
self.dp_size = model_runner.server_args.dp_size
|
self.dp_size = model_runner.server_args.dp_size
|
||||||
@@ -130,28 +132,23 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
if self.require_mlp_tp_gather:
|
if self.require_mlp_tp_gather:
|
||||||
if not forward_batch.can_run_dp_cuda_graph:
|
cuda_graph_bs = (
|
||||||
return False
|
|
||||||
total_batch_size = (
|
|
||||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||||
if self.model_runner.spec_algorithm.is_eagle()
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
else sum(forward_batch.global_num_tokens_cpu)
|
else sum(forward_batch.global_num_tokens_cpu)
|
||||||
)
|
)
|
||||||
is_bs_supported = (
|
|
||||||
total_batch_size in self.graphs
|
|
||||||
if self.disable_padding
|
|
||||||
else total_batch_size <= self.max_bs
|
|
||||||
)
|
|
||||||
return is_bs_supported
|
|
||||||
else:
|
else:
|
||||||
batch_size = forward_batch.seq_lens.numel()
|
cuda_graph_bs = forward_batch.seq_lens.numel()
|
||||||
|
|
||||||
is_bs_supported = (
|
is_bs_supported = (
|
||||||
batch_size in self.graphs
|
cuda_graph_bs in self.graphs
|
||||||
if self.disable_padding
|
if self.disable_padding
|
||||||
else batch_size <= self.max_bs
|
else cuda_graph_bs <= self.max_bs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.require_mlp_sync:
|
||||||
|
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
||||||
|
|
||||||
return is_bs_supported
|
return is_bs_supported
|
||||||
|
|
||||||
def capture(self):
|
def capture(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user