[main][bugfix] Fix bugs and refactor cached mask generation logic (#2442)
### What this PR does / why we need it?
This PR fix bugs and refactor cached mask generation logic. Now just
pre-construct and use the cached mask on cpu instead of device on npu.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.10.1.1
- vLLM main:
9b5f64238f
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -79,11 +79,10 @@ class EagleProposer:
|
||||
def _make_attention_mask(
|
||||
self,
|
||||
seq_lens,
|
||||
query_lens,
|
||||
position,
|
||||
) -> torch.Tensor:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens, query_lens, position, self.dtype, self.device)
|
||||
seq_lens, position, self.dtype, self.device)
|
||||
|
||||
def propose(
|
||||
self,
|
||||
@@ -247,7 +246,6 @@ class EagleProposer:
|
||||
positions = positions_cpu.to(device)
|
||||
attn_mask = self._make_attention_mask(
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
query_lens=attn_metadata.max_query_len,
|
||||
position=positions,
|
||||
)
|
||||
attn_metadata.attn_mask = attn_mask
|
||||
|
||||
Reference in New Issue
Block a user