Allow running with vllm==0.4.3 (#561)
This commit is contained in:
@@ -1,13 +1,19 @@
|
||||
import json
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from outlines.caching import cache as disk_cache
|
||||
from outlines.caching import disable_cache
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
||||
from outlines.models.transformers import TransformerTokenizer
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from outlines.caching import cache as disk_cache
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
from outlines.caching import disable_cache
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
||||
from outlines.models.transformers import TransformerTokenizer
|
||||
except ImportError as e:
|
||||
print(f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n')
|
||||
raise
|
||||
|
||||
try:
|
||||
from outlines.fsm.json_schema import build_regex_from_object
|
||||
except ImportError:
|
||||
|
||||
@@ -512,8 +512,13 @@ def fused_moe(
|
||||
# Check constraints.
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
|
||||
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
||||
renormalize)
|
||||
if hasattr(ops, "topk_softmax"):
|
||||
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
||||
renormalize)
|
||||
else:
|
||||
topk_weights, topk_ids = fused_topk_v0_4_3(hidden_states, gating_output, topk,
|
||||
renormalize)
|
||||
|
||||
return fused_experts(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
@@ -525,4 +530,34 @@ def fused_moe(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale)
|
||||
a2_scale=a2_scale)
|
||||
|
||||
|
||||
|
||||
def fused_topk_v0_4_3(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
import vllm._moe_C as moe_kernels
|
||||
M, _ = hidden_states.shape
|
||||
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||
token_expert_indicies = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
moe_kernels.topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(), # TODO(woosuk): Optimize this.
|
||||
)
|
||||
del token_expert_indicies # Not used. Will be used in the future.
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
Reference in New Issue
Block a user