[Eagle] Remove the greedy branch and some redundant code (#4363)
Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
||||
from sglang.srt.layers.dp_attention import disable_dp_size
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
@@ -27,11 +30,22 @@ from sglang.srt.speculative.eagle_utils import (
|
||||
fast_topk,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.utils import get_available_gpu_memory
|
||||
from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available
|
||||
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import segment_packbits
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def draft_tp_context(tp_group: GroupCoordinator):
|
||||
# Draft model doesn't use dp and has its own tp group.
|
||||
# We disable mscclpp now because it doesn't support 2 comm groups.
|
||||
with disable_dp_size(), patch_tensor_parallel_group(tp_group):
|
||||
yield
|
||||
|
||||
|
||||
class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def __init__(
|
||||
@@ -76,16 +90,17 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.hot_token_id = None
|
||||
|
||||
# Init draft worker
|
||||
super().__init__(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=nccl_port,
|
||||
dp_rank=dp_rank,
|
||||
is_draft_worker=True,
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
with empty_context():
|
||||
super().__init__(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=nccl_port,
|
||||
dp_rank=dp_rank,
|
||||
is_draft_worker=True,
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
|
||||
# Share the embedding and lm_head
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
@@ -94,12 +109,17 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.hot_token_id = self.hot_token_id.to(head.device)
|
||||
head.data = head.data[self.hot_token_id]
|
||||
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
||||
|
||||
# Init attention backend and cuda graphs
|
||||
self.draft_model_runner.server_args.disable_cuda_graph = (
|
||||
backup_disable_cuda_graph
|
||||
)
|
||||
|
||||
self.init_attention_backend()
|
||||
self.init_cuda_graphs()
|
||||
self.draft_tp_context = (
|
||||
draft_tp_context if server_args.enable_dp_attention else empty_context
|
||||
)
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
self.init_attention_backend()
|
||||
self.init_cuda_graphs()
|
||||
|
||||
def init_attention_backend(self):
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
@@ -109,52 +129,70 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
|
||||
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
||||
self.model_runner,
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = True
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
from sglang.srt.layers.attention.triton_backend import (
|
||||
TritonMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = TritonMultiStepDraftBackend(
|
||||
self.model_runner,
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = False
|
||||
elif self.server_args.attention_backend == "flashinfer_mla":
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
|
||||
self.model_runner,
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
|
||||
)
|
||||
|
||||
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
||||
|
||||
def init_cuda_graphs(self):
|
||||
"""Capture cuda graphs."""
|
||||
self.cuda_graph_runner = None
|
||||
self.cuda_graph_runner_for_draft_extend = None
|
||||
|
||||
if self.server_args.disable_cuda_graph:
|
||||
return
|
||||
|
||||
# Capture draft
|
||||
tic = time.time()
|
||||
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||
)
|
||||
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
||||
)
|
||||
|
||||
# Capture extend
|
||||
if self.draft_extend_attn_backend:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def draft_model_runner(self):
|
||||
return self.model_runner
|
||||
@@ -164,8 +202,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
|
||||
"""Run speculative decoding forward.
|
||||
|
||||
NOTE: Many states of batch is modified as you go through. It is not guaranteed
|
||||
the final output batch doesn't have the same state as the input.
|
||||
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
||||
the final output batch have the same state as the input.
|
||||
|
||||
Args:
|
||||
batch: The batch to run forward. The state of the batch is modified as it runs.
|
||||
@@ -173,30 +211,42 @@ class EAGLEWorker(TpModelWorker):
|
||||
A tuple of the final logit output of the target model, next tokens accepeted,
|
||||
the batch id (used for overlap schedule), and number of accepeted tokens.
|
||||
"""
|
||||
assert not batch.spec_algorithm.is_none()
|
||||
if batch.forward_mode.is_decode():
|
||||
spec_info, to_free_cache_loc = self.draft(batch)
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
spec_info, to_free_cache_loc = self.draft(batch)
|
||||
logits_output, verify_output, model_worker_batch = self.verify(
|
||||
batch, spec_info
|
||||
)
|
||||
|
||||
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
|
||||
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
|
||||
# if it is None, means all requests are finished
|
||||
if batch.spec_info.verified_id is not None:
|
||||
self.forward_draft_extend_after_decode(batch)
|
||||
|
||||
# If it is None, it means all requests are finished
|
||||
if batch.spec_info.verified_id is not None:
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
self.forward_draft_extend_after_decode(batch)
|
||||
return (
|
||||
logits_output,
|
||||
verify_output.verified_id,
|
||||
model_worker_batch.bid,
|
||||
sum(verify_output.accept_length_per_req_cpu),
|
||||
)
|
||||
|
||||
elif batch.forward_mode.is_idle():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids, _ = (
|
||||
self.target_worker.forward_batch_generation(
|
||||
ForwardBatch.init_new(
|
||||
model_worker_batch, self.target_worker.model_runner
|
||||
)
|
||||
)
|
||||
)
|
||||
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
|
||||
else:
|
||||
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
||||
self.forward_draft_extend(
|
||||
batch, logits_output.hidden_states, next_token_ids
|
||||
)
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
self.forward_draft_extend(
|
||||
batch, logits_output.hidden_states, next_token_ids
|
||||
)
|
||||
return logits_output, next_token_ids, bid, 0
|
||||
|
||||
def forward_target_extend(
|
||||
@@ -226,6 +276,13 @@ class EAGLEWorker(TpModelWorker):
|
||||
num_seqs = batch.batch_size()
|
||||
spec_info = batch.spec_info
|
||||
|
||||
# Accumulate penalty
|
||||
if batch.sampling_info.penalizer_orchestrator.is_required:
|
||||
# This is a relaxed version of penalties for speculative decoding.
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
spec_info.verified_id.to(torch.int64)
|
||||
)
|
||||
|
||||
# Allocate cache locations
|
||||
out_cache_loc = batch.alloc_token_slots(
|
||||
num_seqs * self.topk * self.speculative_num_steps
|
||||
@@ -275,9 +332,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
self.server_args.speculative_num_draft_tokens,
|
||||
batch.sampling_info.is_all_greedy,
|
||||
)
|
||||
|
||||
return ret, out_cache_loc
|
||||
|
||||
def draft_forward(self, forward_batch: ForwardBatch):
|
||||
@@ -307,7 +362,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
token_list.append(tree_info[1])
|
||||
parents_list.append(tree_info[2])
|
||||
|
||||
# we don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
|
||||
# We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
|
||||
if i == self.speculative_num_steps - 1:
|
||||
break
|
||||
|
||||
@@ -322,7 +377,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
spec_info.hidden_states = hidden_states
|
||||
|
||||
# Run forward
|
||||
logits_output = self.model_runner.model.forward(
|
||||
logits_output = self.draft_model_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
@@ -351,11 +406,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
# Post process based on verified outputs.
|
||||
# Pick indices that we care (accepeted)
|
||||
logits_output.next_token_logits = logits_output.next_token_logits[
|
||||
res.accepeted_indices_cpu
|
||||
]
|
||||
logits_output.hidden_states = logits_output.hidden_states[
|
||||
res.accepeted_indices_cpu
|
||||
res.accepeted_indices
|
||||
]
|
||||
logits_output.hidden_states = logits_output.hidden_states[res.accepeted_indices]
|
||||
|
||||
# Prepare the batch for the next draft forwards.
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.spec_info = res.draft_input
|
||||
@@ -407,7 +461,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch_next_token_ids,
|
||||
]
|
||||
|
||||
# Add output logprobs to the request.
|
||||
# Add output logprobs to the request
|
||||
pt = 0
|
||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||
verified_ids = batch_next_token_ids.tolist()
|
||||
@@ -456,27 +510,38 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
seq_lens_backup = batch.seq_lens
|
||||
# Backup fileds that will be modified in-place
|
||||
seq_lens_backup = batch.seq_lens.clone()
|
||||
req_pool_indices_backup = batch.req_pool_indices
|
||||
accept_length_backup = batch.spec_info.accept_length
|
||||
return_logprob_backup = batch.return_logprob
|
||||
|
||||
# Prepare metadata
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
|
||||
batch.spec_info.prepare_extend_after_decode(
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
# We don't need logprob for this extend.
|
||||
original_return_logprob = batch.return_logprob
|
||||
batch.return_logprob = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
|
||||
# Run
|
||||
logits_output = self.draft_model_runner.forward(forward_batch)
|
||||
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
assert forward_batch.spec_info is batch.spec_info
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
# Restore backup.
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
batch.return_logprob = original_return_logprob
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.seq_lens = seq_lens_backup
|
||||
batch.req_pool_indices = req_pool_indices_backup
|
||||
batch.spec_info.accept_length = accept_length_backup
|
||||
batch.return_logprob = return_logprob_backup
|
||||
|
||||
def capture_for_decode(
|
||||
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
|
||||
@@ -489,7 +554,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
if self.enable_nan_detection:
|
||||
logits = logits_output.next_token_logits
|
||||
if torch.any(torch.isnan(logits)):
|
||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||
logger.error("Detected errors during sampling! NaN in the logits.")
|
||||
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user