init
This commit is contained in:
390
vllm_vacc/vllm/attention/backends/mla/utils.py
Normal file
390
vllm_vacc/vllm/attention/backends/mla/utils.py
Normal file
@@ -0,0 +1,390 @@
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generic, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl, T)
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW8A8Fp8)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
# from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
# apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_dequantize, scaled_quantize)
|
||||
import os
|
||||
|
||||
W_Q_W_QR_WUV_WUK_USE_FP8 = True
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def is_layer_fp8(layer: LinearBase) -> bool:
|
||||
return isinstance(layer.quant_method, Fp8LinearMethod) or\
|
||||
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
|
||||
|
||||
def quantization_scheme_supported(layer: LinearBase) -> bool:
|
||||
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
|
||||
is_layer_fp8(layer)
|
||||
|
||||
# TODO(lucas) This is very gross, we need a more wide scale refactor of
|
||||
# all the FP8 code with a more standard way of
|
||||
# defining schemes/group-shapes, we should also potentially force
|
||||
# quant_methods to support a decompress function
|
||||
#
|
||||
# returns input_group_shape, weight_group_shape
|
||||
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
|
||||
Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
if isinstance(layer.quant_method, Fp8LinearMethod):
|
||||
if layer.quant_method.block_quant is not None:
|
||||
weight_block_size = \
|
||||
layer.quant_method.quant_config.weight_block_size
|
||||
# per-token-group (1, X), block-quantized (X, Y)
|
||||
return (1, weight_block_size[-1]), weight_block_size
|
||||
else:
|
||||
return (-1, -1), (-1, -1) # per-tensor, per-tensor
|
||||
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# this is hacky but we always assume the for
|
||||
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
|
||||
# we ignore if it is static-per-tensor since we are going to
|
||||
# requantize after later anyways
|
||||
strategy = layer.scheme.strategy
|
||||
if strategy == QuantizationStrategy.TENSOR:
|
||||
return (1, -1), (-1, -1) # per-token, per-tensor
|
||||
elif strategy == QuantizationStrategy.CHANNEL:
|
||||
return (1, -1), (-1, 1) # per-token, per-channel
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"QuantizationStrategy.{strategy} is not supported for "
|
||||
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Can't determine scale group shapes for "
|
||||
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
|
||||
)
|
||||
|
||||
def get_scales(layer: LinearBase) -> torch.Tensor:
|
||||
if hasattr(layer, "weight_scale_inv"):
|
||||
return layer.weight_scale_inv
|
||||
return layer.weight_scale
|
||||
|
||||
def get_fp8_layer_weight(layer: LinearBase):
|
||||
if is_layer_fp8(layer):
|
||||
if isinstance(layer.quant_method, \
|
||||
CompressedTensorsLinearMethod) and \
|
||||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||
# seems to store weights as (input, output) instead of
|
||||
# (output, input) so we need to transpose
|
||||
weight = layer.weight.T # standardize to (output, input)
|
||||
else:
|
||||
weight = layer.weight
|
||||
_, weight_scale_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(layer)
|
||||
scales = get_scales(layer) # 已经expand过了
|
||||
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
|
||||
|
||||
# 重新校准一下 weight_scale_group_shape
|
||||
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
|
||||
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
|
||||
|
||||
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
|
||||
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
|
||||
|
||||
return weight, scales
|
||||
else:
|
||||
return layer.weight, None
|
||||
|
||||
def get_fp8_layer_weight_test(layer: LinearBase):
|
||||
if is_layer_fp8(layer):
|
||||
if isinstance(layer.quant_method, \
|
||||
CompressedTensorsLinearMethod) and \
|
||||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||
# seems to store weights as (input, output) instead of
|
||||
# (output, input) so we need to transpose
|
||||
weight = layer.weight.T # standardize to (output, input)
|
||||
else:
|
||||
weight = layer.weight
|
||||
_, weight_scale_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(layer)
|
||||
scales = get_scales(layer) # 已经expand过了
|
||||
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
|
||||
|
||||
# 重新校准一下 weight_scale_group_shape
|
||||
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
|
||||
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
|
||||
|
||||
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
|
||||
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
|
||||
|
||||
# for test
|
||||
weight = scaled_dequantize(weight, scales, weight_scale_group_shape)
|
||||
# print(f'{weight.shape}, {scales.shape}, {weight_scale_group_shape}')
|
||||
return weight, scales
|
||||
else:
|
||||
return layer.weight, None
|
||||
|
||||
def check_eq(name, tensor0, tensor1):
|
||||
assert tensor0.shape == tensor1.shape
|
||||
isEqual = torch.equal(tensor0.reshape([-1]).float(), tensor1.reshape([-1]).float())
|
||||
print(f"{os.getpid()} check {name} {tensor0.shape} equal: {isEqual}")
|
||||
return isEqual
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if is_layer_fp8(layer):
|
||||
if isinstance(layer.quant_method, \
|
||||
CompressedTensorsLinearMethod) and \
|
||||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||
# seems to store weights as (input, output) instead of
|
||||
# (output, input) so we need to transpose
|
||||
weight = layer.weight.T # standardize to (output, input)
|
||||
else:
|
||||
weight = layer.weight
|
||||
_, weight_scale_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(layer)
|
||||
scales = get_scales(layer) # 已经expand过了
|
||||
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
|
||||
|
||||
# 重新校准一下 weight_scale_group_shape
|
||||
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
|
||||
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
|
||||
|
||||
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
|
||||
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
|
||||
|
||||
return scaled_dequantize(weight, scales,
|
||||
weight_scale_group_shape)
|
||||
else:
|
||||
return layer.weight
|
||||
|
||||
if not (quantization_scheme_supported(self.kv_b_proj) and\
|
||||
quantization_scheme_supported(self.q_proj) and\
|
||||
quantization_scheme_supported(self.o_proj)):
|
||||
raise NotImplementedError(
|
||||
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
|
||||
", please run with VLLM_MLA_DISABLE=1")
|
||||
|
||||
weight_dtype = self.kv_b_proj.weight.dtype
|
||||
assert self.o_proj.weight.dtype == weight_dtype
|
||||
assert self.q_proj.weight.dtype == weight_dtype
|
||||
|
||||
if W_Q_W_QR_WUV_WUK_USE_FP8: #and not envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
# 512,1024(=4x256)
|
||||
kv_b_proj_weight, kv_b_proj_scale = \
|
||||
[t.T for t in get_fp8_layer_weight(self.kv_b_proj)]
|
||||
|
||||
# kv_b_proj_weight = kv_b_proj_weight.transpose(-1,-2).contiguous().transpose(-1,-2)
|
||||
N, K = kv_b_proj_weight.shape[0], kv_b_proj_weight.shape[1]
|
||||
|
||||
# 512,1024 => 512,4,256
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
kv_b_proj_scale = kv_b_proj_scale.view(
|
||||
kv_b_proj_scale.shape[0] * self.kv_lora_rank // N,
|
||||
self.num_heads,
|
||||
kv_b_proj_scale.shape[1] * N // (self.kv_lora_rank * self.num_heads),
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
W_UK = W_UK.contiguous()
|
||||
|
||||
scale_0 = kv_b_proj_scale.shape[-1] * self.qk_nope_head_dim // (self.qk_nope_head_dim + self.v_head_dim)
|
||||
scale_1 = kv_b_proj_scale.shape[-1] - scale_0
|
||||
|
||||
W_UK_scale, W_UV_scale = kv_b_proj_scale.split(
|
||||
[scale_0, scale_1], dim=-1)
|
||||
W_UK_scale = W_UK_scale.view(W_UK_scale.shape[0], -1).unsqueeze(-1).contiguous()
|
||||
W_UV_scale = W_UV_scale.view(W_UV_scale.shape[0], -1).unsqueeze(-1)
|
||||
|
||||
# weight: [1536, 768] scale: 12,6
|
||||
q_proj_weight, q_proj_scale = \
|
||||
[t.T for t in get_fp8_layer_weight(self.q_proj)]
|
||||
|
||||
#self.W_Q_QR = q_proj_weight.contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
#self.W_Q_QR_scales = q_proj_scale.reshape(12, 6, 1).repeat(1, 1, 4).reshape(12, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
|
||||
q_proj_weight = q_proj_weight\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
# w_q[1536, 512] + w_qr[1536, 256]
|
||||
W_Q = q_proj_weight[..., :self.qk_nope_head_dim].flatten(start_dim=1)
|
||||
W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
# w_q_scale 12,16 + w_qr_scale 12,8
|
||||
# expand: 12,6(4+2) -> 12,24(16+8)
|
||||
# Q_scale: [s0x4, s1x2, s2x2, s3x4, s4x2, s5x2]
|
||||
repeat_pattern = torch.tensor([4, 2, 2, 4, 2, 2], device=q_proj_scale.device)
|
||||
W_Q_scale = torch.repeat_interleave(q_proj_scale, repeat_pattern, dim=1)
|
||||
# Q_R_scale: [s1x2, s2x2, s4x2, s5x2]
|
||||
selected_indices = [1, 2, 4, 5]
|
||||
repeat_times = 2
|
||||
selected = q_proj_scale[:, selected_indices]
|
||||
W_QR_scale = selected.repeat_interleave(repeat_times, dim=1)
|
||||
|
||||
# temp_WQ_Scale = W_Q_scale.reshape(12, 4, -1).contiguous()
|
||||
# temp_W_QR_scale = W_QR_scale.reshape(12, 4, -1).contiguous()
|
||||
# temp_scale = torch.cat([temp_WQ_Scale, temp_W_QR_scale], dim=2).contiguous().reshape(12, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
# self.W_Q_QR_scales = temp_scale
|
||||
# print("W_Q_scale:", W_Q_scale.shape)
|
||||
# print("W_QR_scale:", W_QR_scale.shape)
|
||||
# print("temp_scale:", temp_scale.shape)
|
||||
# exit(0)
|
||||
|
||||
# Note: to be vnnl compatible
|
||||
# 1. expand w_uv scale for core split friendly
|
||||
if W_UV.shape[-1] % 4 == 0:
|
||||
W_UV_scale = W_UV_scale.expand((W_UV_scale.shape[0], W_UV_scale.shape[1], 4))
|
||||
# 2. change w_q, w_qr, w_uv weight&scale to K-contiguous (shape unchanged)
|
||||
W_Q = W_Q.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
W_Q_scale = W_Q_scale.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
W_QR = W_QR.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
W_QR_scale = W_QR_scale.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
|
||||
W_UV = W_UV.permute(2,1,0).contiguous().permute(2,1,0)
|
||||
W_UV_scale = W_UV_scale.permute(2,1,0).contiguous().permute(2,1,0)
|
||||
|
||||
self.W_Q = W_Q
|
||||
self.W_Q_scales = W_Q_scale
|
||||
|
||||
self.W_QR = W_QR
|
||||
self.W_QR_scales = W_QR_scale
|
||||
|
||||
# temp_Q_scale = self.W_Q_scales.contiguous()
|
||||
# temp_W_QR_scale = self.W_QR_scales.contiguous()
|
||||
# self.W_Q_QR = q_proj_weight.reshape(1536, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
# self.W_Q_QR_scales = torch.concat([temp_Q_scale,temp_W_QR_scale],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
#self.W_Q_QR = torch.concat([self.W_Q.contiguous(),self.W_QR.contiguous()],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
#self.W_Q_QR_scales = torch.concat([W_Q_scale,W_QR_scale],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
|
||||
self.W_UV = W_UV
|
||||
self.W_UV_scales = W_UV_scale
|
||||
|
||||
self.W_UK = W_UK
|
||||
self.W_UK_scales = W_UK_scale
|
||||
return
|
||||
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
|
||||
# can be W_Q or W_UQ depending q_lora_rank, the former if
|
||||
# q_lora_rank is None, the latter otherwise. From the Attention backend
|
||||
# perspective though we call these both W_Q and rely on the layer
|
||||
# to pass in the correct matrix
|
||||
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
|
||||
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
# W_QR is small so for simplicity we dont bother requantizing it
|
||||
self.W_QR = self.W_QR.to(act_dtype)
|
||||
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
assert False, "please set VLLM_MLA_PERFORM_MATRIX_ABSORPTION=0"
|
||||
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
# This assumes it wise to requantize using the same group shapes
|
||||
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
|
||||
# weights were originally quantized
|
||||
requant_input_group_shape, requant_weight_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(self.q_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.o_proj)
|
||||
self.reqaunt_input_group_shape = requant_input_group_shape
|
||||
self.reqaunt_weight_group_shape = requant_weight_group_shape
|
||||
|
||||
#
|
||||
# Perform matrix-absorption following
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/551
|
||||
# for decode, as a result we end up with absorbed weights for decode
|
||||
# and another copy of raw weights for prefill.
|
||||
#
|
||||
self.W_UK, self.W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
|
||||
# depending q_lora_rank, the former if q_lora_rank is None, the
|
||||
# latter otherwise
|
||||
# basically if q_lora_rank is none we are absorbing into q_proj
|
||||
# instead of UQ
|
||||
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
||||
W_Q_UK,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform_fp8_dtype)
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_Q_UK = W_Q_UK.T.contiguous()
|
||||
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
|
||||
else:
|
||||
self.W_Q_UK = W_Q_UK.to(act_dtype)
|
||||
|
||||
W_O = get_and_maybe_dequant_weights(self.o_proj)\
|
||||
.view(-1, self.num_heads, self.v_head_dim)
|
||||
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
||||
.flatten(start_dim=0, end_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_UV_O, W_UV_O_scales = scaled_quantize(
|
||||
W_UV_O,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform_fp8_dtype)
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_UV_O = W_UV_O.T.contiguous()
|
||||
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
|
||||
else:
|
||||
self.W_UV_O = W_UV_O.to(act_dtype)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
# print('W_UV', W_UV.dtype) #float32
|
||||
#if is_fp8(weight_dtype):
|
||||
# raise NotImplementedError(
|
||||
# "Currently fp8 requires matrix absorption")
|
||||
# self.W_UV = W_UV
|
||||
# self.W_UK = W_UK
|
||||
self.W_UV = W_UV.to(act_dtype) # fp32 to bfp16
|
||||
self.W_UK = W_UK.to(act_dtype)
|
||||
W_Q = W_Q.to(act_dtype)
|
||||
self.W_Q = W_Q.flatten(start_dim=1)
|
||||
Reference in New Issue
Block a user