[Eagle] Fix kernel call after updating speculative sampling kernels (#7231)
This commit is contained in:
@@ -49,7 +49,7 @@ runtime_common = [
|
||||
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.1.8.post2",
|
||||
"sgl-kernel==0.1.9",
|
||||
"flashinfer_python==0.2.6.post1",
|
||||
"torch==2.7.1",
|
||||
"torchaudio==2.7.1",
|
||||
|
||||
@@ -605,7 +605,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
if _is_cuda:
|
||||
assert_pkg_version(
|
||||
"sgl-kernel",
|
||||
"0.1.8.post2",
|
||||
"0.1.9",
|
||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||
)
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ def build_tree_kernel_efficient(
|
||||
sgl_build_tree_kernel_efficient(
|
||||
parent_list,
|
||||
top_scores_index,
|
||||
seq_lens.to(torch.int32),
|
||||
seq_lens,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
|
||||
@@ -23,7 +23,7 @@ from sglang.srt.managers.schedule_batch import (
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
|
||||
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import (
|
||||
@@ -32,6 +32,7 @@ if is_cuda():
|
||||
tree_speculative_sampling_target_only,
|
||||
verify_tree_greedy,
|
||||
)
|
||||
from sgl_kernel.top_k import fast_topk
|
||||
elif is_hip():
|
||||
from sgl_kernel import verify_tree_greedy
|
||||
|
||||
@@ -327,11 +328,11 @@ class EagleVerifyInput:
|
||||
predicts=predict, # mutable
|
||||
accept_index=accept_index, # mutable
|
||||
accept_token_num=accept_length, # mutable
|
||||
candidates=candidates.to(torch.int32),
|
||||
retrive_index=self.retrive_index.to(torch.int32),
|
||||
retrive_next_token=self.retrive_next_token.to(torch.int32),
|
||||
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
|
||||
target_predict=target_predict.to(torch.int32),
|
||||
candidates=candidates,
|
||||
retrive_index=self.retrive_index,
|
||||
retrive_next_token=self.retrive_next_token,
|
||||
retrive_next_sibling=self.retrive_next_sibling,
|
||||
target_predict=target_predict,
|
||||
)
|
||||
else:
|
||||
# apply temperature and get target probs
|
||||
@@ -370,12 +371,12 @@ class EagleVerifyInput:
|
||||
predicts=predict, # mutable
|
||||
accept_index=accept_index, # mutable
|
||||
accept_token_num=accept_length, # mutable
|
||||
candidates=candidates.to(torch.int32),
|
||||
retrive_index=self.retrive_index.to(torch.int32),
|
||||
retrive_next_token=self.retrive_next_token.to(torch.int32),
|
||||
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
|
||||
candidates=candidates,
|
||||
retrive_index=self.retrive_index,
|
||||
retrive_next_token=self.retrive_next_token,
|
||||
retrive_next_sibling=self.retrive_next_sibling,
|
||||
uniform_samples=coins,
|
||||
# uniform_samples_for_final_sampling=coins_for_final_sampling,
|
||||
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
||||
target_probs=target_probs,
|
||||
draft_probs=draft_probs,
|
||||
threshold_single=global_server_args_dict[
|
||||
@@ -1005,16 +1006,6 @@ def select_top_k_tokens(
|
||||
return input_ids, hidden_states, scores, tree_info
|
||||
|
||||
|
||||
def fast_topk_torch(values, topk, dim):
|
||||
if topk == 1:
|
||||
# Use max along the specified dimension to get both value and index
|
||||
max_value, max_index = torch.max(values, dim=dim)
|
||||
return max_value.unsqueeze(1), max_index.unsqueeze(1)
|
||||
else:
|
||||
# Use topk for efficiency with larger k values
|
||||
return torch.topk(values, topk, dim=dim)
|
||||
|
||||
|
||||
def _generate_simulated_accept_index(
|
||||
accept_index,
|
||||
predict,
|
||||
|
||||
@@ -828,7 +828,7 @@ def load_token_map(token_map_path: str) -> List[int]:
|
||||
)
|
||||
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
||||
hot_token_id = torch.load(token_map_path, weights_only=True)
|
||||
return torch.tensor(hot_token_id, dtype=torch.int32)
|
||||
return torch.tensor(hot_token_id, dtype=torch.int64)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
|
||||
Reference in New Issue
Block a user