[Performance, Hardware] MoE weights padding to AMD MI300x GPUs (#1836)
This commit is contained in:
@@ -14,6 +14,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -263,7 +264,7 @@ def invoke_fused_moe_kernel(
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
B.shape[2],
|
||||
B.shape[2] - padding_size,
|
||||
sorted_token_ids.shape[0],
|
||||
topk_ids.numel(),
|
||||
A.stride(0),
|
||||
@@ -464,7 +465,7 @@ def fused_experts(
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
@@ -481,7 +482,7 @@ def fused_experts(
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
(w2.shape[0], w2.shape[1], w2.shape[2] - padding_size),
|
||||
topk_ids.shape[1],
|
||||
"float8" if use_fp8 else None,
|
||||
override_config=override_config,
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -18,6 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from sglang.srt.layers.fused_moe.fused_moe import padding_size
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -506,6 +509,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
@@ -572,6 +588,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
start += shard_size
|
||||
|
||||
layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
return
|
||||
|
||||
def apply(
|
||||
|
||||
Reference in New Issue
Block a user