[5/n] DP Enhancement: Correct num_token_non_padded (#9107)
This commit is contained in:
@@ -33,7 +33,11 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||
from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
DPPaddingMode,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
@@ -255,6 +259,9 @@ class CudaGraphRunner:
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
self.pp_size = model_runner.server_args.pp_size
|
||||
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
|
||||
# Batch sizes to capture
|
||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||
rank0_log(f"Capture cuda graph bs {self.capture_bs}")
|
||||
@@ -749,7 +756,17 @@ class CudaGraphRunner:
|
||||
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
||||
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
||||
if enable_num_token_non_padded(self.model_runner.server_args):
|
||||
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
||||
num_token_non_padded = forward_batch.num_token_non_padded
|
||||
if self.require_gathered_buffer:
|
||||
tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs
|
||||
num_local_token_non_padded = torch.clamp(
|
||||
num_token_non_padded - tokens_per_rank * self.attn_tp_rank,
|
||||
min=0,
|
||||
max=tokens_per_rank,
|
||||
)
|
||||
self.num_token_non_padded.copy_(num_local_token_non_padded)
|
||||
else:
|
||||
self.num_token_non_padded.copy_(num_token_non_padded)
|
||||
if self.enable_two_batch_overlap:
|
||||
self.tbo_plugin.replay_prepare(
|
||||
forward_mode=self.capture_forward_mode,
|
||||
|
||||
Reference in New Issue
Block a user