Remove unnecessary kernels of num_token_non_padded (#6965)
This commit is contained in:
@@ -35,6 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
enable_num_token_non_padded,
|
||||
)
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
||||
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
||||
@@ -190,6 +191,9 @@ class CudaGraphRunner:
|
||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
||||
self.enable_two_batch_overlap = (
|
||||
model_runner.server_args.enable_two_batch_overlap
|
||||
)
|
||||
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
||||
self.tp_size = model_runner.server_args.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
@@ -327,9 +331,7 @@ class CudaGraphRunner:
|
||||
)
|
||||
|
||||
is_tbo_supported = (
|
||||
forward_batch.can_run_tbo
|
||||
if self.model_runner.server_args.enable_two_batch_overlap
|
||||
else True
|
||||
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
||||
)
|
||||
|
||||
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
|
||||
@@ -549,13 +551,7 @@ class CudaGraphRunner:
|
||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
||||
num_token_non_padded = len(forward_batch.input_ids)
|
||||
self.num_token_non_padded[...] = num_token_non_padded
|
||||
self.tbo_plugin.replay_prepare(
|
||||
forward_mode=forward_batch.forward_mode,
|
||||
bs=bs,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
)
|
||||
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(1)
|
||||
@@ -572,6 +568,14 @@ class CudaGraphRunner:
|
||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
if enable_num_token_non_padded(self.model_runner.server_args):
|
||||
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
||||
if self.enable_two_batch_overlap:
|
||||
self.tbo_plugin.replay_prepare(
|
||||
forward_mode=forward_batch.forward_mode,
|
||||
bs=bs,
|
||||
num_token_non_padded=len(forward_batch.input_ids),
|
||||
)
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
|
||||
@@ -118,6 +118,7 @@ class ForwardMode(IntEnum):
|
||||
|
||||
|
||||
class CaptureHiddenMode(IntEnum):
|
||||
# Do not capture anything.
|
||||
NULL = auto()
|
||||
# Capture hidden states of all tokens.
|
||||
FULL = auto()
|
||||
@@ -253,6 +254,7 @@ class ForwardBatch:
|
||||
# For Qwen2-VL
|
||||
mrope_positions: torch.Tensor = None
|
||||
|
||||
# For two-batch overlap
|
||||
tbo_split_seq_index: Optional[int] = None
|
||||
tbo_parent_token_range: Optional[Tuple[int, int]] = None
|
||||
tbo_children: Optional[List["ForwardBatch"]] = None
|
||||
@@ -265,12 +267,6 @@ class ForwardBatch:
|
||||
):
|
||||
from sglang.srt.two_batch_overlap import TboForwardBatchPreparer
|
||||
|
||||
device = model_runner.device
|
||||
extend_input_logprob_token_ids_gpu = None
|
||||
if batch.extend_input_logprob_token_ids is not None:
|
||||
extend_input_logprob_token_ids_gpu = (
|
||||
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
|
||||
)
|
||||
ret = cls(
|
||||
forward_mode=batch.forward_mode,
|
||||
batch_size=len(batch.seq_lens),
|
||||
@@ -284,6 +280,7 @@ class ForwardBatch:
|
||||
encoder_lens_cpu=batch.encoder_lens_cpu,
|
||||
encoder_out_cache_loc=batch.encoder_out_cache_loc,
|
||||
seq_lens_sum=batch.seq_lens_sum,
|
||||
seq_lens_cpu=batch.seq_lens_cpu,
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
token_ids_logprobs=batch.token_ids_logprobs,
|
||||
@@ -298,12 +295,19 @@ class ForwardBatch:
|
||||
spec_info=batch.spec_info,
|
||||
capture_hidden_mode=batch.capture_hidden_mode,
|
||||
input_embeds=batch.input_embeds,
|
||||
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
||||
num_token_non_padded=torch.tensor(
|
||||
len(batch.input_ids), dtype=torch.int32
|
||||
).to(device, non_blocking=True),
|
||||
tbo_split_seq_index=batch.tbo_split_seq_index,
|
||||
)
|
||||
device = model_runner.device
|
||||
|
||||
if batch.extend_input_logprob_token_ids is not None:
|
||||
ret.extend_input_logprob_token_ids_gpu = (
|
||||
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
|
||||
)
|
||||
|
||||
if enable_num_token_non_padded(model_runner.server_args):
|
||||
ret.num_token_non_padded = torch.tensor(
|
||||
len(batch.input_ids), dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
# For DP attention
|
||||
if batch.global_num_tokens is not None:
|
||||
@@ -323,6 +327,7 @@ class ForwardBatch:
|
||||
dtype=model_runner.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if ret.forward_mode.is_idle():
|
||||
ret.positions = torch.empty((0,), device=device)
|
||||
TboForwardBatchPreparer.prepare(ret)
|
||||
@@ -335,10 +340,6 @@ class ForwardBatch:
|
||||
):
|
||||
ret.positions = ret.spec_info.positions
|
||||
|
||||
# Get seq_lens_cpu if needed
|
||||
if ret.seq_lens_cpu is None:
|
||||
ret.seq_lens_cpu = batch.seq_lens_cpu
|
||||
|
||||
# Init position information
|
||||
if ret.forward_mode.is_decode():
|
||||
if ret.positions is None:
|
||||
@@ -605,6 +606,10 @@ class ForwardBatch:
|
||||
return self.tbo_split_seq_index is not None
|
||||
|
||||
|
||||
def enable_num_token_non_padded(server_args):
|
||||
return server_args.enable_ep_moe or server_args.enable_deepep_moe
|
||||
|
||||
|
||||
class PPProxyTensors:
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
|
||||
tensors: Dict[str, torch.Tensor]
|
||||
|
||||
@@ -131,9 +131,6 @@ class TboCudaGraphRunnerPlugin:
|
||||
def replay_prepare(
|
||||
self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
|
||||
):
|
||||
if not global_server_args_dict["enable_two_batch_overlap"]:
|
||||
return
|
||||
|
||||
tbo_split_seq_index, tbo_split_token_index = (
|
||||
compute_split_indices_for_cuda_graph_replay(
|
||||
forward_mode=forward_mode,
|
||||
|
||||
Reference in New Issue
Block a user