feat: mtp support dp-attention (#6081)
Co-authored-by: austindeng <austindeng@tencent.com> Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com> Co-authored-by: Qiaolin Yu <liin1211@outlook.com> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -7,8 +7,12 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
||||
from sglang.srt.layers.dp_attention import disable_dp_size
|
||||
from sglang.srt.distributed import (
|
||||
GroupCoordinator,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
patch_tensor_parallel_group,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
@@ -57,7 +61,7 @@ logger = logging.getLogger(__name__)
|
||||
def draft_tp_context(tp_group: GroupCoordinator):
|
||||
# Draft model doesn't use dp and has its own tp group.
|
||||
# We disable mscclpp now because it doesn't support 2 comm groups.
|
||||
with disable_dp_size(), patch_tensor_parallel_group(tp_group):
|
||||
with patch_tensor_parallel_group(tp_group):
|
||||
yield
|
||||
|
||||
|
||||
@@ -76,6 +80,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.server_args = server_args
|
||||
self.topk = server_args.speculative_eagle_topk
|
||||
self.speculative_num_steps = server_args.speculative_num_steps
|
||||
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
|
||||
self.enable_nan_detection = server_args.enable_nan_detection
|
||||
self.gpu_id = gpu_id
|
||||
self.device = server_args.device
|
||||
@@ -302,32 +307,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
A tuple of the final logit output of the target model, next tokens accepted,
|
||||
the batch id (used for overlap schedule), and number of accepted tokens.
|
||||
"""
|
||||
if batch.forward_mode.is_decode():
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
spec_info = self.draft(batch)
|
||||
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
||||
self.verify(batch, spec_info)
|
||||
)
|
||||
|
||||
# If it is None, it means all requests are finished
|
||||
if batch.spec_info.verified_id is not None:
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
self.forward_draft_extend_after_decode(batch)
|
||||
return (
|
||||
logits_output,
|
||||
verify_output.verified_id,
|
||||
model_worker_batch.bid,
|
||||
sum(verify_output.accept_length_per_req_cpu),
|
||||
can_run_cuda_graph,
|
||||
)
|
||||
elif batch.forward_mode.is_idle():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids, _ = (
|
||||
self.target_worker.forward_batch_generation(model_worker_batch)
|
||||
)
|
||||
|
||||
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
|
||||
else:
|
||||
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
|
||||
logits_output, next_token_ids, bid, seq_lens_cpu = (
|
||||
self.forward_target_extend(batch)
|
||||
)
|
||||
@@ -336,6 +316,51 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
||||
)
|
||||
return logits_output, next_token_ids, bid, 0, False
|
||||
else:
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
spec_info = self.draft(batch)
|
||||
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
||||
self.verify(batch, spec_info)
|
||||
)
|
||||
need_forward, can_run_draft_extend_cuda_graph = (
|
||||
self.check_forward_draft_extend_after_decode(batch)
|
||||
)
|
||||
if need_forward:
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
self.forward_draft_extend_after_decode(
|
||||
batch, can_run_draft_extend_cuda_graph
|
||||
)
|
||||
return (
|
||||
logits_output,
|
||||
verify_output.verified_id,
|
||||
model_worker_batch.bid,
|
||||
sum(verify_output.accept_length_per_req_cpu),
|
||||
can_run_cuda_graph,
|
||||
)
|
||||
|
||||
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
local_need_forward = (
|
||||
batch.spec_info.verified_id is not None
|
||||
and batch.spec_info.verified_id.shape[0] > 0
|
||||
)
|
||||
if not self.server_args.enable_dp_attention:
|
||||
return local_need_forward, True
|
||||
|
||||
global_need_forward = torch.tensor(
|
||||
[
|
||||
(local_need_forward),
|
||||
],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
global_need_forward, group=get_tp_group().cpu_group
|
||||
)
|
||||
global_need_forward_cnt = global_need_forward[0].item()
|
||||
need_forward = global_need_forward_cnt > 0
|
||||
can_run_draft_extend_cuda_graph = (
|
||||
global_need_forward_cnt == get_tensor_model_parallel_world_size()
|
||||
)
|
||||
return need_forward, can_run_draft_extend_cuda_graph
|
||||
|
||||
def forward_target_extend(
|
||||
self, batch: ScheduleBatch
|
||||
@@ -354,6 +379,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
model_worker_batch.spec_num_draft_tokens = 1
|
||||
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
@@ -364,7 +390,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
model_worker_batch.seq_lens_cpu,
|
||||
)
|
||||
|
||||
def draft(self, batch: ScheduleBatch):
|
||||
def _draft_preprocess_decode(self, batch: ScheduleBatch):
|
||||
# Parse args
|
||||
num_seqs = batch.batch_size()
|
||||
spec_info = batch.spec_info
|
||||
@@ -466,10 +492,32 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||
batch.return_hidden_states = False
|
||||
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
||||
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
||||
|
||||
def _draft_preprocess_idle(self, batch: ScheduleBatch):
|
||||
batch.spec_info = EagleDraftInput.create_idle_input(
|
||||
device=self.device,
|
||||
hidden_size=self.model_config.hidden_size,
|
||||
topk=self.topk,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
)
|
||||
|
||||
def draft(self, batch: ScheduleBatch):
|
||||
# Parse args
|
||||
if batch.forward_mode.is_idle():
|
||||
self._draft_preprocess_idle(batch)
|
||||
else:
|
||||
self._draft_preprocess_decode(batch)
|
||||
|
||||
spec_info = batch.spec_info
|
||||
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
batch.return_hidden_states = False
|
||||
|
||||
# Get forward batch
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.spec_num_draft_tokens = self.topk
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -481,12 +529,18 @@ class EAGLEWorker(TpModelWorker):
|
||||
forward_batch
|
||||
)
|
||||
else:
|
||||
# Initialize attention backend
|
||||
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
# Initialize attention backend
|
||||
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
||||
# Run forward steps
|
||||
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
||||
|
||||
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
||||
if batch.forward_mode.is_idle():
|
||||
return EagleVerifyInput.create_idle_input(
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
self.speculative_num_draft_tokens,
|
||||
)
|
||||
|
||||
(
|
||||
tree_mask,
|
||||
@@ -504,7 +558,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.seq_lens_sum,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
self.server_args.speculative_num_draft_tokens,
|
||||
self.speculative_num_draft_tokens,
|
||||
)
|
||||
|
||||
return EagleVerifyInput(
|
||||
@@ -584,11 +638,16 @@ class EAGLEWorker(TpModelWorker):
|
||||
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
||||
spec_info.prepare_for_verify(batch, self.page_size)
|
||||
batch.return_hidden_states = False
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.forward_mode = (
|
||||
ForwardMode.TARGET_VERIFY
|
||||
if not batch.forward_mode.is_idle()
|
||||
else ForwardMode.IDLE
|
||||
)
|
||||
batch.spec_info = spec_info
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
||||
)
|
||||
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
||||
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
||||
|
||||
if batch.has_grammar:
|
||||
@@ -646,7 +705,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.add_logprob_values(batch, res, logits_output)
|
||||
|
||||
# Prepare the batch for the next draft forwards.
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.forward_mode = (
|
||||
ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
|
||||
)
|
||||
batch.spec_info = res.draft_input
|
||||
|
||||
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
||||
@@ -743,6 +804,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=seq_lens_cpu
|
||||
)
|
||||
model_worker_batch.spec_num_draft_tokens = 1
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -753,19 +815,37 @@ class EAGLEWorker(TpModelWorker):
|
||||
assert forward_batch.spec_info is batch.spec_info
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
def forward_draft_extend_after_decode(
|
||||
self, batch: ScheduleBatch, can_run_draft_extend_cuda_graph: bool
|
||||
):
|
||||
# Backup fields that will be modified in-place
|
||||
seq_lens_backup = batch.seq_lens.clone()
|
||||
req_pool_indices_backup = batch.req_pool_indices
|
||||
accept_length_backup = batch.spec_info.accept_length
|
||||
return_logprob_backup = batch.return_logprob
|
||||
|
||||
# Prepare metadata
|
||||
batch.spec_info.prepare_extend_after_decode(
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
input_is_idle = batch.forward_mode.is_idle()
|
||||
if not input_is_idle:
|
||||
# Prepare metadata
|
||||
if batch.spec_info.verified_id is not None:
|
||||
batch.spec_info.prepare_extend_after_decode(
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
else:
|
||||
batch = batch.copy()
|
||||
batch.prepare_for_idle()
|
||||
batch.spec_info = EagleDraftInput.create_idle_input(
|
||||
device=self.device,
|
||||
hidden_size=self.model_config.hidden_size,
|
||||
topk=self.topk,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
)
|
||||
|
||||
batch.return_hidden_states = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -776,7 +856,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
# Run
|
||||
can_cuda_graph = (
|
||||
self.cuda_graph_runner_for_draft_extend
|
||||
can_run_draft_extend_cuda_graph
|
||||
and self.cuda_graph_runner_for_draft_extend
|
||||
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
|
||||
)
|
||||
if can_cuda_graph:
|
||||
@@ -789,7 +870,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
forward_batch.spec_info.hidden_states = logits_output.hidden_states
|
||||
else:
|
||||
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
self.draft_model_runner.attn_backend.init_forward_metadata(
|
||||
forward_batch
|
||||
)
|
||||
logits_output = self.draft_model_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
@@ -799,7 +883,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
# Restore backup.
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.forward_mode = (
|
||||
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
|
||||
)
|
||||
batch.seq_lens = seq_lens_backup
|
||||
batch.req_pool_indices = req_pool_indices_backup
|
||||
batch.spec_info.accept_length = accept_length_backup
|
||||
|
||||
Reference in New Issue
Block a user