Fix torch.compile for MoE (#2033)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
@@ -98,7 +98,9 @@ def fused_moe_forward_native(
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
assert custom_routing_function is None
|
||||
topk_weights, topk_ids = select_experts_native(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@@ -114,4 +116,4 @@ def fused_moe_forward_native(
|
||||
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
|
||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights)
|
||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||
|
||||
@@ -28,8 +28,9 @@ from sglang.utils import get_exception_traceback
|
||||
DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8"
|
||||
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
|
||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
|
||||
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
|
||||
@@ -740,7 +741,7 @@ def run_mmlu_test(
|
||||
try:
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
assert metrics["score"] >= 0.65
|
||||
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user