Let bench_one_batch support enable_dp_attention (#4058)
This commit is contained in:
@@ -60,6 +60,7 @@ from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.managers.scheduler import Scheduler
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -184,6 +185,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
||||
req.prefix_indices = []
|
||||
req.fill_ids = req.origin_input_ids
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
reqs.append(req)
|
||||
|
||||
return input_ids, reqs
|
||||
@@ -199,6 +201,7 @@ def prepare_extend_inputs_for_correctness_test(
|
||||
i, : bench_args.cut_len
|
||||
]
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
return reqs
|
||||
|
||||
|
||||
@@ -220,6 +223,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
||||
req.prefix_indices = []
|
||||
req.fill_ids = req.origin_input_ids
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
reqs.append(req)
|
||||
|
||||
return reqs
|
||||
@@ -238,6 +242,7 @@ def extend(reqs, model_runner):
|
||||
enable_custom_logit_processor=False,
|
||||
)
|
||||
batch.prepare_for_extend()
|
||||
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
@@ -249,6 +254,7 @@ def extend(reqs, model_runner):
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.output_ids = input_token_ids
|
||||
batch.prepare_for_decode()
|
||||
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
@@ -256,6 +262,20 @@ def decode(input_token_ids, batch, model_runner):
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
|
||||
if model_runner.server_args.enable_dp_attention:
|
||||
Scheduler.prepare_dp_attn_batch_raw(
|
||||
batch,
|
||||
dp_size=model_runner.server_args.dp_size,
|
||||
attn_tp_size=1,
|
||||
tp_cpu_group=model_runner.tp_group.cpu_group,
|
||||
get_idle_batch=None,
|
||||
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||
speculative_num_draft_tokens=None,
|
||||
)
|
||||
|
||||
|
||||
def correctness_test(
|
||||
server_args,
|
||||
port_args,
|
||||
|
||||
Reference in New Issue
Block a user