Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -13,6 +13,11 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
@@ -37,13 +42,17 @@ from vllm.v1.attention.backends.utils import (
|
||||
)
|
||||
from vllm.v1.attention.ops.flashmla import (
|
||||
FlashMLASchedMeta,
|
||||
flash_mla_sparse_fwd,
|
||||
flash_mla_sparse_prefill,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
import functools
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
import numpy as np
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
@@ -74,7 +83,15 @@ structured as:
|
||||
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
|
||||
part is not quantized for accuracy.
|
||||
"""
|
||||
|
||||
def dynamic_per_batched_tensor_quant(
|
||||
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
|
||||
):
|
||||
DTYPE_MAX = torch.finfo(dtype).max
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
|
||||
scale = DTYPE_MAX / amax
|
||||
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
|
||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
@@ -558,6 +575,11 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
|
||||
self.qk_nope_head_dim = mla_args["qk_nope_head_dim"]
|
||||
self.qk_rope_head_dim = mla_args["qk_rope_head_dim"]
|
||||
self.qk_head_dim = mla_args["qk_head_dim"]
|
||||
self.v_head_dim = mla_args["v_head_dim"]
|
||||
self.kv_b_proj = mla_args["kv_b_proj"]
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
@@ -580,6 +602,65 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
(self.prefill_workspace_shape, torch.bfloat16)
|
||||
)
|
||||
)
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
|
||||
)
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if layer.quant_method is not None and not isinstance(
|
||||
layer.quant_method, UnquantizedLinearMethod
|
||||
):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(
|
||||
layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device,
|
||||
)
|
||||
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}"
|
||||
)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||
)
|
||||
self.W_UV = W_UV
|
||||
self.W_UK = W_UK
|
||||
# self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
def _v_up_proj(self, x: torch.Tensor):
|
||||
|
||||
return torch.einsum("bnl,lnv->bnv", x, self.W_UV)
|
||||
def _k_up_proj(self, q_nope):
|
||||
|
||||
return torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK).view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
@@ -590,12 +671,11 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
) -> torch.Tensor:
|
||||
# Convert per-request indices to global slots (decode) or workspace
|
||||
# offsets (prefill).
|
||||
topk_indices = triton_convert_req_index_to_global_index(
|
||||
topk_indices = ops.dsa_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
||||
attn_metadata.block_size,
|
||||
)
|
||||
|
||||
return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)
|
||||
@@ -790,22 +870,10 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
-1, 1, kv_c_and_k_pe_cache.shape[-1]
|
||||
)
|
||||
|
||||
# NOTE(Chen): kernel requires num_local_head to be a multiple of
|
||||
# 64 on hopper and 128 on blackwell
|
||||
if self.num_heads % self.prefill_padding != 0:
|
||||
assert self.prefill_padding % self.num_heads == 0
|
||||
logger.warning_once(
|
||||
f"Padding num_heads from {self.num_heads} to "
|
||||
f"{self.prefill_padding} for BF16 sparse prefill kernel"
|
||||
)
|
||||
q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
|
||||
q_padded[:, : self.num_heads, :] = q
|
||||
q = q_padded
|
||||
|
||||
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
||||
output = flash_mla_sparse_fwd(
|
||||
output = flash_mla_sparse_prefill(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
|
||||
)[0]
|
||||
)
|
||||
output = output[:, : self.num_heads, :]
|
||||
return output
|
||||
|
||||
@@ -843,5 +911,5 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
attn_out = self._forward_fp8_kv_separate_prefill_decode(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
|
||||
)
|
||||
|
||||
return attn_out, None
|
||||
|
||||
return attn_out
|
||||
|
||||
Reference in New Issue
Block a user