[Fix] Fix cuda graph padding for triton attention backend (#1782)
This commit is contained in:
@@ -41,10 +41,6 @@ class AttentionBackend(ABC):
|
||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@@ -161,9 +161,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
||||
self.cuda_graph_start_loc.zero_()
|
||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def forward_extend(
|
||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||
):
|
||||
|
||||
@@ -210,9 +210,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 0
|
||||
|
||||
def forward_extend(
|
||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||
):
|
||||
|
||||
@@ -108,9 +108,6 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.cuda_graph_start_loc.zero_()
|
||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def forward_extend(
|
||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||
):
|
||||
|
||||
@@ -38,7 +38,7 @@ class ReqToTokenPool:
|
||||
self.size = size
|
||||
self.max_context_len = max_context_len
|
||||
self.device = device
|
||||
self.req_to_token = torch.empty(
|
||||
self.req_to_token = torch.zeros(
|
||||
(size, max_context_len), dtype=torch.int32, device=device
|
||||
)
|
||||
self.free_slots = list(range(size))
|
||||
|
||||
@@ -133,11 +133,8 @@ class CudaGraphRunner:
|
||||
# Attention backend
|
||||
self.max_bs = max(self.capture_bs)
|
||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
||||
self.seq_len_fill_value = (
|
||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
)
|
||||
|
||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||
self.seq_len_fill_value = 1
|
||||
self.encoder_len_fill_value = 0
|
||||
|
||||
if self.use_torch_compile:
|
||||
@@ -290,7 +287,7 @@ class CudaGraphRunner:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(1)
|
||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
|
||||
Reference in New Issue
Block a user