Use seq_len_fill_value in the cuda graph runners (#7233)
This commit is contained in:
@@ -1807,7 +1807,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
"""Get the fill value for sequence length in CUDA graph."""
|
||||
return 0
|
||||
return 1
|
||||
|
||||
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
|
||||
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
||||
|
||||
@@ -440,7 +440,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
raise ValueError("Invalid forward mode")
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 0
|
||||
return 1
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
|
||||
@@ -364,7 +364,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 0
|
||||
return 1
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
|
||||
@@ -612,7 +612,7 @@ class CudaGraphRunner:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(1)
|
||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
@@ -624,7 +624,7 @@ class CudaGraphRunner:
|
||||
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(1)
|
||||
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||
|
||||
if pp_proxy_tensors:
|
||||
@@ -652,7 +652,7 @@ class CudaGraphRunner:
|
||||
bs,
|
||||
self.req_pool_indices,
|
||||
self.seq_lens,
|
||||
forward_batch.seq_lens_sum + (bs - raw_bs),
|
||||
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
|
||||
self.encoder_lens,
|
||||
forward_batch.forward_mode,
|
||||
forward_batch.spec_info,
|
||||
|
||||
@@ -187,9 +187,8 @@ class EAGLEDraftCudaGraphRunner:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(1)
|
||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||
self.out_cache_loc.zero_()
|
||||
self.positions.zero_()
|
||||
|
||||
num_tokens = bs * self.num_tokens_per_bs
|
||||
|
||||
@@ -211,15 +210,15 @@ class EAGLEDraftCudaGraphRunner:
|
||||
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
||||
forward_batch.positions = self.positions[:num_tokens]
|
||||
|
||||
# Special handle for seq_len_cpu used when flashinfer mla is used
|
||||
if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(1)
|
||||
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
||||
|
||||
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
forward_batch, bs
|
||||
)
|
||||
# TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph
|
||||
|
||||
# Replay
|
||||
self.graphs[bs].replay()
|
||||
|
||||
@@ -207,9 +207,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs * self.num_tokens_per_bs != num_tokens:
|
||||
self.seq_lens.fill_(1)
|
||||
self.accept_length.fill_(1)
|
||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||
self.out_cache_loc.zero_()
|
||||
self.accept_length.fill_(1)
|
||||
|
||||
# Common inputs
|
||||
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
|
||||
@@ -223,18 +223,19 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(1)
|
||||
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||
|
||||
if bs != raw_bs:
|
||||
forward_batch.spec_info.positions = self.positions[:num_tokens]
|
||||
forward_batch.spec_info.accept_length = self.accept_length[:bs]
|
||||
forward_batch.spec_info.positions = None
|
||||
|
||||
self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
bs=bs,
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens_sum=forward_batch.seq_lens_sum + (bs - raw_bs),
|
||||
seq_lens_sum=forward_batch.seq_lens_sum
|
||||
+ (bs - raw_bs) * self.seq_len_fill_value,
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DRAFT_EXTEND,
|
||||
spec_info=forward_batch.spec_info,
|
||||
|
||||
@@ -166,6 +166,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def init_attention_backend(self):
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
|
||||
self.has_prefill_wrapper_verify = False
|
||||
self.draft_extend_attn_backend = None
|
||||
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
if not global_server_args_dict["use_mla_backend"]:
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
@@ -213,7 +217,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.has_prefill_wrapper_verify = False
|
||||
elif self.server_args.attention_backend == "fa3":
|
||||
from sglang.srt.layers.attention.flashattention_backend import (
|
||||
FlashAttentionBackend,
|
||||
@@ -229,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.has_prefill_wrapper_verify = False
|
||||
elif self.server_args.attention_backend == "flashmla":
|
||||
from sglang.srt.layers.attention.flashmla_backend import (
|
||||
FlashMLAMultiStepDraftBackend,
|
||||
@@ -240,8 +242,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.has_prefill_wrapper_verify = False
|
||||
else:
|
||||
raise ValueError(
|
||||
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
|
||||
|
||||
Reference in New Issue
Block a user