fix EAGLE 2 non greedy case (#3407)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -54,7 +54,9 @@ def get_model_config(model_name: str, tp_size: int):
|
|||||||
):
|
):
|
||||||
block_shape = config.quantization_config["weight_block_size"]
|
block_shape = config.quantization_config["weight_block_size"]
|
||||||
assert len(block_shape) == 2
|
assert len(block_shape) == 2
|
||||||
assert vllm_version_num >= 66, "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
|
assert (
|
||||||
|
vllm_version_num >= 66
|
||||||
|
), "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
|
||||||
|
|
||||||
shape_configs = {
|
shape_configs = {
|
||||||
"num_experts": E,
|
"num_experts": E,
|
||||||
|
|||||||
@@ -462,8 +462,11 @@ class CudaGraphRunner:
|
|||||||
),
|
),
|
||||||
positions=None,
|
positions=None,
|
||||||
retrive_index=None,
|
retrive_index=None,
|
||||||
|
retrive_next_token=None,
|
||||||
|
retrive_next_sibling=None,
|
||||||
retrive_cum_len=None,
|
retrive_cum_len=None,
|
||||||
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
||||||
|
spec_steps=self.model_runner.server_args.speculative_num_steps,
|
||||||
capture_hidden_mode=CaptureHiddenMode.FULL,
|
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import dataclasses
|
|||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
@@ -11,7 +12,14 @@ from sglang.srt.layers.attention.flashinfer_backend import (
|
|||||||
create_flashinfer_kv_indices_triton,
|
create_flashinfer_kv_indices_triton,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
|
from sglang.srt.speculative.build_eagle_tree import (
|
||||||
|
build_tree_kernel,
|
||||||
|
build_tree_kernel_efficient,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import is_cuda_available
|
||||||
|
|
||||||
|
if is_cuda_available():
|
||||||
|
from sgl_kernel import tree_speculative_sampling_target_only
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
@@ -160,8 +168,11 @@ class EagleVerifyInput:
|
|||||||
custom_mask: torch.Tensor
|
custom_mask: torch.Tensor
|
||||||
positions: torch.Tensor
|
positions: torch.Tensor
|
||||||
retrive_index: torch.Tensor
|
retrive_index: torch.Tensor
|
||||||
|
retrive_next_token: torch.Tensor
|
||||||
|
retrive_next_sibling: torch.Tensor
|
||||||
retrive_cum_len: torch.Tensor
|
retrive_cum_len: torch.Tensor
|
||||||
draft_token_num: int
|
draft_token_num: int
|
||||||
|
spec_steps: int
|
||||||
capture_hidden_mode: CaptureHiddenMode
|
capture_hidden_mode: CaptureHiddenMode
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -175,10 +186,45 @@ class EagleVerifyInput:
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
spec_steps: int,
|
spec_steps: int,
|
||||||
num_verify_token: int,
|
num_verify_tokens: int,
|
||||||
|
is_all_greedy: bool,
|
||||||
):
|
):
|
||||||
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
|
if is_all_greedy:
|
||||||
build_tree_kernel(
|
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
|
||||||
|
build_tree_kernel(
|
||||||
|
verified_id,
|
||||||
|
score_list, # b, n, topk; n= 1 + (num_steps-1) * self.topk
|
||||||
|
token_list,
|
||||||
|
parents_list,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_sum,
|
||||||
|
topk,
|
||||||
|
spec_steps,
|
||||||
|
num_verify_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
draft_tokens,
|
||||||
|
tree_mask,
|
||||||
|
position,
|
||||||
|
retrive_index,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
retrive_cum_len,
|
||||||
|
num_verify_tokens,
|
||||||
|
spec_steps,
|
||||||
|
CaptureHiddenMode.FULL,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
tree_mask,
|
||||||
|
position,
|
||||||
|
retrive_index,
|
||||||
|
retrive_next_token,
|
||||||
|
retrive_next_sibling,
|
||||||
|
draft_tokens,
|
||||||
|
) = build_tree_kernel_efficient(
|
||||||
verified_id,
|
verified_id,
|
||||||
score_list,
|
score_list,
|
||||||
token_list,
|
token_list,
|
||||||
@@ -187,18 +233,21 @@ class EagleVerifyInput:
|
|||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
topk,
|
topk,
|
||||||
spec_steps,
|
spec_steps,
|
||||||
num_verify_token,
|
num_verify_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
draft_tokens,
|
||||||
|
tree_mask,
|
||||||
|
position,
|
||||||
|
retrive_index,
|
||||||
|
retrive_next_token,
|
||||||
|
retrive_next_sibling,
|
||||||
|
None,
|
||||||
|
num_verify_tokens,
|
||||||
|
spec_steps,
|
||||||
|
CaptureHiddenMode.FULL,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
return cls(
|
|
||||||
draft_tokens,
|
|
||||||
tree_mask,
|
|
||||||
position,
|
|
||||||
retrive_index,
|
|
||||||
retrive_cum_len,
|
|
||||||
num_verify_token,
|
|
||||||
CaptureHiddenMode.FULL,
|
|
||||||
)
|
|
||||||
|
|
||||||
def prepare_for_verify(self, batch: ScheduleBatch):
|
def prepare_for_verify(self, batch: ScheduleBatch):
|
||||||
batch.input_ids = self.draft_token
|
batch.input_ids = self.draft_token
|
||||||
@@ -313,12 +362,6 @@ class EagleVerifyInput:
|
|||||||
uniform_samples=coins,
|
uniform_samples=coins,
|
||||||
target_probs=target_probs,
|
target_probs=target_probs,
|
||||||
draft_probs=draft_probs,
|
draft_probs=draft_probs,
|
||||||
threshold_single=global_server_args_dict[
|
|
||||||
"speculative_accept_threshold_single"
|
|
||||||
],
|
|
||||||
threshold_acc=global_server_args_dict[
|
|
||||||
"speculative_accept_threshold_acc"
|
|
||||||
],
|
|
||||||
deterministic=True,
|
deterministic=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -185,6 +185,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.topk,
|
self.topk,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
self.server_args.speculative_num_draft_tokens,
|
self.server_args.speculative_num_draft_tokens,
|
||||||
|
batch.sampling_info.is_all_greedy,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Free cache locations
|
# Free cache locations
|
||||||
|
|||||||
Reference in New Issue
Block a user