Remove vllm ops scaled fp8 quant and accelerate per token quant by 20-28% (#4215)
Co-authored-by: Stefan He <bhe@linkedin.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -40,3 +42,60 @@ class CustomOp(nn.Module):
|
||||
return self.forward_hip
|
||||
else:
|
||||
return self.forward_native
|
||||
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
|
||||
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 (8-bit floating point) format.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input tensor to be quantized
|
||||
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
||||
If None, scales will be computed dynamically.
|
||||
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
||||
determines the quantization granularity:
|
||||
- True: compute scale per token
|
||||
- False: compute single scale per tensor
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- quantized_tensor: The FP8 quantized version of input
|
||||
- scale_tensor: The scaling factors used for quantization
|
||||
|
||||
Raises:
|
||||
AssertionError: If input is not 2D or if static scale's numel != 1
|
||||
"""
|
||||
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||
shape = input.shape
|
||||
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||
|
||||
if scale is None:
|
||||
# Dynamic scaling
|
||||
if use_per_token_if_dynamic:
|
||||
scale = torch.empty(
|
||||
(shape[0], 1), device=input.device, dtype=torch.float32
|
||||
)
|
||||
sgl_per_token_quant_fp8(input, output, scale)
|
||||
else:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
sgl_per_tensor_quant_fp8(
|
||||
input, output, scale, is_static=False
|
||||
) # False for dynamic
|
||||
else:
|
||||
# Static scaling
|
||||
assert (
|
||||
scale.numel() == 1
|
||||
), f"Expected scalar scale, got numel={scale.numel()}"
|
||||
sgl_per_tensor_quant_fp8(
|
||||
input, output, scale, is_static=True
|
||||
) # True for static
|
||||
|
||||
return output, scale
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import (
|
||||
@@ -26,7 +26,13 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||
from sglang.srt.utils import is_hip, set_weight_attrs
|
||||
from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -719,12 +725,20 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
)
|
||||
|
||||
for expert in range(layer.num_experts_per_partition):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
if _is_cuda:
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
else:
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
return
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
@@ -42,6 +42,7 @@ _is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
@@ -486,7 +487,7 @@ def moe_align_block_size(
|
||||
cumsum_buffer,
|
||||
)
|
||||
else:
|
||||
ops.moe_align_block_size(
|
||||
vllm_ops.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
@@ -527,7 +528,10 @@ def invoke_fused_moe_kernel(
|
||||
if block_shape is None:
|
||||
# activation tensor-wise fp8 quantization, dynamic or static
|
||||
padded_size = padding_size
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
if _is_cuda:
|
||||
A, A_scale = sgl_scaled_fp8_quant(A, A_scale)
|
||||
else:
|
||||
A, A_scale = vllm_ops.scaled_fp8_quant(A, A_scale)
|
||||
else:
|
||||
# activation block-wise fp8 quantization
|
||||
assert len(block_shape) == 2
|
||||
@@ -1095,12 +1099,16 @@ def fused_experts_impl(
|
||||
if _is_cuda:
|
||||
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||
else:
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
vllm_ops.silu_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
elif activation == "gelu":
|
||||
if _is_cuda:
|
||||
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||
else:
|
||||
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
vllm_ops.gelu_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation=}")
|
||||
|
||||
@@ -1132,7 +1140,7 @@ def fused_experts_impl(
|
||||
if no_combine:
|
||||
pass
|
||||
elif _is_hip:
|
||||
ops.moe_sum(
|
||||
vllm_ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
|
||||
88
python/sglang/test/test_custom_ops.py
Normal file
88
python/sglang/test/test_custom_ops.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/8ca7a71df787ad711ad3ac70a5bd2eb2bb398938/tests/quantization/test_fp8.py
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from sglang.srt.custom_op import scaled_fp8_quant
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_scaled_fp8_quant_per_tensor(dtype) -> None:
|
||||
|
||||
def quantize_ref_per_tensor(tensor, inv_scale):
|
||||
# The reference implementation that fully aligns to
|
||||
# the kernel being tested.
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
scale = inv_scale.reciprocal()
|
||||
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
return qweight
|
||||
|
||||
def dequantize_per_tensor(tensor, inv_scale, dtype):
|
||||
fake_qweight = tensor.to(dtype)
|
||||
dq_weight = fake_qweight * inv_scale
|
||||
return dq_weight
|
||||
|
||||
# Note that we use a shape % 8 != 0 to cover edge cases,
|
||||
# because scaled_fp8_quant is vectorized by 8.
|
||||
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
|
||||
|
||||
# Test Per Tensor Dynamic quantization
|
||||
# scale = max(abs(x)) / FP8_E4M3_MAX
|
||||
y, scale = scaled_fp8_quant(x, None)
|
||||
ref_y = quantize_ref_per_tensor(x, scale)
|
||||
torch.testing.assert_close(y, ref_y)
|
||||
torch.testing.assert_close(
|
||||
dequantize_per_tensor(y, scale, dtype),
|
||||
dequantize_per_tensor(ref_y, scale, dtype),
|
||||
)
|
||||
|
||||
# Test Per Tensor Static quantization
|
||||
y, _ = scaled_fp8_quant(x, scale)
|
||||
ref_y = quantize_ref_per_tensor(x, scale)
|
||||
torch.testing.assert_close(y, ref_y)
|
||||
torch.testing.assert_close(
|
||||
dequantize_per_tensor(y, scale, dtype),
|
||||
dequantize_per_tensor(ref_y, scale, dtype),
|
||||
)
|
||||
|
||||
|
||||
if is_cuda:
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
|
||||
def quantize_ref_per_token(tensor, inv_scale):
|
||||
# The reference implementation that fully aligns to
|
||||
# the kernel being tested.
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
scale = inv_scale.reciprocal()
|
||||
qweight = (tensor.to(torch.float32) * scale).clamp(
|
||||
min=finfo.min, max=finfo.max
|
||||
)
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
return qweight
|
||||
|
||||
def dequantize_per_token(tensor, inv_scale, dtype):
|
||||
fake_qweight = tensor.to(dtype)
|
||||
dq_weight = fake_qweight * inv_scale
|
||||
return dq_weight
|
||||
|
||||
# Note that we use a shape % 8 = 0,
|
||||
# because per_token_quant_fp8 is vectorized by 8 elements.
|
||||
x = (torch.randn(size=(11, 16), device="cuda") * 13).to(dtype)
|
||||
|
||||
# Test Per Tensor Dynamic quantization
|
||||
# scale = max(abs(x)) / FP8_E4M3_MAX
|
||||
y, scale = scaled_fp8_quant(x, None, use_per_token_if_dynamic=True)
|
||||
ref_y = quantize_ref_per_token(x, scale)
|
||||
torch.testing.assert_close(y, ref_y)
|
||||
torch.testing.assert_close(
|
||||
dequantize_per_token(y, scale, dtype),
|
||||
dequantize_per_token(ref_y, scale, dtype),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the specific test function directly
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user