[Fix] Fix cuda graph padding for triton attention backend (#1782)
This commit is contained in:
@@ -41,10 +41,6 @@ class AttentionBackend(ABC):
|
|||||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
|
||||||
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
|
|||||||
@@ -161,9 +161,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|||||||
self.cuda_graph_start_loc.zero_()
|
self.cuda_graph_start_loc.zero_()
|
||||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
|
||||||
return 1
|
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -210,9 +210,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -108,9 +108,6 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.cuda_graph_start_loc.zero_()
|
self.cuda_graph_start_loc.zero_()
|
||||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
|
||||||
return 1
|
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class ReqToTokenPool:
|
|||||||
self.size = size
|
self.size = size
|
||||||
self.max_context_len = max_context_len
|
self.max_context_len = max_context_len
|
||||||
self.device = device
|
self.device = device
|
||||||
self.req_to_token = torch.empty(
|
self.req_to_token = torch.zeros(
|
||||||
(size, max_context_len), dtype=torch.int32, device=device
|
(size, max_context_len), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
self.free_slots = list(range(size))
|
self.free_slots = list(range(size))
|
||||||
|
|||||||
@@ -133,11 +133,8 @@ class CudaGraphRunner:
|
|||||||
# Attention backend
|
# Attention backend
|
||||||
self.max_bs = max(self.capture_bs)
|
self.max_bs = max(self.capture_bs)
|
||||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
||||||
self.seq_len_fill_value = (
|
|
||||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
|
||||||
)
|
|
||||||
|
|
||||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
self.seq_len_fill_value = 1
|
||||||
self.encoder_len_fill_value = 0
|
self.encoder_len_fill_value = 0
|
||||||
|
|
||||||
if self.use_torch_compile:
|
if self.use_torch_compile:
|
||||||
@@ -290,7 +287,7 @@ class CudaGraphRunner:
|
|||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
bs = self.capture_bs[index]
|
bs = self.capture_bs[index]
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens.fill_(1)
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
|
|||||||
Reference in New Issue
Block a user