Add Qwen3-30B-A3B-Thinking-2507 support on AMD GPUs. (#9456)
This commit is contained in:
@@ -49,13 +49,15 @@ if _is_cuda:
|
|||||||
elif _is_cpu and _is_cpu_amx_available:
|
elif _is_cpu and _is_cpu_amx_available:
|
||||||
pass
|
pass
|
||||||
elif _is_hip:
|
elif _is_hip:
|
||||||
from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||||
|
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
try:
|
try:
|
||||||
from aiter import moe_sum
|
from aiter import moe_sum
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
||||||
|
else:
|
||||||
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
|
||||||
|
|
||||||
if _is_cuda or _is_hip:
|
if _is_cuda or _is_hip:
|
||||||
@@ -1537,7 +1539,7 @@ def fused_experts_impl(
|
|||||||
gemm1_alpha,
|
gemm1_alpha,
|
||||||
gemm1_limit,
|
gemm1_limit,
|
||||||
)
|
)
|
||||||
elif _is_cuda:
|
elif _is_cuda or _is_hip:
|
||||||
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||||
else:
|
else:
|
||||||
vllm_ops.silu_and_mul(
|
vllm_ops.silu_and_mul(
|
||||||
@@ -1546,7 +1548,7 @@ def fused_experts_impl(
|
|||||||
elif activation == "gelu":
|
elif activation == "gelu":
|
||||||
assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
|
assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
|
||||||
assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
|
assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
|
||||||
if _is_cuda:
|
if _is_cuda or _is_hip:
|
||||||
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||||
else:
|
else:
|
||||||
vllm_ops.gelu_and_mul(
|
vllm_ops.gelu_and_mul(
|
||||||
@@ -1619,10 +1621,19 @@ def fused_experts_impl(
|
|||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vllm_ops.moe_sum(
|
# According to micro benchmark results, torch.compile can get better performance for small token.
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
if tokens_in_chunk <= 32:
|
||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
moe_sum_reduce_torch_compile(
|
||||||
)
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
|
routed_scaling_factor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
moe_sum_reduce_triton(
|
||||||
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
|
routed_scaling_factor,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
vllm_ops.moe_sum(
|
vllm_ops.moe_sum(
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
|||||||
Reference in New Issue
Block a user