Optimize topk operation in llama4 (#5128)
This commit is contained in:
@@ -48,7 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
||||
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
|
||||
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -63,7 +63,7 @@ class Llama4MoE(nn.Module):
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1)
|
||||
router_scores_aK, router_indices_aK = fast_topk(gating_output, topk, dim=-1)
|
||||
router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
|
||||
hidden_states.dtype
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ from sglang.srt.managers.schedule_batch import (
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||
from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2
|
||||
from sglang.srt.utils import fast_topk, is_cuda_available, is_hip, next_power_of_2
|
||||
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import (
|
||||
@@ -772,16 +772,6 @@ def select_top_k_tokens(
|
||||
return input_ids, hidden_states, scores, tree_info
|
||||
|
||||
|
||||
def fast_topk(values, topk, dim):
|
||||
if topk == 1:
|
||||
# Use max along the specified dimension to get both value and index
|
||||
max_value, max_index = torch.max(values, dim=dim)
|
||||
return max_value.unsqueeze(1), max_index.unsqueeze(1)
|
||||
else:
|
||||
# Use topk for efficiency with larger k values
|
||||
return torch.topk(values, topk, dim=dim)
|
||||
|
||||
|
||||
def _generate_simulated_accept_index(
|
||||
accept_index,
|
||||
predict,
|
||||
|
||||
@@ -31,11 +31,15 @@ from sglang.srt.speculative.eagle_utils import (
|
||||
EagleVerifyInput,
|
||||
EagleVerifyOutput,
|
||||
assign_draft_cache_locs,
|
||||
fast_topk,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available
|
||||
from sglang.srt.utils import (
|
||||
empty_context,
|
||||
fast_topk,
|
||||
get_available_gpu_memory,
|
||||
is_cuda_available,
|
||||
)
|
||||
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import segment_packbits
|
||||
|
||||
@@ -1819,3 +1819,12 @@ class DeepEPMode(Enum):
|
||||
return DeepEPMode.low_latency
|
||||
else:
|
||||
return DeepEPMode.normal
|
||||
|
||||
|
||||
def fast_topk(values, topk, dim):
|
||||
if topk == 1:
|
||||
# Use max along the specified dimension to get both value and index
|
||||
return torch.max(values, dim=dim, keepdim=True)
|
||||
else:
|
||||
# Use topk for efficiency with larger k values
|
||||
return torch.topk(values, topk, dim=dim)
|
||||
|
||||
Reference in New Issue
Block a user