Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -18,7 +18,6 @@ from .ScaledMMLinearKernel import (
Int8ScaledMMLinearLayerConfig,
)
import vllm.envs as envs
class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
@@ -38,13 +37,28 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
config = self.config
# WEIGHT
# Cutlass kernels need transposed weight.
weight = getattr(layer, w_q_name)
replace_parameter(
layer,
w_q_name,
# torch.nn.Parameter(weight.t().data, requires_grad=False),
torch.nn.Parameter(weight.data if envs.VLLM_W8A8_LINEAR_USE_W4A8 else weight.t().data, requires_grad=False),
)
weight = getattr(layer, w_q_name)
if layer.scheme.is_w4a8_linear:
self.format = "NN"
replace_parameter(layer, w_q_name, torch.nn.Parameter(weight.data.contiguous(), requires_grad=False))
else:
self.format = "TN" #默认weight都是按T排布
m, k = weight.shape
if(m % 64 == 0 and k % 64 == 0):
self.format= "NN"
replace_parameter(
layer, w_q_name,
torch.nn.Parameter(weight.t().data.contiguous(), requires_grad=False))#原始排布是T[m,k] 处理完后是N[k, m]
else:
if k % 64 != 0:
pad_k = (k // 64 + 1) * 64
weight_pad = torch.empty((m, pad_k), dtype=weight.dtype, device=weight.device)
_weight = weight_pad[:, :k]
_weight.copy_(weight)
weight = _weight
replace_parameter(
layer, w_q_name,
torch.nn.Parameter(weight.t(), requires_grad=False))
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
@@ -114,6 +128,7 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
is_w4a8_linear: bool = False,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
@@ -121,9 +136,15 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
x_q, x_s, x_zp = ops.scaled_int8_quant(
x.contiguous(), i_s, i_zp, symmetric=symmetric
)
if isinstance(x, tuple):
x_q, x_s, out_dtype = x
x_zp = None
else:
out_dtype = x.dtype
x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(),
i_s,
i_zp,
symmetric=symmetric)
if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
@@ -134,14 +155,21 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
w_q,
scale_a=x_s,
scale_b=w_s,
out_dtype=x.dtype,
out_dtype=out_dtype,
azp_adj=azp_adj,
azp=azp,
bias=bias,
)
if self.format == "NN" and x_q.shape[-1] != w_q.shape[0]:
padding = w_q.shape[0] - x_q.shape[-1]
x_align = torch.nn.functional.pad(x_q, (0, padding), mode='constant', value=0)
elif self.format == "TN" and x_q.shape[-1] != w_q.shape[-1]:
padding = w_q.shape[-1] - x_q.shape[-1]
x_align = torch.nn.functional.pad(x_q, (0, padding), mode='constant', value=0)
else:
x_align = x_q
return ops.cutlass_scaled_mm(
# x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias, format="NN" if envs.VLLM_W8A8_LINEAR_USE_W4A8 else "TN"
x_align, w_q, scale_a=x_s, scale_b=w_s, out_dtype=out_dtype, bias=bias, format=self.format, is_w4a8_linear=is_w4a8_linear
)