[Feat] Support async_scheduler and disable_padded_drafter_batch in eagle (#4893)
### What this PR does / why we need it?
We refactored the eagle_proposer.py to adapt the framework of eagle.py
in vllm-v0.12.0, to support the logit of padded drafter batch and
async-scheduler.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
Co-authored-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -730,6 +730,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
if get_ascend_device_type() == AscendDeviceType._910_95:
|
||||
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
|
||||
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
|
||||
# If it's necessary, the slots should be sliced.
|
||||
torch_npu.npu_scatter_pa_kv_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens],
|
||||
value=value[:attn_metadata.num_actual_tokens].contiguous(),
|
||||
@@ -742,7 +745,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
value=value[:attn_metadata.num_actual_tokens],
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
slot_indices=slots[:attn_metadata.num_actual_tokens])
|
||||
return key, value
|
||||
|
||||
def forward_impl(
|
||||
|
||||
@@ -119,6 +119,35 @@ class AscendCommonAttentionMetadata:
|
||||
prefill_context_parallel_metadata: Optional[
|
||||
AscendPrefillContextParallelMetadata] = None
|
||||
|
||||
# TODO: Remove it when vLLM no longer uses this function.
|
||||
def unpadded(self, num_actual_tokens: int,
|
||||
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
|
||||
# This only use to eagle now. It will be use to enforce_eager in future.
|
||||
return AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_actual_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_actual_reqs + 1],
|
||||
seq_lens=self.seq_lens[:num_actual_reqs],
|
||||
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
|
||||
num_computed_tokens_cpu=self.
|
||||
num_computed_tokens_cpu[:num_actual_reqs],
|
||||
num_reqs=num_actual_reqs,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=self.max_query_len,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
|
||||
slot_mapping=self.slot_mapping[:num_actual_tokens],
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
|
||||
positions=self.positions[:num_actual_tokens],
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
attn_state=self.attn_state,
|
||||
is_only_prefill=self.is_only_prefill,
|
||||
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
|
||||
num_input_tokens=num_actual_tokens,
|
||||
prefill_context_parallel_metadata=self.
|
||||
prefill_context_parallel_metadata,
|
||||
)
|
||||
|
||||
|
||||
def filter_chunked_req_indices(
|
||||
seq_len: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user