[Bugfix] Disable torch.compile() (#370)
### What this PR does / why we need it? To resolve this [patch](https://github.com/vllm-project/vllm-ascend/pull/236/files#diff-43b96b39b5a52fe209d86449ad703a7ff5e1349ebaf1aa12ece8d82163ee5b61R24-R49) , we need to set `torch.compile()` backend to `eager` to disable compile, using default pytorch way. --------- Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -988,7 +988,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
self.num_heads,
|
||||
self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device="npu")
|
||||
device=query.device)
|
||||
if (attn_metadata.block_tables is None
|
||||
or attn_metadata.block_tables.numel() == 0):
|
||||
assert attn_metadata.attn_mask is not None
|
||||
@@ -1015,11 +1015,12 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
)
|
||||
elif attn_metadata.decode_metadata:
|
||||
assert kv_cache is not None
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=query.dtype,
|
||||
device="npu")
|
||||
# if torch.empty is used here, the preemptive scheduling case of
|
||||
# test_mtp_correctness.py will fail to run.
|
||||
attn_output = torch.randn(
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
|
||||
Reference in New Issue
Block a user