[Feature] support deepseek v3/r1/v3.2 (#78)
* [Feature] support deepseek v3/r1/v3.2 * fix gpt_oss * update readme * update readme --------- Co-authored-by: hanhaowen <hanhaowen@baidu.com>
This commit is contained in:
@@ -20,4 +20,8 @@ import vllm_kunlun.ops.layernorm
|
||||
import vllm_kunlun.ops.quantization.awq
|
||||
import vllm_kunlun.ops.quantization.gptq
|
||||
import vllm_kunlun.ops.vocab_parallel_embedding
|
||||
import vllm_kunlun.ops.linear
|
||||
import vllm_kunlun.ops.linear
|
||||
import vllm_kunlun.ops.quantization.kernels.scaled_mm.cutlass
|
||||
import vllm_kunlun.ops.vocab_parallel_embedding
|
||||
import vllm_kunlun.ops.quantization.compressed_tensors_moe
|
||||
import vllm_kunlun.ops.fused_moe.layer
|
||||
@@ -417,7 +417,6 @@ class KunlunOps:
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
ep_rank: int,
|
||||
moe_top_k: int,
|
||||
renormalize: bool,
|
||||
|
||||
@@ -108,7 +108,7 @@ class SiluAndMul(CustomOp):
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
xtorch_ops.swiglu(x, out)
|
||||
torch.ops._C.silu_and_mul(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
260
vllm_kunlun/ops/attention/flashmla.py
Normal file
260
vllm_kunlun/ops/attention/flashmla.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
import xtorch_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_C # noqa: F401
|
||||
_flashmla_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_extension_C # noqa: F401
|
||||
_flashmla_extension_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
|
||||
|
||||
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Return: is_supported_flag, unsupported_reason (optional).
|
||||
"""
|
||||
return True, None
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int = 1,
|
||||
num_heads_k: int = 1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
cache_seqlens: (batch_size), dtype torch.int32.
|
||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||
num_heads_k: num_heads_k.
|
||||
|
||||
Returns:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
# return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
|
||||
cache_seqlens_cpu = cache_seqlens.cpu()
|
||||
return cache_seqlens_cpu, cache_seqlens
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
descale_q: Optional[torch.Tensor] = None,
|
||||
descale_k: Optional[torch.Tensor] = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head dimension of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
|
||||
Returns:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
|
||||
softmax_lse = None
|
||||
out = torch.ones(q.size(0), q.size(1), q.size(2), head_dim_v, dtype= q.dtype, device=q.device)
|
||||
kv_lora_rank = head_dim_v
|
||||
qk_rope_head_dim = q.size(3) - head_dim_v
|
||||
head_dim = k_cache.shape[3]
|
||||
page_block_size = k_cache.shape[1]
|
||||
k_cache = k_cache.view(-1, 1, page_block_size, head_dim)
|
||||
|
||||
# todo: optimize memcp
|
||||
# q_c = q[..., : kv_lora_rank].contiguous()
|
||||
# q_r = q[..., kv_lora_rank :].contiguous()
|
||||
|
||||
is_context = False
|
||||
vo_head_dim = -1
|
||||
|
||||
xtorch_ops.paged_attention(out,
|
||||
q,
|
||||
k_cache, None,
|
||||
block_table,
|
||||
tile_scheduler_metadata, # context_lens_cpu
|
||||
num_splits, # context_lens_xpu
|
||||
is_context,
|
||||
causal,
|
||||
vo_head_dim,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
softmax_scale,
|
||||
q_r=q)
|
||||
return out, softmax_lse
|
||||
|
||||
def kunlun_flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
cache_seqlens_cpu: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
is_fp8_kvcache: bool = False,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
max_seq_kv: int = 1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_tokens_kv, head_dim).
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head dimension of v.
|
||||
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format.
|
||||
indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
|
||||
max_seq_kv: seq中最大的kv长度
|
||||
|
||||
Returns:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
max_logits: (batch_size, seq_len_q, num_heads_q), torch.float32.
|
||||
p_sums: (batch_size, seq_len_q, num_heads_q), torch.float32.
|
||||
"""
|
||||
assert not is_fp8_kvcache, "By now, the kernel does not support uint8 kv cache."
|
||||
assert q.shape[1] <= 2, "xtorch_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
if indices is not None:
|
||||
# NOTE (zyongye): sparse attention is also causal
|
||||
# since it only attend to the tokens before
|
||||
# but here `causal` should not be specified
|
||||
assert not causal, \
|
||||
"causal must be `false` if sparse attention is enabled."
|
||||
|
||||
q_r, pe_cache = None, None # 当q_r和pe_cache为空时,为packed模式
|
||||
batch_size, seq_len_q, num_heads_q, head_dim = q.shape
|
||||
kv_lora_rank = head_dim_v
|
||||
rope_head_dim = head_dim - kv_lora_rank
|
||||
|
||||
out = torch.zeros([batch_size, seq_len_q, num_heads_q, kv_lora_rank],
|
||||
dtype=q.dtype, device=q.device)
|
||||
max_logits = torch.zeros([batch_size, seq_len_q, num_heads_q],
|
||||
dtype=torch.float32, device=q.device)
|
||||
p_sums = torch.zeros([batch_size, seq_len_q, num_heads_q],
|
||||
dtype=torch.float32, device=q.device)
|
||||
|
||||
xtorch_ops.fwd_kvcache_mla(
|
||||
q_c=q,
|
||||
kv_cache=k_cache,
|
||||
indices=indices,
|
||||
kv_lod_cpu=cache_seqlens_cpu,
|
||||
max_seq_kv=max_seq_kv,
|
||||
softmax_scale=softmax_scale,
|
||||
# q_r=q_r,
|
||||
# pe_cache=pe_cache,
|
||||
out=out,
|
||||
max_logits=max_logits,
|
||||
p_sums=p_sums,
|
||||
kv_lod_xpu=cache_seqlens,
|
||||
)
|
||||
|
||||
return out, max_logits, p_sums
|
||||
|
||||
|
||||
def flash_mla_sparse_prefill(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
q_lod_xpu: torch.Tensor,
|
||||
q_lod_cpu: torch.Tensor,
|
||||
d_v: int = 512,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sparse attention prefill kernel
|
||||
|
||||
Args:
|
||||
- q: [s_q, h_q, d_qk], bfloat16
|
||||
- kv: [s_kv, d_qk], bfloat16
|
||||
- indices: [s_q, h_kv, topk], int32.
|
||||
Invalid indices should be set to -1 or numbers >= s_kv
|
||||
- sm_scale: float
|
||||
- q_lod_xpu: [batch+1], int32, q的每个seq长度的累加信息, 长度为batch_num + 1 (为空则表示q定长).
|
||||
- d_v: The dimension of value vectors. Can only be 512
|
||||
|
||||
Returns:
|
||||
- (output, max_logits, lse)
|
||||
About the definition of output,
|
||||
max_logits and lse, please refer to README.md
|
||||
- output: [s_q, h_q, d_v], bfloat16
|
||||
- max_logits: [s_q, h_q], float
|
||||
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
||||
"""
|
||||
s_q, h_q, d_qk = q.shape
|
||||
|
||||
out = torch.zeros([s_q, h_q, d_v], dtype=q.dtype, device=q.device)
|
||||
max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
|
||||
lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
|
||||
|
||||
xtorch_ops.sparse_prefill_fwd_opt(
|
||||
q=q,
|
||||
kv=kv,
|
||||
indices=indices,
|
||||
qlod_cpu=q_lod_cpu,
|
||||
qlod_xpu=q_lod_xpu,
|
||||
kvlod_cpu=q_lod_cpu,
|
||||
kvlod_xpu=q_lod_xpu,
|
||||
sm_scale=sm_scale,
|
||||
d_v=d_v,
|
||||
is_causal=True, #aiak这个值为true,这是为啥
|
||||
out=out,
|
||||
max_logits=max_logits,
|
||||
lse=lse,
|
||||
)
|
||||
|
||||
# NOTE: Compared with torch.ops._flashmla_C.sparse_prefill_fwd,
|
||||
# out_scale = 1 / math.log2(math.e)
|
||||
# gpu_max_logits * out_scale = kunlun_lse
|
||||
# gpu_lse * out_scale = kunlun_lse
|
||||
return out, max_logits, lse
|
||||
|
||||
|
||||
#
|
||||
# TODO: Add fake functions
|
||||
#
|
||||
# @register_fake("_flashmla_C::get_mla_metadata")
|
||||
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
# @register_fake("_flashmla_C::fwd_kvcache_mla")
|
||||
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
180
vllm_kunlun/ops/attention/mla.py
Normal file
180
vllm_kunlun/ops/attention/mla.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_kunlun.ops.attention.layer import Attention
|
||||
# from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLAModules:
|
||||
"""Modules used in MLA.
|
||||
"""
|
||||
kv_a_layernorm: torch.nn.Module
|
||||
kv_b_proj: torch.nn.Module
|
||||
rotary_emb: torch.nn.Module
|
||||
o_proj: torch.nn.Module
|
||||
fused_qkv_a_proj: Optional[torch.nn.Module]
|
||||
kv_a_proj_with_mqa: Optional[torch.nn.Module]
|
||||
q_a_layernorm: Optional[torch.nn.Module]
|
||||
q_b_proj: Optional[torch.nn.Module]
|
||||
q_proj: Optional[torch.nn.Module]
|
||||
indexer: Optional[torch.nn.Module]
|
||||
is_sparse: bool
|
||||
topk_indices_buffer: Optional[torch.Tensor]
|
||||
|
||||
|
||||
@CustomOp.register("vllm_kunlun_multi_head_latent_attention")
|
||||
class MultiHeadLatentAttention(CustomOp):
|
||||
"""MLA layer registered as CustomOp.
|
||||
Note that currently MLA ignores the enable/disable mechanism of CustomOp
|
||||
because there is only one in-tree implementation in forward_native.
|
||||
TODO: implement this with a new PluggableLayer mechanism.
|
||||
|
||||
This class takes positions and hidden_states as input.
|
||||
The input tensors can either contain prefill tokens or decode tokens.
|
||||
The class does the following:
|
||||
|
||||
1. MLA Preprocess.
|
||||
2. Perform multi-head attention to prefill tokens and
|
||||
multi-query attention to decode tokens separately.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
scale: float,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
mla_modules: MLAModules,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.num_heads = num_heads
|
||||
self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
|
||||
self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
|
||||
self.q_a_layernorm = mla_modules.q_a_layernorm
|
||||
self.q_b_proj = mla_modules.q_b_proj
|
||||
self.q_proj = mla_modules.q_proj
|
||||
self.kv_a_layernorm = mla_modules.kv_a_layernorm
|
||||
self.kv_b_proj = mla_modules.kv_b_proj
|
||||
self.rotary_emb = mla_modules.rotary_emb
|
||||
self.o_proj = mla_modules.o_proj
|
||||
self.indexer = mla_modules.indexer
|
||||
self.is_sparse = mla_modules.is_sparse
|
||||
|
||||
if self.indexer is not None:
|
||||
assert hasattr(self.indexer, "topk_tokens")
|
||||
self.topk_tokens = self.indexer.topk_tokens
|
||||
self.topk_indices_buffer = mla_modules.topk_indices_buffer
|
||||
|
||||
# In the MLA backend, kv_cache includes both k_c and
|
||||
# pe (i.e. decoupled position embeddings). In particular,
|
||||
# the concat_and_cache_mla op requires
|
||||
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
|
||||
# i.e.
|
||||
# kv_lora_rank + qk_rope_head_dim == head_size
|
||||
self.mla_attn = Attention(
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sparse=mla_modules.is_sparse,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
indexer=self.indexer,
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
q_c = None
|
||||
kv_lora = None
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
assert self.fused_qkv_a_proj is not None, \
|
||||
"fused_qkv_a_proj is required when q_lora_rank is not None"
|
||||
assert self.q_a_layernorm is not None, \
|
||||
"q_a_layernorm is required when q_lora_rank is not None"
|
||||
assert self.q_b_proj is not None, \
|
||||
"q_b_proj is required when q_lora_rank is not None"
|
||||
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
|
||||
q_c, kv_lora = qkv_lora.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
q = self.q_b_proj(q_c)[0]
|
||||
else:
|
||||
assert self.kv_a_proj_with_mqa is not None, \
|
||||
"kv_a_proj_with_mqa is required when q_lora_rank is None"
|
||||
assert self.q_proj is not None, \
|
||||
"q_proj is required when q_lora_rank is None"
|
||||
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
q = self.q_proj(hidden_states)[0]
|
||||
|
||||
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c)
|
||||
|
||||
q = q.view(-1, self.num_heads, self.qk_head_dim)
|
||||
# Add head dim of 1 to k_pe
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
|
||||
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
|
||||
positions, q[..., self.qk_nope_head_dim:], k_pe)
|
||||
|
||||
if self.indexer and self.is_sparse:
|
||||
_topk_indices = self.indexer(hidden_states, q_c, positions,
|
||||
self.rotary_emb)
|
||||
|
||||
hidden_states_shape_0 = 0
|
||||
if isinstance(hidden_states, tuple):
|
||||
x_q, x_scale = hidden_states
|
||||
hidden_states_shape_0 = x_q.shape[0]
|
||||
else:
|
||||
hidden_states_shape_0 = hidden_states.shape[0]
|
||||
attn_out = self.mla_attn(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=(hidden_states_shape_0,
|
||||
self.num_heads * self.v_head_dim))
|
||||
return self.o_proj(attn_out)[0]
|
||||
|
||||
def forward_cuda(self, *args, **kwargs):
|
||||
return self.forward_native(*args, **kwargs)
|
||||
114
vllm_kunlun/ops/deep_gemm.py
Normal file
114
vllm_kunlun/ops/deep_gemm.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import torch
|
||||
import xtorch_ops
|
||||
|
||||
def int8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
logits = torch.empty((q.shape[0], kv[0].shape[0]), dtype=torch.float32, device=q.device)
|
||||
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||
|
||||
xtorch_ops.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=kv,
|
||||
weights=weights,
|
||||
context_q_lens=(context_q_lens_xpu.cpu(), context_q_lens_xpu),
|
||||
context_k_lens=(context_k_lens_xpu.cpu(), context_k_lens_xpu),
|
||||
logits=logits,
|
||||
clean_logits=True,
|
||||
use_xfa_boost=False,
|
||||
)
|
||||
seq_len_kv = kv[0].shape[0]
|
||||
# mask参考 https://github.com/vllm-project/vllm/blob/v0.11.0/tests/kernels/attention/test_deepgemm_attention.py 的_ref_fp8_mqa_logits函数的实现
|
||||
mask_lo = (torch.arange(0, seq_len_kv, device=cu_seqlen_ks.device)[None, :]
|
||||
>= cu_seqlen_ks[:, None])
|
||||
mask_hi = (torch.arange(0, seq_len_kv, device=cu_seqlen_ke.device)[None, :]
|
||||
< cu_seqlen_ke[:, None])
|
||||
mask = mask_lo & mask_hi
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
def int8_paged_mqa_logits(
|
||||
q_fp8: torch.Tensor,
|
||||
kv_cache_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
context_lens_cpu: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
schedule_metadata: torch.Tensor,
|
||||
max_model_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits using paged KV-cache.
|
||||
|
||||
Args:
|
||||
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
|
||||
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
||||
4 bytes per (block,pos) store the `float` dequant scale.
|
||||
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
||||
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||
for each batch element.
|
||||
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
||||
block indices to physical blocks in the paged cache.
|
||||
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
|
||||
used to distribute work across SMs.
|
||||
max_model_len: Maximum sequence length used to size the logits output.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [B * next_n, max_model_len], dtype
|
||||
`torch.float32`.
|
||||
"""
|
||||
batch_size, next_n, _, D = q_fp8.shape
|
||||
num_blocks, block_size, _, _ = kv_cache_fp8.shape
|
||||
|
||||
kv_cache_fp8=kv_cache_fp8.view(num_blocks, -1)
|
||||
k_val = kv_cache_fp8[:,:block_size*D].view(torch.int8)
|
||||
k_val = k_val.view(-1,block_size, 1, D)
|
||||
k_scale_list = []
|
||||
for block_tables_idx in range(block_tables.shape[0]):
|
||||
k_scale_item = kv_cache_fp8[block_tables[block_tables_idx], block_size *
|
||||
D:].view(-1, 4)
|
||||
k_scale_list.append(k_scale_item)
|
||||
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).view(-1,max_model_len)
|
||||
kv_cache = [k_val, k_scale]
|
||||
|
||||
weights = weights.view(batch_size,next_n,-1)
|
||||
|
||||
logits = torch.empty((batch_size, next_n, max_model_len), dtype=torch.float32, device=q_fp8.device)
|
||||
|
||||
xtorch_ops.I8_paged_mqa_logits(
|
||||
q=q_fp8,
|
||||
fused_kv_cache=kv_cache,
|
||||
weights=weights,
|
||||
context_lens=[context_lens_cpu, context_lens],
|
||||
block_table=block_tables,
|
||||
max_context_len=max_model_len,
|
||||
clean_logits=True,
|
||||
out=logits,
|
||||
use_xfa_boost=False
|
||||
)
|
||||
logits = logits.view(-1, max_model_len)
|
||||
return logits
|
||||
@@ -1,37 +1,14 @@
|
||||
"""layer.py"""
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Callable, Optional, Union, get_args
|
||||
|
||||
import torch
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE as VllmFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase as VllmFusedMoEMethodBase
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
UnquantizedFusedMoEMethod as VllmUnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig)
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_kunlun.ops.quantization.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
|
||||
|
||||
|
||||
class FusedMoEMethodBase(VllmFusedMoEMethodBase):
|
||||
"""FusedMoEMethodBase"""
|
||||
moe: FusedMoEConfig
|
||||
|
||||
@CustomOp.register("vllm_kunlun_unquantized_fused_moe")
|
||||
class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
"""UnquantizedFusedMoEMethod"""
|
||||
def apply(
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
@@ -45,6 +22,7 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
@@ -52,40 +30,12 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
linear_weights: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""apply"""
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
|
||||
|
||||
return self.forward_kunlun(x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
linear_weights=linear_weights,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
def forward_kunlun(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
"""forward_kunlun"""
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
if self.moe.use_ep:
|
||||
@@ -93,21 +43,18 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
linear_weights,
|
||||
self.moe.ep_rank,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group
|
||||
)
|
||||
topk_group=topk_group)
|
||||
else:
|
||||
return ops.fused_moe(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
linear_weights,
|
||||
self.moe.ep_rank,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
@@ -118,12 +65,13 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
w1_bias = layer.w13_bias,
|
||||
w2_bias = layer.w2_bias,
|
||||
)
|
||||
w2_bias = layer.w2_bias)
|
||||
|
||||
class FusedMoE(VllmFusedMoE):
|
||||
"""FusedMoE"""
|
||||
def __init__(self,
|
||||
UnquantizedFusedMoEMethod.apply = apply
|
||||
|
||||
class VllmFusedMoE(FusedMoE):
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int, # Global number of experts
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
@@ -141,198 +89,47 @@ class FusedMoE(VllmFusedMoE):
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
num_redundant_experts: int = 0,
|
||||
is_sequence_parallel=False,
|
||||
has_bias: bool = False,
|
||||
is_sequence_parallel=False,
|
||||
zero_expert_num: Optional[int] = 0,
|
||||
zero_expert_type: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts, # Global number of experts
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=reduce_results,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
ep_size=ep_size,
|
||||
dp_size=dp_size,
|
||||
prefix=prefix,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=activation,
|
||||
enable_eplb=enable_eplb,
|
||||
num_redundant_experts=num_redundant_experts,
|
||||
)
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.model_config is not None:
|
||||
model_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
# TODO (bnell): This is a hack to get test_mixtral_moe to work
|
||||
# since model_config is not set in the pytest test.
|
||||
model_dtype = params_dtype
|
||||
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=model_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
num_experts=num_experts, # Global number of experts
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=reduce_results,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
ep_size=ep_size,
|
||||
dp_size=dp_size,
|
||||
prefix=prefix,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=activation,
|
||||
enable_eplb=enable_eplb,
|
||||
num_redundant_experts=num_redundant_experts,
|
||||
has_bias=has_bias,
|
||||
# quant_config=quant_config,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
is_sequence_parallel=is_sequence_parallel,
|
||||
zero_expert_num=zero_expert_num,
|
||||
zero_expert_type=zero_expert_type)
|
||||
self.has_bias=has_bias
|
||||
self.register_parameter("w13_bias", None)
|
||||
self.register_parameter("w2_bias", None)
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
quant_method: Optional[QuantizeMethodBase] = None
|
||||
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
|
||||
else quant_config.get_quant_method(self, prefix))
|
||||
|
||||
assert quant_method is not None
|
||||
# assert isinstance(quant_method, FusedMoEMethodBase)
|
||||
self.quant_method = quant_method
|
||||
|
||||
if self.enable_eplb:
|
||||
from vllm_kunlun.ops.quantization.fp8 import (
|
||||
Fp8MoEMethod)
|
||||
if not isinstance(quant_method, Fp8MoEMethod):
|
||||
# TODO: Add support for additional quantization methods.
|
||||
# The implementation for other quantization methods does not
|
||||
# contain essential differences, but the current quant API
|
||||
# design causes duplicated work when extending to new
|
||||
# quantization methods, so I'm leaving it for now.
|
||||
# If you plan to add support for more quantization methods,
|
||||
# please refer to the implementation in `Fp8MoEMethod`.
|
||||
raise NotImplementedError("EPLB is only supported for FP8 "
|
||||
"quantization for now.")
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor = None,
|
||||
linear_weights: torch.Tensor = None):
|
||||
"""forward"""
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we will
|
||||
# switch to using the moe_forward custom op.
|
||||
if current_platform.is_tpu():
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
else:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[self.layer_name]
|
||||
assert self.quant_method is not None
|
||||
return self.forward_impl(hidden_states, router_logits, linear_weights)
|
||||
# return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||
# self.layer_name)
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
linear_weights: torch.Tensor = None):
|
||||
"""forward_impl"""
|
||||
assert self.quant_method is not None
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||
return self.forward_impl_chunked(hidden_states, router_logits)
|
||||
|
||||
do_naive_dispatch_combine: bool = (
|
||||
self.dp_size > 1
|
||||
and not self.moe_parallel_config.use_deepep_ht_kernels)
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
enable_eplb=self.enable_eplb,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
linear_weights=linear_weights
|
||||
)
|
||||
|
||||
if do_naive_dispatch_combine:
|
||||
final_hidden_states = get_ep_group().combine(final_hidden_states)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
# Default set to False. (May have to add shared expert outputs.
|
||||
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls,
|
||||
ckpt_gate_proj_name: str,
|
||||
ckpt_down_proj_name: str,
|
||||
ckpt_up_proj_name: str,
|
||||
num_experts: int,
|
||||
num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]:
|
||||
|
||||
num_physical_experts = num_experts + num_redundant_experts
|
||||
|
||||
# In the returned mapping:
|
||||
# - `expert_id` is the physical expert id
|
||||
# - `weight_name` contains the weight name of the logical expert
|
||||
# So that we should map the expert id to logical in `weight_name`
|
||||
physical_to_logical_map = \
|
||||
EplbState.build_initial_global_physical_to_logical_map(
|
||||
num_experts, num_redundant_experts)
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("experts.w13_" if weight_name
|
||||
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
|
||||
f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.",
|
||||
expert_id, shard_id) for expert_id in range(num_physical_experts)
|
||||
for shard_id, weight_name in [
|
||||
("w1", ckpt_gate_proj_name),
|
||||
("w2", ckpt_down_proj_name),
|
||||
("w3", ckpt_up_proj_name),
|
||||
]
|
||||
]
|
||||
FusedMoE=VllmFusedMoE
|
||||
@@ -1,244 +1,169 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from typing import Any, Literal, Optional, cast, Callable, Optional
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
|
||||
|
||||
from compressed_tensors.config import (CompressionFormat,
|
||||
SparsityCompressionConfig,
|
||||
SparsityStructure)
|
||||
from compressed_tensors.quantization import (ActivationOrdering,
|
||||
QuantizationStrategy)
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
# TODO: import position will be changed after 0.9.0
|
||||
# vllm.model_executor.layers.fused_moe.fused_moe --> vllm.model_executor.layers.fused_moe
|
||||
def klx_process_weights_after_loading(layer: torch.nn.Module) -> None:
|
||||
"""modify scale -> abs max"""
|
||||
layer.w13_weight = torch.nn.Parameter(layer.w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(layer.w2_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
layer.w13_weight_scale.data * 127, requires_grad=False
|
||||
)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
layer.w2_weight_scale.data * 127, requires_grad=False
|
||||
)
|
||||
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
import re
|
||||
import xtorch_ops
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
klx_process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
hidden_states = x
|
||||
global_num_experts, up_gate_size, _ = layer.w13_weight.shape
|
||||
M, N = hidden_states.shape
|
||||
hidden_dim = layer.w2_weight.shape[1]
|
||||
normed_score = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
num_blocks = 12
|
||||
block_statistic = torch.zeros(
|
||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def get_moe_method(quant_config, layer) -> "CompressedTensorsMoEMethod":
|
||||
tsm = getattr(quant_config, "target_scheme_map", None) or {}
|
||||
linear_cfg = None
|
||||
for k in ("Linear", "FusedMoE", "MoE", "Moe", "Experts"):
|
||||
if k in tsm and isinstance(tsm[k], dict):
|
||||
linear_cfg = tsm[k]; break
|
||||
if not linear_cfg:
|
||||
# print("target_scheme_map missing; fallback to INT8(W8A8) method")
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
wq = linear_cfg.get("weights"); aq = linear_cfg.get("input_activations")
|
||||
if not wq or not aq:
|
||||
# print("incomplete scheme; fallback to INT8(W8A8)")
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
# 其它分流按需;默认回落:
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
|
||||
# copied from vllm 0.9.0
|
||||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
|
||||
# 直接创建默认的量化配置字典,避免 QuantizationArgs 的验证问题
|
||||
# print("Creating default INT8 quantization config for MoE")
|
||||
|
||||
# 创建默认的权重量化配置字典
|
||||
self.weight_quant = type('WeightQuant', (), {
|
||||
'type': 'int',
|
||||
'num_bits': 8,
|
||||
'strategy': 'channel',
|
||||
'group_size': 128,
|
||||
'symmetric': True,
|
||||
'dynamic': False,
|
||||
'actorder': 'none',
|
||||
'observer': None,
|
||||
'observer_kwargs': {},
|
||||
'block_structure': None
|
||||
})()
|
||||
|
||||
# 创建默认的输入激活量化配置字典
|
||||
self.input_quant = type('InputQuant', (), {
|
||||
'type': 'int',
|
||||
'num_bits': 8,
|
||||
'strategy': 'token',
|
||||
'group_size': 128,
|
||||
'symmetric': True,
|
||||
'dynamic': True,
|
||||
'actorder': 'none',
|
||||
'observer': None,
|
||||
'observer_kwargs': {},
|
||||
'block_structure': None
|
||||
})()
|
||||
|
||||
# 修改比较方式,直接比较字符串
|
||||
per_channel = (
|
||||
self.weight_quant.strategy == "channel"
|
||||
and self.input_quant.strategy == "token")
|
||||
if not per_channel:
|
||||
raise ValueError(
|
||||
"For INT8 Fused MoE layers, we require channelwise, "
|
||||
"dynamic per token quantization. Found "
|
||||
f"{self.weight_quant}, {self.input_quant}")
|
||||
|
||||
self.static_input_scales = not self.input_quant.dynamic
|
||||
if self.static_input_scales:
|
||||
raise ValueError(
|
||||
"For INT8 Fused MoE layers, we require channelwise, "
|
||||
"dynamic per token quantization. Found static input scales.")
|
||||
|
||||
def create_weights1(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
# 权重先用浮点占位,便于从 ckpt 加载原始权重
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype), # 通常是 torch.bfloat16
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# 通道 scale:float32 + 二维 [E, out](与 fused_moe/UT 对齐)
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# 输入 scale 动态计算即可
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=torch.int8), # 直接使用 int8
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8), # 直接使用 int8
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# 缩放因子
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# 输入 scale 动态计算
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
@torch.no_grad()
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
return
|
||||
#原始权重转 float32 做统计更稳健
|
||||
w13_f = layer.w13_weight.float()
|
||||
w2_f = layer.w2_weight.float()
|
||||
|
||||
# 每列(abs_max) -> per-column scale(out 维在 dim=1,列在 dim=-1)
|
||||
qmax = 127.0
|
||||
w13_abs_max = torch.amax(torch.abs(w13_f), dim=-1) # [E, 2N]
|
||||
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H]
|
||||
|
||||
w13_scale_2d = torch.clamp(w13_abs_max, min=1e-6) / qmax # [E, 2N], float32
|
||||
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32
|
||||
|
||||
# 量化:用 3D scale 广播,存回 2D scale
|
||||
w13_scale_3d = w13_scale_2d.unsqueeze(-1) # [E, 2N, 1]
|
||||
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1]
|
||||
|
||||
w13_q = torch.round(w13_f / w13_scale_3d).clamp_(-128, 127).to(torch.int8)
|
||||
w2_q = torch.round(w2_f / w2_scale_3d ).clamp_(-128, 127).to(torch.int8)
|
||||
|
||||
# 可选:若你的 fused/kernel 期望 scale 预乘 127(与某些 UT 后端一致),打开下面两行:
|
||||
w13_scale_2d = w13_scale_2d * 127.0
|
||||
w2_scale_2d = w2_scale_2d * 127.0
|
||||
|
||||
# 回写参数:权重 int8;scale 用 float32 + 2D
|
||||
replace_parameter(layer, 'w13_weight', torch.nn.Parameter(w13_q, requires_grad=False))
|
||||
replace_parameter(layer, 'w2_weight', torch.nn.Parameter(w2_q, requires_grad=False))
|
||||
replace_parameter(layer, 'w13_weight_scale',
|
||||
torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False))
|
||||
replace_parameter(layer, 'w2_weight_scale',
|
||||
torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False))
|
||||
|
||||
# 简要检查
|
||||
print(f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False, # 添加这个参数
|
||||
expert_load_view: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
logical_replica_count: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
linear_weights: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
) -> torch.Tensor:
|
||||
|
||||
output = torch.empty_like(x)
|
||||
torch.ops._C.moe_ffn_per_token_block(
|
||||
x=x,
|
||||
inter_weight=layer.w13_weight,
|
||||
inter_scale=layer.w13_weight_scale,
|
||||
outer_weight=layer.w2_weight,
|
||||
outer_scale=layer.w2_weight_scale,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
linear_weights=linear_weights,
|
||||
expert_map=expert_map,
|
||||
activation=activation,
|
||||
output=output,
|
||||
use_expert_parallel=expert_map is not None,
|
||||
ep_size=expert_map.size(0) if expert_map is not None else 1,
|
||||
ep_rank=0,
|
||||
router_logits = router_logits.float()
|
||||
if scoring_func == "softmax":
|
||||
torch.ops._C.moe_softmax_topk_norm(
|
||||
x=router_logits,
|
||||
normed_score=normed_score,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=None,
|
||||
stable=True)
|
||||
elif scoring_func == "sigmoid":
|
||||
torch.ops._C.moe_sigmoid_group_topk_norm(
|
||||
x=router_logits,
|
||||
norm_score=normed_score,
|
||||
topk_index=topk_ids,
|
||||
block_static=block_statistic,
|
||||
bias=e_score_correction_bias,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scale=routed_scaling_factor,
|
||||
)
|
||||
return output
|
||||
|
||||
print("[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
|
||||
--> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod")
|
||||
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
|
||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||
|
||||
y = torch.empty(M,top_k,
|
||||
layer.w13_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
moe_expand = moe_expand.view(M * top_k, hidden_dim)
|
||||
|
||||
x_shape = moe_expand.shape
|
||||
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
|
||||
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device)
|
||||
torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=x_q,
|
||||
x_perchannel_max=x_scale,
|
||||
weight=layer.w13_weight,
|
||||
w_perchannel_max=layer.w13_weight_scale,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=y,
|
||||
topk_ids=topk_ids,
|
||||
# sort_mode=False,
|
||||
act=None)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = (y.shape[:-1] + (d, ))
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.silu_and_mul(out1, y)
|
||||
|
||||
out = torch.empty(M,top_k,
|
||||
layer.w2_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
x_shape = out1.shape
|
||||
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
|
||||
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device)
|
||||
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=x_q,
|
||||
x_perchannel_max=x_scale,
|
||||
weight=layer.w2_weight,
|
||||
w_perchannel_max=layer.w2_weight_scale,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=out,
|
||||
topk_ids=topk_ids,
|
||||
# sort_mode=False,
|
||||
act=None)
|
||||
|
||||
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
|
||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
|
||||
|
||||
torch.ops._C.moe_post(
|
||||
x=out,
|
||||
moe_index=sorted_tokens_idx,
|
||||
normed_scale=normed_score,
|
||||
dequant_scale=dequant_scale,
|
||||
y=output
|
||||
)
|
||||
return output
|
||||
|
||||
CompressedTensorsW8A8Int8MoEMethod.process_weights_after_loading = process_weights_after_loading
|
||||
CompressedTensorsW8A8Int8MoEMethod.apply = apply
|
||||
0
vllm_kunlun/ops/quantization/kernels/__init__.py
Normal file
0
vllm_kunlun/ops/quantization/kernels/__init__.py
Normal file
122
vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py
Normal file
122
vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import CutlassScaledMMLinearKernel
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
|
||||
def can_implement_kunlun(
|
||||
cls, c: ScaledMMLinearLayerConfig=None) -> tuple[bool, Optional[str]]:
|
||||
return True, None
|
||||
|
||||
def klx_process_weights_after_loading(layer: torch.nn.Module) -> None:
|
||||
"""modify scale -> abs max"""
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data * 127, requires_grad=False)
|
||||
|
||||
def process_weights_after_loading_kunlun(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False))
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale,
|
||||
layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer, self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, self.i_s_name)
|
||||
|
||||
if self.config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False))
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, self.i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max -
|
||||
int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(scale, requires_grad=False))
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min -
|
||||
range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(layer, self.i_zp_name,
|
||||
torch.nn.Parameter(azp, requires_grad=False))
|
||||
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
if not self.config.input_symmetric:
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if self.config.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
|
||||
setattr(layer, self.azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False))
|
||||
else:
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
klx_process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights_kunlun(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
x_q, x_scale, out = None, None, None
|
||||
w_t_shape = layer.weight.T.shape
|
||||
if isinstance(x, tuple):
|
||||
x_q, x_scale = x
|
||||
out = torch.empty((x_q.shape[0], w_t_shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
device=x_q.device)
|
||||
else:
|
||||
x_shape = x.shape
|
||||
x_q = torch.empty(x_shape, dtype=torch.int8, device=x.device)
|
||||
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=x.device)
|
||||
out = torch.empty((x_shape[0], w_t_shape[0]),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
torch.ops._C.quant2d(x, x_q, x_scale, force_sdnn=True)
|
||||
torch.ops._C.gemm_I8_I8_bf16_nt(x_q, x_scale, layer.weight.T.data, layer.weight_scale.data, out)
|
||||
return out
|
||||
|
||||
CutlassScaledMMLinearKernel.apply_weights = apply_weights_kunlun
|
||||
CutlassScaledMMLinearKernel.can_implement = can_implement_kunlun
|
||||
CutlassScaledMMLinearKernel.process_weights_after_loading = process_weights_after_loading_kunlun
|
||||
@@ -19,7 +19,9 @@ import torch
|
||||
import xspeedgate_ops
|
||||
import os
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
RotaryEmbedding, YaRNScalingRotaryEmbedding, DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding)
|
||||
RotaryEmbedding, YaRNScalingRotaryEmbedding,
|
||||
DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding,
|
||||
DeepseekScalingRotaryEmbedding)
|
||||
from typing import Optional, Tuple
|
||||
|
||||
def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
@@ -143,12 +145,15 @@ def vllm_kunlun_mrope_forward_cuda(
|
||||
|
||||
return query, key
|
||||
|
||||
DeepseekScalingRotaryEmbedding_forward = DeepseekScalingRotaryEmbedding.forward
|
||||
DeepseekScalingRotaryEmbedding_forward_cuda = DeepseekScalingRotaryEmbedding.forward_cuda
|
||||
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||
DeepseekScalingRotaryEmbedding.forward = DeepseekScalingRotaryEmbedding_forward
|
||||
DeepseekScalingRotaryEmbedding.forward_cuda = DeepseekScalingRotaryEmbedding_forward_cuda
|
||||
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
||||
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
||||
|
||||
|
||||
def Split_Norm_Rope(
|
||||
qkv: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user