Fix warning msg print (#3421)
### What this PR does / why we need it? Avoid printing some warning msg as below : UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach ... ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
@@ -79,7 +79,7 @@ class AscendMLAPrefillMetadata:
|
|||||||
chunk_seq_lens: torch.Tensor
|
chunk_seq_lens: torch.Tensor
|
||||||
|
|
||||||
attn_mask: torch.Tensor
|
attn_mask: torch.Tensor
|
||||||
query_lens: list[int]
|
query_lens: torch.Tensor
|
||||||
seq_lens: list[int]
|
seq_lens: list[int]
|
||||||
context_lens: torch.Tensor
|
context_lens: torch.Tensor
|
||||||
input_positions: torch.Tensor
|
input_positions: torch.Tensor
|
||||||
@@ -380,7 +380,7 @@ class AscendMLAMetadataBuilder:
|
|||||||
1).unsqueeze(2)
|
1).unsqueeze(2)
|
||||||
prefill_metadata = AscendMLAPrefillMetadata(
|
prefill_metadata = AscendMLAPrefillMetadata(
|
||||||
attn_mask=common_attn_metadata.attn_mask,
|
attn_mask=common_attn_metadata.attn_mask,
|
||||||
query_lens=query_lens[reqs_start:],
|
query_lens=query_lens[reqs_start:].to(torch.int32),
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
context_lens=seq_lens[reqs_start:],
|
context_lens=seq_lens[reqs_start:],
|
||||||
input_positions=prefill_input_positions,
|
input_positions=prefill_input_positions,
|
||||||
@@ -837,9 +837,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
k_rope=k_pe,
|
k_rope=k_pe,
|
||||||
value=value,
|
value=value,
|
||||||
mask=self.prefill_mask,
|
mask=self.prefill_mask,
|
||||||
seqlen=torch.tensor(
|
seqlen=attn_metadata.prefill.query_lens,
|
||||||
attn_metadata.prefill.query_lens,
|
|
||||||
dtype=torch.int32),
|
|
||||||
head_num=self.num_heads,
|
head_num=self.num_heads,
|
||||||
kv_head_num=self.num_heads,
|
kv_head_num=self.num_heads,
|
||||||
pre_out=None,
|
pre_out=None,
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class AscendMLATorchairPrefillMetadata:
|
|||||||
chunk_seq_lens: torch.Tensor
|
chunk_seq_lens: torch.Tensor
|
||||||
|
|
||||||
attn_mask: torch.Tensor
|
attn_mask: torch.Tensor
|
||||||
query_lens: list[int]
|
query_lens: torch.Tensor
|
||||||
seq_lens: list[int]
|
seq_lens: list[int]
|
||||||
context_lens: torch.Tensor
|
context_lens: torch.Tensor
|
||||||
input_positions: torch.Tensor
|
input_positions: torch.Tensor
|
||||||
@@ -473,7 +473,7 @@ class AscendMLATorchairMetadataBuilder:
|
|||||||
1).unsqueeze(2)
|
1).unsqueeze(2)
|
||||||
prefill_metadata = AscendMLATorchairPrefillMetadata(
|
prefill_metadata = AscendMLATorchairPrefillMetadata(
|
||||||
attn_mask=common_attn_metadata.attn_mask,
|
attn_mask=common_attn_metadata.attn_mask,
|
||||||
query_lens=query_lens[tokens_start:],
|
query_lens=query_lens[tokens_start:].to(torch.int32),
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
context_lens=seq_lens[tokens_start:],
|
context_lens=seq_lens[tokens_start:],
|
||||||
input_positions=prefill_input_positions,
|
input_positions=prefill_input_positions,
|
||||||
@@ -880,9 +880,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
|||||||
k_rope=k_pe,
|
k_rope=k_pe,
|
||||||
value=value,
|
value=value,
|
||||||
mask=self.prefill_mask,
|
mask=self.prefill_mask,
|
||||||
seqlen=torch.tensor(
|
seqlen=attn_metadata.prefill.query_lens,
|
||||||
attn_metadata.prefill.query_lens,
|
|
||||||
dtype=torch.int32),
|
|
||||||
head_num=self.num_heads,
|
head_num=self.num_heads,
|
||||||
kv_head_num=self.num_heads,
|
kv_head_num=self.num_heads,
|
||||||
pre_out=None,
|
pre_out=None,
|
||||||
|
|||||||
Reference in New Issue
Block a user