fix EAGLE 2 non greedy case (#3407)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
Yineng Zhang
2025-02-09 07:28:34 +08:00
committed by GitHub
parent f90db8bc07
commit fad315cb8e
4 changed files with 71 additions and 22 deletions

View File

@@ -54,7 +54,9 @@ def get_model_config(model_name: str, tp_size: int):
): ):
block_shape = config.quantization_config["weight_block_size"] block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2 assert len(block_shape) == 2
assert vllm_version_num >= 66, "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1" assert (
vllm_version_num >= 66
), "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
shape_configs = { shape_configs = {
"num_experts": E, "num_experts": E,

View File

@@ -462,8 +462,11 @@ class CudaGraphRunner:
), ),
positions=None, positions=None,
retrive_index=None, retrive_index=None,
retrive_next_token=None,
retrive_next_sibling=None,
retrive_cum_len=None, retrive_cum_len=None,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
spec_steps=self.model_runner.server_args.speculative_num_steps,
capture_hidden_mode=CaptureHiddenMode.FULL, capture_hidden_mode=CaptureHiddenMode.FULL,
) )

View File

@@ -4,6 +4,7 @@ import dataclasses
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
import torch import torch
import torch.nn.functional as F
import triton import triton
import triton.language as tl import triton.language as tl
@@ -11,7 +12,14 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton, create_flashinfer_kv_indices_triton,
) )
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel from sglang.srt.speculative.build_eagle_tree import (
build_tree_kernel,
build_tree_kernel_efficient,
)
from sglang.srt.utils import is_cuda_available
if is_cuda_available():
from sgl_kernel import tree_speculative_sampling_target_only
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -160,8 +168,11 @@ class EagleVerifyInput:
custom_mask: torch.Tensor custom_mask: torch.Tensor
positions: torch.Tensor positions: torch.Tensor
retrive_index: torch.Tensor retrive_index: torch.Tensor
retrive_next_token: torch.Tensor
retrive_next_sibling: torch.Tensor
retrive_cum_len: torch.Tensor retrive_cum_len: torch.Tensor
draft_token_num: int draft_token_num: int
spec_steps: int
capture_hidden_mode: CaptureHiddenMode capture_hidden_mode: CaptureHiddenMode
@classmethod @classmethod
@@ -175,10 +186,45 @@ class EagleVerifyInput:
seq_lens_sum: int, seq_lens_sum: int,
topk: int, topk: int,
spec_steps: int, spec_steps: int,
num_verify_token: int, num_verify_tokens: int,
is_all_greedy: bool,
): ):
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = ( if is_all_greedy:
build_tree_kernel( tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
build_tree_kernel(
verified_id,
score_list, # b, n, topk; n= 1 + (num_steps-1) * self.topk
token_list,
parents_list,
seq_lens,
seq_lens_sum,
topk,
spec_steps,
num_verify_tokens,
)
)
return cls(
draft_tokens,
tree_mask,
position,
retrive_index,
None,
None,
retrive_cum_len,
num_verify_tokens,
spec_steps,
CaptureHiddenMode.FULL,
)
else:
(
tree_mask,
position,
retrive_index,
retrive_next_token,
retrive_next_sibling,
draft_tokens,
) = build_tree_kernel_efficient(
verified_id, verified_id,
score_list, score_list,
token_list, token_list,
@@ -187,18 +233,21 @@ class EagleVerifyInput:
seq_lens_sum, seq_lens_sum,
topk, topk,
spec_steps, spec_steps,
num_verify_token, num_verify_tokens,
)
return cls(
draft_tokens,
tree_mask,
position,
retrive_index,
retrive_next_token,
retrive_next_sibling,
None,
num_verify_tokens,
spec_steps,
CaptureHiddenMode.FULL,
) )
)
return cls(
draft_tokens,
tree_mask,
position,
retrive_index,
retrive_cum_len,
num_verify_token,
CaptureHiddenMode.FULL,
)
def prepare_for_verify(self, batch: ScheduleBatch): def prepare_for_verify(self, batch: ScheduleBatch):
batch.input_ids = self.draft_token batch.input_ids = self.draft_token
@@ -313,12 +362,6 @@ class EagleVerifyInput:
uniform_samples=coins, uniform_samples=coins,
target_probs=target_probs, target_probs=target_probs,
draft_probs=draft_probs, draft_probs=draft_probs,
threshold_single=global_server_args_dict[
"speculative_accept_threshold_single"
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
deterministic=True, deterministic=True,
) )

View File

@@ -185,6 +185,7 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens, self.server_args.speculative_num_draft_tokens,
batch.sampling_info.is_all_greedy,
) )
# Free cache locations # Free cache locations