[Feature] Reduce host memory usage for attention mask generation (#3048)
### What this PR does / why we need it? Previously, the mask construction process created multiple tensors of size (max_model_len, max_model_len). When max_model_len reached 128k, single GPU host memory usage exceeded hundreds of GB, causing process OOM crashes. This update optimizes the mask generation to significantly reduce memory consumption. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI pass. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -17,19 +17,15 @@ import torch
|
||||
|
||||
def _generate_attn_mask(max_seq_len, dtype):
|
||||
# Construct lower triangle matrix.
|
||||
mask_flag = torch.tril(
|
||||
torch.ones((max_seq_len, max_seq_len),
|
||||
dtype=torch.bool)).view(max_seq_len, max_seq_len)
|
||||
mask_flag = torch.ones((max_seq_len, max_seq_len),
|
||||
dtype=torch.bool).tril_()
|
||||
# Create upper triangle matrix used to mark mask positions.
|
||||
mask_flag = ~mask_flag
|
||||
# Currently for fp16 dtype, the mask value should be set to -inf.
|
||||
# TODO: Eliminate this part in the future.
|
||||
if dtype == torch.float16:
|
||||
mask_value = torch.finfo(torch.float32).min
|
||||
else:
|
||||
mask_value = 1
|
||||
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
|
||||
mask_flag, mask_value).to(dtype)
|
||||
mask_value = float('-inf') if dtype == torch.float16 else 1
|
||||
attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype) \
|
||||
.masked_fill_(mask_flag, mask_value)
|
||||
return attn_mask
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user