提交vllm0.11.0开发分支
This commit is contained in:
@@ -12,10 +12,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import vllm_kunlun.ops.rotary_embedding
|
||||
import vllm_kunlun.ops.layernorm
|
||||
import vllm_kunlun.ops.quantization.awq
|
||||
import vllm_kunlun.ops.quantization.gptq
|
||||
import vllm_kunlun.ops.layernorm
|
||||
@@ -1,20 +1,3 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""kunlun custom op entry"""
|
||||
import torch_xmlir
|
||||
import torch
|
||||
@@ -29,7 +12,6 @@ logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import xtorch_ops
|
||||
|
||||
logger.info(f"Load custom ops library success!")
|
||||
except ImportError as e:
|
||||
logger.warning("Import error msg: %s", e.msg)
|
||||
@@ -37,15 +19,13 @@ except ImportError as e:
|
||||
|
||||
_per_token_smooth_quant = True
|
||||
|
||||
|
||||
def is_per_token_smooth_quant():
|
||||
"""is per token smooth quant"""
|
||||
""" is per token smooth quant """
|
||||
return _per_token_smooth_quant
|
||||
|
||||
|
||||
class KunlunOps:
|
||||
"""KunlunOps"""
|
||||
|
||||
# Attention ops
|
||||
@staticmethod
|
||||
def paged_attention_v1(
|
||||
@@ -70,9 +50,10 @@ class KunlunOps:
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
alibi_sqrt=False,
|
||||
):
|
||||
"""PagedAttentionV1"""
|
||||
alibi_sqrt=False
|
||||
):
|
||||
""" PagedAttentionV1 """
|
||||
# block_size = value_cache.shape[2]
|
||||
xtorch_ops.paged_attention(
|
||||
x=query,
|
||||
k_cache=key_cache,
|
||||
@@ -83,7 +64,7 @@ class KunlunOps:
|
||||
is_context=is_context,
|
||||
is_causal=True,
|
||||
out=output,
|
||||
vo_head_dim=128,
|
||||
vo_head_dim=128
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -112,9 +93,10 @@ class KunlunOps:
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
alibi_sqrt=False,
|
||||
):
|
||||
"""PagedAttentionV2"""
|
||||
alibi_sqrt=False
|
||||
):
|
||||
""" PagedAttentionV2 """
|
||||
# block_size = value_cache.shape[2]
|
||||
xtorch_ops.paged_attention(
|
||||
x=query,
|
||||
k_cache=key_cache,
|
||||
@@ -125,28 +107,31 @@ class KunlunOps:
|
||||
is_context=is_context,
|
||||
is_causal=True,
|
||||
out=output,
|
||||
vo_head_dim=128,
|
||||
vo_head_dim=128
|
||||
)
|
||||
|
||||
|
||||
# Activation ops
|
||||
@staticmethod
|
||||
def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
|
||||
"""silu and mul"""
|
||||
def silu_and_mul(out: torch.Tensor,
|
||||
x: torch.Tensor):
|
||||
""" silu and mul """
|
||||
xtorch_ops.silu_and_mul(
|
||||
x,
|
||||
axis=-1,
|
||||
turn=True,
|
||||
out=out,
|
||||
)
|
||||
)
|
||||
|
||||
# Activation ops
|
||||
@staticmethod
|
||||
def quick_gelu(out: torch.Tensor, x: torch.Tensor):
|
||||
"""quick gelu"""
|
||||
def quick_gelu(out: torch.Tensor,
|
||||
x: torch.Tensor):
|
||||
""" quick gelu """
|
||||
xtorch_ops.quick_gelu(
|
||||
x,
|
||||
out=out,
|
||||
)
|
||||
)
|
||||
|
||||
# Layernorm
|
||||
@staticmethod
|
||||
@@ -157,7 +142,9 @@ class KunlunOps:
|
||||
epsilon,
|
||||
):
|
||||
"""rms_norm"""
|
||||
xtorch_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
|
||||
xtorch_ops.rmsnorm(
|
||||
x, weight.to(torch.float32), epsilon, out=out
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fused_add_rms_norm(
|
||||
@@ -175,11 +162,16 @@ class KunlunOps:
|
||||
residual.copy_(fused_input, non_blocking=True)
|
||||
x.copy_(output)
|
||||
|
||||
|
||||
# Rotary embedding
|
||||
@staticmethod
|
||||
def rotary_embedding(
|
||||
positions, query, key, head_size, cos_sin_cache, is_neox_style
|
||||
):
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style):
|
||||
"""
|
||||
refactor RotaryEmbedding forward function
|
||||
"""
|
||||
@@ -196,43 +188,66 @@ class KunlunOps:
|
||||
key_x = key_x.unsqueeze(0)
|
||||
|
||||
xtorch_ops.rotary_embedding_gptj(
|
||||
positions, query_x, key_x, head_size, cos_sin_cache
|
||||
)
|
||||
positions,
|
||||
query_x,
|
||||
key_x,
|
||||
head_size,
|
||||
cos_sin_cache)
|
||||
query.data = query_x
|
||||
key.data = key_x
|
||||
key.data = key_x
|
||||
if query_x_dim != query_x.dim():
|
||||
query_x = query_x.unsqueeze(0)
|
||||
key_x = key_x.unsqueeze(0)
|
||||
return query, key
|
||||
|
||||
|
||||
# TODO: need opt
|
||||
if cos_sin_cache.dim() == 4:
|
||||
max_seq_len = cos_sin_cache.shape[2]
|
||||
head_dim = cos_sin_cache.shape[3]
|
||||
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(
|
||||
0
|
||||
) # Remove the first two dimensions [1,1,L,D] -> [L,D]
|
||||
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(0) # 移除前两个维度 [1,1,L,D] -> [L,D]
|
||||
cos_sin_cache = cos_sin_cache.view(max_seq_len, 1, head_dim)
|
||||
|
||||
# Reshape query and key
|
||||
|
||||
# 重塑 query 和 key 的形状
|
||||
num_tokens = query_x.shape[0]
|
||||
num_heads = query_x.shape[1] // head_size
|
||||
num_kv_heads = key_x.shape[1] // head_size
|
||||
|
||||
# # [num_tokens, num_heads * head_size] -> [num_tokens, num_heads, head_size]
|
||||
# query_x = query_x.view(num_tokens, num_heads, head_size)
|
||||
# # [num_tokens, num_kv_heads * head_size] -> [num_tokens, num_kv_heads, head_size]
|
||||
# key_x = key_x.view(num_tokens, num_kv_heads, head_size)
|
||||
|
||||
# # 确保形状正确
|
||||
# assert query_x.shape == (num_tokens, num_heads, head_size), \
|
||||
# f"Expected query shape [{num_tokens}, {num_heads}, {head_size}], got {query_x.shape}"
|
||||
# assert key_x.shape == (num_tokens, num_kv_heads, head_size), \
|
||||
# f"Expected key shape [{num_tokens}, {num_kv_heads}, {head_size}], got {key_x.shape}"
|
||||
|
||||
torch.ops._C.rotary_embedding(
|
||||
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
|
||||
)
|
||||
positions,
|
||||
query_x,
|
||||
key_x,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style)
|
||||
|
||||
query_x = query_x.view(num_tokens, num_heads * head_size)
|
||||
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
|
||||
|
||||
# query.data = query_x
|
||||
# key.data = key_x
|
||||
return query_x, key_x
|
||||
|
||||
# Rotary embedding
|
||||
@staticmethod
|
||||
def mrotary_embedding(
|
||||
positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
|
||||
):
|
||||
positions,
|
||||
mrope_section,
|
||||
query,
|
||||
key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style):
|
||||
"""
|
||||
refactor RotaryEmbedding forward function
|
||||
"""
|
||||
@@ -241,21 +256,35 @@ class KunlunOps:
|
||||
query_x_dim = query_x.dim()
|
||||
assert is_neox_style
|
||||
xtorch_ops.mrotary_embedding_neox(
|
||||
positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
|
||||
)
|
||||
positions,
|
||||
query_x,
|
||||
key_x,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
mrope_section)
|
||||
|
||||
query.data = query_x
|
||||
key.data = key_x
|
||||
key.data = key_x
|
||||
return query, key
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(src, dst, block_mapping):
|
||||
"""swap_blocks"""
|
||||
xtorch_ops.swap_blocks(src, dst, block_mapping)
|
||||
def swap_blocks(
|
||||
src,
|
||||
dst,
|
||||
block_mapping):
|
||||
""" swap_blocks """
|
||||
xtorch_ops.swap_blocks(
|
||||
src,
|
||||
dst,
|
||||
block_mapping
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(key_caches, value_caches, block_mapping):
|
||||
"""copy_blocks"""
|
||||
def copy_blocks(
|
||||
key_caches,
|
||||
value_caches,
|
||||
block_mapping):
|
||||
""" copy_blocks """
|
||||
for i in range(len(key_caches)):
|
||||
key_caches[i] = key_caches[i].contiguous()
|
||||
value_caches[i] = value_caches[i].contiguous()
|
||||
@@ -273,9 +302,16 @@ class KunlunOps:
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype,
|
||||
):
|
||||
"""reshape_and_cache"""
|
||||
xtorch_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
):
|
||||
""" reshape_and_cache """
|
||||
# slot_mapping_cast = slot_mapping.to(torch.int32)
|
||||
xtorch_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def multi_query_kv_attention(
|
||||
@@ -284,7 +320,7 @@ class KunlunOps:
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
**kargs,
|
||||
**kargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
@@ -303,7 +339,7 @@ class KunlunOps:
|
||||
KVh = key.size(2)
|
||||
if KVh != Qh:
|
||||
repeat = Qh // KVh
|
||||
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
|
||||
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
|
||||
value = value.repeat_interleave(repeat, dim=2)
|
||||
xtorch_ops.attention(
|
||||
q=query,
|
||||
@@ -318,132 +354,85 @@ class KunlunOps:
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def quant_fusedresidual_rmsnorm_op(
|
||||
x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
|
||||
):
|
||||
def quant_fusedresidual_rmsnorm_op(x,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
scale_to_int,
|
||||
eps,
|
||||
dyn_scale: bool,
|
||||
type: int = 1):
|
||||
"""Quantized fused residual layer normalization"""
|
||||
out = torch.empty_like(x, dtype=torch.int8)
|
||||
|
||||
if is_per_token_smooth_quant():
|
||||
out_scale = torch.empty(
|
||||
x.shape[:-1], device=x.device, dtype=torch.float
|
||||
).unsqueeze(-1)
|
||||
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
|
||||
else:
|
||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||
|
||||
xtorch_ops.quant_fusedresidual_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
out=out,
|
||||
out_scale=out_scale,
|
||||
residual_tensor=residual,
|
||||
)
|
||||
xtorch_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
|
||||
out=out, out_scale=out_scale , residual_tensor=residual)
|
||||
|
||||
if residual is None:
|
||||
return out, out_scale
|
||||
return out, out_scale, residual
|
||||
|
||||
@staticmethod
|
||||
def quant_rmsnorm_op(
|
||||
x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
|
||||
):
|
||||
def quant_rmsnorm_op(x,
|
||||
weight,
|
||||
bias,
|
||||
scale_to_int,
|
||||
eps,
|
||||
dyn_scale : bool,
|
||||
type: int = 1):
|
||||
"""Quantized RMSNorm"""
|
||||
|
||||
out = torch.empty_like(x, dtype=torch.int8)
|
||||
if is_per_token_smooth_quant():
|
||||
out_scale = torch.empty(
|
||||
x.shape[:-1], device=x.device, dtype=torch.float
|
||||
).unsqueeze(-1)
|
||||
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
|
||||
else:
|
||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||
|
||||
xtorch_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
|
||||
xtorch_ops.quant_rmsnorm(x, weight, bias, eps,
|
||||
out=out, out_scale=out_scale)
|
||||
return out, out_scale
|
||||
|
||||
@staticmethod
|
||||
def smooth_quant_matmul_column_row_kernels(
|
||||
input_tensor,
|
||||
weight,
|
||||
smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
otype,
|
||||
):
|
||||
def smooth_quant_matmul_column_row_kernels(input_tensor,
|
||||
weight,
|
||||
smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
otype):
|
||||
"""smooth_quant_matmul_column_row_kernels"""
|
||||
input_shape = input_tensor.shape
|
||||
weight_shape = weight.shape
|
||||
if input_tensor.dim() == 3:
|
||||
input_tensor = input_tensor.reshape(-1, input_shape[-1])
|
||||
out = torch.empty(
|
||||
(input_shape[0] * input_shape[1], weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device,
|
||||
)
|
||||
out = torch.empty((input_shape[0] * input_shape[1],
|
||||
weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device)
|
||||
output_bs_shape = [input_shape[0], input_shape[1]]
|
||||
elif input_tensor.dim() == 2:
|
||||
out = torch.empty(
|
||||
(input_shape[0], weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device,
|
||||
)
|
||||
out = torch.empty((input_shape[0], weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device)
|
||||
output_bs_shape = [-1]
|
||||
xtorch_ops.smooth_quant_matmul_column_row_kernels(
|
||||
input_tensor,
|
||||
weight,
|
||||
smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
out=out,
|
||||
)
|
||||
xtorch_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
|
||||
weight, smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
out=out)
|
||||
|
||||
out = out.view(*output_bs_shape, weight_shape[0])
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""fused_moe"""
|
||||
output = torch.empty(
|
||||
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
expert_num = linear_weights.shape[0]
|
||||
|
||||
torch.ops._C.moe_ffn_block(
|
||||
x=hidden_states,
|
||||
gate_w=linear_weights,
|
||||
inter_w=w1,
|
||||
output_w=w2,
|
||||
expert_num=expert_num,
|
||||
moe_top_k=topk,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
expert_group_num=num_expert_group,
|
||||
out=output,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_moe_ep(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -460,23 +449,23 @@ class KunlunOps:
|
||||
topk_group: Optional[int] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> torch.Tensor:
|
||||
x = hidden_states
|
||||
batch, hidden_size = x.shape
|
||||
batch, hidden_size = x.shape
|
||||
num_local_experts, up_gate_size, _ = w13_weight.shape
|
||||
|
||||
router_logits = x.to(linear_weights.dtype) @ linear_weights.T
|
||||
|
||||
topk_weights = torch.empty(
|
||||
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
|
||||
)
|
||||
topk_ids = torch.empty(
|
||||
batch, top_k, dtype=torch.int32, device=router_logits.device
|
||||
)
|
||||
block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
|
||||
torch.ops._C.moe_softmax_topk(
|
||||
router_logits, topk_weights, topk_ids, block_static
|
||||
)
|
||||
router_logits = x.to(linear_weights.dtype)@linear_weights.T
|
||||
|
||||
topk_weights = torch.empty(batch,
|
||||
top_k,
|
||||
dtype=router_logits.dtype,
|
||||
device=router_logits.device)
|
||||
topk_ids = torch.empty(batch,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=router_logits.device)
|
||||
block_static = torch.empty(0, dtype=torch.int32,device=router_logits.device)
|
||||
torch.ops._C.moe_softmax_topk(router_logits, topk_weights, topk_ids, block_static)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
|
||||
@@ -490,22 +479,50 @@ class KunlunOps:
|
||||
selected_token = topk_ids_flat == experts_id
|
||||
if selected_token.sum():
|
||||
cur_token = repeat_x[selected_token]
|
||||
up_gate = torch.empty(
|
||||
selected_token.sum(),
|
||||
up_gate_size // 2,
|
||||
dtype=cur_token.dtype,
|
||||
device=cur_token.device,
|
||||
)
|
||||
torch.ops._C.swiglu(cur_token @ w13_weight[i].T, up_gate)
|
||||
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
|
||||
dtype=cur_token.dtype, device=cur_token.device)
|
||||
torch.ops._C.swiglu(cur_token@ w13_weight[i].T, up_gate)
|
||||
out[selected_token] = up_gate @ w2_weight[i].T
|
||||
output = (
|
||||
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
|
||||
.sum(dim=1)
|
||||
.to(x.dtype)
|
||||
)
|
||||
output = (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)).sum(dim=1).to(x.dtype)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""fused_moe"""
|
||||
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
expert_num = linear_weights.shape[0]
|
||||
|
||||
torch.ops._C.moe_ffn_block(
|
||||
x=hidden_states,
|
||||
gate_w=linear_weights,
|
||||
inter_w=w1,
|
||||
output_w=w2,
|
||||
expert_num=expert_num,
|
||||
moe_top_k=topk,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
expert_group_num=num_expert_group,
|
||||
out=output,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_multi_head_latent_page_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -538,11 +555,10 @@ class KunlunOps:
|
||||
prompt_lods_cpu: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
) -> torch.Tensor:
|
||||
"""mla pa block"""
|
||||
output = torch.empty(
|
||||
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
xtorch_ops.xft_multi_head_latent_page_attention_block(
|
||||
hidden_states,
|
||||
q_lora_rank,
|
||||
@@ -579,3 +595,42 @@ class KunlunOps:
|
||||
v_cache=v_cache,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def fused_gdn_gating(
|
||||
A_log: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
beta: float = 1.0,
|
||||
threshold: float = 20.0,
|
||||
) -> torch.Tensor:
|
||||
"""fused_gdn_gating"""
|
||||
output = xtorch_ops.fused_gdn_gating(
|
||||
A_log,
|
||||
a,
|
||||
dt_bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
h0_source: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
use_qk_l2norm_in_kernel: bool,
|
||||
cu_seqlens: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
'''
|
||||
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
|
||||
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。
|
||||
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
|
||||
'''
|
||||
|
||||
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwd(
|
||||
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
|
||||
cu_seqlens)
|
||||
return (o, final_state)
|
||||
@@ -1,8 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Custom activation functions."""
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import LazyDict
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_fatrelu_and_mul")
|
||||
class FatreluAndMul(CustomOp):
|
||||
"""An activation function for FATReLU.
|
||||
|
||||
The function computes x -> FATReLU(x[:d]) * x[d:] where
|
||||
d = x.shape[-1] // 2.
|
||||
This is used in openbmb/MiniCPM-S-1B-sft.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def __init__(self, threshold: float = 0.):
|
||||
"""
|
||||
Initializes the instance.
|
||||
|
||||
Args:
|
||||
threshold (float, optional): Threshold value for the filter. Defaults to 0..
|
||||
|
||||
Returns:
|
||||
None: This method does not return anything.
|
||||
"""
|
||||
super().__init__()
|
||||
self.threshold = threshold
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
计算输入张量的正向传播,并返回一个新的张量。
|
||||
该函数实现了原生的前向传播过程,即对输入张量进行阈值化处理后,将其乘以另一个张量。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor, shape=[*, d]):
|
||||
输入张量,其中*表示任意维度,d为特征维度。
|
||||
|
||||
Returns:
|
||||
torch.Tensor, shape=[*, d]:
|
||||
返回一个新的张量,其形状与输入张量相同,除了最后一个维度被设置为d/2。
|
||||
如果输入张量的最后一个维度小于等于d/2,则返回的张量将保持不变;否则,将对输入张量进行阈值化处理。
|
||||
"""
|
||||
d = x.shape[-1] // 2
|
||||
x1 = x[..., :d]
|
||||
x2 = x[..., d:]
|
||||
x1 = F.threshold(x1, self.threshold, 0.0)
|
||||
return x1 * x2
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
在CUDA设备上执行前向传播。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,形状为(N, C, H, W)。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,形状为(N, C, H, W)。
|
||||
"""
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_silu_and_mul")
|
||||
@@ -15,9 +85,532 @@ class SiluAndMul(CustomOp):
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""forward_cuda"""
|
||||
import xtorch_ops
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
torch.ops._C.swiglu(x, out)
|
||||
return out
|
||||
return out
|
||||
|
||||
def forward_kunlun(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""forward_kunlun"""
|
||||
import xtorch_ops
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
xtorch_ops.swiglu(x, out)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the function on `x` using XPU backend.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of any shape. Must be a floating point tensor.
|
||||
The number of channels should be even.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape as input except the last dimension is reduced by half.
|
||||
It has the same dtype as the input and lives on the same device.
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
ops.silu_and_mul(out, x)
|
||||
return out
|
||||
|
||||
def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播一个神经元,计算输入的信号。
|
||||
参数:
|
||||
x (torch.Tensor): 形状为(-1, d)的张量,其中d是输入的维度。
|
||||
每个元素表示一个输入信号。
|
||||
返回值(torch.Tensor):
|
||||
形状为(-1, d)的张量,其中d是输出的维度。
|
||||
每个元素表示一个输出信号。
|
||||
"""
|
||||
d = x.shape[-1] // 2
|
||||
x_reshaped = x.view(-1, x.shape[-1])
|
||||
s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
|
||||
result = s * x_reshaped[:, d:]
|
||||
return result.view(*x.shape[:-1], d)
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_mul_and_silu")
|
||||
class MulAndSilu(CustomOp):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化函数,用于实例化类的对象。
|
||||
如果当前平台是 CUDA 或 XPU,则使用 torch.ops._C.mul_and_silu 进行操作;
|
||||
否则,如果当前平台是 CPU,则使用 forward_native 方法进行操作。
|
||||
"""
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike():
|
||||
self.op = torch.ops._C.mul_and_silu
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
self.op = ipex_ops.silu_and_mul
|
||||
elif current_platform.is_cpu():
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return x[..., :d] * F.silu(x[..., d:])
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
在CUDA设备上执行前向传播操作。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,其形状应为(..., d),其中d是特征维度。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,其形状与输入张量相同,但最后一个维度被替换为d/2。
|
||||
|
||||
Raises:
|
||||
无。
|
||||
"""
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
# TODO implement forward_xpu for MulAndSilu
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_gelu_and_mul")
|
||||
class GeluAndMul(CustomOp):
|
||||
"""An activation function for GeGLU.
|
||||
|
||||
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
"""
|
||||
|
||||
def __init__(self, approximate: str = "none"):
|
||||
"""
|
||||
Initializes the instance.
|
||||
|
||||
Args:
|
||||
approximate (str, optional): The approximation method to use. Defaults to "none".
|
||||
Can be one of "none", "tanh".
|
||||
|
||||
Raises:
|
||||
ValueError: If the `approximate` parameter is not one of "none", "tanh".
|
||||
"""
|
||||
super().__init__()
|
||||
self.approximate = approximate
|
||||
if approximate not in ("none", "tanh"):
|
||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
在CUDA设备上进行前向传播。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,形状为(batch_size, ..., dim),其中dim是特征维度。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,形状为(batch_size, ..., dim//2),其中dim是特征维度,除以2是因为GELU的输出是两个分量。
|
||||
|
||||
Raises:
|
||||
无。
|
||||
"""
|
||||
# from vllm import _custom_ops as ops
|
||||
import xtorch_ops
|
||||
# d = x.shape[-1] // 2
|
||||
# output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(x, dtype=x.dtype, device=x.device)
|
||||
if self.approximate == "none":
|
||||
# ops.gelu_and_mul(out, x)
|
||||
print(x,x.shape)
|
||||
xtorch_ops.gelu(x, out)
|
||||
elif self.approximate == "tanh":
|
||||
ops.gelu_tanh_and_mul(out, x)
|
||||
return out
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d, _ = self._check_and_make_out(x)
|
||||
# 保守地用 contiguous,避免 view 相关坑
|
||||
x = x.contiguous()
|
||||
x1 = x[..., :d]
|
||||
x2 = x[..., d:]
|
||||
return F.gelu(x1, approximate=self.approximate) * x2
|
||||
|
||||
# def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# """PyTorch-native implementation equivalent to forward()."""
|
||||
# d = x.shape[-1] // 2
|
||||
# return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply gelu activation function on input tensor using iPEX backend.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with shape (N, C, H, W).
|
||||
The data type can be float32 or float64.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape and data type as input.
|
||||
The output will have a range of (-0.5, 0.5) for tanh approximation.
|
||||
"""
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
if self.approximate == "none":
|
||||
ops.gelu_and_mul(out, x)
|
||||
elif self.approximate == "tanh":
|
||||
ops.gelu_tanh_and_mul(out, x)
|
||||
return out
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""
|
||||
返回一个字符串,包含有关模型的额外信息。这个函数可以被用于打印出模型的概要信息。
|
||||
默认情况下,这个函数会返回一个包含模型是否使用近似值(approximate)的信息。
|
||||
|
||||
Returns:
|
||||
str (str): 一个字符串,包含有关模型的额外信息。
|
||||
"""
|
||||
return f'approximate={repr(self.approximate)}'
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_gelu_new")
|
||||
class NewGELU(CustomOp):
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
c = math.sqrt(2.0 / math.pi)
|
||||
return 0.5 * x * (1.0 + torch.tanh(c *
|
||||
(x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
计算CUDA上的GELU函数。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,形状为(N, C, H, W)。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: GELU函数的结果,形状与输入相同。
|
||||
|
||||
Raises:
|
||||
无。
|
||||
"""
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_new(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the GELU activation function element-wise.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with any shape. The data type is float32 or float64.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape as input. The data type is the same as input.
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
return ops.gelu_new(x)
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_gelu_fast")
|
||||
class FastGELU(CustomOp):
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
|
||||
(1.0 + 0.044715 * x * x)))
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
计算输入张量x的CUDA版本GELU(Gaussian Error Linear Unit)。
|
||||
该函数调用了vllm模块中的_custom_ops模块中的gelu_fast函数,完成GELU操作。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,形状为(N, C, H, W),类型为float32或float64。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: GELU后的输出张量,形状与x相同,类型与x相同。
|
||||
|
||||
Raises:
|
||||
无。
|
||||
"""
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_fast(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the GELU function element-wise on input tensor ``x``.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with any shape. The data type can be float or half float.
|
||||
The range of the input values is expected to be -inf to inf.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape and data type as input ``x``.
|
||||
The output values are in the range [-0.5, 0.5] for float dtype and [-15, 15] for half float dtype.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input ``x`` is not a torch.Tensor.
|
||||
RuntimeError: If the input ``x`` contains non-finite numbers.
|
||||
"""
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
return ops.gelu_fast(x)
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_quick_gelu")
|
||||
class QuickGELU(CustomOp):
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
使用CUDA设备进行前向计算。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,形状为(N, C, H, W)。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,形状与输入相同,值为GELU函数的结果。
|
||||
|
||||
Raises:
|
||||
无。
|
||||
"""
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_quick(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the GELU function element-wise on input tensor ``x``.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with any shape. The data type is float32 or float64.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape and data type as input ``x``.
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_quick(out, x)
|
||||
return out
|
||||
|
||||
def forward_kunlun(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""forward_kunlun"""
|
||||
from vllm._kunlun_ops import KunlunOps as ops
|
||||
out = torch.empty_like(x)
|
||||
ops.quick_gelu(out, x)
|
||||
return out
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_relu2")
|
||||
class ReLUSquaredActivation(CustomOp):
|
||||
"""
|
||||
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
||||
"""
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return torch.square(F.relu(x))
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
在CUDA设备上执行前向传播。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,形状为(N, C, H, W),数据类型为float32或float64。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,形状与输入相同,数据类型与输入一致。
|
||||
|
||||
Raises:
|
||||
无。
|
||||
"""
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
|
||||
This is used for some quantization methods like AWQ.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act_module: nn.Module,
|
||||
intermediate_size: int,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the LayerNorm module.
|
||||
|
||||
Args:
|
||||
act_module (nn.Module): The activation function to use after layer norm.
|
||||
Default: nn.GELU()
|
||||
intermediate_size (int): The size of the intermediate representation.
|
||||
input_is_parallel (bool, optional): Whether the input is parallelly processed.
|
||||
Default: True
|
||||
params_dtype (Optional[torch.dtype], optional): The data type of parameters.
|
||||
If None, use the default data type. Default: None
|
||||
"""
|
||||
super().__init__()
|
||||
self.act = act_module
|
||||
self.input_is_parallel = input_is_parallel
|
||||
if input_is_parallel:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size_per_partition = divide(intermediate_size,
|
||||
tp_size)
|
||||
else:
|
||||
intermediate_size_per_partition = intermediate_size
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.scales = nn.Parameter(
|
||||
torch.empty(intermediate_size_per_partition, dtype=params_dtype))
|
||||
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播函数,将输入的张量进行缩放和激活操作。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,形状为(N, C, H, W)或者(N, C, H, W, D)。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 返回处理后的张量,形状与输入相同。
|
||||
"""
|
||||
return self.act(x) / self.scales
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
"""
|
||||
加载权重,如果输入是并行的,则需要将其平均分配到每个模型参数中。
|
||||
参数:
|
||||
param (nn.Parameter): 需要加载权重的模型参数。
|
||||
loaded_weight (torch.Tensor): 加载的权重张量。
|
||||
返回值:
|
||||
无返回值,直接修改了param的数据。
|
||||
"""
|
||||
param_data = param.data
|
||||
if self.input_is_parallel:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = param_data.shape[0]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
_ACTIVATION_REGISTRY = LazyDict({
|
||||
"gelu":
|
||||
lambda: nn.GELU(),
|
||||
"gelu_fast":
|
||||
lambda: FastGELU(),
|
||||
"gelu_new":
|
||||
lambda: NewGELU(),
|
||||
"gelu_pytorch_tanh":
|
||||
lambda: nn.GELU(approximate="tanh"),
|
||||
"relu":
|
||||
lambda: nn.ReLU(),
|
||||
"relu2":
|
||||
lambda: ReLUSquaredActivation(),
|
||||
"silu":
|
||||
lambda: nn.SiLU(),
|
||||
"quick_gelu":
|
||||
lambda: QuickGELU(),
|
||||
})
|
||||
|
||||
|
||||
def get_act_fn(
|
||||
act_fn_name: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
) -> nn.Module:
|
||||
"""Get an activation function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
# print(f"activation function name: {act_fn_name}")
|
||||
if act_fn_name not in _ACTIVATION_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
||||
if (quant_config is not None
|
||||
and act_fn_name in quant_config.get_scaled_act_names()):
|
||||
if intermediate_size is None:
|
||||
raise ValueError("intermediate_size must be specified for scaled "
|
||||
"activation functions.")
|
||||
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
|
||||
params_dtype)
|
||||
return act_fn
|
||||
|
||||
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
|
||||
"gelu": lambda: GeluAndMul(),
|
||||
"silu": lambda: SiluAndMul(),
|
||||
"geglu": lambda: GeluAndMul(),
|
||||
})
|
||||
|
||||
|
||||
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
|
||||
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
|
||||
|
||||
@@ -1,55 +1,28 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Bao Qian, Dong Xinyu, Chen Zhennan, Ma Tianyu
|
||||
# Email: baoqian@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""kunlun attention wrapper for context and decode"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
from itertools import accumulate
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
)
|
||||
from .utils import CommonAttentionState, CommonMetadataBuilder
|
||||
from vllm.attention.backends.utils import (
|
||||
is_block_tables_empty,
|
||||
compute_slot_mapping_start_idx,
|
||||
compute_slot_mapping,
|
||||
)
|
||||
from vllm_kunlun.ops.paged_attn import PagedAttention, PagedAttentionMetadata
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from .utils import (CommonAttentionState, CommonMetadataBuilder)
|
||||
from vllm.attention.backends.utils import (is_block_tables_empty,
|
||||
compute_slot_mapping_start_idx, compute_slot_mapping)
|
||||
from vllm_kunlun.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import async_tensor_h2d
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KunlunAttentionBackend(AttentionBackend):
|
||||
"""KunlunAttentionBackend"""
|
||||
|
||||
accept_output_buffer = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "KUNLUN_ATTENTION"
|
||||
@@ -80,9 +53,8 @@ class KunlunAttentionBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(
|
||||
num_blocks, block_size, num_kv_heads, head_size
|
||||
)
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
@@ -182,6 +154,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
@@ -194,27 +167,23 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
"""
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
"""
|
||||
return (
|
||||
(self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None)
|
||||
)
|
||||
'''
|
||||
return ((self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None))
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
"""
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
"""
|
||||
return (
|
||||
self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None)
|
||||
)
|
||||
'''
|
||||
return (self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None))
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["KunlunMetadata"]:
|
||||
@@ -227,60 +196,43 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert (self.seq_lens is not None) or (self.encoder_seq_lens is not None)
|
||||
assert (self.seq_lens_tensor is not None) or (
|
||||
self.encoder_seq_lens_tensor is not None
|
||||
)
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (
|
||||
None
|
||||
if self.query_start_loc is None
|
||||
else self.query_start_loc[: self.num_prefills + 1]
|
||||
)
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
# flash attention needs both lod information on host and device
|
||||
query_start_loc_host = (
|
||||
None
|
||||
if self.query_start_loc_host is None
|
||||
else self.query_start_loc_host[: self.num_prefills + 1]
|
||||
)
|
||||
kv_prefix_start_loc_host = (
|
||||
None
|
||||
if self.kv_prefix_start_loc_host is None
|
||||
else self.kv_prefix_start_loc_host[: self.num_prefills + 1]
|
||||
+ query_start_loc_host
|
||||
)
|
||||
kv_prefix_start_loc = (
|
||||
None
|
||||
if kv_prefix_start_loc_host is None
|
||||
else kv_prefix_start_loc_host.cuda()
|
||||
)
|
||||
slot_mapping = (
|
||||
None
|
||||
if self.slot_mapping is None
|
||||
else self.slot_mapping[: self.num_prefill_tokens]
|
||||
)
|
||||
seq_lens = None if self.seq_lens is None else self.seq_lens[: self.num_prefills]
|
||||
seq_lens_tensor = (
|
||||
None
|
||||
if self.seq_lens_tensor is None
|
||||
else self.seq_lens_tensor[: self.num_prefills]
|
||||
)
|
||||
context_lens_tensor = (
|
||||
None
|
||||
if self.context_lens_tensor is None
|
||||
else self.context_lens_tensor[: self.num_prefills]
|
||||
)
|
||||
query_start_loc_host = (None if self.query_start_loc_host is None else
|
||||
self.query_start_loc_host[:self.num_prefills + 1])
|
||||
kv_prefix_start_loc_host = (None if self.kv_prefix_start_loc_host is None else
|
||||
self.kv_prefix_start_loc_host[:self.num_prefills + 1] + query_start_loc_host)
|
||||
kv_prefix_start_loc = (None if kv_prefix_start_loc_host is None else kv_prefix_start_loc_host.cuda())
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
# for prefix cache, block table only contains blocks that hit
|
||||
# if self.block_tables is None:
|
||||
# block_tables = None
|
||||
# elif self.block_tables.shape[1] == 0:
|
||||
# block_tables = self.block_tables[:self.num_prefills]
|
||||
# else:
|
||||
# block_tables = self.block_tables[:self.num_prefills][:, -1].clone()
|
||||
|
||||
block_tables = (
|
||||
None
|
||||
if self.block_tables is None
|
||||
else self.block_tables[: self.num_prefills]
|
||||
)
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = KunlunMetadata(
|
||||
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
@@ -305,8 +257,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
enable_kv_scales_calculation=False,
|
||||
seq_start_loc=self.seq_start_loc,
|
||||
)
|
||||
seq_start_loc=self.seq_start_loc)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
@@ -319,35 +270,25 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
assert (self.seq_lens_tensor is not None) or (
|
||||
self.encoder_seq_lens_tensor is not None
|
||||
)
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (
|
||||
None
|
||||
if self.slot_mapping is None
|
||||
else self.slot_mapping[self.num_prefill_tokens :]
|
||||
)
|
||||
seq_lens_tensor = (
|
||||
None
|
||||
if self.seq_lens_tensor is None
|
||||
else self.seq_lens_tensor[self.num_prefills :]
|
||||
)
|
||||
seq_lens_tensor_cpu = (
|
||||
None
|
||||
if self.seq_lens_tensor_cpu is None
|
||||
else self.seq_lens_tensor_cpu[self.num_prefills :]
|
||||
)
|
||||
block_tables = (
|
||||
None
|
||||
if self.block_tables is None
|
||||
else self.block_tables[self.num_prefills :]
|
||||
)
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else
|
||||
self.seq_lens_tensor_cpu[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
|
||||
|
||||
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = KunlunMetadata(
|
||||
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
@@ -364,16 +305,13 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
enable_kv_scales_calculation=False)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
|
||||
"""KunlunMetadataBuilder"""
|
||||
|
||||
_metadata_cls = KunlunMetadata
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
super().__init__(input_builder)
|
||||
self.prefix_cache_kv_lens: List[int] = []
|
||||
@@ -382,120 +320,90 @@ class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
|
||||
"""prepare"""
|
||||
super().prepare()
|
||||
self.prefix_cache_kv_lens = list()
|
||||
|
||||
def _add_seq_group(
|
||||
self,
|
||||
inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool,
|
||||
):
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (
|
||||
seq_id,
|
||||
token_len,
|
||||
seq_len,
|
||||
curr_seq_len,
|
||||
query_len,
|
||||
context_len,
|
||||
curr_sliding_window_block,
|
||||
) in zip(
|
||||
inter_data.seq_ids,
|
||||
[len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens,
|
||||
inter_data.seq_lens,
|
||||
inter_data.query_lens,
|
||||
inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks,
|
||||
):
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(placeholders)
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert (
|
||||
query_len == 1
|
||||
), "seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len
|
||||
)
|
||||
assert query_len == 1, (
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
block_table = []
|
||||
assert (
|
||||
not chunked_prefill_enabled
|
||||
), "chunk prefill not supported for kunlun attention"
|
||||
assert not chunked_prefill_enabled, "chunk prefill not supported for kunlun attention"
|
||||
if inter_data.prefix_cache_hit:
|
||||
assert context_len != 0
|
||||
assert context_len % self.block_size == 0
|
||||
block_table = block_tables[seq_id][: context_len // self.block_size]
|
||||
elif (not is_prompt) and block_tables is not None:
|
||||
# block_table = block_tables[seq_id]
|
||||
block_table = block_tables[seq_id][:context_len // self.block_size]
|
||||
elif ((not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][-curr_sliding_window_block:]
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
if is_prompt:
|
||||
self.prefix_cache_kv_lens.append(context_len)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(
|
||||
is_prompt, query_len, context_len, self.sliding_window
|
||||
)
|
||||
compute_slot_mapping(
|
||||
is_profile_run,
|
||||
self.slot_mapping,
|
||||
seq_id,
|
||||
seq_len,
|
||||
context_len,
|
||||
start_idx,
|
||||
self.block_size,
|
||||
inter_data.block_tables,
|
||||
)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def build(
|
||||
self,
|
||||
seq_lens: List[int],
|
||||
query_lens: List[int],
|
||||
cuda_graph_pad_size: int,
|
||||
batch_size: int,
|
||||
):
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""build"""
|
||||
attn_meta = super().build(seq_lens, query_lens, cuda_graph_pad_size, batch_size)
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
query_start_loc_host = torch.tensor(
|
||||
query_start_loc, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
query_start_loc_host = torch.tensor(query_start_loc, dtype=torch.int32, device='cpu')
|
||||
attn_meta.query_start_loc_host = query_start_loc_host
|
||||
# max_kv_len = max(query_lens + prefix_cache_kv_lens)
|
||||
attn_meta.max_kv_len = max(self.prefix_cache_kv_lens + attn_meta.seq_lens)
|
||||
|
||||
# If kv cache is included and there is a hit
|
||||
# 包含kv cache ,且存在命中的情况
|
||||
if len(self.prefix_cache_kv_lens) != 0 and max(self.prefix_cache_kv_lens) != 0:
|
||||
self.prefix_cache_kv_lens = list(
|
||||
accumulate(self.prefix_cache_kv_lens, initial=0)
|
||||
)
|
||||
prefix_cache_kv_lens_tensor = torch.tensor(
|
||||
self.prefix_cache_kv_lens, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
self.prefix_cache_kv_lens = list(accumulate(self.prefix_cache_kv_lens, initial=0))
|
||||
prefix_cache_kv_lens_tensor = torch.tensor(self.prefix_cache_kv_lens, dtype=torch.int32, device="cpu")
|
||||
attn_meta.kv_prefix_start_loc_host = prefix_cache_kv_lens_tensor
|
||||
attn_meta.seq_lens_tensor_cpu = attn_meta.seq_lens_tensor.to("cpu")
|
||||
return attn_meta
|
||||
|
||||
|
||||
|
||||
def _get_seq_len_block_table_args(
|
||||
attn_metadata: KunlunMetadata,
|
||||
is_prompt: bool,
|
||||
attn_type: AttentionType,
|
||||
) -> tuple:
|
||||
"""
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
@@ -517,7 +425,7 @@ def _get_seq_len_block_table_args(
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
"""
|
||||
'''
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
@@ -526,26 +434,23 @@ def _get_seq_len_block_table_args(
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len, attn_metadata.block_tables)
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
||||
attn_metadata.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables,
|
||||
)
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
None,
|
||||
)
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len, None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
|
||||
class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
"""KunlunAttentionImpl"""
|
||||
|
||||
@@ -564,7 +469,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError("kunlunAttention does not support block-sparse attention.")
|
||||
raise ValueError(
|
||||
"kunlunAttention does not support block-sparse attention.")
|
||||
# if logits_soft_cap is not None:
|
||||
# raise ValueError(
|
||||
# "kunlunAttention does not support attention logits soft capping.")
|
||||
@@ -585,8 +491,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
if head_size not in suppored_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {suppored_head_sizes}."
|
||||
)
|
||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -654,21 +560,16 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
|
||||
# Check that appropriate attention metadata attributes are
|
||||
# selected for the desired attention type
|
||||
if attn_type == AttentionType.ENCODER and (
|
||||
not attn_metadata.is_all_encoder_attn_metadata_set
|
||||
):
|
||||
raise AttributeError(
|
||||
"Encoder attention requires setting " "encoder metadata attributes."
|
||||
)
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
|
||||
elif attn_type == AttentionType.ENCODER_DECODER and (
|
||||
not attn_metadata.is_all_cross_attn_metadata_set
|
||||
):
|
||||
raise AttributeError(
|
||||
"Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes."
|
||||
)
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
@@ -682,7 +583,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# which KV cache memory-mapping & which
|
||||
# seqlen datastructures we utilize
|
||||
|
||||
if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
@@ -691,8 +592,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
|
||||
@@ -701,14 +601,10 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
else:
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
value = value.contiguous()
|
||||
KunlunOps.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
)
|
||||
KunlunOps.reshape_and_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype)
|
||||
|
||||
if attn_type == AttentionType.ENCODER:
|
||||
# Encoder attention - chunked prefill is not applicable;
|
||||
@@ -753,20 +649,14 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# Prompt run.
|
||||
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||
out = KunlunOps.multi_query_kv_attention(
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.query_start_loc_host,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
).view_as(query)
|
||||
prefill_meta.query_start_loc,prefill_meta.query_start_loc_host, query, key, value,
|
||||
alibi_slopes=self.alibi_slopes).view_as(query)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert (
|
||||
attn_type != AttentionType.ENCODER_ONLY
|
||||
), "Encoder-only models should not have decode metadata."
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
(
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
@@ -791,4 +681,4 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
@@ -4,13 +4,12 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List, Dict, Any
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group,
|
||||
)
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
@@ -20,10 +19,8 @@ from torch.library import custom_op, impl
|
||||
|
||||
from vllm.platforms import _Backend
|
||||
|
||||
|
||||
class Attention(VllmAttention):
|
||||
"""Attention"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -75,8 +72,11 @@ class Attention(VllmAttention):
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(query, key, value)
|
||||
if self.use_output:
|
||||
output_shape = output_shape if output_shape is not None else query.shape
|
||||
output = torch.zeros(output_shape, dtype=query.dtype, device=query.device)
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
output = torch.zeros(output_shape,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# We skip reshaping query, key and value tensors for the MLA
|
||||
# backend since these tensors have different semantics and are
|
||||
@@ -97,13 +97,16 @@ class Attention(VllmAttention):
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(
|
||||
self, query, key, value, self_kv_cache, attn_metadata, output=output
|
||||
)
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output_kunlun(
|
||||
query, key, value, output, self.layer_name
|
||||
)
|
||||
query, key, value, output, self.layer_name)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
if self.use_direct_call:
|
||||
@@ -112,15 +115,13 @@ class Attention(VllmAttention):
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(
|
||||
self, query, key, value, self_kv_cache, attn_metadata
|
||||
)
|
||||
return self.impl.forward(self, query, key, value,
|
||||
self_kv_cache, attn_metadata)
|
||||
else:
|
||||
return unified_attention(query, key, value, self.layer_name)
|
||||
return unified_attention(
|
||||
query, key, value, self.layer_name)
|
||||
|
||||
|
||||
#
|
||||
# Rewritten from the MultiHeadAttention class in vllm.attention.layer
|
||||
# 重写自 vllm.attention.layer 中的 MultiHeadAttention 类
|
||||
class MultiHeadAttention(VllmMultiHeadAttention):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -130,15 +131,14 @@ class MultiHeadAttention(VllmMultiHeadAttention):
|
||||
num_kv_heads: Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads = num_heads,
|
||||
head_size = head_size,
|
||||
scale = scale,
|
||||
num_kv_heads = num_kv_heads,
|
||||
)
|
||||
|
||||
# kunlun only supports flash_attn
|
||||
# kunlun只支持flash_attn
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@@ -159,31 +159,34 @@ class MultiHeadAttention(VllmMultiHeadAttention):
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
# kunlun only supports flash_attn
|
||||
# kunlun只支持flash_attn
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query, key, value, scale=self.scale
|
||||
)
|
||||
out = xops.memory_efficient_attention_forward(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
return out.reshape(bsz, q_len, -1)
|
||||
|
||||
|
||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
"""wait_for_kv_layer_from_connector"""
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
@@ -198,10 +201,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
|
||||
def maybe_save_kv_layer_to_connector(
|
||||
layer_name: str, kv_cache_layer: List[torch.Tensor]
|
||||
):
|
||||
layer_name: str,
|
||||
kv_cache_layer: List[torch.Tensor]):
|
||||
"""maybe_save_kv_layer_to_connector"""
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
@@ -213,8 +215,8 @@ def maybe_save_kv_layer_to_connector(
|
||||
if attn_metadata is None:
|
||||
return
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
|
||||
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer,
|
||||
attn_metadata[layer_name])
|
||||
|
||||
@custom_op("vllm::unified_attention_with_output_kunlun", mutates_args=())
|
||||
def unified_attention_with_output_kunlun(
|
||||
@@ -223,8 +225,7 @@ def unified_attention_with_output_kunlun(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
output_scale: Optional[torch.Tensor] = None,) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
@@ -232,26 +233,26 @@ def unified_attention_with_output_kunlun(
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self, query, key, value, kv_cache, attn_metadata, output=output)
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
|
||||
|
||||
def _fake_unified_attention_with_output_kunlun(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
output_scale: Optional[torch.Tensor] = None,) -> None:
|
||||
return None
|
||||
|
||||
|
||||
unified_attention_with_output_kunlun.register_fake(
|
||||
_fake_unified_attention_with_output_kunlun
|
||||
)
|
||||
|
||||
unified_attention_with_output_kunlun.register_fake(_fake_unified_attention_with_output_kunlun)
|
||||
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
@@ -268,7 +269,8 @@ def unified_attention(
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||
output = self.impl.forward(self, query, key, value, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
return output
|
||||
return output
|
||||
9
vllm_kunlun/ops/fla/__init__.py
Normal file
9
vllm_kunlun/ops/fla/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .chunk import chunk_gated_delta_rule
|
||||
from .fused_recurrent import fused_recurrent_gated_delta_rule
|
||||
from .layernorm_guard import RMSNormGated
|
||||
from .torch_fla import l2norm, torch_chunk_gated_delta_rule
|
||||
__all__ = [
|
||||
"RMSNormGated",
|
||||
"chunk_gated_delta_rule",
|
||||
"fused_recurrent_gated_delta_rule",
|
||||
]
|
||||
247
vllm_kunlun/ops/fla/chunk.py
Normal file
247
vllm_kunlun/ops/fla/chunk.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
from typing import Optional
|
||||
import torch.nn.functional as F
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from einops import rearrange
|
||||
|
||||
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
||||
from .chunk_o import chunk_fwd_o
|
||||
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
||||
from .cumsum import chunk_local_cumsum
|
||||
from .l2norm import l2norm_fwd
|
||||
from .solve_tril import solve_tril
|
||||
from .utils import SUPPRESS_LEVEL, input_guard
|
||||
from .wy_fast import recompute_w_u_fwd
|
||||
|
||||
|
||||
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
|
||||
chunk_size=64
|
||||
A = A.transpose(1,2)
|
||||
sequence_length = A.shape[-2]
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
A = F.pad(A, (0, 0, 0, pad_size))
|
||||
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
|
||||
|
||||
A = A.masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = A[..., i, :i].clone()
|
||||
sub = A[..., :i, :i].clone()
|
||||
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
|
||||
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[:,:,:sequence_length,:].transpose(1,2)
|
||||
|
||||
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
A = chunk_scaled_dot_kkt_fwd(k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_dtype=q.dtype)
|
||||
|
||||
#torch版
|
||||
for i in range(len(cu_seqlens)-1):
|
||||
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
o = chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_new,
|
||||
h=h,
|
||||
g=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
if SUPPRESS_LEVEL < 3:
|
||||
return g, o, A, final_state, None, None, None
|
||||
elif SUPPRESS_LEVEL >= 3:
|
||||
return g, o, A, final_state, w, h, v_new
|
||||
|
||||
|
||||
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
@torch.amp.custom_fwd(device_type='cuda')
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm_fwd(q)
|
||||
k = l2norm_fwd(k)
|
||||
|
||||
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
ctx.scale = scale
|
||||
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def chunk_gated_delta_rule(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
||||
k (torch.Tensor):
|
||||
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
||||
v (torch.Tensor):
|
||||
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
||||
g (torch.Tensor):
|
||||
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
||||
beta (torch.Tensor):
|
||||
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
head_first (Optional[bool]):
|
||||
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
||||
Default: `False`.
|
||||
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
|
||||
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> o, ht = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o_var, ht_var = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
"""
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||
assert len(
|
||||
beta.shape
|
||||
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
|
||||
if head_first:
|
||||
raise DeprecationWarning(
|
||||
"head_first is deprecated and will be removed in a future version. "
|
||||
"Please use head_first=False for now instead.",
|
||||
stacklevel=2)
|
||||
q, k, v, beta, g = map(
|
||||
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
|
||||
(q, k, v, beta, g))
|
||||
if not head_first and q.shape[1] < q.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2)
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
if initial_state is not None and initial_state.shape[0] != len(
|
||||
cu_seqlens) - 1:
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
|
||||
use_qk_l2norm_in_kernel)
|
||||
if head_first:
|
||||
o = rearrange(o, 'b t h ... -> b h t ...')
|
||||
return o, final_state
|
||||
251
vllm_kunlun/ops/fla/chunk_delta_h.py
Normal file
251
vllm_kunlun/ops/fla/chunk_delta_h.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
||||
from .op import exp
|
||||
from .utils import is_nvidia_hopper, use_cuda_graph
|
||||
|
||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
||||
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
v_new,
|
||||
g,
|
||||
h,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr,
|
||||
SAVE_NEW_VALUE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||
i_n, i_h = i_nh // H, i_nh % H
|
||||
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = i_n * NT
|
||||
|
||||
# [BK, BV]
|
||||
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 64:
|
||||
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 128:
|
||||
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 192:
|
||||
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
|
||||
# calculate offset
|
||||
h += (boh * H + i_h) * K * V
|
||||
v += (bos * H + i_h) * V
|
||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
||||
w += (bos * H + i_h) * K
|
||||
if SAVE_NEW_VALUE:
|
||||
v_new += (bos * H + i_h) * V
|
||||
stride_v = H * V
|
||||
stride_h = H * K * V
|
||||
stride_k = Hg * K
|
||||
stride_w = H * K
|
||||
if USE_INITIAL_STATE:
|
||||
h0 = h0 + i_nh * K * V
|
||||
if STORE_FINAL_STATE:
|
||||
ht = ht + i_nh * K * V
|
||||
|
||||
# load initial state
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 64:
|
||||
p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 128:
|
||||
p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 192:
|
||||
p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
# main recurrence
|
||||
for i_t in range(NT):
|
||||
p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_v_new = (
|
||||
tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
if SAVE_NEW_VALUE
|
||||
else None
|
||||
)
|
||||
b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
|
||||
if K > 64:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
|
||||
if K > 128:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
|
||||
if K > 192:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
|
||||
|
||||
b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
if SAVE_NEW_VALUE:
|
||||
p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
tl.store(p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
if USE_G:
|
||||
m_t = (i_t * BT + tl.arange(0, BT)) < T
|
||||
last_idx = min((i_t + 1) * BT, T) - 1
|
||||
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
|
||||
p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_v_new = b_v_new * tl.where(m_t, tl.exp(b_g_last - b_g), 0)[:, None]
|
||||
b_g_last = tl.exp(b_g_last)
|
||||
b_h1 = b_h1 * b_g_last
|
||||
if K > 64:
|
||||
b_h2 = b_h2 * b_g_last
|
||||
if K > 128:
|
||||
b_h3 = b_h3 * b_g_last
|
||||
if K > 192:
|
||||
b_h4 = b_h4 * b_g_last
|
||||
b_v_new = b_v_new.to(k.dtype.element_ty)
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h1 += tl.dot(b_k, b_v_new)
|
||||
if K > 64:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h2 += tl.dot(b_k, b_v_new)
|
||||
if K > 128:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h3 += tl.dot(b_k, b_v_new)
|
||||
if K > 192:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h4 += tl.dot(b_k, b_v_new)
|
||||
|
||||
# epilogue
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd_h(
|
||||
k: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, u.shape[-1]
|
||||
H = u.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
else:
|
||||
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(
|
||||
chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h = k.new_empty(B, NT, H, K, V)
|
||||
final_state = k.new_empty(
|
||||
N, H, K, V, dtype=torch.float32) if output_final_state else None
|
||||
|
||||
v_new = torch.empty_like(u) if save_new_value else None
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta['BV']), N * H)
|
||||
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
|
||||
k=k,
|
||||
v=u,
|
||||
w=w,
|
||||
v_new=v_new,
|
||||
g=g,
|
||||
h=h,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_offsets=chunk_offsets,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BV=64,
|
||||
)
|
||||
return h, v_new, final_state
|
||||
180
vllm_kunlun/ops/fla/chunk_o.py
Normal file
180
vllm_kunlun/ops/fla/chunk_o.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import exp
|
||||
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
|
||||
|
||||
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
|
||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_G': lambda args: args['g'] is not None,
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
||||
})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({
|
||||
# 'BK': BK,
|
||||
# 'BV': BV
|
||||
# },
|
||||
# num_warps=num_warps,
|
||||
# num_stages=num_stages) for BK in BKV_LIST
|
||||
# for BV in BKV_LIST for num_warps in NUM_WARPS
|
||||
# for num_stages in [2, 3, 4]
|
||||
# ],
|
||||
# key=['H', 'K', 'V', 'BT'],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
|
||||
if IS_VARLEN:
|
||||
i_tg = i_t
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
else:
|
||||
NT = tl.cdiv(T, BT)
|
||||
i_tg = i_b * NT + i_t
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
# offset calculation
|
||||
q += (bos * Hg + i_h // (H // Hg)) * K
|
||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
||||
v += (bos * H + i_h) * V
|
||||
o += (bos * H + i_h) * V
|
||||
h += (i_tg * H + i_h).to(tl.int64) * K * V
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
|
||||
(BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT),
|
||||
(BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV),
|
||||
(BK, BV), (1, 0))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BK] @ [BK, BV] -> [BT, BV]
|
||||
b_o += tl.dot(b_q, b_h)
|
||||
# [BT, BK] @ [BK, BT] -> [BT, BT]
|
||||
b_A += tl.dot(b_q, b_k)
|
||||
|
||||
if USE_G:
|
||||
g += bos * H + i_h
|
||||
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
b_o = b_o * tl.exp(b_g)[:, None]
|
||||
b_A = b_A * tl.exp(b_g[:, None] - b_g[None, :])
|
||||
|
||||
o_t = i_t * BT + tl.arange(0, BT)
|
||||
# m_t = o_t < T
|
||||
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
|
||||
# b_A = tl.where(m_A, b_A, 0)
|
||||
b_A = tl.where(o_t[:, None] >= o_t[None, :], b_A, 0)
|
||||
|
||||
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
# to fix mma -> mma layout conversion
|
||||
# already solved by triton v3.2 or higher
|
||||
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_fwd_o(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None, # cumsum of log decay
|
||||
scale: Optional[float] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
if FLA_GDN_FIX_BT:
|
||||
BT = 64
|
||||
else:
|
||||
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
|
||||
o = torch.empty_like(v)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta['BV']), NT, B * H)
|
||||
|
||||
chunk_fwd_kernel_o[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=64,
|
||||
BV=32
|
||||
)
|
||||
return o
|
||||
144
vllm_kunlun/ops/fla/chunk_scaled_dot_kkt.py
Normal file
144
vllm_kunlun/ops/fla/chunk_scaled_dot_kkt.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import exp
|
||||
|
||||
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
||||
'USE_G': lambda args: args['g_cumsum'] is not None
|
||||
})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for BK in [32, 64, 128] for num_warps in [2, 4, 8]
|
||||
# for num_stages in [2, 3, 4]
|
||||
# ],
|
||||
# key=['H', 'K', 'BT', 'IS_VARLEN'],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
k,
|
||||
beta,
|
||||
g_cumsum,
|
||||
A,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
o_t = i_t * BT + tl.arange(0, BT)
|
||||
#m_t = o_t < T
|
||||
|
||||
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0, ))
|
||||
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K),
|
||||
(Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK),
|
||||
(1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = b_k * b_beta[:, None]
|
||||
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
|
||||
|
||||
if USE_G:
|
||||
p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T, ), (H, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||
b_A = b_A * tl.exp(b_g_diff) # 使用了triton而非vllm中的exp
|
||||
|
||||
#m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
|
||||
#b_A = tl.where(m_A, b_A, 0)
|
||||
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
|
||||
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
|
||||
(i_t * BT, 0), (BT, BT), (1, 0))
|
||||
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_scaled_dot_kkt_fwd(
|
||||
k: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
r"""
|
||||
Compute beta * K * K^T.
|
||||
|
||||
Args:
|
||||
k (torch.Tensor):
|
||||
The key tensor of shape `[B, T, H, K]`.
|
||||
beta (torch.Tensor):
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
g_cumsum (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H]`.
|
||||
Default: None
|
||||
cu_seqlens (torch.LongTensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
The chunk size. Default: 64.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float32`
|
||||
|
||||
Returns:
|
||||
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
|
||||
"""
|
||||
|
||||
B, T, Hg, K = k.shape
|
||||
|
||||
H = beta.shape[-1]
|
||||
BT = chunk_size
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g_cumsum,
|
||||
A=A,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
BT=BT,
|
||||
BK=64,
|
||||
)
|
||||
return A
|
||||
229
vllm_kunlun/ops/fla/cumsum.py
Normal file
229
vllm_kunlun/ops/fla/cumsum.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .utils import check_shared_mem, input_guard
|
||||
|
||||
BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
# @triton.autotune(configs=[
|
||||
# triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]
|
||||
# ],
|
||||
# key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE'])
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_local_cumsum_scalar_kernel(
|
||||
s,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
HEAD_FIRST: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
if HEAD_FIRST:
|
||||
p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T, ), (1, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T, ), (1, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
else:
|
||||
p_s = tl.make_block_ptr(s + bos * H + i_h, (T, ), (H, ), (i_t * BT, ),
|
||||
(BT, ), (0, ))
|
||||
p_o = tl.make_block_ptr(o + bos * H + i_h, (T, ), (H, ), (i_t * BT, ),
|
||||
(BT, ), (0, ))
|
||||
# [BT]
|
||||
b_s = tl.load(p_s, boundary_check=(0, )).to(tl.float32)
|
||||
b_o = tl.cumsum(b_s, axis=0)
|
||||
if REVERSE:
|
||||
b_z = tl.sum(b_s, axis=0)
|
||||
b_o = -b_o + b_z[None] + b_s
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, ))
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
# @triton.autotune(configs=[
|
||||
# triton.Config({'BS': BS}, num_warps=num_warps) for BS in BS_LIST
|
||||
# for num_warps in [2, 4, 8]
|
||||
# ],
|
||||
# key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE'])
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_local_cumsum_vector_kernel(
|
||||
s,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BS: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
HEAD_FIRST: tl.constexpr,
|
||||
):
|
||||
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
if REVERSE:
|
||||
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
|
||||
else:
|
||||
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
||||
|
||||
if HEAD_FIRST:
|
||||
p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
else:
|
||||
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
# [BT, BS]
|
||||
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = tl.dot(m_s, b_s, allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_local_cumsum_scalar(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor:
|
||||
if head_first:
|
||||
B, H, T = g.shape
|
||||
else:
|
||||
B, T, H = g.shape
|
||||
assert chunk_size == 2**(chunk_size.bit_length() -
|
||||
1), "chunk_size must be a power of 2"
|
||||
BT = chunk_size
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
grid = (NT, B * H)
|
||||
chunk_local_cumsum_scalar_kernel[grid](
|
||||
s=g_org,
|
||||
o=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
BT=BT,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse,
|
||||
is_use_mask_zero = True
|
||||
)
|
||||
return g
|
||||
|
||||
|
||||
def chunk_local_cumsum_vector(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor:
|
||||
if head_first:
|
||||
B, H, T, S = g.shape
|
||||
else:
|
||||
B, T, H, S = g.shape
|
||||
BT = chunk_size
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
assert chunk_size == 2**(chunk_size.bit_length() -
|
||||
1), "chunk_size must be a power of 2"
|
||||
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
|
||||
|
||||
# keep cumulative normalizer in fp32
|
||||
# this kernel is equivalent to
|
||||
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
|
||||
chunk_local_cumsum_vector_kernel[grid](g_org,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
S=S,
|
||||
BT=BT,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse)
|
||||
return g
|
||||
|
||||
|
||||
@input_guard
|
||||
def chunk_local_cumsum(g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float,
|
||||
**kwargs) -> torch.Tensor:
|
||||
if not head_first and g.shape[1] < g.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2)
|
||||
if cu_seqlens is not None:
|
||||
assert g.shape[
|
||||
0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
|
||||
if len(g.shape) == 3:
|
||||
return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens,
|
||||
head_first, output_dtype)
|
||||
elif len(g.shape) == 4:
|
||||
return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens,
|
||||
head_first, output_dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape {g.shape}. "
|
||||
f"which should be (B, T, H, D) if `head_first=False` "
|
||||
f"or (B, H, T, D) otherwise")
|
||||
153
vllm_kunlun/ops/fla/fused_recurrent.py
Normal file
153
vllm_kunlun/ops/fla/fused_recurrent.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import xtorch_ops
|
||||
|
||||
|
||||
class FusedRecurrentFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
|
||||
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwdv2(
|
||||
q.contiguous(),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
g.contiguous(),
|
||||
beta.contiguous(),
|
||||
scale,
|
||||
initial_state,
|
||||
inplace_final_state=inplace_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
h0_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||
)
|
||||
return o, final_state
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor = None,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `[B, T, H, K]`.
|
||||
k (torch.Tensor):
|
||||
keys of shape `[B, T, H, K]`.
|
||||
v (torch.Tensor):
|
||||
values of shape `[B, T, HV, V]`.
|
||||
GVA is applied if `HV > H`.
|
||||
g (torch.Tensor):
|
||||
g (decays) of shape `[B, T, HV]`.
|
||||
beta (torch.Tensor):
|
||||
betas of shape `[B, T, HV]`.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
inplace_final_state: bool:
|
||||
Whether to store the final state in-place to save memory.
|
||||
Default: `True`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
ssm_state_indices (Optional[torch.Tensor]):
|
||||
Indices to map the input sequences to the initial/final states.
|
||||
num_accepted_tokens (Optional[torch.Tensor]):
|
||||
Number of accepted tokens for each sequence during decoding.
|
||||
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, HV, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, HV, K, V]`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, HV, V, device='cuda')
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
|
||||
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
|
||||
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
|
||||
>>> o, ht = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
"""
|
||||
if cu_seqlens is not None and q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
else:
|
||||
assert scale > 0, "scale must be positive"
|
||||
if beta is None:
|
||||
beta = torch.ones_like(q[..., 0])
|
||||
o, final_state = FusedRecurrentFunction.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
inplace_final_state,
|
||||
cu_seqlens,
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
return o, final_state
|
||||
38
vllm_kunlun/ops/fla/index.py
Normal file
38
vllm_kunlun/ops/fla/index.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
from .utils import tensor_cache
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_indices(cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int) -> torch.LongTensor:
|
||||
indices = torch.cat([
|
||||
torch.arange(n)
|
||||
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
||||
])
|
||||
return torch.stack([indices.eq(0).cumsum(0) - 1, indices],
|
||||
1).to(cu_seqlens)
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int) -> torch.LongTensor:
|
||||
return torch.cat([
|
||||
cu_seqlens.new_tensor([0]),
|
||||
triton.cdiv(prepare_lens(cu_seqlens), chunk_size)
|
||||
]).cumsum(-1)
|
||||
143
vllm_kunlun/ops/fla/l2norm.py
Normal file
143
vllm_kunlun/ops/fla/l2norm.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
BT_LIST = [8, 16, 32, 64, 128]
|
||||
|
||||
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
|
||||
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16, 32]
|
||||
],
|
||||
key=['D'])
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel1(
|
||||
x,
|
||||
y,
|
||||
D,
|
||||
BD: tl.constexpr,
|
||||
eps,
|
||||
):
|
||||
i_t = tl.program_id(0)
|
||||
x += i_t * D
|
||||
y += i_t * D
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BD)
|
||||
mask = cols < D
|
||||
b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
b_var = tl.sum(b_x * b_x, axis=0)
|
||||
b_rstd = 1 / tl.sqrt(b_var + eps)
|
||||
# tl.store(Rstd + i_t, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
b_y = b_x * b_rstd
|
||||
tl.store(y + cols, b_y, mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({'BT': BT}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
|
||||
],
|
||||
key=['D'])
|
||||
@triton.jit(do_not_specialize=["NB"])
|
||||
def l2norm_fwd_kernel(
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB,
|
||||
T,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BD: tl.constexpr,
|
||||
):
|
||||
i_t = tl.program_id(0)
|
||||
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
||||
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_var = tl.sum(b_x * b_x, axis=1)
|
||||
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
|
||||
p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
||||
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
||||
xoffset = tl.program_id(0) * MBLOCK
|
||||
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
|
||||
xmask = row_idx < M
|
||||
rindex = tl.arange(0, N)[None, :]
|
||||
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
|
||||
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
|
||||
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
|
||||
rsqrt = tl.rsqrt(square_sum + eps)
|
||||
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
||||
|
||||
|
||||
def l2norm_fwd(x: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
output_dtype: Optional[torch.dtype] = None):
|
||||
x_shape_og = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# allocate output
|
||||
if output_dtype is None:
|
||||
y = torch.empty_like(x)
|
||||
else:
|
||||
y = torch.empty_like(x, dtype=output_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
T, D = x.shape[0], x.shape[-1]
|
||||
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
|
||||
if D > BD:
|
||||
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
|
||||
|
||||
if not USE_DEFAULT_FLA_NORM:
|
||||
MBLOCK = 32
|
||||
# M, N = x.shape
|
||||
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
T,
|
||||
D,
|
||||
MBLOCK,
|
||||
)
|
||||
else:
|
||||
if D <= 512:
|
||||
NB = triton.cdiv(T, 2048)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(T, meta['BT']), )
|
||||
|
||||
l2norm_fwd_kernel[grid](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB=NB,
|
||||
T=T,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
else:
|
||||
l2norm_fwd_kernel1[(T, )](
|
||||
x,
|
||||
y,
|
||||
eps=eps,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
|
||||
return y.view(x_shape_og)
|
||||
343
vllm_kunlun/ops/fla/layernorm_guard.py
Normal file
343
vllm_kunlun/ops/fla/layernorm_guard.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Tri Dao
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
|
||||
# ruff: noqa: E501
|
||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
||||
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .utils import input_guard
|
||||
|
||||
|
||||
def rms_norm_ref(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
upcast=True):
|
||||
dtype = x.dtype
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
if upcast:
|
||||
x = x.float()
|
||||
z = z.float() if z is not None else z
|
||||
if z is not None and not norm_before_gate:
|
||||
x = x * F.silu(z)
|
||||
if group_size is None:
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
|
||||
weight)
|
||||
else:
|
||||
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
||||
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
|
||||
eps)
|
||||
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
if z is not None and norm_before_gate:
|
||||
out *= F.silu(z)
|
||||
return out.to(dtype)
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"HAS_BIAS": lambda args: args["B"] is not None,
|
||||
"HAS_Z": lambda args: args["Z"] is not None,
|
||||
})
|
||||
@triton.jit
|
||||
def layer_norm_fwd_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_z_row,
|
||||
M, # number of rows in X
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
X += row * stride_x_row + group * N
|
||||
Y += row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z += row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean += group * M
|
||||
Rstd += group * M
|
||||
W += group * N
|
||||
if HAS_BIAS:
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def layer_norm_fwd(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
eps: float,
|
||||
z: torch.Tensor = None,
|
||||
out: torch.Tensor = None,
|
||||
group_size: int = None,
|
||||
norm_before_gate: bool = True,
|
||||
is_rms_norm: bool = False,
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
# if weight.shape != (N,):
|
||||
# weight = weight.reshape(N)
|
||||
# print("weight",weight.shape)
|
||||
# print("x",x.shape)
|
||||
assert weight.shape == (N, )
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N, )
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
|
||||
device=x.device) if not is_rms_norm else None
|
||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
layer_norm_fwd_kernel[grid](x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@input_guard
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
# if z.shape != x_shape_og:
|
||||
# z = z.reshape(x_shape_og)
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.eps = eps
|
||||
ctx.group_size = group_size
|
||||
ctx.norm_before_gate = norm_before_gate
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def layernorm_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
||||
norm_before_gate, is_rms_norm)
|
||||
|
||||
|
||||
def rmsnorm_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
||||
norm_before_gate, True)
|
||||
|
||||
|
||||
class LayerNormGated(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
eps: float = 1e-5,
|
||||
group_size: Optional[int] = None,
|
||||
norm_before_gate: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
torch.nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
return layernorm_fn(x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
group_size=self.group_size,
|
||||
eps=self.eps,
|
||||
norm_before_gate=self.norm_before_gate)
|
||||
|
||||
|
||||
class RMSNormGated(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
eps: float = 1e-5,
|
||||
group_size: Optional[int] = None,
|
||||
norm_before_gate: bool = False,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
return rmsnorm_fn(x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate)
|
||||
39
vllm_kunlun/ops/fla/op.py
Normal file
39
vllm_kunlun/ops/fla/op.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import os
|
||||
|
||||
from vllm.triton_utils import tl, tldevice, triton
|
||||
|
||||
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
|
||||
div = tldevice.fast_dividef
|
||||
exp = tldevice.fast_expf
|
||||
log = tldevice.fast_logf
|
||||
log2 = tldevice.fast_log2f
|
||||
else:
|
||||
|
||||
@triton.jit
|
||||
def div_normal(x, y):
|
||||
return x / y
|
||||
|
||||
div = div_normal
|
||||
exp = tl.exp
|
||||
log = tl.log
|
||||
log2 = tl.log2
|
||||
|
||||
|
||||
if not hasattr(tl, 'gather'):
|
||||
|
||||
@triton.jit
|
||||
def gather(src, index, axis, _builder=None):
|
||||
# This is a fallback implementation when tl.gather is not supported
|
||||
# In order to pass triton compiler, there is no actual gather operation
|
||||
return src
|
||||
else:
|
||||
gather = tl.gather
|
||||
422
vllm_kunlun/ops/fla/solve_tril.py
Normal file
422
vllm_kunlun/ops/fla/solve_tril.py
Normal file
@@ -0,0 +1,422 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import os
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .utils import input_guard
|
||||
|
||||
base_dir = os.path.dirname(__file__)
|
||||
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
def prepare_chunk_indices(
|
||||
cu_seqlens: torch.LongTensor, chunk_size: int
|
||||
) -> torch.LongTensor:
|
||||
indices = torch.cat(
|
||||
[
|
||||
torch.arange(n)
|
||||
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
||||
]
|
||||
)
|
||||
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for num_warps in [1, 2, 4, 8]
|
||||
# for num_stages in [2, 3, 4, 5]
|
||||
# ],
|
||||
# key=["BT"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def solve_tril_16x16_kernel(
|
||||
A,
|
||||
Ad,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
A = A + (bos * H + i_h) * BT
|
||||
Ad = Ad + (bos * H + i_h) * 16
|
||||
|
||||
offset = (i_t * 16) % BT
|
||||
p_A = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)
|
||||
|
||||
o_i = tl.arange(0, 16)
|
||||
for i in range(1, min(16, T - i_t * 16)):
|
||||
b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
|
||||
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
|
||||
mask = o_i == i
|
||||
b_A = tl.where(mask[:, None], b_a, b_A)
|
||||
b_A += o_i[:, None] == o_i[None, :]
|
||||
tl.store(
|
||||
p_Ai,
|
||||
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def solve_tril_16x16_kernel_modified(
|
||||
i_t,
|
||||
i_bh,
|
||||
i_n,
|
||||
bos,
|
||||
i_b,
|
||||
i_h,
|
||||
subA,
|
||||
subAd,
|
||||
A,
|
||||
Ad,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T, # 32
|
||||
H: tl.constexpr, # 4
|
||||
BT: tl.constexpr, # 64
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
A = A + (bos * H + i_h) * BT
|
||||
print("for A Base offset ", (bos * H + i_h) * BT)
|
||||
|
||||
offset = (i_t * 16) % BT
|
||||
|
||||
range16 = tl.arange(0, 16)
|
||||
newp_A = subA + range16[:, None] * 16 + range16[None, :]
|
||||
b_A = tl.load(newp_A).to(tl.float32)
|
||||
|
||||
o_i = tl.arange(0, 16)
|
||||
for i in range(1, min(16, T - i_t * 16)):
|
||||
print("[naive impl-0]loopIdx:", i)
|
||||
# print("for A start (i_t * 16 + i) * H * BT", (i_t * 16 + i) * H * BT)
|
||||
# print("for A start offset", offset)
|
||||
# print("for A start", (i_t * 16 + i) * H * BT + offset)
|
||||
print("[naive impl-1]b_A value in now loopIdx:", b_A)
|
||||
b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
|
||||
# print("[naive impl-2]b_a value in now loopIdx:", b_a)
|
||||
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
|
||||
print("[naive impl-2-1]b_a value after reduce in now loopIdx:", b_a)
|
||||
mask = o_i == i
|
||||
b_A = tl.where(mask[:, None], b_a, b_A)
|
||||
print("[naive impl-2-2]b_A value after oimask in now loopIdx:", b_A)
|
||||
# print("[naive impl-3]b_A result in now loopIdx:", b_A)
|
||||
# print(f"[naive impl-4] b_A value after allLoop = {b_A}")
|
||||
b_A += o_i[:, None] == o_i[None, :]
|
||||
# print(f"[naive impl-5] b_A value after mask = {b_A}")
|
||||
|
||||
newp_Ad = subAd + range16[:, None] * 16 + range16[None, :]
|
||||
tl.store(
|
||||
newp_Ad,
|
||||
b_A.to(subAd.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for num_warps in [1, 2, 4, 8]
|
||||
# for num_stages in [2, 3, 4, 5]
|
||||
# ],
|
||||
# key=["BT"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def solve_tril_16x16_kernel_modified_in_Loop(
|
||||
i_t,
|
||||
i_bh,
|
||||
i_n,
|
||||
bos,
|
||||
i_b,
|
||||
i_h,
|
||||
subA,
|
||||
subAd,
|
||||
AInLoop,
|
||||
ba_reduce,
|
||||
loopIdx,
|
||||
reduce_res,
|
||||
A,
|
||||
Ad,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T, # 32
|
||||
H: tl.constexpr, # 4
|
||||
BT: tl.constexpr, # 64
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
range16 = tl.arange(0, 16)
|
||||
newp_A = subA + range16[:, None] * 16 + range16[None, :]
|
||||
b_A = tl.load(newp_A).to(tl.float32)
|
||||
# print("[loop impl-0]loopIdx:", loopIdx)
|
||||
# print("[loop impl-1]b_A value in now loopIdx:", b_A)
|
||||
|
||||
o_i = tl.arange(0, 16)
|
||||
i=loopIdx
|
||||
b_a = -tl.load(AInLoop + o_i)
|
||||
# print("[loop impl-2]b_a value in now loopIdx:", b_a)
|
||||
red_res = b_a[:, None] * b_A
|
||||
# print("[Triton]red_res=", red_res)
|
||||
tl.store(reduce_res + range16[:, None] * 16 + range16[None, :], red_res)
|
||||
# b_a = b_a + tl.sum(b_a[:, None] * b_A, 1) # TODO: revert to 0
|
||||
# # print("triton reduce b_a", b_a)
|
||||
# tl.store(ba_reduce + o_i, b_a)
|
||||
|
||||
# mask = o_i == i
|
||||
# # print("mask", mask[:, None])
|
||||
# # print("b_a", b_a)
|
||||
# # print("b_A", b_A)
|
||||
# print("before b_A", b_A)
|
||||
# b_A = tl.where(mask[:, None], b_a, b_A)
|
||||
# print("[loop impl-3]b_A result in now loopIdx:", b_A)
|
||||
|
||||
# tl.store(newp_A, b_A)
|
||||
|
||||
|
||||
def solve_tril_16x16_kernel_new(
|
||||
NT,
|
||||
B,
|
||||
A,
|
||||
Ad,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H,
|
||||
BT,
|
||||
IS_VARLEN,
|
||||
):
|
||||
Ad_modify = Ad
|
||||
for loopX in range(NT):
|
||||
# i_n, i_t = tl.load(chunk_indices ...
|
||||
chunk_indices_load_offset_1 = loopX * 2
|
||||
row_idx = chunk_indices_load_offset_1 // chunk_indices.shape[1]
|
||||
col_idx = chunk_indices_load_offset_1 % chunk_indices.shape[1]
|
||||
i_n = int(chunk_indices[row_idx, col_idx])
|
||||
chunk_indices_load_offset_2 = loopX * 2 + 1
|
||||
row_idx = chunk_indices_load_offset_2 // chunk_indices.shape[1]
|
||||
col_idx = chunk_indices_load_offset_2 % chunk_indices.shape[1]
|
||||
i_t = int(chunk_indices[row_idx, col_idx])
|
||||
|
||||
# bos, eos = tl.load(cu_seqlens ...
|
||||
cu_seqlens_load_offset_1 = i_n
|
||||
bos = int(cu_seqlens[cu_seqlens_load_offset_1])
|
||||
cu_seqlens_load_offset_2 = i_n + 1
|
||||
eos = int(cu_seqlens[cu_seqlens_load_offset_2])
|
||||
T = eos - bos
|
||||
|
||||
for loopY in range(B * H):
|
||||
i_b = loopY // H
|
||||
i_h = loopY % H
|
||||
|
||||
# get subA
|
||||
if (bos * H + i_h) < H:
|
||||
Tstart = loopX * 16 % BT
|
||||
Tend = Tstart + 16
|
||||
BTstart = loopX * 16 % BT
|
||||
BTend = BTstart + 16
|
||||
subA = A[0, Tstart:Tend, loopY, BTstart:BTend].contiguous().clone()
|
||||
# print(f"subA slice A dim[0, {Tstart}:{Tend}, {loopY}, {BTstart}:{BTend}]")
|
||||
if (Tend > T): # bondary check
|
||||
subA[T-16:, :] = 0
|
||||
|
||||
# subA.shape torch.Size([9, 16])
|
||||
# vvv
|
||||
# subA.shape torch.Size([16, 16]) 用0补齐
|
||||
if subA.shape[0] < 16:
|
||||
pad_rows = 16 - subA.shape[0]
|
||||
zeros = torch.zeros((pad_rows, subA.shape[1]), dtype=subA.dtype, device=subA.device)
|
||||
subA = torch.cat([subA, zeros], dim=0)
|
||||
else:
|
||||
assert(0) & "need deal this situation"
|
||||
|
||||
# get subAd
|
||||
if (bos * H + i_h) < H:
|
||||
Tstart = loopX * 16
|
||||
Tend = Tstart + 16
|
||||
BTstart = 0 * 16
|
||||
BTend = BTstart + 16
|
||||
subAd = Ad_modify[0, Tstart:Tend, loopY, BTstart:BTend].contiguous().clone()
|
||||
# print(f'T={T}, Tstart={Tstart}, Tend={Tend}, BTstart={BTstart}, BTend={BTend}')
|
||||
else:
|
||||
assert(0) & "need deal this situation"
|
||||
|
||||
mask = (torch.arange(16, device=subA.device)[:, None] > torch.arange(16, device=subA.device)[None, :])
|
||||
subA = -torch.where(mask, subA, torch.zeros_like(subA))
|
||||
|
||||
for inLoopIdx in range(1, min(16, T - i_t * 16)):
|
||||
# print(f"loopX={loopX}, loopY={loopY}, inLoopIdx={inLoopIdx}")
|
||||
offsetStart=loopX*16 % BT
|
||||
offsetEnd=offsetStart+16
|
||||
|
||||
AInLoop = A[0, (loopX * 16 + inLoopIdx), loopY, offsetStart:offsetEnd]
|
||||
# print(f"AInLoop slice A dim[0, {(loopX * 16 + inLoopIdx)}, {loopY}, {offsetStart}:{offsetEnd}")
|
||||
|
||||
ba_reduce = torch.empty_like(AInLoop)
|
||||
reduce_res = torch.empty_like(subA)
|
||||
solve_tril_16x16_kernel_modified_in_Loop[1, 1](
|
||||
i_t,
|
||||
loopY,
|
||||
i_n,
|
||||
bos,
|
||||
i_b,
|
||||
i_h,
|
||||
subA=subA,
|
||||
subAd=subAd,
|
||||
AInLoop=AInLoop,
|
||||
ba_reduce=ba_reduce,
|
||||
loopIdx=inLoopIdx,
|
||||
reduce_res=reduce_res,
|
||||
A=A,
|
||||
Ad=Ad_modify,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
num_warps=1,
|
||||
num_stages=4,
|
||||
)
|
||||
AInLoop = AInLoop.flatten()
|
||||
b_A = subA # [16x16]
|
||||
b_a = -AInLoop[0:16] # [16]
|
||||
b_a = b_a + torch.sum(reduce_res, 0)
|
||||
ba_reduce = b_a
|
||||
o_i = torch.arange(16, device=ba_reduce.device)
|
||||
mask = (o_i == inLoopIdx)
|
||||
mask_expand = mask[:, None]
|
||||
subA = torch.where(mask_expand, ba_reduce, subA)
|
||||
|
||||
subAd = subA + (torch.arange(16, device=subA.device)[:, None] == torch.arange(16, device=subA.device)[None, :])
|
||||
|
||||
# deal store mask
|
||||
Tstart = loopX * 16
|
||||
Tend = Tstart + 16
|
||||
BTstart = 0 * 16
|
||||
BTend = BTstart + 16
|
||||
# print(f"slice Ad_modify dim[0, {Tend-needMaskRow}:{Tend}, {loopY}, {BTstart}:{BTend}]")
|
||||
if (Tend > T): # bondary mask
|
||||
needMaskRow = Tend - T
|
||||
Ad_modify[0, Tstart:Tend, loopY, BTstart:BTend] = subAd[:T-Tstart, :]
|
||||
else:
|
||||
# assert (Ad_modify[0, Tstart:Tend, loopY, BTstart:BTend].shape == subAd.shape)
|
||||
Ad_modify[0, Tstart:Tend, loopY, BTstart:BTend] = subAd
|
||||
|
||||
# if BT == 16:
|
||||
# return Ad
|
||||
|
||||
return Ad_modify
|
||||
|
||||
# @input_guard
|
||||
def solve_tril(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
output_dtype: torch.dtype = torch.float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the inverse of the lower triangular matrix
|
||||
A should be strictly lower triangular, i.e., A.triu() == 0.
|
||||
|
||||
Args:
|
||||
A (torch.Tensor):
|
||||
[B, T, H, K]
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float`
|
||||
|
||||
Returns:
|
||||
(I + A)^-1 with the same shape as A
|
||||
"""
|
||||
assert A.shape[-1] in [16, 32, 64]
|
||||
|
||||
B, T, H, BT = A.shape
|
||||
# cnt = 0
|
||||
# for b in range(B):
|
||||
# for t in range(T):
|
||||
# for h in range(H):
|
||||
# for d in range(BT):
|
||||
# A[b, t, h, d] = cnt
|
||||
# cnt += 1
|
||||
|
||||
Ad = -999 * torch.ones(
|
||||
B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype
|
||||
)
|
||||
# cnt = 0
|
||||
# for b in range(B):
|
||||
# for t in range(T):
|
||||
# for h in range(H):
|
||||
# for d in range(16):
|
||||
# Ad[b, t, h, d] = cnt
|
||||
# cnt += 1
|
||||
|
||||
Ad_modify = Ad.clone()
|
||||
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None
|
||||
)
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16)
|
||||
|
||||
import os
|
||||
if os.getenv("TRITON_INTERPRET", None) == "1":
|
||||
solve_tril_16x16_kernel[NT, B * H](
|
||||
A=A,
|
||||
Ad=Ad,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
num_warps=1,
|
||||
num_stages=4,
|
||||
)
|
||||
return Ad
|
||||
|
||||
Ad_modify = solve_tril_16x16_kernel_new(
|
||||
NT,
|
||||
B,
|
||||
A=A,
|
||||
Ad=Ad_modify,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
IS_VARLEN= True if cu_seqlens is not None else False,
|
||||
# num_warps=1,
|
||||
# num_stages=4,
|
||||
).to(A.dtype)
|
||||
return Ad_modify
|
||||
85
vllm_kunlun/ops/fla/torch_fla.py
Normal file
85
vllm_kunlun/ops/fla/torch_fla.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
|
||||
"""This function is intended to align with the l2norm implementation in the FLA library."""
|
||||
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
||||
return x * inv_norm
|
||||
|
||||
def torch_chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
g,
|
||||
beta,
|
||||
chunk_size=64,
|
||||
initial_state=None,
|
||||
output_final_state=False,
|
||||
use_qk_l2norm_in_kernel=False,
|
||||
):
|
||||
initial_dtype = query.dtype
|
||||
if use_qk_l2norm_in_kernel:
|
||||
query = l2norm(query, dim=-1, eps=1e-6)
|
||||
key = l2norm(key, dim=-1, eps=1e-6)
|
||||
query, key, value, beta, g = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
|
||||
]
|
||||
|
||||
batch_size, num_heads, sequence_length, k_head_dim = key.shape
|
||||
v_head_dim = value.shape[-1]
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
query = F.pad(query, (0, 0, 0, pad_size))
|
||||
key = F.pad(key, (0, 0, 0, pad_size))
|
||||
value = F.pad(value, (0, 0, 0, pad_size))
|
||||
beta = F.pad(beta, (0, pad_size))
|
||||
g = F.pad(g, (0, pad_size))
|
||||
total_sequence_length = sequence_length + pad_size
|
||||
scale = 1 / (query.shape[-1] ** 0.5)
|
||||
query = query * scale
|
||||
|
||||
v_beta = value * beta.unsqueeze(-1)
|
||||
k_beta = key * beta.unsqueeze(-1)
|
||||
# reshape to chunks
|
||||
query, key, value, k_beta, v_beta = [
|
||||
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
|
||||
]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
|
||||
|
||||
# chunk decay
|
||||
g = g.cumsum(dim=-1)
|
||||
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
|
||||
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = attn[..., i, :i].clone()
|
||||
sub = attn[..., :i, :i].clone()
|
||||
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||
value = attn @ v_beta
|
||||
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
||||
last_recurrent_state = (
|
||||
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
|
||||
if initial_state is None
|
||||
else initial_state.to(value)
|
||||
)
|
||||
core_attn_out = torch.zeros_like(value)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
|
||||
|
||||
# for each chunk
|
||||
for i in range(0, total_sequence_length // chunk_size):
|
||||
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
||||
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
v_new = v_i - v_prime
|
||||
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
last_recurrent_state = (
|
||||
last_recurrent_state * g[:, :, i, -1, None, None].exp()
|
||||
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
|
||||
)
|
||||
|
||||
if not output_final_state:
|
||||
last_recurrent_state = None
|
||||
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
|
||||
core_attn_out = core_attn_out[:, :, :sequence_length]
|
||||
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
||||
return core_attn_out, last_recurrent_state
|
||||
180
vllm_kunlun/ops/fla/utils.py
Normal file
180
vllm_kunlun/ops/fla/utils.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
|
||||
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
|
||||
FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1"
|
||||
|
||||
SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
|
||||
|
||||
|
||||
def tensor_cache(
|
||||
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator that caches the most recent results of a function with tensor inputs.
|
||||
|
||||
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
||||
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
|
||||
|
||||
Args:
|
||||
fn (Callable[..., torch.Tensor]):
|
||||
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
||||
|
||||
Returns:
|
||||
Callable[..., torch.Tensor]:
|
||||
A wrapped version of the input function with single-entry caching.
|
||||
"""
|
||||
|
||||
cache_entries: tuple[Optional[tuple], Optional[dict], Any] = []
|
||||
cache_size = 4
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal cache_entries, cache_size
|
||||
for i, entry in enumerate(cache_entries):
|
||||
last_args, last_kwargs, last_result = entry
|
||||
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs) \
|
||||
and all(a is b for a, b in zip(args, last_args)) \
|
||||
and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
|
||||
cache_entries = cache_entries[:i] + cache_entries[i + 1:] + [
|
||||
(args, kwargs, last_result)
|
||||
]
|
||||
return last_result
|
||||
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
if len(cache_entries) >= cache_size:
|
||||
cache_entries = cache_entries[1:]
|
||||
cache_entries.append((args, kwargs, result))
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def input_guard(
|
||||
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
contiguous_args = (i if not isinstance(i, torch.Tensor) else
|
||||
i.contiguous() for i in args)
|
||||
contiguous_kwargs = {
|
||||
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
tensor = None
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
tensor = arg
|
||||
break
|
||||
if tensor is None:
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
tensor = value
|
||||
break
|
||||
|
||||
if tensor is not None:
|
||||
ctx = torch.cuda.device(tensor.device.index)
|
||||
else:
|
||||
ctx = contextlib.nullcontext()
|
||||
|
||||
with ctx:
|
||||
return fn(*contiguous_args, **contiguous_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_available_device() -> str:
|
||||
try:
|
||||
return triton.runtime.driver.active.get_current_target().backend
|
||||
except BaseException:
|
||||
return 'cpu'
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']:
|
||||
device = get_available_device()
|
||||
mapping = {
|
||||
"cuda": "nvidia",
|
||||
"hip": "amd",
|
||||
"xpu": "intel",
|
||||
}
|
||||
# return the mapped value, or the original if not found
|
||||
return mapping.get(device, device)
|
||||
|
||||
|
||||
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
||||
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
||||
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
||||
device = get_available_device() if get_available_device() != 'hip' else 'cuda'
|
||||
device_torch_lib = getattr(torch, device)
|
||||
device_platform = _check_platform()
|
||||
|
||||
is_amd = (device_platform == 'amd')
|
||||
is_intel = (device_platform == 'nvidia')
|
||||
is_nvidia = (device_platform == 'nvidia')
|
||||
is_intel_alchemist = (is_intel
|
||||
and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0))
|
||||
is_nvidia_hopper = (is_nvidia
|
||||
and ('NVIDIA H' in torch.cuda.get_device_name(0)
|
||||
or torch.cuda.get_device_capability()[0] >= 9))
|
||||
use_cuda_graph = (is_nvidia
|
||||
and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')
|
||||
|
||||
|
||||
def get_all_max_shared_mem():
|
||||
try:
|
||||
return [
|
||||
triton.runtime.driver.active.utils.get_device_properties(i)
|
||||
['max_shared_mem'] for i in range(device_torch_lib.device_count())
|
||||
]
|
||||
except BaseException:
|
||||
return [-1]
|
||||
|
||||
|
||||
class Backend(Enum):
|
||||
ADA = 101376 # RTX 4090
|
||||
AMPERE = 166912 # A100
|
||||
HOPPER = 232448 # H100
|
||||
DEFAULT = 102400 # Default
|
||||
|
||||
@classmethod
|
||||
def get_shared_memory(cls, arch: str) -> int:
|
||||
try:
|
||||
return cls[arch.upper()].value
|
||||
except KeyError:
|
||||
return cls.DEFAULT.value
|
||||
|
||||
|
||||
@functools.cache
|
||||
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
||||
try:
|
||||
device_shared_mem_list = get_all_max_shared_mem()
|
||||
max_shared_memory = device_shared_mem_list[tensor_idx]
|
||||
return max_shared_memory >= Backend.get_shared_memory(arch)
|
||||
except Exception:
|
||||
return False
|
||||
247
vllm_kunlun/ops/fla/wy_fast.py
Normal file
247
vllm_kunlun/ops/fla/wy_fast.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
|
||||
RESOLUTION = {
|
||||
torch.bool: 0,
|
||||
torch.int16: 0,
|
||||
torch.int32: 0,
|
||||
torch.int64: 0,
|
||||
torch.float16: 1e-3,
|
||||
torch.float32: 1.3e-6,
|
||||
torch.bfloat16: 0.016,
|
||||
torch.complex32: 1e-3,
|
||||
torch.complex64: 1.3e-6,
|
||||
}
|
||||
|
||||
def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
|
||||
assert res.dtype == dtype
|
||||
ref = ref.to(dtype)
|
||||
atol = 1e-3 * reduce_dim
|
||||
rtol = RESOLUTION[dtype]
|
||||
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for num_warps in [2, 4, 8]
|
||||
# for num_stages in [2, 3, 4]
|
||||
# ],
|
||||
# key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def recompute_u_fwd_kernel(
|
||||
k,
|
||||
v,
|
||||
beta,
|
||||
w,
|
||||
u,
|
||||
A,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
p_beta = tl.make_block_ptr(
|
||||
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
p_A = tl.make_block_ptr(
|
||||
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||
)
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(
|
||||
v + (bos * H + i_h) * V,
|
||||
(T, V),
|
||||
(H * V, 1),
|
||||
(i_t * BT, i_v * BV),
|
||||
(BT, BV),
|
||||
(1, 0),
|
||||
)
|
||||
p_u = tl.make_block_ptr(
|
||||
u + (bos * H + i_h) * V,
|
||||
(T, V),
|
||||
(H * V, 1),
|
||||
(i_t * BT, i_v * BV),
|
||||
(BT, BV),
|
||||
(1, 0),
|
||||
)
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
|
||||
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
||||
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for num_warps in [2, 4, 8]
|
||||
# for num_stages in [2, 3, 4]
|
||||
# ],
|
||||
# key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def recompute_w_fwd_kernel(
|
||||
k,
|
||||
v,
|
||||
beta,
|
||||
w,
|
||||
u,
|
||||
A,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
p_beta = tl.make_block_ptr(
|
||||
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
p_A = tl.make_block_ptr(
|
||||
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||
)
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(
|
||||
k + (bos * Hg + i_h // (H // Hg)) * K,
|
||||
(T, K),
|
||||
(Hg * K, 1),
|
||||
(i_t * BT, i_k * BK),
|
||||
(BT, BK),
|
||||
(1, 0),
|
||||
)
|
||||
p_w = tl.make_block_ptr(
|
||||
w + (bos * H + i_h) * K,
|
||||
(T, K),
|
||||
(H * K, 1),
|
||||
(i_t * BT, i_k * BK),
|
||||
(BT, BK),
|
||||
(1, 0),
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
|
||||
b_w = tl.dot(b_A, b_kb)
|
||||
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def recompute_w_u_fwd(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = A.shape[-1]
|
||||
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
BK = 64
|
||||
BV = 64
|
||||
u = torch.empty_like(v)
|
||||
w = k.new_empty(B, T, H, K)
|
||||
recompute_u_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
)
|
||||
recompute_w_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
)
|
||||
return w, u
|
||||
@@ -1,29 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Author: Dong Xinyu, Chen Zhennan, Bao Qian, Yuan Jizhong
|
||||
# Email: dongxinyu03@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""layer.py"""
|
||||
import torch
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE as VllmFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase as VllmFusedMoEMethodBase
|
||||
@@ -101,7 +85,6 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
) -> torch.Tensor:
|
||||
"""forward_kunlun"""
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
|
||||
if self.moe.use_ep:
|
||||
return ops.fused_moe_ep(x,
|
||||
layer.w13_weight,
|
||||
@@ -116,6 +99,96 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group
|
||||
)
|
||||
# fused_moe do not support expert number > 400
|
||||
elif layer.local_num_experts > 400:
|
||||
hidden_states = x
|
||||
global_num_experts = linear_weights.shape[0]
|
||||
M, N = hidden_states.shape
|
||||
hidden_dim = layer.w2_weight.shape[1]
|
||||
normed_score = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
num_blocks = 12
|
||||
block_statistic = torch.zeros(
|
||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
router_logits = router_logits.float()
|
||||
torch.ops._C.moe_softmax_topk_norm(
|
||||
x=router_logits,
|
||||
normed_score=normed_score,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=None,
|
||||
stable=True)
|
||||
|
||||
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
|
||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||
|
||||
y = torch.empty(M,top_k,
|
||||
layer.w13_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
moe_expand = moe_expand.view(M * top_k, hidden_dim)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=layer.w13_weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=y)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = (y.shape[:-1] + (d, ))
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.swiglu(y, out1)
|
||||
|
||||
out = torch.empty(M,top_k,
|
||||
layer.w2_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=out1,
|
||||
weight=layer.w2_weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=out)
|
||||
|
||||
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
|
||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
|
||||
|
||||
torch.ops._C.moe_post(
|
||||
x=out,
|
||||
moe_index=sorted_tokens_idx,
|
||||
normed_scale=normed_score,
|
||||
dequant_scale=dequant_scale,
|
||||
y=output
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return ops.fused_moe(x,
|
||||
layer.w13_weight,
|
||||
@@ -155,6 +228,7 @@ class FusedMoE(VllmFusedMoE):
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
num_redundant_experts: int = 0,
|
||||
is_sequence_parallel=False,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts, # Global number of experts
|
||||
@@ -189,7 +263,7 @@ class FusedMoE(VllmFusedMoE):
|
||||
# since model_config is not set in the pytest test.
|
||||
model_dtype = params_dtype
|
||||
|
||||
moe = FusedMoEConfig.make(
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
@@ -197,7 +271,7 @@ class FusedMoE(VllmFusedMoE):
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=model_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
quant_config=quant_config,
|
||||
# quant_config=quant_config,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
@@ -307,4 +381,35 @@ class FusedMoE(VllmFusedMoE):
|
||||
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
return final_hidden_states
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls,
|
||||
ckpt_gate_proj_name: str,
|
||||
ckpt_down_proj_name: str,
|
||||
ckpt_up_proj_name: str,
|
||||
num_experts: int,
|
||||
num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]:
|
||||
|
||||
num_physical_experts = num_experts + num_redundant_experts
|
||||
|
||||
# In the returned mapping:
|
||||
# - `expert_id` is the physical expert id
|
||||
# - `weight_name` contains the weight name of the logical expert
|
||||
# So that we should map the expert id to logical in `weight_name`
|
||||
physical_to_logical_map = \
|
||||
EplbState.build_initial_global_physical_to_logical_map(
|
||||
num_experts, num_redundant_experts)
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("experts.w13_" if weight_name
|
||||
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
|
||||
f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.",
|
||||
expert_id, shard_id) for expert_id in range(num_physical_experts)
|
||||
for shard_id, weight_name in [
|
||||
("w1", ckpt_gate_proj_name),
|
||||
("w2", ckpt_down_proj_name),
|
||||
("w3", ckpt_up_proj_name),
|
||||
]
|
||||
]
|
||||
|
||||
@@ -12,49 +12,101 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
|
||||
from vllm.model_executor.layers import layernorm
|
||||
from typing import Optional, Union
|
||||
import xtorch_ops
|
||||
|
||||
|
||||
def vllm_kunlun_forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""forward_cuda"""
|
||||
if x.is_contiguous() == False:
|
||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||
# so we must make it contiguous() manually
|
||||
x = x.contiguous()
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""forward_cuda"""
|
||||
if x.is_contiguous() == False:
|
||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||
# so we must make it contiguous() manually
|
||||
x = x.contiguous()
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
if residual is not None:
|
||||
# residual_output = torch.empty_like(residual)
|
||||
torch.ops._C.add_rmsnorm(
|
||||
|
||||
if residual is not None:
|
||||
# residual_output = torch.empty_like(residual)
|
||||
torch.ops._C.add_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
residual_output=residual,
|
||||
weight=self.weight.data,
|
||||
eps=self.variance_epsilon,
|
||||
output=x
|
||||
)
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
residual_output=residual,
|
||||
weight=self.weight.data,
|
||||
eps=self.variance_epsilon,
|
||||
output=x,
|
||||
self.weight.data,
|
||||
out,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.rmsnorm(
|
||||
x,
|
||||
self.weight.data,
|
||||
out,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out
|
||||
return out
|
||||
|
||||
|
||||
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
||||
@staticmethod
|
||||
def forward_xpu(
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if x.is_contiguous() == False:
|
||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||
# so we must make it contiguous() manually
|
||||
x = x.contiguous()
|
||||
|
||||
if residual is not None:
|
||||
torch.ops._C.add_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
residual_output=residual,
|
||||
weight=weight+1,
|
||||
eps=variance_epsilon,
|
||||
output=x
|
||||
)
|
||||
return x, residual
|
||||
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.rmsnorm(
|
||||
x,
|
||||
weight+1,
|
||||
out,
|
||||
variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if torch.compiler.is_compiling():
|
||||
self.forward_static = self.forward_xpu # only use in cudagraph
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
if not getattr(self, "_is_compiled", False):
|
||||
self.forward_static = torch.compile( # type: ignore
|
||||
self.forward_static, backend="aot_eager")
|
||||
self._is_compiled = True
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RMSNorm.forward = vllm_kunlun_forward_cuda
|
||||
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm
|
||||
File diff suppressed because it is too large
Load Diff
0
vllm_kunlun/ops/mamba/__init__.py
Normal file
0
vllm_kunlun/ops/mamba/__init__.py
Normal file
1217
vllm_kunlun/ops/mamba/causal_conv1d.py
Normal file
1217
vllm_kunlun/ops/mamba/causal_conv1d.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,6 @@ _PARTITION_SIZE = 512
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
@@ -53,18 +52,18 @@ class PagedAttention:
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
"""
|
||||
Get the shape of the KV cache. Returns different shapes based on whether the computation is on-chip.
|
||||
If on-chip (is_kunlun() is True), returns shape (2, num_blocks, num_kv_heads, block_size, head_size);
|
||||
Otherwise, returns shape (2, num_blocks, block_size * num_kv_heads * head_size).
|
||||
|
||||
获取KV缓存的形状,根据是否在芯片上进行计算返回不同的形状。
|
||||
如果在芯片上(is_kunlun()为True),则返回形状(2, num_blocks, num_kv_heads, block_size, head_size);
|
||||
否则,返回形状(2, num_blocks, block_size * num_kv_heads * head_size)。
|
||||
|
||||
Args:
|
||||
num_blocks (int): The number of blocks.
|
||||
block_size (int): The size of each block.
|
||||
num_kv_heads (int): The number of KV heads.
|
||||
head_size (int): The size of each head.
|
||||
|
||||
num_blocks (int): 块数量。
|
||||
block_size (int): 每个块大小。
|
||||
num_kv_heads (int): KV头数量。
|
||||
head_size (int): 每个头大小。
|
||||
|
||||
Returns:
|
||||
Tuple[int, ...]: The shape of the KV cache, including two elements: the first element is 2, indicating the number of dimensions is 2; the second element is one of num_blocks, num_kv_heads, block_size, and head_size.
|
||||
Tuple[int, ...]: KV缓存的形状,包括两个元素:第一个元素为2,表示维度数量为2;第二个元素为num_blocks、num_kv_heads、block_size和head_size中的任意一个。
|
||||
"""
|
||||
if current_platform.is_kunlun():
|
||||
return (2, num_blocks, num_kv_heads, block_size, head_size)
|
||||
@@ -77,20 +76,20 @@ class PagedAttention:
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Split a cached tensor (containing key and value) into two parts, each part is a tensor.
|
||||
If running on KUNLUN, the first returned tensor is the key cache, and the second tensor is the value cache.
|
||||
Otherwise, the first tensor is the key cache, and the second tensor is a view of the key cache with shape (num_blocks, num_kv_heads, head_size//x, -1, x),
|
||||
and the third tensor is the value cache with shape (num_blocks, num_kv_heads, head_size, -1).
|
||||
|
||||
将一个缓存张量(包含key和value)分成两部分,每个部分是一个张量。
|
||||
如果在KUNLUN上运行,则返回的第一个张量是key缓存,第二个张量是value缓存。
|
||||
否则,第一个张量是key缓存,第二个张量是key缓存的view,其形状为(num_blocks, num_kv_heads, head_size//x, -1, x),
|
||||
第三个张量是value缓存,其形状为(num_blocks, num_kv_heads, head_size, -1)。
|
||||
|
||||
Args:
|
||||
kv_cache (torch.Tensor): A tensor containing key and value, with shape (2, num_blocks, kv_cache_size).
|
||||
num_kv_heads (int): The number of heads in multi-head attention.
|
||||
head_size (int): The size of each head.
|
||||
|
||||
kv_cache (torch.Tensor): 包含key和value的张量,形状为(2, num_blocks, kv_cache_size)。
|
||||
num_kv_heads (int): 多头注意力中的头数。
|
||||
head_size (int): 每个头的大小。
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
- key_cache (torch.Tensor): A tensor containing the key cache, with shape (num_blocks, num_kv_heads, head_size//x, -1, x).
|
||||
- value_cache (torch.Tensor): A tensor containing the value cache, with shape (num_blocks, num_kv_heads, head_size, -1).
|
||||
- key_cache (torch.Tensor): 形状为(num_blocks, num_kv_heads, head_size//x, -1, x),包含key缓存。
|
||||
- value_cache (torch.Tensor): 形状为(num_blocks, num_kv_heads, head_size, -1),包含value缓存。
|
||||
"""
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
@@ -100,7 +99,8 @@ class PagedAttention:
|
||||
value_cache = kv_cache[1]
|
||||
else:
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x)
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
@@ -152,17 +152,16 @@ class PagedAttention:
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
assert (
|
||||
blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0
|
||||
), (
|
||||
f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables."
|
||||
)
|
||||
assert (blocksparse_block_size > 0 and
|
||||
blocksparse_block_size % block_size == 0), \
|
||||
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables.")
|
||||
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
@@ -170,10 +169,9 @@ class PagedAttention:
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
use_v1 = max_seq_len <= 8192 and (
|
||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||
)
|
||||
|
||||
use_v1 = (max_seq_len <= 8192
|
||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ops.paged_attention_v1(
|
||||
@@ -302,4 +300,4 @@ class PagedAttention:
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
@@ -1,128 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Li Wei, Pan Xiakai, You Zeyu
|
||||
# Email: liwei157@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from typing import Optional
|
||||
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||
|
||||
|
||||
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4):
|
||||
"""Convert AWQ-packed int4 weights to Kunlun XPU format.
|
||||
Input: packed[N, K], dtype=int32, saved as AWQ order
|
||||
Output: packed_reordered[N, K], dtype=int32, saved as Kunlun order
|
||||
"""
|
||||
N, K = packed.shape
|
||||
self.align_type = 1 if K % 8 == 0 else 0
|
||||
assert num_bits == 4, "Only int4 supported now"
|
||||
shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32)
|
||||
|
||||
if self.align_type == 0: # NORMAL MODE
|
||||
# Unpack AWQ order:[0, 2, 4, 6, 1, 3, 5, 7]
|
||||
unpacked_awq = (packed.unsqueeze(-1) >> shifts) & 0xF # [N, K, 8]
|
||||
|
||||
# Reverse AWQ order and convert to KUNLUN order
|
||||
AWQ_TO_KUNLUN_ORDER_NORMAL = [4, 0, 5, 1, 6, 2, 7, 3]
|
||||
# [0,2,4,6,1,3,5,7] --> [1, 0, 3, 2, 5, 4, 7, 6]
|
||||
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_NORMAL] # [N, K, 8]
|
||||
|
||||
# Pack to int32, order[6, 7, 4, 5, 2, 3, 0, 1]
|
||||
packed_kunlun = (unpacked_kunlun << shifts).sum(
|
||||
dim=-1, dtype=torch.int32
|
||||
) # [N, K]
|
||||
elif self.align_type == 1: # FAST MODEL
|
||||
# Unpack AWQ order
|
||||
unpacked_awq = (
|
||||
packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts
|
||||
) & 0xF # [N, K//8, 8, 8]
|
||||
|
||||
# Reverse AWQ order and convert to KUNLUN order
|
||||
AWQ_TO_KUNLUN_ORDER_FAST = [
|
||||
32, 0, 36, 4, 33, 1, 37, 5,
|
||||
34, 2, 38, 6, 35, 3, 39, 7,
|
||||
40, 8, 44, 12, 41, 9, 45, 13,
|
||||
42, 10, 46, 14, 43, 11, 47, 15,
|
||||
48, 16, 52, 20, 49, 17, 53, 21,
|
||||
50, 18, 54, 22, 51, 19, 55, 23,
|
||||
56, 24, 60, 28, 57, 25, 61, 29,
|
||||
58, 26, 62, 30, 59, 27, 63, 31
|
||||
]
|
||||
unpacked_awq = unpacked_awq.reshape(N, K // 8, 64)
|
||||
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
|
||||
|
||||
# Pack to int32
|
||||
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
|
||||
packed_kunlun = (
|
||||
(unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K)
|
||||
) # [N, K]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return packed_kunlun
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.qweight = torch.nn.Parameter(
|
||||
(
|
||||
self.repack_int4_for_kunlun(layer.qweight.data)
|
||||
if layer.qweight.data.dtype == torch.int32
|
||||
else layer.qweight.data
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.qzeros = torch.nn.Parameter(
|
||||
(
|
||||
self.repack_int4_for_kunlun(layer.qzeros.data)
|
||||
if layer.qzeros.data.dtype == torch.int32
|
||||
else layer.qzeros.data
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
|
||||
def apply(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
qzeros = layer.qzeros
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
# num_tokens >= threshold
|
||||
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
|
||||
|
||||
if FP16_MATMUL_HEURISTIC_CONDITION:
|
||||
out = torch.ops._C.awq_dequantize(
|
||||
qweight, scales, qzeros, quant_type=0, align_type=self.align_type
|
||||
)
|
||||
out = torch.matmul(reshaped_x, out)
|
||||
else:
|
||||
out = torch.ops._C.awq_gemm(
|
||||
reshaped_x, qweight, scales, qzeros, align_type=self.align_type
|
||||
)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out.reshape(out_shape)
|
||||
|
||||
|
||||
AWQLinearMethod.repack_int4_for_kunlun = repack_int4_for_kunlun
|
||||
AWQLinearMethod.process_weights_after_loading = process_weights_after_loading
|
||||
AWQLinearMethod.apply = apply
|
||||
@@ -1,37 +1,14 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
# Author: Chen Zhennan, Dong Xinyu
|
||||
# Email: chenzhennan@baidu.com
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from typing import Any, Literal, Optional, cast, Callable, Optional
|
||||
|
||||
from compressed_tensors.config import (
|
||||
CompressionFormat,
|
||||
SparsityCompressionConfig,
|
||||
SparsityStructure,
|
||||
)
|
||||
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from compressed_tensors.config import (CompressionFormat,
|
||||
SparsityCompressionConfig,
|
||||
SparsityStructure)
|
||||
from compressed_tensors.quantization import (ActivationOrdering,
|
||||
QuantizationStrategy)
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
|
||||
# TODO: import position will be changed after 0.9.0
|
||||
# vllm.model_executor.layers.fused_moe.fused_moe --> vllm.model_executor.layers.fused_moe
|
||||
|
||||
@@ -42,7 +19,6 @@ import xtorch_ops
|
||||
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
|
||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def get_moe_method(quant_config, layer) -> "CompressedTensorsMoEMethod":
|
||||
@@ -50,239 +26,177 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
linear_cfg = None
|
||||
for k in ("Linear", "FusedMoE", "MoE", "Moe", "Experts"):
|
||||
if k in tsm and isinstance(tsm[k], dict):
|
||||
linear_cfg = tsm[k]
|
||||
break
|
||||
linear_cfg = tsm[k]; break
|
||||
if not linear_cfg:
|
||||
# print("target_scheme_map missing; fallback to INT8(W8A8) method")
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
wq = linear_cfg.get("weights")
|
||||
aq = linear_cfg.get("input_activations")
|
||||
wq = linear_cfg.get("weights"); aq = linear_cfg.get("input_activations")
|
||||
if not wq or not aq:
|
||||
# print("incomplete scheme; fallback to INT8(W8A8)")
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
|
||||
# Other branches are handled as needed; default fallback:
|
||||
# 其它分流按需;默认回落:
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
|
||||
|
||||
# copied from vllm 0.9.0
|
||||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Directly create a default quantization config dictionary to avoid validation issues with QuantizationArgs
|
||||
|
||||
# 直接创建默认的量化配置字典,避免 QuantizationArgs 的验证问题
|
||||
# print("Creating default INT8 quantization config for MoE")
|
||||
|
||||
# 创建默认的权重量化配置字典
|
||||
self.weight_quant = type('WeightQuant', (), {
|
||||
'type': 'int',
|
||||
'num_bits': 8,
|
||||
'strategy': 'channel',
|
||||
'group_size': 128,
|
||||
'symmetric': True,
|
||||
'dynamic': False,
|
||||
'actorder': 'none',
|
||||
'observer': None,
|
||||
'observer_kwargs': {},
|
||||
'block_structure': None
|
||||
})()
|
||||
|
||||
# 创建默认的输入激活量化配置字典
|
||||
self.input_quant = type('InputQuant', (), {
|
||||
'type': 'int',
|
||||
'num_bits': 8,
|
||||
'strategy': 'token',
|
||||
'group_size': 128,
|
||||
'symmetric': True,
|
||||
'dynamic': True,
|
||||
'actorder': 'none',
|
||||
'observer': None,
|
||||
'observer_kwargs': {},
|
||||
'block_structure': None
|
||||
})()
|
||||
|
||||
# Create a default weight quantization config dictionary
|
||||
self.weight_quant = type(
|
||||
"WeightQuant",
|
||||
(),
|
||||
{
|
||||
"type": "int",
|
||||
"num_bits": 8,
|
||||
"strategy": "channel",
|
||||
"group_size": 128,
|
||||
"symmetric": True,
|
||||
"dynamic": False,
|
||||
"actorder": "none",
|
||||
"observer": None,
|
||||
"observer_kwargs": {},
|
||||
"block_structure": None,
|
||||
},
|
||||
)()
|
||||
|
||||
# Create a default input activation quantization config dictionary
|
||||
self.input_quant = type(
|
||||
"InputQuant",
|
||||
(),
|
||||
{
|
||||
"type": "int",
|
||||
"num_bits": 8,
|
||||
"strategy": "token",
|
||||
"group_size": 128,
|
||||
"symmetric": True,
|
||||
"dynamic": True,
|
||||
"actorder": "none",
|
||||
"observer": None,
|
||||
"observer_kwargs": {},
|
||||
"block_structure": None,
|
||||
},
|
||||
)()
|
||||
|
||||
# Change comparison method to directly compare strings
|
||||
# 修改比较方式,直接比较字符串
|
||||
per_channel = (
|
||||
self.weight_quant.strategy == "channel"
|
||||
and self.input_quant.strategy == "token"
|
||||
)
|
||||
and self.input_quant.strategy == "token")
|
||||
if not per_channel:
|
||||
raise ValueError(
|
||||
"For INT8 Fused MoE layers, we require channelwise, "
|
||||
"dynamic per token quantization. Found "
|
||||
f"{self.weight_quant}, {self.input_quant}"
|
||||
)
|
||||
f"{self.weight_quant}, {self.input_quant}")
|
||||
|
||||
self.static_input_scales = not self.input_quant.dynamic
|
||||
if self.static_input_scales:
|
||||
raise ValueError(
|
||||
"For INT8 Fused MoE layers, we require channelwise, "
|
||||
"dynamic per token quantization. Found static input scales."
|
||||
)
|
||||
"dynamic per token quantization. Found static input scales.")
|
||||
|
||||
def create_weights1(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Use float32 as a placeholder for weights to facilitate loading original weights from ckpt
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
), # generally is torch.bfloat16
|
||||
requires_grad=False,
|
||||
)
|
||||
def create_weights1(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
# 权重先用浮点占位,便于从 ckpt 加载原始权重
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype), # 通常是 torch.bfloat16
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# Channel scale: float32 + 2D [E, out] (aligned with fused_moe/UT)
|
||||
# 通道 scale:float32 + 二维 [E, out](与 fused_moe/UT 对齐)
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# Input scale can be dynamically calculated
|
||||
# 输入 scale 动态计算即可
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=torch.int8,
|
||||
), # directly use int8
|
||||
requires_grad=False,
|
||||
)
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=torch.int8), # 直接使用 int8
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8,
|
||||
), # directly use int8
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8), # 直接使用 int8
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# Scale factors
|
||||
# 缩放因子
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# Input scale can be dynamically calculated
|
||||
# 输入 scale 动态计算
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
return
|
||||
# Convert original weights to float32 for more robust statistics
|
||||
#原始权重转 float32 做统计更稳健
|
||||
w13_f = layer.w13_weight.float()
|
||||
w2_f = layer.w2_weight.float()
|
||||
w2_f = layer.w2_weight.float()
|
||||
|
||||
# Each column (abs_max) -> per-column scale (out dimension is dim=1, column is dim=-1)
|
||||
# 每列(abs_max) -> per-column scale(out 维在 dim=1,列在 dim=-1)
|
||||
qmax = 127.0
|
||||
w13_abs_max = torch.amax(torch.abs(w13_f), dim=-1) # [E, 2N]
|
||||
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H]
|
||||
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H]
|
||||
|
||||
w13_scale_2d = torch.clamp(w13_abs_max, min=1e-6) / qmax # [E, 2N], float32
|
||||
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32
|
||||
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32
|
||||
|
||||
# Quantization: broadcast 3D scale and store back to 2D scale
|
||||
# 量化:用 3D scale 广播,存回 2D scale
|
||||
w13_scale_3d = w13_scale_2d.unsqueeze(-1) # [E, 2N, 1]
|
||||
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1]
|
||||
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1]
|
||||
|
||||
w13_q = torch.round(w13_f / w13_scale_3d).clamp_(-128, 127).to(torch.int8)
|
||||
w2_q = torch.round(w2_f / w2_scale_3d).clamp_(-128, 127).to(torch.int8)
|
||||
w2_q = torch.round(w2_f / w2_scale_3d ).clamp_(-128, 127).to(torch.int8)
|
||||
|
||||
# Optional: If your fused/kernel expects scale pre-multiplied by 127 (to be consistent with some UT backends), uncomment the following two lines:
|
||||
# 可选:若你的 fused/kernel 期望 scale 预乘 127(与某些 UT 后端一致),打开下面两行:
|
||||
w13_scale_2d = w13_scale_2d * 127.0
|
||||
w2_scale_2d = w2_scale_2d * 127.0
|
||||
w2_scale_2d = w2_scale_2d * 127.0
|
||||
|
||||
# Write back parameters: weight int8; scale uses float32 + 2D
|
||||
replace_parameter(
|
||||
layer, "w13_weight", torch.nn.Parameter(w13_q, requires_grad=False)
|
||||
)
|
||||
replace_parameter(
|
||||
layer, "w2_weight", torch.nn.Parameter(w2_q, requires_grad=False)
|
||||
)
|
||||
replace_parameter(
|
||||
layer,
|
||||
"w13_weight_scale",
|
||||
torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False),
|
||||
)
|
||||
replace_parameter(
|
||||
layer,
|
||||
"w2_weight_scale",
|
||||
torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False),
|
||||
)
|
||||
|
||||
# Brief check
|
||||
print(
|
||||
f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}"
|
||||
)
|
||||
# 回写参数:权重 int8;scale 用 float32 + 2D
|
||||
replace_parameter(layer, 'w13_weight', torch.nn.Parameter(w13_q, requires_grad=False))
|
||||
replace_parameter(layer, 'w2_weight', torch.nn.Parameter(w2_q, requires_grad=False))
|
||||
replace_parameter(layer, 'w13_weight_scale',
|
||||
torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False))
|
||||
replace_parameter(layer, 'w2_weight_scale',
|
||||
torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False))
|
||||
|
||||
# 简要检查
|
||||
print(f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -300,11 +214,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False, # Add this parameter
|
||||
expert_load_view: Optional[torch.Tensor] = None, # Add this parameter
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None, # Add this parameter
|
||||
logical_replica_count: Optional[torch.Tensor] = None, # Add this parameter
|
||||
linear_weights: Optional[torch.Tensor] = None, # Add this parameter
|
||||
enable_eplb: bool = False, # 添加这个参数
|
||||
expert_load_view: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
logical_replica_count: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
linear_weights: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
) -> torch.Tensor:
|
||||
|
||||
output = torch.empty_like(x)
|
||||
@@ -326,8 +240,5 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
print(
|
||||
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
|
||||
--> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod"
|
||||
)
|
||||
print("[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
|
||||
--> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod")
|
||||
@@ -1,108 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Li Wei, You Zeyu
|
||||
# Email: liwei157@baidu.com, youzeyu@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Optional
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod, ExllamaState
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# for torch.compile
|
||||
layer.qzeros = Parameter(
|
||||
self.repack_int4_for_kunlun(layer.qzeros.data, self.quant_config.weight_bits)
|
||||
if self.quant_config.weight_bits == 4 else layer.qzeros.data,
|
||||
requires_grad=False
|
||||
)
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
||||
if self.quant_config.desc_act:
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
else:
|
||||
layer.g_idx.data = torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=layer.g_idx.device)
|
||||
layer.exllama_state = ExllamaState.READY
|
||||
|
||||
# No need shuffle on xpu
|
||||
# ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
||||
# self.quant_config.weight_bits)
|
||||
|
||||
|
||||
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4):
|
||||
N, K = packed.shape
|
||||
assert num_bits == 4, "Only int4 supported now"
|
||||
shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32)
|
||||
|
||||
# Unpack int32 to int4 values
|
||||
unpacked_gptq = (
|
||||
packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts
|
||||
) & 0xF # [N, K//8, 8, 8]
|
||||
|
||||
# Convert to KUNLUN order
|
||||
GPTQ_TO_KUNLUN_ORDER_FAST = [
|
||||
32, 0, 33, 1, 34, 2, 35, 3,
|
||||
36, 4, 37, 5, 38, 6, 39, 7,
|
||||
40, 8, 41, 9, 42, 10, 43, 11,
|
||||
44, 12, 45, 13, 46, 14, 47, 15,
|
||||
48, 16, 49, 17, 50, 18, 51, 19,
|
||||
52, 20, 53, 21, 54, 22, 55, 23,
|
||||
56, 24, 57, 25, 58, 26, 59, 27,
|
||||
60, 28, 61, 29, 62, 30, 63, 31,
|
||||
]
|
||||
unpacked_gptq = unpacked_gptq.reshape(N, K // 8, 64)
|
||||
unpacked_kunlun = unpacked_gptq[..., GPTQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
|
||||
|
||||
# Pack to int32
|
||||
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
|
||||
packed_kunlun = (
|
||||
(unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K)
|
||||
) # [N, K]
|
||||
|
||||
return packed_kunlun
|
||||
|
||||
|
||||
def apply(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
output = torch.ops.xspeedgate_ops.gptq_gemm(
|
||||
reshaped_x,
|
||||
layer.qweight,
|
||||
layer.qzeros,
|
||||
layer.scales,
|
||||
layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.quant_config.weight_bits,
|
||||
)
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
GPTQLinearMethod.repack_int4_for_kunlun = repack_int4_for_kunlun
|
||||
GPTQLinearMethod.process_weights_after_loading = process_weights_after_loading
|
||||
GPTQLinearMethod.apply = apply
|
||||
@@ -12,35 +12,33 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
import xspeedgate_ops
|
||||
import os
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
RotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding,
|
||||
DynamicNTKScalingRotaryEmbedding,
|
||||
MRotaryEmbedding,
|
||||
)
|
||||
RotaryEmbedding, YaRNScalingRotaryEmbedding, DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding)
|
||||
from typing import Optional, Tuple
|
||||
import xtorch_ops
|
||||
|
||||
|
||||
def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
if hasattr(self, "scaling_factor"):
|
||||
self.max_position_embeddings = int(
|
||||
self.max_position_embeddings * self.scaling_factor
|
||||
)
|
||||
if hasattr(self, 'scaling_factor'):
|
||||
self.max_position_embeddings = int(self.max_position_embeddings * self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
if os.getenv("FUSED_QK_ROPE_OP") == "1":
|
||||
#对于glm4-9b-chat,rope跑forward_native,所以需要cache保持特定的形状,这里通过环境变量控制
|
||||
#对于qwen2.5-vl,rope跑mrope,也需要cache保持特定的形状
|
||||
#也就是说跑glm4-9b-chat、qwen2.5-vl,需要设置GLM4_CHAT环境变量为1
|
||||
if os.getenv('ROPE_NATIVE_2D') == "1":
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
if os.getenv('USE_ORI_ROPE') == "0":
|
||||
cache_cos = torch.cat((cos, cos), dim=-1)
|
||||
cache_sin = torch.cat((sin, sin), dim=-1)
|
||||
# [2, self.max_position_embeddings, self.rotary_dim * 2]
|
||||
@@ -51,89 +49,108 @@ def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
|
||||
|
||||
def vllm_kunlun_forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""forward_cuda"""
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""forward_cuda"""
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
|
||||
if (
|
||||
self.cos_sin_cache.device != query.device
|
||||
or self.cos_sin_cache.dtype != query.dtype
|
||||
):
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if offsets is not None:
|
||||
ops.batched_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
self.rotary_dim,
|
||||
offsets,
|
||||
)
|
||||
if self.cos_sin_cache.device != query.device or \
|
||||
self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||
dtype=query.dtype)
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if offsets is not None:
|
||||
ops.batched_rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style, self.rotary_dim,
|
||||
offsets)
|
||||
else:
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
def apply_interleaved_rope(x: torch.Tensor,
|
||||
mrope_section: list[int]) -> torch.Tensor:
|
||||
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
||||
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
|
||||
interleaved [THTHWHTHW...TT], preserving frequency continuity.
|
||||
"""
|
||||
x_t = x[0].clone()
|
||||
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
|
||||
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
|
||||
return x_t
|
||||
|
||||
def vllm_kunlun_apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
|
||||
is_neox_style: bool) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
query, key = ops.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
def vllm_kunlun_mrope_forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Args:
|
||||
positions:
|
||||
[num_tokens,] (text only) or
|
||||
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
Args:
|
||||
positions:
|
||||
[num_tokens,] (text only) or
|
||||
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
assert positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
query, key = torch.ops.xspeedgate_ops.mrotary_embedding_fwd_v0(
|
||||
query,
|
||||
key,
|
||||
positions.to(dtype=torch.int32),
|
||||
self.cos_sin_cache,
|
||||
self.mrope_interleaved,
|
||||
self.is_neox_style,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.mrope_section[0],
|
||||
self.mrope_section[1],
|
||||
self.mrope_section[2]
|
||||
)
|
||||
|
||||
assert positions.ndim == 2
|
||||
assert key is not None
|
||||
return query, key
|
||||
|
||||
query, key = torch.ops.xspeedgate_ops.mrotary_embedding_fwd_v0(
|
||||
query,
|
||||
key,
|
||||
positions.to(dtype=torch.int32),
|
||||
self.cos_sin_cache,
|
||||
False, # self.mrope_interleaved,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.mrope_section[0],
|
||||
self.mrope_section[1],
|
||||
self.mrope_section[2],
|
||||
)
|
||||
|
||||
return query, key
|
||||
|
||||
|
||||
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||
if os.getenv("KUNLUN_ENABLE_MULTI_LORA") == "1" or os.getenv("FUSED_QK_ROPE_OP") == "1":
|
||||
RotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
else:
|
||||
pass
|
||||
# RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||
# RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||
# RotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
||||
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
||||
# MRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq
|
||||
# YaRNScalingRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
|
||||
|
||||
def Split_Norm_Rope(
|
||||
@@ -145,36 +162,27 @@ def Split_Norm_Rope(
|
||||
max_position_embeddings: int,
|
||||
q_head_num: int,
|
||||
kv_head_num: int,
|
||||
head_dim: int,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
head_dim:int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
num_tokens = qkv.shape[0]
|
||||
rotary_dim = head_dim
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
q_emb_out = torch.empty(
|
||||
(num_tokens, q_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
|
||||
)
|
||||
k_emb_out = torch.empty(
|
||||
(num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
|
||||
)
|
||||
v_out = torch.empty(
|
||||
(num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
|
||||
)
|
||||
rotary_dim=head_dim
|
||||
q_emb_out = torch.empty((num_tokens, q_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
|
||||
k_emb_out = torch.empty((num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
|
||||
v_out = torch.empty((num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
|
||||
torch.ops._C.split_norm_rope_neox(
|
||||
q_emb_out,
|
||||
k_emb_out,
|
||||
v_out,
|
||||
qkv,
|
||||
cos_sin_cache,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
positions,
|
||||
num_tokens,
|
||||
max_position_embeddings,
|
||||
q_head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
)
|
||||
return q_emb_out, k_emb_out, v_out
|
||||
q_emb_out,
|
||||
k_emb_out,
|
||||
v_out,
|
||||
qkv,
|
||||
cos_sin_cache,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
positions,
|
||||
num_tokens,
|
||||
max_position_embeddings,
|
||||
q_head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
)
|
||||
return q_emb_out, k_emb_out, v_out
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user