Integrate ROCm ater package for ck moe function feasibility (#2854)

Co-authored-by: wunhuang <wunhuang@amd.com>
Co-authored-by: Lin, Soga <soga.lin@amd.com>
This commit is contained in:
kk
2025-01-13 16:23:07 +08:00
committed by GitHub
parent a18ab81ddd
commit e808c1df3e
4 changed files with 162 additions and 54 deletions

View File

@@ -40,6 +40,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
from sglang.srt.utils import (
get_bool_env_var,
is_hip,
permute_weight,
print_warning_once,
set_weight_attrs,
)
@@ -616,18 +617,30 @@ class Fp8MoEMethod:
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
if is_hip():
if bool(int(os.getenv("CK_MOE", "0"))):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
elif bool(int(os.getenv("MOE_PADDING", "0"))):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return
# If checkpoint is fp8, we need to handle that the
@@ -708,18 +721,30 @@ class Fp8MoEMethod:
max_w13_scales, requires_grad=False
)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
if is_hip():
if bool(int(os.getenv("CK_MOE", "0"))):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
elif bool(int(os.getenv("MOE_PADDING", "0"))):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return
def apply(
@@ -752,27 +777,55 @@ class Fp8MoEMethod:
correction_bias=correction_bias,
)
# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
import ater
from ater.fused_moe import fused_experts_ck
return fused_experts_ck(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv
if self.block_quant
else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
else:
# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv
if self.block_quant
else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
class Fp8KVCacheMethod(BaseKVCacheMethod):