diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 300501c6f..c5e6e04a6 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0" ARG AITER_REPO="https://github.com/ROCm/aiter.git" -ARG AITER_COMMIT="dev/testx" +ARG AITER_COMMIT="testx" RUN git clone ${SGL_REPO} \ && cd sglang \ diff --git a/python/pyproject.toml b/python/pyproject.toml index 528524811..fb98fca12 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -51,7 +51,7 @@ srt = [ ] # HIP (Heterogeneous-computing Interface for Portability) for AMD -# => base docker rocm/vllm-dev:20241022, not from public vllm whl +# => base docker rocm/vllm-dev:20250114, not from public vllm whl srt_hip = ["sglang[runtime_common]", "sgl-kernel==0.0.3.post6", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"] # xpu is not enabled in public vllm and torch whl, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index cf9a706b8..a365f8481 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -29,6 +29,9 @@ import logging is_hip_ = is_hip() +if is_hip_: + from aiter import ck_moe + logger = logging.getLogger(__name__) @@ -173,18 +176,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ) if is_hip_ and get_bool_env_var("CK_MOE"): - import aiter - from aiter.fused_moe import fused_experts_ck - - assert activation == "silu", f"{activation=} is not supported." assert not no_combine, "unsupported" - - return fused_experts_ck( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + return ck_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + None, + None, + None, + None, + 32, + None, + activation, ) else: return fused_experts( diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 2707026e8..c61adbdae 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -51,6 +51,10 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] is_hip_ = is_hip() +if is_hip_: + from aiter.fused_moe_bf16_asm import asm_moe + from aiter.ops.shuffle import shuffle_weight + logger = logging.getLogger(__name__) @@ -533,6 +537,20 @@ class Fp8MoEMethod: ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) + + if is_hip_ and get_bool_env_var("CK_MOE"): + # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling + w13_weight_scale1 = torch.nn.Parameter( + torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32), + requires_grad=False, + ) + w2_weight_scale1 = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale1", w13_weight_scale1) + layer.register_parameter("w2_weight_scale1", w2_weight_scale1) + # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly extra_weight_attrs.update( @@ -602,6 +620,15 @@ class Fp8MoEMethod: w2_weight_scale, requires_grad=False ) layer.w2_input_scale = None + + if get_bool_env_var("CK_MOE"): + # Pre-shuffle weights + layer.w13_weight.data = shuffle_weight( + layer.w13_weight.contiguous(), (16, 16) + ) + layer.w2_weight.data = shuffle_weight( + layer.w2_weight.contiguous(), (16, 16) + ) return # If checkpoint is fp16 or bfloat16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: @@ -640,6 +667,9 @@ class Fp8MoEMethod: requires_grad=False, ) torch.cuda.empty_cache() + # ROCm (CK_MOE): using column-wise scaling + layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) + layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) elif get_bool_env_var("MOE_PADDING"): # If ROCm, apply weight padding (min. Mem channel contention) only if set layer.w13_weight = torch.nn.Parameter( @@ -744,6 +774,9 @@ class Fp8MoEMethod: requires_grad=False, ) torch.cuda.empty_cache() + # ROCm (CK_MOE): using column-wise scaling + layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) + layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) elif get_bool_env_var("MOE_PADDING"): # If ROCm, apply weight padding (min. Mem channel contention) only if set layer.w13_weight = torch.nn.Parameter( @@ -790,34 +823,38 @@ class Fp8MoEMethod: correction_bias=correction_bias, ) - if is_hip_ and get_bool_env_var("CK_MOE"): - import aiter - from aiter.fused_moe import fused_experts_ck - - assert activation == "silu", f"{activation=} is not supported." + if is_hip_ and get_bool_env_var("CK_MOE") and activation == "silu": + # TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being. assert not no_combine, f"{no_combine=} is not supported." - - return fused_experts_ck( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - use_fp8_w8a8=True, - w1_scale=( - layer.w13_weight_scale_inv - if self.block_quant - else layer.w13_weight_scale - ), - w2_scale=( - layer.w2_weight_scale_inv - if self.block_quant - else layer.w2_weight_scale - ), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - + if self.block_quant: + return asm_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + layer.w13_weight_scale_inv, + layer.w2_weight_scale_inv, + None, + None, + False, + None, + block_shape=tuple(self.quant_config.weight_block_size), + expert_mask=None, + ) + else: + return asm_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + layer.w13_weight_scale1, + layer.w2_weight_scale1, + None, + None, + False, + ) else: # Expert fusion with FP8 quantization return fused_experts(