Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -87,6 +87,10 @@ def _rocm_aiter_fused_moe_impl(
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
num_local_tokens: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype | None = None,
|
||||
hidden_pad: int = 0,
|
||||
intermediate_pad: int = 0,
|
||||
bias1: torch.Tensor | None = None,
|
||||
bias2: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
@@ -110,6 +114,10 @@ def _rocm_aiter_fused_moe_impl(
|
||||
a2_scale,
|
||||
num_local_tokens=num_local_tokens,
|
||||
dtype=output_dtype,
|
||||
hidden_pad=hidden_pad,
|
||||
intermediate_pad=intermediate_pad,
|
||||
bias1=bias1,
|
||||
bias2=bias2,
|
||||
)
|
||||
|
||||
|
||||
@@ -307,6 +315,28 @@ def _rocm_aiter_grouped_topk_fake(
|
||||
pass
|
||||
|
||||
|
||||
def _rocm_aiter_fused_topk_impl(
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
gate_up: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from aiter.fused_moe import fused_topk
|
||||
|
||||
# fused_topk returns (topk_weights, topk_indices)
|
||||
return fused_topk(x, router_logits, top_k, gate_up)
|
||||
|
||||
|
||||
def _rocm_aiter_fused_topk_fake(
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
gate_up: bool,
|
||||
) -> None:
|
||||
# tuple[torch.Tensor, torch.Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
# Cache whether aiter supports FP8 MLA parameters
|
||||
_AITER_MLA_SUPPORTS_FP8: bool | None = None
|
||||
|
||||
@@ -994,6 +1024,70 @@ class rocm_aiter_ops:
|
||||
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||
cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
||||
|
||||
@staticmethod
|
||||
def get_aiter_activation_type(activation_str: str):
|
||||
"""
|
||||
Given an activation type as a string, returns the corresponding aiter ActivationType enum.
|
||||
Supported activation types: "no", "none", "silu", "gelu", "swiglu".
|
||||
Returns None if the mapping fails.
|
||||
|
||||
Args:
|
||||
activation_str (str): Activation type as string.
|
||||
|
||||
Returns:
|
||||
Aiter ActivationType enum value, or None if not found.
|
||||
"""
|
||||
# Import only locally, since aiter may not always be available.
|
||||
try:
|
||||
from aiter import ActivationType
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
if not isinstance(activation_str, str):
|
||||
return None
|
||||
|
||||
name = activation_str.strip().lower()
|
||||
mapping = {
|
||||
"none": ActivationType.No,
|
||||
"no": ActivationType.No,
|
||||
"silu": ActivationType.Silu,
|
||||
"gelu": ActivationType.Gelu,
|
||||
"swiglu": ActivationType.Swiglu,
|
||||
}
|
||||
return mapping.get(name)
|
||||
|
||||
@staticmethod
|
||||
def get_aiter_quant_type(quant_type_str: str):
|
||||
"""
|
||||
Given a quantization type as a string, returns the corresponding aiter QuantType enum.
|
||||
Supported quantization types: "no", "per_tensor", "per_token", "per_1x32", "per_1x128", "per_128x128".
|
||||
Returns None if the mapping fails.
|
||||
|
||||
Args:
|
||||
quant_type_str (str): Quantization type as string.
|
||||
|
||||
Returns:
|
||||
Aiter QuantType enum value, or None if not found.
|
||||
"""
|
||||
try:
|
||||
from aiter import QuantType
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
if not isinstance(quant_type_str, str):
|
||||
return None
|
||||
|
||||
name = quant_type_str.strip().lower()
|
||||
mapping = {
|
||||
"no": QuantType.No,
|
||||
"per_tensor": QuantType.per_Tensor,
|
||||
"per_token": QuantType.per_Token,
|
||||
"per_1x32": QuantType.per_1x32,
|
||||
"per_1x128": QuantType.per_1x128,
|
||||
"per_128x128": QuantType.per_128x128,
|
||||
}
|
||||
return mapping.get(name)
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_enabled(cls) -> bool:
|
||||
@@ -1127,6 +1221,14 @@ class rocm_aiter_ops:
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_fused_topk",
|
||||
op_func=_rocm_aiter_fused_topk_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_fused_topk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_mla_decode_fwd",
|
||||
op_func=_rocm_aiter_mla_decode_fwd_impl,
|
||||
@@ -1360,6 +1462,10 @@ class rocm_aiter_ops:
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
num_local_tokens: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype | None = None,
|
||||
hidden_pad: int = 0,
|
||||
intermediate_pad: int = 0,
|
||||
bias1: torch.Tensor | None = None,
|
||||
bias2: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_fused_moe(
|
||||
hidden_states,
|
||||
@@ -1377,6 +1483,10 @@ class rocm_aiter_ops:
|
||||
a2_scale,
|
||||
num_local_tokens,
|
||||
output_dtype,
|
||||
hidden_pad,
|
||||
intermediate_pad,
|
||||
bias1,
|
||||
bias2,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -1481,6 +1591,15 @@ class rocm_aiter_ops:
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fused_topk(
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
gate_up: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.ops.vllm.rocm_aiter_fused_topk(x, router_logits, top_k, gate_up)
|
||||
|
||||
@staticmethod
|
||||
def mla_decode_fwd(
|
||||
q: torch.Tensor,
|
||||
@@ -1701,6 +1820,47 @@ class rocm_aiter_ops:
|
||||
|
||||
return shuffle_weight(tensor, layout=layout)
|
||||
|
||||
@staticmethod
|
||||
def shuffle_weight_a16w4(
|
||||
tensor: "torch.Tensor",
|
||||
nLane: int,
|
||||
gate_up: bool,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Shuffles the weight tensor into (A16W4) layout for AITER kernels.
|
||||
|
||||
Args:
|
||||
tensor: The input weight tensor to be shuffled.
|
||||
layout: The block layout to use, defaults to (16, 4).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The shuffled tensor.
|
||||
"""
|
||||
from aiter.ops.shuffle import shuffle_weight_a16w4
|
||||
|
||||
return shuffle_weight_a16w4(tensor, nLane, gate_up)
|
||||
|
||||
@staticmethod
|
||||
def shuffle_scale_a16w4(
|
||||
tensor: "torch.Tensor",
|
||||
num_experts: int,
|
||||
gate_up: bool,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Shuffles the scale tensor into (A16W4) layout for AITER kernels.
|
||||
|
||||
Args:
|
||||
tensor: The input scale tensor to be shuffled.
|
||||
num_experts: Number of experts, needed for reshaping logic.
|
||||
gate_up: Whether the scale is for w13 (True) or w2 (False).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The shuffled scale tensor.
|
||||
"""
|
||||
from aiter.ops.shuffle import shuffle_scale_a16w4
|
||||
|
||||
return shuffle_scale_a16w4(tensor, num_experts, gate_up)
|
||||
|
||||
@staticmethod
|
||||
def shuffle_weights(
|
||||
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
||||
|
||||
Reference in New Issue
Block a user