From 0c6349610ec33e8796d5509be13993b719db0011 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Tue, 21 Oct 2025 20:19:04 +0800 Subject: [PATCH] [Feature] Reduce host memory usage for attention mask generation (#3048) ### What this PR does / why we need it? Previously, the mask construction process created multiple tensors of size (max_model_len, max_model_len). When max_model_len reached 128k, single GPU host memory usage exceeded hundreds of GB, causing process OOM crashes. This update optimizes the mask generation to significantly reduce memory consumption. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI pass. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Jade Zheng --- vllm_ascend/attention/attention_mask.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index 079efff..b1da723 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -17,19 +17,15 @@ import torch def _generate_attn_mask(max_seq_len, dtype): # Construct lower triangle matrix. - mask_flag = torch.tril( - torch.ones((max_seq_len, max_seq_len), - dtype=torch.bool)).view(max_seq_len, max_seq_len) + mask_flag = torch.ones((max_seq_len, max_seq_len), + dtype=torch.bool).tril_() # Create upper triangle matrix used to mark mask positions. mask_flag = ~mask_flag # Currently for fp16 dtype, the mask value should be set to -inf. # TODO: Eliminate this part in the future. - if dtype == torch.float16: - mask_value = torch.finfo(torch.float32).min - else: - mask_value = 1 - attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)), - mask_flag, mask_value).to(dtype) + mask_value = float('-inf') if dtype == torch.float16 else 1 + attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype) \ + .masked_fill_(mask_flag, mask_value) return attn_mask