Re-introduce get_cuda_graph_seq_len_fill_value (#1783)
This commit is contained in:
@@ -41,6 +41,10 @@ class AttentionBackend(ABC):
|
|||||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
|
|||||||
@@ -161,6 +161,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|||||||
self.cuda_graph_start_loc.zero_()
|
self.cuda_graph_start_loc.zero_()
|
||||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||||
|
|
||||||
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -210,6 +210,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -108,6 +108,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.cuda_graph_start_loc.zero_()
|
self.cuda_graph_start_loc.zero_()
|
||||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||||
|
|
||||||
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -134,7 +134,11 @@ class CudaGraphRunner:
|
|||||||
self.max_bs = max(self.capture_bs)
|
self.max_bs = max(self.capture_bs)
|
||||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
||||||
|
|
||||||
self.seq_len_fill_value = 1
|
self.seq_len_fill_value = (
|
||||||
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
|
)
|
||||||
|
|
||||||
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||||
self.encoder_len_fill_value = 0
|
self.encoder_len_fill_value = 0
|
||||||
|
|
||||||
if self.use_torch_compile:
|
if self.use_torch_compile:
|
||||||
@@ -287,7 +291,7 @@ class CudaGraphRunner:
|
|||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
bs = self.capture_bs[index]
|
bs = self.capture_bs[index]
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
self.seq_lens.fill_(1)
|
||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
|
|||||||
Reference in New Issue
Block a user