remove qwen2.py llama.py fix llama output
This commit is contained in:
@@ -43,7 +43,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
|
||||
import inspect
|
||||
|
||||
class KunlunAttentionBackend(AttentionBackend):
|
||||
"""KunlunAttentionBackend"""
|
||||
@@ -723,30 +723,45 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
tmp_block_tables = decode_meta.block_tables
|
||||
else:
|
||||
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
|
||||
|
||||
xtorch_ops.speculative_attention(
|
||||
out=output[:num_decode_tokens],
|
||||
# Only MLA support q len > 1 right now
|
||||
q=decode_query.unsqueeze(0),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||
batch_num=decode_meta.block_tables.shape[0],
|
||||
# TODO (@xyDong23): Support MTP(q lens >1)
|
||||
qlen=1,
|
||||
# TODO (@xyDong23): Support max_context_len to (262144)
|
||||
max_context_len=131072,
|
||||
head_num=self.num_heads,
|
||||
head_dim=self.head_size,
|
||||
scale=0.0,
|
||||
kv_head_num=self.num_kv_heads,
|
||||
block_size=key_cache.shape[2],
|
||||
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
|
||||
max_window_size=self.sliding_window if self.sliding_window is not None else -1,
|
||||
block_tables=tmp_block_tables,
|
||||
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
|
||||
)
|
||||
|
||||
sig = inspect.signature(xtorch_ops.speculative_attention)
|
||||
if "max_window_size" in sig.parameters:
|
||||
xtorch_ops.speculative_attention(
|
||||
out=output[:num_decode_tokens],
|
||||
# Only MLA support q len > 1 right now
|
||||
q=decode_query.unsqueeze(0),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||
batch_num=decode_meta.block_tables.shape[0],
|
||||
# TODO (@xyDong23): Support MTP(q lens >1)
|
||||
qlen=1,
|
||||
# TODO (@xyDong23): Support max_context_len to (262144)
|
||||
max_context_len=131072,
|
||||
head_num=self.num_heads,
|
||||
head_dim=self.head_size,
|
||||
scale=0.0,
|
||||
kv_head_num=self.num_kv_heads,
|
||||
block_size=key_cache.shape[2],
|
||||
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
|
||||
max_window_size=self.sliding_window if self.sliding_window is not None else -1,
|
||||
block_tables=tmp_block_tables,
|
||||
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
|
||||
)
|
||||
else:
|
||||
xtorch_ops.paged_attention(
|
||||
x=decode_query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_tables=tmp_block_tables,
|
||||
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||
is_context=False,
|
||||
is_causal=True,
|
||||
out=output[:num_decode_tokens],
|
||||
vo_head_dim=self.head_size
|
||||
)
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
def use_cascade_attention(
|
||||
|
||||
Reference in New Issue
Block a user