ROCm: update aiter and its usage to fused moe (bloat16, fp8, fp8 block-quant) (#4053)
This commit is contained in:
@@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
|
|||||||
|
|
||||||
|
|
||||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||||
ARG AITER_COMMIT="dev/testx"
|
ARG AITER_COMMIT="testx"
|
||||||
|
|
||||||
RUN git clone ${SGL_REPO} \
|
RUN git clone ${SGL_REPO} \
|
||||||
&& cd sglang \
|
&& cd sglang \
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ srt = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||||
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
|
||||||
srt_hip = ["sglang[runtime_common]", "sgl-kernel==0.0.3.post6", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"]
|
srt_hip = ["sglang[runtime_common]", "sgl-kernel==0.0.3.post6", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"]
|
||||||
|
|
||||||
# xpu is not enabled in public vllm and torch whl,
|
# xpu is not enabled in public vllm and torch whl,
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ import logging
|
|||||||
|
|
||||||
is_hip_ = is_hip()
|
is_hip_ = is_hip()
|
||||||
|
|
||||||
|
if is_hip_:
|
||||||
|
from aiter import ck_moe
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -173,18 +176,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||||
import aiter
|
|
||||||
from aiter.fused_moe import fused_experts_ck
|
|
||||||
|
|
||||||
assert activation == "silu", f"{activation=} is not supported."
|
|
||||||
assert not no_combine, "unsupported"
|
assert not no_combine, "unsupported"
|
||||||
|
return ck_moe(
|
||||||
return fused_experts_ck(
|
x,
|
||||||
hidden_states=x,
|
layer.w13_weight,
|
||||||
w1=layer.w13_weight,
|
layer.w2_weight,
|
||||||
w2=layer.w2_weight,
|
topk_weights,
|
||||||
topk_weights=topk_weights,
|
topk_ids,
|
||||||
topk_ids=topk_ids,
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
32,
|
||||||
|
None,
|
||||||
|
activation,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
|
|||||||
@@ -51,6 +51,10 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
|||||||
|
|
||||||
is_hip_ = is_hip()
|
is_hip_ = is_hip()
|
||||||
|
|
||||||
|
if is_hip_:
|
||||||
|
from aiter.fused_moe_bf16_asm import asm_moe
|
||||||
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -533,6 +537,20 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
|
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||||
|
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
||||||
|
w13_weight_scale1 = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
w2_weight_scale1 = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, hidden_size, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale1", w13_weight_scale1)
|
||||||
|
layer.register_parameter("w2_weight_scale1", w2_weight_scale1)
|
||||||
|
|
||||||
# Add the quantization method used (per tensor/grouped/channel)
|
# Add the quantization method used (per tensor/grouped/channel)
|
||||||
# to ensure the weight scales are loaded in properly
|
# to ensure the weight scales are loaded in properly
|
||||||
extra_weight_attrs.update(
|
extra_weight_attrs.update(
|
||||||
@@ -602,6 +620,15 @@ class Fp8MoEMethod:
|
|||||||
w2_weight_scale, requires_grad=False
|
w2_weight_scale, requires_grad=False
|
||||||
)
|
)
|
||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
|
if get_bool_env_var("CK_MOE"):
|
||||||
|
# Pre-shuffle weights
|
||||||
|
layer.w13_weight.data = shuffle_weight(
|
||||||
|
layer.w13_weight.contiguous(), (16, 16)
|
||||||
|
)
|
||||||
|
layer.w2_weight.data = shuffle_weight(
|
||||||
|
layer.w2_weight.contiguous(), (16, 16)
|
||||||
|
)
|
||||||
return
|
return
|
||||||
# If checkpoint is fp16 or bfloat16, quantize in place.
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
@@ -640,6 +667,9 @@ class Fp8MoEMethod:
|
|||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
# ROCm (CK_MOE): using column-wise scaling
|
||||||
|
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
||||||
|
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
||||||
elif get_bool_env_var("MOE_PADDING"):
|
elif get_bool_env_var("MOE_PADDING"):
|
||||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
@@ -744,6 +774,9 @@ class Fp8MoEMethod:
|
|||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
# ROCm (CK_MOE): using column-wise scaling
|
||||||
|
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
||||||
|
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
||||||
elif get_bool_env_var("MOE_PADDING"):
|
elif get_bool_env_var("MOE_PADDING"):
|
||||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
@@ -790,34 +823,38 @@ class Fp8MoEMethod:
|
|||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
if is_hip_ and get_bool_env_var("CK_MOE") and activation == "silu":
|
||||||
import aiter
|
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
|
||||||
from aiter.fused_moe import fused_experts_ck
|
|
||||||
|
|
||||||
assert activation == "silu", f"{activation=} is not supported."
|
|
||||||
assert not no_combine, f"{no_combine=} is not supported."
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
|
if self.block_quant:
|
||||||
return fused_experts_ck(
|
return asm_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_weights=topk_weights,
|
topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids,
|
||||||
use_fp8_w8a8=True,
|
layer.w13_weight_scale_inv,
|
||||||
w1_scale=(
|
layer.w2_weight_scale_inv,
|
||||||
layer.w13_weight_scale_inv
|
None,
|
||||||
if self.block_quant
|
None,
|
||||||
else layer.w13_weight_scale
|
False,
|
||||||
),
|
None,
|
||||||
w2_scale=(
|
block_shape=tuple(self.quant_config.weight_block_size),
|
||||||
layer.w2_weight_scale_inv
|
expert_mask=None,
|
||||||
if self.block_quant
|
)
|
||||||
else layer.w2_weight_scale
|
else:
|
||||||
),
|
return asm_moe(
|
||||||
a1_scale=layer.w13_input_scale,
|
x,
|
||||||
a2_scale=layer.w2_input_scale,
|
layer.w13_weight,
|
||||||
)
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
layer.w13_weight_scale1,
|
||||||
|
layer.w2_weight_scale1,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Expert fusion with FP8 quantization
|
# Expert fusion with FP8 quantization
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
|
|||||||
Reference in New Issue
Block a user