DP Enhancement (#8280)
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.dp_attention import DPPaddingMode
|
||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||
CudaGraphRunner,
|
||||
@@ -97,13 +98,6 @@ class EAGLEDraftCudaGraphRunner:
|
||||
)
|
||||
|
||||
if self.require_gathered_buffer:
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
@@ -111,12 +105,30 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token * self.dp_size,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
assert self.require_attn_tp_gather
|
||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(1,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
self.global_num_tokens_gpu = None
|
||||
self.global_num_tokens_for_logprob_gpu = None
|
||||
self.gathered_buffer = None
|
||||
|
||||
# Capture
|
||||
try:
|
||||
@@ -130,9 +142,9 @@ class EAGLEDraftCudaGraphRunner:
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.require_mlp_tp_gather:
|
||||
cuda_graph_bs = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
else sum(forward_batch.global_num_tokens_cpu)
|
||||
else max(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
else:
|
||||
cuda_graph_bs = forward_batch.batch_size
|
||||
@@ -168,26 +180,20 @@ class EAGLEDraftCudaGraphRunner:
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
||||
for i in range(self.dp_size)
|
||||
],
|
||||
[num_tokens] * self.dp_size,
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
||||
for i in range(self.dp_size)
|
||||
],
|
||||
[num_tokens] * self.dp_size,
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
elif self.require_attn_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
@@ -233,6 +239,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
return_logprob=False,
|
||||
positions=positions,
|
||||
global_num_tokens_gpu=global_num_tokens,
|
||||
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
gathered_buffer=gathered_buffer,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
@@ -290,12 +297,13 @@ class EAGLEDraftCudaGraphRunner:
|
||||
|
||||
# Pad
|
||||
if self.require_mlp_tp_gather:
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
|
||||
max_batch_size = (
|
||||
max_num_tokens // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
else sum(forward_batch.global_num_tokens_cpu)
|
||||
else max_num_tokens
|
||||
)
|
||||
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
||||
index = bisect.bisect_left(self.capture_bs, max_batch_size)
|
||||
else:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
@@ -316,12 +324,10 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
||||
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
||||
|
||||
# TODO(ch-wan): support num_token_non_padded
|
||||
if self.require_gathered_buffer:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
forward_batch.global_num_tokens_for_logprob_gpu
|
||||
)
|
||||
forward_batch.gathered_buffer = self.gathered_buffer
|
||||
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
||||
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
||||
|
||||
# Attention backend
|
||||
if bs != raw_bs:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.dp_attention import DPPaddingMode
|
||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||
CudaGraphRunner,
|
||||
@@ -109,13 +110,6 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
)
|
||||
|
||||
if self.require_gathered_buffer:
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
@@ -123,12 +117,31 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token * self.dp_size,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
assert self.require_attn_tp_gather
|
||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(1,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
self.global_num_tokens_gpu = None
|
||||
self.global_num_tokens_for_logprob_gpu = None
|
||||
self.gathered_buffer = None
|
||||
|
||||
# Capture
|
||||
try:
|
||||
with model_capture_mode():
|
||||
@@ -141,9 +154,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.require_mlp_tp_gather:
|
||||
cuda_graph_bs = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
else sum(forward_batch.global_num_tokens_cpu)
|
||||
else max(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
else:
|
||||
cuda_graph_bs = forward_batch.seq_lens.numel()
|
||||
@@ -180,27 +193,19 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
||||
for i in range(self.dp_size)
|
||||
],
|
||||
[num_tokens] * self.dp_size,
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
||||
for i in range(self.dp_size)
|
||||
],
|
||||
[bs] * self.dp_size,
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
||||
elif self.require_attn_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
@@ -211,18 +216,14 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
torch.tensor(
|
||||
[num_tokens],
|
||||
[bs],
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
else:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
global_num_tokens_for_logprob = None
|
||||
|
||||
spec_info = EagleDraftInput(
|
||||
hidden_states=hidden_states,
|
||||
@@ -243,8 +244,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
seq_lens_sum=seq_lens.sum().item(),
|
||||
return_logprob=False,
|
||||
positions=positions,
|
||||
global_num_tokens_gpu=global_num_tokens,
|
||||
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
|
||||
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
||||
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
||||
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
gathered_buffer=gathered_buffer,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
@@ -306,12 +308,13 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
raw_bs = forward_batch.batch_size
|
||||
num_tokens = forward_batch.input_ids.shape[0]
|
||||
if self.require_mlp_tp_gather:
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
|
||||
max_batch_size = (
|
||||
max_num_tokens // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
else sum(forward_batch.global_num_tokens_cpu)
|
||||
else max_num_tokens
|
||||
)
|
||||
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
||||
index = bisect.bisect_left(self.capture_bs, max_batch_size)
|
||||
else:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
|
||||
@@ -334,12 +337,10 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||
|
||||
# TODO(ch-wan): support num_token_non_padded
|
||||
if self.require_gathered_buffer:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
forward_batch.global_num_tokens_for_logprob_gpu
|
||||
)
|
||||
forward_batch.gathered_buffer = self.gathered_buffer
|
||||
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
||||
self.global_num_tokens_for_logprob_gpu.fill_(bs)
|
||||
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
|
||||
@@ -71,9 +71,20 @@ class EagleDraftInput:
|
||||
kv_indptr: torch.Tensor = None
|
||||
kv_indices: torch.Tensor = None
|
||||
|
||||
# Shape info for padding
|
||||
num_tokens_per_batch: int = -1
|
||||
num_tokens_for_logprob_per_batch: int = -1
|
||||
|
||||
# Inputs for draft extend
|
||||
# shape: (b,)
|
||||
seq_lens_for_draft_extend: torch.Tensor = None
|
||||
req_pool_indices_for_draft_extend: torch.Tensor = None
|
||||
|
||||
def prepare_for_extend(self, batch: ScheduleBatch):
|
||||
|
||||
if batch.forward_mode.is_idle():
|
||||
return
|
||||
|
||||
# Prefill only generate 1 token.
|
||||
assert len(self.verified_id) == len(batch.seq_lens)
|
||||
|
||||
@@ -95,7 +106,7 @@ class EagleDraftInput:
|
||||
capture_hidden_mode: CaptureHiddenMode,
|
||||
):
|
||||
return cls(
|
||||
verified_id=None,
|
||||
verified_id=torch.empty((0,), device=device, dtype=torch.int32),
|
||||
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
|
||||
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
|
||||
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
|
||||
@@ -109,7 +120,10 @@ class EagleDraftInput:
|
||||
batch: ScheduleBatch,
|
||||
speculative_num_steps: int,
|
||||
):
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||
|
||||
if batch.forward_mode.is_idle():
|
||||
return
|
||||
|
||||
batch.input_ids = self.verified_id
|
||||
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
|
||||
batch.extend_num_tokens = sum(batch.extend_lens)
|
||||
@@ -316,7 +330,7 @@ class EagleVerifyInput:
|
||||
def verify(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
logits_output: torch.Tensor,
|
||||
logits_output: LogitsProcessorOutput,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
||||
@@ -599,13 +613,14 @@ class EagleVerifyInput:
|
||||
batch.out_cache_loc = tgt_cache_loc
|
||||
batch.seq_lens.add_(accept_length + 1)
|
||||
|
||||
draft_input = EagleDraftInput()
|
||||
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
|
||||
draft_input.verified_id = verified_id
|
||||
draft_input.accept_length = accept_length
|
||||
draft_input.accept_length_cpu = accept_length.tolist()
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
|
||||
draft_input = EagleDraftInput(
|
||||
hidden_states=batch.spec_info.hidden_states[accept_index],
|
||||
verified_id=verified_id,
|
||||
accept_length=accept_length,
|
||||
accept_length_cpu=accept_length.tolist(),
|
||||
seq_lens_for_draft_extend=batch.seq_lens,
|
||||
req_pool_indices_for_draft_extend=batch.req_pool_indices,
|
||||
)
|
||||
|
||||
return EagleVerifyOutput(
|
||||
draft_input=draft_input,
|
||||
@@ -628,7 +643,6 @@ class EagleVerifyInput:
|
||||
batch.seq_lens.add_(accept_length + 1)
|
||||
|
||||
accept_length_cpu = accept_length.tolist()
|
||||
draft_input = EagleDraftInput()
|
||||
if len(unfinished_accept_index) > 0:
|
||||
unfinished_accept_index = torch.cat(unfinished_accept_index)
|
||||
unfinished_index_device = torch.tensor(
|
||||
@@ -659,18 +673,26 @@ class EagleVerifyInput:
|
||||
next_power_of_2(self.draft_token_num),
|
||||
)
|
||||
|
||||
draft_input.hidden_states = batch.spec_info.hidden_states[
|
||||
unfinished_accept_index
|
||||
]
|
||||
draft_input.verified_id = predict[unfinished_accept_index]
|
||||
draft_input.accept_length_cpu = draft_input_accept_length_cpu
|
||||
draft_input.accept_length = accept_length[unfinished_index_device]
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
||||
unfinished_index_device
|
||||
]
|
||||
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
|
||||
unfinished_index_device
|
||||
]
|
||||
draft_input = EagleDraftInput(
|
||||
hidden_states=batch.spec_info.hidden_states[
|
||||
unfinished_accept_index
|
||||
],
|
||||
verified_id=predict[unfinished_accept_index],
|
||||
accept_length_cpu=draft_input_accept_length_cpu,
|
||||
accept_length=accept_length[unfinished_index_device],
|
||||
seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
|
||||
req_pool_indices_for_draft_extend=batch.req_pool_indices[
|
||||
unfinished_index_device
|
||||
],
|
||||
)
|
||||
else:
|
||||
draft_input = EagleDraftInput.create_idle_input(
|
||||
device=batch.device,
|
||||
hidden_size=batch.model_config.hidden_size,
|
||||
dtype=batch.model_config.dtype,
|
||||
topk=self.topk,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
)
|
||||
|
||||
return EagleVerifyOutput(
|
||||
draft_input=draft_input,
|
||||
|
||||
@@ -297,7 +297,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def forward_batch_speculative_generation(
|
||||
self, batch: ScheduleBatch
|
||||
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
|
||||
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
|
||||
"""Run speculative decoding forward.
|
||||
|
||||
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
||||
@@ -325,11 +325,16 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.verify(batch, spec_info)
|
||||
)
|
||||
|
||||
if self.check_forward_draft_extend_after_decode(batch):
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
self.forward_draft_extend_after_decode(
|
||||
batch,
|
||||
)
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
# NOTE: We should use `check_forward_draft_extend_after_decode`
|
||||
# when DP attention is enabled, but it is slow. Skip it for now.
|
||||
if (
|
||||
self.server_args.enable_dp_attention
|
||||
or batch.spec_info.verified_id.shape[0] > 0
|
||||
):
|
||||
# decode is not finished
|
||||
self.forward_draft_extend_after_decode(batch)
|
||||
|
||||
return (
|
||||
logits_output,
|
||||
verify_output.verified_id,
|
||||
@@ -339,10 +344,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
|
||||
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
local_need_forward = (
|
||||
batch.spec_info.verified_id is not None
|
||||
and batch.spec_info.verified_id.shape[0] > 0
|
||||
)
|
||||
local_need_forward = batch.spec_info.verified_id.shape[0] > 0
|
||||
if not self.server_args.enable_dp_attention:
|
||||
return local_need_forward
|
||||
|
||||
@@ -361,7 +363,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def forward_target_extend(
|
||||
self, batch: ScheduleBatch
|
||||
) -> Tuple[LogitsProcessorOutput, List[int], int]:
|
||||
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]:
|
||||
"""Run the target extend.
|
||||
|
||||
Args:
|
||||
@@ -376,7 +378,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
model_worker_batch.spec_num_draft_tokens = 1
|
||||
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
@@ -508,13 +509,15 @@ class EAGLEWorker(TpModelWorker):
|
||||
self._draft_preprocess_decode(batch)
|
||||
|
||||
spec_info = batch.spec_info
|
||||
assert isinstance(spec_info, EagleDraftInput)
|
||||
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
spec_info.num_tokens_per_batch = self.topk
|
||||
spec_info.num_tokens_for_logprob_per_batch = self.topk
|
||||
batch.return_hidden_states = False
|
||||
|
||||
# Get forward batch
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.spec_num_draft_tokens = self.topk
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
@@ -527,6 +530,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
forward_batch
|
||||
)
|
||||
else:
|
||||
forward_batch.can_run_dp_cuda_graph = False
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
# Initialize attention backend
|
||||
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
||||
@@ -578,6 +582,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
def draft_forward(self, forward_batch: ForwardBatch):
|
||||
# Parse args
|
||||
spec_info = forward_batch.spec_info
|
||||
assert isinstance(spec_info, EagleDraftInput)
|
||||
out_cache_loc = forward_batch.out_cache_loc
|
||||
topk_p, topk_index, hidden_states = (
|
||||
spec_info.topk_p,
|
||||
@@ -621,8 +626,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
spec_info.hidden_states = hidden_states
|
||||
|
||||
# Run forward
|
||||
logits_output = self.draft_model_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
logits_output, _ = self.draft_model_runner.forward(
|
||||
forward_batch, skip_attn_backend_init=True
|
||||
)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||
@@ -642,10 +647,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
else ForwardMode.IDLE
|
||||
)
|
||||
batch.spec_info = spec_info
|
||||
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
||||
)
|
||||
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
||||
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
||||
|
||||
if batch.has_grammar:
|
||||
@@ -782,8 +787,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
hidden_states: torch.Tensor,
|
||||
next_token_ids: List[int],
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
next_token_ids: torch.Tensor,
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
"""Run draft model extend. This API modifies the states of the batch.
|
||||
|
||||
@@ -795,6 +800,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.spec_info = EagleDraftInput(
|
||||
hidden_states=hidden_states,
|
||||
verified_id=next_token_ids,
|
||||
num_tokens_per_batch=1,
|
||||
num_tokens_for_logprob_per_batch=1,
|
||||
)
|
||||
batch.return_hidden_states = False
|
||||
batch.spec_info.prepare_for_extend(batch)
|
||||
@@ -802,7 +809,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=seq_lens_cpu
|
||||
)
|
||||
model_worker_batch.spec_num_draft_tokens = 1
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -814,37 +820,45 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
assert isinstance(batch.spec_info, EagleDraftInput)
|
||||
# Backup fields that will be modified in-place
|
||||
seq_lens_backup = batch.seq_lens.clone()
|
||||
req_pool_indices_backup = batch.req_pool_indices
|
||||
accept_length_backup = batch.spec_info.accept_length
|
||||
return_logprob_backup = batch.return_logprob
|
||||
|
||||
input_is_idle = batch.forward_mode.is_idle()
|
||||
if not input_is_idle:
|
||||
# Prepare metadata
|
||||
if batch.spec_info.verified_id is not None:
|
||||
batch.spec_info.prepare_extend_after_decode(
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
else:
|
||||
batch = batch.copy()
|
||||
batch.prepare_for_idle()
|
||||
hidden_size = (
|
||||
self.model_config.hidden_size * 3
|
||||
if self.speculative_algorithm.is_eagle3()
|
||||
else self.model_config.hidden_size
|
||||
)
|
||||
batch.spec_info = EagleDraftInput.create_idle_input(
|
||||
device=self.device,
|
||||
hidden_size=hidden_size,
|
||||
dtype=self.model_config.dtype,
|
||||
topk=self.topk,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
)
|
||||
|
||||
if not input_is_idle and batch.spec_info.verified_id.numel() == 0:
|
||||
batch = batch.copy()
|
||||
batch.prepare_for_idle()
|
||||
hidden_size = (
|
||||
self.model_config.hidden_size * 3
|
||||
if self.speculative_algorithm.is_eagle3()
|
||||
else self.model_config.hidden_size
|
||||
)
|
||||
batch.spec_info = EagleDraftInput.create_idle_input(
|
||||
device=self.device,
|
||||
hidden_size=hidden_size,
|
||||
dtype=self.model_config.dtype,
|
||||
topk=self.topk,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
)
|
||||
|
||||
batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
|
||||
batch.spec_info.num_tokens_for_logprob_per_batch = 1
|
||||
batch.spec_info.prepare_extend_after_decode(
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
batch.forward_mode = (
|
||||
ForwardMode.DRAFT_EXTEND
|
||||
if not batch.forward_mode.is_idle()
|
||||
else ForwardMode.IDLE
|
||||
)
|
||||
|
||||
batch.return_hidden_states = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
@@ -869,12 +883,13 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
forward_batch.spec_info.hidden_states = logits_output.hidden_states
|
||||
else:
|
||||
forward_batch.can_run_dp_cuda_graph = False
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
self.draft_model_runner.attn_backend.init_forward_metadata(
|
||||
forward_batch
|
||||
)
|
||||
logits_output = self.draft_model_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
logits_output, _ = self.draft_model_runner.forward(
|
||||
forward_batch, skip_attn_backend_init=True
|
||||
)
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user