Remove vllm ops scaled fp8 quant and accelerate per token quant by 20-28% (#4215)

Co-authored-by: Stefan He <bhe@linkedin.com>
This commit is contained in:
Stefan He
2025-03-12 00:08:03 -07:00
committed by GitHub
parent 7130a7cea9
commit e0917e6bd0
5 changed files with 202 additions and 37 deletions

View File

@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Tuple
import torch
from torch.nn import Module
from vllm import _custom_ops as ops
from vllm import _custom_ops as vllm_ops
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
@@ -26,7 +26,13 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.utils import is_hip, set_weight_attrs
from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
logger = logging.getLogger(__name__)
@@ -719,12 +725,20 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
)
for expert in range(layer.num_experts_per_partition):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
if _is_cuda:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
else:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
return

View File

@@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops
from vllm import _custom_ops as vllm_ops
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
@@ -42,6 +42,7 @@ _is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
@@ -486,7 +487,7 @@ def moe_align_block_size(
cumsum_buffer,
)
else:
ops.moe_align_block_size(
vllm_ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
@@ -527,7 +528,10 @@ def invoke_fused_moe_kernel(
if block_shape is None:
# activation tensor-wise fp8 quantization, dynamic or static
padded_size = padding_size
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
if _is_cuda:
A, A_scale = sgl_scaled_fp8_quant(A, A_scale)
else:
A, A_scale = vllm_ops.scaled_fp8_quant(A, A_scale)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2
@@ -1095,12 +1099,16 @@ def fused_experts_impl(
if _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
vllm_ops.silu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
elif activation == "gelu":
if _is_cuda:
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
vllm_ops.gelu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
else:
raise ValueError(f"Unsupported activation: {activation=}")
@@ -1132,7 +1140,7 @@ def fused_experts_impl(
if no_combine:
pass
elif _is_hip:
ops.moe_sum(
vllm_ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
)