Add Qwen3-30B-A3B-Thinking-2507 support on AMD GPUs. (#9456)
This commit is contained in:
@@ -49,13 +49,15 @@ if _is_cuda:
|
||||
elif _is_cpu and _is_cpu_amx_available:
|
||||
pass
|
||||
elif _is_hip:
|
||||
from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
|
||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||
|
||||
if _use_aiter:
|
||||
try:
|
||||
from aiter import moe_sum
|
||||
except ImportError:
|
||||
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
|
||||
if _is_cuda or _is_hip:
|
||||
@@ -1537,7 +1539,7 @@ def fused_experts_impl(
|
||||
gemm1_alpha,
|
||||
gemm1_limit,
|
||||
)
|
||||
elif _is_cuda:
|
||||
elif _is_cuda or _is_hip:
|
||||
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||
else:
|
||||
vllm_ops.silu_and_mul(
|
||||
@@ -1546,7 +1548,7 @@ def fused_experts_impl(
|
||||
elif activation == "gelu":
|
||||
assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
|
||||
assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
|
||||
if _is_cuda:
|
||||
if _is_cuda or _is_hip:
|
||||
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||
else:
|
||||
vllm_ops.gelu_and_mul(
|
||||
@@ -1619,10 +1621,19 @@ def fused_experts_impl(
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
else:
|
||||
vllm_ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
# According to micro benchmark results, torch.compile can get better performance for small token.
|
||||
if tokens_in_chunk <= 32:
|
||||
moe_sum_reduce_torch_compile(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
moe_sum_reduce_triton(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
vllm_ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
|
||||
Reference in New Issue
Block a user