[AMD] Remove vllm's scaled_fp8_quant and moe_sum when SGLANG_USE_AITER=1 (#7484)
This commit is contained in:
@@ -54,14 +54,11 @@ _is_npu = is_npu()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if not _is_npu:
|
||||
if not (_is_npu or _is_hip):
|
||||
from sgl_kernel import silu_and_mul
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
|
||||
if _is_hip:
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
if _use_aiter:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
|
||||
@@ -39,11 +39,20 @@ _is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||
elif _is_cpu and _is_cpu_amx_available:
|
||||
pass
|
||||
elif _is_hip:
|
||||
from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
|
||||
|
||||
if _use_aiter:
|
||||
try:
|
||||
from aiter import moe_sum
|
||||
except ImportError:
|
||||
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
@@ -1521,11 +1530,7 @@ def fused_experts_impl(
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
):
|
||||
padded_size = padding_size
|
||||
if (
|
||||
not (use_fp8_w8a8 or use_int8_w8a8)
|
||||
or block_shape is not None
|
||||
or (_is_hip and get_bool_env_var("SGLANG_USE_AITER"))
|
||||
):
|
||||
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
||||
padded_size = 0
|
||||
|
||||
# Check constraints.
|
||||
@@ -1723,6 +1728,17 @@ def fused_experts_impl(
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
routed_scaling_factor,
|
||||
)
|
||||
elif _is_hip:
|
||||
if _use_aiter:
|
||||
moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
else:
|
||||
vllm_ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
else:
|
||||
vllm_ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
|
||||
@@ -20,7 +20,7 @@ from sglang.srt.layers.quantization.utils import (
|
||||
per_tensor_dequantize,
|
||||
replace_parameter,
|
||||
)
|
||||
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
|
||||
from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
@@ -32,8 +32,9 @@ _is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
|
||||
@@ -95,10 +95,9 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
if _is_hip and (_use_aiter or _use_hip_int4):
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.utils import (
|
||||
align,
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
get_device_core_count,
|
||||
get_device_name,
|
||||
is_cpu,
|
||||
@@ -39,6 +40,7 @@ from sglang.srt.utils import (
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_cpu = is_cpu()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
@@ -47,6 +49,22 @@ if _is_cuda:
|
||||
sgl_per_token_quant_fp8,
|
||||
)
|
||||
|
||||
if _is_hip:
|
||||
if _use_aiter:
|
||||
try:
|
||||
from aiter import ( # v0.1.3
|
||||
dynamic_per_tensor_quant,
|
||||
dynamic_per_token_scaled_quant,
|
||||
static_per_tensor_quant,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
||||
else:
|
||||
try:
|
||||
import vllm._C
|
||||
except ImportError:
|
||||
raise ImportError("vllm is required when SGLANG_USE_AITER is set to False")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1116,58 +1134,109 @@ def per_token_group_quant_mla_deep_gemm_masked_fp8(
|
||||
return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
|
||||
|
||||
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 (8-bit floating point) format.
|
||||
"""
|
||||
Quantize input tensor to FP8 (8-bit floating point) format.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input tensor to be quantized
|
||||
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
||||
If None, scales will be computed dynamically.
|
||||
num_token_padding (Optional[int]): If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
||||
determines the quantization granularity:
|
||||
- True: compute scale per token
|
||||
- False: compute single scale per tensor
|
||||
Args:
|
||||
input (torch.Tensor): Input tensor to be quantized
|
||||
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
||||
If None, scales will be computed dynamically.
|
||||
num_token_padding (Optional[int]): If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
||||
determines the quantization granularity:
|
||||
- True: compute scale per token
|
||||
- False: compute single scale per tensor
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- quantized_tensor: The FP8 quantized version of input
|
||||
- scale_tensor: The scaling factors used for quantization
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- quantized_tensor: The FP8 quantized version of input
|
||||
- scale_tensor: The scaling factors used for quantization
|
||||
|
||||
Raises:
|
||||
AssertionError: If input is not 2D or if static scale's numel != 1
|
||||
"""
|
||||
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||
shape = input.shape
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
|
||||
Raises:
|
||||
AssertionError: If input is not 2D or if static scale's numel != 1
|
||||
"""
|
||||
if _is_hip:
|
||||
|
||||
if scale is None:
|
||||
# Dynamic scaling
|
||||
if use_per_token_if_dynamic:
|
||||
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
|
||||
sgl_per_token_quant_fp8(input, output, scale)
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||
shape = input.shape
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
|
||||
|
||||
if scale is None:
|
||||
# Dynamic scaling
|
||||
if use_per_token_if_dynamic:
|
||||
scale = torch.empty(
|
||||
(shape[0], 1), device=input.device, dtype=torch.float32
|
||||
)
|
||||
if _use_aiter:
|
||||
dynamic_per_token_scaled_quant(output, input, scale)
|
||||
else:
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
|
||||
output, input.contiguous(), scale, None
|
||||
)
|
||||
else:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
if _use_aiter:
|
||||
dynamic_per_tensor_quant(output, input, scale)
|
||||
else:
|
||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
sgl_per_tensor_quant_fp8(
|
||||
input, output, scale, is_static=False
|
||||
) # False for dynamic
|
||||
else:
|
||||
# Static scaling
|
||||
assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
|
||||
sgl_per_tensor_quant_fp8(
|
||||
input, output, scale, is_static=True
|
||||
) # True for static
|
||||
# Static scaling
|
||||
assert (
|
||||
scale.numel() == 1
|
||||
), f"Expected scalar scale, got numel={scale.numel()}"
|
||||
if _use_aiter:
|
||||
static_per_tensor_quant(output, input, scale)
|
||||
else:
|
||||
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
return output, scale
|
||||
return output, scale
|
||||
|
||||
else:
|
||||
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||
shape = input.shape
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
|
||||
|
||||
if scale is None:
|
||||
# Dynamic scaling
|
||||
if use_per_token_if_dynamic:
|
||||
scale = torch.empty(
|
||||
(shape[0], 1), device=input.device, dtype=torch.float32
|
||||
)
|
||||
sgl_per_token_quant_fp8(input, output, scale)
|
||||
else:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
sgl_per_tensor_quant_fp8(
|
||||
input, output, scale, is_static=False
|
||||
) # False for dynamic
|
||||
else:
|
||||
# Static scaling
|
||||
assert (
|
||||
scale.numel() == 1
|
||||
), f"Expected scalar scale, got numel={scale.numel()}"
|
||||
sgl_per_tensor_quant_fp8(
|
||||
input, output, scale, is_static=True
|
||||
) # True for static
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
fp8_autotune = triton.autotune(
|
||||
|
||||
@@ -37,7 +37,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
if _use_aiter:
|
||||
from aiter import ActivationType
|
||||
from aiter.fused_moe import fused_moe
|
||||
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import torch
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
|
||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -21,8 +21,9 @@ _is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
|
||||
|
||||
@@ -3,8 +3,13 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.utils import is_cuda
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
|
||||
def quantize_ref_per_tensor(tensor, inv_scale):
|
||||
# The reference implementation that fully aligns to
|
||||
# the kernel being tested.
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
finfo = torch.finfo(fp8_dtype)
|
||||
scale = inv_scale.reciprocal()
|
||||
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
qweight = qweight.to(fp8_dtype)
|
||||
return qweight
|
||||
|
||||
def dequantize_per_tensor(tensor, inv_scale, dtype):
|
||||
@@ -48,19 +53,19 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
|
||||
)
|
||||
|
||||
|
||||
if is_cuda:
|
||||
if _is_cuda or _is_hip:
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
|
||||
def quantize_ref_per_token(tensor, inv_scale):
|
||||
# The reference implementation that fully aligns to
|
||||
# the kernel being tested.
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
finfo = torch.finfo(fp8_dtype)
|
||||
scale = inv_scale.reciprocal()
|
||||
qweight = (tensor.to(torch.float32) * scale).clamp(
|
||||
min=finfo.min, max=finfo.max
|
||||
)
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
qweight = qweight.to(fp8_dtype)
|
||||
return qweight
|
||||
|
||||
def dequantize_per_token(tensor, inv_scale, dtype):
|
||||
|
||||
Reference in New Issue
Block a user