Integrate ROCm ater package for ck moe function feasibility (#2854)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: Lin, Soga <soga.lin@amd.com>
This commit is contained in:
@@ -40,6 +40,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_hip,
|
||||
permute_weight,
|
||||
print_warning_once,
|
||||
set_weight_attrs,
|
||||
)
|
||||
@@ -616,18 +617,30 @@ class Fp8MoEMethod:
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
if is_hip():
|
||||
if bool(int(os.getenv("CK_MOE", "0"))):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w2_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
elif bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
@@ -708,18 +721,30 @@ class Fp8MoEMethod:
|
||||
max_w13_scales, requires_grad=False
|
||||
)
|
||||
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
if is_hip():
|
||||
if bool(int(os.getenv("CK_MOE", "0"))):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w2_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
elif bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
return
|
||||
|
||||
def apply(
|
||||
@@ -752,27 +777,55 @@ class Fp8MoEMethod:
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
# Expert fusion with FP8 quantization
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||
),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
|
||||
import ater
|
||||
from ater.fused_moe import fused_experts_ck
|
||||
|
||||
return fused_experts_ck(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w2_weight_scale
|
||||
),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
else:
|
||||
# Expert fusion with FP8 quantization
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w2_weight_scale
|
||||
),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
|
||||
Reference in New Issue
Block a user