Fix torch.compile for MoE (#2033)

This commit is contained in:
Lianmin Zheng
2024-11-14 01:30:24 -08:00
committed by GitHub
parent b275ce0043
commit c3eac1b010
10 changed files with 89 additions and 12 deletions

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Callable, Optional
import torch
from torch.nn import functional as F
@@ -98,7 +98,9 @@ def fused_moe_forward_native(
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
assert custom_routing_function is None
topk_weights, topk_ids = select_experts_native(
hidden_states=x,
router_logits=router_logits,
@@ -114,4 +116,4 @@ def fused_moe_forward_native(
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))

View File

@@ -28,8 +28,9 @@ from sglang.utils import get_exception_traceback
DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8"
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
@@ -740,7 +741,7 @@ def run_mmlu_test(
try:
metrics = run_eval(args)
print(f"{metrics=}")
assert metrics["score"] >= 0.65
self.assertGreaterEqual(metrics["score"], 0.65)
finally:
pass