feat: mtp support dp-attention (#6081)
Co-authored-by: austindeng <austindeng@tencent.com> Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com> Co-authored-by: Qiaolin Yu <liin1211@outlook.com> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -38,6 +38,10 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.output_buffers = {}
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
||||
self.dp_size = self.model_runner.dp_size
|
||||
self.tp_size = self.model_runner.tp_size
|
||||
self.topk = model_runner.server_args.speculative_eagle_topk
|
||||
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||
@@ -53,7 +57,9 @@ class EAGLEDraftCudaGraphRunner:
|
||||
# Attention backend
|
||||
self.max_bs = max(self.capture_bs)
|
||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||
self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token)
|
||||
self.model_runner.draft_attn_backend.init_cuda_graph_state(
|
||||
self.max_bs, self.max_num_token
|
||||
)
|
||||
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
|
||||
0
|
||||
].get_cuda_graph_seq_len_fill_value()
|
||||
@@ -78,10 +84,26 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
||||
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
||||
self.hidden_states = torch.zeros(
|
||||
(self.max_num_token, self.model_runner.model_config.hidden_size),
|
||||
(self.max_bs, self.model_runner.model_config.hidden_size),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
|
||||
# Capture
|
||||
try:
|
||||
with model_capture_mode():
|
||||
@@ -92,11 +114,26 @@ class EAGLEDraftCudaGraphRunner:
|
||||
)
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
is_bs_supported = (
|
||||
forward_batch.batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else forward_batch.batch_size <= self.max_bs
|
||||
)
|
||||
if self.enable_dp_attention:
|
||||
# TODO(ch-wan): check --moe-dense-tp-size and --enable-dp-lm-head
|
||||
if not forward_batch.can_run_dp_cuda_graph:
|
||||
return False
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
else sum(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
is_bs_supported = (
|
||||
total_batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else total_batch_size <= self.max_bs
|
||||
)
|
||||
else:
|
||||
is_bs_supported = (
|
||||
forward_batch.batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else forward_batch.batch_size <= self.max_bs
|
||||
)
|
||||
return is_bs_supported
|
||||
|
||||
def capture(self):
|
||||
@@ -116,8 +153,40 @@ class EAGLEDraftCudaGraphRunner:
|
||||
topk_index = self.topk_index[:num_seqs]
|
||||
hidden_states = self.hidden_states[:num_seqs]
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
||||
for i in range(self.dp_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
||||
for i in range(self.dp_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
else:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
global_num_tokens_for_logprob = None
|
||||
|
||||
spec_info = EagleDraftInput(
|
||||
topk_p=topk_p, topk_index=topk_index, hidden_states=hidden_states
|
||||
topk_p=topk_p,
|
||||
topk_index=topk_index,
|
||||
hidden_states=hidden_states,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
)
|
||||
|
||||
# Forward batch
|
||||
@@ -133,11 +202,14 @@ class EAGLEDraftCudaGraphRunner:
|
||||
seq_lens_sum=seq_lens.sum().item(),
|
||||
return_logprob=False,
|
||||
positions=positions,
|
||||
global_num_tokens_gpu=global_num_tokens,
|
||||
gathered_buffer=gathered_buffer,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
capture_hidden_mode=(
|
||||
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
||||
),
|
||||
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
|
||||
)
|
||||
|
||||
# Attention backend
|
||||
@@ -147,6 +219,9 @@ class EAGLEDraftCudaGraphRunner:
|
||||
|
||||
# Run and capture
|
||||
def run_once():
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
|
||||
# Backup two fields, which will be modified in-place in `draft_forward`.
|
||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||
hidden_states_backup = forward_batch.spec_info.hidden_states
|
||||
@@ -184,7 +259,15 @@ class EAGLEDraftCudaGraphRunner:
|
||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||
|
||||
# Pad
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
else sum(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
||||
else:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||
@@ -203,6 +286,13 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
||||
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
forward_batch.global_num_tokens_for_logprob_gpu
|
||||
)
|
||||
forward_batch.gathered_buffer = self.gathered_buffer
|
||||
|
||||
# Attention backend
|
||||
if bs != raw_bs:
|
||||
forward_batch.batch_size = bs
|
||||
@@ -210,8 +300,10 @@ class EAGLEDraftCudaGraphRunner:
|
||||
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
||||
forward_batch.positions = self.positions[:num_tokens]
|
||||
|
||||
if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||
# Special handle for seq_len_cpu used when flashinfer mla is used
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
||||
|
||||
|
||||
@@ -35,6 +35,8 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.output_buffers = {}
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
||||
self.tp_size = self.model_runner.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||
@@ -51,7 +53,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||
|
||||
self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
|
||||
self.max_num_token
|
||||
self.max_bs, self.max_num_token
|
||||
)
|
||||
self.seq_len_fill_value = (
|
||||
self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
@@ -90,6 +92,21 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
|
||||
)
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
|
||||
# Capture
|
||||
try:
|
||||
with model_capture_mode():
|
||||
@@ -100,15 +117,30 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
)
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
batch_size = forward_batch.seq_lens.numel()
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if not forward_batch.can_run_dp_cuda_graph:
|
||||
return False
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
else sum(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
is_bs_supported = (
|
||||
total_batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else total_batch_size <= self.max_bs
|
||||
)
|
||||
return is_bs_supported
|
||||
else:
|
||||
batch_size = forward_batch.seq_lens.numel()
|
||||
|
||||
is_bs_supported = (
|
||||
batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else batch_size <= self.max_bs
|
||||
)
|
||||
is_bs_supported = (
|
||||
batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else batch_size <= self.max_bs
|
||||
)
|
||||
|
||||
return is_bs_supported
|
||||
return is_bs_supported
|
||||
|
||||
def capture(self):
|
||||
CudaGraphRunner.capture(self)
|
||||
@@ -128,6 +160,35 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
positions = self.positions[:num_tokens]
|
||||
hidden_states = self.hidden_states[:num_tokens]
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
||||
for i in range(self.dp_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
||||
for i in range(self.dp_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
else:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
global_num_tokens_for_logprob = None
|
||||
|
||||
spec_info = EagleDraftInput(
|
||||
hidden_states=hidden_states,
|
||||
accept_length=accept_length,
|
||||
@@ -147,6 +208,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
seq_lens_sum=seq_lens.sum().item(),
|
||||
return_logprob=False,
|
||||
positions=positions,
|
||||
global_num_tokens_gpu=global_num_tokens,
|
||||
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
|
||||
gathered_buffer=gathered_buffer,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
@@ -167,6 +231,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
|
||||
# Run and capture
|
||||
def run_once():
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
|
||||
# Backup two fields, which will be modified in-place in `draft_forward`.
|
||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||
hidden_states_backup = forward_batch.spec_info.hidden_states
|
||||
@@ -203,24 +270,42 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
# in the batch, which will not be counted as num_seqs
|
||||
raw_bs = forward_batch.batch_size
|
||||
num_tokens = forward_batch.input_ids.shape[0]
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
else sum(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
||||
else:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs * self.num_tokens_per_bs != num_tokens:
|
||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||
self.out_cache_loc.zero_()
|
||||
self.accept_length.fill_(1)
|
||||
self.extend_seq_lens.fill_(1)
|
||||
|
||||
# Common inputs
|
||||
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
|
||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
|
||||
if forward_batch.extend_seq_lens is not None:
|
||||
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
|
||||
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
|
||||
self.positions[:num_tokens].copy_(forward_batch.positions)
|
||||
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
|
||||
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
||||
if forward_batch.spec_info.accept_length is not None:
|
||||
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
forward_batch.global_num_tokens_for_logprob_gpu
|
||||
)
|
||||
forward_batch.gathered_buffer = self.gathered_buffer
|
||||
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||
|
||||
@@ -25,6 +25,8 @@ from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import (
|
||||
fast_topk,
|
||||
@@ -69,6 +71,8 @@ class EagleDraftInput:
|
||||
kv_indices: torch.Tensor = None
|
||||
|
||||
def prepare_for_extend(self, batch: ScheduleBatch):
|
||||
if batch.forward_mode.is_idle():
|
||||
return
|
||||
# Prefill only generate 1 token.
|
||||
assert len(self.verified_id) == len(batch.seq_lens)
|
||||
|
||||
@@ -80,6 +84,24 @@ class EagleDraftInput:
|
||||
)
|
||||
pt += extend_len
|
||||
|
||||
@classmethod
|
||||
def create_idle_input(
|
||||
cls,
|
||||
device: torch.device,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
capture_hidden_mode: CaptureHiddenMode,
|
||||
):
|
||||
return cls(
|
||||
verified_id=None,
|
||||
hidden_states=torch.empty(
|
||||
(0, hidden_size), device=device, dtype=torch.float32
|
||||
),
|
||||
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
|
||||
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
|
||||
capture_hidden_mode=capture_hidden_mode,
|
||||
)
|
||||
|
||||
def prepare_extend_after_decode(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
@@ -193,7 +215,35 @@ class EagleVerifyInput:
|
||||
seq_lens_cpu: torch.Tensor
|
||||
grammar: BaseGrammarObject = None
|
||||
|
||||
@classmethod
|
||||
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
|
||||
return cls(
|
||||
draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
|
||||
custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
|
||||
positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
|
||||
retrive_index=torch.full(
|
||||
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
||||
),
|
||||
retrive_next_token=torch.full(
|
||||
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
||||
),
|
||||
retrive_next_sibling=torch.full(
|
||||
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
||||
),
|
||||
retrive_cum_len=None,
|
||||
topk=topk,
|
||||
draft_token_num=num_verify_tokens,
|
||||
spec_steps=spec_steps,
|
||||
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||
seq_lens_sum=0,
|
||||
seq_lens_cpu=torch.empty((0,), dtype=torch.int32),
|
||||
)
|
||||
|
||||
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
||||
|
||||
if batch.forward_mode.is_idle():
|
||||
return
|
||||
|
||||
batch.input_ids = self.draft_token
|
||||
|
||||
if page_size == 1:
|
||||
@@ -279,6 +329,25 @@ class EagleVerifyInput:
|
||||
tokens. I.e., logits_output.next_token_logits only contains
|
||||
accepted token logits.
|
||||
"""
|
||||
if batch.forward_mode.is_idle():
|
||||
return EagleVerifyOutput(
|
||||
draft_input=EagleDraftInput.create_idle_input(
|
||||
device=batch.device,
|
||||
hidden_size=batch.model_config.hidden_size,
|
||||
topk=self.topk,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
),
|
||||
logits_output=logits_output,
|
||||
verified_id=torch.empty(0, dtype=torch.long, device=batch.device),
|
||||
accept_length_per_req_cpu=[],
|
||||
accepted_indices=torch.full(
|
||||
(0, self.spec_steps + 1),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
device=batch.device,
|
||||
),
|
||||
)
|
||||
|
||||
bs = self.retrive_index.shape[0]
|
||||
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
||||
sampling_info = batch.sampling_info
|
||||
@@ -992,10 +1061,11 @@ def select_top_k_tokens(
|
||||
topk_index = topk_index.reshape(-1, topk**2)
|
||||
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
|
||||
|
||||
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
|
||||
0, hidden_states.shape[0], step=topk, device="cuda"
|
||||
).repeat_interleave(topk)
|
||||
hidden_states = hidden_states[selected_input_index, :]
|
||||
if hidden_states.shape[0] > 0:
|
||||
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
|
||||
0, hidden_states.shape[0], step=topk, device="cuda"
|
||||
).repeat_interleave(topk)
|
||||
hidden_states = hidden_states[selected_input_index, :]
|
||||
|
||||
tree_info = (
|
||||
expand_scores, # shape: (b, topk, topk)
|
||||
|
||||
@@ -7,8 +7,12 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
||||
from sglang.srt.layers.dp_attention import disable_dp_size
|
||||
from sglang.srt.distributed import (
|
||||
GroupCoordinator,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
patch_tensor_parallel_group,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
@@ -57,7 +61,7 @@ logger = logging.getLogger(__name__)
|
||||
def draft_tp_context(tp_group: GroupCoordinator):
|
||||
# Draft model doesn't use dp and has its own tp group.
|
||||
# We disable mscclpp now because it doesn't support 2 comm groups.
|
||||
with disable_dp_size(), patch_tensor_parallel_group(tp_group):
|
||||
with patch_tensor_parallel_group(tp_group):
|
||||
yield
|
||||
|
||||
|
||||
@@ -76,6 +80,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.server_args = server_args
|
||||
self.topk = server_args.speculative_eagle_topk
|
||||
self.speculative_num_steps = server_args.speculative_num_steps
|
||||
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
|
||||
self.enable_nan_detection = server_args.enable_nan_detection
|
||||
self.gpu_id = gpu_id
|
||||
self.device = server_args.device
|
||||
@@ -302,32 +307,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
A tuple of the final logit output of the target model, next tokens accepted,
|
||||
the batch id (used for overlap schedule), and number of accepted tokens.
|
||||
"""
|
||||
if batch.forward_mode.is_decode():
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
spec_info = self.draft(batch)
|
||||
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
||||
self.verify(batch, spec_info)
|
||||
)
|
||||
|
||||
# If it is None, it means all requests are finished
|
||||
if batch.spec_info.verified_id is not None:
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
self.forward_draft_extend_after_decode(batch)
|
||||
return (
|
||||
logits_output,
|
||||
verify_output.verified_id,
|
||||
model_worker_batch.bid,
|
||||
sum(verify_output.accept_length_per_req_cpu),
|
||||
can_run_cuda_graph,
|
||||
)
|
||||
elif batch.forward_mode.is_idle():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids, _ = (
|
||||
self.target_worker.forward_batch_generation(model_worker_batch)
|
||||
)
|
||||
|
||||
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
|
||||
else:
|
||||
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
|
||||
logits_output, next_token_ids, bid, seq_lens_cpu = (
|
||||
self.forward_target_extend(batch)
|
||||
)
|
||||
@@ -336,6 +316,51 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
||||
)
|
||||
return logits_output, next_token_ids, bid, 0, False
|
||||
else:
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
spec_info = self.draft(batch)
|
||||
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
||||
self.verify(batch, spec_info)
|
||||
)
|
||||
need_forward, can_run_draft_extend_cuda_graph = (
|
||||
self.check_forward_draft_extend_after_decode(batch)
|
||||
)
|
||||
if need_forward:
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
self.forward_draft_extend_after_decode(
|
||||
batch, can_run_draft_extend_cuda_graph
|
||||
)
|
||||
return (
|
||||
logits_output,
|
||||
verify_output.verified_id,
|
||||
model_worker_batch.bid,
|
||||
sum(verify_output.accept_length_per_req_cpu),
|
||||
can_run_cuda_graph,
|
||||
)
|
||||
|
||||
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
local_need_forward = (
|
||||
batch.spec_info.verified_id is not None
|
||||
and batch.spec_info.verified_id.shape[0] > 0
|
||||
)
|
||||
if not self.server_args.enable_dp_attention:
|
||||
return local_need_forward, True
|
||||
|
||||
global_need_forward = torch.tensor(
|
||||
[
|
||||
(local_need_forward),
|
||||
],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
global_need_forward, group=get_tp_group().cpu_group
|
||||
)
|
||||
global_need_forward_cnt = global_need_forward[0].item()
|
||||
need_forward = global_need_forward_cnt > 0
|
||||
can_run_draft_extend_cuda_graph = (
|
||||
global_need_forward_cnt == get_tensor_model_parallel_world_size()
|
||||
)
|
||||
return need_forward, can_run_draft_extend_cuda_graph
|
||||
|
||||
def forward_target_extend(
|
||||
self, batch: ScheduleBatch
|
||||
@@ -354,6 +379,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
model_worker_batch.spec_num_draft_tokens = 1
|
||||
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
@@ -364,7 +390,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
model_worker_batch.seq_lens_cpu,
|
||||
)
|
||||
|
||||
def draft(self, batch: ScheduleBatch):
|
||||
def _draft_preprocess_decode(self, batch: ScheduleBatch):
|
||||
# Parse args
|
||||
num_seqs = batch.batch_size()
|
||||
spec_info = batch.spec_info
|
||||
@@ -466,10 +492,32 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||
batch.return_hidden_states = False
|
||||
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
||||
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
||||
|
||||
def _draft_preprocess_idle(self, batch: ScheduleBatch):
|
||||
batch.spec_info = EagleDraftInput.create_idle_input(
|
||||
device=self.device,
|
||||
hidden_size=self.model_config.hidden_size,
|
||||
topk=self.topk,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
)
|
||||
|
||||
def draft(self, batch: ScheduleBatch):
|
||||
# Parse args
|
||||
if batch.forward_mode.is_idle():
|
||||
self._draft_preprocess_idle(batch)
|
||||
else:
|
||||
self._draft_preprocess_decode(batch)
|
||||
|
||||
spec_info = batch.spec_info
|
||||
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
batch.return_hidden_states = False
|
||||
|
||||
# Get forward batch
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.spec_num_draft_tokens = self.topk
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -481,12 +529,18 @@ class EAGLEWorker(TpModelWorker):
|
||||
forward_batch
|
||||
)
|
||||
else:
|
||||
# Initialize attention backend
|
||||
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
# Initialize attention backend
|
||||
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
||||
# Run forward steps
|
||||
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
||||
|
||||
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
||||
if batch.forward_mode.is_idle():
|
||||
return EagleVerifyInput.create_idle_input(
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
self.speculative_num_draft_tokens,
|
||||
)
|
||||
|
||||
(
|
||||
tree_mask,
|
||||
@@ -504,7 +558,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.seq_lens_sum,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
self.server_args.speculative_num_draft_tokens,
|
||||
self.speculative_num_draft_tokens,
|
||||
)
|
||||
|
||||
return EagleVerifyInput(
|
||||
@@ -584,11 +638,16 @@ class EAGLEWorker(TpModelWorker):
|
||||
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
||||
spec_info.prepare_for_verify(batch, self.page_size)
|
||||
batch.return_hidden_states = False
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.forward_mode = (
|
||||
ForwardMode.TARGET_VERIFY
|
||||
if not batch.forward_mode.is_idle()
|
||||
else ForwardMode.IDLE
|
||||
)
|
||||
batch.spec_info = spec_info
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
||||
)
|
||||
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
||||
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
||||
|
||||
if batch.has_grammar:
|
||||
@@ -646,7 +705,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.add_logprob_values(batch, res, logits_output)
|
||||
|
||||
# Prepare the batch for the next draft forwards.
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.forward_mode = (
|
||||
ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
|
||||
)
|
||||
batch.spec_info = res.draft_input
|
||||
|
||||
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
||||
@@ -743,6 +804,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=seq_lens_cpu
|
||||
)
|
||||
model_worker_batch.spec_num_draft_tokens = 1
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -753,19 +815,37 @@ class EAGLEWorker(TpModelWorker):
|
||||
assert forward_batch.spec_info is batch.spec_info
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
def forward_draft_extend_after_decode(
|
||||
self, batch: ScheduleBatch, can_run_draft_extend_cuda_graph: bool
|
||||
):
|
||||
# Backup fields that will be modified in-place
|
||||
seq_lens_backup = batch.seq_lens.clone()
|
||||
req_pool_indices_backup = batch.req_pool_indices
|
||||
accept_length_backup = batch.spec_info.accept_length
|
||||
return_logprob_backup = batch.return_logprob
|
||||
|
||||
# Prepare metadata
|
||||
batch.spec_info.prepare_extend_after_decode(
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
input_is_idle = batch.forward_mode.is_idle()
|
||||
if not input_is_idle:
|
||||
# Prepare metadata
|
||||
if batch.spec_info.verified_id is not None:
|
||||
batch.spec_info.prepare_extend_after_decode(
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
else:
|
||||
batch = batch.copy()
|
||||
batch.prepare_for_idle()
|
||||
batch.spec_info = EagleDraftInput.create_idle_input(
|
||||
device=self.device,
|
||||
hidden_size=self.model_config.hidden_size,
|
||||
topk=self.topk,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
)
|
||||
|
||||
batch.return_hidden_states = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -776,7 +856,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
# Run
|
||||
can_cuda_graph = (
|
||||
self.cuda_graph_runner_for_draft_extend
|
||||
can_run_draft_extend_cuda_graph
|
||||
and self.cuda_graph_runner_for_draft_extend
|
||||
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
|
||||
)
|
||||
if can_cuda_graph:
|
||||
@@ -789,7 +870,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
forward_batch.spec_info.hidden_states = logits_output.hidden_states
|
||||
else:
|
||||
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
self.draft_model_runner.attn_backend.init_forward_metadata(
|
||||
forward_batch
|
||||
)
|
||||
logits_output = self.draft_model_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
@@ -799,7 +883,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
# Restore backup.
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.forward_mode = (
|
||||
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
|
||||
)
|
||||
batch.seq_lens = seq_lens_backup
|
||||
batch.req_pool_indices = req_pool_indices_backup
|
||||
batch.spec_info.accept_length = accept_length_backup
|
||||
|
||||
Reference in New Issue
Block a user