Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -18,7 +18,6 @@ from .ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
@@ -38,13 +37,28 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
config = self.config
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_q_name,
|
||||
# torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
torch.nn.Parameter(weight.data if envs.VLLM_W8A8_LINEAR_USE_W4A8 else weight.t().data, requires_grad=False),
|
||||
)
|
||||
weight = getattr(layer, w_q_name)
|
||||
if layer.scheme.is_w4a8_linear:
|
||||
self.format = "NN"
|
||||
replace_parameter(layer, w_q_name, torch.nn.Parameter(weight.data.contiguous(), requires_grad=False))
|
||||
else:
|
||||
self.format = "TN" #默认weight都是按T排布
|
||||
m, k = weight.shape
|
||||
if(m % 64 == 0 and k % 64 == 0):
|
||||
self.format= "NN"
|
||||
replace_parameter(
|
||||
layer, w_q_name,
|
||||
torch.nn.Parameter(weight.t().data.contiguous(), requires_grad=False))#原始排布是T[m,k] 处理完后是N[k, m]
|
||||
else:
|
||||
if k % 64 != 0:
|
||||
pad_k = (k // 64 + 1) * 64
|
||||
weight_pad = torch.empty((m, pad_k), dtype=weight.dtype, device=weight.device)
|
||||
_weight = weight_pad[:, :k]
|
||||
_weight.copy_(weight)
|
||||
weight = _weight
|
||||
replace_parameter(
|
||||
layer, w_q_name,
|
||||
torch.nn.Parameter(weight.t(), requires_grad=False))
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
@@ -114,6 +128,7 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
is_w4a8_linear: bool = False,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
@@ -121,9 +136,15 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
symmetric = azp_adj is None
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(
|
||||
x.contiguous(), i_s, i_zp, symmetric=symmetric
|
||||
)
|
||||
if isinstance(x, tuple):
|
||||
x_q, x_s, out_dtype = x
|
||||
x_zp = None
|
||||
else:
|
||||
out_dtype = x.dtype
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(),
|
||||
i_s,
|
||||
i_zp,
|
||||
symmetric=symmetric)
|
||||
|
||||
if x_zp is not None:
|
||||
# Currently, static is always per-tensor and dynamic is per-token
|
||||
@@ -134,14 +155,21 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
w_q,
|
||||
scale_a=x_s,
|
||||
scale_b=w_s,
|
||||
out_dtype=x.dtype,
|
||||
out_dtype=out_dtype,
|
||||
azp_adj=azp_adj,
|
||||
azp=azp,
|
||||
bias=bias,
|
||||
)
|
||||
if self.format == "NN" and x_q.shape[-1] != w_q.shape[0]:
|
||||
padding = w_q.shape[0] - x_q.shape[-1]
|
||||
x_align = torch.nn.functional.pad(x_q, (0, padding), mode='constant', value=0)
|
||||
elif self.format == "TN" and x_q.shape[-1] != w_q.shape[-1]:
|
||||
padding = w_q.shape[-1] - x_q.shape[-1]
|
||||
x_align = torch.nn.functional.pad(x_q, (0, padding), mode='constant', value=0)
|
||||
else:
|
||||
x_align = x_q
|
||||
return ops.cutlass_scaled_mm(
|
||||
# x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
||||
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias, format="NN" if envs.VLLM_W8A8_LINEAR_USE_W4A8 else "TN"
|
||||
x_align, w_q, scale_a=x_s, scale_b=w_s, out_dtype=out_dtype, bias=bias, format=self.format, is_w4a8_linear=is_w4a8_linear
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user