Fp8 MoE optimizations on AMD (#2388)

This commit is contained in:
HAI
2024-12-07 05:18:26 -08:00
committed by GitHub
parent aaac33fd8d
commit 95f93f493a
2 changed files with 97 additions and 22 deletions

View File

@@ -1,9 +1,11 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
import logging
import os
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.nn.functional as F
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
@@ -24,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
@@ -420,7 +423,7 @@ class Fp8MoEMethod:
def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint is fp16, quantize in place.
# If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
@@ -444,6 +447,19 @@ class Fp8MoEMethod:
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return
# If checkpoint is fp8, we need to handle that the
@@ -472,6 +488,7 @@ class Fp8MoEMethod:
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip():
# Normalize the weights and scales
@@ -523,6 +540,19 @@ class Fp8MoEMethod:
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return
def apply(
@@ -540,6 +570,7 @@ class Fp8MoEMethod:
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@@ -551,6 +582,7 @@ class Fp8MoEMethod:
custom_routing_function=custom_routing_function,
)
# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,