[Eagle] Fix kernel call after updating speculative sampling kernels (#7231)

This commit is contained in:
Lianmin Zheng
2025-06-16 07:25:59 -07:00
committed by GitHub
parent 7ddf8e83d2
commit 53a525bf33
7 changed files with 24 additions and 33 deletions

View File

@@ -49,7 +49,7 @@ runtime_common = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.1.8.post2",
"sgl-kernel==0.1.9",
"flashinfer_python==0.2.6.post1",
"torch==2.7.1",
"torchaudio==2.7.1",

View File

@@ -605,7 +605,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda:
assert_pkg_version(
"sgl-kernel",
"0.1.8.post2",
"0.1.9",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
)

View File

@@ -92,7 +92,7 @@ def build_tree_kernel_efficient(
sgl_build_tree_kernel_efficient(
parent_list,
top_scores_index,
seq_lens.to(torch.int32),
seq_lens,
tree_mask,
positions,
retrive_index,

View File

@@ -23,7 +23,7 @@ from sglang.srt.managers.schedule_batch import (
)
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
if is_cuda():
from sgl_kernel import (
@@ -32,6 +32,7 @@ if is_cuda():
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
from sgl_kernel.top_k import fast_topk
elif is_hip():
from sgl_kernel import verify_tree_greedy
@@ -327,11 +328,11 @@ class EagleVerifyInput:
predicts=predict, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable
candidates=candidates.to(torch.int32),
retrive_index=self.retrive_index.to(torch.int32),
retrive_next_token=self.retrive_next_token.to(torch.int32),
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
target_predict=target_predict.to(torch.int32),
candidates=candidates,
retrive_index=self.retrive_index,
retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling,
target_predict=target_predict,
)
else:
# apply temperature and get target probs
@@ -370,12 +371,12 @@ class EagleVerifyInput:
predicts=predict, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable
candidates=candidates.to(torch.int32),
retrive_index=self.retrive_index.to(torch.int32),
retrive_next_token=self.retrive_next_token.to(torch.int32),
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
candidates=candidates,
retrive_index=self.retrive_index,
retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling,
uniform_samples=coins,
# uniform_samples_for_final_sampling=coins_for_final_sampling,
uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=global_server_args_dict[
@@ -1005,16 +1006,6 @@ def select_top_k_tokens(
return input_ids, hidden_states, scores, tree_info
def fast_topk_torch(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
max_value, max_index = torch.max(values, dim=dim)
return max_value.unsqueeze(1), max_index.unsqueeze(1)
else:
# Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim)
def _generate_simulated_accept_index(
accept_index,
predict,

View File

@@ -828,7 +828,7 @@ def load_token_map(token_map_path: str) -> List[int]:
)
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int32)
return torch.tensor(hot_token_id, dtype=torch.int64)
@torch.compile(dynamic=True)