[Feature] support deepseek v3/r1/v3.2 (#78)

* [Feature] support deepseek v3/r1/v3.2

* fix gpt_oss

* update readme

* update readme

---------

Co-authored-by: hanhaowen <hanhaowen@baidu.com>
This commit is contained in:
baoqian426
2026-01-05 22:55:35 +08:00
committed by GitHub
parent 07bc24a555
commit ee0f50e68f
27 changed files with 5760 additions and 621 deletions

View File

@@ -20,4 +20,8 @@ import vllm_kunlun.ops.layernorm
import vllm_kunlun.ops.quantization.awq
import vllm_kunlun.ops.quantization.gptq
import vllm_kunlun.ops.vocab_parallel_embedding
import vllm_kunlun.ops.linear
import vllm_kunlun.ops.linear
import vllm_kunlun.ops.quantization.kernels.scaled_mm.cutlass
import vllm_kunlun.ops.vocab_parallel_embedding
import vllm_kunlun.ops.quantization.compressed_tensors_moe
import vllm_kunlun.ops.fused_moe.layer

View File

@@ -417,7 +417,6 @@ class KunlunOps:
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
linear_weights: torch.Tensor,
ep_rank: int,
moe_top_k: int,
renormalize: bool,

View File

@@ -108,7 +108,7 @@ class SiluAndMul(CustomOp):
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
xtorch_ops.swiglu(x, out)
torch.ops._C.silu_and_mul(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

View File

@@ -0,0 +1,260 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
from typing import Optional, Tuple
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
import xtorch_ops
logger = init_logger(__name__)
if current_platform.is_cuda():
try:
import vllm._flashmla_C # noqa: F401
_flashmla_C_AVAILABLE = True
except ImportError:
_flashmla_C_AVAILABLE = False
else:
_flashmla_C_AVAILABLE = False
if current_platform.is_cuda():
try:
import vllm._flashmla_extension_C # noqa: F401
_flashmla_extension_C_AVAILABLE = True
except ImportError:
_flashmla_extension_C_AVAILABLE = False
else:
_flashmla_extension_C_AVAILABLE = False
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
"""
Return: is_supported_flag, unsupported_reason (optional).
"""
return True, None
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int = 1,
num_heads_k: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
# return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
cache_seqlens_cpu = cache_seqlens.cpu()
return cache_seqlens_cpu, cache_seqlens
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
softmax_lse = None
out = torch.ones(q.size(0), q.size(1), q.size(2), head_dim_v, dtype= q.dtype, device=q.device)
kv_lora_rank = head_dim_v
qk_rope_head_dim = q.size(3) - head_dim_v
head_dim = k_cache.shape[3]
page_block_size = k_cache.shape[1]
k_cache = k_cache.view(-1, 1, page_block_size, head_dim)
# todo: optimize memcp
# q_c = q[..., : kv_lora_rank].contiguous()
# q_r = q[..., kv_lora_rank :].contiguous()
is_context = False
vo_head_dim = -1
xtorch_ops.paged_attention(out,
q,
k_cache, None,
block_table,
tile_scheduler_metadata, # context_lens_cpu
num_splits, # context_lens_xpu
is_context,
causal,
vo_head_dim,
kv_lora_rank,
qk_rope_head_dim,
softmax_scale,
q_r=q)
return out, softmax_lse
def kunlun_flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
cache_seqlens_cpu: torch.Tensor,
head_dim_v: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
max_seq_kv: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_tokens_kv, head_dim).
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format.
indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
max_seq_kv: seq中最大的kv长度
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
max_logits: (batch_size, seq_len_q, num_heads_q), torch.float32.
p_sums: (batch_size, seq_len_q, num_heads_q), torch.float32.
"""
assert not is_fp8_kvcache, "By now, the kernel does not support uint8 kv cache."
assert q.shape[1] <= 2, "xtorch_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if indices is not None:
# NOTE (zyongye): sparse attention is also causal
# since it only attend to the tokens before
# but here `causal` should not be specified
assert not causal, \
"causal must be `false` if sparse attention is enabled."
q_r, pe_cache = None, None # 当q_r和pe_cache为空时为packed模式
batch_size, seq_len_q, num_heads_q, head_dim = q.shape
kv_lora_rank = head_dim_v
rope_head_dim = head_dim - kv_lora_rank
out = torch.zeros([batch_size, seq_len_q, num_heads_q, kv_lora_rank],
dtype=q.dtype, device=q.device)
max_logits = torch.zeros([batch_size, seq_len_q, num_heads_q],
dtype=torch.float32, device=q.device)
p_sums = torch.zeros([batch_size, seq_len_q, num_heads_q],
dtype=torch.float32, device=q.device)
xtorch_ops.fwd_kvcache_mla(
q_c=q,
kv_cache=k_cache,
indices=indices,
kv_lod_cpu=cache_seqlens_cpu,
max_seq_kv=max_seq_kv,
softmax_scale=softmax_scale,
# q_r=q_r,
# pe_cache=pe_cache,
out=out,
max_logits=max_logits,
p_sums=p_sums,
kv_lod_xpu=cache_seqlens,
)
return out, max_logits, p_sums
def flash_mla_sparse_prefill(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
q_lod_xpu: torch.Tensor,
q_lod_cpu: torch.Tensor,
d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
- q: [s_q, h_q, d_qk], bfloat16
- kv: [s_kv, d_qk], bfloat16
- indices: [s_q, h_kv, topk], int32.
Invalid indices should be set to -1 or numbers >= s_kv
- sm_scale: float
- q_lod_xpu: [batch+1], int32, q的每个seq长度的累加信息, 长度为batch_num + 1 (为空则表示q定长).
- d_v: The dimension of value vectors. Can only be 512
Returns:
- (output, max_logits, lse)
About the definition of output,
max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
s_q, h_q, d_qk = q.shape
out = torch.zeros([s_q, h_q, d_v], dtype=q.dtype, device=q.device)
max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
xtorch_ops.sparse_prefill_fwd_opt(
q=q,
kv=kv,
indices=indices,
qlod_cpu=q_lod_cpu,
qlod_xpu=q_lod_xpu,
kvlod_cpu=q_lod_cpu,
kvlod_xpu=q_lod_xpu,
sm_scale=sm_scale,
d_v=d_v,
is_causal=True, #aiak这个值为true这是为啥
out=out,
max_logits=max_logits,
lse=lse,
)
# NOTE: Compared with torch.ops._flashmla_C.sparse_prefill_fwd,
# out_scale = 1 / math.log2(math.e)
# gpu_max_logits * out_scale = kunlun_lse
# gpu_lse * out_scale = kunlun_lse
return out, max_logits, lse
#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
# return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
# return ....
#

View File

@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import torch
from vllm_kunlun.ops.attention.layer import Attention
# from vllm.attention import Attention
from vllm.config import CacheConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
@dataclass
class MLAModules:
"""Modules used in MLA.
"""
kv_a_layernorm: torch.nn.Module
kv_b_proj: torch.nn.Module
rotary_emb: torch.nn.Module
o_proj: torch.nn.Module
fused_qkv_a_proj: Optional[torch.nn.Module]
kv_a_proj_with_mqa: Optional[torch.nn.Module]
q_a_layernorm: Optional[torch.nn.Module]
q_b_proj: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
indexer: Optional[torch.nn.Module]
is_sparse: bool
topk_indices_buffer: Optional[torch.Tensor]
@CustomOp.register("vllm_kunlun_multi_head_latent_attention")
class MultiHeadLatentAttention(CustomOp):
"""MLA layer registered as CustomOp.
Note that currently MLA ignores the enable/disable mechanism of CustomOp
because there is only one in-tree implementation in forward_native.
TODO: implement this with a new PluggableLayer mechanism.
This class takes positions and hidden_states as input.
The input tensors can either contain prefill tokens or decode tokens.
The class does the following:
1. MLA Preprocess.
2. Perform multi-head attention to prefill tokens and
multi-query attention to decode tokens separately.
3. Return the output tensor.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
mla_modules: MLAModules,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
self.q_a_layernorm = mla_modules.q_a_layernorm
self.q_b_proj = mla_modules.q_b_proj
self.q_proj = mla_modules.q_proj
self.kv_a_layernorm = mla_modules.kv_a_layernorm
self.kv_b_proj = mla_modules.kv_b_proj
self.rotary_emb = mla_modules.rotary_emb
self.o_proj = mla_modules.o_proj
self.indexer = mla_modules.indexer
self.is_sparse = mla_modules.is_sparse
if self.indexer is not None:
assert hasattr(self.indexer, "topk_tokens")
self.topk_tokens = self.indexer.topk_tokens
self.topk_indices_buffer = mla_modules.topk_indices_buffer
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self.mla_attn = Attention(
num_heads=self.num_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scale,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=mla_modules.is_sparse,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=self.kv_b_proj,
indexer=self.indexer,
)
self.prefix = prefix
def forward_native(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
q_c = None
kv_lora = None
if self.q_lora_rank is not None:
assert self.fused_qkv_a_proj is not None, \
"fused_qkv_a_proj is required when q_lora_rank is not None"
assert self.q_a_layernorm is not None, \
"q_a_layernorm is required when q_lora_rank is not None"
assert self.q_b_proj is not None, \
"q_b_proj is required when q_lora_rank is not None"
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
assert self.kv_a_proj_with_mqa is not None, \
"kv_a_proj_with_mqa is required when q_lora_rank is None"
assert self.q_proj is not None, \
"q_proj is required when q_lora_rank is None"
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c)
q = q.view(-1, self.num_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions,
self.rotary_emb)
hidden_states_shape_0 = 0
if isinstance(hidden_states, tuple):
x_q, x_scale = hidden_states
hidden_states_shape_0 = x_q.shape[0]
else:
hidden_states_shape_0 = hidden_states.shape[0]
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states_shape_0,
self.num_heads * self.v_head_dim))
return self.o_proj(attn_out)[0]
def forward_cuda(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)

View File

@@ -0,0 +1,114 @@
import torch
import xtorch_ops
def int8_mqa_logits(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
logits = torch.empty((q.shape[0], kv[0].shape[0]), dtype=torch.float32, device=q.device)
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
xtorch_ops.I8_mqa_logits(
q=q,
fused_kv_cache=kv,
weights=weights,
context_q_lens=(context_q_lens_xpu.cpu(), context_q_lens_xpu),
context_k_lens=(context_k_lens_xpu.cpu(), context_k_lens_xpu),
logits=logits,
clean_logits=True,
use_xfa_boost=False,
)
seq_len_kv = kv[0].shape[0]
# mask参考 https://github.com/vllm-project/vllm/blob/v0.11.0/tests/kernels/attention/test_deepgemm_attention.py 的_ref_fp8_mqa_logits函数的实现
mask_lo = (torch.arange(0, seq_len_kv, device=cu_seqlen_ks.device)[None, :]
>= cu_seqlen_ks[:, None])
mask_hi = (torch.arange(0, seq_len_kv, device=cu_seqlen_ke.device)[None, :]
< cu_seqlen_ke[:, None])
mask = mask_lo & mask_hi
logits = logits.masked_fill(~mask, float("-inf"))
return logits
def int8_paged_mqa_logits(
q_fp8: torch.Tensor,
kv_cache_fp8: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
context_lens_cpu: torch.Tensor,
block_tables: torch.Tensor,
schedule_metadata: torch.Tensor,
max_model_len: int,
) -> torch.Tensor:
"""Compute FP8 MQA logits using paged KV-cache.
Args:
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
4 bytes per (block,pos) store the `float` dequant scale.
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
context_lens: Tensor of shape [B], dtype int32; effective context length
for each batch element.
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
block indices to physical blocks in the paged cache.
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
used to distribute work across SMs.
max_model_len: Maximum sequence length used to size the logits output.
Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""
batch_size, next_n, _, D = q_fp8.shape
num_blocks, block_size, _, _ = kv_cache_fp8.shape
kv_cache_fp8=kv_cache_fp8.view(num_blocks, -1)
k_val = kv_cache_fp8[:,:block_size*D].view(torch.int8)
k_val = k_val.view(-1,block_size, 1, D)
k_scale_list = []
for block_tables_idx in range(block_tables.shape[0]):
k_scale_item = kv_cache_fp8[block_tables[block_tables_idx], block_size *
D:].view(-1, 4)
k_scale_list.append(k_scale_item)
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).view(-1,max_model_len)
kv_cache = [k_val, k_scale]
weights = weights.view(batch_size,next_n,-1)
logits = torch.empty((batch_size, next_n, max_model_len), dtype=torch.float32, device=q_fp8.device)
xtorch_ops.I8_paged_mqa_logits(
q=q_fp8,
fused_kv_cache=kv_cache,
weights=weights,
context_lens=[context_lens_cpu, context_lens],
block_table=block_tables,
max_context_len=max_model_len,
clean_logits=True,
out=logits,
use_xfa_boost=False
)
logits = logits.view(-1, max_model_len)
return logits

View File

@@ -1,37 +1,14 @@
"""layer.py"""
from contextlib import nullcontext
from typing import Callable, Optional, Union, get_args
import torch
import os
from typing import Callable, Optional
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed import get_ep_group
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.model_executor.layers.fused_moe import FusedMoE as VllmFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase as VllmFusedMoEMethodBase
from vllm.model_executor.layers.fused_moe.layer import (
UnquantizedFusedMoEMethod as VllmUnquantizedFusedMoEMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig)
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm_kunlun.ops.quantization.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
class FusedMoEMethodBase(VllmFusedMoEMethodBase):
"""FusedMoEMethodBase"""
moe: FusedMoEConfig
@CustomOp.register("vllm_kunlun_unquantized_fused_moe")
class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
"""UnquantizedFusedMoEMethod"""
def apply(
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
@@ -45,6 +22,7 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
@@ -52,40 +30,12 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
linear_weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""apply"""
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
return self.forward_kunlun(x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
linear_weights=linear_weights,
e_score_correction_bias=e_score_correction_bias)
def forward_kunlun(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
linear_weights: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""forward_kunlun"""
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
if self.moe.use_ep:
@@ -93,21 +43,18 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
layer.w13_weight,
layer.w2_weight,
router_logits,
linear_weights,
self.moe.ep_rank,
top_k,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group
)
topk_group=topk_group)
else:
return ops.fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
linear_weights,
self.moe.ep_rank,
top_k,
renormalize=renormalize,
@@ -118,12 +65,13 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
w1_bias = layer.w13_bias,
w2_bias = layer.w2_bias,
)
w2_bias = layer.w2_bias)
class FusedMoE(VllmFusedMoE):
"""FusedMoE"""
def __init__(self,
UnquantizedFusedMoEMethod.apply = apply
class VllmFusedMoE(FusedMoE):
def __init__(
self,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
@@ -141,198 +89,47 @@ class FusedMoE(VllmFusedMoE):
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
num_redundant_experts: int = 0,
is_sequence_parallel=False,
has_bias: bool = False,
is_sequence_parallel=False,
zero_expert_num: Optional[int] = 0,
zero_expert_type: Optional[str] = None,
):
super().__init__(
num_experts=num_experts, # Global number of experts
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=reduce_results,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
quant_config=quant_config,
tp_size=tp_size,
ep_size=ep_size,
dp_size=dp_size,
prefix=prefix,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
enable_eplb=enable_eplb,
num_redundant_experts=num_redundant_experts,
)
vllm_config = get_current_vllm_config()
if vllm_config.model_config is not None:
model_dtype = vllm_config.model_config.dtype
else:
# TODO (bnell): This is a hack to get test_mixtral_moe to work
# since model_config is not set in the pytest test.
model_dtype = params_dtype
moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
num_experts=num_experts, # Global number of experts
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=reduce_results,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
quant_config=quant_config,
tp_size=tp_size,
ep_size=ep_size,
dp_size=dp_size,
prefix=prefix,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
enable_eplb=enable_eplb,
num_redundant_experts=num_redundant_experts,
has_bias=has_bias,
# quant_config=quant_config,
)
self.moe_config = moe
self.quant_config = quant_config
is_sequence_parallel=is_sequence_parallel,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type)
self.has_bias=has_bias
self.register_parameter("w13_bias", None)
self.register_parameter("w2_bias", None)
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
# assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
if self.enable_eplb:
from vllm_kunlun.ops.quantization.fp8 import (
Fp8MoEMethod)
if not isinstance(quant_method, Fp8MoEMethod):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
# design causes duplicated work when extending to new
# quantization methods, so I'm leaving it for now.
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError("EPLB is only supported for FP8 "
"quantization for now.")
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor = None,
linear_weights: torch.Tensor = None):
"""forward"""
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits)
else:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[self.layer_name]
assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits, linear_weights)
# return torch.ops.vllm.moe_forward(hidden_states, router_logits,
# self.layer_name)
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor,
linear_weights: torch.Tensor = None):
"""forward_impl"""
assert self.quant_method is not None
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
return self.forward_impl_chunked(hidden_states, router_logits)
do_naive_dispatch_combine: bool = (
self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels)
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
linear_weights=linear_weights
)
if do_naive_dispatch_combine:
final_hidden_states = get_ep_group().combine(final_hidden_states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
return final_hidden_states
@classmethod
def make_expert_params_mapping(
cls,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]:
num_physical_experts = num_experts + num_redundant_experts
# In the returned mapping:
# - `expert_id` is the physical expert id
# - `weight_name` contains the weight name of the logical expert
# So that we should map the expert id to logical in `weight_name`
physical_to_logical_map = \
EplbState.build_initial_global_physical_to_logical_map(
num_experts, num_redundant_experts)
return [
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_" if weight_name
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.",
expert_id, shard_id) for expert_id in range(num_physical_experts)
for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]
FusedMoE=VllmFusedMoE

View File

@@ -1,244 +1,169 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from enum import Enum
from typing import Callable, Optional, Union
import torch
from typing import Any, Literal, Optional, cast, Callable, Optional
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
from compressed_tensors.config import (CompressionFormat,
SparsityCompressionConfig,
SparsityStructure)
from compressed_tensors.quantization import (ActivationOrdering,
QuantizationStrategy)
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.utils import replace_parameter
# TODO: import position will be changed after 0.9.0
# vllm.model_executor.layers.fused_moe.fused_moe --> vllm.model_executor.layers.fused_moe
def klx_process_weights_after_loading(layer: torch.nn.Module) -> None:
"""modify scale -> abs max"""
layer.w13_weight = torch.nn.Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
layer.w13_weight_scale.data * 127, requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
layer.w2_weight_scale.data * 127, requires_grad=False
)
from vllm.model_executor.utils import set_weight_attrs
import re
import xtorch_ops
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
klx_process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
hidden_states = x
global_num_experts, up_gate_size, _ = layer.w13_weight.shape
M, N = hidden_states.shape
hidden_dim = layer.w2_weight.shape[1]
normed_score = torch.empty(M,
top_k,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
top_k,
dtype=torch.int32,
device=hidden_states.device)
num_blocks = 12
block_statistic = torch.zeros(
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
)
from safetensors.torch import load_file as safe_load_file
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def get_moe_method(quant_config, layer) -> "CompressedTensorsMoEMethod":
tsm = getattr(quant_config, "target_scheme_map", None) or {}
linear_cfg = None
for k in ("Linear", "FusedMoE", "MoE", "Moe", "Experts"):
if k in tsm and isinstance(tsm[k], dict):
linear_cfg = tsm[k]; break
if not linear_cfg:
# print("target_scheme_map missing; fallback to INT8(W8A8) method")
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
wq = linear_cfg.get("weights"); aq = linear_cfg.get("input_activations")
if not wq or not aq:
# print("incomplete scheme; fallback to INT8(W8A8)")
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
# 其它分流按需;默认回落:
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
# copied from vllm 0.9.0
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
# 直接创建默认的量化配置字典,避免 QuantizationArgs 的验证问题
# print("Creating default INT8 quantization config for MoE")
# 创建默认的权重量化配置字典
self.weight_quant = type('WeightQuant', (), {
'type': 'int',
'num_bits': 8,
'strategy': 'channel',
'group_size': 128,
'symmetric': True,
'dynamic': False,
'actorder': 'none',
'observer': None,
'observer_kwargs': {},
'block_structure': None
})()
# 创建默认的输入激活量化配置字典
self.input_quant = type('InputQuant', (), {
'type': 'int',
'num_bits': 8,
'strategy': 'token',
'group_size': 128,
'symmetric': True,
'dynamic': True,
'actorder': 'none',
'observer': None,
'observer_kwargs': {},
'block_structure': None
})()
# 修改比较方式,直接比较字符串
per_channel = (
self.weight_quant.strategy == "channel"
and self.input_quant.strategy == "token")
if not per_channel:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
def create_weights1(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
# 权重先用浮点占位,便于从 ckpt 加载原始权重
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype), # 通常是 torch.bfloat16
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# 通道 scalefloat32 + 二维 [E, out](与 fused_moe/UT 对齐)
w13_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
requires_grad=False)
w2_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# 输入 scale 动态计算即可
layer.w13_input_scale = None
layer.w2_input_scale = None
def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=torch.int8), # 直接使用 int8
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=torch.int8), # 直接使用 int8
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# 缩放因子
w13_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
requires_grad=False)
w2_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# 输入 scale 动态计算
layer.w13_input_scale = None
layer.w2_input_scale = None
@torch.no_grad()
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
return
#原始权重转 float32 做统计更稳健
w13_f = layer.w13_weight.float()
w2_f = layer.w2_weight.float()
# 每列(abs_max) -> per-column scaleout 维在 dim=1列在 dim=-1
qmax = 127.0
w13_abs_max = torch.amax(torch.abs(w13_f), dim=-1) # [E, 2N]
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H]
w13_scale_2d = torch.clamp(w13_abs_max, min=1e-6) / qmax # [E, 2N], float32
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32
# 量化:用 3D scale 广播,存回 2D scale
w13_scale_3d = w13_scale_2d.unsqueeze(-1) # [E, 2N, 1]
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1]
w13_q = torch.round(w13_f / w13_scale_3d).clamp_(-128, 127).to(torch.int8)
w2_q = torch.round(w2_f / w2_scale_3d ).clamp_(-128, 127).to(torch.int8)
# 可选:若你的 fused/kernel 期望 scale 预乘 127与某些 UT 后端一致),打开下面两行:
w13_scale_2d = w13_scale_2d * 127.0
w2_scale_2d = w2_scale_2d * 127.0
# 回写参数:权重 int8scale 用 float32 + 2D
replace_parameter(layer, 'w13_weight', torch.nn.Parameter(w13_q, requires_grad=False))
replace_parameter(layer, 'w2_weight', torch.nn.Parameter(w2_q, requires_grad=False))
replace_parameter(layer, 'w13_weight_scale',
torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False))
replace_parameter(layer, 'w2_weight_scale',
torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False))
# 简要检查
print(f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}")
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False, # 添加这个参数
expert_load_view: Optional[torch.Tensor] = None, # 添加这个参数
logical_to_physical_map: Optional[torch.Tensor] = None, # 添加这个参数
logical_replica_count: Optional[torch.Tensor] = None, # 添加这个参数
linear_weights: Optional[torch.Tensor] = None, # 添加这个参数
) -> torch.Tensor:
output = torch.empty_like(x)
torch.ops._C.moe_ffn_per_token_block(
x=x,
inter_weight=layer.w13_weight,
inter_scale=layer.w13_weight_scale,
outer_weight=layer.w2_weight,
outer_scale=layer.w2_weight_scale,
top_k=top_k,
global_num_experts=global_num_experts,
linear_weights=linear_weights,
expert_map=expert_map,
activation=activation,
output=output,
use_expert_parallel=expert_map is not None,
ep_size=expert_map.size(0) if expert_map is not None else 1,
ep_rank=0,
router_logits = router_logits.float()
if scoring_func == "softmax":
torch.ops._C.moe_softmax_topk_norm(
x=router_logits,
normed_score=normed_score,
topk_index=topk_ids,
block_statistic=None,
stable=True)
elif scoring_func == "sigmoid":
torch.ops._C.moe_sigmoid_group_topk_norm(
x=router_logits,
norm_score=normed_score,
topk_index=topk_ids,
block_static=block_statistic,
bias=e_score_correction_bias,
n_group=num_expert_group,
topk_group=topk_group,
scale=routed_scaling_factor,
)
return output
print("[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
--> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod")
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod)
y = torch.empty(M,top_k,
layer.w13_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
moe_expand = moe_expand.view(M * top_k, hidden_dim)
x_shape = moe_expand.shape
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device)
torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True)
torch.ops._C.moe_fc(
x=x_q,
x_perchannel_max=x_scale,
weight=layer.w13_weight,
w_perchannel_max=layer.w13_weight_scale,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=y,
topk_ids=topk_ids,
# sort_mode=False,
act=None)
d = y.shape[-1] // 2
output_shape = (y.shape[:-1] + (d, ))
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out = torch.empty(M,top_k,
layer.w2_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
out1 = out1.reshape(-1, out1.shape[-1])
x_shape = out1.shape
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device)
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
torch.ops._C.moe_fc(
x=x_q,
x_perchannel_max=x_scale,
weight=layer.w2_weight,
w_perchannel_max=layer.w2_weight_scale,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=out,
topk_ids=topk_ids,
# sort_mode=False,
act=None)
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
torch.ops._C.moe_post(
x=out,
moe_index=sorted_tokens_idx,
normed_scale=normed_score,
dequant_scale=dequant_scale,
y=output
)
return output
CompressedTensorsW8A8Int8MoEMethod.process_weights_after_loading = process_weights_after_loading
CompressedTensorsW8A8Int8MoEMethod.apply = apply

View File

@@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ScaledMMLinearLayerConfig
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import CutlassScaledMMLinearKernel
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise)
def can_implement_kunlun(
cls, c: ScaledMMLinearLayerConfig=None) -> tuple[bool, Optional[str]]:
return True, None
def klx_process_weights_after_loading(layer: torch.nn.Module) -> None:
"""modify scale -> abs max"""
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data * 127, requires_grad=False)
def process_weights_after_loading_kunlun(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Cutlass kernels need transposed weight.
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer, self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False))
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale,
layer.logical_widths)
replace_parameter(
layer, self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False))
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
if self.config.input_symmetric:
replace_parameter(
layer, self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False))
setattr(layer, self.i_zp_name, None)
else:
input_zero_point = getattr(layer, self.i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max -
int8_traits.min)
replace_parameter(
layer, self.i_s_name,
torch.nn.Parameter(scale, requires_grad=False))
# AZP loaded as int8 but used as int32
azp = (int8_traits.min -
range_min / scale).to(dtype=torch.int32)
replace_parameter(layer, self.i_zp_name,
torch.nn.Parameter(azp, requires_grad=False))
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
if not self.config.input_symmetric:
weight = getattr(layer, self.w_q_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if self.config.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
setattr(layer, self.azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False))
else:
setattr(layer, self.azp_adj_name, None)
klx_process_weights_after_loading(layer)
def apply_weights_kunlun(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
x_q, x_scale, out = None, None, None
w_t_shape = layer.weight.T.shape
if isinstance(x, tuple):
x_q, x_scale = x
out = torch.empty((x_q.shape[0], w_t_shape[0]),
dtype=torch.bfloat16,
device=x_q.device)
else:
x_shape = x.shape
x_q = torch.empty(x_shape, dtype=torch.int8, device=x.device)
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=x.device)
out = torch.empty((x_shape[0], w_t_shape[0]),
dtype=x.dtype,
device=x.device)
torch.ops._C.quant2d(x, x_q, x_scale, force_sdnn=True)
torch.ops._C.gemm_I8_I8_bf16_nt(x_q, x_scale, layer.weight.T.data, layer.weight_scale.data, out)
return out
CutlassScaledMMLinearKernel.apply_weights = apply_weights_kunlun
CutlassScaledMMLinearKernel.can_implement = can_implement_kunlun
CutlassScaledMMLinearKernel.process_weights_after_loading = process_weights_after_loading_kunlun

View File

@@ -19,7 +19,9 @@ import torch
import xspeedgate_ops
import os
from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding, YaRNScalingRotaryEmbedding, DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding)
RotaryEmbedding, YaRNScalingRotaryEmbedding,
DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding,
DeepseekScalingRotaryEmbedding)
from typing import Optional, Tuple
def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
@@ -143,12 +145,15 @@ def vllm_kunlun_mrope_forward_cuda(
return query, key
DeepseekScalingRotaryEmbedding_forward = DeepseekScalingRotaryEmbedding.forward
DeepseekScalingRotaryEmbedding_forward_cuda = DeepseekScalingRotaryEmbedding.forward_cuda
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
DeepseekScalingRotaryEmbedding.forward = DeepseekScalingRotaryEmbedding_forward
DeepseekScalingRotaryEmbedding.forward_cuda = DeepseekScalingRotaryEmbedding_forward_cuda
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
def Split_Norm_Rope(
qkv: torch.Tensor,
cos_sin_cache: torch.Tensor,