diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 7e6ae193a..2ad62d2d4 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -16,6 +16,10 @@ ARG SGL_BRANCH=${SGL_DEFAULT} ARG TRITON_REPO="https://github.com/triton-lang/triton.git" ARG TRITON_COMMIT="845d75a" + +ARG ATER_REPO="https://github.com/HaiShaw/ater" +ARG CK_COMMITS="fa05ae" + RUN git clone ${SGL_REPO} \ && cd sglang \ && if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \ @@ -46,6 +50,11 @@ RUN git clone ${TRITON_REPO} \ && cd python \ && python3 setup.py install +RUN git clone ${ATER_REPO} \ + && cd ater \ + && git submodule update --init --recursive \ + && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop + # Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 8d0b7035e..e1064bcda 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1,5 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py +import os from abc import abstractmethod from enum import Enum from typing import Callable, List, Optional, Tuple @@ -18,7 +19,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.utils import set_weight_attrs +from sglang.srt.utils import is_hip, permute_weight, set_weight_attrs if torch.cuda.is_available(): from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts @@ -97,6 +98,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if is_hip() and bool(int(os.getenv("CK_MOE", "0"))): + layer.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + return + def apply( self, layer: torch.nn.Module, @@ -148,14 +163,26 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): correction_bias=correction_bias, ) - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - ) + if is_hip() and bool(int(os.getenv("CK_MOE", "0"))): + import ater + from ater.fused_moe import fused_experts_ck + + return fused_experts_ck( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) + else: + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + ) def forward_cpu(self, *args, **kwargs): raise NotImplementedError("The CPU backend currently does not support MoE.") diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index f9e4a8a4f..22a43675b 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -40,6 +40,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.utils import ( get_bool_env_var, is_hip, + permute_weight, print_warning_once, set_weight_attrs, ) @@ -616,18 +617,30 @@ class Fp8MoEMethod: layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - # If ROCm, apply weight padding (min. Mem channel contention) only if set - if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): - layer.w13_weight = torch.nn.Parameter( - F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() + if is_hip(): + if bool(int(os.getenv("CK_MOE", "0"))): + layer.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + elif bool(int(os.getenv("MOE_PADDING", "0"))): + # If ROCm, apply weight padding (min. Mem channel contention) only if set + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return # If checkpoint is fp8, we need to handle that the @@ -708,18 +721,30 @@ class Fp8MoEMethod: max_w13_scales, requires_grad=False ) - # If ROCm, apply weight padding (min. Mem channel contention) only if set - if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): - layer.w13_weight = torch.nn.Parameter( - F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() + if is_hip(): + if bool(int(os.getenv("CK_MOE", "0"))): + layer.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + elif bool(int(os.getenv("MOE_PADDING", "0"))): + # If ROCm, apply weight padding (min. Mem channel contention) only if set + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return def apply( @@ -752,27 +777,55 @@ class Fp8MoEMethod: correction_bias=correction_bias, ) - # Expert fusion with FP8 quantization - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=True, - w1_scale=( - layer.w13_weight_scale_inv - if self.block_quant - else layer.w13_weight_scale - ), - w2_scale=( - layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale - ), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size, - ) + if is_hip() and bool(int(os.getenv("CK_MOE", "0"))): + import ater + from ater.fused_moe import fused_experts_ck + + return fused_experts_ck( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_fp8_w8a8=True, + w1_scale=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_scale=( + layer.w2_weight_scale_inv + if self.block_quant + else layer.w2_weight_scale + ), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + + else: + # Expert fusion with FP8 quantization + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + w1_scale=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_scale=( + layer.w2_weight_scale_inv + if self.block_quant + else layer.w2_weight_scale + ), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, + ) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index af9bdd60b..51ca91a96 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1340,6 +1340,25 @@ def parse_tool_response(text, tools, **kwargs): return text, call_info_list +def permute_weight(x: torch.Tensor) -> torch.Tensor: + b_ = x.shape[0] + n_ = x.shape[1] + k_ = x.shape[2] + + x_ = x + if x.dtype == torch.bfloat16 or x.dtype == torch.float16: + x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8) + elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8: + x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16) + else: + return x_ + + x_ = x_.permute(0, 1, 3, 4, 2, 5) + x_ = x_.contiguous() + x_ = x_.view(*x.shape) + return x_ + + class MultiprocessingSerializer: @staticmethod def serialize(obj):