Let bench_one_batch support enable_dp_attention (#4058)
This commit is contained in:
@@ -60,6 +60,7 @@ from sglang.srt.configs.model_config import ModelConfig
|
|||||||
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||||
|
from sglang.srt.managers.scheduler import Scheduler
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
@@ -184,6 +185,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
|||||||
req.prefix_indices = []
|
req.prefix_indices = []
|
||||||
req.fill_ids = req.origin_input_ids
|
req.fill_ids = req.origin_input_ids
|
||||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||||
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||||
reqs.append(req)
|
reqs.append(req)
|
||||||
|
|
||||||
return input_ids, reqs
|
return input_ids, reqs
|
||||||
@@ -199,6 +201,7 @@ def prepare_extend_inputs_for_correctness_test(
|
|||||||
i, : bench_args.cut_len
|
i, : bench_args.cut_len
|
||||||
]
|
]
|
||||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||||
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||||
return reqs
|
return reqs
|
||||||
|
|
||||||
|
|
||||||
@@ -220,6 +223,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
|||||||
req.prefix_indices = []
|
req.prefix_indices = []
|
||||||
req.fill_ids = req.origin_input_ids
|
req.fill_ids = req.origin_input_ids
|
||||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||||
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||||
reqs.append(req)
|
reqs.append(req)
|
||||||
|
|
||||||
return reqs
|
return reqs
|
||||||
@@ -238,6 +242,7 @@ def extend(reqs, model_runner):
|
|||||||
enable_custom_logit_processor=False,
|
enable_custom_logit_processor=False,
|
||||||
)
|
)
|
||||||
batch.prepare_for_extend()
|
batch.prepare_for_extend()
|
||||||
|
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||||
logits_output = model_runner.forward(forward_batch)
|
logits_output = model_runner.forward(forward_batch)
|
||||||
@@ -249,6 +254,7 @@ def extend(reqs, model_runner):
|
|||||||
def decode(input_token_ids, batch, model_runner):
|
def decode(input_token_ids, batch, model_runner):
|
||||||
batch.output_ids = input_token_ids
|
batch.output_ids = input_token_ids
|
||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
|
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||||
logits_output = model_runner.forward(forward_batch)
|
logits_output = model_runner.forward(forward_batch)
|
||||||
@@ -256,6 +262,20 @@ def decode(input_token_ids, batch, model_runner):
|
|||||||
return next_token_ids, logits_output.next_token_logits
|
return next_token_ids, logits_output.next_token_logits
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
|
||||||
|
if model_runner.server_args.enable_dp_attention:
|
||||||
|
Scheduler.prepare_dp_attn_batch_raw(
|
||||||
|
batch,
|
||||||
|
dp_size=model_runner.server_args.dp_size,
|
||||||
|
attn_tp_size=1,
|
||||||
|
tp_cpu_group=model_runner.tp_group.cpu_group,
|
||||||
|
get_idle_batch=None,
|
||||||
|
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
||||||
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||||
|
speculative_num_draft_tokens=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def correctness_test(
|
def correctness_test(
|
||||||
server_args,
|
server_args,
|
||||||
port_args,
|
port_args,
|
||||||
|
|||||||
@@ -1466,14 +1466,36 @@ class Scheduler(
|
|||||||
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
||||||
|
|
||||||
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
||||||
|
return self.prepare_dp_attn_batch_raw(
|
||||||
|
local_batch,
|
||||||
|
dp_size=self.server_args.dp_size,
|
||||||
|
attn_tp_size=self.attn_tp_size,
|
||||||
|
tp_cpu_group=self.tp_cpu_group,
|
||||||
|
get_idle_batch=self.get_idle_batch,
|
||||||
|
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
||||||
|
spec_algorithm=self.spec_algorithm,
|
||||||
|
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_dp_attn_batch_raw(
|
||||||
|
local_batch: ScheduleBatch,
|
||||||
|
dp_size,
|
||||||
|
attn_tp_size: int,
|
||||||
|
tp_cpu_group,
|
||||||
|
get_idle_batch,
|
||||||
|
disable_cuda_graph: bool,
|
||||||
|
spec_algorithm,
|
||||||
|
speculative_num_draft_tokens,
|
||||||
|
):
|
||||||
# Check if other DP workers have running batches
|
# Check if other DP workers have running batches
|
||||||
if local_batch is None:
|
if local_batch is None:
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
global_num_tokens_for_logprob = 0
|
global_num_tokens_for_logprob = 0
|
||||||
elif local_batch.forward_mode.is_decode():
|
elif local_batch.forward_mode.is_decode():
|
||||||
num_tokens = local_batch.batch_size()
|
num_tokens = local_batch.batch_size()
|
||||||
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
|
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
|
||||||
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
|
num_tokens = num_tokens * speculative_num_draft_tokens
|
||||||
global_num_tokens_for_logprob = num_tokens
|
global_num_tokens_for_logprob = num_tokens
|
||||||
else:
|
else:
|
||||||
num_tokens = local_batch.extend_num_tokens
|
num_tokens = local_batch.extend_num_tokens
|
||||||
@@ -1492,7 +1514,7 @@ class Scheduler(
|
|||||||
else:
|
else:
|
||||||
can_cuda_graph = 0
|
can_cuda_graph = 0
|
||||||
|
|
||||||
if not self.spec_algorithm.is_none():
|
if not spec_algorithm.is_none():
|
||||||
# TODO(sang): Support cuda graph when idle batch is there.
|
# TODO(sang): Support cuda graph when idle batch is there.
|
||||||
if local_batch is None or local_batch.forward_mode.is_idle():
|
if local_batch is None or local_batch.forward_mode.is_idle():
|
||||||
can_cuda_graph = 0
|
can_cuda_graph = 0
|
||||||
@@ -1510,13 +1532,13 @@ class Scheduler(
|
|||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
global_info = torch.empty(
|
global_info = torch.empty(
|
||||||
(self.server_args.dp_size, self.attn_tp_size, 4),
|
(dp_size, attn_tp_size, 4),
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
torch.distributed.all_gather_into_tensor(
|
torch.distributed.all_gather_into_tensor(
|
||||||
global_info.flatten(),
|
global_info.flatten(),
|
||||||
local_info,
|
local_info,
|
||||||
group=self.tp_cpu_group,
|
group=tp_cpu_group,
|
||||||
)
|
)
|
||||||
global_num_tokens = global_info[:, 0, 0].tolist()
|
global_num_tokens = global_info[:, 0, 0].tolist()
|
||||||
can_cuda_graph = min(global_info[:, 0, 1].tolist())
|
can_cuda_graph = min(global_info[:, 0, 1].tolist())
|
||||||
@@ -1524,14 +1546,14 @@ class Scheduler(
|
|||||||
is_extend_in_batch = global_info[:, 0, 3].tolist()
|
is_extend_in_batch = global_info[:, 0, 3].tolist()
|
||||||
|
|
||||||
if local_batch is None and max(global_num_tokens) > 0:
|
if local_batch is None and max(global_num_tokens) > 0:
|
||||||
local_batch = self.get_idle_batch()
|
local_batch = get_idle_batch()
|
||||||
|
|
||||||
if local_batch is not None:
|
if local_batch is not None:
|
||||||
local_batch.global_num_tokens = global_num_tokens
|
local_batch.global_num_tokens = global_num_tokens
|
||||||
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
|
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
|
||||||
|
|
||||||
# Check forward mode for cuda graph
|
# Check forward mode for cuda graph
|
||||||
if not self.server_args.disable_cuda_graph:
|
if not disable_cuda_graph:
|
||||||
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
||||||
|
|
||||||
return local_batch, any(is_extend_in_batch)
|
return local_batch, any(is_extend_in_batch)
|
||||||
|
|||||||
Reference in New Issue
Block a user