Integrate ROCm ater package for ck moe function feasibility (#2854)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: Lin, Soga <soga.lin@amd.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
||||
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
@@ -18,7 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
from sglang.srt.utils import is_hip, permute_weight, set_weight_attrs
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
@@ -97,6 +98,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w2_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
return
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -148,14 +163,26 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
)
|
||||
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
|
||||
import ater
|
||||
from ater.fused_moe import fused_experts_ck
|
||||
|
||||
return fused_experts_ck(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
else:
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
def forward_cpu(self, *args, **kwargs):
|
||||
raise NotImplementedError("The CPU backend currently does not support MoE.")
|
||||
|
||||
@@ -40,6 +40,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_hip,
|
||||
permute_weight,
|
||||
print_warning_once,
|
||||
set_weight_attrs,
|
||||
)
|
||||
@@ -616,18 +617,30 @@ class Fp8MoEMethod:
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
if is_hip():
|
||||
if bool(int(os.getenv("CK_MOE", "0"))):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w2_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
elif bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
@@ -708,18 +721,30 @@ class Fp8MoEMethod:
|
||||
max_w13_scales, requires_grad=False
|
||||
)
|
||||
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
if is_hip():
|
||||
if bool(int(os.getenv("CK_MOE", "0"))):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w2_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
elif bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
return
|
||||
|
||||
def apply(
|
||||
@@ -752,27 +777,55 @@ class Fp8MoEMethod:
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
# Expert fusion with FP8 quantization
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||
),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
|
||||
import ater
|
||||
from ater.fused_moe import fused_experts_ck
|
||||
|
||||
return fused_experts_ck(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w2_weight_scale
|
||||
),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
else:
|
||||
# Expert fusion with FP8 quantization
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w2_weight_scale
|
||||
),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
|
||||
@@ -1340,6 +1340,25 @@ def parse_tool_response(text, tools, **kwargs):
|
||||
return text, call_info_list
|
||||
|
||||
|
||||
def permute_weight(x: torch.Tensor) -> torch.Tensor:
|
||||
b_ = x.shape[0]
|
||||
n_ = x.shape[1]
|
||||
k_ = x.shape[2]
|
||||
|
||||
x_ = x
|
||||
if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
|
||||
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8)
|
||||
elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
|
||||
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
|
||||
else:
|
||||
return x_
|
||||
|
||||
x_ = x_.permute(0, 1, 3, 4, 2, 5)
|
||||
x_ = x_.contiguous()
|
||||
x_ = x_.view(*x.shape)
|
||||
return x_
|
||||
|
||||
|
||||
class MultiprocessingSerializer:
|
||||
@staticmethod
|
||||
def serialize(obj):
|
||||
|
||||
Reference in New Issue
Block a user