[Eagle] Refactor eagle speculative decoding (#3986)
Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
@@ -22,11 +22,13 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||
from sglang.srt.speculative.eagle_utils import (
|
||||
EagleDraftInput,
|
||||
EagleVerifyInput,
|
||||
EagleVerifyOutput,
|
||||
assign_draft_cache_locs,
|
||||
fast_topk,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import get_available_gpu_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,12 +44,16 @@ class EAGLEWorker(TpModelWorker):
|
||||
nccl_port: int,
|
||||
target_worker: TpModelWorker,
|
||||
):
|
||||
# Override context length with target model's context length
|
||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
|
||||
|
||||
# Do not capture cuda graph in `super().__init__()`
|
||||
# We will capture it later
|
||||
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
||||
server_args.disable_cuda_graph = True
|
||||
|
||||
# Load hot token ids
|
||||
# Lossy optimization by using hot tokens
|
||||
if server_args.speculative_token_map is not None:
|
||||
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
||||
server_args.json_model_override_args = (
|
||||
@@ -56,6 +62,12 @@ class EAGLEWorker(TpModelWorker):
|
||||
else:
|
||||
self.hot_token_id = None
|
||||
|
||||
# We share the allocator with a target worker. Draft/target worker
|
||||
# owns its own KV cache.
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||
target_worker.get_memory_pool()
|
||||
)
|
||||
|
||||
# Init target worker
|
||||
super().__init__(
|
||||
gpu_id=gpu_id,
|
||||
@@ -64,9 +76,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
nccl_port=nccl_port,
|
||||
dp_rank=dp_rank,
|
||||
is_draft_worker=True,
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
self.target_worker = target_worker
|
||||
self.finish_extend_len = []
|
||||
|
||||
# Parse arguments
|
||||
self.topk = server_args.speculative_eagle_topk
|
||||
@@ -75,6 +88,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
self.server_args = server_args
|
||||
self.use_nan_detection = self.server_args.enable_nan_detection
|
||||
self.device = self.model_runner.device
|
||||
self.gpu_id = self.model_runner.gpu_id
|
||||
|
||||
# Share the embedding and lm_head
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
@@ -82,8 +98,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
head = head.clone()
|
||||
self.hot_token_id = self.hot_token_id.to(head.device)
|
||||
head.data = head.data[self.hot_token_id]
|
||||
self.model_runner.model.set_embed_and_head(embed, head)
|
||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
||||
self.draft_model_runner.server_args.disable_cuda_graph = (
|
||||
backup_disable_cuda_graph
|
||||
)
|
||||
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
if server_args.attention_backend == "flashinfer":
|
||||
@@ -111,7 +129,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
|
||||
)
|
||||
|
||||
self.model_runner.draft_attn_backend = self.draft_attn_backend
|
||||
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
||||
self.init_cuda_graphs()
|
||||
|
||||
def init_cuda_graphs(self):
|
||||
@@ -122,55 +140,81 @@ class EAGLEWorker(TpModelWorker):
|
||||
return
|
||||
|
||||
tic = time.time()
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
logger.info(
|
||||
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
||||
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
|
||||
logger.info(
|
||||
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
||||
@property
|
||||
def draft_model_runner(self):
|
||||
return self.model_runner
|
||||
|
||||
def forward_batch_speculative_generation(
|
||||
self, batch: ScheduleBatch
|
||||
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
|
||||
"""Run speculative decoding forward.
|
||||
|
||||
NOTE: Many states of batch is modified as you go through. It is not guaranteed
|
||||
the final output batch doesn't have the same state as the input.
|
||||
|
||||
Args:
|
||||
batch: The batch to run forward. The state of the batch is modified as it runs.
|
||||
Returns:
|
||||
A tuple of the final logit output of the target model, next tokens accepeted,
|
||||
the batch id (used for overlap schedule), and number of accepeted tokens.
|
||||
"""
|
||||
assert not batch.spec_algorithm.is_none()
|
||||
if batch.forward_mode.is_decode():
|
||||
# Draft
|
||||
spec_info: EagleVerifyInput = self.draft(batch)
|
||||
|
||||
# Verify
|
||||
(
|
||||
next_draft_input,
|
||||
logits_output,
|
||||
verified_id,
|
||||
self.finish_extend_len,
|
||||
accept_length_cpu,
|
||||
model_worker_batch,
|
||||
) = self.verify(batch, spec_info)
|
||||
batch.spec_info = next_draft_input
|
||||
# if it is None, means all requsets are finished
|
||||
spec_info, to_free_cache_loc = self.draft(batch)
|
||||
logits_output, verify_output, model_worker_batch = self.verify(
|
||||
batch, spec_info
|
||||
)
|
||||
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
|
||||
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
|
||||
# if it is None, means all requests are finished
|
||||
if batch.spec_info.verified_id is not None:
|
||||
self.forward_draft_extend_after_decode(batch)
|
||||
|
||||
return (
|
||||
logits_output,
|
||||
verified_id,
|
||||
model_worker_batch,
|
||||
sum(accept_length_cpu),
|
||||
verify_output.verified_id,
|
||||
model_worker_batch.bid,
|
||||
sum(verify_output.accept_length_per_req_cpu),
|
||||
)
|
||||
|
||||
else:
|
||||
# Forward with the target model and get hidden states.
|
||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
||||
self.forward_draft_extend(
|
||||
batch, logits_output.hidden_states, next_token_ids
|
||||
)
|
||||
return logits_output, next_token_ids, bid, 0
|
||||
|
||||
# Forward with the draft model.
|
||||
batch.spec_info = EagleDraftInput(
|
||||
hidden_states=logits_output.hidden_states,
|
||||
verified_id=next_token_ids,
|
||||
)
|
||||
self.forward_draft_extend(batch)
|
||||
return logits_output, next_token_ids, model_worker_batch, 0
|
||||
def forward_target_extend(
|
||||
self, batch: ScheduleBatch
|
||||
) -> Tuple[LogitsProcessorOutput, List[int], int]:
|
||||
"""Run the target extend.
|
||||
|
||||
Args:
|
||||
batch: The batch to run. States could be modified.
|
||||
|
||||
Returns:
|
||||
logits_output: The output of logits. It will contain the full hidden states.
|
||||
next_token_ids: Next token ids generated.
|
||||
bid: The model batch ID. Used for overlap schedule.
|
||||
"""
|
||||
# Forward with the target model and get hidden states.
|
||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
return logits_output, next_token_ids, model_worker_batch.bid
|
||||
|
||||
def draft(self, batch: ScheduleBatch):
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
|
||||
# Parse args
|
||||
num_seqs = batch.batch_size()
|
||||
spec_info = batch.spec_info
|
||||
@@ -188,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
|
||||
batch.out_cache_loc = out_cache_loc
|
||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
||||
@@ -196,11 +239,12 @@ class EAGLEWorker(TpModelWorker):
|
||||
# Get forward batch
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
||||
forward_batch
|
||||
)
|
||||
|
||||
if can_cuda_graph:
|
||||
score_list, token_list, parents_list = self.cuda_graph_runner.replay(
|
||||
forward_batch
|
||||
@@ -208,7 +252,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
else:
|
||||
# Initialize attention backend
|
||||
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
||||
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
# Run forward steps
|
||||
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
||||
|
||||
@@ -225,10 +271,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.sampling_info.is_all_greedy,
|
||||
)
|
||||
|
||||
# Free cache locations
|
||||
batch.token_to_kv_pool.free(out_cache_loc)
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
return ret
|
||||
return ret, out_cache_loc
|
||||
|
||||
def draft_forward(self, forward_batch: ForwardBatch):
|
||||
# Parse args
|
||||
@@ -278,6 +321,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
logits_output = self.model_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
if self.hot_token_id is not None:
|
||||
@@ -294,71 +338,88 @@ class EAGLEWorker(TpModelWorker):
|
||||
logits_output, _ = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch, skip_sample=True
|
||||
)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
spec_info.hidden_states = logits_output.hidden_states
|
||||
res = spec_info.verify(batch, logits_output)
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
return res + (model_worker_batch,)
|
||||
res: EagleVerifyOutput = spec_info.verify(
|
||||
batch, logits_output, self.token_to_kv_pool_allocator
|
||||
)
|
||||
|
||||
def forward_draft_extend(self, batch: ScheduleBatch):
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
# Post process based on verified outputs.
|
||||
# Pick indices that we care (accepeted)
|
||||
logits_output.next_token_logits = logits_output.next_token_logits[
|
||||
res.accepeted_indices_cpu
|
||||
]
|
||||
logits_output.hidden_states = logits_output.hidden_states[
|
||||
res.accepeted_indices_cpu
|
||||
]
|
||||
# Prepare the batch for the next draft forwards.
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.spec_info = res.draft_input
|
||||
|
||||
return logits_output, res, model_worker_batch
|
||||
|
||||
def forward_draft_extend(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
hidden_states: torch.Tensor,
|
||||
next_token_ids: List[int],
|
||||
):
|
||||
"""Run draft model extend. This API modifies the states of the batch.
|
||||
|
||||
Args:
|
||||
batch: The batch to run.
|
||||
hidden_states: Hidden states from the target model forward
|
||||
next_token_ids: Next token ids generated from the target forward.
|
||||
"""
|
||||
batch.spec_info = EagleDraftInput(
|
||||
hidden_states=hidden_states,
|
||||
verified_id=next_token_ids,
|
||||
)
|
||||
batch.spec_info.prepare_for_extend(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
|
||||
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
|
||||
batch.token_to_kv_pool = runner.token_to_kv_pool
|
||||
batch.req_to_token_pool = runner.req_to_token_pool
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
logits_output = self.draft_model_runner.forward(forward_batch)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
assert forward_batch.spec_info is batch.spec_info
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
seq_lens_backup = batch.seq_lens
|
||||
req_pool_indices_backup = batch.req_pool_indices
|
||||
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
# We don't need logprob for this extend.
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
logits_output = self.draft_model_runner.forward(forward_batch)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
assert forward_batch.spec_info is batch.spec_info
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
# Restore backup.
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.seq_lens = seq_lens_backup
|
||||
batch.req_pool_indices = req_pool_indices_backup
|
||||
|
||||
def capture_for_decode(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
|
||||
):
|
||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||
spec_info = forward_batch.spec_info
|
||||
spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
spec_info.hidden_states = logits_output.hidden_states
|
||||
draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
draft_input.hidden_states = logits_output.hidden_states
|
||||
|
||||
# Don't support prefix share now.
|
||||
def finish_request(self, reqs: Union[Req, List[Req]]):
|
||||
if not isinstance(reqs, List):
|
||||
reqs = [reqs]
|
||||
for req in reqs:
|
||||
if req.rid not in self.finish_extend_len:
|
||||
continue
|
||||
req_len = (
|
||||
len(req.origin_input_ids)
|
||||
+ len(req.output_ids)
|
||||
- self.finish_extend_len[req.rid]
|
||||
- 1
|
||||
)
|
||||
kv_indices = self.model_runner.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx
|
||||
][:req_len]
|
||||
self.model_runner.token_to_kv_pool.free(kv_indices)
|
||||
self.model_runner.req_to_token_pool.free(req.req_pool_idx)
|
||||
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
||||
if self.use_nan_detection:
|
||||
logits = logits_output.next_token_logits
|
||||
if torch.any(torch.isnan(logits)):
|
||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
||||
|
||||
|
||||
def load_token_map(token_map_path: str) -> List[int]:
|
||||
|
||||
Reference in New Issue
Block a user