[Bugfix] Fix index out of bounds in local attention with large sequences (#5173)
This commit is contained in:
@@ -236,7 +236,11 @@ def make_local_attention_virtual_batches(
|
|||||||
np.arange(pages_per_local_batch, dtype=np.int32),
|
np.arange(pages_per_local_batch, dtype=np.int32),
|
||||||
(virtual_batches, pages_per_local_batch),
|
(virtual_batches, pages_per_local_batch),
|
||||||
) + np.expand_dims(block_starts, axis=1)
|
) + np.expand_dims(block_starts, axis=1)
|
||||||
block_indices = block_indices.flatten()
|
# Ensure block_indices doesn't exceed block_table dimensions
|
||||||
|
# This is a critical safety check that prevents index out of bounds errors
|
||||||
|
# when dealing with large sequences (>8192 tokens) or when the block_table
|
||||||
|
# dimensions are smaller than what would be needed for the full attention chunk size.
|
||||||
|
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
||||||
batch_indices = np.repeat(
|
batch_indices = np.repeat(
|
||||||
np.arange(actual_batch_size, dtype=np.int32),
|
np.arange(actual_batch_size, dtype=np.int32),
|
||||||
local_blocks * pages_per_local_batch,
|
local_blocks * pages_per_local_batch,
|
||||||
|
|||||||
Reference in New Issue
Block a user