Improve torch compile for fused moe (#2327)
This commit is contained in:
@@ -105,20 +105,29 @@ def fused_moe_forward_native(
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
assert custom_routing_function is None
|
||||
topk_weights, topk_ids = select_experts_native(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
)
|
||||
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
x,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
x, router_logits, top_k, renormalize
|
||||
)
|
||||
|
||||
w13_weights = layer.w13_weight[topk_ids]
|
||||
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||
w2_weights = layer.w2_weight[topk_ids]
|
||||
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
|
||||
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
||||
x1 = F.silu(x1)
|
||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||
|
||||
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
||||
for sub in model._modules.values():
|
||||
if isinstance(sub, CustomOp):
|
||||
if reverse:
|
||||
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
||||
else:
|
||||
# NOTE: Temporarily workaround MoE
|
||||
if "FusedMoE" in sub.__class__.__name__:
|
||||
sub._forward_method = fused_moe_forward_native
|
||||
if batch_size == 1:
|
||||
# The performance of torch.compile on this layer is not always good when bs > 1,
|
||||
# so we decide to skip it for now.
|
||||
sub._forward_method = fused_moe_forward_native
|
||||
else:
|
||||
sub._forward_method = sub.forward_native
|
||||
setattr(sub, "is_torch_compile", True)
|
||||
if isinstance(sub, torch.nn.Module):
|
||||
_to_torch(sub, reverse)
|
||||
_to_torch(sub, reverse, batch_size)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_model(
|
||||
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
|
||||
model: torch.nn.Module,
|
||||
enable_compile: bool,
|
||||
batch_size: int,
|
||||
tp_group: "GroupCoordinator",
|
||||
):
|
||||
"""Patch the model to make it compatible with with torch.compile"""
|
||||
backup_ca_comm = None
|
||||
|
||||
try:
|
||||
if enable_compile:
|
||||
_to_torch(model)
|
||||
_to_torch(model, reverse=False, batch_size=batch_size)
|
||||
monkey_patch_vllm_all_gather()
|
||||
backup_ca_comm = tp_group.ca_comm
|
||||
# Use custom-allreduce here.
|
||||
@@ -70,13 +76,15 @@ def patch_model(
|
||||
# even with ENABLE_INTRA_NODE_COMM=1.
|
||||
# tp_group.ca_comm = None
|
||||
yield torch.compile(
|
||||
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
|
||||
torch.no_grad()(model.forward),
|
||||
mode="max-autotune-no-cudagraphs",
|
||||
dynamic=False,
|
||||
)
|
||||
else:
|
||||
yield model.forward
|
||||
finally:
|
||||
if enable_compile:
|
||||
_to_torch(model, reverse=True)
|
||||
_to_torch(model, reverse=True, batch_size=batch_size)
|
||||
monkey_patch_vllm_all_gather(reverse=True)
|
||||
tp_group.ca_comm = backup_ca_comm
|
||||
|
||||
@@ -237,6 +245,7 @@ class CudaGraphRunner:
|
||||
with patch_model(
|
||||
self.model_runner.model,
|
||||
bs in self.compile_bs,
|
||||
bs,
|
||||
self.model_runner.tp_group,
|
||||
) as forward:
|
||||
(
|
||||
|
||||
@@ -622,7 +622,7 @@ class ModelRunner:
|
||||
tic = time.time()
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f}s")
|
||||
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
|
||||
|
||||
def apply_torch_tp(self):
|
||||
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
||||
|
||||
Reference in New Issue
Block a user