[5/n] DP Enhancement: Correct num_token_non_padded (#9107)
This commit is contained in:
@@ -33,7 +33,11 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
|||||||
set_graph_pool_id,
|
set_graph_pool_id,
|
||||||
)
|
)
|
||||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||||
from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
DPPaddingMode,
|
||||||
|
get_attention_tp_rank,
|
||||||
|
get_attention_tp_size,
|
||||||
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
@@ -255,6 +259,9 @@ class CudaGraphRunner:
|
|||||||
self.dp_size = model_runner.server_args.dp_size
|
self.dp_size = model_runner.server_args.dp_size
|
||||||
self.pp_size = model_runner.server_args.pp_size
|
self.pp_size = model_runner.server_args.pp_size
|
||||||
|
|
||||||
|
self.attn_tp_size = get_attention_tp_size()
|
||||||
|
self.attn_tp_rank = get_attention_tp_rank()
|
||||||
|
|
||||||
# Batch sizes to capture
|
# Batch sizes to capture
|
||||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||||
rank0_log(f"Capture cuda graph bs {self.capture_bs}")
|
rank0_log(f"Capture cuda graph bs {self.capture_bs}")
|
||||||
@@ -749,7 +756,17 @@ class CudaGraphRunner:
|
|||||||
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
||||||
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
||||||
if enable_num_token_non_padded(self.model_runner.server_args):
|
if enable_num_token_non_padded(self.model_runner.server_args):
|
||||||
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
num_token_non_padded = forward_batch.num_token_non_padded
|
||||||
|
if self.require_gathered_buffer:
|
||||||
|
tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs
|
||||||
|
num_local_token_non_padded = torch.clamp(
|
||||||
|
num_token_non_padded - tokens_per_rank * self.attn_tp_rank,
|
||||||
|
min=0,
|
||||||
|
max=tokens_per_rank,
|
||||||
|
)
|
||||||
|
self.num_token_non_padded.copy_(num_local_token_non_padded)
|
||||||
|
else:
|
||||||
|
self.num_token_non_padded.copy_(num_token_non_padded)
|
||||||
if self.enable_two_batch_overlap:
|
if self.enable_two_batch_overlap:
|
||||||
self.tbo_plugin.replay_prepare(
|
self.tbo_plugin.replay_prepare(
|
||||||
forward_mode=self.capture_forward_mode,
|
forward_mode=self.capture_forward_mode,
|
||||||
|
|||||||
Reference in New Issue
Block a user