Add kernels to optimize RoPE and the decoding stage (#143)
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -442,98 +442,6 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def cp_gather_indexer_k_quant_cache(
|
|
||||||
kv_cache, # [num_blocks, block_size, head_dim + 1]
|
|
||||||
block_table, # [batch_size, num_blocks]
|
|
||||||
cu_seq_lens, # [batch_size + 1, ]
|
|
||||||
batch_size,
|
|
||||||
head_dim,
|
|
||||||
):
|
|
||||||
num_blocks, block_size, _ = kv_cache.shape
|
|
||||||
kv_cache = kv_cache.view(num_blocks, -1)
|
|
||||||
|
|
||||||
expected_value = []
|
|
||||||
expected_scale = []
|
|
||||||
for b in range(batch_size):
|
|
||||||
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
|
|
||||||
if s == 0:
|
|
||||||
continue
|
|
||||||
tot = cdiv(s, block_size)
|
|
||||||
blocks = block_table[b, :tot]
|
|
||||||
|
|
||||||
value = []
|
|
||||||
scale = []
|
|
||||||
full_block = torch.arange(tot - 1,
|
|
||||||
device=kv_cache.device,
|
|
||||||
dtype=torch.int32)
|
|
||||||
non_remaining_value = kv_cache[blocks[full_block], :block_size *
|
|
||||||
head_dim].view(-1, head_dim)
|
|
||||||
non_remaining_scale = kv_cache[blocks[full_block],
|
|
||||||
block_size * head_dim:].view(-1, 4)
|
|
||||||
|
|
||||||
remaining = s - (tot - 1) * block_size
|
|
||||||
|
|
||||||
value = torch.cat([
|
|
||||||
non_remaining_value,
|
|
||||||
kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)
|
|
||||||
],
|
|
||||||
dim=0)
|
|
||||||
scale = torch.cat([
|
|
||||||
non_remaining_scale,
|
|
||||||
kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
|
|
||||||
remaining * 4].view(-1, 4)
|
|
||||||
],
|
|
||||||
dim=0)
|
|
||||||
|
|
||||||
expected_value.append(value)
|
|
||||||
expected_scale.append(scale)
|
|
||||||
|
|
||||||
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
|
|
||||||
gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
|
|
||||||
gather_value = gather_value.view(torch.int8)
|
|
||||||
gather_scale = gather_scale.view(torch.float32)
|
|
||||||
return gather_value, gather_scale
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def kunlun_indexer_k_quant_cache(
|
|
||||||
k, #[num_tokens, head_dim]
|
|
||||||
kv_cache, # [num_blocks, cache_block_size, head_dim + 1]
|
|
||||||
slot_mapping, # [num_tokens]
|
|
||||||
quant_block_size,
|
|
||||||
):
|
|
||||||
num_blocks, cache_block_size, cache_stride = kv_cache.shape
|
|
||||||
# num_tokens, head_dim = k.shape
|
|
||||||
head_dim = k.shape[1]
|
|
||||||
num_tokens = slot_mapping.shape[0]
|
|
||||||
assert head_dim % quant_block_size == 0
|
|
||||||
kv_cache = kv_cache.view(num_blocks, -1)
|
|
||||||
|
|
||||||
k_fp8 = torch.empty(
|
|
||||||
k.shape,
|
|
||||||
device=k.device,
|
|
||||||
dtype=torch.int8,
|
|
||||||
)
|
|
||||||
k_scale = torch.empty(
|
|
||||||
[k.shape[0], 1],
|
|
||||||
device=k.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.ops._C.quant2d(k, k_fp8, k_scale, force_sdnn=True)
|
|
||||||
k_scale /= 127
|
|
||||||
for token_idx in range(num_tokens):
|
|
||||||
slot_idx = slot_mapping[token_idx]
|
|
||||||
if slot_idx < 0:
|
|
||||||
continue
|
|
||||||
block_idx = slot_idx // cache_block_size
|
|
||||||
block_offset = slot_idx % cache_block_size
|
|
||||||
v_offset = block_offset * head_dim
|
|
||||||
kv_cache[block_idx, v_offset:v_offset + head_dim] = k_fp8[token_idx, :].view(torch.uint8).contiguous()
|
|
||||||
s_offset = cache_block_size * head_dim + block_offset * 4
|
|
||||||
kv_cache[block_idx, s_offset:s_offset + 4] = k_scale[token_idx, :].view(torch.uint8).contiguous()
|
|
||||||
kv_cache = kv_cache.view(num_blocks, cache_block_size, cache_stride)
|
|
||||||
|
|
||||||
@custom_op("vllm::sparse_attn_indexer_vllm_kunlun", mutates_args=())
|
@custom_op("vllm::sparse_attn_indexer_vllm_kunlun", mutates_args=())
|
||||||
def sparse_attn_indexer_vllm_kunlun(
|
def sparse_attn_indexer_vllm_kunlun(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -578,12 +486,6 @@ def sparse_attn_indexer_vllm_kunlun(
|
|||||||
has_prefill = attn_metadata.num_prefills > 0
|
has_prefill = attn_metadata.num_prefills > 0
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
|
|
||||||
# kunlun_indexer_k_quant_cache(
|
|
||||||
# k,
|
|
||||||
# kv_cache,
|
|
||||||
# slot_mapping,
|
|
||||||
# quant_block_size,
|
|
||||||
# )
|
|
||||||
|
|
||||||
torch.ops.xspeedgate_ops.indexer_k_quant_and_cache(
|
torch.ops.xspeedgate_ops.indexer_k_quant_and_cache(
|
||||||
k,
|
k,
|
||||||
@@ -685,8 +587,16 @@ def sparse_attn_indexer_vllm_kunlun(
|
|||||||
mask = positions <= index_end_pos
|
mask = positions <= index_end_pos
|
||||||
# mask: [B * N, L]
|
# mask: [B * N, L]
|
||||||
logits = logits.masked_fill(~mask, float('-inf'))
|
logits = logits.masked_fill(~mask, float('-inf'))
|
||||||
topk_indices = torch.argsort(logits, dim=-1,
|
|
||||||
descending=True)[..., :min(topk_tokens, logits.shape[-1])]# [B * N, K]
|
del positions, mask
|
||||||
|
topk_indices = torch.ops._C.fast_topkv2(logits, decode_metadata.seq_lens, topk_tokens) # [B * N, K]
|
||||||
|
need_mask = decode_metadata.seq_lens_cpu.min() < topk_tokens
|
||||||
|
if need_mask:
|
||||||
|
positions_topk = torch.arange(topk_tokens,
|
||||||
|
device=current_device).unsqueeze(0).expand(
|
||||||
|
batch_size * next_n, -1)
|
||||||
|
mask_topk = positions_topk <= index_end_pos
|
||||||
|
topk_indices = topk_indices.masked_fill(~mask_topk, -1) # [B * N, K]
|
||||||
|
|
||||||
# ensure we don't set indices for the top k
|
# ensure we don't set indices for the top k
|
||||||
# that is out of range(masked already)
|
# that is out of range(masked already)
|
||||||
|
|||||||
@@ -159,10 +159,8 @@ def kunlun_flash_mla_with_kvcache(
|
|||||||
assert not causal, \
|
assert not causal, \
|
||||||
"causal must be `false` if sparse attention is enabled."
|
"causal must be `false` if sparse attention is enabled."
|
||||||
|
|
||||||
q_r, pe_cache = None, None # 当q_r和pe_cache为空时,为packed模式
|
|
||||||
batch_size, seq_len_q, num_heads_q, head_dim = q.shape
|
batch_size, seq_len_q, num_heads_q, head_dim = q.shape
|
||||||
kv_lora_rank = head_dim_v
|
kv_lora_rank = head_dim_v
|
||||||
rope_head_dim = head_dim - kv_lora_rank
|
|
||||||
|
|
||||||
out = torch.zeros([batch_size, seq_len_q, num_heads_q, kv_lora_rank],
|
out = torch.zeros([batch_size, seq_len_q, num_heads_q, kv_lora_rank],
|
||||||
dtype=q.dtype, device=q.device)
|
dtype=q.dtype, device=q.device)
|
||||||
|
|||||||
@@ -87,15 +87,13 @@ def int8_paged_mqa_logits(
|
|||||||
batch_size, next_n, _, D = q_fp8.shape
|
batch_size, next_n, _, D = q_fp8.shape
|
||||||
num_blocks, block_size, _, _ = kv_cache_fp8.shape
|
num_blocks, block_size, _, _ = kv_cache_fp8.shape
|
||||||
|
|
||||||
kv_cache_fp8=kv_cache_fp8.view(num_blocks, -1)
|
kv_cache_fp8 = kv_cache_fp8.view(num_blocks, -1)
|
||||||
k_val = kv_cache_fp8[:,:block_size*D].view(torch.int8)
|
k_val = kv_cache_fp8[:, :block_size * D].view(torch.int8)
|
||||||
k_val = k_val.view(-1,block_size, 1, D)
|
k_val = k_val.view(-1, block_size, 1, D)
|
||||||
k_scale_list = []
|
|
||||||
for block_tables_idx in range(block_tables.shape[0]):
|
block_indices = block_tables.flatten()
|
||||||
k_scale_item = kv_cache_fp8[block_tables[block_tables_idx], block_size *
|
k_scale = kv_cache_fp8[block_indices, block_size * D:].view(-1, 4).view(torch.float32)
|
||||||
D:].view(-1, 4)
|
k_scale = k_scale.view(-1, max_model_len)
|
||||||
k_scale_list.append(k_scale_item)
|
|
||||||
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).view(-1,max_model_len)
|
|
||||||
kv_cache = [k_val, k_scale]
|
kv_cache = [k_val, k_scale]
|
||||||
|
|
||||||
weights = weights.view(batch_size,next_n,-1)
|
weights = weights.view(batch_size,next_n,-1)
|
||||||
|
|||||||
@@ -76,6 +76,24 @@ def vllm_kunlun_forward_cuda(
|
|||||||
self.cos_sin_cache, self.is_neox_style)
|
self.cos_sin_cache, self.is_neox_style)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
def vllm_ds_rope_forward_cuda(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor | None = None,
|
||||||
|
offsets: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
|
return torch.ops.xspeedgate_ops.flashinfer_rotary_embedding(
|
||||||
|
positions=positions,
|
||||||
|
rotary_dim=self.rotary_dim,
|
||||||
|
head_size=self.head_size,
|
||||||
|
cos_sin_cache=self.cos_sin_cache,
|
||||||
|
is_neox_style=self.is_neox_style,
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
offsets=offsets,
|
||||||
|
)
|
||||||
|
|
||||||
def apply_interleaved_rope(x: torch.Tensor,
|
def apply_interleaved_rope(x: torch.Tensor,
|
||||||
mrope_section: list[int]) -> torch.Tensor:
|
mrope_section: list[int]) -> torch.Tensor:
|
||||||
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
||||||
@@ -145,12 +163,10 @@ def vllm_kunlun_mrope_forward_cuda(
|
|||||||
|
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
DeepseekScalingRotaryEmbedding_forward = DeepseekScalingRotaryEmbedding.forward
|
|
||||||
DeepseekScalingRotaryEmbedding_forward_cuda = DeepseekScalingRotaryEmbedding.forward_cuda
|
|
||||||
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||||
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||||
DeepseekScalingRotaryEmbedding.forward = DeepseekScalingRotaryEmbedding_forward
|
DeepseekScalingRotaryEmbedding.forward = vllm_ds_rope_forward_cuda
|
||||||
DeepseekScalingRotaryEmbedding.forward_cuda = DeepseekScalingRotaryEmbedding_forward_cuda
|
DeepseekScalingRotaryEmbedding.forward_cuda = vllm_ds_rope_forward_cuda
|
||||||
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
||||||
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
||||||
|
|
||||||
|
|||||||
@@ -2235,3 +2235,40 @@ def _fake_fwd_kvcache_mla(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla)
|
fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla)
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# ------------------ fast_topkv2 -------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::fast_topkv2", mutates_args=())
|
||||||
|
def fast_topkv2(
|
||||||
|
score: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
topk: Optional[int] = 2048) -> torch.Tensor:
|
||||||
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||||
|
topk_indices = xtorch_ops.fast_topkv2(
|
||||||
|
score=score,
|
||||||
|
lengths=lengths,
|
||||||
|
topk=topk)
|
||||||
|
return topk_indices
|
||||||
|
|
||||||
|
@impl("_C::fast_topkv2", "CUDA")
|
||||||
|
def fast_topkv2_cuda(
|
||||||
|
score: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
topk: Optional[int] = 2048) -> torch.Tensor:
|
||||||
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||||
|
topk_indices = xtorch_ops.fast_topkv2(
|
||||||
|
score=score,
|
||||||
|
lengths=lengths,
|
||||||
|
topk=topk)
|
||||||
|
return topk_indices
|
||||||
|
|
||||||
|
def _fake_fast_topkv2(
|
||||||
|
score: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
topk: Optional[int] = 2048) -> torch.Tensor:
|
||||||
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||||
|
topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32)
|
||||||
|
return topk_indices
|
||||||
|
|
||||||
|
fast_topkv2.register_fake(_fake_fast_topkv2)
|
||||||
Reference in New Issue
Block a user