Re-introduce get_cuda_graph_seq_len_fill_value (#1783)
This commit is contained in:
@@ -41,6 +41,10 @@ class AttentionBackend(ABC):
|
||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@@ -161,6 +161,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
||||
self.cuda_graph_start_loc.zero_()
|
||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def forward_extend(
|
||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||
):
|
||||
|
||||
@@ -210,6 +210,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 0
|
||||
|
||||
def forward_extend(
|
||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||
):
|
||||
|
||||
@@ -108,6 +108,9 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.cuda_graph_start_loc.zero_()
|
||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def forward_extend(
|
||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||
):
|
||||
|
||||
@@ -134,7 +134,11 @@ class CudaGraphRunner:
|
||||
self.max_bs = max(self.capture_bs)
|
||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
||||
|
||||
self.seq_len_fill_value = 1
|
||||
self.seq_len_fill_value = (
|
||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
)
|
||||
|
||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||
self.encoder_len_fill_value = 0
|
||||
|
||||
if self.use_torch_compile:
|
||||
@@ -287,7 +291,7 @@ class CudaGraphRunner:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||
self.seq_lens.fill_(1)
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
|
||||
Reference in New Issue
Block a user