301 lines
11 KiB
Python
301 lines
11 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from sglang.srt.utils import get_bool_env_var, is_npu
|
|
|
|
_is_npu = is_npu()
|
|
_ENABLE_MLA_PREPROCESS_FLAG = get_bool_env_var("SGLANG_NPU_USE_MLAPO")
|
|
_NPU_FORMAT_NZ = 29
|
|
|
|
|
|
def is_mla_preprocess_enabled() -> bool:
|
|
return _is_npu and _ENABLE_MLA_PREPROCESS_FLAG
|
|
|
|
|
|
if is_mla_preprocess_enabled():
|
|
import sgl_kernel_npu
|
|
import torch_npu
|
|
|
|
torch.npu.config.allow_internal_format = True
|
|
torch.npu.set_compile_mode(jit_compile=False)
|
|
|
|
|
|
def round_up(val: int, align: int) -> int:
|
|
if align == 0:
|
|
return 0
|
|
return -(val // -align) * align
|
|
|
|
|
|
def transdata(nd_mat, block_size: tuple = (16, 16)):
|
|
r = round_up(nd_mat.shape[0], block_size[0])
|
|
c = round_up(nd_mat.shape[1], block_size[1])
|
|
r_pad = r - nd_mat.shape[0]
|
|
c_pad = c - nd_mat.shape[1]
|
|
nd_mat = F.pad(nd_mat, ((0, r_pad, 0, c_pad)))
|
|
nz_mat = torch.permute(
|
|
torch.reshape(
|
|
nd_mat,
|
|
(r // block_size[0], block_size[0], c // block_size[1], block_size[1]),
|
|
),
|
|
[2, 0, 1, 3],
|
|
)
|
|
nz_mat = torch.reshape(
|
|
nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])
|
|
)
|
|
return nz_mat
|
|
|
|
|
|
def trans_rope_weight(weight, rope_dim):
|
|
weight_1 = weight[..., -rope_dim::2, :].contiguous()
|
|
weight_2 = weight[..., -rope_dim + 1 :: 2, :].contiguous()
|
|
weight[..., -rope_dim:, :] = torch.cat([weight_1, weight_2], dim=-2)
|
|
|
|
return weight.contiguous()
|
|
|
|
|
|
class NPUFusedMLAPreprocess(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
fused_qkv_a_proj_with_mqa,
|
|
q_a_layernorm,
|
|
kv_a_layernorm,
|
|
q_b_proj,
|
|
w_kc,
|
|
rotary_emb,
|
|
layer_id,
|
|
num_local_heads,
|
|
qk_nope_head_dim,
|
|
qk_rope_head_dim,
|
|
):
|
|
super().__init__()
|
|
self.qkv_a_proj = fused_qkv_a_proj_with_mqa
|
|
self.q_a_layernorm = q_a_layernorm
|
|
self.kv_a_layernorm = kv_a_layernorm
|
|
self.q_b_proj = q_b_proj
|
|
self.w_kc = w_kc.contiguous()
|
|
self.rotary_emb = rotary_emb
|
|
self.layer_id = layer_id
|
|
self.has_preprocess_weights = False
|
|
|
|
self.q_lora_rank = self.q_b_proj.input_size # 1536
|
|
self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512
|
|
self.num_local_heads = num_local_heads # tp
|
|
self.qk_nope_head_dim = qk_nope_head_dim # 128
|
|
self.qk_rope_head_dim = qk_rope_head_dim # 64
|
|
|
|
def preprocess_weights(self, hidden_states):
|
|
self.dummy = torch.empty(
|
|
(hidden_states.shape[-1]),
|
|
dtype=hidden_states.dtype,
|
|
device=hidden_states.device,
|
|
)
|
|
self.qkv_a_proj_input_offset = self.qkv_a_proj.input_offset.to(dtype=torch.int8)
|
|
self.q_b_proj_input_offset = self.q_b_proj.input_offset.to(dtype=torch.int8)
|
|
|
|
# matmul_0 weight [7168, 2112]
|
|
fused_qkv_a_proj_with_mqa_weight_q = self.qkv_a_proj.weight.data[
|
|
:, : self.q_lora_rank
|
|
].clone() # [7168, 1536]
|
|
fused_qkv_a_proj_with_mqa_weight_kv = self.qkv_a_proj.weight.data[
|
|
:, self.q_lora_rank :
|
|
].clone() # [7168, 576]
|
|
# rope fit
|
|
fused_qkv_a_proj_with_mqa_weight_kv_t = (
|
|
fused_qkv_a_proj_with_mqa_weight_kv.t().contiguous()
|
|
)
|
|
fused_qkv_a_proj_with_mqa_weight_kv_t = trans_rope_weight(
|
|
fused_qkv_a_proj_with_mqa_weight_kv_t, self.qk_rope_head_dim
|
|
)
|
|
fused_qkv_a_proj_with_mqa_weight_kv = (
|
|
fused_qkv_a_proj_with_mqa_weight_kv_t.t().contiguous()
|
|
)
|
|
# cat nz
|
|
fused_qkv_a_proj_with_mqa_weight_new = torch.cat(
|
|
(fused_qkv_a_proj_with_mqa_weight_kv, fused_qkv_a_proj_with_mqa_weight_q),
|
|
dim=-1,
|
|
)
|
|
fused_qkv_a_proj_with_mqa_weight = (
|
|
fused_qkv_a_proj_with_mqa_weight_new.t().contiguous()
|
|
)
|
|
fused_qkv_a_proj_with_mqa_weight_nz = (
|
|
transdata(fused_qkv_a_proj_with_mqa_weight, block_size=(16, 32))
|
|
.unsqueeze(0)
|
|
.contiguous()
|
|
)
|
|
self.qkv_a_proj_weight_nz = torch_npu.npu_format_cast(
|
|
fused_qkv_a_proj_with_mqa_weight_nz, _NPU_FORMAT_NZ
|
|
)
|
|
|
|
# matmul_0 deq_scale [2112]
|
|
fused_qkv_a_proj_with_mqa_deq_scale_q = self.qkv_a_proj.deq_scale.data[
|
|
: self.q_lora_rank
|
|
].clone() # [7168, 1536]
|
|
fused_qkv_a_proj_with_mqa_deq_scale_kv = self.qkv_a_proj.deq_scale.data[
|
|
self.q_lora_rank :
|
|
].clone() # [7168, 576]
|
|
# rope fit
|
|
fused_qkv_a_proj_with_mqa_deq_scale_kv = (
|
|
fused_qkv_a_proj_with_mqa_deq_scale_kv.reshape(
|
|
self.kv_lora_rank + self.qk_rope_head_dim, -1
|
|
).contiguous()
|
|
)
|
|
fused_qkv_a_proj_with_mqa_deq_scale_kv = trans_rope_weight(
|
|
fused_qkv_a_proj_with_mqa_deq_scale_kv, self.qk_rope_head_dim
|
|
)
|
|
fused_qkv_a_proj_with_mqa_deq_scale_kv = (
|
|
fused_qkv_a_proj_with_mqa_deq_scale_kv.view(
|
|
self.kv_lora_rank + self.qk_rope_head_dim
|
|
).contiguous()
|
|
)
|
|
self.qkv_a_proj_deq_scale_kvq = torch.cat(
|
|
(
|
|
fused_qkv_a_proj_with_mqa_deq_scale_kv,
|
|
fused_qkv_a_proj_with_mqa_deq_scale_q,
|
|
),
|
|
dim=-1,
|
|
)
|
|
|
|
# matmul_0 quant_bias [2112]
|
|
fused_qkv_a_proj_with_mqa_quant_bias_q = self.qkv_a_proj.quant_bias.data[
|
|
: self.q_lora_rank
|
|
].clone() # [7168, 1536]
|
|
fused_qkv_a_proj_with_mqa_quant_bias_kv = self.qkv_a_proj.quant_bias.data[
|
|
self.q_lora_rank :
|
|
].clone() # [7168, 576]
|
|
# rope fit
|
|
fused_qkv_a_proj_with_mqa_quant_bias_kv = (
|
|
fused_qkv_a_proj_with_mqa_quant_bias_kv.reshape(
|
|
self.kv_lora_rank + self.qk_rope_head_dim, -1
|
|
).contiguous()
|
|
)
|
|
fused_qkv_a_proj_with_mqa_quant_bias_kv = trans_rope_weight(
|
|
fused_qkv_a_proj_with_mqa_quant_bias_kv, self.qk_rope_head_dim
|
|
)
|
|
fused_qkv_a_proj_with_mqa_quant_bias_kv = (
|
|
fused_qkv_a_proj_with_mqa_quant_bias_kv.view(
|
|
self.kv_lora_rank + self.qk_rope_head_dim
|
|
).contiguous()
|
|
)
|
|
self.qkv_a_proj_quant_bias_kvq = torch.cat(
|
|
(
|
|
fused_qkv_a_proj_with_mqa_quant_bias_kv,
|
|
fused_qkv_a_proj_with_mqa_quant_bias_q,
|
|
),
|
|
dim=-1,
|
|
)
|
|
|
|
# matmul_1 weight [1536, num_head * 192]
|
|
q_b_proj_weight = self.q_b_proj.weight.data.clone()
|
|
q_b_proj_weight = q_b_proj_weight.t().reshape(
|
|
self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1
|
|
)
|
|
q_b_proj_weight = trans_rope_weight(q_b_proj_weight, self.qk_rope_head_dim)
|
|
q_b_proj_weight = q_b_proj_weight.reshape(
|
|
self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1
|
|
)
|
|
q_b_proj_weight_nz = (
|
|
transdata(q_b_proj_weight, block_size=(16, 32)).unsqueeze(0).contiguous()
|
|
)
|
|
self.q_b_proj_weight_nz = torch_npu.npu_format_cast(
|
|
q_b_proj_weight_nz, _NPU_FORMAT_NZ
|
|
)
|
|
|
|
# matmul_1 deq_scale [num_head * 192]
|
|
q_b_proj_deq_scale = self.q_b_proj.deq_scale.data.clone()
|
|
q_b_proj_deq_scale = q_b_proj_deq_scale.reshape(
|
|
self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1
|
|
)
|
|
q_b_proj_deq_scale = trans_rope_weight(
|
|
q_b_proj_deq_scale, self.qk_rope_head_dim
|
|
)
|
|
self.q_b_proj_deq_scale = q_b_proj_deq_scale.reshape(
|
|
self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)
|
|
)
|
|
|
|
# matmul_1 quant_bias [num_head * 192]
|
|
q_b_proj_quant_bias = self.q_b_proj.quant_bias.data.clone()
|
|
q_b_proj_quant_bias = q_b_proj_quant_bias.reshape(
|
|
self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1
|
|
)
|
|
q_b_proj_quant_bias = trans_rope_weight(
|
|
q_b_proj_quant_bias, self.qk_rope_head_dim
|
|
)
|
|
self.q_b_proj_quant_bias = q_b_proj_quant_bias.reshape(
|
|
self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)
|
|
)
|
|
|
|
def get_sin_cos(self, positions):
|
|
cos_sin = self.rotary_emb.cos_sin_cache[positions]
|
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
|
cos = cos.repeat(1, 2)
|
|
sin = sin.repeat(1, 2)
|
|
return cos, sin
|
|
|
|
def get_kv_cache_and_cache_idx(self, forward_batch):
|
|
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(self.layer_id)
|
|
slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)
|
|
return k_cache, v_cache, slot_mapping
|
|
|
|
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
|
|
input_dtype = hidden_states.dtype
|
|
if not self.has_preprocess_weights:
|
|
self.preprocess_weights(hidden_states)
|
|
self.has_preprocess_weights = True
|
|
self.dtype = hidden_states.dtype
|
|
|
|
cos, sin = self.get_sin_cos(positions)
|
|
k_cache, v_cache, slot_mapping = self.get_kv_cache_and_cache_idx(forward_batch)
|
|
|
|
q_nope_out = torch.empty(
|
|
(hidden_states.shape[0], self.w_kc.shape[0], k_cache.shape[-1]),
|
|
dtype=input_dtype,
|
|
device=hidden_states.device,
|
|
)
|
|
q_rope_out = torch.empty(
|
|
(hidden_states.shape[0], self.w_kc.shape[0], v_cache.shape[-1]),
|
|
dtype=input_dtype,
|
|
device=hidden_states.device,
|
|
)
|
|
|
|
# TODO: dummy inputs to be removed
|
|
# https://github.com/sgl-project/sgl-kernel-npu/issues/78
|
|
torch.ops.npu.mla_preprocess(
|
|
hidden_states,
|
|
self.dummy,
|
|
self.dummy,
|
|
self.qkv_a_proj_weight_nz,
|
|
self.qkv_a_proj_deq_scale_kvq,
|
|
self.q_a_layernorm.weight,
|
|
self.q_a_layernorm.bias,
|
|
self.q_b_proj_weight_nz,
|
|
self.q_b_proj_deq_scale,
|
|
self.kv_a_layernorm.weight,
|
|
cos,
|
|
sin,
|
|
self.w_kc,
|
|
k_cache,
|
|
v_cache,
|
|
slot_mapping,
|
|
quant_scale0=self.qkv_a_proj.input_scale,
|
|
quant_offset0=self.qkv_a_proj_input_offset,
|
|
bias0=self.qkv_a_proj_quant_bias_kvq,
|
|
quant_scale1=self.q_b_proj.input_scale,
|
|
quant_offset1=self.q_b_proj_input_offset,
|
|
bias1=self.q_b_proj_quant_bias,
|
|
cache_mode="krope_ctkv",
|
|
quant_mode="per_tensor_quant_asymm",
|
|
q_out0=q_nope_out,
|
|
kv_cache_out0=k_cache,
|
|
q_out1=q_rope_out,
|
|
kv_cache_out1=v_cache,
|
|
)
|
|
return (
|
|
q_rope_out,
|
|
v_cache,
|
|
q_nope_out,
|
|
k_cache,
|
|
forward_batch,
|
|
zero_allocator,
|
|
positions,
|
|
)
|