MoE torch compile (#1497)
This commit is contained in:
@@ -25,6 +25,7 @@ import torch
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.fused_moe.patch import fused_moe_forward_native
|
||||
from sglang.srt.layers.logits_processor import (
|
||||
LogitsMetadata,
|
||||
LogitsProcessor,
|
||||
@@ -41,14 +42,15 @@ if TYPE_CHECKING:
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
||||
for sub in model._modules.values():
|
||||
if isinstance(sub, CustomOp):
|
||||
# NOTE: FusedMoE torch native implementaiton is not efficient
|
||||
if "FusedMoE" in sub.__class__.__name__:
|
||||
continue
|
||||
if reverse:
|
||||
sub._forward_method = sub.forward_cuda
|
||||
setattr(sub, "is_torch_compile", False)
|
||||
else:
|
||||
sub._forward_method = sub.forward_native
|
||||
# NOTE: Temporarily workaround MoE
|
||||
if "FusedMoE" in sub.__class__.__name__:
|
||||
sub._forward_method = fused_moe_forward_native
|
||||
else:
|
||||
sub._forward_method = sub.forward_native
|
||||
setattr(sub, "is_torch_compile", True)
|
||||
if isinstance(sub, torch.nn.Module):
|
||||
_to_torch(sub, reverse)
|
||||
@@ -67,7 +69,9 @@ def patch_model(
|
||||
monkey_patch_vllm_all_gather()
|
||||
backup_ca_comm = tp_group.ca_comm
|
||||
tp_group.ca_comm = None
|
||||
yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
|
||||
yield torch.compile(
|
||||
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
|
||||
)
|
||||
else:
|
||||
yield model.forward
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user