Initial commit for vLLM-Kunlun Plugin
This commit is contained in:
21
vllm_kunlun/ops/__init__.py
Normal file
21
vllm_kunlun/ops/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
|
||||
import vllm_kunlun.ops.rotary_embedding
|
||||
import vllm_kunlun.ops.layernorm
|
||||
import vllm_kunlun.ops.quantization.awq
|
||||
import vllm_kunlun.ops.quantization.gptq
|
||||
597
vllm_kunlun/ops/_kunlun_ops.py
Normal file
597
vllm_kunlun/ops/_kunlun_ops.py
Normal file
@@ -0,0 +1,597 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""kunlun custom op entry"""
|
||||
import torch_xmlir
|
||||
import torch
|
||||
import os
|
||||
from typing import Optional, List, Dict
|
||||
import vllm.envs as envs
|
||||
import os
|
||||
import ctypes
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import xtorch_ops
|
||||
|
||||
logger.info(f"Load custom ops library success!")
|
||||
except ImportError as e:
|
||||
logger.warning("Import error msg: %s", e.msg)
|
||||
|
||||
|
||||
_per_token_smooth_quant = True
|
||||
|
||||
|
||||
def is_per_token_smooth_quant():
|
||||
"""is per token smooth quant"""
|
||||
return _per_token_smooth_quant
|
||||
|
||||
|
||||
class KunlunOps:
|
||||
"""KunlunOps"""
|
||||
|
||||
# Attention ops
|
||||
@staticmethod
|
||||
def paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
context_lens_cpu,
|
||||
is_context,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
alibi_sqrt=False,
|
||||
):
|
||||
"""PagedAttentionV1"""
|
||||
# block_size = value_cache.shape[2]
|
||||
xtorch_ops.paged_attention(
|
||||
x=query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_tables=block_tables,
|
||||
context_lens_cpu=context_lens_cpu,
|
||||
context_lens_xpu=context_lens,
|
||||
is_context=is_context,
|
||||
is_causal=True,
|
||||
out=output,
|
||||
vo_head_dim=128,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
context_lens_cpu,
|
||||
is_context,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
alibi_sqrt=False,
|
||||
):
|
||||
"""PagedAttentionV2"""
|
||||
# block_size = value_cache.shape[2]
|
||||
xtorch_ops.paged_attention(
|
||||
x=query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_tables=block_tables,
|
||||
context_lens_cpu=context_lens_cpu,
|
||||
context_lens_xpu=context_lens,
|
||||
is_context=is_context,
|
||||
is_causal=True,
|
||||
out=output,
|
||||
vo_head_dim=128,
|
||||
)
|
||||
|
||||
# Activation ops
|
||||
@staticmethod
|
||||
def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
|
||||
"""silu and mul"""
|
||||
xtorch_ops.silu_and_mul(
|
||||
x,
|
||||
axis=-1,
|
||||
turn=True,
|
||||
out=out,
|
||||
)
|
||||
|
||||
# Activation ops
|
||||
@staticmethod
|
||||
def quick_gelu(out: torch.Tensor, x: torch.Tensor):
|
||||
"""quick gelu"""
|
||||
xtorch_ops.quick_gelu(
|
||||
x,
|
||||
out=out,
|
||||
)
|
||||
|
||||
# Layernorm
|
||||
@staticmethod
|
||||
def rms_norm(
|
||||
out,
|
||||
x,
|
||||
weight,
|
||||
epsilon,
|
||||
):
|
||||
"""rms_norm"""
|
||||
xtorch_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
|
||||
|
||||
@staticmethod
|
||||
def fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
weight,
|
||||
epsilon,
|
||||
):
|
||||
"""fused_add_rms_norm"""
|
||||
output = torch.empty_like(x)
|
||||
xtorch_ops.add_rmsnorm(
|
||||
x, residual, weight.to(torch.float32), epsilon, out=output
|
||||
)
|
||||
fused_input = x + residual
|
||||
residual.copy_(fused_input, non_blocking=True)
|
||||
x.copy_(output)
|
||||
|
||||
# Rotary embedding
|
||||
@staticmethod
|
||||
def rotary_embedding(
|
||||
positions, query, key, head_size, cos_sin_cache, is_neox_style
|
||||
):
|
||||
"""
|
||||
refactor RotaryEmbedding forward function
|
||||
"""
|
||||
query_x = query.contiguous()
|
||||
key_x = key.contiguous()
|
||||
query_x_dim = query_x.dim()
|
||||
if not is_neox_style:
|
||||
if cos_sin_cache.dtype == torch.float16:
|
||||
cos_sin_cache = cos_sin_cache.to(torch.float32)
|
||||
positions = positions.to(torch.int)
|
||||
if positions.dim() == 1:
|
||||
positions = positions.unsqueeze(0)
|
||||
query_x = query_x.unsqueeze(0)
|
||||
key_x = key_x.unsqueeze(0)
|
||||
|
||||
xtorch_ops.rotary_embedding_gptj(
|
||||
positions, query_x, key_x, head_size, cos_sin_cache
|
||||
)
|
||||
query.data = query_x
|
||||
key.data = key_x
|
||||
if query_x_dim != query_x.dim():
|
||||
query_x = query_x.unsqueeze(0)
|
||||
key_x = key_x.unsqueeze(0)
|
||||
return query, key
|
||||
|
||||
# TODO: need opt
|
||||
if cos_sin_cache.dim() == 4:
|
||||
max_seq_len = cos_sin_cache.shape[2]
|
||||
head_dim = cos_sin_cache.shape[3]
|
||||
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(
|
||||
0
|
||||
) # Remove the first two dimensions [1,1,L,D] -> [L,D]
|
||||
cos_sin_cache = cos_sin_cache.view(max_seq_len, 1, head_dim)
|
||||
|
||||
# Reshape query and key
|
||||
num_tokens = query_x.shape[0]
|
||||
num_heads = query_x.shape[1] // head_size
|
||||
num_kv_heads = key_x.shape[1] // head_size
|
||||
|
||||
# # [num_tokens, num_heads * head_size] -> [num_tokens, num_heads, head_size]
|
||||
# query_x = query_x.view(num_tokens, num_heads, head_size)
|
||||
# # [num_tokens, num_kv_heads * head_size] -> [num_tokens, num_kv_heads, head_size]
|
||||
# key_x = key_x.view(num_tokens, num_kv_heads, head_size)
|
||||
|
||||
# # Ensure shapes are correct
|
||||
# assert query_x.shape == (num_tokens, num_heads, head_size), \
|
||||
# f"Expected query shape [{num_tokens}, {num_heads}, {head_size}], got {query_x.shape}"
|
||||
# assert key_x.shape == (num_tokens, num_kv_heads, head_size), \
|
||||
# f"Expected key shape [{num_tokens}, {num_kv_heads}, {head_size}], got {key_x.shape}"
|
||||
|
||||
torch.ops._C.rotary_embedding(
|
||||
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
|
||||
)
|
||||
|
||||
query_x = query_x.view(num_tokens, num_heads * head_size)
|
||||
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
|
||||
|
||||
# query.data = query_x
|
||||
# key.data = key_x
|
||||
return query_x, key_x
|
||||
|
||||
# Rotary embedding
|
||||
@staticmethod
|
||||
def mrotary_embedding(
|
||||
positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
|
||||
):
|
||||
"""
|
||||
refactor RotaryEmbedding forward function
|
||||
"""
|
||||
query_x = query.contiguous()
|
||||
key_x = key.contiguous()
|
||||
query_x_dim = query_x.dim()
|
||||
assert is_neox_style
|
||||
xtorch_ops.mrotary_embedding_neox(
|
||||
positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
|
||||
)
|
||||
|
||||
query.data = query_x
|
||||
key.data = key_x
|
||||
return query, key
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(src, dst, block_mapping):
|
||||
"""swap_blocks"""
|
||||
xtorch_ops.swap_blocks(src, dst, block_mapping)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(key_caches, value_caches, block_mapping):
|
||||
"""copy_blocks"""
|
||||
for i in range(len(key_caches)):
|
||||
key_caches[i] = key_caches[i].contiguous()
|
||||
value_caches[i] = value_caches[i].contiguous()
|
||||
xtorch_ops.copy_blocks(
|
||||
key_caches,
|
||||
value_caches,
|
||||
block_mapping,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype,
|
||||
):
|
||||
"""reshape_and_cache"""
|
||||
# slot_mapping_cast = slot_mapping.to(torch.int32)
|
||||
xtorch_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
@staticmethod
|
||||
def multi_query_kv_attention(
|
||||
usual_seq_lod_xpu: torch.Tensor,
|
||||
usual_seq_lod_cpu: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
**kargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
"""
|
||||
if query.dim() == 3:
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
output = torch.empty_like(query)
|
||||
alibi_slopes = kargs.get("alibi_slopes", None)
|
||||
mask = kargs.get("mask", None)
|
||||
is_causal = kargs.get("is_causal", True)
|
||||
is_lvsl = kargs.get("is_lvsl", True)
|
||||
|
||||
B, T, Qh, Hd = query.shape
|
||||
KVh = key.size(2)
|
||||
if KVh != Qh:
|
||||
repeat = Qh // KVh
|
||||
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
|
||||
value = value.repeat_interleave(repeat, dim=2)
|
||||
xtorch_ops.attention(
|
||||
q=query,
|
||||
k_cache=key,
|
||||
v_cache=value,
|
||||
out=output,
|
||||
is_causal=True,
|
||||
is_prefill=True,
|
||||
context_seq_lod_cpu=usual_seq_lod_cpu,
|
||||
context_seq_lod_xpu=usual_seq_lod_xpu,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def quant_fusedresidual_rmsnorm_op(
|
||||
x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
|
||||
):
|
||||
"""Quantized fused residual layer normalization"""
|
||||
out = torch.empty_like(x, dtype=torch.int8)
|
||||
|
||||
if is_per_token_smooth_quant():
|
||||
out_scale = torch.empty(
|
||||
x.shape[:-1], device=x.device, dtype=torch.float
|
||||
).unsqueeze(-1)
|
||||
else:
|
||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||
|
||||
xtorch_ops.quant_fusedresidual_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
out=out,
|
||||
out_scale=out_scale,
|
||||
residual_tensor=residual,
|
||||
)
|
||||
|
||||
if residual is None:
|
||||
return out, out_scale
|
||||
return out, out_scale, residual
|
||||
|
||||
@staticmethod
|
||||
def quant_rmsnorm_op(
|
||||
x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
|
||||
):
|
||||
"""Quantized RMSNorm"""
|
||||
|
||||
out = torch.empty_like(x, dtype=torch.int8)
|
||||
if is_per_token_smooth_quant():
|
||||
out_scale = torch.empty(
|
||||
x.shape[:-1], device=x.device, dtype=torch.float
|
||||
).unsqueeze(-1)
|
||||
else:
|
||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||
|
||||
xtorch_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
|
||||
return out, out_scale
|
||||
|
||||
@staticmethod
|
||||
def smooth_quant_matmul_column_row_kernels(
|
||||
input_tensor,
|
||||
weight,
|
||||
smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
otype,
|
||||
):
|
||||
"""smooth_quant_matmul_column_row_kernels"""
|
||||
input_shape = input_tensor.shape
|
||||
weight_shape = weight.shape
|
||||
if input_tensor.dim() == 3:
|
||||
input_tensor = input_tensor.reshape(-1, input_shape[-1])
|
||||
out = torch.empty(
|
||||
(input_shape[0] * input_shape[1], weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device,
|
||||
)
|
||||
output_bs_shape = [input_shape[0], input_shape[1]]
|
||||
elif input_tensor.dim() == 2:
|
||||
out = torch.empty(
|
||||
(input_shape[0], weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device,
|
||||
)
|
||||
output_bs_shape = [-1]
|
||||
xtorch_ops.smooth_quant_matmul_column_row_kernels(
|
||||
input_tensor,
|
||||
weight,
|
||||
smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
out=out,
|
||||
)
|
||||
|
||||
out = out.view(*output_bs_shape, weight_shape[0])
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""fused_moe"""
|
||||
output = torch.empty(
|
||||
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
expert_num = linear_weights.shape[0]
|
||||
|
||||
torch.ops._C.moe_ffn_block(
|
||||
x=hidden_states,
|
||||
gate_w=linear_weights,
|
||||
inter_w=w1,
|
||||
output_w=w2,
|
||||
expert_num=expert_num,
|
||||
moe_top_k=topk,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
expert_group_num=num_expert_group,
|
||||
out=output,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_moe_ep(
|
||||
hidden_states: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
ep_rank: int,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
x = hidden_states
|
||||
batch, hidden_size = x.shape
|
||||
num_local_experts, up_gate_size, _ = w13_weight.shape
|
||||
|
||||
router_logits = x.to(linear_weights.dtype) @ linear_weights.T
|
||||
|
||||
topk_weights = torch.empty(
|
||||
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
|
||||
)
|
||||
topk_ids = torch.empty(
|
||||
batch, top_k, dtype=torch.int32, device=router_logits.device
|
||||
)
|
||||
block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
|
||||
torch.ops._C.moe_softmax_topk(
|
||||
router_logits, topk_weights, topk_ids, block_static
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
out = torch.zeros(batch * top_k, hidden_size, dtype=x.dtype, device=x.device)
|
||||
repeat_x = x.repeat_interleave(top_k, dim=0)
|
||||
topk_ids_flat = topk_ids.flatten()
|
||||
for i in range(num_local_experts):
|
||||
experts_id = ep_rank * num_local_experts + i
|
||||
selected_token = topk_ids_flat == experts_id
|
||||
if selected_token.sum():
|
||||
cur_token = repeat_x[selected_token]
|
||||
up_gate = torch.empty(
|
||||
selected_token.sum(),
|
||||
up_gate_size // 2,
|
||||
dtype=cur_token.dtype,
|
||||
device=cur_token.device,
|
||||
)
|
||||
torch.ops._C.swiglu(cur_token @ w13_weight[i].T, up_gate)
|
||||
out[selected_token] = up_gate @ w2_weight[i].T
|
||||
output = (
|
||||
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
|
||||
.sum(dim=1)
|
||||
.to(x.dtype)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_multi_head_latent_page_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
q_lora_rank: int,
|
||||
kv_lora_rank: int,
|
||||
q_a_proj_w: torch.Tensor,
|
||||
q_a_layernorm_w: torch.Tensor,
|
||||
q_b_proj_w: torch.Tensor,
|
||||
q_proj_w: torch.Tensor,
|
||||
kv_a_proj_w: torch.Tensor,
|
||||
kv_a_layernorm_w: torch.Tensor,
|
||||
kv_b_proj_w: torch.Tensor,
|
||||
o_proj_w: torch.Tensor,
|
||||
head_num: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
max_context_len: int,
|
||||
layernorm_eps: float,
|
||||
scale: float,
|
||||
is_causal: bool,
|
||||
is_context: bool,
|
||||
mp_size: int,
|
||||
local_rank: int,
|
||||
rotary_pos_embedding: torch.Tensor,
|
||||
pa_block_tables: torch.Tensor,
|
||||
position: torch.Tensor,
|
||||
context_lens_cpu: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
prompt_lods_cpu: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""mla pa block"""
|
||||
output = torch.empty(
|
||||
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
xtorch_ops.xft_multi_head_latent_page_attention_block(
|
||||
hidden_states,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
q_a_proj_w,
|
||||
q_a_layernorm_w,
|
||||
q_b_proj_w,
|
||||
q_proj_w,
|
||||
kv_a_proj_w,
|
||||
kv_a_layernorm_w,
|
||||
kv_b_proj_w,
|
||||
o_proj_w,
|
||||
head_num,
|
||||
qk_nope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
v_head_dim,
|
||||
max_context_len,
|
||||
layernorm_eps,
|
||||
scale,
|
||||
is_causal,
|
||||
is_context,
|
||||
mp_size,
|
||||
local_rank,
|
||||
rotary_pos_embedding,
|
||||
pa_block_tables,
|
||||
position,
|
||||
None,
|
||||
context_lens_cpu,
|
||||
slot_mapping,
|
||||
None,
|
||||
prompt_lods_cpu,
|
||||
out=output,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
)
|
||||
return output
|
||||
23
vllm_kunlun/ops/activation.py
Normal file
23
vllm_kunlun/ops/activation.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Custom activation functions."""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_silu_and_mul")
|
||||
class SiluAndMul(CustomOp):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
torch.ops._C.swiglu(x, out)
|
||||
return out
|
||||
3
vllm_kunlun/ops/attention/__init__.py
Normal file
3
vllm_kunlun/ops/attention/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# from .backends import KunlunMetadata
|
||||
|
||||
# __all__ = ['KunlunMetadata']
|
||||
3
vllm_kunlun/ops/attention/backends/__init__.py
Normal file
3
vllm_kunlun/ops/attention/backends/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# from .kunlun_attn import KunlunMetadata
|
||||
|
||||
# __all__ = ['KunlunMetadata']
|
||||
803
vllm_kunlun/ops/attention/backends/kunlun_attn.py
Normal file
803
vllm_kunlun/ops/attention/backends/kunlun_attn.py
Normal file
@@ -0,0 +1,803 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Bao Qian, Dong Xinyu, Chen Zhennan, Ma Tianyu
|
||||
# Email: baoqian@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""kunlun attention wrapper for context and decode"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
from itertools import accumulate
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
)
|
||||
from .utils import CommonAttentionState, CommonMetadataBuilder
|
||||
from vllm.attention.backends.utils import (
|
||||
is_block_tables_empty,
|
||||
compute_slot_mapping_start_idx,
|
||||
compute_slot_mapping,
|
||||
)
|
||||
from vllm_kunlun.ops.paged_attn import PagedAttention, PagedAttentionMetadata
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import async_tensor_h2d
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KunlunAttentionBackend(AttentionBackend):
|
||||
"""KunlunAttentionBackend"""
|
||||
|
||||
accept_output_buffer = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "KUNLUN_ATTENTION"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["KunlunAttentionImpl"]:
|
||||
"""get_impl_cls"""
|
||||
return KunlunAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["KunlunMetadata"]:
|
||||
"""get_metadata_cls"""
|
||||
return KunlunMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["KunlunMetadataBuilder"]:
|
||||
"""get_builder_cls"""
|
||||
return KunlunMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(
|
||||
num_blocks, block_size, num_kv_heads, head_size
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""KunlunMetadata"""
|
||||
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# FIXME: It is for flash attn.
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# Max number of key/value length in the batch, especially for prefix cache
|
||||
max_kv_len: Optional[int] = None
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
query_start_loc_host: Optional[torch.Tensor] = None
|
||||
# serve only for prefix cache
|
||||
kv_prefix_start_loc_host: Optional[torch.Tensor] = None
|
||||
kv_prefix_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Self-attention prefill/decode metadata cache
|
||||
_cached_prefill_metadata: Optional["KunlunMetadata"] = None
|
||||
_cached_decode_metadata: Optional["KunlunMetadata"] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.cross_attn_bias: Optional[List[AttentionBias]] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
"""
|
||||
All attention metadata required for encoder attention is set.
|
||||
"""
|
||||
return (
|
||||
(self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None)
|
||||
)
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
"""
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
"""
|
||||
return (
|
||||
self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None)
|
||||
)
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["KunlunMetadata"]:
|
||||
"""prefill_metadata"""
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
# Recover cached prefill-phase attention
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert (self.seq_lens is not None) or (self.encoder_seq_lens is not None)
|
||||
assert (self.seq_lens_tensor is not None) or (
|
||||
self.encoder_seq_lens_tensor is not None
|
||||
)
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (
|
||||
None
|
||||
if self.query_start_loc is None
|
||||
else self.query_start_loc[: self.num_prefills + 1]
|
||||
)
|
||||
# flash attention needs both lod information on host and device
|
||||
query_start_loc_host = (
|
||||
None
|
||||
if self.query_start_loc_host is None
|
||||
else self.query_start_loc_host[: self.num_prefills + 1]
|
||||
)
|
||||
kv_prefix_start_loc_host = (
|
||||
None
|
||||
if self.kv_prefix_start_loc_host is None
|
||||
else self.kv_prefix_start_loc_host[: self.num_prefills + 1]
|
||||
+ query_start_loc_host
|
||||
)
|
||||
kv_prefix_start_loc = (
|
||||
None
|
||||
if kv_prefix_start_loc_host is None
|
||||
else kv_prefix_start_loc_host.cuda()
|
||||
)
|
||||
slot_mapping = (
|
||||
None
|
||||
if self.slot_mapping is None
|
||||
else self.slot_mapping[: self.num_prefill_tokens]
|
||||
)
|
||||
seq_lens = None if self.seq_lens is None else self.seq_lens[: self.num_prefills]
|
||||
seq_lens_tensor = (
|
||||
None
|
||||
if self.seq_lens_tensor is None
|
||||
else self.seq_lens_tensor[: self.num_prefills]
|
||||
)
|
||||
context_lens_tensor = (
|
||||
None
|
||||
if self.context_lens_tensor is None
|
||||
else self.context_lens_tensor[: self.num_prefills]
|
||||
)
|
||||
# for prefix cache, block table only contains blocks that hit
|
||||
# if self.block_tables is None:
|
||||
# block_tables = None
|
||||
# elif self.block_tables.shape[1] == 0:
|
||||
# block_tables = self.block_tables[:self.num_prefills]
|
||||
# else:
|
||||
# block_tables = self.block_tables[:self.num_prefills][:, -1].clone()
|
||||
|
||||
block_tables = (
|
||||
None
|
||||
if self.block_tables is None
|
||||
else self.block_tables[: self.num_prefills]
|
||||
)
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = KunlunMetadata(
|
||||
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_kv_len=self.max_kv_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_host=query_start_loc_host,
|
||||
kv_prefix_start_loc=kv_prefix_start_loc,
|
||||
kv_prefix_start_loc_host=kv_prefix_start_loc_host,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
enable_kv_scales_calculation=False,
|
||||
seq_start_loc=self.seq_start_loc,
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["KunlunMetadata"]:
|
||||
"""decode_metadata"""
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
assert (self.seq_lens_tensor is not None) or (
|
||||
self.encoder_seq_lens_tensor is not None
|
||||
)
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (
|
||||
None
|
||||
if self.slot_mapping is None
|
||||
else self.slot_mapping[self.num_prefill_tokens :]
|
||||
)
|
||||
seq_lens_tensor = (
|
||||
None
|
||||
if self.seq_lens_tensor is None
|
||||
else self.seq_lens_tensor[self.num_prefills :]
|
||||
)
|
||||
seq_lens_tensor_cpu = (
|
||||
None
|
||||
if self.seq_lens_tensor_cpu is None
|
||||
else self.seq_lens_tensor_cpu[self.num_prefills :]
|
||||
)
|
||||
block_tables = (
|
||||
None
|
||||
if self.block_tables is None
|
||||
else self.block_tables[self.num_prefills :]
|
||||
)
|
||||
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = KunlunMetadata(
|
||||
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
seq_lens_tensor_cpu=seq_lens_tensor_cpu,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
|
||||
"""KunlunMetadataBuilder"""
|
||||
|
||||
_metadata_cls = KunlunMetadata
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
super().__init__(input_builder)
|
||||
self.prefix_cache_kv_lens: List[int] = []
|
||||
|
||||
def prepare(self):
|
||||
"""prepare"""
|
||||
super().prepare()
|
||||
self.prefix_cache_kv_lens = list()
|
||||
|
||||
def _add_seq_group(
|
||||
self,
|
||||
inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool,
|
||||
):
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (
|
||||
seq_id,
|
||||
token_len,
|
||||
seq_len,
|
||||
curr_seq_len,
|
||||
query_len,
|
||||
context_len,
|
||||
curr_sliding_window_block,
|
||||
) in zip(
|
||||
inter_data.seq_ids,
|
||||
[len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens,
|
||||
inter_data.seq_lens,
|
||||
inter_data.query_lens,
|
||||
inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks,
|
||||
):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert (
|
||||
query_len == 1
|
||||
), "seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len
|
||||
)
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
block_table = []
|
||||
assert (
|
||||
not chunked_prefill_enabled
|
||||
), "chunk prefill not supported for kunlun attention"
|
||||
if inter_data.prefix_cache_hit:
|
||||
assert context_len != 0
|
||||
assert context_len % self.block_size == 0
|
||||
# block_table = block_tables[seq_id]
|
||||
block_table = block_tables[seq_id][: context_len // self.block_size]
|
||||
elif (not is_prompt) and block_tables is not None:
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
if is_prompt:
|
||||
self.prefix_cache_kv_lens.append(context_len)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(
|
||||
is_prompt, query_len, context_len, self.sliding_window
|
||||
)
|
||||
compute_slot_mapping(
|
||||
is_profile_run,
|
||||
self.slot_mapping,
|
||||
seq_id,
|
||||
seq_len,
|
||||
context_len,
|
||||
start_idx,
|
||||
self.block_size,
|
||||
inter_data.block_tables,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
seq_lens: List[int],
|
||||
query_lens: List[int],
|
||||
cuda_graph_pad_size: int,
|
||||
batch_size: int,
|
||||
):
|
||||
"""build"""
|
||||
attn_meta = super().build(seq_lens, query_lens, cuda_graph_pad_size, batch_size)
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
query_start_loc_host = torch.tensor(
|
||||
query_start_loc, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
attn_meta.query_start_loc_host = query_start_loc_host
|
||||
# max_kv_len = max(query_lens + prefix_cache_kv_lens)
|
||||
attn_meta.max_kv_len = max(self.prefix_cache_kv_lens + attn_meta.seq_lens)
|
||||
|
||||
# If kv cache is included and there is a hit
|
||||
if len(self.prefix_cache_kv_lens) != 0 and max(self.prefix_cache_kv_lens) != 0:
|
||||
self.prefix_cache_kv_lens = list(
|
||||
accumulate(self.prefix_cache_kv_lens, initial=0)
|
||||
)
|
||||
prefix_cache_kv_lens_tensor = torch.tensor(
|
||||
self.prefix_cache_kv_lens, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
attn_meta.kv_prefix_start_loc_host = prefix_cache_kv_lens_tensor
|
||||
attn_meta.seq_lens_tensor_cpu = attn_meta.seq_lens_tensor.to("cpu")
|
||||
return attn_meta
|
||||
|
||||
|
||||
def _get_seq_len_block_table_args(
|
||||
attn_metadata: KunlunMetadata,
|
||||
is_prompt: bool,
|
||||
attn_type: AttentionType,
|
||||
) -> tuple:
|
||||
"""
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
cross-attn block-tables fields
|
||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention op
|
||||
* is_prompt: True if prefill, False otherwise
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
"""
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
if is_prompt:
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len, attn_metadata.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables,
|
||||
)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
"""KunlunAttentionImpl"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError("kunlunAttention does not support block-sparse attention.")
|
||||
# if logits_soft_cap is not None:
|
||||
# raise ValueError(
|
||||
# "kunlunAttention does not support attention logits soft capping.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in suppored_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {suppored_head_sizes}."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor],
|
||||
value: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: "KunlunAttnMetadata",
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with KunlunAttn and PagedAttention.
|
||||
|
||||
For decoder-only models: query, key and value must be non-None.
|
||||
|
||||
For encoder/decoder models:
|
||||
* KunlunAttnImpl.forward() may be invoked for both self- and cross-
|
||||
attention layers.
|
||||
* For self-attention: query, key and value must be non-None.
|
||||
* For cross-attention:
|
||||
* Query must be non-None
|
||||
* During prefill, key and value must be non-None; key and value
|
||||
get cached for use during decode.
|
||||
* During decode, key and value may be None, since:
|
||||
(1) key and value tensors were cached during prefill, and
|
||||
(2) cross-attention key and value tensors do not grow during
|
||||
decode
|
||||
|
||||
A note on how the attn_type (attention type enum) argument impacts
|
||||
attention forward() behavior:
|
||||
|
||||
* DECODER: normal decoder-only behavior;
|
||||
use decoder self-attention block table
|
||||
* ENCODER: no KV caching; pass encoder sequence
|
||||
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len) to kernel, in lieu of decoder
|
||||
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
|
||||
Used for encoder branch of encoder-decoder models.
|
||||
* ENCODER_ONLY: no kv_caching, uses the normal attention
|
||||
attributes (seq_lens/seq_lens_tensor/max_seq_len).
|
||||
* ENCODER_DECODER: cross-attention behavior;
|
||||
use cross-attention block table for caching KVs derived
|
||||
from encoder hidden states; since KV sequence lengths
|
||||
will match encoder sequence lengths, pass encoder sequence
|
||||
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len)
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
attn_type: Select attention type, between encoder attention,
|
||||
decoder self-attention, or encoder/decoder cross-
|
||||
attention. Defaults to decoder self-attention,
|
||||
which is the vLLM default generally
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
|
||||
# Check that appropriate attention metadata attributes are
|
||||
# selected for the desired attention type
|
||||
if attn_type == AttentionType.ENCODER and (
|
||||
not attn_metadata.is_all_encoder_attn_metadata_set
|
||||
):
|
||||
raise AttributeError(
|
||||
"Encoder attention requires setting " "encoder metadata attributes."
|
||||
)
|
||||
|
||||
elif attn_type == AttentionType.ENCODER_DECODER and (
|
||||
not attn_metadata.is_all_cross_attn_metadata_set
|
||||
):
|
||||
raise AttributeError(
|
||||
"Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes."
|
||||
)
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
assert value is None
|
||||
|
||||
# Self-attention vs. cross-attention will impact
|
||||
# which KV cache memory-mapping & which
|
||||
# seqlen datastructures we utilize
|
||||
|
||||
if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
#
|
||||
# Even if there are no new key/value pairs to cache,
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
value = value.contiguous()
|
||||
KunlunOps.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
)
|
||||
|
||||
if attn_type == AttentionType.ENCODER:
|
||||
# Encoder attention - chunked prefill is not applicable;
|
||||
# derive token-count from query shape & and treat them
|
||||
# as 100% prefill tokens
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||
num_encoder_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_tokens = 0
|
||||
elif attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_encoder_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Only enforce this shape-constraint for decoder
|
||||
# self-attention
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
else: # attn_type == AttentionType.ENCODER_DECODER
|
||||
# Encoder/decoder cross-attention requires no chunked
|
||||
# prefill (100% prefill or 100% decode tokens, no mix)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
if attn_metadata.num_encoder_tokens is not None:
|
||||
num_encoder_tokens = attn_metadata.num_encoder_tokens
|
||||
else:
|
||||
num_encoder_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
if key is not None and value is not None:
|
||||
key = key[:num_encoder_tokens]
|
||||
value = value[:num_encoder_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||
out = KunlunOps.multi_query_kv_attention(
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.query_start_loc_host,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
).view_as(query)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert (
|
||||
attn_type != AttentionType.ENCODER_ONLY
|
||||
), "Encoder-only models should not have decode metadata."
|
||||
(
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
block_tables_arg,
|
||||
) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables_arg,
|
||||
seq_lens_arg,
|
||||
decode_meta.seq_lens_tensor_cpu,
|
||||
False,
|
||||
max_seq_len_arg,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
604
vllm_kunlun/ops/attention/backends/utils.py
Normal file
604
vllm_kunlun/ops/attention/backends/utils.py
Normal file
@@ -0,0 +1,604 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention backend utils"""
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
|
||||
TypeVar, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
||||
AttentionState)
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||
|
||||
# Error string(s) for encoder/decoder
|
||||
# unsupported attention scenarios
|
||||
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
|
||||
"with encoder/decoder models.")
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
# Switch to numpy implementation of compute_slot_mapping
|
||||
# if we have at least this many elements. Could be tuned further.
|
||||
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
|
||||
def is_block_tables_empty(block_tables: Union[None, Dict]):
|
||||
"""
|
||||
Check if block_tables is None or a dictionary with all None values.
|
||||
"""
|
||||
if block_tables is None:
|
||||
return True
|
||||
return (isinstance(block_tables, dict)
|
||||
and all(value is None for value in block_tables.values()))
|
||||
|
||||
|
||||
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
|
||||
context_len: int, sliding_window: int):
|
||||
"""
|
||||
Compute the start index of slot mapping.
|
||||
"""
|
||||
start_idx = 0
|
||||
if is_prompt and sliding_window is not None:
|
||||
start_idx = max(0, query_len - sliding_window)
|
||||
return start_idx
|
||||
|
||||
|
||||
def _compute_slot_mapping_python(slot_mapping: List[int],
|
||||
block_table: List[int], range_start: int,
|
||||
range_end: int, block_size: int):
|
||||
for i in range(range_start, range_end):
|
||||
block_number = block_table[i // block_size]
|
||||
block_offset = i % block_size
|
||||
slot = block_number * block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
|
||||
def _compute_slot_mapping_numpy(slot_mapping: List[int],
|
||||
block_table: List[int], range_start: int,
|
||||
range_end: int, block_size: int):
|
||||
block_table_array = np.array(block_table)
|
||||
idx = np.arange(range_start, range_end)
|
||||
block_offset = idx % block_size
|
||||
idx //= block_size
|
||||
seq_slot_mapping_array = block_table_array[idx]
|
||||
seq_slot_mapping_array *= block_size
|
||||
seq_slot_mapping_array += block_offset
|
||||
slot_mapping.extend(seq_slot_mapping_array)
|
||||
|
||||
|
||||
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
|
||||
seq_id: int, seq_len: int, context_len: int,
|
||||
start_idx: int, block_size: int,
|
||||
block_tables: Dict[int, List[int]]):
|
||||
"""
|
||||
Compute slot mapping.
|
||||
"""
|
||||
if is_profile_run:
|
||||
# During memory profiling, the block tables are not
|
||||
# initialized yet. In this case, we just use a dummy
|
||||
# slot mapping.
|
||||
# In embeddings, the block tables are {seq_id: None}.
|
||||
slot_mapping.extend([PAD_SLOT_ID] * seq_len)
|
||||
return
|
||||
|
||||
# Mask the [0, start_idx) tokens of the prompt with
|
||||
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
|
||||
# sliding_window). For example, if the prompt len is 10,
|
||||
# sliding window is 8, and block size is 4, the first two
|
||||
# tokens are masked and the slot mapping will be
|
||||
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
padding_mask_len = max(0, start_idx - context_len)
|
||||
slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
|
||||
|
||||
range_start = max(start_idx, context_len)
|
||||
range_end = seq_len
|
||||
numel = range_end - range_start
|
||||
block_table = block_tables[seq_id]
|
||||
|
||||
# numpy implementation will be faster than python if we have
|
||||
# many elements, otherwise it will be slower.
|
||||
if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
|
||||
_compute_slot_mapping_python(slot_mapping, block_table, range_start,
|
||||
range_end, block_size)
|
||||
else:
|
||||
_compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
|
||||
range_end, block_size)
|
||||
|
||||
|
||||
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
|
||||
|
||||
|
||||
class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
"""CommonMetadataBuilder"""
|
||||
|
||||
_metadata_cls: Type[TAttentionMetadata]
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
"""prepare"""
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.block_tables: List[List[int]] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert query_len == 1, (
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if inter_data.prefix_cache_hit:
|
||||
block_table = block_tables[seq_id]
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size
|
||||
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
input_block_tables = self.runner.graph_block_tables[:batch_size]
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
if block_table:
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
block_tables = torch.from_numpy(input_block_tables).to(
|
||||
device, non_blocking=True)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
return self._metadata_cls( # type: ignore
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
|
||||
|
||||
class CommonAttentionState(AttentionState):
|
||||
"""CommonAttentionState"""
|
||||
|
||||
def __init__(self, runner: "ModelRunnerBase"):
|
||||
self.runner = runner
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
"""graph_capture"""
|
||||
|
||||
self._is_graph_capturing = True
|
||||
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens_cpu = self._graph_seq_lens.to('cpu')
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
|
||||
yield
|
||||
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_seq_lens_cpu
|
||||
del self._graph_block_tables
|
||||
|
||||
def graph_clone(self, batch_size: int) -> "CommonAttentionState":
|
||||
"""graph_clone"""
|
||||
assert self._is_graph_capturing
|
||||
return self.__class__(self.runner)
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
"""graph_capture_get_metadata_for_batch"""
|
||||
assert self._is_graph_capturing
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
seq_lens_tensor_cpu=self._graph_seq_lens_cpu[:batch_size],
|
||||
max_query_len=1,
|
||||
max_decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self._graph_block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in \
|
||||
["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS'," \
|
||||
f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._update_captured_metadata_for_enc_dec_model(
|
||||
batch_size=batch_size, attn_metadata=attn_metadata)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(
|
||||
self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
|
||||
"""get_graph_input_buffers"""
|
||||
input_buffers = {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"seq_lens_tensor_cpu": attn_metadata.decode_metadata.seq_lens_tensor_cpu,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
}
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in \
|
||||
["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS'," \
|
||||
f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._add_additional_input_buffers_for_enc_dec_model(
|
||||
attn_metadata=attn_metadata, input_buffers=input_buffers)
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(
|
||||
self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False) -> None:
|
||||
"""prepare_graph_input_buffers"""
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in\
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
||||
f"'FLASH_ATTN', but "\
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._prepare_input_buffers_for_enc_dec_model(
|
||||
attn_metadata, input_buffers)
|
||||
|
||||
def begin_forward(self, model_input) -> None:
|
||||
"""begin_forward"""
|
||||
return
|
||||
|
||||
def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
|
||||
attn_metadata):
|
||||
"""
|
||||
Updates the attention metadata parameters for CUDA graph capture in an
|
||||
encoder-decoder model.
|
||||
|
||||
This method modifies attention-related tensors and metadata required
|
||||
for CUDA graph capture in encoder-decoder models. Specifically, it
|
||||
updates the cross-attention and encoder sequence tensors in the
|
||||
AttentionMetadata object.
|
||||
"""
|
||||
# During decode phase the cross_slot_mapping will be empty. Hence set
|
||||
# an empty tensor for CUDA Graph capture.
|
||||
attn_metadata.cross_slot_mapping = torch.tensor(
|
||||
[], dtype=torch.int).cuda()
|
||||
attn_metadata.cross_block_tables = torch.full(
|
||||
(batch_size, self.runner.get_max_block_per_batch()),
|
||||
1,
|
||||
dtype=torch.int).cuda()
|
||||
attn_metadata.encoder_seq_lens = torch.full((batch_size, ),
|
||||
1,
|
||||
dtype=torch.int).cuda()
|
||||
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
||||
(batch_size, ), 1, dtype=torch.int).cuda()
|
||||
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
|
||||
attn_metadata.num_encoder_tokens = 0
|
||||
|
||||
def _add_additional_input_buffers_for_enc_dec_model(
|
||||
self, attn_metadata, input_buffers: Dict[str, Any]):
|
||||
"""
|
||||
Saves additional input buffers specific to the encoder-decoder model
|
||||
from the attention metadata.
|
||||
|
||||
This method extracts and stores encoder-decoder related input buffers
|
||||
from the `attn_metadata` into the `input_buffers` dictionary. The
|
||||
buffers include encoder sequence lengths, cross-slot mappings, and
|
||||
cross-block tables, which are essential for the encoder-decoder model
|
||||
during CUDA graph replay.
|
||||
"""
|
||||
input_buffers["encoder_seq_lens_tensor"] = (
|
||||
attn_metadata.decode_metadata.encoder_seq_lens_tensor)
|
||||
input_buffers["seq_lens_tensor_cpu"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor_cpu, non_blocking=True)
|
||||
input_buffers["cross_slot_mapping"] = (
|
||||
attn_metadata.decode_metadata.cross_slot_mapping)
|
||||
input_buffers["cross_block_tables"] = (
|
||||
attn_metadata.decode_metadata.cross_block_tables)
|
||||
|
||||
def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
|
||||
input_buffers: Dict[str,
|
||||
Any]):
|
||||
"""
|
||||
Populates input buffers with data from the encoder-decoder model's
|
||||
attention metadata.
|
||||
|
||||
This method fills the input buffers with encoder-decoder specific
|
||||
tensors. It copies data from the `attn_metadata` and keyword arguments
|
||||
(`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
|
||||
The copied data includes attention-related metadata as well as input
|
||||
IDs and positional information for the encoder.
|
||||
"""
|
||||
input_buffers["encoder_seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.encoder_seq_lens_tensor,
|
||||
non_blocking=True)
|
||||
input_buffers["cross_slot_mapping"].copy_(
|
||||
attn_metadata.decode_metadata.cross_slot_mapping,
|
||||
non_blocking=True)
|
||||
input_buffers["cross_block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.cross_block_tables,
|
||||
non_blocking=True)
|
||||
|
||||
|
||||
def is_all_encoder_attn_metadata_set(attn_metadata):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return ((attn_metadata.encoder_seq_lens is not None)
|
||||
and (attn_metadata.encoder_seq_lens_tensor is not None)
|
||||
and (attn_metadata.max_encoder_seq_len is not None))
|
||||
|
||||
|
||||
def is_all_cross_attn_metadata_set(attn_metadata):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return (attn_metadata.is_all_encoder_attn_metadata_set
|
||||
and (attn_metadata.cross_slot_mapping is not None)
|
||||
and (attn_metadata.cross_block_tables is not None))
|
||||
|
||||
|
||||
def get_seq_len_block_table_args(
|
||||
attn_metadata,
|
||||
is_prompt: bool,
|
||||
attn_type: str,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
cross-attn block-tables fields
|
||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention op
|
||||
* is_prompt: True if prefill, False otherwise
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
'''
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
if is_prompt:
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
||||
attn_metadata.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len, None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
def get_num_prefill_decode_query_kv_tokens(
|
||||
attn_metadata,
|
||||
attn_type: str,
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Calculate the number of prefill and decode tokens for query, key/value
|
||||
based on the attention metadata and the specified attention type.
|
||||
|
||||
Args:
|
||||
attn_metadata (AttentionMetadata): Attention Metadata object.
|
||||
attn_type (AttentionType): The type of attention being used.
|
||||
Returns:
|
||||
Tuple[int, int, int]: A tuple containing three integers:
|
||||
- The number of prefill query tokens.
|
||||
- The number of prefill key/value tokens.
|
||||
- The number of decode query tokens.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the number of encoder tokens in `attn_metadata`
|
||||
is `None` when required for the calculations.
|
||||
"""
|
||||
num_prefill_query_tokens = 0
|
||||
num_decode_query_tokens = 0
|
||||
num_prefill_kv_tokens = 0
|
||||
if attn_type == AttentionType.ENCODER:
|
||||
# Encoder attention is only invoked during prefill phase.
|
||||
# The same input servers a both query and key.
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_query_tokens = attn_metadata.num_encoder_tokens
|
||||
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_query_tokens = 0
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
|
||||
# The key is the encoder/cross-attention.
|
||||
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_query_tokens = attn_metadata.num_decode_tokens
|
||||
else: # attn_type == AttentionType.DECODER or
|
||||
# attn_type == AttentionType.ENCODER_ONLY
|
||||
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
|
||||
num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_query_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens)
|
||||
274
vllm_kunlun/ops/attention/layer.py
Normal file
274
vllm_kunlun/ops/attention/layer.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""layer.py"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List, Dict, Any
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group,
|
||||
)
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
from vllm.attention import Attention as VllmAttention
|
||||
from vllm.attention.layer import MultiHeadAttention as VllmMultiHeadAttention
|
||||
from torch.library import custom_op, impl
|
||||
|
||||
from vllm.platforms import _Backend
|
||||
|
||||
|
||||
class Attention(VllmAttention):
|
||||
"""Attention"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
"""
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
per_layer_sliding_window=per_layer_sliding_window,
|
||||
use_mla=use_mla,
|
||||
prefix=prefix,
|
||||
attn_type=attn_type,
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
**extra_impl_args,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output_shape: Optional[torch.Size] = None,
|
||||
) -> torch.Tensor:
|
||||
"""forward"""
|
||||
if self.calculate_kv_scales:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(query, key, value)
|
||||
if self.use_output:
|
||||
output_shape = output_shape if output_shape is not None else query.shape
|
||||
output = torch.zeros(output_shape, dtype=query.dtype, device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# We skip reshaping query, key and value tensors for the MLA
|
||||
# backend since these tensors have different semantics and are
|
||||
# processed differently.
|
||||
if not self.use_mla:
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(
|
||||
self, query, key, value, self_kv_cache, attn_metadata, output=output
|
||||
)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output_kunlun(
|
||||
query, key, value, output, self.layer_name
|
||||
)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
if self.use_direct_call:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(
|
||||
self, query, key, value, self_kv_cache, attn_metadata
|
||||
)
|
||||
else:
|
||||
return unified_attention(query, key, value, self.layer_name)
|
||||
|
||||
|
||||
#
|
||||
# Rewritten from the MultiHeadAttention class in vllm.attention.layer
|
||||
class MultiHeadAttention(VllmMultiHeadAttention):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
)
|
||||
|
||||
# kunlun only supports flash_attn
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape: batch_size x seq_len x hidden_size"""
|
||||
# TODO(Isotr0py): Use existing backend implementations and support FA3
|
||||
bsz, q_len, _ = query.size()
|
||||
kv_len = key.size(1)
|
||||
|
||||
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
# kunlun only supports flash_attn
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query, key, value, scale=self.scale
|
||||
)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
return out.reshape(bsz, q_len, -1)
|
||||
|
||||
|
||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
"""wait_for_kv_layer_from_connector"""
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
|
||||
def maybe_save_kv_layer_to_connector(
|
||||
layer_name: str, kv_cache_layer: List[torch.Tensor]
|
||||
):
|
||||
"""maybe_save_kv_layer_to_connector"""
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
|
||||
|
||||
|
||||
@custom_op("vllm::unified_attention_with_output_kunlun", mutates_args=())
|
||||
def unified_attention_with_output_kunlun(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self, query, key, value, kv_cache, attn_metadata, output=output)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
|
||||
|
||||
def _fake_unified_attention_with_output_kunlun(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
unified_attention_with_output_kunlun.register_fake(
|
||||
_fake_unified_attention_with_output_kunlun
|
||||
)
|
||||
|
||||
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
"""unified_attention"""
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
return output
|
||||
0
vllm_kunlun/ops/fused_moe/__init__.py
Normal file
0
vllm_kunlun/ops/fused_moe/__init__.py
Normal file
310
vllm_kunlun/ops/fused_moe/layer.py
Normal file
310
vllm_kunlun/ops/fused_moe/layer.py
Normal file
@@ -0,0 +1,310 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Author: Dong Xinyu, Chen Zhennan, Bao Qian, Yuan Jizhong
|
||||
# Email: dongxinyu03@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""layer.py"""
|
||||
import torch
|
||||
from typing import Callable, Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.distributed import get_ep_group
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE as VllmFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase as VllmFusedMoEMethodBase
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
UnquantizedFusedMoEMethod as VllmUnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig)
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_kunlun.ops.quantization.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
|
||||
|
||||
|
||||
class FusedMoEMethodBase(VllmFusedMoEMethodBase):
|
||||
"""FusedMoEMethodBase"""
|
||||
moe: FusedMoEConfig
|
||||
|
||||
@CustomOp.register("vllm_kunlun_unquantized_fused_moe")
|
||||
class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
"""UnquantizedFusedMoEMethod"""
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
linear_weights: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""apply"""
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
|
||||
|
||||
return self.forward_kunlun(x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
linear_weights=linear_weights)
|
||||
|
||||
def forward_kunlun(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None
|
||||
) -> torch.Tensor:
|
||||
"""forward_kunlun"""
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
|
||||
if self.moe.use_ep:
|
||||
return ops.fused_moe_ep(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
linear_weights,
|
||||
self.moe.ep_rank,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group
|
||||
)
|
||||
else:
|
||||
return ops.fused_moe(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
linear_weights,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group
|
||||
)
|
||||
|
||||
class FusedMoE(VllmFusedMoE):
|
||||
"""FusedMoE"""
|
||||
def __init__(self,
|
||||
num_experts: int, # Global number of experts
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = 0,
|
||||
topk_group: Optional[int] = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
dp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
num_redundant_experts: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts, # Global number of experts
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=reduce_results,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
ep_size=ep_size,
|
||||
dp_size=dp_size,
|
||||
prefix=prefix,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=activation,
|
||||
enable_eplb=enable_eplb,
|
||||
num_redundant_experts=num_redundant_experts,
|
||||
)
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.model_config is not None:
|
||||
model_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
# TODO (bnell): This is a hack to get test_mixtral_moe to work
|
||||
# since model_config is not set in the pytest test.
|
||||
model_dtype = params_dtype
|
||||
|
||||
moe = FusedMoEConfig.make(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=model_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
quant_method: Optional[QuantizeMethodBase] = None
|
||||
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
|
||||
else quant_config.get_quant_method(self, prefix))
|
||||
|
||||
assert quant_method is not None
|
||||
# assert isinstance(quant_method, FusedMoEMethodBase)
|
||||
self.quant_method = quant_method
|
||||
|
||||
if self.enable_eplb:
|
||||
from vllm_kunlun.ops.quantization.fp8 import (
|
||||
Fp8MoEMethod)
|
||||
if not isinstance(quant_method, Fp8MoEMethod):
|
||||
# TODO: Add support for additional quantization methods.
|
||||
# The implementation for other quantization methods does not
|
||||
# contain essential differences, but the current quant API
|
||||
# design causes duplicated work when extending to new
|
||||
# quantization methods, so I'm leaving it for now.
|
||||
# If you plan to add support for more quantization methods,
|
||||
# please refer to the implementation in `Fp8MoEMethod`.
|
||||
raise NotImplementedError("EPLB is only supported for FP8 "
|
||||
"quantization for now.")
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor = None,
|
||||
linear_weights: torch.Tensor = None):
|
||||
"""forward"""
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we will
|
||||
# switch to using the moe_forward custom op.
|
||||
if current_platform.is_tpu():
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
else:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[self.layer_name]
|
||||
assert self.quant_method is not None
|
||||
return self.forward_impl(hidden_states, router_logits, linear_weights)
|
||||
# return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||
# self.layer_name)
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
linear_weights: torch.Tensor = None):
|
||||
"""forward_impl"""
|
||||
assert self.quant_method is not None
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||
return self.forward_impl_chunked(hidden_states, router_logits)
|
||||
|
||||
do_naive_dispatch_combine: bool = (
|
||||
self.dp_size > 1
|
||||
and not self.moe_parallel_config.use_deepep_ht_kernels)
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
enable_eplb=self.enable_eplb,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
linear_weights=linear_weights
|
||||
)
|
||||
|
||||
if do_naive_dispatch_combine:
|
||||
final_hidden_states = get_ep_group().combine(final_hidden_states)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
# Default set to False. (May have to add shared expert outputs.
|
||||
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
60
vllm_kunlun/ops/layernorm.py
Normal file
60
vllm_kunlun/ops/layernorm.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from typing import Optional, Union
|
||||
import xtorch_ops
|
||||
|
||||
|
||||
def vllm_kunlun_forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""forward_cuda"""
|
||||
if x.is_contiguous() == False:
|
||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||
# so we must make it contiguous() manually
|
||||
x = x.contiguous()
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
if residual is not None:
|
||||
# residual_output = torch.empty_like(residual)
|
||||
torch.ops._C.add_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
residual_output=residual,
|
||||
weight=self.weight.data,
|
||||
eps=self.variance_epsilon,
|
||||
output=x,
|
||||
)
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.rmsnorm(
|
||||
x,
|
||||
self.weight.data,
|
||||
out,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RMSNorm.forward = vllm_kunlun_forward_cuda
|
||||
24
vllm_kunlun/ops/linear.py
Normal file
24
vllm_kunlun/ops/linear.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear as VllmReplicatedLinear
|
||||
|
||||
class ReplicatedLinear(VllmReplicatedLinear):
|
||||
"""Replicated linear layer"""
|
||||
|
||||
def get_weights(self):
|
||||
"""get_weights"""
|
||||
if hasattr(self, 'kunlun_linear_weights'):
|
||||
return self.kunlun_linear_weights
|
||||
weights = torch.nn.Parameter(self.weight.to(torch.float32))
|
||||
self.register_parameter("kunlun_linear_weights", weights)
|
||||
return self.kunlun_linear_weights
|
||||
|
||||
def get_weights_half(self):
|
||||
"""get_weights_half"""
|
||||
if hasattr(self, 'kunlun_linear_weights_half'):
|
||||
return self.kunlun_linear_weights_half
|
||||
weights = torch.nn.Parameter(self.weight.to(torch.float16))
|
||||
305
vllm_kunlun/ops/paged_attn.py
Normal file
305
vllm_kunlun/ops/paged_attn.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
if current_platform.is_kunlun():
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
else:
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.triton_utils.importing import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class PagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
"""
|
||||
Get the shape of the KV cache. Returns different shapes based on whether the computation is on-chip.
|
||||
If on-chip (is_kunlun() is True), returns shape (2, num_blocks, num_kv_heads, block_size, head_size);
|
||||
Otherwise, returns shape (2, num_blocks, block_size * num_kv_heads * head_size).
|
||||
|
||||
Args:
|
||||
num_blocks (int): The number of blocks.
|
||||
block_size (int): The size of each block.
|
||||
num_kv_heads (int): The number of KV heads.
|
||||
head_size (int): The size of each head.
|
||||
|
||||
Returns:
|
||||
Tuple[int, ...]: The shape of the KV cache, including two elements: the first element is 2, indicating the number of dimensions is 2; the second element is one of num_blocks, num_kv_heads, block_size, and head_size.
|
||||
"""
|
||||
if current_platform.is_kunlun():
|
||||
return (2, num_blocks, num_kv_heads, block_size, head_size)
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Split a cached tensor (containing key and value) into two parts, each part is a tensor.
|
||||
If running on KUNLUN, the first returned tensor is the key cache, and the second tensor is the value cache.
|
||||
Otherwise, the first tensor is the key cache, and the second tensor is a view of the key cache with shape (num_blocks, num_kv_heads, head_size//x, -1, x),
|
||||
and the third tensor is the value cache with shape (num_blocks, num_kv_heads, head_size, -1).
|
||||
|
||||
Args:
|
||||
kv_cache (torch.Tensor): A tensor containing key and value, with shape (2, num_blocks, kv_cache_size).
|
||||
num_kv_heads (int): The number of heads in multi-head attention.
|
||||
head_size (int): The size of each head.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
- key_cache (torch.Tensor): A tensor containing the key cache, with shape (num_blocks, num_kv_heads, head_size//x, -1, x).
|
||||
- value_cache (torch.Tensor): A tensor containing the value cache, with shape (num_blocks, num_kv_heads, head_size, -1).
|
||||
"""
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
if current_platform.is_kunlun():
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
else:
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
context_lens_cpu: torch.Tensor,
|
||||
is_context,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
assert (
|
||||
blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0
|
||||
), (
|
||||
f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables."
|
||||
)
|
||||
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
use_v1 = max_seq_len <= 8192 and (
|
||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||
)
|
||||
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
context_lens_cpu,
|
||||
is_context,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
context_lens_cpu,
|
||||
is_context,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward_prefix(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens_tensor: torch.Tensor,
|
||||
max_query_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
sliding_window: Optional[int],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
max_seq_len = None
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
# query_start_loc is (batch_size + 1,)
|
||||
query_start_loc,
|
||||
seq_lens_tensor,
|
||||
max_seq_len,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
0
vllm_kunlun/ops/quantization/__init__.py
Normal file
0
vllm_kunlun/ops/quantization/__init__.py
Normal file
128
vllm_kunlun/ops/quantization/awq.py
Normal file
128
vllm_kunlun/ops/quantization/awq.py
Normal file
@@ -0,0 +1,128 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Li Wei, Pan Xiakai, You Zeyu
|
||||
# Email: liwei157@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from typing import Optional
|
||||
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||
|
||||
|
||||
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4):
|
||||
"""Convert AWQ-packed int4 weights to Kunlun XPU format.
|
||||
Input: packed[N, K], dtype=int32, saved as AWQ order
|
||||
Output: packed_reordered[N, K], dtype=int32, saved as Kunlun order
|
||||
"""
|
||||
N, K = packed.shape
|
||||
self.align_type = 1 if K % 8 == 0 else 0
|
||||
assert num_bits == 4, "Only int4 supported now"
|
||||
shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32)
|
||||
|
||||
if self.align_type == 0: # NORMAL MODE
|
||||
# Unpack AWQ order:[0, 2, 4, 6, 1, 3, 5, 7]
|
||||
unpacked_awq = (packed.unsqueeze(-1) >> shifts) & 0xF # [N, K, 8]
|
||||
|
||||
# Reverse AWQ order and convert to KUNLUN order
|
||||
AWQ_TO_KUNLUN_ORDER_NORMAL = [4, 0, 5, 1, 6, 2, 7, 3]
|
||||
# [0,2,4,6,1,3,5,7] --> [1, 0, 3, 2, 5, 4, 7, 6]
|
||||
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_NORMAL] # [N, K, 8]
|
||||
|
||||
# Pack to int32, order[6, 7, 4, 5, 2, 3, 0, 1]
|
||||
packed_kunlun = (unpacked_kunlun << shifts).sum(
|
||||
dim=-1, dtype=torch.int32
|
||||
) # [N, K]
|
||||
elif self.align_type == 1: # FAST MODEL
|
||||
# Unpack AWQ order
|
||||
unpacked_awq = (
|
||||
packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts
|
||||
) & 0xF # [N, K//8, 8, 8]
|
||||
|
||||
# Reverse AWQ order and convert to KUNLUN order
|
||||
AWQ_TO_KUNLUN_ORDER_FAST = [
|
||||
32, 0, 36, 4, 33, 1, 37, 5,
|
||||
34, 2, 38, 6, 35, 3, 39, 7,
|
||||
40, 8, 44, 12, 41, 9, 45, 13,
|
||||
42, 10, 46, 14, 43, 11, 47, 15,
|
||||
48, 16, 52, 20, 49, 17, 53, 21,
|
||||
50, 18, 54, 22, 51, 19, 55, 23,
|
||||
56, 24, 60, 28, 57, 25, 61, 29,
|
||||
58, 26, 62, 30, 59, 27, 63, 31
|
||||
]
|
||||
unpacked_awq = unpacked_awq.reshape(N, K // 8, 64)
|
||||
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
|
||||
|
||||
# Pack to int32
|
||||
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
|
||||
packed_kunlun = (
|
||||
(unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K)
|
||||
) # [N, K]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return packed_kunlun
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.qweight = torch.nn.Parameter(
|
||||
(
|
||||
self.repack_int4_for_kunlun(layer.qweight.data)
|
||||
if layer.qweight.data.dtype == torch.int32
|
||||
else layer.qweight.data
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.qzeros = torch.nn.Parameter(
|
||||
(
|
||||
self.repack_int4_for_kunlun(layer.qzeros.data)
|
||||
if layer.qzeros.data.dtype == torch.int32
|
||||
else layer.qzeros.data
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
|
||||
def apply(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
qzeros = layer.qzeros
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
# num_tokens >= threshold
|
||||
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
|
||||
|
||||
if FP16_MATMUL_HEURISTIC_CONDITION:
|
||||
out = torch.ops._C.awq_dequantize(
|
||||
qweight, scales, qzeros, quant_type=0, align_type=self.align_type
|
||||
)
|
||||
out = torch.matmul(reshaped_x, out)
|
||||
else:
|
||||
out = torch.ops._C.awq_gemm(
|
||||
reshaped_x, qweight, scales, qzeros, align_type=self.align_type
|
||||
)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out.reshape(out_shape)
|
||||
|
||||
|
||||
AWQLinearMethod.repack_int4_for_kunlun = repack_int4_for_kunlun
|
||||
AWQLinearMethod.process_weights_after_loading = process_weights_after_loading
|
||||
AWQLinearMethod.apply = apply
|
||||
333
vllm_kunlun/ops/quantization/compressed_tensors_moe.py
Normal file
333
vllm_kunlun/ops/quantization/compressed_tensors_moe.py
Normal file
@@ -0,0 +1,333 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
# Author: Chen Zhennan, Dong Xinyu
|
||||
# Email: chenzhennan@baidu.com
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from typing import Any, Literal, Optional, cast, Callable, Optional
|
||||
|
||||
from compressed_tensors.config import (
|
||||
CompressionFormat,
|
||||
SparsityCompressionConfig,
|
||||
SparsityStructure,
|
||||
)
|
||||
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
|
||||
# TODO: import position will be changed after 0.9.0
|
||||
# vllm.model_executor.layers.fused_moe.fused_moe --> vllm.model_executor.layers.fused_moe
|
||||
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
import re
|
||||
import xtorch_ops
|
||||
|
||||
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
|
||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def get_moe_method(quant_config, layer) -> "CompressedTensorsMoEMethod":
|
||||
tsm = getattr(quant_config, "target_scheme_map", None) or {}
|
||||
linear_cfg = None
|
||||
for k in ("Linear", "FusedMoE", "MoE", "Moe", "Experts"):
|
||||
if k in tsm and isinstance(tsm[k], dict):
|
||||
linear_cfg = tsm[k]
|
||||
break
|
||||
if not linear_cfg:
|
||||
# print("target_scheme_map missing; fallback to INT8(W8A8) method")
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
wq = linear_cfg.get("weights")
|
||||
aq = linear_cfg.get("input_activations")
|
||||
if not wq or not aq:
|
||||
# print("incomplete scheme; fallback to INT8(W8A8)")
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
|
||||
# Other branches are handled as needed; default fallback:
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
|
||||
|
||||
# copied from vllm 0.9.0
|
||||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Directly create a default quantization config dictionary to avoid validation issues with QuantizationArgs
|
||||
# print("Creating default INT8 quantization config for MoE")
|
||||
|
||||
# Create a default weight quantization config dictionary
|
||||
self.weight_quant = type(
|
||||
"WeightQuant",
|
||||
(),
|
||||
{
|
||||
"type": "int",
|
||||
"num_bits": 8,
|
||||
"strategy": "channel",
|
||||
"group_size": 128,
|
||||
"symmetric": True,
|
||||
"dynamic": False,
|
||||
"actorder": "none",
|
||||
"observer": None,
|
||||
"observer_kwargs": {},
|
||||
"block_structure": None,
|
||||
},
|
||||
)()
|
||||
|
||||
# Create a default input activation quantization config dictionary
|
||||
self.input_quant = type(
|
||||
"InputQuant",
|
||||
(),
|
||||
{
|
||||
"type": "int",
|
||||
"num_bits": 8,
|
||||
"strategy": "token",
|
||||
"group_size": 128,
|
||||
"symmetric": True,
|
||||
"dynamic": True,
|
||||
"actorder": "none",
|
||||
"observer": None,
|
||||
"observer_kwargs": {},
|
||||
"block_structure": None,
|
||||
},
|
||||
)()
|
||||
|
||||
# Change comparison method to directly compare strings
|
||||
per_channel = (
|
||||
self.weight_quant.strategy == "channel"
|
||||
and self.input_quant.strategy == "token"
|
||||
)
|
||||
if not per_channel:
|
||||
raise ValueError(
|
||||
"For INT8 Fused MoE layers, we require channelwise, "
|
||||
"dynamic per token quantization. Found "
|
||||
f"{self.weight_quant}, {self.input_quant}"
|
||||
)
|
||||
|
||||
self.static_input_scales = not self.input_quant.dynamic
|
||||
if self.static_input_scales:
|
||||
raise ValueError(
|
||||
"For INT8 Fused MoE layers, we require channelwise, "
|
||||
"dynamic per token quantization. Found static input scales."
|
||||
)
|
||||
|
||||
def create_weights1(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Use float32 as a placeholder for weights to facilitate loading original weights from ckpt
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
), # generally is torch.bfloat16
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# Channel scale: float32 + 2D [E, out] (aligned with fused_moe/UT)
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# Input scale can be dynamically calculated
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=torch.int8,
|
||||
), # directly use int8
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8,
|
||||
), # directly use int8
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# Scale factors
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# Input scale can be dynamically calculated
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
@torch.no_grad()
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
return
|
||||
# Convert original weights to float32 for more robust statistics
|
||||
w13_f = layer.w13_weight.float()
|
||||
w2_f = layer.w2_weight.float()
|
||||
|
||||
# Each column (abs_max) -> per-column scale (out dimension is dim=1, column is dim=-1)
|
||||
qmax = 127.0
|
||||
w13_abs_max = torch.amax(torch.abs(w13_f), dim=-1) # [E, 2N]
|
||||
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H]
|
||||
|
||||
w13_scale_2d = torch.clamp(w13_abs_max, min=1e-6) / qmax # [E, 2N], float32
|
||||
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32
|
||||
|
||||
# Quantization: broadcast 3D scale and store back to 2D scale
|
||||
w13_scale_3d = w13_scale_2d.unsqueeze(-1) # [E, 2N, 1]
|
||||
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1]
|
||||
|
||||
w13_q = torch.round(w13_f / w13_scale_3d).clamp_(-128, 127).to(torch.int8)
|
||||
w2_q = torch.round(w2_f / w2_scale_3d).clamp_(-128, 127).to(torch.int8)
|
||||
|
||||
# Optional: If your fused/kernel expects scale pre-multiplied by 127 (to be consistent with some UT backends), uncomment the following two lines:
|
||||
w13_scale_2d = w13_scale_2d * 127.0
|
||||
w2_scale_2d = w2_scale_2d * 127.0
|
||||
|
||||
# Write back parameters: weight int8; scale uses float32 + 2D
|
||||
replace_parameter(
|
||||
layer, "w13_weight", torch.nn.Parameter(w13_q, requires_grad=False)
|
||||
)
|
||||
replace_parameter(
|
||||
layer, "w2_weight", torch.nn.Parameter(w2_q, requires_grad=False)
|
||||
)
|
||||
replace_parameter(
|
||||
layer,
|
||||
"w13_weight_scale",
|
||||
torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False),
|
||||
)
|
||||
replace_parameter(
|
||||
layer,
|
||||
"w2_weight_scale",
|
||||
torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False),
|
||||
)
|
||||
|
||||
# Brief check
|
||||
print(
|
||||
f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False, # Add this parameter
|
||||
expert_load_view: Optional[torch.Tensor] = None, # Add this parameter
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None, # Add this parameter
|
||||
logical_replica_count: Optional[torch.Tensor] = None, # Add this parameter
|
||||
linear_weights: Optional[torch.Tensor] = None, # Add this parameter
|
||||
) -> torch.Tensor:
|
||||
|
||||
output = torch.empty_like(x)
|
||||
torch.ops._C.moe_ffn_per_token_block(
|
||||
x=x,
|
||||
inter_weight=layer.w13_weight,
|
||||
inter_scale=layer.w13_weight_scale,
|
||||
outer_weight=layer.w2_weight,
|
||||
outer_scale=layer.w2_weight_scale,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
linear_weights=linear_weights,
|
||||
expert_map=expert_map,
|
||||
activation=activation,
|
||||
output=output,
|
||||
use_expert_parallel=expert_map is not None,
|
||||
ep_size=expert_map.size(0) if expert_map is not None else 1,
|
||||
ep_rank=0,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
print(
|
||||
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
|
||||
--> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod"
|
||||
)
|
||||
108
vllm_kunlun/ops/quantization/gptq.py
Normal file
108
vllm_kunlun/ops/quantization/gptq.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Li Wei, You Zeyu
|
||||
# Email: liwei157@baidu.com, youzeyu@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Optional
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod, ExllamaState
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# for torch.compile
|
||||
layer.qzeros = Parameter(
|
||||
self.repack_int4_for_kunlun(layer.qzeros.data, self.quant_config.weight_bits)
|
||||
if self.quant_config.weight_bits == 4 else layer.qzeros.data,
|
||||
requires_grad=False
|
||||
)
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
||||
if self.quant_config.desc_act:
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
else:
|
||||
layer.g_idx.data = torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=layer.g_idx.device)
|
||||
layer.exllama_state = ExllamaState.READY
|
||||
|
||||
# No need shuffle on xpu
|
||||
# ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
||||
# self.quant_config.weight_bits)
|
||||
|
||||
|
||||
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4):
|
||||
N, K = packed.shape
|
||||
assert num_bits == 4, "Only int4 supported now"
|
||||
shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32)
|
||||
|
||||
# Unpack int32 to int4 values
|
||||
unpacked_gptq = (
|
||||
packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts
|
||||
) & 0xF # [N, K//8, 8, 8]
|
||||
|
||||
# Convert to KUNLUN order
|
||||
GPTQ_TO_KUNLUN_ORDER_FAST = [
|
||||
32, 0, 33, 1, 34, 2, 35, 3,
|
||||
36, 4, 37, 5, 38, 6, 39, 7,
|
||||
40, 8, 41, 9, 42, 10, 43, 11,
|
||||
44, 12, 45, 13, 46, 14, 47, 15,
|
||||
48, 16, 49, 17, 50, 18, 51, 19,
|
||||
52, 20, 53, 21, 54, 22, 55, 23,
|
||||
56, 24, 57, 25, 58, 26, 59, 27,
|
||||
60, 28, 61, 29, 62, 30, 63, 31,
|
||||
]
|
||||
unpacked_gptq = unpacked_gptq.reshape(N, K // 8, 64)
|
||||
unpacked_kunlun = unpacked_gptq[..., GPTQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
|
||||
|
||||
# Pack to int32
|
||||
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
|
||||
packed_kunlun = (
|
||||
(unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K)
|
||||
) # [N, K]
|
||||
|
||||
return packed_kunlun
|
||||
|
||||
|
||||
def apply(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
output = torch.ops.xspeedgate_ops.gptq_gemm(
|
||||
reshaped_x,
|
||||
layer.qweight,
|
||||
layer.qzeros,
|
||||
layer.scales,
|
||||
layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.quant_config.weight_bits,
|
||||
)
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
GPTQLinearMethod.repack_int4_for_kunlun = repack_int4_for_kunlun
|
||||
GPTQLinearMethod.process_weights_after_loading = process_weights_after_loading
|
||||
GPTQLinearMethod.apply = apply
|
||||
180
vllm_kunlun/ops/rotary_embedding.py
Normal file
180
vllm_kunlun/ops/rotary_embedding.py
Normal file
@@ -0,0 +1,180 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
|
||||
import torch
|
||||
import xspeedgate_ops
|
||||
import os
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
RotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding,
|
||||
DynamicNTKScalingRotaryEmbedding,
|
||||
MRotaryEmbedding,
|
||||
)
|
||||
from typing import Optional, Tuple
|
||||
import xtorch_ops
|
||||
|
||||
|
||||
def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
if hasattr(self, "scaling_factor"):
|
||||
self.max_position_embeddings = int(
|
||||
self.max_position_embeddings * self.scaling_factor
|
||||
)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
if os.getenv("FUSED_QK_ROPE_OP") == "1":
|
||||
cache_cos = torch.cat((cos, cos), dim=-1)
|
||||
cache_sin = torch.cat((sin, sin), dim=-1)
|
||||
# [2, self.max_position_embeddings, self.rotary_dim * 2]
|
||||
cache = torch.stack((cache_cos, cache_sin), dim=0).unsqueeze(1)
|
||||
else:
|
||||
cache = torch.cat((cos, sin), dim=-1).unsqueeze(0).unsqueeze(1)
|
||||
return cache
|
||||
|
||||
|
||||
def vllm_kunlun_forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""forward_cuda"""
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
|
||||
if (
|
||||
self.cos_sin_cache.device != query.device
|
||||
or self.cos_sin_cache.dtype != query.dtype
|
||||
):
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if offsets is not None:
|
||||
ops.batched_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
self.rotary_dim,
|
||||
offsets,
|
||||
)
|
||||
else:
|
||||
query, key = ops.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
|
||||
|
||||
def vllm_kunlun_mrope_forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Args:
|
||||
positions:
|
||||
[num_tokens,] (text only) or
|
||||
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
|
||||
assert positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
query, key = torch.ops.xspeedgate_ops.mrotary_embedding_fwd_v0(
|
||||
query,
|
||||
key,
|
||||
positions.to(dtype=torch.int32),
|
||||
self.cos_sin_cache,
|
||||
False, # self.mrope_interleaved,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.mrope_section[0],
|
||||
self.mrope_section[1],
|
||||
self.mrope_section[2],
|
||||
)
|
||||
|
||||
return query, key
|
||||
|
||||
|
||||
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||
if os.getenv("KUNLUN_ENABLE_MULTI_LORA") == "1" or os.getenv("FUSED_QK_ROPE_OP") == "1":
|
||||
RotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
else:
|
||||
pass
|
||||
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
||||
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
||||
YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq
|
||||
|
||||
|
||||
def Split_Norm_Rope(
|
||||
qkv: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
q_norm_weight: torch.Tensor,
|
||||
k_norm_weight: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
max_position_embeddings: int,
|
||||
q_head_num: int,
|
||||
kv_head_num: int,
|
||||
head_dim: int,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
num_tokens = qkv.shape[0]
|
||||
rotary_dim = head_dim
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
q_emb_out = torch.empty(
|
||||
(num_tokens, q_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
|
||||
)
|
||||
k_emb_out = torch.empty(
|
||||
(num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
|
||||
)
|
||||
v_out = torch.empty(
|
||||
(num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
|
||||
)
|
||||
torch.ops._C.split_norm_rope_neox(
|
||||
q_emb_out,
|
||||
k_emb_out,
|
||||
v_out,
|
||||
qkv,
|
||||
cos_sin_cache,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
positions,
|
||||
num_tokens,
|
||||
max_position_embeddings,
|
||||
q_head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
)
|
||||
return q_emb_out, k_emb_out, v_out
|
||||
0
vllm_kunlun/ops/sample/__init__.py
Normal file
0
vllm_kunlun/ops/sample/__init__.py
Normal file
1431
vllm_kunlun/ops/sample/sampler.py
Normal file
1431
vllm_kunlun/ops/sample/sampler.py
Normal file
File diff suppressed because it is too large
Load Diff
477
vllm_kunlun/ops/vocab_parallel_embedding.py
Normal file
477
vllm_kunlun/ops/vocab_parallel_embedding.py
Normal file
@@ -0,0 +1,477 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
from vllm.model_executor.parameter import BasevLLMParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
|
||||
|
||||
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
"""Unquantized method for embeddings."""
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for embedding layer."""
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
input_: torch.Tensor) -> torch.Tensor:
|
||||
return F.embedding(input_, layer.weight)
|
||||
|
||||
|
||||
def pad_vocab_size(vocab_size: int,
|
||||
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
||||
"""Pad the vocab size to the given value."""
|
||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
|
||||
def vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size: int,
|
||||
rank: int,
|
||||
offset: int = 0) -> Sequence[int]:
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f + offset, index_l + offset
|
||||
|
||||
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
offset: int = 0) -> Sequence[int]:
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
|
||||
rank,
|
||||
offset=offset)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabParallelEmbeddingShardIndices:
|
||||
"""Indices for a shard of a vocab parallel embedding."""
|
||||
padded_org_vocab_start_index: int
|
||||
padded_org_vocab_end_index: int
|
||||
padded_added_vocab_start_index: int
|
||||
padded_added_vocab_end_index: int
|
||||
|
||||
org_vocab_start_index: int
|
||||
org_vocab_end_index: int
|
||||
added_vocab_start_index: int
|
||||
added_vocab_end_index: int
|
||||
|
||||
@property
|
||||
def num_org_elements(self) -> int:
|
||||
return self.org_vocab_end_index - self.org_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_added_elements(self) -> int:
|
||||
return self.added_vocab_end_index - self.added_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_org_elements_padded(self) -> int:
|
||||
return (self.padded_org_vocab_end_index -
|
||||
self.padded_org_vocab_start_index)
|
||||
|
||||
@property
|
||||
def num_added_elements_padded(self) -> int:
|
||||
return (self.padded_added_vocab_end_index -
|
||||
self.padded_added_vocab_start_index)
|
||||
|
||||
@property
|
||||
def num_org_vocab_padding(self) -> int:
|
||||
return self.num_org_elements_padded - self.num_org_elements
|
||||
|
||||
@property
|
||||
def num_added_vocab_padding(self) -> int:
|
||||
return self.num_added_elements_padded - self.num_added_elements
|
||||
|
||||
@property
|
||||
def num_elements_padded(self) -> int:
|
||||
return self.num_org_elements_padded + self.num_added_elements_padded
|
||||
|
||||
def __post_init__(self):
|
||||
# sanity checks
|
||||
assert (self.padded_org_vocab_start_index
|
||||
<= self.padded_org_vocab_end_index)
|
||||
assert (self.padded_added_vocab_start_index
|
||||
<= self.padded_added_vocab_end_index)
|
||||
|
||||
assert self.org_vocab_start_index <= self.org_vocab_end_index
|
||||
assert self.added_vocab_start_index <= self.added_vocab_end_index
|
||||
|
||||
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
|
||||
assert (self.added_vocab_start_index
|
||||
<= self.padded_added_vocab_start_index)
|
||||
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
|
||||
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
|
||||
|
||||
assert self.num_org_elements <= self.num_org_elements_padded
|
||||
assert self.num_added_elements <= self.num_added_elements_padded
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend="aot_eager")
|
||||
def get_masked_input_and_mask(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.compile will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
||||
input_ < org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
return input_, ~vocab_mask
|
||||
|
||||
|
||||
@CustomOp.register("vllm_kunlun_vocab_parallel_embedding")
|
||||
class VocabParallelEmbedding(CustomOp):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
||||
make sure it is divisible by the number of model parallel GPUs.
|
||||
|
||||
In order to support various loading methods, we ensure that LoRA-added
|
||||
embeddings are always at the end of TP-sharded tensors. In other words,
|
||||
we shard base embeddings and LoRA embeddings separately (both padded),
|
||||
and place them in the same tensor.
|
||||
In this example, we will have the original vocab size = 1010,
|
||||
added vocab size = 16 and padding to 64. Therefore, the total
|
||||
vocab size with padding will be 1088 (because we first pad 1010 to
|
||||
1024, add 16, and then pad to 1088).
|
||||
Therefore, the tensor format looks like the following:
|
||||
TP1, rank 0 (no sharding):
|
||||
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
|
||||
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1025 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
|
||||
|
||||
TP2, rank 0:
|
||||
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
|
||||
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1010 | ... | 1025 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 |
|
||||
TP2, rank 1:
|
||||
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
|
||||
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 |
|
||||
|
||||
Args:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
params_dtype: type of the parameters.
|
||||
org_num_embeddings: original vocabulary size (without LoRA).
|
||||
padding_size: padding size for the vocabulary.
|
||||
quant_config: quant config for the layer
|
||||
prefix: full name of the layer in the state dict
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
# Keep the input dimensions.
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.padding_size = padding_size
|
||||
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||
num_added_embeddings = num_embeddings - self.org_vocab_size
|
||||
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
|
||||
self.padding_size)
|
||||
self.num_embeddings_padded = pad_vocab_size(
|
||||
self.org_vocab_size_padded + num_added_embeddings,
|
||||
self.padding_size)
|
||||
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
||||
|
||||
self.shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
quant_method = None
|
||||
if quant_config is not None:
|
||||
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedEmbeddingMethod()
|
||||
|
||||
# If we are making an embedding layer, then our quantization linear
|
||||
# method must implement the embedding operation. If we are another
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self) is VocabParallelEmbedding
|
||||
quant_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(quant_method))
|
||||
if is_embedding_layer and not quant_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(quant_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||
|
||||
self.quant_method: QuantizeMethodBase = quant_method
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
||||
self.tp_size)
|
||||
assert (self.shard_indices.num_elements_padded ==
|
||||
self.num_embeddings_per_partition)
|
||||
self.num_org_embeddings_per_partition = (
|
||||
self.shard_indices.org_vocab_end_index -
|
||||
self.shard_indices.org_vocab_start_index)
|
||||
self.num_added_embeddings_per_partition = (
|
||||
self.shard_indices.added_vocab_end_index -
|
||||
self.shard_indices.added_vocab_start_index)
|
||||
|
||||
self.quant_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
@classmethod
|
||||
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
|
||||
vocab_size: int, org_vocab_size: int, tp_rank: int,
|
||||
tp_size: int) -> VocabParallelEmbeddingShardIndices:
|
||||
"""Get start and end indices for vocab parallel embedding, following the
|
||||
layout outlined in the class docstring, based on the given tp_rank and
|
||||
tp_size."""
|
||||
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
|
||||
padded_org_vocab_start_index, padded_org_vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
|
||||
tp_size))
|
||||
padded_added_vocab_start_index, padded_added_vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(num_added_embeddings_padded,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
offset=org_vocab_size))
|
||||
# remove padding
|
||||
org_vocab_start_index = min(padded_org_vocab_start_index,
|
||||
org_vocab_size)
|
||||
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
|
||||
added_vocab_start_index = min(padded_added_vocab_start_index,
|
||||
vocab_size)
|
||||
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
|
||||
return VocabParallelEmbeddingShardIndices(
|
||||
padded_org_vocab_start_index, padded_org_vocab_end_index,
|
||||
padded_added_vocab_start_index, padded_added_vocab_end_index,
|
||||
org_vocab_start_index, org_vocab_end_index,
|
||||
added_vocab_start_index, added_vocab_end_index)
|
||||
|
||||
def get_sharded_to_full_mapping(self) -> Optional[list[int]]:
|
||||
"""Get a mapping that can be used to reindex the gathered
|
||||
logits for sampling.
|
||||
|
||||
During sampling, we gather logits from all ranks. The relationship
|
||||
of index->token_id will follow the same format as outlined in the class
|
||||
docstring. However, after the gather, we want to reindex the final
|
||||
logits tensor to map index->token_id one-to-one (the index is always
|
||||
equal the token_id it corresponds to). The indices returned by this
|
||||
method allow us to do that.
|
||||
"""
|
||||
if self.tp_size < 2:
|
||||
return None
|
||||
|
||||
base_embeddings: list[int] = []
|
||||
added_embeddings: list[int] = []
|
||||
padding: list[int] = []
|
||||
for tp_rank in range(self.tp_size):
|
||||
shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
range_start = self.num_embeddings_per_partition * tp_rank
|
||||
range_end = self.num_embeddings_per_partition * (tp_rank + 1)
|
||||
base_embeddings.extend(
|
||||
range(range_start,
|
||||
range_start + shard_indices.num_org_elements))
|
||||
padding.extend(
|
||||
range(range_start + shard_indices.num_org_elements,
|
||||
range_start + shard_indices.num_org_elements_padded))
|
||||
added_embeddings.extend(
|
||||
range(
|
||||
range_start + shard_indices.num_org_elements_padded,
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements))
|
||||
padding.extend(
|
||||
range(
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements,
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements_padded))
|
||||
assert (range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements_padded == range_end)
|
||||
ret = base_embeddings + added_embeddings + padding
|
||||
assert len(ret) == self.num_embeddings_padded
|
||||
return ret
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
|
||||
# If the parameter is a gguf weight, then load it directly.
|
||||
if getattr(param, "is_gguf_weight_type", None):
|
||||
param.data.copy_(loaded_weight)
|
||||
param.weight_type = loaded_weight.item()
|
||||
return
|
||||
elif isinstance(param, UninitializedParameter):
|
||||
shape = list(loaded_weight.shape)
|
||||
if output_dim is not None:
|
||||
shape[output_dim] = self.num_embeddings_per_partition
|
||||
param.materialize(tuple(shape), dtype=loaded_weight.dtype)
|
||||
|
||||
# If parameter does not have output dim, then it should
|
||||
# be copied onto all gpus (e.g. g_idx for act_order gptq).
|
||||
if output_dim is None:
|
||||
assert param.data.shape == loaded_weight.shape
|
||||
param.data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
# Shard indexes for loading the weight
|
||||
start_idx = self.shard_indices.org_vocab_start_index
|
||||
shard_size = self.shard_indices.org_vocab_end_index - start_idx
|
||||
|
||||
# If param packed on the same dim we are sharding on, then
|
||||
# need to adjust offsets of loaded weight by pack_factor.
|
||||
if packed_dim is not None and packed_dim == output_dim:
|
||||
packed_factor = param.packed_factor if isinstance(
|
||||
param, BasevLLMParameter) else param.pack_factor
|
||||
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
|
||||
param.packed_factor)
|
||||
start_idx = start_idx // packed_factor
|
||||
shard_size = shard_size // packed_factor
|
||||
else:
|
||||
assert loaded_weight.shape[output_dim] == self.org_vocab_size
|
||||
|
||||
# Copy the data. Select chunk corresponding to current shard.
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self,
|
||||
masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"num_embeddings={self.num_embeddings_per_partition}"
|
||||
s += f", embedding_dim={self.embedding_dim}"
|
||||
s += f", org_vocab_size={self.org_vocab_size}"
|
||||
s += f', num_embeddings_padded={self.num_embeddings_padded}'
|
||||
s += f', tp_size={self.tp_size}'
|
||||
return s
|
||||
|
||||
|
||||
class ParallelLMHead(VocabParallelEmbedding):
|
||||
"""Parallelized LM head.
|
||||
|
||||
Output logits weight matrices used in the Sampler. The weight and bias
|
||||
tensors are padded to make sure they are divisible by the number of
|
||||
model parallel GPUs.
|
||||
|
||||
Args:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
bias: whether to use bias.
|
||||
params_dtype: type of the parameters.
|
||||
org_num_embeddings: original vocabulary size (without LoRA).
|
||||
padding_size: padding size for the vocabulary.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
||||
org_num_embeddings, padding_size, quant_config,
|
||||
prefix)
|
||||
self.quant_config = quant_config
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
|
||||
"""Tie the weights with word embeddings."""
|
||||
# GGUF quantized embed_tokens.
|
||||
if self.quant_config and self.quant_config.get_name() == "gguf":
|
||||
return embed_tokens
|
||||
else:
|
||||
self.weight = embed_tokens.weight
|
||||
return self
|
||||
|
||||
def forward(self, input_):
|
||||
del input_
|
||||
raise RuntimeError("LMHead's weights should be used in the sampler.")
|
||||
Reference in New Issue
Block a user