From 5f65e2b830a4cf752ae8ab5739ae7ad958eced83 Mon Sep 17 00:00:00 2001 From: HAI Date: Wed, 30 Oct 2024 12:17:32 -0700 Subject: [PATCH] [Performance, Hardware] MoE weights padding to AMD MI300x GPUs (#1836) --- .../sglang/srt/layers/fused_moe/fused_moe.py | 7 +++-- python/sglang/srt/layers/fused_moe/layer.py | 28 +++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/fused_moe/fused_moe.py b/python/sglang/srt/layers/fused_moe/fused_moe.py index 717be5ce9..646cea14d 100644 --- a/python/sglang/srt/layers/fused_moe/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe/fused_moe.py @@ -14,6 +14,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger logger = init_logger(__name__) +padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 @triton.jit @@ -263,7 +264,7 @@ def invoke_fused_moe_kernel( expert_ids, num_tokens_post_padded, B.shape[1], - B.shape[2], + B.shape[2] - padding_size, sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0), @@ -464,7 +465,7 @@ def fused_experts( a2_scale: Optional[torch.Tensor] = None, ): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -481,7 +482,7 @@ def fused_experts( get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, - w2.shape, + (w2.shape[0], w2.shape[1], w2.shape[2] - padding_size), topk_ids.shape[1], "float8" if use_fp8 else None, override_config=override_config, diff --git a/python/sglang/srt/layers/fused_moe/layer.py b/python/sglang/srt/layers/fused_moe/layer.py index 0511db5a1..19012185d 100644 --- a/python/sglang/srt/layers/fused_moe/layer.py +++ b/python/sglang/srt/layers/fused_moe/layer.py @@ -1,9 +1,11 @@ # Adapted from # https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe +import os from abc import abstractmethod from typing import List, Optional, Tuple import torch +import torch.nn.functional as F from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -18,6 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.utils import set_weight_attrs +from sglang.srt.layers.fused_moe.fused_moe import padding_size from sglang.srt.utils import is_hip logger = init_logger(__name__) @@ -506,6 +509,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + + # If ROCm, apply weight padding (min. Mem channel contention) only if set + if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return # If checkpoint is fp8, we need to handle that the @@ -572,6 +588,18 @@ class Fp8MoEMethod(FusedMoEMethodBase): start += shard_size layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + # If ROCm, apply weight padding (min. Mem channel contention) only if set + if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return def apply(