MoE torch compile (#1497)

This commit is contained in:
Ke Bao
2024-09-24 16:46:59 +08:00
committed by GitHub
parent 2854a5ea9f
commit 8d4ed42ad5
2 changed files with 126 additions and 5 deletions

View File

@@ -25,6 +25,7 @@ import torch
from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.fused_moe.patch import fused_moe_forward_native
from sglang.srt.layers.logits_processor import (
LogitsMetadata,
LogitsProcessor,
@@ -41,14 +42,15 @@ if TYPE_CHECKING:
def _to_torch(model: torch.nn.Module, reverse: bool = False):
for sub in model._modules.values():
if isinstance(sub, CustomOp):
# NOTE: FusedMoE torch native implementaiton is not efficient
if "FusedMoE" in sub.__class__.__name__:
continue
if reverse:
sub._forward_method = sub.forward_cuda
setattr(sub, "is_torch_compile", False)
else:
sub._forward_method = sub.forward_native
# NOTE: Temporarily workaround MoE
if "FusedMoE" in sub.__class__.__name__:
sub._forward_method = fused_moe_forward_native
else:
sub._forward_method = sub.forward_native
setattr(sub, "is_torch_compile", True)
if isinstance(sub, torch.nn.Module):
_to_torch(sub, reverse)
@@ -67,7 +69,9 @@ def patch_model(
monkey_patch_vllm_all_gather()
backup_ca_comm = tp_group.ca_comm
tp_group.ca_comm = None
yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
yield torch.compile(
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
)
else:
yield model.forward
finally: