[Feature] Reduce host memory usage for attention mask generation (#3048)

### What this PR does / why we need it?

Previously, the mask construction process created multiple tensors of
size (max_model_len, max_model_len). When max_model_len reached 128k,
single GPU host memory usage exceeded hundreds of GB, causing process
OOM crashes. This update optimizes the mask generation to significantly
reduce memory consumption.

### Does this PR introduce _any_ user-facing change?

No.
### How was this patch tested?

CI pass.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2025-10-21 20:19:04 +08:00
committed by GitHub
parent 5f8b1699ae
commit 0c6349610e

View File

@@ -17,19 +17,15 @@ import torch
def _generate_attn_mask(max_seq_len, dtype):
# Construct lower triangle matrix.
mask_flag = torch.tril(
torch.ones((max_seq_len, max_seq_len),
dtype=torch.bool)).view(max_seq_len, max_seq_len)
mask_flag = torch.ones((max_seq_len, max_seq_len),
dtype=torch.bool).tril_()
# Create upper triangle matrix used to mark mask positions.
mask_flag = ~mask_flag
# Currently for fp16 dtype, the mask value should be set to -inf.
# TODO: Eliminate this part in the future.
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
mask_flag, mask_value).to(dtype)
mask_value = float('-inf') if dtype == torch.float16 else 1
attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype) \
.masked_fill_(mask_flag, mask_value)
return attn_mask