[Feature] support deepseek v3/r1/v3.2 (#78)
* [Feature] support deepseek v3/r1/v3.2 * fix gpt_oss * update readme * update readme --------- Co-authored-by: hanhaowen <hanhaowen@baidu.com>
This commit is contained in:
@@ -10,34 +10,15 @@ import vllm.envs as envs
|
||||
OLD_IMPORT_HOOK = builtins.__import__
|
||||
def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0):
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 模块映射表
|
||||
module_mappings = {
|
||||
"vllm.model_executor.layers.fused_moe.layer": "vllm_kunlun.ops.fused_moe.layer",
|
||||
"vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe": "vllm_kunlun.ops.quantization.compressed_tensors_moe",
|
||||
"vllm.compilation.wrapper": "vllm_kunlun.compilation.wrapper",
|
||||
"vllm.v1.worker.gpu_model_runner": "vllm_kunlun.v1.worker.gpu_model_runner"
|
||||
"vllm.v1.worker.utils": "vllm_kunlun.v1.worker.utils",
|
||||
"vllm.model_executor.model_loader.bitsandbytes_loader": "vllm_kunlun.models.model_loader.bitsandbytes_loader",
|
||||
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
|
||||
"vllm.model_executor.layers.sampler": "vllm_kunlun.ops.sample.sampler",
|
||||
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
|
||||
}
|
||||
|
||||
# 需要保持原始导入的模块
|
||||
original_imports = [
|
||||
"vllm.model_executor.layers.fused_moe.base",
|
||||
"vllm.model_executor.layers.fused_moe.config",
|
||||
"vllm.model_executor.layers.fused_moe.layer"
|
||||
]
|
||||
|
||||
if module_name in original_imports:
|
||||
if module_name == "vllm.model_executor.layers.fused_moe.layer" and fromlist:
|
||||
if "FusedMoEMethodBase" in fromlist:
|
||||
return OLD_IMPORT_HOOK(
|
||||
module_name,
|
||||
globals=globals,
|
||||
locals=locals,
|
||||
fromlist=fromlist,
|
||||
level=level
|
||||
)
|
||||
|
||||
if module_name in module_mappings:
|
||||
if module_name in sys.modules:
|
||||
return sys.modules[module_name]
|
||||
@@ -45,25 +26,6 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
|
||||
module = importlib.import_module(target_module)
|
||||
sys.modules[module_name] = module
|
||||
sys.modules[target_module] = module
|
||||
return module
|
||||
|
||||
relative_mappings = {
|
||||
("compressed_tensors_moe", "compressed_tensors"): "vllm_kunlun.ops.quantization.compressed_tensors_moe",
|
||||
("layer", "fused_moe"): "vllm_kunlun.ops.fused_moe.layer",
|
||||
}
|
||||
|
||||
if level == 1:
|
||||
parent = globals.get('__package__', '').split('.')[-1] if globals else ''
|
||||
key = (module_name, parent)
|
||||
if key in relative_mappings:
|
||||
if module_name in sys.modules:
|
||||
return sys.modules[module_name]
|
||||
target_module = relative_mappings[key]
|
||||
module = importlib.import_module(target_module)
|
||||
sys.modules[module_name] = module
|
||||
sys.modules[target_module] = module
|
||||
return module
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -77,79 +39,16 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
|
||||
|
||||
def import_hook():
|
||||
"""Apply import hook for VLLM Kunlun"""
|
||||
if not int(os.environ.get("DISABLE_KUNLUN_HOOK", "0")):
|
||||
builtins.__import__ = _custom_import
|
||||
|
||||
try:
|
||||
modules_to_preload = [
|
||||
"vllm_kunlun.ops.quantization.compressed_tensors_moe",
|
||||
"vllm_kunlun.ops.fused_moe.custom_ops",
|
||||
"vllm_kunlun.ops.fused_moe.layer",
|
||||
"vllm_kunlun.ops.quantization.fp8",
|
||||
]
|
||||
for module_name in modules_to_preload:
|
||||
importlib.import_module(module_name)
|
||||
except Exception:
|
||||
pass
|
||||
builtins.__import__ = _custom_import
|
||||
|
||||
def register():
|
||||
"""Register the Kunlun platform"""
|
||||
from .utils import redirect_output
|
||||
from .vllm_utils_wrapper import direct_register_custom_op, patch_annotations_for_schema
|
||||
patch_bitsandbytes_loader()
|
||||
import_hook()
|
||||
if envs.VLLM_USE_V1:
|
||||
# patch_V1blockTable()
|
||||
patch_V1top_p_K()
|
||||
# TODO fixed fast top & k for vLLM 0.10.2,
|
||||
pass
|
||||
else:
|
||||
patch_sampler()
|
||||
return "vllm_kunlun.platforms.kunlun.KunlunPlatform"
|
||||
|
||||
def register_model():
|
||||
"""Register models for training and inference"""
|
||||
from .models import register_model as _reg
|
||||
_reg()
|
||||
|
||||
# [monkey patach sampler]
|
||||
import sys
|
||||
import sys, importlib, warnings
|
||||
|
||||
def patch_bitsandbytes_loader():
|
||||
try:
|
||||
# 载入你插件里自定义的 direct_register_custom_op 实现
|
||||
custom_utils = importlib.import_module("vllm_kunlun.models.model_loader.bitsandbytes_loader")
|
||||
# 覆盖 vllm.utils
|
||||
sys.modules["vllm.model_executor.model_loader.bitsandbytes_loader"] = custom_utils
|
||||
print("[vllm_kunlun] bitsandbytes_loader patched ->", custom_utils.__file__)
|
||||
except Exception as e:
|
||||
warnings.warn(f"[vllm_kunlun] bitsandbytes_loader patch failed: {e!r}")
|
||||
|
||||
def patch_sampler():
|
||||
try:
|
||||
custom_sampler = importlib.import_module("vllm_kunlun.ops.sample.sampler")
|
||||
sys.modules["vllm.model_executor.layers.sampler"] = custom_sampler
|
||||
print("[vllm_kunlun] sampler patched ->", custom_sampler.__file__)
|
||||
except Exception as e:
|
||||
warnings.warn(f"[vllm_kunlun] sampler patch failed: {e!r}")
|
||||
|
||||
|
||||
def patch_V1top_p_K():
|
||||
try:
|
||||
custom_sampler = importlib.import_module("vllm_kunlun.v1.sample.ops.topk_topp_sampler")
|
||||
sys.modules["vllm.v1.sample.ops.topk_topp_sampler"] = custom_sampler
|
||||
print("[vllm_kunlun] V1sampler top p & k patched ->", custom_sampler.__file__)
|
||||
except Exception as e:
|
||||
warnings.warn(f"[vllm_kunlun] V1 sampler top p & k patch failed: {e!r}")
|
||||
|
||||
def patch_V1blockTable():
|
||||
try:
|
||||
custom_sampler = importlib.import_module("vllm_kunlun.v1.worker.block_table")
|
||||
sys.modules["vllm.v1.worker.block_table"] = custom_sampler
|
||||
print("[vllm_kunlun] V1 block table patched ->", custom_sampler.__file__)
|
||||
except Exception as e:
|
||||
warnings.warn(f"[vllm_kunlun] V1 block table patch failed: {e!r}")
|
||||
|
||||
# 在模块导入时自动应用补丁
|
||||
import_hook()
|
||||
_reg()
|
||||
@@ -80,6 +80,14 @@ def register_model():
|
||||
ModelRegistry.register_model(
|
||||
"GptOssForCausalLM",
|
||||
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
|
||||
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV3ForCausalLM",
|
||||
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV32ForCausalLM",
|
||||
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
|
||||
|
||||
def register_quant_method():
|
||||
"""to do"""
|
||||
|
||||
1445
vllm_kunlun/models/deepseek_v2.py
Normal file
1445
vllm_kunlun/models/deepseek_v2.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -16,7 +16,7 @@ from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm_kunlun.ops.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
@@ -176,7 +176,7 @@ class MLPBlock(torch.nn.Module):
|
||||
x = sequence_parallel_chunk(x)
|
||||
|
||||
g = self.router(x)
|
||||
x = self.experts(hidden_states=x, router_logits=g, linear_weights=self.router.weight)
|
||||
x = self.experts(hidden_states=x, router_logits=g)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
|
||||
|
||||
@@ -21,7 +21,7 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm_kunlun.ops.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -185,8 +185,7 @@ class MiMoV2MoE(nn.Module):
|
||||
gate_input = hidden_states
|
||||
router_logits = self.gate(gate_input)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits, linear_weights=self.gate.weight
|
||||
)
|
||||
hidden_states=hidden_states, router_logits=router_logits)
|
||||
|
||||
return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
|
||||
|
||||
|
||||
@@ -20,4 +20,8 @@ import vllm_kunlun.ops.layernorm
|
||||
import vllm_kunlun.ops.quantization.awq
|
||||
import vllm_kunlun.ops.quantization.gptq
|
||||
import vllm_kunlun.ops.vocab_parallel_embedding
|
||||
import vllm_kunlun.ops.linear
|
||||
import vllm_kunlun.ops.linear
|
||||
import vllm_kunlun.ops.quantization.kernels.scaled_mm.cutlass
|
||||
import vllm_kunlun.ops.vocab_parallel_embedding
|
||||
import vllm_kunlun.ops.quantization.compressed_tensors_moe
|
||||
import vllm_kunlun.ops.fused_moe.layer
|
||||
@@ -417,7 +417,6 @@ class KunlunOps:
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
ep_rank: int,
|
||||
moe_top_k: int,
|
||||
renormalize: bool,
|
||||
|
||||
@@ -108,7 +108,7 @@ class SiluAndMul(CustomOp):
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
xtorch_ops.swiglu(x, out)
|
||||
torch.ops._C.silu_and_mul(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
260
vllm_kunlun/ops/attention/flashmla.py
Normal file
260
vllm_kunlun/ops/attention/flashmla.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
import xtorch_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_C # noqa: F401
|
||||
_flashmla_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_extension_C # noqa: F401
|
||||
_flashmla_extension_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
|
||||
|
||||
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Return: is_supported_flag, unsupported_reason (optional).
|
||||
"""
|
||||
return True, None
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int = 1,
|
||||
num_heads_k: int = 1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
cache_seqlens: (batch_size), dtype torch.int32.
|
||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||
num_heads_k: num_heads_k.
|
||||
|
||||
Returns:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
# return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
|
||||
cache_seqlens_cpu = cache_seqlens.cpu()
|
||||
return cache_seqlens_cpu, cache_seqlens
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
descale_q: Optional[torch.Tensor] = None,
|
||||
descale_k: Optional[torch.Tensor] = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head dimension of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
|
||||
Returns:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
|
||||
softmax_lse = None
|
||||
out = torch.ones(q.size(0), q.size(1), q.size(2), head_dim_v, dtype= q.dtype, device=q.device)
|
||||
kv_lora_rank = head_dim_v
|
||||
qk_rope_head_dim = q.size(3) - head_dim_v
|
||||
head_dim = k_cache.shape[3]
|
||||
page_block_size = k_cache.shape[1]
|
||||
k_cache = k_cache.view(-1, 1, page_block_size, head_dim)
|
||||
|
||||
# todo: optimize memcp
|
||||
# q_c = q[..., : kv_lora_rank].contiguous()
|
||||
# q_r = q[..., kv_lora_rank :].contiguous()
|
||||
|
||||
is_context = False
|
||||
vo_head_dim = -1
|
||||
|
||||
xtorch_ops.paged_attention(out,
|
||||
q,
|
||||
k_cache, None,
|
||||
block_table,
|
||||
tile_scheduler_metadata, # context_lens_cpu
|
||||
num_splits, # context_lens_xpu
|
||||
is_context,
|
||||
causal,
|
||||
vo_head_dim,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
softmax_scale,
|
||||
q_r=q)
|
||||
return out, softmax_lse
|
||||
|
||||
def kunlun_flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
cache_seqlens_cpu: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
is_fp8_kvcache: bool = False,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
max_seq_kv: int = 1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_tokens_kv, head_dim).
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head dimension of v.
|
||||
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format.
|
||||
indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
|
||||
max_seq_kv: seq中最大的kv长度
|
||||
|
||||
Returns:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
max_logits: (batch_size, seq_len_q, num_heads_q), torch.float32.
|
||||
p_sums: (batch_size, seq_len_q, num_heads_q), torch.float32.
|
||||
"""
|
||||
assert not is_fp8_kvcache, "By now, the kernel does not support uint8 kv cache."
|
||||
assert q.shape[1] <= 2, "xtorch_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
if indices is not None:
|
||||
# NOTE (zyongye): sparse attention is also causal
|
||||
# since it only attend to the tokens before
|
||||
# but here `causal` should not be specified
|
||||
assert not causal, \
|
||||
"causal must be `false` if sparse attention is enabled."
|
||||
|
||||
q_r, pe_cache = None, None # 当q_r和pe_cache为空时,为packed模式
|
||||
batch_size, seq_len_q, num_heads_q, head_dim = q.shape
|
||||
kv_lora_rank = head_dim_v
|
||||
rope_head_dim = head_dim - kv_lora_rank
|
||||
|
||||
out = torch.zeros([batch_size, seq_len_q, num_heads_q, kv_lora_rank],
|
||||
dtype=q.dtype, device=q.device)
|
||||
max_logits = torch.zeros([batch_size, seq_len_q, num_heads_q],
|
||||
dtype=torch.float32, device=q.device)
|
||||
p_sums = torch.zeros([batch_size, seq_len_q, num_heads_q],
|
||||
dtype=torch.float32, device=q.device)
|
||||
|
||||
xtorch_ops.fwd_kvcache_mla(
|
||||
q_c=q,
|
||||
kv_cache=k_cache,
|
||||
indices=indices,
|
||||
kv_lod_cpu=cache_seqlens_cpu,
|
||||
max_seq_kv=max_seq_kv,
|
||||
softmax_scale=softmax_scale,
|
||||
# q_r=q_r,
|
||||
# pe_cache=pe_cache,
|
||||
out=out,
|
||||
max_logits=max_logits,
|
||||
p_sums=p_sums,
|
||||
kv_lod_xpu=cache_seqlens,
|
||||
)
|
||||
|
||||
return out, max_logits, p_sums
|
||||
|
||||
|
||||
def flash_mla_sparse_prefill(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
q_lod_xpu: torch.Tensor,
|
||||
q_lod_cpu: torch.Tensor,
|
||||
d_v: int = 512,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sparse attention prefill kernel
|
||||
|
||||
Args:
|
||||
- q: [s_q, h_q, d_qk], bfloat16
|
||||
- kv: [s_kv, d_qk], bfloat16
|
||||
- indices: [s_q, h_kv, topk], int32.
|
||||
Invalid indices should be set to -1 or numbers >= s_kv
|
||||
- sm_scale: float
|
||||
- q_lod_xpu: [batch+1], int32, q的每个seq长度的累加信息, 长度为batch_num + 1 (为空则表示q定长).
|
||||
- d_v: The dimension of value vectors. Can only be 512
|
||||
|
||||
Returns:
|
||||
- (output, max_logits, lse)
|
||||
About the definition of output,
|
||||
max_logits and lse, please refer to README.md
|
||||
- output: [s_q, h_q, d_v], bfloat16
|
||||
- max_logits: [s_q, h_q], float
|
||||
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
||||
"""
|
||||
s_q, h_q, d_qk = q.shape
|
||||
|
||||
out = torch.zeros([s_q, h_q, d_v], dtype=q.dtype, device=q.device)
|
||||
max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
|
||||
lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
|
||||
|
||||
xtorch_ops.sparse_prefill_fwd_opt(
|
||||
q=q,
|
||||
kv=kv,
|
||||
indices=indices,
|
||||
qlod_cpu=q_lod_cpu,
|
||||
qlod_xpu=q_lod_xpu,
|
||||
kvlod_cpu=q_lod_cpu,
|
||||
kvlod_xpu=q_lod_xpu,
|
||||
sm_scale=sm_scale,
|
||||
d_v=d_v,
|
||||
is_causal=True, #aiak这个值为true,这是为啥
|
||||
out=out,
|
||||
max_logits=max_logits,
|
||||
lse=lse,
|
||||
)
|
||||
|
||||
# NOTE: Compared with torch.ops._flashmla_C.sparse_prefill_fwd,
|
||||
# out_scale = 1 / math.log2(math.e)
|
||||
# gpu_max_logits * out_scale = kunlun_lse
|
||||
# gpu_lse * out_scale = kunlun_lse
|
||||
return out, max_logits, lse
|
||||
|
||||
|
||||
#
|
||||
# TODO: Add fake functions
|
||||
#
|
||||
# @register_fake("_flashmla_C::get_mla_metadata")
|
||||
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
# @register_fake("_flashmla_C::fwd_kvcache_mla")
|
||||
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
180
vllm_kunlun/ops/attention/mla.py
Normal file
180
vllm_kunlun/ops/attention/mla.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_kunlun.ops.attention.layer import Attention
|
||||
# from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLAModules:
|
||||
"""Modules used in MLA.
|
||||
"""
|
||||
kv_a_layernorm: torch.nn.Module
|
||||
kv_b_proj: torch.nn.Module
|
||||
rotary_emb: torch.nn.Module
|
||||
o_proj: torch.nn.Module
|
||||
fused_qkv_a_proj: Optional[torch.nn.Module]
|
||||
kv_a_proj_with_mqa: Optional[torch.nn.Module]
|
||||
q_a_layernorm: Optional[torch.nn.Module]
|
||||
q_b_proj: Optional[torch.nn.Module]
|
||||
q_proj: Optional[torch.nn.Module]
|
||||
indexer: Optional[torch.nn.Module]
|
||||
is_sparse: bool
|
||||
topk_indices_buffer: Optional[torch.Tensor]
|
||||
|
||||
|
||||
@CustomOp.register("vllm_kunlun_multi_head_latent_attention")
|
||||
class MultiHeadLatentAttention(CustomOp):
|
||||
"""MLA layer registered as CustomOp.
|
||||
Note that currently MLA ignores the enable/disable mechanism of CustomOp
|
||||
because there is only one in-tree implementation in forward_native.
|
||||
TODO: implement this with a new PluggableLayer mechanism.
|
||||
|
||||
This class takes positions and hidden_states as input.
|
||||
The input tensors can either contain prefill tokens or decode tokens.
|
||||
The class does the following:
|
||||
|
||||
1. MLA Preprocess.
|
||||
2. Perform multi-head attention to prefill tokens and
|
||||
multi-query attention to decode tokens separately.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
scale: float,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
mla_modules: MLAModules,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.num_heads = num_heads
|
||||
self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
|
||||
self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
|
||||
self.q_a_layernorm = mla_modules.q_a_layernorm
|
||||
self.q_b_proj = mla_modules.q_b_proj
|
||||
self.q_proj = mla_modules.q_proj
|
||||
self.kv_a_layernorm = mla_modules.kv_a_layernorm
|
||||
self.kv_b_proj = mla_modules.kv_b_proj
|
||||
self.rotary_emb = mla_modules.rotary_emb
|
||||
self.o_proj = mla_modules.o_proj
|
||||
self.indexer = mla_modules.indexer
|
||||
self.is_sparse = mla_modules.is_sparse
|
||||
|
||||
if self.indexer is not None:
|
||||
assert hasattr(self.indexer, "topk_tokens")
|
||||
self.topk_tokens = self.indexer.topk_tokens
|
||||
self.topk_indices_buffer = mla_modules.topk_indices_buffer
|
||||
|
||||
# In the MLA backend, kv_cache includes both k_c and
|
||||
# pe (i.e. decoupled position embeddings). In particular,
|
||||
# the concat_and_cache_mla op requires
|
||||
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
|
||||
# i.e.
|
||||
# kv_lora_rank + qk_rope_head_dim == head_size
|
||||
self.mla_attn = Attention(
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sparse=mla_modules.is_sparse,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
indexer=self.indexer,
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
q_c = None
|
||||
kv_lora = None
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
assert self.fused_qkv_a_proj is not None, \
|
||||
"fused_qkv_a_proj is required when q_lora_rank is not None"
|
||||
assert self.q_a_layernorm is not None, \
|
||||
"q_a_layernorm is required when q_lora_rank is not None"
|
||||
assert self.q_b_proj is not None, \
|
||||
"q_b_proj is required when q_lora_rank is not None"
|
||||
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
|
||||
q_c, kv_lora = qkv_lora.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
q = self.q_b_proj(q_c)[0]
|
||||
else:
|
||||
assert self.kv_a_proj_with_mqa is not None, \
|
||||
"kv_a_proj_with_mqa is required when q_lora_rank is None"
|
||||
assert self.q_proj is not None, \
|
||||
"q_proj is required when q_lora_rank is None"
|
||||
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
q = self.q_proj(hidden_states)[0]
|
||||
|
||||
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c)
|
||||
|
||||
q = q.view(-1, self.num_heads, self.qk_head_dim)
|
||||
# Add head dim of 1 to k_pe
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
|
||||
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
|
||||
positions, q[..., self.qk_nope_head_dim:], k_pe)
|
||||
|
||||
if self.indexer and self.is_sparse:
|
||||
_topk_indices = self.indexer(hidden_states, q_c, positions,
|
||||
self.rotary_emb)
|
||||
|
||||
hidden_states_shape_0 = 0
|
||||
if isinstance(hidden_states, tuple):
|
||||
x_q, x_scale = hidden_states
|
||||
hidden_states_shape_0 = x_q.shape[0]
|
||||
else:
|
||||
hidden_states_shape_0 = hidden_states.shape[0]
|
||||
attn_out = self.mla_attn(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=(hidden_states_shape_0,
|
||||
self.num_heads * self.v_head_dim))
|
||||
return self.o_proj(attn_out)[0]
|
||||
|
||||
def forward_cuda(self, *args, **kwargs):
|
||||
return self.forward_native(*args, **kwargs)
|
||||
114
vllm_kunlun/ops/deep_gemm.py
Normal file
114
vllm_kunlun/ops/deep_gemm.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import torch
|
||||
import xtorch_ops
|
||||
|
||||
def int8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
logits = torch.empty((q.shape[0], kv[0].shape[0]), dtype=torch.float32, device=q.device)
|
||||
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||
|
||||
xtorch_ops.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=kv,
|
||||
weights=weights,
|
||||
context_q_lens=(context_q_lens_xpu.cpu(), context_q_lens_xpu),
|
||||
context_k_lens=(context_k_lens_xpu.cpu(), context_k_lens_xpu),
|
||||
logits=logits,
|
||||
clean_logits=True,
|
||||
use_xfa_boost=False,
|
||||
)
|
||||
seq_len_kv = kv[0].shape[0]
|
||||
# mask参考 https://github.com/vllm-project/vllm/blob/v0.11.0/tests/kernels/attention/test_deepgemm_attention.py 的_ref_fp8_mqa_logits函数的实现
|
||||
mask_lo = (torch.arange(0, seq_len_kv, device=cu_seqlen_ks.device)[None, :]
|
||||
>= cu_seqlen_ks[:, None])
|
||||
mask_hi = (torch.arange(0, seq_len_kv, device=cu_seqlen_ke.device)[None, :]
|
||||
< cu_seqlen_ke[:, None])
|
||||
mask = mask_lo & mask_hi
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
def int8_paged_mqa_logits(
|
||||
q_fp8: torch.Tensor,
|
||||
kv_cache_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
context_lens_cpu: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
schedule_metadata: torch.Tensor,
|
||||
max_model_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits using paged KV-cache.
|
||||
|
||||
Args:
|
||||
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
|
||||
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
||||
4 bytes per (block,pos) store the `float` dequant scale.
|
||||
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
||||
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||
for each batch element.
|
||||
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
||||
block indices to physical blocks in the paged cache.
|
||||
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
|
||||
used to distribute work across SMs.
|
||||
max_model_len: Maximum sequence length used to size the logits output.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [B * next_n, max_model_len], dtype
|
||||
`torch.float32`.
|
||||
"""
|
||||
batch_size, next_n, _, D = q_fp8.shape
|
||||
num_blocks, block_size, _, _ = kv_cache_fp8.shape
|
||||
|
||||
kv_cache_fp8=kv_cache_fp8.view(num_blocks, -1)
|
||||
k_val = kv_cache_fp8[:,:block_size*D].view(torch.int8)
|
||||
k_val = k_val.view(-1,block_size, 1, D)
|
||||
k_scale_list = []
|
||||
for block_tables_idx in range(block_tables.shape[0]):
|
||||
k_scale_item = kv_cache_fp8[block_tables[block_tables_idx], block_size *
|
||||
D:].view(-1, 4)
|
||||
k_scale_list.append(k_scale_item)
|
||||
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).view(-1,max_model_len)
|
||||
kv_cache = [k_val, k_scale]
|
||||
|
||||
weights = weights.view(batch_size,next_n,-1)
|
||||
|
||||
logits = torch.empty((batch_size, next_n, max_model_len), dtype=torch.float32, device=q_fp8.device)
|
||||
|
||||
xtorch_ops.I8_paged_mqa_logits(
|
||||
q=q_fp8,
|
||||
fused_kv_cache=kv_cache,
|
||||
weights=weights,
|
||||
context_lens=[context_lens_cpu, context_lens],
|
||||
block_table=block_tables,
|
||||
max_context_len=max_model_len,
|
||||
clean_logits=True,
|
||||
out=logits,
|
||||
use_xfa_boost=False
|
||||
)
|
||||
logits = logits.view(-1, max_model_len)
|
||||
return logits
|
||||
@@ -1,37 +1,14 @@
|
||||
"""layer.py"""
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Callable, Optional, Union, get_args
|
||||
|
||||
import torch
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE as VllmFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase as VllmFusedMoEMethodBase
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
UnquantizedFusedMoEMethod as VllmUnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig)
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_kunlun.ops.quantization.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
|
||||
|
||||
|
||||
class FusedMoEMethodBase(VllmFusedMoEMethodBase):
|
||||
"""FusedMoEMethodBase"""
|
||||
moe: FusedMoEConfig
|
||||
|
||||
@CustomOp.register("vllm_kunlun_unquantized_fused_moe")
|
||||
class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
"""UnquantizedFusedMoEMethod"""
|
||||
def apply(
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
@@ -45,6 +22,7 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
@@ -52,40 +30,12 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
linear_weights: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""apply"""
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
|
||||
|
||||
return self.forward_kunlun(x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
linear_weights=linear_weights,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
def forward_kunlun(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
"""forward_kunlun"""
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
if self.moe.use_ep:
|
||||
@@ -93,21 +43,18 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
linear_weights,
|
||||
self.moe.ep_rank,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group
|
||||
)
|
||||
topk_group=topk_group)
|
||||
else:
|
||||
return ops.fused_moe(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
linear_weights,
|
||||
self.moe.ep_rank,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
@@ -118,12 +65,13 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
w1_bias = layer.w13_bias,
|
||||
w2_bias = layer.w2_bias,
|
||||
)
|
||||
w2_bias = layer.w2_bias)
|
||||
|
||||
class FusedMoE(VllmFusedMoE):
|
||||
"""FusedMoE"""
|
||||
def __init__(self,
|
||||
UnquantizedFusedMoEMethod.apply = apply
|
||||
|
||||
class VllmFusedMoE(FusedMoE):
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int, # Global number of experts
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
@@ -141,198 +89,47 @@ class FusedMoE(VllmFusedMoE):
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
num_redundant_experts: int = 0,
|
||||
is_sequence_parallel=False,
|
||||
has_bias: bool = False,
|
||||
is_sequence_parallel=False,
|
||||
zero_expert_num: Optional[int] = 0,
|
||||
zero_expert_type: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts, # Global number of experts
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=reduce_results,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
ep_size=ep_size,
|
||||
dp_size=dp_size,
|
||||
prefix=prefix,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=activation,
|
||||
enable_eplb=enable_eplb,
|
||||
num_redundant_experts=num_redundant_experts,
|
||||
)
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.model_config is not None:
|
||||
model_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
# TODO (bnell): This is a hack to get test_mixtral_moe to work
|
||||
# since model_config is not set in the pytest test.
|
||||
model_dtype = params_dtype
|
||||
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=model_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
num_experts=num_experts, # Global number of experts
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=reduce_results,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
ep_size=ep_size,
|
||||
dp_size=dp_size,
|
||||
prefix=prefix,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=activation,
|
||||
enable_eplb=enable_eplb,
|
||||
num_redundant_experts=num_redundant_experts,
|
||||
has_bias=has_bias,
|
||||
# quant_config=quant_config,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
is_sequence_parallel=is_sequence_parallel,
|
||||
zero_expert_num=zero_expert_num,
|
||||
zero_expert_type=zero_expert_type)
|
||||
self.has_bias=has_bias
|
||||
self.register_parameter("w13_bias", None)
|
||||
self.register_parameter("w2_bias", None)
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
quant_method: Optional[QuantizeMethodBase] = None
|
||||
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
|
||||
else quant_config.get_quant_method(self, prefix))
|
||||
|
||||
assert quant_method is not None
|
||||
# assert isinstance(quant_method, FusedMoEMethodBase)
|
||||
self.quant_method = quant_method
|
||||
|
||||
if self.enable_eplb:
|
||||
from vllm_kunlun.ops.quantization.fp8 import (
|
||||
Fp8MoEMethod)
|
||||
if not isinstance(quant_method, Fp8MoEMethod):
|
||||
# TODO: Add support for additional quantization methods.
|
||||
# The implementation for other quantization methods does not
|
||||
# contain essential differences, but the current quant API
|
||||
# design causes duplicated work when extending to new
|
||||
# quantization methods, so I'm leaving it for now.
|
||||
# If you plan to add support for more quantization methods,
|
||||
# please refer to the implementation in `Fp8MoEMethod`.
|
||||
raise NotImplementedError("EPLB is only supported for FP8 "
|
||||
"quantization for now.")
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor = None,
|
||||
linear_weights: torch.Tensor = None):
|
||||
"""forward"""
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we will
|
||||
# switch to using the moe_forward custom op.
|
||||
if current_platform.is_tpu():
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
else:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[self.layer_name]
|
||||
assert self.quant_method is not None
|
||||
return self.forward_impl(hidden_states, router_logits, linear_weights)
|
||||
# return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||
# self.layer_name)
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
linear_weights: torch.Tensor = None):
|
||||
"""forward_impl"""
|
||||
assert self.quant_method is not None
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||
return self.forward_impl_chunked(hidden_states, router_logits)
|
||||
|
||||
do_naive_dispatch_combine: bool = (
|
||||
self.dp_size > 1
|
||||
and not self.moe_parallel_config.use_deepep_ht_kernels)
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
enable_eplb=self.enable_eplb,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
linear_weights=linear_weights
|
||||
)
|
||||
|
||||
if do_naive_dispatch_combine:
|
||||
final_hidden_states = get_ep_group().combine(final_hidden_states)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
# Default set to False. (May have to add shared expert outputs.
|
||||
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls,
|
||||
ckpt_gate_proj_name: str,
|
||||
ckpt_down_proj_name: str,
|
||||
ckpt_up_proj_name: str,
|
||||
num_experts: int,
|
||||
num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]:
|
||||
|
||||
num_physical_experts = num_experts + num_redundant_experts
|
||||
|
||||
# In the returned mapping:
|
||||
# - `expert_id` is the physical expert id
|
||||
# - `weight_name` contains the weight name of the logical expert
|
||||
# So that we should map the expert id to logical in `weight_name`
|
||||
physical_to_logical_map = \
|
||||
EplbState.build_initial_global_physical_to_logical_map(
|
||||
num_experts, num_redundant_experts)
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("experts.w13_" if weight_name
|
||||
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
|
||||
f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.",
|
||||
expert_id, shard_id) for expert_id in range(num_physical_experts)
|
||||
for shard_id, weight_name in [
|
||||
("w1", ckpt_gate_proj_name),
|
||||
("w2", ckpt_down_proj_name),
|
||||
("w3", ckpt_up_proj_name),
|
||||
]
|
||||
]
|
||||
FusedMoE=VllmFusedMoE
|
||||
@@ -1,244 +1,169 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from typing import Any, Literal, Optional, cast, Callable, Optional
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
|
||||
|
||||
from compressed_tensors.config import (CompressionFormat,
|
||||
SparsityCompressionConfig,
|
||||
SparsityStructure)
|
||||
from compressed_tensors.quantization import (ActivationOrdering,
|
||||
QuantizationStrategy)
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
# TODO: import position will be changed after 0.9.0
|
||||
# vllm.model_executor.layers.fused_moe.fused_moe --> vllm.model_executor.layers.fused_moe
|
||||
def klx_process_weights_after_loading(layer: torch.nn.Module) -> None:
|
||||
"""modify scale -> abs max"""
|
||||
layer.w13_weight = torch.nn.Parameter(layer.w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(layer.w2_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
layer.w13_weight_scale.data * 127, requires_grad=False
|
||||
)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
layer.w2_weight_scale.data * 127, requires_grad=False
|
||||
)
|
||||
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
import re
|
||||
import xtorch_ops
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
klx_process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
hidden_states = x
|
||||
global_num_experts, up_gate_size, _ = layer.w13_weight.shape
|
||||
M, N = hidden_states.shape
|
||||
hidden_dim = layer.w2_weight.shape[1]
|
||||
normed_score = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
num_blocks = 12
|
||||
block_statistic = torch.zeros(
|
||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def get_moe_method(quant_config, layer) -> "CompressedTensorsMoEMethod":
|
||||
tsm = getattr(quant_config, "target_scheme_map", None) or {}
|
||||
linear_cfg = None
|
||||
for k in ("Linear", "FusedMoE", "MoE", "Moe", "Experts"):
|
||||
if k in tsm and isinstance(tsm[k], dict):
|
||||
linear_cfg = tsm[k]; break
|
||||
if not linear_cfg:
|
||||
# print("target_scheme_map missing; fallback to INT8(W8A8) method")
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
wq = linear_cfg.get("weights"); aq = linear_cfg.get("input_activations")
|
||||
if not wq or not aq:
|
||||
# print("incomplete scheme; fallback to INT8(W8A8)")
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
# 其它分流按需;默认回落:
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
|
||||
# copied from vllm 0.9.0
|
||||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
|
||||
# 直接创建默认的量化配置字典,避免 QuantizationArgs 的验证问题
|
||||
# print("Creating default INT8 quantization config for MoE")
|
||||
|
||||
# 创建默认的权重量化配置字典
|
||||
self.weight_quant = type('WeightQuant', (), {
|
||||
'type': 'int',
|
||||
'num_bits': 8,
|
||||
'strategy': 'channel',
|
||||
'group_size': 128,
|
||||
'symmetric': True,
|
||||
'dynamic': False,
|
||||
'actorder': 'none',
|
||||
'observer': None,
|
||||
'observer_kwargs': {},
|
||||
'block_structure': None
|
||||
})()
|
||||
|
||||
# 创建默认的输入激活量化配置字典
|
||||
self.input_quant = type('InputQuant', (), {
|
||||
'type': 'int',
|
||||
'num_bits': 8,
|
||||
'strategy': 'token',
|
||||
'group_size': 128,
|
||||
'symmetric': True,
|
||||
'dynamic': True,
|
||||
'actorder': 'none',
|
||||
'observer': None,
|
||||
'observer_kwargs': {},
|
||||
'block_structure': None
|
||||
})()
|
||||
|
||||
# 修改比较方式,直接比较字符串
|
||||
per_channel = (
|
||||
self.weight_quant.strategy == "channel"
|
||||
and self.input_quant.strategy == "token")
|
||||
if not per_channel:
|
||||
raise ValueError(
|
||||
"For INT8 Fused MoE layers, we require channelwise, "
|
||||
"dynamic per token quantization. Found "
|
||||
f"{self.weight_quant}, {self.input_quant}")
|
||||
|
||||
self.static_input_scales = not self.input_quant.dynamic
|
||||
if self.static_input_scales:
|
||||
raise ValueError(
|
||||
"For INT8 Fused MoE layers, we require channelwise, "
|
||||
"dynamic per token quantization. Found static input scales.")
|
||||
|
||||
def create_weights1(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
# 权重先用浮点占位,便于从 ckpt 加载原始权重
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype), # 通常是 torch.bfloat16
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# 通道 scale:float32 + 二维 [E, out](与 fused_moe/UT 对齐)
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# 输入 scale 动态计算即可
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=torch.int8), # 直接使用 int8
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8), # 直接使用 int8
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# 缩放因子
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# 输入 scale 动态计算
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
@torch.no_grad()
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
return
|
||||
#原始权重转 float32 做统计更稳健
|
||||
w13_f = layer.w13_weight.float()
|
||||
w2_f = layer.w2_weight.float()
|
||||
|
||||
# 每列(abs_max) -> per-column scale(out 维在 dim=1,列在 dim=-1)
|
||||
qmax = 127.0
|
||||
w13_abs_max = torch.amax(torch.abs(w13_f), dim=-1) # [E, 2N]
|
||||
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H]
|
||||
|
||||
w13_scale_2d = torch.clamp(w13_abs_max, min=1e-6) / qmax # [E, 2N], float32
|
||||
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32
|
||||
|
||||
# 量化:用 3D scale 广播,存回 2D scale
|
||||
w13_scale_3d = w13_scale_2d.unsqueeze(-1) # [E, 2N, 1]
|
||||
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1]
|
||||
|
||||
w13_q = torch.round(w13_f / w13_scale_3d).clamp_(-128, 127).to(torch.int8)
|
||||
w2_q = torch.round(w2_f / w2_scale_3d ).clamp_(-128, 127).to(torch.int8)
|
||||
|
||||
# 可选:若你的 fused/kernel 期望 scale 预乘 127(与某些 UT 后端一致),打开下面两行:
|
||||
w13_scale_2d = w13_scale_2d * 127.0
|
||||
w2_scale_2d = w2_scale_2d * 127.0
|
||||
|
||||
# 回写参数:权重 int8;scale 用 float32 + 2D
|
||||
replace_parameter(layer, 'w13_weight', torch.nn.Parameter(w13_q, requires_grad=False))
|
||||
replace_parameter(layer, 'w2_weight', torch.nn.Parameter(w2_q, requires_grad=False))
|
||||
replace_parameter(layer, 'w13_weight_scale',
|
||||
torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False))
|
||||
replace_parameter(layer, 'w2_weight_scale',
|
||||
torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False))
|
||||
|
||||
# 简要检查
|
||||
print(f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False, # 添加这个参数
|
||||
expert_load_view: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
logical_replica_count: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
linear_weights: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
) -> torch.Tensor:
|
||||
|
||||
output = torch.empty_like(x)
|
||||
torch.ops._C.moe_ffn_per_token_block(
|
||||
x=x,
|
||||
inter_weight=layer.w13_weight,
|
||||
inter_scale=layer.w13_weight_scale,
|
||||
outer_weight=layer.w2_weight,
|
||||
outer_scale=layer.w2_weight_scale,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
linear_weights=linear_weights,
|
||||
expert_map=expert_map,
|
||||
activation=activation,
|
||||
output=output,
|
||||
use_expert_parallel=expert_map is not None,
|
||||
ep_size=expert_map.size(0) if expert_map is not None else 1,
|
||||
ep_rank=0,
|
||||
router_logits = router_logits.float()
|
||||
if scoring_func == "softmax":
|
||||
torch.ops._C.moe_softmax_topk_norm(
|
||||
x=router_logits,
|
||||
normed_score=normed_score,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=None,
|
||||
stable=True)
|
||||
elif scoring_func == "sigmoid":
|
||||
torch.ops._C.moe_sigmoid_group_topk_norm(
|
||||
x=router_logits,
|
||||
norm_score=normed_score,
|
||||
topk_index=topk_ids,
|
||||
block_static=block_statistic,
|
||||
bias=e_score_correction_bias,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scale=routed_scaling_factor,
|
||||
)
|
||||
return output
|
||||
|
||||
print("[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
|
||||
--> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod")
|
||||
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
|
||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||
|
||||
y = torch.empty(M,top_k,
|
||||
layer.w13_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
moe_expand = moe_expand.view(M * top_k, hidden_dim)
|
||||
|
||||
x_shape = moe_expand.shape
|
||||
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
|
||||
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device)
|
||||
torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=x_q,
|
||||
x_perchannel_max=x_scale,
|
||||
weight=layer.w13_weight,
|
||||
w_perchannel_max=layer.w13_weight_scale,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=y,
|
||||
topk_ids=topk_ids,
|
||||
# sort_mode=False,
|
||||
act=None)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = (y.shape[:-1] + (d, ))
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.silu_and_mul(out1, y)
|
||||
|
||||
out = torch.empty(M,top_k,
|
||||
layer.w2_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
x_shape = out1.shape
|
||||
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
|
||||
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device)
|
||||
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=x_q,
|
||||
x_perchannel_max=x_scale,
|
||||
weight=layer.w2_weight,
|
||||
w_perchannel_max=layer.w2_weight_scale,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=out,
|
||||
topk_ids=topk_ids,
|
||||
# sort_mode=False,
|
||||
act=None)
|
||||
|
||||
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
|
||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
|
||||
|
||||
torch.ops._C.moe_post(
|
||||
x=out,
|
||||
moe_index=sorted_tokens_idx,
|
||||
normed_scale=normed_score,
|
||||
dequant_scale=dequant_scale,
|
||||
y=output
|
||||
)
|
||||
return output
|
||||
|
||||
CompressedTensorsW8A8Int8MoEMethod.process_weights_after_loading = process_weights_after_loading
|
||||
CompressedTensorsW8A8Int8MoEMethod.apply = apply
|
||||
0
vllm_kunlun/ops/quantization/kernels/__init__.py
Normal file
0
vllm_kunlun/ops/quantization/kernels/__init__.py
Normal file
122
vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py
Normal file
122
vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import CutlassScaledMMLinearKernel
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
|
||||
def can_implement_kunlun(
|
||||
cls, c: ScaledMMLinearLayerConfig=None) -> tuple[bool, Optional[str]]:
|
||||
return True, None
|
||||
|
||||
def klx_process_weights_after_loading(layer: torch.nn.Module) -> None:
|
||||
"""modify scale -> abs max"""
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data * 127, requires_grad=False)
|
||||
|
||||
def process_weights_after_loading_kunlun(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False))
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale,
|
||||
layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer, self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, self.i_s_name)
|
||||
|
||||
if self.config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False))
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, self.i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max -
|
||||
int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(scale, requires_grad=False))
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min -
|
||||
range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(layer, self.i_zp_name,
|
||||
torch.nn.Parameter(azp, requires_grad=False))
|
||||
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
if not self.config.input_symmetric:
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if self.config.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
|
||||
setattr(layer, self.azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False))
|
||||
else:
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
klx_process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights_kunlun(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
x_q, x_scale, out = None, None, None
|
||||
w_t_shape = layer.weight.T.shape
|
||||
if isinstance(x, tuple):
|
||||
x_q, x_scale = x
|
||||
out = torch.empty((x_q.shape[0], w_t_shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
device=x_q.device)
|
||||
else:
|
||||
x_shape = x.shape
|
||||
x_q = torch.empty(x_shape, dtype=torch.int8, device=x.device)
|
||||
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=x.device)
|
||||
out = torch.empty((x_shape[0], w_t_shape[0]),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
torch.ops._C.quant2d(x, x_q, x_scale, force_sdnn=True)
|
||||
torch.ops._C.gemm_I8_I8_bf16_nt(x_q, x_scale, layer.weight.T.data, layer.weight_scale.data, out)
|
||||
return out
|
||||
|
||||
CutlassScaledMMLinearKernel.apply_weights = apply_weights_kunlun
|
||||
CutlassScaledMMLinearKernel.can_implement = can_implement_kunlun
|
||||
CutlassScaledMMLinearKernel.process_weights_after_loading = process_weights_after_loading_kunlun
|
||||
@@ -19,7 +19,9 @@ import torch
|
||||
import xspeedgate_ops
|
||||
import os
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
RotaryEmbedding, YaRNScalingRotaryEmbedding, DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding)
|
||||
RotaryEmbedding, YaRNScalingRotaryEmbedding,
|
||||
DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding,
|
||||
DeepseekScalingRotaryEmbedding)
|
||||
from typing import Optional, Tuple
|
||||
|
||||
def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
@@ -143,12 +145,15 @@ def vllm_kunlun_mrope_forward_cuda(
|
||||
|
||||
return query, key
|
||||
|
||||
DeepseekScalingRotaryEmbedding_forward = DeepseekScalingRotaryEmbedding.forward
|
||||
DeepseekScalingRotaryEmbedding_forward_cuda = DeepseekScalingRotaryEmbedding.forward_cuda
|
||||
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||
DeepseekScalingRotaryEmbedding.forward = DeepseekScalingRotaryEmbedding_forward
|
||||
DeepseekScalingRotaryEmbedding.forward_cuda = DeepseekScalingRotaryEmbedding_forward_cuda
|
||||
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
||||
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
||||
|
||||
|
||||
def Split_Norm_Rope(
|
||||
qkv: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
|
||||
@@ -177,6 +177,8 @@ class KunlunPlatform(Platform):
|
||||
# if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
|
||||
# we default to FlashMLA backend, so we need to force the blocksize
|
||||
# here
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_config,
|
||||
"index_topk")
|
||||
use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \
|
||||
or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
|
||||
from vllm.attention.ops.flashmla import is_flashmla_supported
|
||||
@@ -185,6 +187,11 @@ class KunlunPlatform(Platform):
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLA backend.")
|
||||
if use_sparse and cache_config.block_size != 64:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLASparse "
|
||||
"backend.")
|
||||
|
||||
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
|
||||
and parallel_config.data_parallel_size > 1
|
||||
@@ -224,6 +231,14 @@ class KunlunPlatform(Platform):
|
||||
Returns:
|
||||
str: Class name of the attention backend.
|
||||
"""
|
||||
if use_mla:
|
||||
if use_sparse:
|
||||
logger.info_once("Using Sparse MLA backend on V1 engine.")
|
||||
# return ("vllm.v1.attention.backends.mla.flashmla_sparse."
|
||||
# "FlashMLASparseBackend")
|
||||
return ("vllm_kunlun.v1.attention.backends.mla.flashmla_sparse."
|
||||
"FlashMLASparseBackend")
|
||||
return "vllm_kunlun.v1.attention.backends.mla.flashmla.FlashMLABackend"
|
||||
if use_v1:
|
||||
return "vllm_kunlun.v1.attention.backends.kunlun_attn.KunlunAttentionBackend"
|
||||
elif not use_mla:
|
||||
|
||||
0
vllm_kunlun/v1/attention/backends/mla/__init__.py
Normal file
0
vllm_kunlun/v1/attention/backends/mla/__init__.py
Normal file
1867
vllm_kunlun/v1/attention/backends/mla/common.py
Normal file
1867
vllm_kunlun/v1/attention/backends/mla/common.py
Normal file
File diff suppressed because it is too large
Load Diff
202
vllm_kunlun/v1/attention/backends/mla/flashmla.py
Normal file
202
vllm_kunlun/v1/attention/backends/mla/flashmla.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||
from vllm_kunlun.ops.attention.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm_kunlun.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["FlashMLAMetadata"]:
|
||||
return FlashMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
tile_scheduler_metadata: torch.Tensor
|
||||
num_splits: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||
FlashMLAMetadata)
|
||||
|
||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(self.device)
|
||||
num_sms = device_properties.multi_processor_count
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.cg_buf_tile_scheduler_metadata = torch.zeros(
|
||||
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
(num_sms, 8),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.cg_buf_num_splits = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_seqs + 1),
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
get_mla_metadata(
|
||||
seq_lens_device,
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
sm_parts = tile_scheduler_metadata.size(0)
|
||||
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
|
||||
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
|
||||
tile_scheduler_metadata_view = \
|
||||
self.cg_buf_tile_scheduler_metadata[:sm_parts]
|
||||
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = tile_scheduler_metadata_view
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
num_splits_view.copy_(num_splits)
|
||||
# Num splits needs to monotonically increasing
|
||||
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
|
||||
# it needs to monotonically increasing by 1)
|
||||
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
|
||||
num_splits = num_splits_view
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
)
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
is_supported, reason = is_flashmla_supported()
|
||||
assert is_supported, reason
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# TODO: (zyongye) decode function for mla here
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
assert isinstance(q, torch.Tensor)
|
||||
o, lse = flash_mla_with_kvcache(
|
||||
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=attn_metadata.decode.
|
||||
tile_scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
descale_q=layer._q_scale.reshape(1),
|
||||
descale_k=layer._k_scale.reshape(1),
|
||||
)
|
||||
|
||||
return o, lse
|
||||
752
vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py
Normal file
752
vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py
Normal file
@@ -0,0 +1,752 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm_kunlun.ops.attention.flashmla import (flash_mla_sparse_prefill,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
kunlun_flash_mla_with_kvcache)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.distributed import get_tp_group
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
"""
|
||||
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
|
||||
|
||||
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
|
||||
structured as:
|
||||
- **First 512 bytes:** The "quantized NoPE" part, containing 512
|
||||
`float8_e4m3` values.
|
||||
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
|
||||
The first `float32` is the scale for the first 128 `float8_e4m3` values,
|
||||
the second for the next 128, and so on.
|
||||
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
|
||||
part is not quantized for accuracy.
|
||||
"""
|
||||
|
||||
|
||||
def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor:
|
||||
# Convert base-2 LSE to natural-log LSE
|
||||
# Keep FP32 for numerical stability during the merge.
|
||||
return (lse_base2.to(torch.float32) * math.log(2.0))
|
||||
|
||||
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_SPARSE_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type[AttentionMetadata]:
|
||||
return FlashMLASparseMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
|
||||
return FlashMLASparseMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLASparseImpl"]:
|
||||
return FlashMLASparseImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if cache_dtype_str == "fp8_ds_mla":
|
||||
# custom storage fromat is 656 bytes
|
||||
# see FlashMLA readme.md for details
|
||||
return (num_blocks, block_size, 656)
|
||||
else:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLASparsePrefillMetadata:
|
||||
# NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because
|
||||
# the kernel is not from flashmla
|
||||
block_table: torch.Tensor = None
|
||||
has_context: bool = False
|
||||
context_lens: Optional[torch.Tensor] = None
|
||||
|
||||
# Sequence lengths (context + query) for prefill requests
|
||||
# Shape: [num_prefill_reqs]
|
||||
seq_lens: torch.Tensor = None
|
||||
|
||||
# Request ID for each token: -1 for decode tokens, request index
|
||||
# (0, 1, 2, ...) for prefill tokens.
|
||||
# Shape: [num_actual_tokens]
|
||||
request_ids: torch.Tensor = None
|
||||
query_start_loc: torch.Tensor = None
|
||||
query_start_loc_cpu: torch.Tensor = None
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseDecodeAndContextMetadata:
|
||||
scheduler_metadata: torch.Tensor = None
|
||||
num_splits: torch.Tensor = None
|
||||
cache_lens: torch.Tensor = None
|
||||
prefill_context_lengths: Optional[torch.Tensor] = None
|
||||
prefill_new_k_start_locs: Optional[torch.Tensor] = None
|
||||
dummy_block_table: torch.Tensor = None
|
||||
|
||||
seq_lens: torch.Tensor = None
|
||||
seq_lens_cpu: torch.Tensor = None
|
||||
max_seq_len: int = -1 # needed for reshape in spec decode
|
||||
|
||||
def filter_prefill_indices(
|
||||
self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.prefill_context_lengths is not None
|
||||
prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1)
|
||||
context_indices = torch.where(indices < prefill_context_lengths,
|
||||
indices, -1)
|
||||
new_token_indices = torch.where(indices >= prefill_context_lengths,
|
||||
indices - prefill_context_lengths, -1)
|
||||
return context_indices, new_token_indices
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadata:
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
block_table: torch.Tensor
|
||||
req_id_per_token: torch.Tensor
|
||||
block_size: int = 64
|
||||
topk_tokens: int = 2048
|
||||
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
num_prefill_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
|
||||
decode_metadata: Optional[FlashMLASparseDecodeAndContextMetadata] = None
|
||||
prefill_metadata: Optional[MLASparsePrefillMetadata] = None
|
||||
|
||||
@dataclass
|
||||
class FP8KernelMetadata:
|
||||
scheduler_metadata: Optional[torch.Tensor]
|
||||
num_splits: torch.Tensor
|
||||
dummy_block_table: torch.Tensor
|
||||
cache_lens: torch.Tensor
|
||||
|
||||
fp8_extra_metadata: Optional[FP8KernelMetadata] = None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _convert_req_index_to_global_index_kernel(
|
||||
req_id_ptr, # int32 [num_tokens]
|
||||
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
# shapes (compile-time where possible)
|
||||
max_num_blocks_per_req: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, # tile width along columns
|
||||
# strides (in elements)
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
):
|
||||
# program_id(0) -> token_id (row)
|
||||
# program_id(1) -> tile index along columns
|
||||
token_id = tl.program_id(0)
|
||||
tile_id = tl.program_id(1)
|
||||
|
||||
# Each program covers BLOCK_N consecutive columns
|
||||
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# Load request id for this token (no mask: grid is exact)
|
||||
req = tl.load(req_id_ptr + token_id)
|
||||
|
||||
# Load token indices for this tile
|
||||
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
|
||||
tok = tl.load(ti_ptr) # int32
|
||||
|
||||
# Only token == -1 should propagate as -1
|
||||
is_invalid_tok = tok < 0
|
||||
|
||||
# Compute block id and in-block offset
|
||||
block_id = tok // BLOCK_SIZE
|
||||
inblock_off = tok % BLOCK_SIZE
|
||||
|
||||
# Guard block_table access
|
||||
valid_block = block_id < max_num_blocks_per_req
|
||||
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
|
||||
base = tl.load(bt_ptr, mask=valid_block, other=0)
|
||||
|
||||
# If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
|
||||
out_val = tl.where(is_invalid_tok | (~valid_block), -1,
|
||||
base * BLOCK_SIZE + inblock_off)
|
||||
|
||||
# Store results
|
||||
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
|
||||
tl.store(out_ptr_ij, out_val)
|
||||
|
||||
|
||||
def triton_convert_req_index_to_global_index(
|
||||
req_id: torch.Tensor, # int32 [num_tokens]
|
||||
block_table: torch.
|
||||
Tensor, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
BLOCK_SIZE: int = 64,
|
||||
NUM_TOPK_TOKENS: int = 2048,
|
||||
BLOCK_N: int = 128, # tile width along columns
|
||||
):
|
||||
"""
|
||||
out[token_id, indice_id] =
|
||||
block_table[req_id[token_id],
|
||||
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
|
||||
+ token_indices[token_id, indice_id] % BLOCK_SIZE
|
||||
|
||||
Only when token_indices[token_id, indice_id] == -1 do we output -1.
|
||||
For safety, we also output -1 if the derived block_id would be
|
||||
out-of-bounds.
|
||||
"""
|
||||
assert req_id.dtype == torch.int32
|
||||
assert block_table.dtype == torch.int32
|
||||
assert token_indices.dtype == torch.int32
|
||||
assert token_indices.shape[1] == NUM_TOPK_TOKENS
|
||||
assert NUM_TOPK_TOKENS % BLOCK_N == 0, \
|
||||
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \
|
||||
f"BLOCK_N ({BLOCK_N})"
|
||||
|
||||
num_tokens = req_id.shape[0]
|
||||
num_requests, max_num_blocks_per_req = block_table.shape
|
||||
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
|
||||
|
||||
# Ensure contiguous tensors on the same device
|
||||
req_id_c = req_id.contiguous()
|
||||
block_table_c = block_table.contiguous()
|
||||
token_indices_c = token_indices.contiguous()
|
||||
out = torch.empty_like(token_indices_c)
|
||||
|
||||
# Strides in elements
|
||||
bt_stride0, bt_stride1 = block_table_c.stride()
|
||||
ti_stride0, ti_stride1 = token_indices_c.stride()
|
||||
out_stride0, out_stride1 = out.stride()
|
||||
|
||||
# Exact 2D grid: tokens × column tiles
|
||||
grid = (num_tokens, tiles_per_row)
|
||||
|
||||
_convert_req_index_to_global_index_kernel[grid](
|
||||
req_id_c,
|
||||
block_table_c,
|
||||
token_indices_c,
|
||||
out,
|
||||
# shapes / constexprs
|
||||
max_num_blocks_per_req,
|
||||
BLOCK_SIZE,
|
||||
BLOCK_N,
|
||||
# strides
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
)
|
||||
return out
|
||||
|
||||
def kunlun_convert_req_index_to_global_index(
|
||||
req_id: torch.Tensor, # int32 [num_tokens]
|
||||
block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
BLOCK_SIZE: int = 64,
|
||||
NUM_TOPK_TOKENS: int = 2048,
|
||||
):
|
||||
assert req_id.dtype == torch.int32
|
||||
assert block_table.dtype == torch.int32
|
||||
assert token_indices.dtype == torch.int32
|
||||
assert token_indices.shape[1] == NUM_TOPK_TOKENS
|
||||
|
||||
num_tokens = req_id.shape[0]
|
||||
num_requests, max_num_blocks_per_req = block_table.shape
|
||||
|
||||
out = torch.zeros_like(token_indices)
|
||||
|
||||
# Compute block_id and inblock_off for all tokens at once
|
||||
block_id = token_indices // BLOCK_SIZE
|
||||
inblock_off = token_indices % BLOCK_SIZE
|
||||
|
||||
# Create mask for invalid tokens (tok < 0)
|
||||
invalid_tok_mask = token_indices < 0
|
||||
|
||||
# Create mask for out-of-bounds block_id
|
||||
oob_block_mask = block_id >= max_num_blocks_per_req
|
||||
|
||||
# Combine masks - output -1 for either condition
|
||||
invalid_mask = invalid_tok_mask | oob_block_mask
|
||||
|
||||
# Get request IDs expanded to match token_indices shape
|
||||
req_ids_expanded = req_id.unsqueeze(1).expand(-1, NUM_TOPK_TOKENS)
|
||||
|
||||
# Gather base addresses from block_table
|
||||
# Clamp block_id to avoid index errors (we'll mask these out anyway)
|
||||
block_id_clamped = torch.clamp(block_id, 0, max_num_blocks_per_req - 1)
|
||||
|
||||
# Use advanced indexing to get base addresses
|
||||
base_addrs = block_table[req_ids_expanded, block_id_clamped]
|
||||
|
||||
# Compute the global indices
|
||||
global_indices = base_addrs * BLOCK_SIZE + inblock_off
|
||||
|
||||
# Apply mask: set invalid positions to -1
|
||||
out = torch.where(invalid_mask, torch.tensor(-1, dtype=torch.int32, device=token_indices.device), global_indices)
|
||||
|
||||
return out
|
||||
|
||||
def kunlun_concat_and_cache_mla(
|
||||
kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
|
||||
k_pe: torch.Tensor, #[num_tokens, pe_dim]
|
||||
kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
||||
slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
|
||||
kv_cache_dtype: str,
|
||||
scale: torch.Tensor
|
||||
):
|
||||
num_tokens = slot_mapping.shape[0]
|
||||
kv_lora_rank = kv_c.shape[1]
|
||||
pe_dim = k_pe.shape[1]
|
||||
block_size = kv_cache.shape[1]
|
||||
|
||||
def kunlun_fp8_ds_mla():
|
||||
for token_idx in range(num_tokens):
|
||||
slot = slot_mapping[token_idx].item()
|
||||
if slot < 0: continue
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
kv_c_i = kv_c[token_idx].view(4,kv_lora_rank//4).contiguous()
|
||||
kv_c_i_int8 = torch.zeros(
|
||||
kv_c_i.shape,
|
||||
device=kv_c.device,
|
||||
dtype=torch.int8,
|
||||
)
|
||||
kv_c_i_scale = torch.zeros(
|
||||
[kv_c_i.shape[0], 1],
|
||||
device=kv_c.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
torch.ops._C.quant2d(kv_c_i, kv_c_i_int8, kv_c_i_scale, force_sdnn=True)
|
||||
kv_c_i_scale /= 127
|
||||
kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c_i_int8.view(-1).view(torch.uint8).contiguous()
|
||||
kv_cache[block_idx, block_offset, kv_lora_rank:kv_lora_rank + 16] = kv_c_i_scale.view(-1).view(torch.uint8).contiguous()
|
||||
kv_cache[block_idx, block_offset, kv_lora_rank+16:] = k_pe[token_idx, :].view(torch.uint8).contiguous()
|
||||
|
||||
def kunlun_mla():
|
||||
for token_idx in range(num_tokens):
|
||||
slot = slot_mapping[token_idx].item()
|
||||
if slot < 0: continue
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c[token_idx, :].contiguous()
|
||||
kv_cache[block_idx, block_offset, kv_lora_rank:] = k_pe[token_idx, :].contiguous()
|
||||
|
||||
if (kv_cache_dtype == "fp8_ds_mla"):
|
||||
assert kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla"
|
||||
assert pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla"
|
||||
assert kv_cache.shape[2] == 656 // kv_cache.element_size(), "kv_cache.shape[2] must be 656 bytes for fp8_ds_mla"
|
||||
assert kv_c.element_size() == 2, "kv_c.element_size() must be 2 for fp8_ds_mla"
|
||||
assert k_pe.element_size() == 2, "k_pe.element_size() must be 2 for fp8_ds_mla"
|
||||
kunlun_fp8_ds_mla()
|
||||
else:
|
||||
assert kv_cache.shape[2] == kv_lora_rank + pe_dim
|
||||
kunlun_mla()
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashMLASparseMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
self.layer_names = layer_names
|
||||
cache_config = vllm_config.cache_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.device = device
|
||||
|
||||
# Treat requests with query length <= 1 as decodes to match the
|
||||
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
|
||||
# 从最新版本vllm中引入的
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
sm_count = props.multi_processor_count
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(
|
||||
parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
|
||||
|
||||
self.topk_tokens_tensor = torch.tensor([self.topk_tokens],
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
# self.max_model_len_tensor = torch.tensor(
|
||||
# [self.model_config.max_model_len],
|
||||
# device=device,
|
||||
# dtype=torch.int32)
|
||||
|
||||
# this is ignored by `flash_mla_with_kvcache` if indices not None
|
||||
self.dummy_block_table = torch.empty((1, 1),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
# Equation taken from FlashMLA/csrc/pybind.cpp
|
||||
h_q, h_k = self.num_heads, 1
|
||||
s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest
|
||||
max_num_sm_parts = int(
|
||||
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1))
|
||||
if current_platform.is_device_capability(100):
|
||||
max_num_sm_parts *= 2
|
||||
self.tile_scheduler_metadata_buffer = torch.zeros(
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
# see: FlashMLA/csrc/params.h
|
||||
(max_num_sm_parts, 8),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.num_splits_buffer = torch.zeros(
|
||||
# We pack all the tokens into one batch for sparse attention.
|
||||
# Otherwise, we can exceed the sm of `get_mla_metadata`.
|
||||
(
|
||||
2, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.req_id_per_token_buffer = torch.zeros(
|
||||
(vllm_config.scheduler_config.max_num_batched_tokens, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> FlashMLASparseMetadata:
|
||||
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
|
||||
dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths)
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\
|
||||
.copy_(torch.from_numpy(req_id_per_token), non_blocking=True)
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
fp8_extra_metadata = None
|
||||
|
||||
if self.use_fp8_kv_cache:
|
||||
cache_seqlens_cpu, cache_seqlens = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor,
|
||||
)
|
||||
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
scheduler_metadata=None,
|
||||
num_splits=None,
|
||||
# cache_lens and block_table are basically unused in sparse case
|
||||
# but the decode kernel will treat -1 and indices >= cache_lens
|
||||
# as invalid so we make sure cache_lens is large enough to not
|
||||
# accidentally mark indices invalid, we will use -1 exclusively
|
||||
# to mark invalid indices
|
||||
cache_lens=cache_seqlens_cpu,
|
||||
dummy_block_table=self.dummy_block_table)
|
||||
|
||||
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold or 1,
|
||||
require_uniform=True,
|
||||
)
|
||||
)
|
||||
|
||||
# For pure decode batches, prefill_request_id will be None
|
||||
# For mixed batches, it will have -1 for decode and request_id for prefill
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
prefill_metadata = MLASparsePrefillMetadata(
|
||||
query_start_loc = common_attn_metadata.query_start_loc[num_decodes:] - common_attn_metadata.query_start_loc[num_decodes], #因为prefiil、decode请求是分离,所以需要对q进行切分,故需调整该值
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[num_decodes:] - common_attn_metadata.query_start_loc_cpu[num_decodes],
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu[:num_decodes].max())
|
||||
|
||||
decode_metadata = FlashMLASparseDecodeAndContextMetadata(
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens_cpu[:num_decodes],
|
||||
)
|
||||
|
||||
|
||||
metadata = FlashMLASparseMetadata(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
fp8_extra_metadata=fp8_extra_metadata,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
decode_metadata=decode_metadata,
|
||||
prefill_metadata=prefill_metadata
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: Optional[torch.Tensor] = None,
|
||||
indexer: Optional["Indexer"] = None,
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer = indexer.topk_indices_buffer
|
||||
self.padding = 128 if current_platform.is_device_capability(
|
||||
100) else 64
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
|
||||
|
||||
num_tokens = q.shape[0]
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.contiguous().view(
|
||||
-1, kv_c_and_k_pe_cache.shape[-1])
|
||||
|
||||
# num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
def _bf16_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
|
||||
# Reshape q: (num_decode_tokens, num_heads, head_dim)
|
||||
# -> (num_decodes, seq_len, num_heads, head_dim)
|
||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||
seq_len = q.shape[1]
|
||||
# Reshape topk_indices: (num_decode_tokens, topk)
|
||||
# -> (num_decodes, seq_len, topk)
|
||||
topk_indices = topk_indices.view(num_decodes, seq_len, -1)
|
||||
decode_metadata = attn_metadata.decode_metadata
|
||||
_attn_out, _, _ = kunlun_flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache,
|
||||
head_dim_v=512,
|
||||
cache_seqlens=decode_metadata.seq_lens,
|
||||
cache_seqlens_cpu=decode_metadata.seq_lens_cpu,
|
||||
is_fp8_kvcache=False,
|
||||
indices=topk_indices,
|
||||
softmax_scale=self.softmax_scale,
|
||||
max_seq_kv=decode_metadata.max_seq_len
|
||||
)
|
||||
# Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
|
||||
# -> (num_decode_tokens, num_heads, head_dim_v)
|
||||
return reshape_attn_output_for_spec_decode(_attn_out)
|
||||
|
||||
def _bf16_prefill(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
|
||||
prefill_metadata = attn_metadata.prefill_metadata
|
||||
topk_indices = topk_indices.view(num_prefill_tokens, 1, -1)
|
||||
# NOTE: 只有prefill阶段attn_metadata.query_start_loc是符合klx算子需求的
|
||||
_attn_out = flash_mla_sparse_prefill(
|
||||
q=q,
|
||||
kv=kv_c_and_k_pe_cache,
|
||||
indices=topk_indices,
|
||||
sm_scale=self.softmax_scale,
|
||||
q_lod_xpu=prefill_metadata.query_start_loc,
|
||||
q_lod_cpu=prefill_metadata.query_start_loc_cpu
|
||||
)[0]
|
||||
return _attn_out
|
||||
|
||||
topk_indices_global = torch.ops.xspeedgate_ops.convert_req_index_to_global_index(
|
||||
req_id=attn_metadata.req_id_per_token,
|
||||
block_table=attn_metadata.block_table,
|
||||
token_indices=topk_indices,
|
||||
block_size=attn_metadata.block_size,
|
||||
num_topk_tokens=attn_metadata.topk_tokens,
|
||||
)
|
||||
|
||||
attn_out = torch.empty(
|
||||
(num_tokens, self.num_heads, self.kv_lora_rank),
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
if has_prefill:
|
||||
prefill_q = q[num_decode_tokens:]
|
||||
prefill_topk_indices_global = topk_indices_global[num_decode_tokens:]
|
||||
attn_out[num_decode_tokens:] = _bf16_prefill(prefill_q, prefill_topk_indices_global)
|
||||
|
||||
# 处理decode部分 - 需要正确的block table映射print
|
||||
if has_decode:
|
||||
decode_q = q[:num_decode_tokens]
|
||||
decode_topk_indices_global = topk_indices_global[:num_decode_tokens]
|
||||
attn_out[:num_decode_tokens] = _bf16_decode(decode_q, decode_topk_indices_global)
|
||||
|
||||
return attn_out
|
||||
|
||||
|
||||
def _forward_fp8_kv(self, q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
|
||||
# TODO: When fwd_kvcache_mla supports uint8 kv cache, execute this function.
|
||||
assert attn_metadata.fp8_extra_metadata is not None
|
||||
extra_metadata = attn_metadata.fp8_extra_metadata
|
||||
|
||||
_attn_out, _ = flash_mla_with_kvcache(
|
||||
q=q.unsqueeze(0), # unsqueeze to add batch_dim
|
||||
k_cache=kv_c_and_k_pe_cache,
|
||||
block_table=extra_metadata.dummy_block_table,
|
||||
head_dim_v=512,
|
||||
cache_seqlens=extra_metadata.cache_lens,
|
||||
tile_scheduler_metadata=extra_metadata.scheduler_metadata, # None
|
||||
num_splits=extra_metadata.num_splits, # None
|
||||
is_fp8_kvcache=True,
|
||||
indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim
|
||||
softmax_scale=self.softmax_scale,
|
||||
max_seq_kv=attn_metadata.max_seq_len
|
||||
)
|
||||
|
||||
return _attn_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||
# MQA 576/512 approach for both prefill and decode
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for MLACommonImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
ql_nope = ql_nope.transpose(0, 1)
|
||||
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
q = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
torch.ops._C.concat_and_cache_mla(
|
||||
kv_c=k_c_normed,
|
||||
k_pe=k_pe.squeeze(1),
|
||||
kv_cache=kv_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping.flatten(),
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype != "fp8_ds_mla":
|
||||
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
|
||||
attn_metadata)
|
||||
else:
|
||||
# attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
|
||||
# attn_metadata)
|
||||
raise NotImplementedError
|
||||
|
||||
self._v_up_proj(attn_out, out=output[:num_actual_toks])
|
||||
return output
|
||||
133
vllm_kunlun/v1/attention/backends/mla/indexer.py
Normal file
133
vllm_kunlun/v1/attention/backends/mla/indexer.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
|
||||
DeepseekV32IndexerMetadataBuilder,
|
||||
DeepseekV32IndexerPrefillMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class DeepSeekV32IndexerDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
seq_lens_cpu: torch.Tensor
|
||||
decode_lens: torch.Tensor
|
||||
requires_padding: bool
|
||||
schedule_metadata: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerMetadata:
|
||||
|
||||
# FIXME (zyongye)
|
||||
# hacky way to access the data now, need to be in chunked meta
|
||||
seq_lens: torch.Tensor
|
||||
seq_lens_cpu: torch.Tensor
|
||||
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
# The dimension of the attention heads
|
||||
head_dim: int
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
|
||||
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
|
||||
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
|
||||
|
||||
def kunlun_build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
|
||||
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
self.max_prefill_buffer_size,
|
||||
num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start, reqs_end, query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.block_table_tensor)
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks, )
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
|
||||
out=self.decode_lens_buffer[:num_decodes])
|
||||
decode_lens = self.decode_lens_buffer[:num_decodes]
|
||||
decode_lens_cpu = torch.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max()
|
||||
> decode_lens_cpu.min()).item()
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.
|
||||
block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(),
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
schedule_metadata=self.scheduler_metadata_buffer,
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens.cpu(),
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
head_dim=128,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||
return attn_metadata
|
||||
|
||||
DeepseekV32IndexerMetadataBuilder.build = kunlun_build
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
@@ -24,6 +24,7 @@ class TopKTopPSampler(nn.Module):
|
||||
|
||||
def __init__(self, logprobs_mode):
|
||||
super().__init__()
|
||||
self.logprobs_mode = logprobs_mode
|
||||
logger.info_once(
|
||||
"Using FlashInfer for top-p & top-k sampling.")
|
||||
self.forward = self.forward_kunlun
|
||||
@@ -40,9 +41,14 @@ class TopKTopPSampler(nn.Module):
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = apply_top_k_top_p(logits, k, p)
|
||||
logits = self.apply_top_k_top_p(logits, k, p)
|
||||
logits_to_return = None
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
logits_to_return = logits
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators), None
|
||||
return random_sample(probs, generators), logits_to_return
|
||||
|
||||
def forward_kunlun(
|
||||
self,
|
||||
@@ -52,16 +58,13 @@ class TopKTopPSampler(nn.Module):
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""More optimized implementation for top-k and top-p sampling."""
|
||||
if k is None and p is None:
|
||||
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||
# not needed. This is because `random_sample` does not require
|
||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators), None
|
||||
if generators:
|
||||
logger.warning_once("FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
"PyTorch-native implementation.")
|
||||
if (k is None and p is None) or generators:
|
||||
if generators:
|
||||
logger.debug_once(
|
||||
"FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
"PyTorch-native implementation."
|
||||
)
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
# flashinfer sampling functions expect contiguous logits.
|
||||
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
|
||||
@@ -196,6 +199,7 @@ def flashinfer_sample(
|
||||
probs, top_k=k, deterministic=True)
|
||||
else:
|
||||
# Both top-k and top-p.
|
||||
k = k.to(torch.int32)
|
||||
next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs(
|
||||
probs, top_k=k, top_p=p, deterministic=True)
|
||||
|
||||
|
||||
344
vllm_kunlun/v1/worker/utils.py
Normal file
344
vllm_kunlun/v1/worker/utils.py
Normal file
@@ -0,0 +1,344 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
from vllm.multimodal.registry import MultiModalRegistry
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
|
||||
class MultiModalBudget:
|
||||
"""Helper class to calculate budget information for multi-modal models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
mm_registry: MultiModalRegistry,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model_config = model_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.mm_registry = mm_registry
|
||||
self.cache = cache = processor_only_cache_from_config(
|
||||
model_config, mm_registry)
|
||||
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config,
|
||||
cache=cache)
|
||||
|
||||
max_tokens_by_modality = mm_registry \
|
||||
.get_max_tokens_per_item_by_nonzero_modality(model_config,
|
||||
cache=cache)
|
||||
|
||||
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
|
||||
scheduler_config,
|
||||
max_tokens_by_modality,
|
||||
)
|
||||
|
||||
self.encoder_compute_budget = encoder_compute_budget
|
||||
self.encoder_cache_size = encoder_cache_size
|
||||
|
||||
max_items_per_prompt_by_modality = dict[str, int]()
|
||||
max_items_per_batch_by_modality = dict[str, int]()
|
||||
|
||||
for modality, max_tokens in max_tokens_by_modality.items():
|
||||
(
|
||||
max_items_per_prompt,
|
||||
max_items_per_batch,
|
||||
) = self.get_max_items(modality, max_tokens)
|
||||
|
||||
max_items_per_prompt_by_modality[modality] = max_items_per_prompt
|
||||
max_items_per_batch_by_modality[modality] = max_items_per_batch
|
||||
|
||||
self.max_tokens_by_modality = max_tokens_by_modality
|
||||
self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
|
||||
self.max_items_per_batch_by_modality = max_items_per_batch_by_modality
|
||||
|
||||
def get_modality_with_max_tokens(self) -> str:
|
||||
max_tokens_by_modality = self.max_tokens_by_modality
|
||||
modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
|
||||
|
||||
return modality
|
||||
|
||||
def get_encoder_budget(self) -> int:
|
||||
return min(self.encoder_compute_budget, self.encoder_cache_size)
|
||||
|
||||
def get_max_items(
|
||||
self,
|
||||
modality: str,
|
||||
max_tokens_per_item: int,
|
||||
) -> tuple[int, int]:
|
||||
if max_tokens_per_item == 0:
|
||||
return 0, 0
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the encoder budget.
|
||||
encoder_budget = self.get_encoder_budget()
|
||||
|
||||
# TODO: handle encoder-decoder models once we support them.
|
||||
if encoder_budget == 0:
|
||||
return 0, 0
|
||||
|
||||
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the decoder budget.
|
||||
mm_limit = self.mm_limits[modality]
|
||||
|
||||
max_items_per_prompt = max(
|
||||
1,
|
||||
min(mm_limit, self.max_model_len // max_tokens_per_item),
|
||||
)
|
||||
|
||||
scheduler_config = self.scheduler_config
|
||||
max_num_reqs = self.max_num_reqs
|
||||
|
||||
if not scheduler_config.enable_chunked_prefill:
|
||||
max_num_reqs = min(
|
||||
max_num_reqs,
|
||||
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
|
||||
)
|
||||
|
||||
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
|
||||
|
||||
max_items_per_batch = max(
|
||||
1,
|
||||
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
|
||||
)
|
||||
|
||||
return max_items_per_prompt, max_items_per_batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
# When ubatching is enabled we will have a metadata builder for each ubatch
|
||||
# so that if they use internal persistant buffers for cudagraphs, and they
|
||||
# won't have to worry about conflicting with the other ubatches.
|
||||
metadata_builders: list[AttentionMetadataBuilder]
|
||||
layer_names: list[str]
|
||||
kv_cache_spec: KVCacheSpec
|
||||
|
||||
@staticmethod
|
||||
def create_with_metadata_builders(
|
||||
backend: type[AttentionBackend],
|
||||
layer_names: list[str],
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
num_metadata_builders: int = 1,
|
||||
) -> 'AttentionGroup':
|
||||
metadata_builders = [
|
||||
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config,
|
||||
device)
|
||||
for _ in range(num_metadata_builders)
|
||||
]
|
||||
return AttentionGroup(backend, metadata_builders, layer_names,
|
||||
kv_cache_spec)
|
||||
|
||||
def get_metadata_builder(self,
|
||||
ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
||||
assert len(self.metadata_builders) > ubatch_id
|
||||
return self.metadata_builders[ubatch_id]
|
||||
|
||||
|
||||
def sanity_check_mm_encoder_outputs(
|
||||
mm_embeddings: MultiModalEmbeddings,
|
||||
expected_num_items: int,
|
||||
) -> None:
|
||||
"""
|
||||
Perform sanity checks for the result of
|
||||
[`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
|
||||
"""
|
||||
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
|
||||
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
|
||||
f"or a single 3D tensor, but got {type(mm_embeddings)} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
assert len(mm_embeddings) == expected_num_items, (
|
||||
"Expected number of multimodal embeddings to match number of "
|
||||
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
assert all(e.ndim == 2 for e in mm_embeddings), (
|
||||
"Expected multimodal embeddings to be a sequence of 2D tensors, "
|
||||
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
|
||||
def scatter_mm_placeholders(
|
||||
embeds: torch.Tensor,
|
||||
is_embed: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Scatter the multimodal embeddings into a contiguous tensor that represents
|
||||
the placeholder tokens.
|
||||
|
||||
[`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][].
|
||||
|
||||
Args:
|
||||
embeds: The multimodal embeddings.
|
||||
Shape: `(num_embeds, embed_dim)`
|
||||
is_embed: A boolean mask indicating which positions in the placeholder
|
||||
tokens need to be filled with multimodal embeddings.
|
||||
Shape: `(num_placeholders, num_embeds)`
|
||||
"""
|
||||
if is_embed is None:
|
||||
return embeds
|
||||
|
||||
placeholders = embeds.new_full(
|
||||
(is_embed.shape[0], embeds.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
placeholders[is_embed] = embeds
|
||||
return placeholders
|
||||
|
||||
|
||||
def gather_mm_placeholders(
|
||||
placeholders: torch.Tensor,
|
||||
is_embed: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reconstructs the embeddings from the placeholder tokens.
|
||||
|
||||
This is the operation of [`scatter_mm_placeholders`]
|
||||
[vllm.v1.worker.utils.scatter_mm_placeholders].
|
||||
"""
|
||||
if is_embed is None:
|
||||
return placeholders
|
||||
|
||||
return placeholders[is_embed]
|
||||
|
||||
|
||||
def add_kv_sharing_layers_to_kv_cache_groups(
|
||||
shared_kv_cache_layers: dict[str, str],
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
runner_only_attn_layers: Optional[set[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
|
||||
for layers that do not allocate its own KV cache, based on the mapping in
|
||||
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
|
||||
group, which is needed to ensure that attention metadata is assigned later.
|
||||
|
||||
Args:
|
||||
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
|
||||
If an Attention layer `layer_name` is in the keys of this dict, it
|
||||
means this layer will perform attention using the keys and values
|
||||
from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||
kv_cache_groups: The KV cache groups of the model.
|
||||
"""
|
||||
layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
|
||||
for kv_cache_group in kv_cache_groups:
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
layer_to_kv_cache_group[layer_name] = kv_cache_group
|
||||
|
||||
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
||||
tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
|
||||
tgt_kv_cache_group.layer_names.append(layer_name)
|
||||
|
||||
if runner_only_attn_layers is not None:
|
||||
runner_only_attn_layers.add(layer_name)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, "Attention"],
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
num_attn_module: Optional[int] = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches:
|
||||
index2name[extract_layer_index(layer_name,
|
||||
num_attn_module)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
if len(layer_names) > 1:
|
||||
# One typical case is encoder-decoder model, e.g., bart.
|
||||
# The cross attention and self attention in the same decoder layer
|
||||
# has different layer_name but the same layer_index.
|
||||
|
||||
# TODO - analyze where runner_kv_caches is used and the right
|
||||
# way to ensure it properly reflects multiple attention layers
|
||||
# in the same decoder block.
|
||||
if current_platform.is_kunlun() or current_platform.is_cuda() or current_platform.is_xpu():
|
||||
# We know that the GPU runner is not impacted by this
|
||||
# case. Some test code depends on runner_kv_caches, but
|
||||
# not in a way that's impacted by ignoring this.
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
layer_name = layer_names[0]
|
||||
runner_kv_caches.append(kv_caches[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
|
||||
def is_residual_scattered_for_sp(vllm_config: VllmConfig,
|
||||
num_input_tokens: int) -> bool:
|
||||
"""Check if the residual tensor is scattered for sequence parallelism.
|
||||
|
||||
The residual tensor is scattered across tensor parallel ranks when sequence
|
||||
parallelism and tensor parallelism is enabled, and the number of
|
||||
input tokens is one of the compilation sizes.
|
||||
"""
|
||||
if not vllm_config.compilation_config.pass_config.\
|
||||
enable_sequence_parallelism:
|
||||
return False
|
||||
|
||||
tp = vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
if tp == 1:
|
||||
return False
|
||||
|
||||
# When sequence parallelism is enabled, we always pad num_input_tokens
|
||||
# to be a multiple of tensor_parallel_size (tp) earlier.
|
||||
assert num_input_tokens % tp == 0
|
||||
|
||||
# Currently, SP is only enabled for static size fx graphs.
|
||||
return (num_input_tokens in vllm_config.compilation_config.compile_sizes)
|
||||
@@ -1493,4 +1493,45 @@ def _fake_gptq_shuffle(
|
||||
return None
|
||||
|
||||
|
||||
gptq_shuffle.register_fake(_fake_gptq_shuffle)
|
||||
gptq_shuffle.register_fake(_fake_gptq_shuffle)
|
||||
|
||||
##################################################
|
||||
# ---------------- concat_and_cache_mla ------------------
|
||||
##################################################
|
||||
@custom_op("_C::concat_and_cache_mla", mutates_args=())
|
||||
def concat_and_cache_mla(
|
||||
kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
|
||||
k_pe: torch.Tensor, #[num_tokens, pe_dim]
|
||||
kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
||||
slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
|
||||
) -> None:
|
||||
xtorch_ops.concat_and_cache_mla(
|
||||
kv_c=kv_c,
|
||||
k_pe=k_pe,
|
||||
slot_mapping=slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
|
||||
@impl("_C::concat_and_cache_mla", "CUDA")
|
||||
def concat_and_cache_mla_cuda(
|
||||
kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
|
||||
k_pe: torch.Tensor, #[num_tokens, pe_dim]
|
||||
kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
||||
slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
|
||||
) -> None:
|
||||
xtorch_ops.concat_and_cache_mla(
|
||||
kv_c=kv_c,
|
||||
k_pe=k_pe,
|
||||
slot_mapping=slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
|
||||
def _fake_concat_and_cache_mla(
|
||||
kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
|
||||
k_pe: torch.Tensor, #[num_tokens, pe_dim]
|
||||
kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
||||
slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
concat_and_cache_mla.register_fake(_fake_concat_and_cache_mla)
|
||||
Reference in New Issue
Block a user