[Bugfix] Fix accuracy problem caused by mask pollution (#1678)
### What this PR does / why we need it?
If a small batch of short requests is sent first, forming a chunk with a
length <128, it will corrupt the `attn_mask_cache`, causing subsequent
requests that do not form a chunk to have accuracy issues.
The root cause of this problem is the use of in-place multiplication.
Modifying it to use out-of-place multiplication will resolve the
accuracy problem.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Yes.
- vLLM version: v0.9.2
- vLLM main:
ad6c2e1a0b
---------
Signed-off-by: ApsarasX <apsarax@outlook.com>
This commit is contained in:
@@ -105,3 +105,52 @@ class TestAttentionMaskBuilder(TestBase):
|
|||||||
device=torch.device("cpu"),
|
device=torch.device("cpu"),
|
||||||
)
|
)
|
||||||
self.assertEqual(attn_mask.shape, (1, 512))
|
self.assertEqual(attn_mask.shape, (1, 512))
|
||||||
|
|
||||||
|
def test_use_multiple_masks(self):
|
||||||
|
max_seq_lens = [128, 512, 1024]
|
||||||
|
dtypes = [torch.float16, torch.bfloat16, torch.int8]
|
||||||
|
for max_seq_len, dtype in zip(max_seq_lens, dtypes):
|
||||||
|
with self.subTest(max_seq_len=max_seq_len, dtype=dtype):
|
||||||
|
self._test_use_multiple_masks(max_seq_len, dtype)
|
||||||
|
|
||||||
|
def _test_use_multiple_masks(self, max_seq_len, dtype):
|
||||||
|
expected_mask_value = torch.finfo(
|
||||||
|
torch.float32).min if dtype == torch.float16 else 1
|
||||||
|
if dtype == torch.float16:
|
||||||
|
expected_splitfuse_mask_value = expected_mask_value
|
||||||
|
elif dtype == torch.bfloat16:
|
||||||
|
expected_splitfuse_mask_value = -10000
|
||||||
|
else:
|
||||||
|
assert dtype == torch.int8, "Unsupported dtype for attention mask"
|
||||||
|
expected_splitfuse_mask_value = -16
|
||||||
|
|
||||||
|
attention_mask_builder = AttentionMaskBuilder(max_seq_len=max_seq_len,
|
||||||
|
dtype=dtype)
|
||||||
|
|
||||||
|
splitfuse_attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
|
||||||
|
seq_lens=[max_seq_len],
|
||||||
|
query_lens=[max_seq_len],
|
||||||
|
position=torch.tensor([0]),
|
||||||
|
dtype=dtype,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
self.assertEqual(splitfuse_attn_mask.shape, (1, max_seq_len))
|
||||||
|
self.assertEqual(
|
||||||
|
splitfuse_attn_mask[0][-1],
|
||||||
|
torch.tensor(expected_splitfuse_mask_value, dtype=dtype))
|
||||||
|
self.assertEqual(attention_mask_builder._seq_len_cached, max_seq_len)
|
||||||
|
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
||||||
|
(max_seq_len, max_seq_len))
|
||||||
|
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
||||||
|
torch.tensor(expected_mask_value, dtype=dtype))
|
||||||
|
|
||||||
|
attn_mask = attention_mask_builder.get_attn_mask(
|
||||||
|
max_seq_len=max_seq_len, dtype=dtype, device=torch.device("cpu"))
|
||||||
|
self.assertEqual(attn_mask.shape, (max_seq_len, max_seq_len))
|
||||||
|
self.assertEqual(attn_mask[0][-1],
|
||||||
|
torch.tensor(expected_mask_value, dtype=dtype))
|
||||||
|
self.assertEqual(attention_mask_builder._seq_len_cached, max_seq_len)
|
||||||
|
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
||||||
|
(max_seq_len, max_seq_len))
|
||||||
|
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
||||||
|
torch.tensor(expected_mask_value, dtype=dtype))
|
||||||
|
|||||||
@@ -572,7 +572,8 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|||||||
attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
||||||
max_seq_len, dtype, device)
|
max_seq_len, dtype, device)
|
||||||
if attn_mask.numel() > 1 and attn_mask[0][1] > 0:
|
if attn_mask.numel() > 1 and attn_mask[0][1] > 0:
|
||||||
attn_mask *= -10000
|
# Do not use in-place multiplication to avoid modifying `attn_mask_cache`!
|
||||||
|
attn_mask = attn_mask * -10000
|
||||||
chunk_mask_list = []
|
chunk_mask_list = []
|
||||||
for i, seq_len in enumerate(seq_lens):
|
for i, seq_len in enumerate(seq_lens):
|
||||||
context_len = self.context_lens[i]
|
context_len = self.context_lens[i]
|
||||||
|
|||||||
@@ -68,7 +68,8 @@ class AttentionMaskBuilder:
|
|||||||
) > 1 and self.attn_mask_cache[0][1] > 0:
|
) > 1 and self.attn_mask_cache[0][1] > 0:
|
||||||
attn_mask = self.get_attn_mask( # type: ignore
|
attn_mask = self.get_attn_mask( # type: ignore
|
||||||
max_seq_len, dtype, device)
|
max_seq_len, dtype, device)
|
||||||
attn_mask *= -10000
|
# Do not use in-place multiplication to avoid modifying `self.attn_mask_cache`!
|
||||||
|
attn_mask = attn_mask * -10000
|
||||||
else:
|
else:
|
||||||
attn_mask = self.attn_mask_cache
|
attn_mask = self.attn_mask_cache
|
||||||
return torch.index_select(attn_mask, dim=0,
|
return torch.index_select(attn_mask, dim=0,
|
||||||
|
|||||||
Reference in New Issue
Block a user