[AMD] Remove vllm's scaled_fp8_quant and moe_sum when SGLANG_USE_AITER=1 (#7484)
This commit is contained in:
@@ -54,14 +54,11 @@ _is_npu = is_npu()
|
|||||||
_is_fp8_fnuz = is_fp8_fnuz()
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
|
|
||||||
if not _is_npu:
|
if not (_is_npu or _is_hip):
|
||||||
from sgl_kernel import silu_and_mul
|
from sgl_kernel import silu_and_mul
|
||||||
|
|
||||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||||
|
|
||||||
if _is_hip:
|
|
||||||
from vllm._custom_ops import scaled_fp8_quant
|
|
||||||
|
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
from aiter import ActivationType, QuantType
|
from aiter import ActivationType, QuantType
|
||||||
from aiter.fused_moe import fused_moe
|
from aiter.fused_moe import fused_moe
|
||||||
|
|||||||
@@ -39,11 +39,20 @@ _is_hip = is_hip()
|
|||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||||
elif _is_cpu and _is_cpu_amx_available:
|
elif _is_cpu and _is_cpu_amx_available:
|
||||||
pass
|
pass
|
||||||
|
elif _is_hip:
|
||||||
|
from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
|
||||||
|
|
||||||
|
if _use_aiter:
|
||||||
|
try:
|
||||||
|
from aiter import moe_sum
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
||||||
else:
|
else:
|
||||||
from vllm import _custom_ops as vllm_ops
|
from vllm import _custom_ops as vllm_ops
|
||||||
from vllm._custom_ops import scaled_fp8_quant
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
@@ -1521,11 +1530,7 @@ def fused_experts_impl(
|
|||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
):
|
):
|
||||||
padded_size = padding_size
|
padded_size = padding_size
|
||||||
if (
|
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
||||||
not (use_fp8_w8a8 or use_int8_w8a8)
|
|
||||||
or block_shape is not None
|
|
||||||
or (_is_hip and get_bool_env_var("SGLANG_USE_AITER"))
|
|
||||||
):
|
|
||||||
padded_size = 0
|
padded_size = 0
|
||||||
|
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
@@ -1723,6 +1728,17 @@ def fused_experts_impl(
|
|||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
elif _is_hip:
|
||||||
|
if _use_aiter:
|
||||||
|
moe_sum(
|
||||||
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vllm_ops.moe_sum(
|
||||||
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
vllm_ops.moe_sum(
|
vllm_ops.moe_sum(
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from sglang.srt.layers.quantization.utils import (
|
|||||||
per_tensor_dequantize,
|
per_tensor_dequantize,
|
||||||
replace_parameter,
|
replace_parameter,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
|
from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
@@ -32,8 +32,9 @@ _is_cuda = is_cuda()
|
|||||||
_is_npu = is_npu()
|
_is_npu = is_npu()
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
||||||
from vllm import _custom_ops as vllm_ops
|
from vllm import _custom_ops as vllm_ops
|
||||||
from vllm._custom_ops import scaled_fp8_quant
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
|
|
||||||
|
|||||||
@@ -95,10 +95,9 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|||||||
if _is_hip and (_use_aiter or _use_hip_int4):
|
if _is_hip and (_use_aiter or _use_hip_int4):
|
||||||
from aiter import ActivationType, QuantType
|
from aiter import ActivationType, QuantType
|
||||||
from aiter.fused_moe import fused_moe
|
from aiter.fused_moe import fused_moe
|
||||||
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
||||||
from vllm._custom_ops import scaled_fp8_quant
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from sglang.srt.layers.quantization import deep_gemm_wrapper
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
align,
|
align,
|
||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
|
get_bool_env_var,
|
||||||
get_device_core_count,
|
get_device_core_count,
|
||||||
get_device_name,
|
get_device_name,
|
||||||
is_cpu,
|
is_cpu,
|
||||||
@@ -39,6 +40,7 @@ from sglang.srt.utils import (
|
|||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
@@ -47,6 +49,22 @@ if _is_cuda:
|
|||||||
sgl_per_token_quant_fp8,
|
sgl_per_token_quant_fp8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if _is_hip:
|
||||||
|
if _use_aiter:
|
||||||
|
try:
|
||||||
|
from aiter import ( # v0.1.3
|
||||||
|
dynamic_per_tensor_quant,
|
||||||
|
dynamic_per_token_scaled_quant,
|
||||||
|
static_per_tensor_quant,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import vllm._C
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("vllm is required when SGLANG_USE_AITER is set to False")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -1116,58 +1134,109 @@ def per_token_group_quant_mla_deep_gemm_masked_fp8(
|
|||||||
return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
|
return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
|
||||||
|
|
||||||
|
|
||||||
def scaled_fp8_quant(
|
"""
|
||||||
input: torch.Tensor,
|
Quantize input tensor to FP8 (8-bit floating point) format.
|
||||||
scale: Optional[torch.Tensor] = None,
|
|
||||||
num_token_padding: Optional[int] = None,
|
|
||||||
use_per_token_if_dynamic: bool = False,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Quantize input tensor to FP8 (8-bit floating point) format.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (torch.Tensor): Input tensor to be quantized
|
input (torch.Tensor): Input tensor to be quantized
|
||||||
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
||||||
If None, scales will be computed dynamically.
|
If None, scales will be computed dynamically.
|
||||||
num_token_padding (Optional[int]): If specified, pad the first dimension
|
num_token_padding (Optional[int]): If specified, pad the first dimension
|
||||||
of the output to at least this value.
|
of the output to at least this value.
|
||||||
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
||||||
determines the quantization granularity:
|
determines the quantization granularity:
|
||||||
- True: compute scale per token
|
- True: compute scale per token
|
||||||
- False: compute single scale per tensor
|
- False: compute single scale per tensor
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||||
- quantized_tensor: The FP8 quantized version of input
|
- quantized_tensor: The FP8 quantized version of input
|
||||||
- scale_tensor: The scaling factors used for quantization
|
- scale_tensor: The scaling factors used for quantization
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AssertionError: If input is not 2D or if static scale's numel != 1
|
AssertionError: If input is not 2D or if static scale's numel != 1
|
||||||
"""
|
"""
|
||||||
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
if _is_hip:
|
||||||
shape = input.shape
|
|
||||||
if num_token_padding:
|
|
||||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
|
||||||
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
|
|
||||||
|
|
||||||
if scale is None:
|
def scaled_fp8_quant(
|
||||||
# Dynamic scaling
|
input: torch.Tensor,
|
||||||
if use_per_token_if_dynamic:
|
scale: Optional[torch.Tensor] = None,
|
||||||
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
|
num_token_padding: Optional[int] = None,
|
||||||
sgl_per_token_quant_fp8(input, output, scale)
|
use_per_token_if_dynamic: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||||
|
shape = input.shape
|
||||||
|
if num_token_padding:
|
||||||
|
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||||
|
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
|
||||||
|
|
||||||
|
if scale is None:
|
||||||
|
# Dynamic scaling
|
||||||
|
if use_per_token_if_dynamic:
|
||||||
|
scale = torch.empty(
|
||||||
|
(shape[0], 1), device=input.device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
if _use_aiter:
|
||||||
|
dynamic_per_token_scaled_quant(output, input, scale)
|
||||||
|
else:
|
||||||
|
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
|
||||||
|
output, input.contiguous(), scale, None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||||
|
if _use_aiter:
|
||||||
|
dynamic_per_tensor_quant(output, input, scale)
|
||||||
|
else:
|
||||||
|
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||||
else:
|
else:
|
||||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
# Static scaling
|
||||||
sgl_per_tensor_quant_fp8(
|
assert (
|
||||||
input, output, scale, is_static=False
|
scale.numel() == 1
|
||||||
) # False for dynamic
|
), f"Expected scalar scale, got numel={scale.numel()}"
|
||||||
else:
|
if _use_aiter:
|
||||||
# Static scaling
|
static_per_tensor_quant(output, input, scale)
|
||||||
assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
|
else:
|
||||||
sgl_per_tensor_quant_fp8(
|
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
||||||
input, output, scale, is_static=True
|
|
||||||
) # True for static
|
|
||||||
|
|
||||||
return output, scale
|
return output, scale
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def scaled_fp8_quant(
|
||||||
|
input: torch.Tensor,
|
||||||
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
num_token_padding: Optional[int] = None,
|
||||||
|
use_per_token_if_dynamic: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||||
|
shape = input.shape
|
||||||
|
if num_token_padding:
|
||||||
|
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||||
|
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
|
||||||
|
|
||||||
|
if scale is None:
|
||||||
|
# Dynamic scaling
|
||||||
|
if use_per_token_if_dynamic:
|
||||||
|
scale = torch.empty(
|
||||||
|
(shape[0], 1), device=input.device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
sgl_per_token_quant_fp8(input, output, scale)
|
||||||
|
else:
|
||||||
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||||
|
sgl_per_tensor_quant_fp8(
|
||||||
|
input, output, scale, is_static=False
|
||||||
|
) # False for dynamic
|
||||||
|
else:
|
||||||
|
# Static scaling
|
||||||
|
assert (
|
||||||
|
scale.numel() == 1
|
||||||
|
), f"Expected scalar scale, got numel={scale.numel()}"
|
||||||
|
sgl_per_tensor_quant_fp8(
|
||||||
|
input, output, scale, is_static=True
|
||||||
|
) # True for static
|
||||||
|
|
||||||
|
return output, scale
|
||||||
|
|
||||||
|
|
||||||
fp8_autotune = triton.autotune(
|
fp8_autotune = triton.autotune(
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
from aiter import ActivationType
|
from aiter import ActivationType
|
||||||
from aiter.fused_moe import fused_moe
|
from aiter.fused_moe import fused_moe
|
||||||
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||||
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
||||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
|
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
@@ -21,8 +21,9 @@ _is_cuda = is_cuda()
|
|||||||
_is_npu = is_npu()
|
_is_npu = is_npu()
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
||||||
from vllm._custom_ops import scaled_fp8_quant
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,13 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
||||||
from sglang.srt.utils import is_cuda
|
from sglang.srt.utils import is_cuda, is_hip
|
||||||
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
_is_hip = is_hip()
|
||||||
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
|
fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
|
|||||||
def quantize_ref_per_tensor(tensor, inv_scale):
|
def quantize_ref_per_tensor(tensor, inv_scale):
|
||||||
# The reference implementation that fully aligns to
|
# The reference implementation that fully aligns to
|
||||||
# the kernel being tested.
|
# the kernel being tested.
|
||||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
finfo = torch.finfo(fp8_dtype)
|
||||||
scale = inv_scale.reciprocal()
|
scale = inv_scale.reciprocal()
|
||||||
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
|
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
qweight = qweight.to(torch.float8_e4m3fn)
|
qweight = qweight.to(fp8_dtype)
|
||||||
return qweight
|
return qweight
|
||||||
|
|
||||||
def dequantize_per_tensor(tensor, inv_scale, dtype):
|
def dequantize_per_tensor(tensor, inv_scale, dtype):
|
||||||
@@ -48,19 +53,19 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_cuda:
|
if _is_cuda or _is_hip:
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
|
def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
|
||||||
def quantize_ref_per_token(tensor, inv_scale):
|
def quantize_ref_per_token(tensor, inv_scale):
|
||||||
# The reference implementation that fully aligns to
|
# The reference implementation that fully aligns to
|
||||||
# the kernel being tested.
|
# the kernel being tested.
|
||||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
finfo = torch.finfo(fp8_dtype)
|
||||||
scale = inv_scale.reciprocal()
|
scale = inv_scale.reciprocal()
|
||||||
qweight = (tensor.to(torch.float32) * scale).clamp(
|
qweight = (tensor.to(torch.float32) * scale).clamp(
|
||||||
min=finfo.min, max=finfo.max
|
min=finfo.min, max=finfo.max
|
||||||
)
|
)
|
||||||
qweight = qweight.to(torch.float8_e4m3fn)
|
qweight = qweight.to(fp8_dtype)
|
||||||
return qweight
|
return qweight
|
||||||
|
|
||||||
def dequantize_per_token(tensor, inv_scale, dtype):
|
def dequantize_per_token(tensor, inv_scale, dtype):
|
||||||
|
|||||||
Reference in New Issue
Block a user