[Core] Support pooling (#229)
This PR added pooling support for vllm-ascend
Tested with `bge-base-en-v1.5` by encode:
```
from vllm import LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats
```
Tested by embedding:
```
from vllm import LLM, SamplingParams
llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")
embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```
Related: https://github.com/vllm-project/vllm-ascend/issues/200
## Known issue
The accuracy is not correct since this feature rely on `enc-dec`
support. It'll be done in the following PR by @MengqingCao
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -38,7 +38,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm_ascend.model_runner import ModelInputForNPUBuilder
|
||||
from vllm_ascend.worker.model_runner import ModelInputForNPUBuilder
|
||||
|
||||
|
||||
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
|
||||
@@ -211,6 +211,9 @@ class AscendMetadata(AttentionMetadata):
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
@@ -258,6 +261,9 @@ class AscendMetadata(AttentionMetadata):
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure.
|
||||
self._cached_prefill_metadata = AscendMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
@@ -265,6 +271,7 @@ class AscendMetadata(AttentionMetadata):
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
@@ -297,7 +304,8 @@ class AscendMetadata(AttentionMetadata):
|
||||
self.seq_lens[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
# Construct & cache decode-phase attention metadata structure.
|
||||
self._cached_decode_metadata = AscendMetadata(
|
||||
num_prefills=0,
|
||||
@@ -305,6 +313,7 @@ class AscendMetadata(AttentionMetadata):
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
block_tables=block_tables,
|
||||
@@ -322,7 +331,6 @@ class AscendMetadata(AttentionMetadata):
|
||||
|
||||
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
|
||||
_metadata_cls = AscendMetadata
|
||||
_attn_mask_builder = None # noqa
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForNPUBuilder"):
|
||||
@@ -451,7 +459,11 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
return self._metadata_cls( # type: ignore
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
return AscendMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
@@ -459,6 +471,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
@@ -528,12 +541,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
shape = [batch_size, seq_len * num_heads * head_size]
|
||||
"""
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
# View q k v to BSH.
|
||||
num_tokens = query.shape[0]
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
Reference in New Issue
Block a user