Allow running with vllm==0.4.3 (#561)
This commit is contained in:
@@ -1,13 +1,19 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from outlines.caching import cache as disk_cache
|
|
||||||
from outlines.caching import disable_cache
|
|
||||||
from outlines.fsm.guide import RegexGuide
|
|
||||||
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
|
||||||
from outlines.models.transformers import TransformerTokenizer
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
try:
|
||||||
|
from outlines.caching import cache as disk_cache
|
||||||
|
from outlines.fsm.guide import RegexGuide
|
||||||
|
from outlines.caching import disable_cache
|
||||||
|
from outlines.fsm.guide import RegexGuide
|
||||||
|
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
||||||
|
from outlines.models.transformers import TransformerTokenizer
|
||||||
|
except ImportError as e:
|
||||||
|
print(f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n')
|
||||||
|
raise
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from outlines.fsm.json_schema import build_regex_from_object
|
from outlines.fsm.json_schema import build_regex_from_object
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@@ -512,8 +512,13 @@ def fused_moe(
|
|||||||
# Check constraints.
|
# Check constraints.
|
||||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||||
|
|
||||||
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
if hasattr(ops, "topk_softmax"):
|
||||||
renormalize)
|
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
||||||
|
renormalize)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = fused_topk_v0_4_3(hidden_states, gating_output, topk,
|
||||||
|
renormalize)
|
||||||
|
|
||||||
return fused_experts(hidden_states,
|
return fused_experts(hidden_states,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
@@ -525,4 +530,34 @@ def fused_moe(
|
|||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale)
|
a2_scale=a2_scale)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def fused_topk_v0_4_3(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
):
|
||||||
|
import vllm._moe_C as moe_kernels
|
||||||
|
M, _ = hidden_states.shape
|
||||||
|
|
||||||
|
topk_weights = torch.empty(
|
||||||
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||||
|
token_expert_indicies = torch.empty(
|
||||||
|
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
moe_kernels.topk_softmax(
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
token_expert_indicies,
|
||||||
|
gating_output.float(), # TODO(woosuk): Optimize this.
|
||||||
|
)
|
||||||
|
del token_expert_indicies # Not used. Will be used in the future.
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return topk_weights, topk_ids
|
||||||
Reference in New Issue
Block a user