diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index 988f502c8..6cc85fd64 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -1,13 +1,19 @@ import json from typing import Dict, Optional, Union -from outlines.caching import cache as disk_cache -from outlines.caching import disable_cache -from outlines.fsm.guide import RegexGuide -from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm -from outlines.models.transformers import TransformerTokenizer from pydantic import BaseModel +try: + from outlines.caching import cache as disk_cache + from outlines.fsm.guide import RegexGuide + from outlines.caching import disable_cache + from outlines.fsm.guide import RegexGuide + from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm + from outlines.models.transformers import TransformerTokenizer +except ImportError as e: + print(f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n') + raise + try: from outlines.fsm.json_schema import build_regex_from_object except ImportError: diff --git a/python/sglang/srt/layers/fused_moe.py b/python/sglang/srt/layers/fused_moe.py index a6a46e50b..6ebe206df 100644 --- a/python/sglang/srt/layers/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe.py @@ -512,8 +512,13 @@ def fused_moe( # Check constraints. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) + if hasattr(ops, "topk_softmax"): + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + else: + topk_weights, topk_ids = fused_topk_v0_4_3(hidden_states, gating_output, topk, + renormalize) + return fused_experts(hidden_states, w1, w2, @@ -525,4 +530,34 @@ def fused_moe( w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, - a2_scale=a2_scale) \ No newline at end of file + a2_scale=a2_scale) + + + +def fused_topk_v0_4_3( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + import vllm._moe_C as moe_kernels + M, _ = hidden_states.shape + + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) + token_expert_indicies = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) + moe_kernels.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids \ No newline at end of file