fix EAGLE 2 non greedy case (#3407)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -462,8 +462,11 @@ class CudaGraphRunner:
|
||||
),
|
||||
positions=None,
|
||||
retrive_index=None,
|
||||
retrive_next_token=None,
|
||||
retrive_next_sibling=None,
|
||||
retrive_cum_len=None,
|
||||
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
||||
spec_steps=self.model_runner.server_args.speculative_num_steps,
|
||||
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import dataclasses
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@@ -11,7 +12,14 @@ from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
create_flashinfer_kv_indices_triton,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
|
||||
from sglang.srt.speculative.build_eagle_tree import (
|
||||
build_tree_kernel,
|
||||
build_tree_kernel_efficient,
|
||||
)
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import tree_speculative_sampling_target_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
@@ -160,8 +168,11 @@ class EagleVerifyInput:
|
||||
custom_mask: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
retrive_index: torch.Tensor
|
||||
retrive_next_token: torch.Tensor
|
||||
retrive_next_sibling: torch.Tensor
|
||||
retrive_cum_len: torch.Tensor
|
||||
draft_token_num: int
|
||||
spec_steps: int
|
||||
capture_hidden_mode: CaptureHiddenMode
|
||||
|
||||
@classmethod
|
||||
@@ -175,10 +186,45 @@ class EagleVerifyInput:
|
||||
seq_lens_sum: int,
|
||||
topk: int,
|
||||
spec_steps: int,
|
||||
num_verify_token: int,
|
||||
num_verify_tokens: int,
|
||||
is_all_greedy: bool,
|
||||
):
|
||||
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
|
||||
build_tree_kernel(
|
||||
if is_all_greedy:
|
||||
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
|
||||
build_tree_kernel(
|
||||
verified_id,
|
||||
score_list, # b, n, topk; n= 1 + (num_steps-1) * self.topk
|
||||
token_list,
|
||||
parents_list,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
topk,
|
||||
spec_steps,
|
||||
num_verify_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
draft_tokens,
|
||||
tree_mask,
|
||||
position,
|
||||
retrive_index,
|
||||
None,
|
||||
None,
|
||||
retrive_cum_len,
|
||||
num_verify_tokens,
|
||||
spec_steps,
|
||||
CaptureHiddenMode.FULL,
|
||||
)
|
||||
else:
|
||||
(
|
||||
tree_mask,
|
||||
position,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
draft_tokens,
|
||||
) = build_tree_kernel_efficient(
|
||||
verified_id,
|
||||
score_list,
|
||||
token_list,
|
||||
@@ -187,18 +233,21 @@ class EagleVerifyInput:
|
||||
seq_lens_sum,
|
||||
topk,
|
||||
spec_steps,
|
||||
num_verify_token,
|
||||
num_verify_tokens,
|
||||
)
|
||||
|
||||
return cls(
|
||||
draft_tokens,
|
||||
tree_mask,
|
||||
position,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
None,
|
||||
num_verify_tokens,
|
||||
spec_steps,
|
||||
CaptureHiddenMode.FULL,
|
||||
)
|
||||
)
|
||||
return cls(
|
||||
draft_tokens,
|
||||
tree_mask,
|
||||
position,
|
||||
retrive_index,
|
||||
retrive_cum_len,
|
||||
num_verify_token,
|
||||
CaptureHiddenMode.FULL,
|
||||
)
|
||||
|
||||
def prepare_for_verify(self, batch: ScheduleBatch):
|
||||
batch.input_ids = self.draft_token
|
||||
@@ -313,12 +362,6 @@ class EagleVerifyInput:
|
||||
uniform_samples=coins,
|
||||
target_probs=target_probs,
|
||||
draft_probs=draft_probs,
|
||||
threshold_single=global_server_args_dict[
|
||||
"speculative_accept_threshold_single"
|
||||
],
|
||||
threshold_acc=global_server_args_dict[
|
||||
"speculative_accept_threshold_acc"
|
||||
],
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -185,6 +185,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
self.server_args.speculative_num_draft_tokens,
|
||||
batch.sampling_info.is_all_greedy,
|
||||
)
|
||||
|
||||
# Free cache locations
|
||||
|
||||
Reference in New Issue
Block a user