remove qwen2.py llama.py fix llama output

This commit is contained in:
hanhaowen
2025-12-31 11:31:26 +08:00
parent b3c30a3cb9
commit b015bb76fd
11 changed files with 65 additions and 1263 deletions

View File

@@ -43,7 +43,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.config import VllmConfig, get_layers_from_vllm_config
import inspect
class KunlunAttentionBackend(AttentionBackend):
"""KunlunAttentionBackend"""
@@ -723,30 +723,45 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
tmp_block_tables = decode_meta.block_tables
else:
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
xtorch_ops.speculative_attention(
out=output[:num_decode_tokens],
# Only MLA support q len > 1 right now
q=decode_query.unsqueeze(0),
k_cache=key_cache,
v_cache=value_cache,
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
context_lens_xpu=decode_meta.seq_lens_tensor,
batch_num=decode_meta.block_tables.shape[0],
# TODO (@xyDong23): Support MTP(q lens >1)
qlen=1,
# TODO (@xyDong23): Support max_context_len to (262144)
max_context_len=131072,
head_num=self.num_heads,
head_dim=self.head_size,
scale=0.0,
kv_head_num=self.num_kv_heads,
block_size=key_cache.shape[2],
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
max_window_size=self.sliding_window if self.sliding_window is not None else -1,
block_tables=tmp_block_tables,
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
)
sig = inspect.signature(xtorch_ops.speculative_attention)
if "max_window_size" in sig.parameters:
xtorch_ops.speculative_attention(
out=output[:num_decode_tokens],
# Only MLA support q len > 1 right now
q=decode_query.unsqueeze(0),
k_cache=key_cache,
v_cache=value_cache,
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
context_lens_xpu=decode_meta.seq_lens_tensor,
batch_num=decode_meta.block_tables.shape[0],
# TODO (@xyDong23): Support MTP(q lens >1)
qlen=1,
# TODO (@xyDong23): Support max_context_len to (262144)
max_context_len=131072,
head_num=self.num_heads,
head_dim=self.head_size,
scale=0.0,
kv_head_num=self.num_kv_heads,
block_size=key_cache.shape[2],
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
max_window_size=self.sliding_window if self.sliding_window is not None else -1,
block_tables=tmp_block_tables,
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
)
else:
xtorch_ops.paged_attention(
x=decode_query,
k_cache=key_cache,
v_cache=value_cache,
block_tables=tmp_block_tables,
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
context_lens_xpu=decode_meta.seq_lens_tensor,
is_context=False,
is_causal=True,
out=output[:num_decode_tokens],
vo_head_dim=self.head_size
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def use_cascade_attention(