Allow running with vllm==0.4.3 (#561)

This commit is contained in:
Lianmin Zheng
2024-06-24 15:24:21 -07:00
committed by GitHub
parent 05471f2103
commit 9465b668b9
2 changed files with 49 additions and 8 deletions

View File

@@ -1,13 +1,19 @@
import json
from typing import Dict, Optional, Union
from outlines.caching import cache as disk_cache
from outlines.caching import disable_cache
from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel
try:
from outlines.caching import cache as disk_cache
from outlines.fsm.guide import RegexGuide
from outlines.caching import disable_cache
from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer
except ImportError as e:
print(f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n')
raise
try:
from outlines.fsm.json_schema import build_regex_from_object
except ImportError:

View File

@@ -512,8 +512,13 @@ def fused_moe(
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
if hasattr(ops, "topk_softmax"):
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
else:
topk_weights, topk_ids = fused_topk_v0_4_3(hidden_states, gating_output, topk,
renormalize)
return fused_experts(hidden_states,
w1,
w2,
@@ -525,4 +530,34 @@ def fused_moe(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
a2_scale=a2_scale)
def fused_topk_v0_4_3(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
import vllm._moe_C as moe_kernels
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids