Fp8 MoE optimizations on AMD (#2388)
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm import _custom_ops as ops
|
||||
@@ -24,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
|
||||
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
@@ -420,7 +423,7 @@ class Fp8MoEMethod:
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
||||
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
||||
@@ -444,6 +447,19 @@ class Fp8MoEMethod:
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
@@ -472,6 +488,7 @@ class Fp8MoEMethod:
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False
|
||||
)
|
||||
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip():
|
||||
# Normalize the weights and scales
|
||||
@@ -523,6 +540,19 @@ class Fp8MoEMethod:
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
max_w13_scales, requires_grad=False
|
||||
)
|
||||
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
return
|
||||
|
||||
def apply(
|
||||
@@ -540,6 +570,7 @@ class Fp8MoEMethod:
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
# Expert selection
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@@ -551,6 +582,7 @@ class Fp8MoEMethod:
|
||||
custom_routing_function=custom_routing_function,
|
||||
)
|
||||
|
||||
# Expert fusion with FP8 quantization
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
|
||||
Reference in New Issue
Block a user