Integrate ROCm ater package for ck moe function feasibility (#2854)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: Lin, Soga <soga.lin@amd.com>
This commit is contained in:
@@ -16,6 +16,10 @@ ARG SGL_BRANCH=${SGL_DEFAULT}
|
|||||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||||
ARG TRITON_COMMIT="845d75a"
|
ARG TRITON_COMMIT="845d75a"
|
||||||
|
|
||||||
|
|
||||||
|
ARG ATER_REPO="https://github.com/HaiShaw/ater"
|
||||||
|
ARG CK_COMMITS="fa05ae"
|
||||||
|
|
||||||
RUN git clone ${SGL_REPO} \
|
RUN git clone ${SGL_REPO} \
|
||||||
&& cd sglang \
|
&& cd sglang \
|
||||||
&& if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \
|
&& if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \
|
||||||
@@ -46,6 +50,11 @@ RUN git clone ${TRITON_REPO} \
|
|||||||
&& cd python \
|
&& cd python \
|
||||||
&& python3 setup.py install
|
&& python3 setup.py install
|
||||||
|
|
||||||
|
RUN git clone ${ATER_REPO} \
|
||||||
|
&& cd ater \
|
||||||
|
&& git submodule update --init --recursive \
|
||||||
|
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop
|
||||||
|
|
||||||
# Performance environment variable.
|
# Performance environment variable.
|
||||||
|
|
||||||
ENV HIP_FORCE_DEV_KERNARG=1
|
ENV HIP_FORCE_DEV_KERNARG=1
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
||||||
|
|
||||||
|
import os
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple
|
||||||
@@ -18,7 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import is_hip, permute_weight, set_weight_attrs
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
@@ -97,6 +98,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer.register_parameter("w2_weight", w2_weight)
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
|
||||||
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
|
permute_weight(layer.w13_weight.data),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
|
permute_weight(layer.w2_weight.data),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -148,14 +163,26 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
return fused_experts(
|
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
|
||||||
hidden_states=x,
|
import ater
|
||||||
w1=layer.w13_weight,
|
from ater.fused_moe import fused_experts_ck
|
||||||
w2=layer.w2_weight,
|
|
||||||
topk_weights=topk_weights,
|
return fused_experts_ck(
|
||||||
topk_ids=topk_ids,
|
hidden_states=x,
|
||||||
inplace=True,
|
w1=layer.w13_weight,
|
||||||
)
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_cpu(self, *args, **kwargs):
|
def forward_cpu(self, *args, **kwargs):
|
||||||
raise NotImplementedError("The CPU backend currently does not support MoE.")
|
raise NotImplementedError("The CPU backend currently does not support MoE.")
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
is_hip,
|
is_hip,
|
||||||
|
permute_weight,
|
||||||
print_warning_once,
|
print_warning_once,
|
||||||
set_weight_attrs,
|
set_weight_attrs,
|
||||||
)
|
)
|
||||||
@@ -616,18 +617,30 @@ class Fp8MoEMethod:
|
|||||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
|
||||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
if is_hip():
|
||||||
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
if bool(int(os.getenv("CK_MOE", "0"))):
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
permute_weight(layer.w13_weight.data),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
layer.w2_weight = torch.nn.Parameter(
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
permute_weight(layer.w2_weight.data),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
elif bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||||
|
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||||
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
|
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
|
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
return
|
return
|
||||||
|
|
||||||
# If checkpoint is fp8, we need to handle that the
|
# If checkpoint is fp8, we need to handle that the
|
||||||
@@ -708,18 +721,30 @@ class Fp8MoEMethod:
|
|||||||
max_w13_scales, requires_grad=False
|
max_w13_scales, requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
if is_hip():
|
||||||
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
if bool(int(os.getenv("CK_MOE", "0"))):
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
permute_weight(layer.w13_weight.data),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
layer.w2_weight = torch.nn.Parameter(
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
permute_weight(layer.w2_weight.data),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
elif bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||||
|
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||||
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
|
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
|
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
return
|
return
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
@@ -752,27 +777,55 @@ class Fp8MoEMethod:
|
|||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Expert fusion with FP8 quantization
|
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
|
||||||
return fused_experts(
|
import ater
|
||||||
x,
|
from ater.fused_moe import fused_experts_ck
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
return fused_experts_ck(
|
||||||
topk_weights=topk_weights,
|
x,
|
||||||
topk_ids=topk_ids,
|
layer.w13_weight,
|
||||||
inplace=True,
|
layer.w2_weight,
|
||||||
use_fp8_w8a8=True,
|
topk_weights=topk_weights,
|
||||||
w1_scale=(
|
topk_ids=topk_ids,
|
||||||
layer.w13_weight_scale_inv
|
use_fp8_w8a8=True,
|
||||||
if self.block_quant
|
w1_scale=(
|
||||||
else layer.w13_weight_scale
|
layer.w13_weight_scale_inv
|
||||||
),
|
if self.block_quant
|
||||||
w2_scale=(
|
else layer.w13_weight_scale
|
||||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
),
|
||||||
),
|
w2_scale=(
|
||||||
a1_scale=layer.w13_input_scale,
|
layer.w2_weight_scale_inv
|
||||||
a2_scale=layer.w2_input_scale,
|
if self.block_quant
|
||||||
block_shape=self.quant_config.weight_block_size,
|
else layer.w2_weight_scale
|
||||||
)
|
),
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Expert fusion with FP8 quantization
|
||||||
|
return fused_experts(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=(
|
||||||
|
layer.w13_weight_scale_inv
|
||||||
|
if self.block_quant
|
||||||
|
else layer.w13_weight_scale
|
||||||
|
),
|
||||||
|
w2_scale=(
|
||||||
|
layer.w2_weight_scale_inv
|
||||||
|
if self.block_quant
|
||||||
|
else layer.w2_weight_scale
|
||||||
|
),
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
|
|||||||
@@ -1340,6 +1340,25 @@ def parse_tool_response(text, tools, **kwargs):
|
|||||||
return text, call_info_list
|
return text, call_info_list
|
||||||
|
|
||||||
|
|
||||||
|
def permute_weight(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
b_ = x.shape[0]
|
||||||
|
n_ = x.shape[1]
|
||||||
|
k_ = x.shape[2]
|
||||||
|
|
||||||
|
x_ = x
|
||||||
|
if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
|
||||||
|
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8)
|
||||||
|
elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
|
||||||
|
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
|
||||||
|
else:
|
||||||
|
return x_
|
||||||
|
|
||||||
|
x_ = x_.permute(0, 1, 3, 4, 2, 5)
|
||||||
|
x_ = x_.contiguous()
|
||||||
|
x_ = x_.view(*x.shape)
|
||||||
|
return x_
|
||||||
|
|
||||||
|
|
||||||
class MultiprocessingSerializer:
|
class MultiprocessingSerializer:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def serialize(obj):
|
def serialize(obj):
|
||||||
|
|||||||
Reference in New Issue
Block a user