diff --git a/vllm_kunlun/models/deepseek_v2.py b/vllm_kunlun/models/deepseek_v2.py index c1c00f3..de63226 100644 --- a/vllm_kunlun/models/deepseek_v2.py +++ b/vllm_kunlun/models/deepseek_v2.py @@ -442,98 +442,6 @@ class DeepseekV2Attention(nn.Module): output, _ = self.o_proj(attn_output) return output -@torch.inference_mode() -def cp_gather_indexer_k_quant_cache( - kv_cache, # [num_blocks, block_size, head_dim + 1] - block_table, # [batch_size, num_blocks] - cu_seq_lens, # [batch_size + 1, ] - batch_size, - head_dim, -): - num_blocks, block_size, _ = kv_cache.shape - kv_cache = kv_cache.view(num_blocks, -1) - - expected_value = [] - expected_scale = [] - for b in range(batch_size): - s = cu_seq_lens[b + 1] - cu_seq_lens[b] - if s == 0: - continue - tot = cdiv(s, block_size) - blocks = block_table[b, :tot] - - value = [] - scale = [] - full_block = torch.arange(tot - 1, - device=kv_cache.device, - dtype=torch.int32) - non_remaining_value = kv_cache[blocks[full_block], :block_size * - head_dim].view(-1, head_dim) - non_remaining_scale = kv_cache[blocks[full_block], - block_size * head_dim:].view(-1, 4) - - remaining = s - (tot - 1) * block_size - - value = torch.cat([ - non_remaining_value, - kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim) - ], - dim=0) - scale = torch.cat([ - non_remaining_scale, - kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim + - remaining * 4].view(-1, 4) - ], - dim=0) - - expected_value.append(value) - expected_scale.append(scale) - - gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim) - gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4) - gather_value = gather_value.view(torch.int8) - gather_scale = gather_scale.view(torch.float32) - return gather_value, gather_scale - -@torch.inference_mode() -def kunlun_indexer_k_quant_cache( - k, #[num_tokens, head_dim] - kv_cache, # [num_blocks, cache_block_size, head_dim + 1] - slot_mapping, # [num_tokens] - quant_block_size, -): - num_blocks, cache_block_size, cache_stride = kv_cache.shape - # num_tokens, head_dim = k.shape - head_dim = k.shape[1] - num_tokens = slot_mapping.shape[0] - assert head_dim % quant_block_size == 0 - kv_cache = kv_cache.view(num_blocks, -1) - - k_fp8 = torch.empty( - k.shape, - device=k.device, - dtype=torch.int8, - ) - k_scale = torch.empty( - [k.shape[0], 1], - device=k.device, - dtype=torch.float32, - ) - - torch.ops._C.quant2d(k, k_fp8, k_scale, force_sdnn=True) - k_scale /= 127 - for token_idx in range(num_tokens): - slot_idx = slot_mapping[token_idx] - if slot_idx < 0: - continue - block_idx = slot_idx // cache_block_size - block_offset = slot_idx % cache_block_size - v_offset = block_offset * head_dim - kv_cache[block_idx, v_offset:v_offset + head_dim] = k_fp8[token_idx, :].view(torch.uint8).contiguous() - s_offset = cache_block_size * head_dim + block_offset * 4 - kv_cache[block_idx, s_offset:s_offset + 4] = k_scale[token_idx, :].view(torch.uint8).contiguous() - kv_cache = kv_cache.view(num_blocks, cache_block_size, cache_stride) - @custom_op("vllm::sparse_attn_indexer_vllm_kunlun", mutates_args=()) def sparse_attn_indexer_vllm_kunlun( hidden_states: torch.Tensor, @@ -578,12 +486,6 @@ def sparse_attn_indexer_vllm_kunlun( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - # kunlun_indexer_k_quant_cache( - # k, - # kv_cache, - # slot_mapping, - # quant_block_size, - # ) torch.ops.xspeedgate_ops.indexer_k_quant_and_cache( k, @@ -685,8 +587,16 @@ def sparse_attn_indexer_vllm_kunlun( mask = positions <= index_end_pos # mask: [B * N, L] logits = logits.masked_fill(~mask, float('-inf')) - topk_indices = torch.argsort(logits, dim=-1, - descending=True)[..., :min(topk_tokens, logits.shape[-1])]# [B * N, K] + + del positions, mask + topk_indices = torch.ops._C.fast_topkv2(logits, decode_metadata.seq_lens, topk_tokens) # [B * N, K] + need_mask = decode_metadata.seq_lens_cpu.min() < topk_tokens + if need_mask: + positions_topk = torch.arange(topk_tokens, + device=current_device).unsqueeze(0).expand( + batch_size * next_n, -1) + mask_topk = positions_topk <= index_end_pos + topk_indices = topk_indices.masked_fill(~mask_topk, -1) # [B * N, K] # ensure we don't set indices for the top k # that is out of range(masked already) diff --git a/vllm_kunlun/ops/attention/flashmla.py b/vllm_kunlun/ops/attention/flashmla.py index 2375799..0eeb4a5 100644 --- a/vllm_kunlun/ops/attention/flashmla.py +++ b/vllm_kunlun/ops/attention/flashmla.py @@ -159,10 +159,8 @@ def kunlun_flash_mla_with_kvcache( assert not causal, \ "causal must be `false` if sparse attention is enabled." - q_r, pe_cache = None, None # 当q_r和pe_cache为空时,为packed模式 batch_size, seq_len_q, num_heads_q, head_dim = q.shape kv_lora_rank = head_dim_v - rope_head_dim = head_dim - kv_lora_rank out = torch.zeros([batch_size, seq_len_q, num_heads_q, kv_lora_rank], dtype=q.dtype, device=q.device) diff --git a/vllm_kunlun/ops/deep_gemm.py b/vllm_kunlun/ops/deep_gemm.py index 063c5b6..ffc7c90 100644 --- a/vllm_kunlun/ops/deep_gemm.py +++ b/vllm_kunlun/ops/deep_gemm.py @@ -87,15 +87,13 @@ def int8_paged_mqa_logits( batch_size, next_n, _, D = q_fp8.shape num_blocks, block_size, _, _ = kv_cache_fp8.shape - kv_cache_fp8=kv_cache_fp8.view(num_blocks, -1) - k_val = kv_cache_fp8[:,:block_size*D].view(torch.int8) - k_val = k_val.view(-1,block_size, 1, D) - k_scale_list = [] - for block_tables_idx in range(block_tables.shape[0]): - k_scale_item = kv_cache_fp8[block_tables[block_tables_idx], block_size * - D:].view(-1, 4) - k_scale_list.append(k_scale_item) - k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).view(-1,max_model_len) + kv_cache_fp8 = kv_cache_fp8.view(num_blocks, -1) + k_val = kv_cache_fp8[:, :block_size * D].view(torch.int8) + k_val = k_val.view(-1, block_size, 1, D) + + block_indices = block_tables.flatten() + k_scale = kv_cache_fp8[block_indices, block_size * D:].view(-1, 4).view(torch.float32) + k_scale = k_scale.view(-1, max_model_len) kv_cache = [k_val, k_scale] weights = weights.view(batch_size,next_n,-1) diff --git a/vllm_kunlun/ops/rotary_embedding.py b/vllm_kunlun/ops/rotary_embedding.py index a4c6289..7704832 100644 --- a/vllm_kunlun/ops/rotary_embedding.py +++ b/vllm_kunlun/ops/rotary_embedding.py @@ -76,6 +76,24 @@ def vllm_kunlun_forward_cuda( self.cos_sin_cache, self.is_neox_style) return query, key +def vllm_ds_rope_forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor | None]: + return torch.ops.xspeedgate_ops.flashinfer_rotary_embedding( + positions=positions, + rotary_dim=self.rotary_dim, + head_size=self.head_size, + cos_sin_cache=self.cos_sin_cache, + is_neox_style=self.is_neox_style, + query=query, + key=key, + offsets=offsets, + ) + def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: """Apply interleaved MRoPE to 3D rotary embeddings. @@ -145,12 +163,10 @@ def vllm_kunlun_mrope_forward_cuda( return query, key -DeepseekScalingRotaryEmbedding_forward = DeepseekScalingRotaryEmbedding.forward -DeepseekScalingRotaryEmbedding_forward_cuda = DeepseekScalingRotaryEmbedding.forward_cuda RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda RotaryEmbedding.forward = vllm_kunlun_forward_cuda -DeepseekScalingRotaryEmbedding.forward = DeepseekScalingRotaryEmbedding_forward -DeepseekScalingRotaryEmbedding.forward_cuda = DeepseekScalingRotaryEmbedding_forward_cuda +DeepseekScalingRotaryEmbedding.forward = vllm_ds_rope_forward_cuda +DeepseekScalingRotaryEmbedding.forward_cuda = vllm_ds_rope_forward_cuda MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index ca7273c..0fb258f 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -2235,3 +2235,40 @@ def _fake_fwd_kvcache_mla( return None fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla) + +################################################## +# ------------------ fast_topkv2 ------------- +################################################## +@custom_op("_C::fast_topkv2", mutates_args=()) +def fast_topkv2( + score: torch.Tensor, + lengths: torch.Tensor, + topk: Optional[int] = 2048) -> torch.Tensor: + assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now" + topk_indices = xtorch_ops.fast_topkv2( + score=score, + lengths=lengths, + topk=topk) + return topk_indices + +@impl("_C::fast_topkv2", "CUDA") +def fast_topkv2_cuda( + score: torch.Tensor, + lengths: torch.Tensor, + topk: Optional[int] = 2048) -> torch.Tensor: + assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now" + topk_indices = xtorch_ops.fast_topkv2( + score=score, + lengths=lengths, + topk=topk) + return topk_indices + +def _fake_fast_topkv2( + score: torch.Tensor, + lengths: torch.Tensor, + topk: Optional[int] = 2048) -> torch.Tensor: + assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now" + topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32) + return topk_indices + +fast_topkv2.register_fake(_fake_fast_topkv2) \ No newline at end of file