[Feature] support deepseek v3/r1/v3.2 (#78)

* [Feature] support deepseek v3/r1/v3.2

* fix gpt_oss

* update readme

* update readme

---------

Co-authored-by: hanhaowen <hanhaowen@baidu.com>
This commit is contained in:
baoqian426
2026-01-05 22:55:35 +08:00
committed by GitHub
parent 07bc24a555
commit ee0f50e68f
27 changed files with 5760 additions and 621 deletions

View File

@@ -10,34 +10,15 @@ import vllm.envs as envs
OLD_IMPORT_HOOK = builtins.__import__
def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0):
try:
start_time = time.time()
# 模块映射表
module_mappings = {
"vllm.model_executor.layers.fused_moe.layer": "vllm_kunlun.ops.fused_moe.layer",
"vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe": "vllm_kunlun.ops.quantization.compressed_tensors_moe",
"vllm.compilation.wrapper": "vllm_kunlun.compilation.wrapper",
"vllm.v1.worker.gpu_model_runner": "vllm_kunlun.v1.worker.gpu_model_runner"
"vllm.v1.worker.utils": "vllm_kunlun.v1.worker.utils",
"vllm.model_executor.model_loader.bitsandbytes_loader": "vllm_kunlun.models.model_loader.bitsandbytes_loader",
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
"vllm.model_executor.layers.sampler": "vllm_kunlun.ops.sample.sampler",
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
}
# 需要保持原始导入的模块
original_imports = [
"vllm.model_executor.layers.fused_moe.base",
"vllm.model_executor.layers.fused_moe.config",
"vllm.model_executor.layers.fused_moe.layer"
]
if module_name in original_imports:
if module_name == "vllm.model_executor.layers.fused_moe.layer" and fromlist:
if "FusedMoEMethodBase" in fromlist:
return OLD_IMPORT_HOOK(
module_name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level
)
if module_name in module_mappings:
if module_name in sys.modules:
return sys.modules[module_name]
@@ -45,25 +26,6 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
module = importlib.import_module(target_module)
sys.modules[module_name] = module
sys.modules[target_module] = module
return module
relative_mappings = {
("compressed_tensors_moe", "compressed_tensors"): "vllm_kunlun.ops.quantization.compressed_tensors_moe",
("layer", "fused_moe"): "vllm_kunlun.ops.fused_moe.layer",
}
if level == 1:
parent = globals.get('__package__', '').split('.')[-1] if globals else ''
key = (module_name, parent)
if key in relative_mappings:
if module_name in sys.modules:
return sys.modules[module_name]
target_module = relative_mappings[key]
module = importlib.import_module(target_module)
sys.modules[module_name] = module
sys.modules[target_module] = module
return module
except Exception:
pass
@@ -77,79 +39,16 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
def import_hook():
"""Apply import hook for VLLM Kunlun"""
if not int(os.environ.get("DISABLE_KUNLUN_HOOK", "0")):
builtins.__import__ = _custom_import
try:
modules_to_preload = [
"vllm_kunlun.ops.quantization.compressed_tensors_moe",
"vllm_kunlun.ops.fused_moe.custom_ops",
"vllm_kunlun.ops.fused_moe.layer",
"vllm_kunlun.ops.quantization.fp8",
]
for module_name in modules_to_preload:
importlib.import_module(module_name)
except Exception:
pass
builtins.__import__ = _custom_import
def register():
"""Register the Kunlun platform"""
from .utils import redirect_output
from .vllm_utils_wrapper import direct_register_custom_op, patch_annotations_for_schema
patch_bitsandbytes_loader()
import_hook()
if envs.VLLM_USE_V1:
# patch_V1blockTable()
patch_V1top_p_K()
# TODO fixed fast top & k for vLLM 0.10.2,
pass
else:
patch_sampler()
return "vllm_kunlun.platforms.kunlun.KunlunPlatform"
def register_model():
"""Register models for training and inference"""
from .models import register_model as _reg
_reg()
# [monkey patach sampler]
import sys
import sys, importlib, warnings
def patch_bitsandbytes_loader():
try:
# 载入你插件里自定义的 direct_register_custom_op 实现
custom_utils = importlib.import_module("vllm_kunlun.models.model_loader.bitsandbytes_loader")
# 覆盖 vllm.utils
sys.modules["vllm.model_executor.model_loader.bitsandbytes_loader"] = custom_utils
print("[vllm_kunlun] bitsandbytes_loader patched ->", custom_utils.__file__)
except Exception as e:
warnings.warn(f"[vllm_kunlun] bitsandbytes_loader patch failed: {e!r}")
def patch_sampler():
try:
custom_sampler = importlib.import_module("vllm_kunlun.ops.sample.sampler")
sys.modules["vllm.model_executor.layers.sampler"] = custom_sampler
print("[vllm_kunlun] sampler patched ->", custom_sampler.__file__)
except Exception as e:
warnings.warn(f"[vllm_kunlun] sampler patch failed: {e!r}")
def patch_V1top_p_K():
try:
custom_sampler = importlib.import_module("vllm_kunlun.v1.sample.ops.topk_topp_sampler")
sys.modules["vllm.v1.sample.ops.topk_topp_sampler"] = custom_sampler
print("[vllm_kunlun] V1sampler top p & k patched ->", custom_sampler.__file__)
except Exception as e:
warnings.warn(f"[vllm_kunlun] V1 sampler top p & k patch failed: {e!r}")
def patch_V1blockTable():
try:
custom_sampler = importlib.import_module("vllm_kunlun.v1.worker.block_table")
sys.modules["vllm.v1.worker.block_table"] = custom_sampler
print("[vllm_kunlun] V1 block table patched ->", custom_sampler.__file__)
except Exception as e:
warnings.warn(f"[vllm_kunlun] V1 block table patch failed: {e!r}")
# 在模块导入时自动应用补丁
import_hook()
_reg()

View File

@@ -80,6 +80,14 @@ def register_model():
ModelRegistry.register_model(
"GptOssForCausalLM",
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
ModelRegistry.register_model(
"DeepseekV32ForCausalLM",
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
def register_quant_method():
"""to do"""

File diff suppressed because it is too large Load Diff

View File

@@ -16,7 +16,7 @@ from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm_kunlun.ops.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
@@ -176,7 +176,7 @@ class MLPBlock(torch.nn.Module):
x = sequence_parallel_chunk(x)
g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g, linear_weights=self.router.weight)
x = self.experts(hidden_states=x, router_logits=g)
if self.is_sequence_parallel:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)

View File

@@ -21,7 +21,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
)
from vllm.logger import init_logger
from vllm_kunlun.ops.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@@ -185,8 +185,7 @@ class MiMoV2MoE(nn.Module):
gate_input = hidden_states
router_logits = self.gate(gate_input)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits, linear_weights=self.gate.weight
)
hidden_states=hidden_states, router_logits=router_logits)
return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states

View File

@@ -20,4 +20,8 @@ import vllm_kunlun.ops.layernorm
import vllm_kunlun.ops.quantization.awq
import vllm_kunlun.ops.quantization.gptq
import vllm_kunlun.ops.vocab_parallel_embedding
import vllm_kunlun.ops.linear
import vllm_kunlun.ops.linear
import vllm_kunlun.ops.quantization.kernels.scaled_mm.cutlass
import vllm_kunlun.ops.vocab_parallel_embedding
import vllm_kunlun.ops.quantization.compressed_tensors_moe
import vllm_kunlun.ops.fused_moe.layer

View File

@@ -417,7 +417,6 @@ class KunlunOps:
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
linear_weights: torch.Tensor,
ep_rank: int,
moe_top_k: int,
renormalize: bool,

View File

@@ -108,7 +108,7 @@ class SiluAndMul(CustomOp):
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
xtorch_ops.swiglu(x, out)
torch.ops._C.silu_and_mul(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

View File

@@ -0,0 +1,260 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
from typing import Optional, Tuple
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
import xtorch_ops
logger = init_logger(__name__)
if current_platform.is_cuda():
try:
import vllm._flashmla_C # noqa: F401
_flashmla_C_AVAILABLE = True
except ImportError:
_flashmla_C_AVAILABLE = False
else:
_flashmla_C_AVAILABLE = False
if current_platform.is_cuda():
try:
import vllm._flashmla_extension_C # noqa: F401
_flashmla_extension_C_AVAILABLE = True
except ImportError:
_flashmla_extension_C_AVAILABLE = False
else:
_flashmla_extension_C_AVAILABLE = False
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
"""
Return: is_supported_flag, unsupported_reason (optional).
"""
return True, None
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int = 1,
num_heads_k: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
# return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
cache_seqlens_cpu = cache_seqlens.cpu()
return cache_seqlens_cpu, cache_seqlens
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
softmax_lse = None
out = torch.ones(q.size(0), q.size(1), q.size(2), head_dim_v, dtype= q.dtype, device=q.device)
kv_lora_rank = head_dim_v
qk_rope_head_dim = q.size(3) - head_dim_v
head_dim = k_cache.shape[3]
page_block_size = k_cache.shape[1]
k_cache = k_cache.view(-1, 1, page_block_size, head_dim)
# todo: optimize memcp
# q_c = q[..., : kv_lora_rank].contiguous()
# q_r = q[..., kv_lora_rank :].contiguous()
is_context = False
vo_head_dim = -1
xtorch_ops.paged_attention(out,
q,
k_cache, None,
block_table,
tile_scheduler_metadata, # context_lens_cpu
num_splits, # context_lens_xpu
is_context,
causal,
vo_head_dim,
kv_lora_rank,
qk_rope_head_dim,
softmax_scale,
q_r=q)
return out, softmax_lse
def kunlun_flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
cache_seqlens_cpu: torch.Tensor,
head_dim_v: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
max_seq_kv: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_tokens_kv, head_dim).
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format.
indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
max_seq_kv: seq中最大的kv长度
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
max_logits: (batch_size, seq_len_q, num_heads_q), torch.float32.
p_sums: (batch_size, seq_len_q, num_heads_q), torch.float32.
"""
assert not is_fp8_kvcache, "By now, the kernel does not support uint8 kv cache."
assert q.shape[1] <= 2, "xtorch_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if indices is not None:
# NOTE (zyongye): sparse attention is also causal
# since it only attend to the tokens before
# but here `causal` should not be specified
assert not causal, \
"causal must be `false` if sparse attention is enabled."
q_r, pe_cache = None, None # 当q_r和pe_cache为空时为packed模式
batch_size, seq_len_q, num_heads_q, head_dim = q.shape
kv_lora_rank = head_dim_v
rope_head_dim = head_dim - kv_lora_rank
out = torch.zeros([batch_size, seq_len_q, num_heads_q, kv_lora_rank],
dtype=q.dtype, device=q.device)
max_logits = torch.zeros([batch_size, seq_len_q, num_heads_q],
dtype=torch.float32, device=q.device)
p_sums = torch.zeros([batch_size, seq_len_q, num_heads_q],
dtype=torch.float32, device=q.device)
xtorch_ops.fwd_kvcache_mla(
q_c=q,
kv_cache=k_cache,
indices=indices,
kv_lod_cpu=cache_seqlens_cpu,
max_seq_kv=max_seq_kv,
softmax_scale=softmax_scale,
# q_r=q_r,
# pe_cache=pe_cache,
out=out,
max_logits=max_logits,
p_sums=p_sums,
kv_lod_xpu=cache_seqlens,
)
return out, max_logits, p_sums
def flash_mla_sparse_prefill(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
q_lod_xpu: torch.Tensor,
q_lod_cpu: torch.Tensor,
d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
- q: [s_q, h_q, d_qk], bfloat16
- kv: [s_kv, d_qk], bfloat16
- indices: [s_q, h_kv, topk], int32.
Invalid indices should be set to -1 or numbers >= s_kv
- sm_scale: float
- q_lod_xpu: [batch+1], int32, q的每个seq长度的累加信息, 长度为batch_num + 1 (为空则表示q定长).
- d_v: The dimension of value vectors. Can only be 512
Returns:
- (output, max_logits, lse)
About the definition of output,
max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
s_q, h_q, d_qk = q.shape
out = torch.zeros([s_q, h_q, d_v], dtype=q.dtype, device=q.device)
max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
xtorch_ops.sparse_prefill_fwd_opt(
q=q,
kv=kv,
indices=indices,
qlod_cpu=q_lod_cpu,
qlod_xpu=q_lod_xpu,
kvlod_cpu=q_lod_cpu,
kvlod_xpu=q_lod_xpu,
sm_scale=sm_scale,
d_v=d_v,
is_causal=True, #aiak这个值为true这是为啥
out=out,
max_logits=max_logits,
lse=lse,
)
# NOTE: Compared with torch.ops._flashmla_C.sparse_prefill_fwd,
# out_scale = 1 / math.log2(math.e)
# gpu_max_logits * out_scale = kunlun_lse
# gpu_lse * out_scale = kunlun_lse
return out, max_logits, lse
#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
# return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
# return ....
#

View File

@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import torch
from vllm_kunlun.ops.attention.layer import Attention
# from vllm.attention import Attention
from vllm.config import CacheConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
@dataclass
class MLAModules:
"""Modules used in MLA.
"""
kv_a_layernorm: torch.nn.Module
kv_b_proj: torch.nn.Module
rotary_emb: torch.nn.Module
o_proj: torch.nn.Module
fused_qkv_a_proj: Optional[torch.nn.Module]
kv_a_proj_with_mqa: Optional[torch.nn.Module]
q_a_layernorm: Optional[torch.nn.Module]
q_b_proj: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
indexer: Optional[torch.nn.Module]
is_sparse: bool
topk_indices_buffer: Optional[torch.Tensor]
@CustomOp.register("vllm_kunlun_multi_head_latent_attention")
class MultiHeadLatentAttention(CustomOp):
"""MLA layer registered as CustomOp.
Note that currently MLA ignores the enable/disable mechanism of CustomOp
because there is only one in-tree implementation in forward_native.
TODO: implement this with a new PluggableLayer mechanism.
This class takes positions and hidden_states as input.
The input tensors can either contain prefill tokens or decode tokens.
The class does the following:
1. MLA Preprocess.
2. Perform multi-head attention to prefill tokens and
multi-query attention to decode tokens separately.
3. Return the output tensor.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
mla_modules: MLAModules,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
self.q_a_layernorm = mla_modules.q_a_layernorm
self.q_b_proj = mla_modules.q_b_proj
self.q_proj = mla_modules.q_proj
self.kv_a_layernorm = mla_modules.kv_a_layernorm
self.kv_b_proj = mla_modules.kv_b_proj
self.rotary_emb = mla_modules.rotary_emb
self.o_proj = mla_modules.o_proj
self.indexer = mla_modules.indexer
self.is_sparse = mla_modules.is_sparse
if self.indexer is not None:
assert hasattr(self.indexer, "topk_tokens")
self.topk_tokens = self.indexer.topk_tokens
self.topk_indices_buffer = mla_modules.topk_indices_buffer
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self.mla_attn = Attention(
num_heads=self.num_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scale,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=mla_modules.is_sparse,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=self.kv_b_proj,
indexer=self.indexer,
)
self.prefix = prefix
def forward_native(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
q_c = None
kv_lora = None
if self.q_lora_rank is not None:
assert self.fused_qkv_a_proj is not None, \
"fused_qkv_a_proj is required when q_lora_rank is not None"
assert self.q_a_layernorm is not None, \
"q_a_layernorm is required when q_lora_rank is not None"
assert self.q_b_proj is not None, \
"q_b_proj is required when q_lora_rank is not None"
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
assert self.kv_a_proj_with_mqa is not None, \
"kv_a_proj_with_mqa is required when q_lora_rank is None"
assert self.q_proj is not None, \
"q_proj is required when q_lora_rank is None"
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c)
q = q.view(-1, self.num_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions,
self.rotary_emb)
hidden_states_shape_0 = 0
if isinstance(hidden_states, tuple):
x_q, x_scale = hidden_states
hidden_states_shape_0 = x_q.shape[0]
else:
hidden_states_shape_0 = hidden_states.shape[0]
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states_shape_0,
self.num_heads * self.v_head_dim))
return self.o_proj(attn_out)[0]
def forward_cuda(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)

View File

@@ -0,0 +1,114 @@
import torch
import xtorch_ops
def int8_mqa_logits(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
logits = torch.empty((q.shape[0], kv[0].shape[0]), dtype=torch.float32, device=q.device)
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
xtorch_ops.I8_mqa_logits(
q=q,
fused_kv_cache=kv,
weights=weights,
context_q_lens=(context_q_lens_xpu.cpu(), context_q_lens_xpu),
context_k_lens=(context_k_lens_xpu.cpu(), context_k_lens_xpu),
logits=logits,
clean_logits=True,
use_xfa_boost=False,
)
seq_len_kv = kv[0].shape[0]
# mask参考 https://github.com/vllm-project/vllm/blob/v0.11.0/tests/kernels/attention/test_deepgemm_attention.py 的_ref_fp8_mqa_logits函数的实现
mask_lo = (torch.arange(0, seq_len_kv, device=cu_seqlen_ks.device)[None, :]
>= cu_seqlen_ks[:, None])
mask_hi = (torch.arange(0, seq_len_kv, device=cu_seqlen_ke.device)[None, :]
< cu_seqlen_ke[:, None])
mask = mask_lo & mask_hi
logits = logits.masked_fill(~mask, float("-inf"))
return logits
def int8_paged_mqa_logits(
q_fp8: torch.Tensor,
kv_cache_fp8: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
context_lens_cpu: torch.Tensor,
block_tables: torch.Tensor,
schedule_metadata: torch.Tensor,
max_model_len: int,
) -> torch.Tensor:
"""Compute FP8 MQA logits using paged KV-cache.
Args:
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
4 bytes per (block,pos) store the `float` dequant scale.
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
context_lens: Tensor of shape [B], dtype int32; effective context length
for each batch element.
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
block indices to physical blocks in the paged cache.
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
used to distribute work across SMs.
max_model_len: Maximum sequence length used to size the logits output.
Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""
batch_size, next_n, _, D = q_fp8.shape
num_blocks, block_size, _, _ = kv_cache_fp8.shape
kv_cache_fp8=kv_cache_fp8.view(num_blocks, -1)
k_val = kv_cache_fp8[:,:block_size*D].view(torch.int8)
k_val = k_val.view(-1,block_size, 1, D)
k_scale_list = []
for block_tables_idx in range(block_tables.shape[0]):
k_scale_item = kv_cache_fp8[block_tables[block_tables_idx], block_size *
D:].view(-1, 4)
k_scale_list.append(k_scale_item)
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).view(-1,max_model_len)
kv_cache = [k_val, k_scale]
weights = weights.view(batch_size,next_n,-1)
logits = torch.empty((batch_size, next_n, max_model_len), dtype=torch.float32, device=q_fp8.device)
xtorch_ops.I8_paged_mqa_logits(
q=q_fp8,
fused_kv_cache=kv_cache,
weights=weights,
context_lens=[context_lens_cpu, context_lens],
block_table=block_tables,
max_context_len=max_model_len,
clean_logits=True,
out=logits,
use_xfa_boost=False
)
logits = logits.view(-1, max_model_len)
return logits

View File

@@ -1,37 +1,14 @@
"""layer.py"""
from contextlib import nullcontext
from typing import Callable, Optional, Union, get_args
import torch
import os
from typing import Callable, Optional
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed import get_ep_group
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.model_executor.layers.fused_moe import FusedMoE as VllmFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase as VllmFusedMoEMethodBase
from vllm.model_executor.layers.fused_moe.layer import (
UnquantizedFusedMoEMethod as VllmUnquantizedFusedMoEMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig)
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm_kunlun.ops.quantization.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
class FusedMoEMethodBase(VllmFusedMoEMethodBase):
"""FusedMoEMethodBase"""
moe: FusedMoEConfig
@CustomOp.register("vllm_kunlun_unquantized_fused_moe")
class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
"""UnquantizedFusedMoEMethod"""
def apply(
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
@@ -45,6 +22,7 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
@@ -52,40 +30,12 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
linear_weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""apply"""
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
return self.forward_kunlun(x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
linear_weights=linear_weights,
e_score_correction_bias=e_score_correction_bias)
def forward_kunlun(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
linear_weights: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""forward_kunlun"""
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
if self.moe.use_ep:
@@ -93,21 +43,18 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
layer.w13_weight,
layer.w2_weight,
router_logits,
linear_weights,
self.moe.ep_rank,
top_k,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group
)
topk_group=topk_group)
else:
return ops.fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
linear_weights,
self.moe.ep_rank,
top_k,
renormalize=renormalize,
@@ -118,12 +65,13 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
w1_bias = layer.w13_bias,
w2_bias = layer.w2_bias,
)
w2_bias = layer.w2_bias)
class FusedMoE(VllmFusedMoE):
"""FusedMoE"""
def __init__(self,
UnquantizedFusedMoEMethod.apply = apply
class VllmFusedMoE(FusedMoE):
def __init__(
self,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
@@ -141,198 +89,47 @@ class FusedMoE(VllmFusedMoE):
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
num_redundant_experts: int = 0,
is_sequence_parallel=False,
has_bias: bool = False,
is_sequence_parallel=False,
zero_expert_num: Optional[int] = 0,
zero_expert_type: Optional[str] = None,
):
super().__init__(
num_experts=num_experts, # Global number of experts
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=reduce_results,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
quant_config=quant_config,
tp_size=tp_size,
ep_size=ep_size,
dp_size=dp_size,
prefix=prefix,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
enable_eplb=enable_eplb,
num_redundant_experts=num_redundant_experts,
)
vllm_config = get_current_vllm_config()
if vllm_config.model_config is not None:
model_dtype = vllm_config.model_config.dtype
else:
# TODO (bnell): This is a hack to get test_mixtral_moe to work
# since model_config is not set in the pytest test.
model_dtype = params_dtype
moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
num_experts=num_experts, # Global number of experts
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=reduce_results,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
quant_config=quant_config,
tp_size=tp_size,
ep_size=ep_size,
dp_size=dp_size,
prefix=prefix,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
enable_eplb=enable_eplb,
num_redundant_experts=num_redundant_experts,
has_bias=has_bias,
# quant_config=quant_config,
)
self.moe_config = moe
self.quant_config = quant_config
is_sequence_parallel=is_sequence_parallel,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type)
self.has_bias=has_bias
self.register_parameter("w13_bias", None)
self.register_parameter("w2_bias", None)
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
# assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
if self.enable_eplb:
from vllm_kunlun.ops.quantization.fp8 import (
Fp8MoEMethod)
if not isinstance(quant_method, Fp8MoEMethod):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
# design causes duplicated work when extending to new
# quantization methods, so I'm leaving it for now.
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError("EPLB is only supported for FP8 "
"quantization for now.")
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor = None,
linear_weights: torch.Tensor = None):
"""forward"""
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits)
else:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[self.layer_name]
assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits, linear_weights)
# return torch.ops.vllm.moe_forward(hidden_states, router_logits,
# self.layer_name)
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor,
linear_weights: torch.Tensor = None):
"""forward_impl"""
assert self.quant_method is not None
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
return self.forward_impl_chunked(hidden_states, router_logits)
do_naive_dispatch_combine: bool = (
self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels)
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
linear_weights=linear_weights
)
if do_naive_dispatch_combine:
final_hidden_states = get_ep_group().combine(final_hidden_states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
return final_hidden_states
@classmethod
def make_expert_params_mapping(
cls,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]:
num_physical_experts = num_experts + num_redundant_experts
# In the returned mapping:
# - `expert_id` is the physical expert id
# - `weight_name` contains the weight name of the logical expert
# So that we should map the expert id to logical in `weight_name`
physical_to_logical_map = \
EplbState.build_initial_global_physical_to_logical_map(
num_experts, num_redundant_experts)
return [
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_" if weight_name
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.",
expert_id, shard_id) for expert_id in range(num_physical_experts)
for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]
FusedMoE=VllmFusedMoE

View File

@@ -1,244 +1,169 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from enum import Enum
from typing import Callable, Optional, Union
import torch
from typing import Any, Literal, Optional, cast, Callable, Optional
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
from compressed_tensors.config import (CompressionFormat,
SparsityCompressionConfig,
SparsityStructure)
from compressed_tensors.quantization import (ActivationOrdering,
QuantizationStrategy)
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.utils import replace_parameter
# TODO: import position will be changed after 0.9.0
# vllm.model_executor.layers.fused_moe.fused_moe --> vllm.model_executor.layers.fused_moe
def klx_process_weights_after_loading(layer: torch.nn.Module) -> None:
"""modify scale -> abs max"""
layer.w13_weight = torch.nn.Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
layer.w13_weight_scale.data * 127, requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
layer.w2_weight_scale.data * 127, requires_grad=False
)
from vllm.model_executor.utils import set_weight_attrs
import re
import xtorch_ops
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
klx_process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
hidden_states = x
global_num_experts, up_gate_size, _ = layer.w13_weight.shape
M, N = hidden_states.shape
hidden_dim = layer.w2_weight.shape[1]
normed_score = torch.empty(M,
top_k,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
top_k,
dtype=torch.int32,
device=hidden_states.device)
num_blocks = 12
block_statistic = torch.zeros(
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
)
from safetensors.torch import load_file as safe_load_file
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def get_moe_method(quant_config, layer) -> "CompressedTensorsMoEMethod":
tsm = getattr(quant_config, "target_scheme_map", None) or {}
linear_cfg = None
for k in ("Linear", "FusedMoE", "MoE", "Moe", "Experts"):
if k in tsm and isinstance(tsm[k], dict):
linear_cfg = tsm[k]; break
if not linear_cfg:
# print("target_scheme_map missing; fallback to INT8(W8A8) method")
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
wq = linear_cfg.get("weights"); aq = linear_cfg.get("input_activations")
if not wq or not aq:
# print("incomplete scheme; fallback to INT8(W8A8)")
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
# 其它分流按需;默认回落:
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
# copied from vllm 0.9.0
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
# 直接创建默认的量化配置字典,避免 QuantizationArgs 的验证问题
# print("Creating default INT8 quantization config for MoE")
# 创建默认的权重量化配置字典
self.weight_quant = type('WeightQuant', (), {
'type': 'int',
'num_bits': 8,
'strategy': 'channel',
'group_size': 128,
'symmetric': True,
'dynamic': False,
'actorder': 'none',
'observer': None,
'observer_kwargs': {},
'block_structure': None
})()
# 创建默认的输入激活量化配置字典
self.input_quant = type('InputQuant', (), {
'type': 'int',
'num_bits': 8,
'strategy': 'token',
'group_size': 128,
'symmetric': True,
'dynamic': True,
'actorder': 'none',
'observer': None,
'observer_kwargs': {},
'block_structure': None
})()
# 修改比较方式,直接比较字符串
per_channel = (
self.weight_quant.strategy == "channel"
and self.input_quant.strategy == "token")
if not per_channel:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
def create_weights1(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
# 权重先用浮点占位,便于从 ckpt 加载原始权重
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype), # 通常是 torch.bfloat16
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# 通道 scalefloat32 + 二维 [E, out](与 fused_moe/UT 对齐)
w13_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
requires_grad=False)
w2_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# 输入 scale 动态计算即可
layer.w13_input_scale = None
layer.w2_input_scale = None
def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=torch.int8), # 直接使用 int8
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=torch.int8), # 直接使用 int8
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# 缩放因子
w13_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
requires_grad=False)
w2_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# 输入 scale 动态计算
layer.w13_input_scale = None
layer.w2_input_scale = None
@torch.no_grad()
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
return
#原始权重转 float32 做统计更稳健
w13_f = layer.w13_weight.float()
w2_f = layer.w2_weight.float()
# 每列(abs_max) -> per-column scaleout 维在 dim=1列在 dim=-1
qmax = 127.0
w13_abs_max = torch.amax(torch.abs(w13_f), dim=-1) # [E, 2N]
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H]
w13_scale_2d = torch.clamp(w13_abs_max, min=1e-6) / qmax # [E, 2N], float32
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32
# 量化:用 3D scale 广播,存回 2D scale
w13_scale_3d = w13_scale_2d.unsqueeze(-1) # [E, 2N, 1]
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1]
w13_q = torch.round(w13_f / w13_scale_3d).clamp_(-128, 127).to(torch.int8)
w2_q = torch.round(w2_f / w2_scale_3d ).clamp_(-128, 127).to(torch.int8)
# 可选:若你的 fused/kernel 期望 scale 预乘 127与某些 UT 后端一致),打开下面两行:
w13_scale_2d = w13_scale_2d * 127.0
w2_scale_2d = w2_scale_2d * 127.0
# 回写参数:权重 int8scale 用 float32 + 2D
replace_parameter(layer, 'w13_weight', torch.nn.Parameter(w13_q, requires_grad=False))
replace_parameter(layer, 'w2_weight', torch.nn.Parameter(w2_q, requires_grad=False))
replace_parameter(layer, 'w13_weight_scale',
torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False))
replace_parameter(layer, 'w2_weight_scale',
torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False))
# 简要检查
print(f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}")
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False, # 添加这个参数
expert_load_view: Optional[torch.Tensor] = None, # 添加这个参数
logical_to_physical_map: Optional[torch.Tensor] = None, # 添加这个参数
logical_replica_count: Optional[torch.Tensor] = None, # 添加这个参数
linear_weights: Optional[torch.Tensor] = None, # 添加这个参数
) -> torch.Tensor:
output = torch.empty_like(x)
torch.ops._C.moe_ffn_per_token_block(
x=x,
inter_weight=layer.w13_weight,
inter_scale=layer.w13_weight_scale,
outer_weight=layer.w2_weight,
outer_scale=layer.w2_weight_scale,
top_k=top_k,
global_num_experts=global_num_experts,
linear_weights=linear_weights,
expert_map=expert_map,
activation=activation,
output=output,
use_expert_parallel=expert_map is not None,
ep_size=expert_map.size(0) if expert_map is not None else 1,
ep_rank=0,
router_logits = router_logits.float()
if scoring_func == "softmax":
torch.ops._C.moe_softmax_topk_norm(
x=router_logits,
normed_score=normed_score,
topk_index=topk_ids,
block_statistic=None,
stable=True)
elif scoring_func == "sigmoid":
torch.ops._C.moe_sigmoid_group_topk_norm(
x=router_logits,
norm_score=normed_score,
topk_index=topk_ids,
block_static=block_statistic,
bias=e_score_correction_bias,
n_group=num_expert_group,
topk_group=topk_group,
scale=routed_scaling_factor,
)
return output
print("[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
--> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod")
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod)
y = torch.empty(M,top_k,
layer.w13_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
moe_expand = moe_expand.view(M * top_k, hidden_dim)
x_shape = moe_expand.shape
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device)
torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True)
torch.ops._C.moe_fc(
x=x_q,
x_perchannel_max=x_scale,
weight=layer.w13_weight,
w_perchannel_max=layer.w13_weight_scale,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=y,
topk_ids=topk_ids,
# sort_mode=False,
act=None)
d = y.shape[-1] // 2
output_shape = (y.shape[:-1] + (d, ))
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out = torch.empty(M,top_k,
layer.w2_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
out1 = out1.reshape(-1, out1.shape[-1])
x_shape = out1.shape
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device)
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
torch.ops._C.moe_fc(
x=x_q,
x_perchannel_max=x_scale,
weight=layer.w2_weight,
w_perchannel_max=layer.w2_weight_scale,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=out,
topk_ids=topk_ids,
# sort_mode=False,
act=None)
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
torch.ops._C.moe_post(
x=out,
moe_index=sorted_tokens_idx,
normed_scale=normed_score,
dequant_scale=dequant_scale,
y=output
)
return output
CompressedTensorsW8A8Int8MoEMethod.process_weights_after_loading = process_weights_after_loading
CompressedTensorsW8A8Int8MoEMethod.apply = apply

View File

@@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ScaledMMLinearLayerConfig
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import CutlassScaledMMLinearKernel
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise)
def can_implement_kunlun(
cls, c: ScaledMMLinearLayerConfig=None) -> tuple[bool, Optional[str]]:
return True, None
def klx_process_weights_after_loading(layer: torch.nn.Module) -> None:
"""modify scale -> abs max"""
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data * 127, requires_grad=False)
def process_weights_after_loading_kunlun(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Cutlass kernels need transposed weight.
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer, self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False))
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale,
layer.logical_widths)
replace_parameter(
layer, self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False))
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
if self.config.input_symmetric:
replace_parameter(
layer, self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False))
setattr(layer, self.i_zp_name, None)
else:
input_zero_point = getattr(layer, self.i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max -
int8_traits.min)
replace_parameter(
layer, self.i_s_name,
torch.nn.Parameter(scale, requires_grad=False))
# AZP loaded as int8 but used as int32
azp = (int8_traits.min -
range_min / scale).to(dtype=torch.int32)
replace_parameter(layer, self.i_zp_name,
torch.nn.Parameter(azp, requires_grad=False))
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
if not self.config.input_symmetric:
weight = getattr(layer, self.w_q_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if self.config.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
setattr(layer, self.azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False))
else:
setattr(layer, self.azp_adj_name, None)
klx_process_weights_after_loading(layer)
def apply_weights_kunlun(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
x_q, x_scale, out = None, None, None
w_t_shape = layer.weight.T.shape
if isinstance(x, tuple):
x_q, x_scale = x
out = torch.empty((x_q.shape[0], w_t_shape[0]),
dtype=torch.bfloat16,
device=x_q.device)
else:
x_shape = x.shape
x_q = torch.empty(x_shape, dtype=torch.int8, device=x.device)
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=x.device)
out = torch.empty((x_shape[0], w_t_shape[0]),
dtype=x.dtype,
device=x.device)
torch.ops._C.quant2d(x, x_q, x_scale, force_sdnn=True)
torch.ops._C.gemm_I8_I8_bf16_nt(x_q, x_scale, layer.weight.T.data, layer.weight_scale.data, out)
return out
CutlassScaledMMLinearKernel.apply_weights = apply_weights_kunlun
CutlassScaledMMLinearKernel.can_implement = can_implement_kunlun
CutlassScaledMMLinearKernel.process_weights_after_loading = process_weights_after_loading_kunlun

View File

@@ -19,7 +19,9 @@ import torch
import xspeedgate_ops
import os
from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding, YaRNScalingRotaryEmbedding, DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding)
RotaryEmbedding, YaRNScalingRotaryEmbedding,
DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding,
DeepseekScalingRotaryEmbedding)
from typing import Optional, Tuple
def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
@@ -143,12 +145,15 @@ def vllm_kunlun_mrope_forward_cuda(
return query, key
DeepseekScalingRotaryEmbedding_forward = DeepseekScalingRotaryEmbedding.forward
DeepseekScalingRotaryEmbedding_forward_cuda = DeepseekScalingRotaryEmbedding.forward_cuda
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
DeepseekScalingRotaryEmbedding.forward = DeepseekScalingRotaryEmbedding_forward
DeepseekScalingRotaryEmbedding.forward_cuda = DeepseekScalingRotaryEmbedding_forward_cuda
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
def Split_Norm_Rope(
qkv: torch.Tensor,
cos_sin_cache: torch.Tensor,

View File

@@ -177,6 +177,8 @@ class KunlunPlatform(Platform):
# if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
# we default to FlashMLA backend, so we need to force the blocksize
# here
use_sparse = hasattr(vllm_config.model_config.hf_config,
"index_topk")
use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \
or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
from vllm.attention.ops.flashmla import is_flashmla_supported
@@ -185,6 +187,11 @@ class KunlunPlatform(Platform):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLA backend.")
if use_sparse and cache_config.block_size != 64:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLASparse "
"backend.")
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
and parallel_config.data_parallel_size > 1
@@ -224,6 +231,14 @@ class KunlunPlatform(Platform):
Returns:
str: Class name of the attention backend.
"""
if use_mla:
if use_sparse:
logger.info_once("Using Sparse MLA backend on V1 engine.")
# return ("vllm.v1.attention.backends.mla.flashmla_sparse."
# "FlashMLASparseBackend")
return ("vllm_kunlun.v1.attention.backends.mla.flashmla_sparse."
"FlashMLASparseBackend")
return "vllm_kunlun.v1.attention.backends.mla.flashmla.FlashMLABackend"
if use_v1:
return "vllm_kunlun.v1.attention.backends.kunlun_attn.KunlunAttentionBackend"
elif not use_mla:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,202 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional, Union
import torch
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm_kunlun.ops.attention.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm_kunlun.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "FLASHMLA"
@staticmethod
def get_metadata_cls() -> type["FlashMLAMetadata"]:
return FlashMLAMetadata
@staticmethod
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
tile_scheduler_metadata: torch.Tensor
num_splits: torch.Tensor
@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
pass
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
FlashMLAMetadata)
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
device_properties = torch.cuda.get_device_properties(self.device)
num_sms = device_properties.multi_processor_count
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.cg_buf_tile_scheduler_metadata = torch.zeros(
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
# TileSchedulerMetaDataSize = 8
(num_sms, 8),
device=self.device,
dtype=torch.int32,
)
self.cg_buf_num_splits = torch.empty(
(vllm_config.scheduler_config.max_num_seqs + 1),
device=self.device,
dtype=torch.int32)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens_device,
self.num_q_heads,
1, # MQA for the decode path
)
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None
sm_parts = tile_scheduler_metadata.size(0)
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
tile_scheduler_metadata_view = \
self.cg_buf_tile_scheduler_metadata[:sm_parts]
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
tile_scheduler_metadata = tile_scheduler_metadata_view
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
# Num splits needs to monotonically increasing
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
# it needs to monotonically increasing by 1)
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
num_splits = num_splits_view
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
)
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
is_supported, reason = is_flashmla_supported()
assert is_supported, reason
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashMLAImpl")
def _forward_decode(
self,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
# TODO: (zyongye) decode function for mla here
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q = torch.cat(q, dim=-1)
assert isinstance(q, torch.Tensor)
o, lse = flash_mla_with_kvcache(
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1),
)
return o, lse

View File

@@ -0,0 +1,752 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional
import numpy as np
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata)
from vllm.attention.backends.utils import get_mla_dims
from vllm_kunlun.ops.attention.flashmla import (flash_mla_sparse_prefill,
flash_mla_with_kvcache,
get_mla_metadata,
kunlun_flash_mla_with_kvcache)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.distributed import get_tp_group
if TYPE_CHECKING:
from vllm.model_executor.models.deepseek_v2 import Indexer
logger = init_logger(__name__)
"""
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
structured as:
- **First 512 bytes:** The "quantized NoPE" part, containing 512
`float8_e4m3` values.
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
The first `float32` is the scale for the first 128 `float8_e4m3` values,
the second for the next 128, and so on.
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
part is not quantized for accuracy.
"""
def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor:
# Convert base-2 LSE to natural-log LSE
# Keep FP32 for numerical stability during the merge.
return (lse_base2.to(torch.float32) * math.log(2.0))
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "FLASHMLA_SPARSE_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type[AttentionMetadata]:
return FlashMLASparseMetadata
@staticmethod
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
return FlashMLASparseMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashMLASparseImpl"]:
return FlashMLASparseImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if cache_dtype_str == "fp8_ds_mla":
# custom storage fromat is 656 bytes
# see FlashMLA readme.md for details
return (num_blocks, block_size, 656)
else:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@dataclass
class MLASparsePrefillMetadata:
# NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because
# the kernel is not from flashmla
block_table: torch.Tensor = None
has_context: bool = False
context_lens: Optional[torch.Tensor] = None
# Sequence lengths (context + query) for prefill requests
# Shape: [num_prefill_reqs]
seq_lens: torch.Tensor = None
# Request ID for each token: -1 for decode tokens, request index
# (0, 1, 2, ...) for prefill tokens.
# Shape: [num_actual_tokens]
request_ids: torch.Tensor = None
query_start_loc: torch.Tensor = None
query_start_loc_cpu: torch.Tensor = None
@dataclass
class FlashMLASparseDecodeAndContextMetadata:
scheduler_metadata: torch.Tensor = None
num_splits: torch.Tensor = None
cache_lens: torch.Tensor = None
prefill_context_lengths: Optional[torch.Tensor] = None
prefill_new_k_start_locs: Optional[torch.Tensor] = None
dummy_block_table: torch.Tensor = None
seq_lens: torch.Tensor = None
seq_lens_cpu: torch.Tensor = None
max_seq_len: int = -1 # needed for reshape in spec decode
def filter_prefill_indices(
self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.prefill_context_lengths is not None
prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1)
context_indices = torch.where(indices < prefill_context_lengths,
indices, -1)
new_token_indices = torch.where(indices >= prefill_context_lengths,
indices - prefill_context_lengths, -1)
return context_indices, new_token_indices
@dataclass
class FlashMLASparseMetadata:
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
req_id_per_token: torch.Tensor
block_size: int = 64
topk_tokens: int = 2048
num_prefills: int = 0
num_decodes: int = 0
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
decode_metadata: Optional[FlashMLASparseDecodeAndContextMetadata] = None
prefill_metadata: Optional[MLASparsePrefillMetadata] = None
@dataclass
class FP8KernelMetadata:
scheduler_metadata: Optional[torch.Tensor]
num_splits: torch.Tensor
dummy_block_table: torch.Tensor
cache_lens: torch.Tensor
fp8_extra_metadata: Optional[FP8KernelMetadata] = None
@triton.jit
def _convert_req_index_to_global_index_kernel(
req_id_ptr, # int32 [num_tokens]
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
# shapes (compile-time where possible)
max_num_blocks_per_req: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr, # tile width along columns
# strides (in elements)
bt_stride0,
bt_stride1,
ti_stride0,
ti_stride1,
out_stride0,
out_stride1,
):
# program_id(0) -> token_id (row)
# program_id(1) -> tile index along columns
token_id = tl.program_id(0)
tile_id = tl.program_id(1)
# Each program covers BLOCK_N consecutive columns
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
# Load request id for this token (no mask: grid is exact)
req = tl.load(req_id_ptr + token_id)
# Load token indices for this tile
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
tok = tl.load(ti_ptr) # int32
# Only token == -1 should propagate as -1
is_invalid_tok = tok < 0
# Compute block id and in-block offset
block_id = tok // BLOCK_SIZE
inblock_off = tok % BLOCK_SIZE
# Guard block_table access
valid_block = block_id < max_num_blocks_per_req
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
base = tl.load(bt_ptr, mask=valid_block, other=0)
# If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
out_val = tl.where(is_invalid_tok | (~valid_block), -1,
base * BLOCK_SIZE + inblock_off)
# Store results
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
tl.store(out_ptr_ij, out_val)
def triton_convert_req_index_to_global_index(
req_id: torch.Tensor, # int32 [num_tokens]
block_table: torch.
Tensor, # int32 [num_requests, max_num_blocks_per_req]
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
BLOCK_SIZE: int = 64,
NUM_TOPK_TOKENS: int = 2048,
BLOCK_N: int = 128, # tile width along columns
):
"""
out[token_id, indice_id] =
block_table[req_id[token_id],
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
+ token_indices[token_id, indice_id] % BLOCK_SIZE
Only when token_indices[token_id, indice_id] == -1 do we output -1.
For safety, we also output -1 if the derived block_id would be
out-of-bounds.
"""
assert req_id.dtype == torch.int32
assert block_table.dtype == torch.int32
assert token_indices.dtype == torch.int32
assert token_indices.shape[1] == NUM_TOPK_TOKENS
assert NUM_TOPK_TOKENS % BLOCK_N == 0, \
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \
f"BLOCK_N ({BLOCK_N})"
num_tokens = req_id.shape[0]
num_requests, max_num_blocks_per_req = block_table.shape
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
# Ensure contiguous tensors on the same device
req_id_c = req_id.contiguous()
block_table_c = block_table.contiguous()
token_indices_c = token_indices.contiguous()
out = torch.empty_like(token_indices_c)
# Strides in elements
bt_stride0, bt_stride1 = block_table_c.stride()
ti_stride0, ti_stride1 = token_indices_c.stride()
out_stride0, out_stride1 = out.stride()
# Exact 2D grid: tokens × column tiles
grid = (num_tokens, tiles_per_row)
_convert_req_index_to_global_index_kernel[grid](
req_id_c,
block_table_c,
token_indices_c,
out,
# shapes / constexprs
max_num_blocks_per_req,
BLOCK_SIZE,
BLOCK_N,
# strides
bt_stride0,
bt_stride1,
ti_stride0,
ti_stride1,
out_stride0,
out_stride1,
)
return out
def kunlun_convert_req_index_to_global_index(
req_id: torch.Tensor, # int32 [num_tokens]
block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
BLOCK_SIZE: int = 64,
NUM_TOPK_TOKENS: int = 2048,
):
assert req_id.dtype == torch.int32
assert block_table.dtype == torch.int32
assert token_indices.dtype == torch.int32
assert token_indices.shape[1] == NUM_TOPK_TOKENS
num_tokens = req_id.shape[0]
num_requests, max_num_blocks_per_req = block_table.shape
out = torch.zeros_like(token_indices)
# Compute block_id and inblock_off for all tokens at once
block_id = token_indices // BLOCK_SIZE
inblock_off = token_indices % BLOCK_SIZE
# Create mask for invalid tokens (tok < 0)
invalid_tok_mask = token_indices < 0
# Create mask for out-of-bounds block_id
oob_block_mask = block_id >= max_num_blocks_per_req
# Combine masks - output -1 for either condition
invalid_mask = invalid_tok_mask | oob_block_mask
# Get request IDs expanded to match token_indices shape
req_ids_expanded = req_id.unsqueeze(1).expand(-1, NUM_TOPK_TOKENS)
# Gather base addresses from block_table
# Clamp block_id to avoid index errors (we'll mask these out anyway)
block_id_clamped = torch.clamp(block_id, 0, max_num_blocks_per_req - 1)
# Use advanced indexing to get base addresses
base_addrs = block_table[req_ids_expanded, block_id_clamped]
# Compute the global indices
global_indices = base_addrs * BLOCK_SIZE + inblock_off
# Apply mask: set invalid positions to -1
out = torch.where(invalid_mask, torch.tensor(-1, dtype=torch.int32, device=token_indices.device), global_indices)
return out
def kunlun_concat_and_cache_mla(
kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
k_pe: torch.Tensor, #[num_tokens, pe_dim]
kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
kv_cache_dtype: str,
scale: torch.Tensor
):
num_tokens = slot_mapping.shape[0]
kv_lora_rank = kv_c.shape[1]
pe_dim = k_pe.shape[1]
block_size = kv_cache.shape[1]
def kunlun_fp8_ds_mla():
for token_idx in range(num_tokens):
slot = slot_mapping[token_idx].item()
if slot < 0: continue
block_idx = slot // block_size
block_offset = slot % block_size
kv_c_i = kv_c[token_idx].view(4,kv_lora_rank//4).contiguous()
kv_c_i_int8 = torch.zeros(
kv_c_i.shape,
device=kv_c.device,
dtype=torch.int8,
)
kv_c_i_scale = torch.zeros(
[kv_c_i.shape[0], 1],
device=kv_c.device,
dtype=torch.float32,
)
torch.ops._C.quant2d(kv_c_i, kv_c_i_int8, kv_c_i_scale, force_sdnn=True)
kv_c_i_scale /= 127
kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c_i_int8.view(-1).view(torch.uint8).contiguous()
kv_cache[block_idx, block_offset, kv_lora_rank:kv_lora_rank + 16] = kv_c_i_scale.view(-1).view(torch.uint8).contiguous()
kv_cache[block_idx, block_offset, kv_lora_rank+16:] = k_pe[token_idx, :].view(torch.uint8).contiguous()
def kunlun_mla():
for token_idx in range(num_tokens):
slot = slot_mapping[token_idx].item()
if slot < 0: continue
block_idx = slot // block_size
block_offset = slot % block_size
kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c[token_idx, :].contiguous()
kv_cache[block_idx, block_offset, kv_lora_rank:] = k_pe[token_idx, :].contiguous()
if (kv_cache_dtype == "fp8_ds_mla"):
assert kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla"
assert pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla"
assert kv_cache.shape[2] == 656 // kv_cache.element_size(), "kv_cache.shape[2] must be 656 bytes for fp8_ds_mla"
assert kv_c.element_size() == 2, "kv_c.element_size() must be 2 for fp8_ds_mla"
assert k_pe.element_size() == 2, "k_pe.element_size() must be 2 for fp8_ds_mla"
kunlun_fp8_ds_mla()
else:
assert kv_cache.shape[2] == kv_lora_rank + pe_dim
kunlun_mla()
@dataclass
class FlashMLASparseMetadataBuilder(
AttentionMetadataBuilder[FlashMLASparseMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.layer_names = layer_names
cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.device = device
# Treat requests with query length <= 1 as decodes to match the
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
# 从最新版本vllm中引入的
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
props = torch.cuda.get_device_properties(device)
sm_count = props.multi_processor_count
self.num_heads = self.model_config.get_num_attention_heads(
parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
self.topk_tokens_tensor = torch.tensor([self.topk_tokens],
device=device,
dtype=torch.int32)
# self.max_model_len_tensor = torch.tensor(
# [self.model_config.max_model_len],
# device=device,
# dtype=torch.int32)
# this is ignored by `flash_mla_with_kvcache` if indices not None
self.dummy_block_table = torch.empty((1, 1),
dtype=torch.int32,
device=self.device)
# Equation taken from FlashMLA/csrc/pybind.cpp
h_q, h_k = self.num_heads, 1
s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest
max_num_sm_parts = int(
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1))
if current_platform.is_device_capability(100):
max_num_sm_parts *= 2
self.tile_scheduler_metadata_buffer = torch.zeros(
# TileSchedulerMetaDataSize = 8
# see: FlashMLA/csrc/params.h
(max_num_sm_parts, 8),
dtype=torch.int32,
device=device)
self.num_splits_buffer = torch.zeros(
# We pack all the tokens into one batch for sparse attention.
# Otherwise, we can exceed the sm of `get_mla_metadata`.
(
2, ),
dtype=torch.int32,
device=device)
self.req_id_per_token_buffer = torch.zeros(
(vllm_config.scheduler_config.max_num_batched_tokens, ),
dtype=torch.int32,
device=device)
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashMLASparseMetadata:
num_tokens = common_attn_metadata.num_actual_tokens
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
dtype=np.int32)
seg_lengths = np.diff(starts)
req_id_per_token = np.repeat(
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths)
# Zero-fill for cudagraphs
self.req_id_per_token_buffer.fill_(0)
self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\
.copy_(torch.from_numpy(req_id_per_token), non_blocking=True)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
fp8_extra_metadata = None
if self.use_fp8_kv_cache:
cache_seqlens_cpu, cache_seqlens = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor,
)
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=None,
num_splits=None,
# cache_lens and block_table are basically unused in sparse case
# but the decode kernel will treat -1 and indices >= cache_lens
# as invalid so we make sure cache_lens is large enough to not
# accidentally mark indices invalid, we will use -1 exclusively
# to mark invalid indices
cache_lens=cache_seqlens_cpu,
dummy_block_table=self.dummy_block_table)
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold or 1,
require_uniform=True,
)
)
# For pure decode batches, prefill_request_id will be None
# For mixed batches, it will have -1 for decode and request_id for prefill
prefill_metadata = None
if num_prefills > 0:
prefill_metadata = MLASparsePrefillMetadata(
query_start_loc = common_attn_metadata.query_start_loc[num_decodes:] - common_attn_metadata.query_start_loc[num_decodes], #因为prefiil、decode请求是分离所以需要对q进行切分故需调整该值
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[num_decodes:] - common_attn_metadata.query_start_loc_cpu[num_decodes],
)
decode_metadata = None
if num_decodes > 0:
max_seq_len = int(common_attn_metadata.seq_lens_cpu[:num_decodes].max())
decode_metadata = FlashMLASparseDecodeAndContextMetadata(
max_seq_len=max_seq_len,
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
seq_lens_cpu=common_attn_metadata.seq_lens_cpu[:num_decodes],
)
metadata = FlashMLASparseMetadata(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
block_table=common_attn_metadata.block_table_tensor,
req_id_per_token=req_id_per_token,
block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens,
fp8_extra_metadata=fp8_extra_metadata,
num_prefills=num_prefills,
num_decodes=num_decodes,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
decode_metadata=decode_metadata,
prefill_metadata=prefill_metadata
)
return metadata
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
topk_indice_buffer: Optional[torch.Tensor] = None,
indexer: Optional["Indexer"] = None,
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer = indexer.topk_indices_buffer
self.padding = 128 if current_platform.is_device_capability(
100) else 64
def _forward_bf16_kv(
self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
num_tokens = q.shape[0]
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.contiguous().view(
-1, kv_c_and_k_pe_cache.shape[-1])
# num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decodes = attn_metadata.num_decodes
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
def _bf16_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
# Reshape q: (num_decode_tokens, num_heads, head_dim)
# -> (num_decodes, seq_len, num_heads, head_dim)
q = reshape_query_for_spec_decode(q, num_decodes)
seq_len = q.shape[1]
# Reshape topk_indices: (num_decode_tokens, topk)
# -> (num_decodes, seq_len, topk)
topk_indices = topk_indices.view(num_decodes, seq_len, -1)
decode_metadata = attn_metadata.decode_metadata
_attn_out, _, _ = kunlun_flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache,
head_dim_v=512,
cache_seqlens=decode_metadata.seq_lens,
cache_seqlens_cpu=decode_metadata.seq_lens_cpu,
is_fp8_kvcache=False,
indices=topk_indices,
softmax_scale=self.softmax_scale,
max_seq_kv=decode_metadata.max_seq_len
)
# Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
# -> (num_decode_tokens, num_heads, head_dim_v)
return reshape_attn_output_for_spec_decode(_attn_out)
def _bf16_prefill(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
prefill_metadata = attn_metadata.prefill_metadata
topk_indices = topk_indices.view(num_prefill_tokens, 1, -1)
# NOTE: 只有prefill阶段attn_metadata.query_start_loc是符合klx算子需求的
_attn_out = flash_mla_sparse_prefill(
q=q,
kv=kv_c_and_k_pe_cache,
indices=topk_indices,
sm_scale=self.softmax_scale,
q_lod_xpu=prefill_metadata.query_start_loc,
q_lod_cpu=prefill_metadata.query_start_loc_cpu
)[0]
return _attn_out
topk_indices_global = torch.ops.xspeedgate_ops.convert_req_index_to_global_index(
req_id=attn_metadata.req_id_per_token,
block_table=attn_metadata.block_table,
token_indices=topk_indices,
block_size=attn_metadata.block_size,
num_topk_tokens=attn_metadata.topk_tokens,
)
attn_out = torch.empty(
(num_tokens, self.num_heads, self.kv_lora_rank),
dtype=q.dtype,
device=q.device,
)
if has_prefill:
prefill_q = q[num_decode_tokens:]
prefill_topk_indices_global = topk_indices_global[num_decode_tokens:]
attn_out[num_decode_tokens:] = _bf16_prefill(prefill_q, prefill_topk_indices_global)
# 处理decode部分 - 需要正确的block table映射print
if has_decode:
decode_q = q[:num_decode_tokens]
decode_topk_indices_global = topk_indices_global[:num_decode_tokens]
attn_out[:num_decode_tokens] = _bf16_decode(decode_q, decode_topk_indices_global)
return attn_out
def _forward_fp8_kv(self, q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
# TODO: When fwd_kvcache_mla supports uint8 kv cache, execute this function.
assert attn_metadata.fp8_extra_metadata is not None
extra_metadata = attn_metadata.fp8_extra_metadata
_attn_out, _ = flash_mla_with_kvcache(
q=q.unsqueeze(0), # unsqueeze to add batch_dim
k_cache=kv_c_and_k_pe_cache,
block_table=extra_metadata.dummy_block_table,
head_dim_v=512,
cache_seqlens=extra_metadata.cache_lens,
tile_scheduler_metadata=extra_metadata.scheduler_metadata, # None
num_splits=extra_metadata.num_splits, # None
is_fp8_kvcache=True,
indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim
softmax_scale=self.softmax_scale,
max_seq_kv=attn_metadata.max_seq_len
)
return _attn_out
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLACommonImpl")
if attn_metadata is None:
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1)
topk_indices = self.topk_indices_buffer[:num_actual_toks]
q = torch.cat([ql_nope, q_pe], dim=-1)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
torch.ops._C.concat_and_cache_mla(
kv_c=k_c_normed,
k_pe=k_pe.squeeze(1),
kv_cache=kv_cache,
slot_mapping=attn_metadata.slot_mapping.flatten(),
)
if self.kv_cache_dtype != "fp8_ds_mla":
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
attn_metadata)
else:
# attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
# attn_metadata)
raise NotImplementedError
self._v_up_proj(attn_out, out=output[:num_actual_toks])
return output

View File

@@ -0,0 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from dataclasses import dataclass
from typing import ClassVar, Optional
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
DeepseekV32IndexerMetadataBuilder,
DeepseekV32IndexerPrefillMetadata)
logger = init_logger(__name__)
@dataclass
class DeepSeekV32IndexerDecodeMetadata:
block_table: torch.Tensor
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
decode_lens: torch.Tensor
requires_padding: bool
schedule_metadata: torch.Tensor
@dataclass
class DeepseekV32IndexerMetadata:
# FIXME (zyongye)
# hacky way to access the data now, need to be in chunked meta
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
# The dimension of the attention heads
head_dim: int
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
def kunlun_build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
prefill_metadata = None
if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks(
common_attn_metadata.seq_lens_cpu,
self.max_prefill_buffer_size,
num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
reqs_start, reqs_end, query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.block_table_tensor)
for reqs_start, reqs_end in chunk_seq_ids
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks, )
decode_metadata = None
if num_decodes > 0:
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
out=self.decode_lens_buffer[:num_decodes])
decode_lens = self.decode_lens_buffer[:num_decodes]
decode_lens_cpu = torch.diff(
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max()
> decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.
block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(),
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
)
attn_metadata = DeepseekV32IndexerMetadata(
seq_lens=common_attn_metadata.seq_lens,
seq_lens_cpu=common_attn_metadata.seq_lens.cpu(),
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
prefill=prefill_metadata,
decode=decode_metadata,
)
# if get_tensor_model_parallel_rank() == 0:
# logger.info(f"attn_metadata: {attn_metadata}")
return attn_metadata
DeepseekV32IndexerMetadataBuilder.build = kunlun_build

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import os
import torch
import torch.nn as nn
from packaging import version
@@ -24,6 +24,7 @@ class TopKTopPSampler(nn.Module):
def __init__(self, logprobs_mode):
super().__init__()
self.logprobs_mode = logprobs_mode
logger.info_once(
"Using FlashInfer for top-p & top-k sampling.")
self.forward = self.forward_kunlun
@@ -40,9 +41,14 @@ class TopKTopPSampler(nn.Module):
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators), None
return random_sample(probs, generators), logits_to_return
def forward_kunlun(
self,
@@ -52,16 +58,13 @@ class TopKTopPSampler(nn.Module):
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling."""
if k is None and p is None:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators), None
if generators:
logger.warning_once("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")
if (k is None and p is None) or generators:
if generators:
logger.debug_once(
"FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation."
)
return self.forward_native(logits, generators, k, p)
# flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
@@ -196,6 +199,7 @@ def flashinfer_sample(
probs, top_k=k, deterministic=True)
else:
# Both top-k and top-p.
k = k.to(torch.int32)
next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs(
probs, top_k=k, top_p=p, deterministic=True)

View File

@@ -0,0 +1,344 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
if TYPE_CHECKING:
from vllm.attention.layer import Attention
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
def __init__(
self,
model_config: ModelConfig,
scheduler_config: SchedulerConfig,
mm_registry: MultiModalRegistry,
) -> None:
super().__init__()
self.model_config = model_config
self.scheduler_config = scheduler_config
self.mm_registry = mm_registry
self.cache = cache = processor_only_cache_from_config(
model_config, mm_registry)
self.max_model_len = model_config.max_model_len
self.max_num_reqs = scheduler_config.max_num_seqs
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config,
cache=cache)
max_tokens_by_modality = mm_registry \
.get_max_tokens_per_item_by_nonzero_modality(model_config,
cache=cache)
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
scheduler_config,
max_tokens_by_modality,
)
self.encoder_compute_budget = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
max_items_per_prompt_by_modality = dict[str, int]()
max_items_per_batch_by_modality = dict[str, int]()
for modality, max_tokens in max_tokens_by_modality.items():
(
max_items_per_prompt,
max_items_per_batch,
) = self.get_max_items(modality, max_tokens)
max_items_per_prompt_by_modality[modality] = max_items_per_prompt
max_items_per_batch_by_modality[modality] = max_items_per_batch
self.max_tokens_by_modality = max_tokens_by_modality
self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
self.max_items_per_batch_by_modality = max_items_per_batch_by_modality
def get_modality_with_max_tokens(self) -> str:
max_tokens_by_modality = self.max_tokens_by_modality
modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
return modality
def get_encoder_budget(self) -> int:
return min(self.encoder_compute_budget, self.encoder_cache_size)
def get_max_items(
self,
modality: str,
max_tokens_per_item: int,
) -> tuple[int, int]:
if max_tokens_per_item == 0:
return 0, 0
# Check how many items of this modality can be supported by
# the encoder budget.
encoder_budget = self.get_encoder_budget()
# TODO: handle encoder-decoder models once we support them.
if encoder_budget == 0:
return 0, 0
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
# Check how many items of this modality can be supported by
# the decoder budget.
mm_limit = self.mm_limits[modality]
max_items_per_prompt = max(
1,
min(mm_limit, self.max_model_len // max_tokens_per_item),
)
scheduler_config = self.scheduler_config
max_num_reqs = self.max_num_reqs
if not scheduler_config.enable_chunked_prefill:
max_num_reqs = min(
max_num_reqs,
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
)
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
max_items_per_batch = max(
1,
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
)
return max_items_per_prompt, max_items_per_batch
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder]
layer_names: list[str]
kv_cache_spec: KVCacheSpec
@staticmethod
def create_with_metadata_builders(
backend: type[AttentionBackend],
layer_names: list[str],
kv_cache_spec: KVCacheSpec,
vllm_config: VllmConfig,
device: torch.device,
num_metadata_builders: int = 1,
) -> 'AttentionGroup':
metadata_builders = [
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config,
device)
for _ in range(num_metadata_builders)
]
return AttentionGroup(backend, metadata_builders, layer_names,
kv_cache_spec)
def get_metadata_builder(self,
ubatch_id: int = 0) -> AttentionMetadataBuilder:
assert len(self.metadata_builders) > ubatch_id
return self.metadata_builders[ubatch_id]
def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings,
expected_num_items: int,
) -> None:
"""
Perform sanity checks for the result of
[`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
"""
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
f"or a single 3D tensor, but got {type(mm_embeddings)} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
assert len(mm_embeddings) == expected_num_items, (
"Expected number of multimodal embeddings to match number of "
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
assert all(e.ndim == 2 for e in mm_embeddings), (
"Expected multimodal embeddings to be a sequence of 2D tensors, "
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
def scatter_mm_placeholders(
embeds: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
[`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][].
Args:
embeds: The multimodal embeddings.
Shape: `(num_embeds, embed_dim)`
is_embed: A boolean mask indicating which positions in the placeholder
tokens need to be filled with multimodal embeddings.
Shape: `(num_placeholders, num_embeds)`
"""
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
def gather_mm_placeholders(
placeholders: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Reconstructs the embeddings from the placeholder tokens.
This is the operation of [`scatter_mm_placeholders`]
[vllm.v1.worker.utils.scatter_mm_placeholders].
"""
if is_embed is None:
return placeholders
return placeholders[is_embed]
def add_kv_sharing_layers_to_kv_cache_groups(
shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec],
runner_only_attn_layers: Optional[set[str]] = None,
) -> None:
"""
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
for layers that do not allocate its own KV cache, based on the mapping in
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
group, which is needed to ensure that attention metadata is assigned later.
Args:
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
If an Attention layer `layer_name` is in the keys of this dict, it
means this layer will perform attention using the keys and values
from the KV cache of `shared_kv_cache_layers[layer_name]`.
kv_cache_groups: The KV cache groups of the model.
"""
layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
for kv_cache_group in kv_cache_groups:
for layer_name in kv_cache_group.layer_names:
layer_to_kv_cache_group[layer_name] = kv_cache_group
for layer_name, target_layer_name in shared_kv_cache_layers.items():
tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
tgt_kv_cache_group.layer_names.append(layer_name)
if runner_only_attn_layers is not None:
runner_only_attn_layers.add(layer_name)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
num_attn_module: Optional[int] = 1,
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name,
num_attn_module)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
# TODO - analyze where runner_kv_caches is used and the right
# way to ensure it properly reflects multiple attention layers
# in the same decoder block.
if current_platform.is_kunlun() or current_platform.is_cuda() or current_platform.is_xpu():
# We know that the GPU runner is not impacted by this
# case. Some test code depends on runner_kv_caches, but
# not in a way that's impacted by ignoring this.
pass
else:
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
def is_residual_scattered_for_sp(vllm_config: VllmConfig,
num_input_tokens: int) -> bool:
"""Check if the residual tensor is scattered for sequence parallelism.
The residual tensor is scattered across tensor parallel ranks when sequence
parallelism and tensor parallelism is enabled, and the number of
input tokens is one of the compilation sizes.
"""
if not vllm_config.compilation_config.pass_config.\
enable_sequence_parallelism:
return False
tp = vllm_config.parallel_config.tensor_parallel_size
if tp == 1:
return False
# When sequence parallelism is enabled, we always pad num_input_tokens
# to be a multiple of tensor_parallel_size (tp) earlier.
assert num_input_tokens % tp == 0
# Currently, SP is only enabled for static size fx graphs.
return (num_input_tokens in vllm_config.compilation_config.compile_sizes)

View File

@@ -1493,4 +1493,45 @@ def _fake_gptq_shuffle(
return None
gptq_shuffle.register_fake(_fake_gptq_shuffle)
gptq_shuffle.register_fake(_fake_gptq_shuffle)
##################################################
# ---------------- concat_and_cache_mla ------------------
##################################################
@custom_op("_C::concat_and_cache_mla", mutates_args=())
def concat_and_cache_mla(
kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
k_pe: torch.Tensor, #[num_tokens, pe_dim]
kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
) -> None:
xtorch_ops.concat_and_cache_mla(
kv_c=kv_c,
k_pe=k_pe,
slot_mapping=slot_mapping,
kv_cache=kv_cache,
)
@impl("_C::concat_and_cache_mla", "CUDA")
def concat_and_cache_mla_cuda(
kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
k_pe: torch.Tensor, #[num_tokens, pe_dim]
kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
) -> None:
xtorch_ops.concat_and_cache_mla(
kv_c=kv_c,
k_pe=k_pe,
slot_mapping=slot_mapping,
kv_cache=kv_cache,
)
def _fake_concat_and_cache_mla(
kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
k_pe: torch.Tensor, #[num_tokens, pe_dim]
kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
) -> None:
return None
concat_and_cache_mla.register_fake(_fake_concat_and_cache_mla)