Fix torch compile run (#7391)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: Sai Enduri <saimanas.enduri@amd.com>
This commit is contained in:
@@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
|
|||||||
|
|
||||||
|
|
||||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||||
ARG AITER_COMMIT="v0.1.2"
|
ARG AITER_COMMIT="v0.1.3"
|
||||||
|
|
||||||
RUN git clone ${SGL_REPO} \
|
RUN git clone ${SGL_REPO} \
|
||||||
&& cd sglang \
|
&& cd sglang \
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|||||||
|
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
from aiter import ActivationType
|
from aiter import ActivationType
|
||||||
|
from aiter.fused_moe import fused_moe
|
||||||
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
@@ -204,7 +205,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
topk_weights, dtype=torch.float32
|
topk_weights, dtype=torch.float32
|
||||||
) # topk_weights must be FP32 (float32)
|
) # topk_weights must be FP32 (float32)
|
||||||
|
|
||||||
return ck_moe_2stages(
|
return fused_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
|
|||||||
@@ -1052,15 +1052,15 @@ class Fp8MoEMethod:
|
|||||||
if _use_hip_int4:
|
if _use_hip_int4:
|
||||||
# TODO: add triton kernel and add check _use_aiter
|
# TODO: add triton kernel and add check _use_aiter
|
||||||
assert not no_combine, f"{no_combine=} is not supported."
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
return ck_moe_2stages(
|
return fused_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
QuantType.per_Token,
|
quant_type=QuantType.per_Token,
|
||||||
layer.w13_weight_scale1,
|
w1_scale=layer.w13_weight_scale1,
|
||||||
layer.w2_weight_scale1,
|
w2_scale=layer.w2_weight_scale1,
|
||||||
activation=(
|
activation=(
|
||||||
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
||||||
),
|
),
|
||||||
@@ -1086,15 +1086,15 @@ class Fp8MoEMethod:
|
|||||||
expert_mask=None,
|
expert_mask=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return ck_moe_2stages(
|
return fused_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
QuantType.per_Token,
|
quant_type=QuantType.per_Token,
|
||||||
layer.w13_weight_scale1,
|
w1_scale=layer.w13_weight_scale1,
|
||||||
layer.w2_weight_scale1,
|
w2_scale=layer.w2_weight_scale1,
|
||||||
activation=(
|
activation=(
|
||||||
ActivationType.Silu
|
ActivationType.Silu
|
||||||
if activation == "silu"
|
if activation == "silu"
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Pull the image
|
# Pull the image
|
||||||
IMAGE="lmsysorg/sglang:v0.4.6.post5-rocm630"
|
IMAGE="ghcr.io/saienduri/sglang:aiter-1.3"
|
||||||
echo "Pulling Docker image: $IMAGE"
|
echo "Pulling Docker image: $IMAGE"
|
||||||
docker pull "$IMAGE"
|
docker pull "$IMAGE"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user