Remove unnecessary kernels of num_token_non_padded (#6965)
This commit is contained in:
@@ -35,6 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
ForwardBatch,
|
ForwardBatch,
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
PPProxyTensors,
|
PPProxyTensors,
|
||||||
|
enable_num_token_non_padded,
|
||||||
)
|
)
|
||||||
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
||||||
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
||||||
@@ -190,6 +191,9 @@ class CudaGraphRunner:
|
|||||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
||||||
|
self.enable_two_batch_overlap = (
|
||||||
|
model_runner.server_args.enable_two_batch_overlap
|
||||||
|
)
|
||||||
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
||||||
self.tp_size = model_runner.server_args.tp_size
|
self.tp_size = model_runner.server_args.tp_size
|
||||||
self.dp_size = model_runner.server_args.dp_size
|
self.dp_size = model_runner.server_args.dp_size
|
||||||
@@ -327,9 +331,7 @@ class CudaGraphRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
is_tbo_supported = (
|
is_tbo_supported = (
|
||||||
forward_batch.can_run_tbo
|
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
||||||
if self.model_runner.server_args.enable_two_batch_overlap
|
|
||||||
else True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
|
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
|
||||||
@@ -549,13 +551,7 @@ class CudaGraphRunner:
|
|||||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||||
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
||||||
num_token_non_padded = len(forward_batch.input_ids)
|
|
||||||
self.num_token_non_padded[...] = num_token_non_padded
|
|
||||||
self.tbo_plugin.replay_prepare(
|
|
||||||
forward_mode=forward_batch.forward_mode,
|
|
||||||
bs=bs,
|
|
||||||
num_token_non_padded=num_token_non_padded,
|
|
||||||
)
|
|
||||||
if forward_batch.seq_lens_cpu is not None:
|
if forward_batch.seq_lens_cpu is not None:
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens_cpu.fill_(1)
|
self.seq_lens_cpu.fill_(1)
|
||||||
@@ -572,6 +568,14 @@ class CudaGraphRunner:
|
|||||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||||
|
if enable_num_token_non_padded(self.model_runner.server_args):
|
||||||
|
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
||||||
|
if self.enable_two_batch_overlap:
|
||||||
|
self.tbo_plugin.replay_prepare(
|
||||||
|
forward_mode=forward_batch.forward_mode,
|
||||||
|
bs=bs,
|
||||||
|
num_token_non_padded=len(forward_batch.input_ids),
|
||||||
|
)
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
|
|||||||
@@ -118,6 +118,7 @@ class ForwardMode(IntEnum):
|
|||||||
|
|
||||||
|
|
||||||
class CaptureHiddenMode(IntEnum):
|
class CaptureHiddenMode(IntEnum):
|
||||||
|
# Do not capture anything.
|
||||||
NULL = auto()
|
NULL = auto()
|
||||||
# Capture hidden states of all tokens.
|
# Capture hidden states of all tokens.
|
||||||
FULL = auto()
|
FULL = auto()
|
||||||
@@ -253,6 +254,7 @@ class ForwardBatch:
|
|||||||
# For Qwen2-VL
|
# For Qwen2-VL
|
||||||
mrope_positions: torch.Tensor = None
|
mrope_positions: torch.Tensor = None
|
||||||
|
|
||||||
|
# For two-batch overlap
|
||||||
tbo_split_seq_index: Optional[int] = None
|
tbo_split_seq_index: Optional[int] = None
|
||||||
tbo_parent_token_range: Optional[Tuple[int, int]] = None
|
tbo_parent_token_range: Optional[Tuple[int, int]] = None
|
||||||
tbo_children: Optional[List["ForwardBatch"]] = None
|
tbo_children: Optional[List["ForwardBatch"]] = None
|
||||||
@@ -265,12 +267,6 @@ class ForwardBatch:
|
|||||||
):
|
):
|
||||||
from sglang.srt.two_batch_overlap import TboForwardBatchPreparer
|
from sglang.srt.two_batch_overlap import TboForwardBatchPreparer
|
||||||
|
|
||||||
device = model_runner.device
|
|
||||||
extend_input_logprob_token_ids_gpu = None
|
|
||||||
if batch.extend_input_logprob_token_ids is not None:
|
|
||||||
extend_input_logprob_token_ids_gpu = (
|
|
||||||
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
|
|
||||||
)
|
|
||||||
ret = cls(
|
ret = cls(
|
||||||
forward_mode=batch.forward_mode,
|
forward_mode=batch.forward_mode,
|
||||||
batch_size=len(batch.seq_lens),
|
batch_size=len(batch.seq_lens),
|
||||||
@@ -284,6 +280,7 @@ class ForwardBatch:
|
|||||||
encoder_lens_cpu=batch.encoder_lens_cpu,
|
encoder_lens_cpu=batch.encoder_lens_cpu,
|
||||||
encoder_out_cache_loc=batch.encoder_out_cache_loc,
|
encoder_out_cache_loc=batch.encoder_out_cache_loc,
|
||||||
seq_lens_sum=batch.seq_lens_sum,
|
seq_lens_sum=batch.seq_lens_sum,
|
||||||
|
seq_lens_cpu=batch.seq_lens_cpu,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
token_ids_logprobs=batch.token_ids_logprobs,
|
token_ids_logprobs=batch.token_ids_logprobs,
|
||||||
@@ -298,12 +295,19 @@ class ForwardBatch:
|
|||||||
spec_info=batch.spec_info,
|
spec_info=batch.spec_info,
|
||||||
capture_hidden_mode=batch.capture_hidden_mode,
|
capture_hidden_mode=batch.capture_hidden_mode,
|
||||||
input_embeds=batch.input_embeds,
|
input_embeds=batch.input_embeds,
|
||||||
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
|
||||||
num_token_non_padded=torch.tensor(
|
|
||||||
len(batch.input_ids), dtype=torch.int32
|
|
||||||
).to(device, non_blocking=True),
|
|
||||||
tbo_split_seq_index=batch.tbo_split_seq_index,
|
tbo_split_seq_index=batch.tbo_split_seq_index,
|
||||||
)
|
)
|
||||||
|
device = model_runner.device
|
||||||
|
|
||||||
|
if batch.extend_input_logprob_token_ids is not None:
|
||||||
|
ret.extend_input_logprob_token_ids_gpu = (
|
||||||
|
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
if enable_num_token_non_padded(model_runner.server_args):
|
||||||
|
ret.num_token_non_padded = torch.tensor(
|
||||||
|
len(batch.input_ids), dtype=torch.int32
|
||||||
|
).to(device, non_blocking=True)
|
||||||
|
|
||||||
# For DP attention
|
# For DP attention
|
||||||
if batch.global_num_tokens is not None:
|
if batch.global_num_tokens is not None:
|
||||||
@@ -323,6 +327,7 @@ class ForwardBatch:
|
|||||||
dtype=model_runner.dtype,
|
dtype=model_runner.dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if ret.forward_mode.is_idle():
|
if ret.forward_mode.is_idle():
|
||||||
ret.positions = torch.empty((0,), device=device)
|
ret.positions = torch.empty((0,), device=device)
|
||||||
TboForwardBatchPreparer.prepare(ret)
|
TboForwardBatchPreparer.prepare(ret)
|
||||||
@@ -335,10 +340,6 @@ class ForwardBatch:
|
|||||||
):
|
):
|
||||||
ret.positions = ret.spec_info.positions
|
ret.positions = ret.spec_info.positions
|
||||||
|
|
||||||
# Get seq_lens_cpu if needed
|
|
||||||
if ret.seq_lens_cpu is None:
|
|
||||||
ret.seq_lens_cpu = batch.seq_lens_cpu
|
|
||||||
|
|
||||||
# Init position information
|
# Init position information
|
||||||
if ret.forward_mode.is_decode():
|
if ret.forward_mode.is_decode():
|
||||||
if ret.positions is None:
|
if ret.positions is None:
|
||||||
@@ -605,6 +606,10 @@ class ForwardBatch:
|
|||||||
return self.tbo_split_seq_index is not None
|
return self.tbo_split_seq_index is not None
|
||||||
|
|
||||||
|
|
||||||
|
def enable_num_token_non_padded(server_args):
|
||||||
|
return server_args.enable_ep_moe or server_args.enable_deepep_moe
|
||||||
|
|
||||||
|
|
||||||
class PPProxyTensors:
|
class PPProxyTensors:
|
||||||
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
|
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
|
||||||
tensors: Dict[str, torch.Tensor]
|
tensors: Dict[str, torch.Tensor]
|
||||||
|
|||||||
@@ -131,9 +131,6 @@ class TboCudaGraphRunnerPlugin:
|
|||||||
def replay_prepare(
|
def replay_prepare(
|
||||||
self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
|
self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
|
||||||
):
|
):
|
||||||
if not global_server_args_dict["enable_two_batch_overlap"]:
|
|
||||||
return
|
|
||||||
|
|
||||||
tbo_split_seq_index, tbo_split_token_index = (
|
tbo_split_seq_index, tbo_split_token_index = (
|
||||||
compute_split_indices_for_cuda_graph_replay(
|
compute_split_indices_for_cuda_graph_replay(
|
||||||
forward_mode=forward_mode,
|
forward_mode=forward_mode,
|
||||||
|
|||||||
Reference in New Issue
Block a user