[Core] Support pooling (#229)

This PR added pooling support for vllm-ascend

Tested with `bge-base-en-v1.5` by encode:
```
from vllm import LLM

# Sample prompts.
prompts = [
  "Hello, my name is",
  "The president of the United States is",
  "The capital of France is",
  "The future of AI is",
]
# Create an LLM.
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
    print(output.outputs.embedding)  # list of 4096 floats
```

Tested by embedding:
```
from vllm import LLM, SamplingParams

llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")

embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```

Related: https://github.com/vllm-project/vllm-ascend/issues/200

## Known issue
The accuracy is not correct since this feature rely on `enc-dec`
support. It'll be done in the following PR by @MengqingCao

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-03-04 15:59:34 +08:00
committed by GitHub
parent 8fda31cafe
commit ae49bfd13a
7 changed files with 258 additions and 71 deletions

View File

@@ -38,7 +38,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm_ascend.model_runner import ModelInputForNPUBuilder
from vllm_ascend.worker.model_runner import ModelInputForNPUBuilder
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
@@ -211,6 +211,9 @@ class AscendMetadata(AttentionMetadata):
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
@@ -258,6 +261,9 @@ class AscendMetadata(AttentionMetadata):
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
# Construct & cache prefill-phase attention metadata structure.
self._cached_prefill_metadata = AscendMetadata(
num_prefills=self.num_prefills,
@@ -265,6 +271,7 @@ class AscendMetadata(AttentionMetadata):
num_decode_tokens=0,
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
@@ -297,7 +304,8 @@ class AscendMetadata(AttentionMetadata):
self.seq_lens[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
# Construct & cache decode-phase attention metadata structure.
self._cached_decode_metadata = AscendMetadata(
num_prefills=0,
@@ -305,6 +313,7 @@ class AscendMetadata(AttentionMetadata):
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
block_tables=block_tables,
@@ -322,7 +331,6 @@ class AscendMetadata(AttentionMetadata):
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
_metadata_cls = AscendMetadata
_attn_mask_builder = None # noqa
def __init__(self, input_builder: "ModelInputForNPUBuilder"):
@@ -451,7 +459,11 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
self.multimodal_placeholder_maps.items()
}
return self._metadata_cls( # type: ignore
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.long,
device=device)
return AscendMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
@@ -459,6 +471,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=self.num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
@@ -528,12 +541,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
shape = [batch_size, seq_len * num_heads * head_size]
"""
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
# View q k v to BSH.
num_tokens = query.shape[0]
query = query.view(-1, self.num_heads, self.head_size)