Minor refactor two-batch overlap (#6682)
This commit is contained in:
@@ -40,7 +40,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
)
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
||||
from sglang.srt.two_batch_overlap import (
|
||||
TboCudaGraphRunnerUtils,
|
||||
TboCudaGraphRunnerPlugin,
|
||||
TboForwardBatchPreparer,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
@@ -256,6 +256,7 @@ class CudaGraphRunner:
|
||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
||||
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
||||
self.tbo_plugin = TboCudaGraphRunnerPlugin()
|
||||
|
||||
# pipeline parallelism
|
||||
if self.pp_size > 1:
|
||||
@@ -481,12 +482,9 @@ class CudaGraphRunner:
|
||||
capture_hidden_mode=self.capture_hidden_mode,
|
||||
lora_paths=lora_paths,
|
||||
num_token_non_padded=self.num_token_non_padded,
|
||||
tbo_split_seq_index=TboCudaGraphRunnerUtils.compute_tbo_split_seq_index(
|
||||
self, num_tokens
|
||||
),
|
||||
global_forward_mode=self.capture_forward_mode,
|
||||
)
|
||||
TboForwardBatchPreparer.prepare(forward_batch)
|
||||
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
||||
|
||||
if lora_paths is not None:
|
||||
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
||||
@@ -581,7 +579,13 @@ class CudaGraphRunner:
|
||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
||||
self.num_token_non_padded[...] = len(forward_batch.input_ids)
|
||||
num_token_non_padded = len(forward_batch.input_ids)
|
||||
self.num_token_non_padded[...] = num_token_non_padded
|
||||
self.tbo_plugin.replay_prepare(
|
||||
forward_mode=forward_batch.forward_mode,
|
||||
bs=bs,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
)
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(1)
|
||||
|
||||
Reference in New Issue
Block a user