ROCm: update AITER (#5816)
This commit is contained in:
12
.github/workflows/pr-test-amd.yml
vendored
12
.github/workflows/pr-test-amd.yml
vendored
@@ -38,12 +38,12 @@ jobs:
|
|||||||
else
|
else
|
||||||
DEVICE_FLAG="--device /dev/dri"
|
DEVICE_FLAG="--device /dev/dri"
|
||||||
fi
|
fi
|
||||||
docker pull lmsysorg/sglang:v0.4.5.post3-rocm630
|
docker pull ghcr.io/saienduri/sglang-aiter-v0.1.1:428
|
||||||
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
|
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
|
||||||
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
|
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
|
||||||
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
|
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
|
||||||
-w /sglang-checkout --name ci_sglang \
|
-w /sglang-checkout --name ci_sglang \
|
||||||
lmsysorg/sglang:v0.4.5.post3-rocm630
|
ghcr.io/saienduri/sglang-aiter-v0.1.1:428
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -82,12 +82,12 @@ jobs:
|
|||||||
else
|
else
|
||||||
DEVICE_FLAG="--device /dev/dri"
|
DEVICE_FLAG="--device /dev/dri"
|
||||||
fi
|
fi
|
||||||
docker pull lmsysorg/sglang:v0.4.5.post3-rocm630
|
docker pull ghcr.io/saienduri/sglang-aiter-v0.1.1:428
|
||||||
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
|
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
|
||||||
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
|
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
|
||||||
--cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \
|
--cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \
|
||||||
-w /sglang-checkout --name ci_sglang \
|
-w /sglang-checkout --name ci_sglang \
|
||||||
lmsysorg/sglang:v0.4.5.post3-rocm630
|
ghcr.io/saienduri/sglang-aiter-v0.1.1:428
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -120,12 +120,12 @@ jobs:
|
|||||||
else
|
else
|
||||||
DEVICE_FLAG="--device /dev/dri"
|
DEVICE_FLAG="--device /dev/dri"
|
||||||
fi
|
fi
|
||||||
docker pull lmsysorg/sglang:v0.4.5.post3-rocm630
|
docker pull ghcr.io/saienduri/sglang-aiter-v0.1.1:428
|
||||||
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
|
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
|
||||||
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
|
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
|
||||||
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
|
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
|
||||||
-w /sglang-checkout --name ci_sglang \
|
-w /sglang-checkout --name ci_sglang \
|
||||||
lmsysorg/sglang:v0.4.5.post3-rocm630
|
ghcr.io/saienduri/sglang-aiter-v0.1.1:428
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
2
3rdparty/amd/tuning/benchmark_moe_rocm.py
vendored
2
3rdparty/amd/tuning/benchmark_moe_rocm.py
vendored
@@ -15,7 +15,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
|||||||
get_config_file_name,
|
get_config_file_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
|
||||||
|
|
||||||
|
|
||||||
def main(model, tp_size, dtype: str, batches):
|
def main(model, tp_size, dtype: str, batches):
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
|
|||||||
|
|
||||||
|
|
||||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||||
ARG AITER_COMMIT="testx"
|
ARG AITER_COMMIT="v0.1.1"
|
||||||
|
|
||||||
RUN git clone ${SGL_REPO} \
|
RUN git clone ${SGL_REPO} \
|
||||||
&& cd sglang \
|
&& cd sglang \
|
||||||
@@ -74,7 +74,7 @@ ENV SGLANG_SET_CPU_AFFINITY=1
|
|||||||
ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
|
ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
|
||||||
ENV NCCL_MIN_NCHANNELS=112
|
ENV NCCL_MIN_NCHANNELS=112
|
||||||
|
|
||||||
ENV MOE_PADDING=1
|
ENV SGLANG_MOE_PADDING=1
|
||||||
ENV VLLM_FP8_PADDING=1
|
ENV VLLM_FP8_PADDING=1
|
||||||
ENV VLLM_FP8_ACT_PADDING=1
|
ENV VLLM_FP8_ACT_PADDING=1
|
||||||
ENV VLLM_FP8_WEIGHT_PADDING=1
|
ENV VLLM_FP8_WEIGHT_PADDING=1
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ if _is_cuda or _is_hip:
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
|
||||||
enable_moe_align_block_size_triton = bool(
|
enable_moe_align_block_size_triton = bool(
|
||||||
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||||
)
|
)
|
||||||
@@ -1327,7 +1327,7 @@ def fused_experts_impl(
|
|||||||
if (
|
if (
|
||||||
not (use_fp8_w8a8 or use_int8_w8a8)
|
not (use_fp8_w8a8 or use_int8_w8a8)
|
||||||
or block_shape is not None
|
or block_shape is not None
|
||||||
or (_is_hip and get_bool_env_var("CK_MOE"))
|
or (_is_hip and get_bool_env_var("SGLANG_AITER_MOE"))
|
||||||
):
|
):
|
||||||
padded_size = 0
|
padded_size = 0
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs
|
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
@@ -30,7 +30,9 @@ import logging
|
|||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
from aiter import ck_moe
|
from aiter import ActivationType
|
||||||
|
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
||||||
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -102,14 +104,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
permute_weight(layer.w13_weight.data),
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
layer.w2_weight = torch.nn.Parameter(
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
permute_weight(layer.w2_weight.data),
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@@ -182,21 +184,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
||||||
assert not no_combine, "unsupported"
|
assert not no_combine, "unsupported"
|
||||||
return ck_moe(
|
return ck_moe_2stages(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
None,
|
activation=(
|
||||||
None,
|
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
||||||
None,
|
),
|
||||||
None,
|
|
||||||
32,
|
|
||||||
None,
|
|
||||||
activation,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
@@ -527,7 +525,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
# Case input scale: input_scale loading is only supported for fp8
|
# Case input scale: input_scale loading is only supported for fp8
|
||||||
if "input_scale" in weight_name:
|
if "input_scale" in weight_name:
|
||||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
|
||||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
||||||
loaded_weight = loaded_weight * 2.0
|
loaded_weight = loaded_weight * 2.0
|
||||||
|
|
||||||
# this is needed for compressed-tensors only
|
# this is needed for compressed-tensors only
|
||||||
@@ -569,7 +567,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
quant_method = getattr(param, "quant_method", None)
|
quant_method = getattr(param, "quant_method", None)
|
||||||
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
||||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
|
||||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
||||||
loaded_weight = loaded_weight * 0.5
|
loaded_weight = loaded_weight * 0.5
|
||||||
|
|
||||||
self._load_per_channel_weight_scale(
|
self._load_per_channel_weight_scale(
|
||||||
@@ -592,7 +590,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
|
||||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
||||||
loaded_weight = loaded_weight * 2.0
|
loaded_weight = loaded_weight * 2.0
|
||||||
|
|
||||||
self._load_per_tensor_weight_scale(
|
self._load_per_tensor_weight_scale(
|
||||||
|
|||||||
@@ -72,8 +72,8 @@ _is_hip = is_hip()
|
|||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
from aiter import ActivationType
|
from aiter import ActivationType, QuantType
|
||||||
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
|
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
if not _is_cuda:
|
if not _is_cuda:
|
||||||
@@ -484,7 +484,7 @@ class Fp8MoEMethod:
|
|||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
params_dtype = (
|
params_dtype = (
|
||||||
torch.uint32
|
torch.uint32
|
||||||
if get_bool_env_var("USE_INT4_WEIGHT")
|
if get_bool_env_var("SGLANG_INT4_WEIGHT")
|
||||||
else torch.float8_e4m3fn
|
else torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@@ -511,7 +511,7 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# WEIGHTS
|
# WEIGHTS
|
||||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
||||||
# INT4 MoE weight - INT32 packed
|
# INT4 MoE weight - INT32 packed
|
||||||
w13_weight = torch.nn.Parameter(
|
w13_weight = torch.nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
@@ -585,7 +585,7 @@ class Fp8MoEMethod:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
_is_hip
|
_is_hip
|
||||||
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
|
): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
|
||||||
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
||||||
w13_weight_scale1 = torch.nn.Parameter(
|
w13_weight_scale1 = torch.nn.Parameter(
|
||||||
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
||||||
@@ -612,7 +612,7 @@ class Fp8MoEMethod:
|
|||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
||||||
extra_weight_attrs.update(
|
extra_weight_attrs.update(
|
||||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||||
)
|
)
|
||||||
@@ -644,7 +644,7 @@ class Fp8MoEMethod:
|
|||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
||||||
self.process_weights_hip_int4(layer)
|
self.process_weights_hip_int4(layer)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -675,7 +675,7 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
if get_bool_env_var("CK_MOE"):
|
if get_bool_env_var("SGLANG_AITER_MOE"):
|
||||||
# Pre-shuffle weights
|
# Pre-shuffle weights
|
||||||
layer.w13_weight.data = shuffle_weight(
|
layer.w13_weight.data = shuffle_weight(
|
||||||
layer.w13_weight.contiguous(), (16, 16)
|
layer.w13_weight.contiguous(), (16, 16)
|
||||||
@@ -798,17 +798,15 @@ class Fp8MoEMethod:
|
|||||||
return
|
return
|
||||||
|
|
||||||
def process_weights_hip_int4(self, layer: Module):
|
def process_weights_hip_int4(self, layer: Module):
|
||||||
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
|
# TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
|
||||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
||||||
# Weight Permutation
|
# Weight Permutation
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
# permute_weight(layer.w13_weight.data),
|
|
||||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
layer.w2_weight = torch.nn.Parameter(
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
# permute_weight(layer.w2_weight.data),
|
|
||||||
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
@@ -847,23 +845,21 @@ class Fp8MoEMethod:
|
|||||||
padding_size, # Avoid circular import
|
padding_size, # Avoid circular import
|
||||||
)
|
)
|
||||||
|
|
||||||
if get_bool_env_var("CK_MOE"):
|
if get_bool_env_var("SGLANG_AITER_MOE"):
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
# permute_weight(layer.w13_weight.data),
|
|
||||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
layer.w2_weight = torch.nn.Parameter(
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
# permute_weight(layer.w2_weight.data),
|
|
||||||
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
# ROCm (CK_MOE): using column-wise scaling
|
# ROCm (SGLANG_AITER_MOE): using column-wise scaling
|
||||||
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
||||||
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
||||||
elif get_bool_env_var("MOE_PADDING"):
|
elif get_bool_env_var("SGLANG_MOE_PADDING"):
|
||||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||||
@@ -912,15 +908,16 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
if get_bool_env_var("USE_INT4_WEIGHT"):
|
if get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
||||||
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
# TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
|
||||||
assert not no_combine, f"{no_combine=} is not supported."
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
return ck_moe_2stages_win4(
|
return ck_moe_2stages(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
QuantType.per_Token,
|
||||||
layer.w13_weight_scale1,
|
layer.w13_weight_scale1,
|
||||||
layer.w2_weight_scale1,
|
layer.w2_weight_scale1,
|
||||||
activation=(
|
activation=(
|
||||||
@@ -930,13 +927,13 @@ class Fp8MoEMethod:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if get_bool_env_var("CK_MOE"):
|
if get_bool_env_var("SGLANG_AITER_MOE"):
|
||||||
assert not no_combine, f"{no_combine=} is not supported."
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
# TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
||||||
assert (
|
assert (
|
||||||
activation == "silu"
|
activation == "silu"
|
||||||
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
|
||||||
return asm_moe(
|
return asm_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@@ -955,6 +952,7 @@ class Fp8MoEMethod:
|
|||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
QuantType.per_Token,
|
||||||
layer.w13_weight_scale1,
|
layer.w13_weight_scale1,
|
||||||
layer.w2_weight_scale1,
|
layer.w2_weight_scale1,
|
||||||
activation=(
|
activation=(
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from sglang.srt.utils import (
|
|||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
||||||
from aiter import gemm_a8w8_blockscale
|
from aiter import gemm_a8w8_blockscale
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
@@ -132,7 +132,7 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
output = fp8_blockwise_scaled_mm(
|
output = fp8_blockwise_scaled_mm(
|
||||||
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
||||||
)
|
)
|
||||||
elif _is_hip and get_bool_env_var("CK_MOE"):
|
elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
||||||
q_input, x_scale = per_token_group_quant_fp8(
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
input_2d, block_size[1], column_major_scales=False
|
input_2d, block_size[1], column_major_scales=False
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user