Use seq_len_fill_value in the cuda graph runners (#7233)
This commit is contained in:
@@ -1807,7 +1807,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
"""Get the fill value for sequence length in CUDA graph."""
|
"""Get the fill value for sequence length in CUDA graph."""
|
||||||
return 0
|
return 1
|
||||||
|
|
||||||
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
|
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
|
||||||
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
||||||
|
|||||||
@@ -440,7 +440,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
raise ValueError("Invalid forward mode")
|
raise ValueError("Invalid forward mode")
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
return 0
|
return 1
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -364,7 +364,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
return 0
|
return 1
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -612,7 +612,7 @@ class CudaGraphRunner:
|
|||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
bs = self.capture_bs[index]
|
bs = self.capture_bs[index]
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens.fill_(1)
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
@@ -624,7 +624,7 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
if forward_batch.seq_lens_cpu is not None:
|
if forward_batch.seq_lens_cpu is not None:
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens_cpu.fill_(1)
|
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||||
|
|
||||||
if pp_proxy_tensors:
|
if pp_proxy_tensors:
|
||||||
@@ -652,7 +652,7 @@ class CudaGraphRunner:
|
|||||||
bs,
|
bs,
|
||||||
self.req_pool_indices,
|
self.req_pool_indices,
|
||||||
self.seq_lens,
|
self.seq_lens,
|
||||||
forward_batch.seq_lens_sum + (bs - raw_bs),
|
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
|
||||||
self.encoder_lens,
|
self.encoder_lens,
|
||||||
forward_batch.forward_mode,
|
forward_batch.forward_mode,
|
||||||
forward_batch.spec_info,
|
forward_batch.spec_info,
|
||||||
|
|||||||
@@ -187,9 +187,8 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
bs = self.capture_bs[index]
|
bs = self.capture_bs[index]
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens.fill_(1)
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
self.positions.zero_()
|
|
||||||
|
|
||||||
num_tokens = bs * self.num_tokens_per_bs
|
num_tokens = bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
@@ -211,15 +210,15 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
||||||
forward_batch.positions = self.positions[:num_tokens]
|
forward_batch.positions = self.positions[:num_tokens]
|
||||||
|
|
||||||
# Special handle for seq_len_cpu used when flashinfer mla is used
|
|
||||||
if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
|
if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
|
||||||
self.seq_lens_cpu.fill_(1)
|
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||||
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
||||||
|
|
||||||
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
forward_batch, bs
|
forward_batch, bs
|
||||||
)
|
)
|
||||||
|
# TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph
|
||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
self.graphs[bs].replay()
|
self.graphs[bs].replay()
|
||||||
|
|||||||
@@ -207,9 +207,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
bs = self.capture_bs[index]
|
bs = self.capture_bs[index]
|
||||||
if bs * self.num_tokens_per_bs != num_tokens:
|
if bs * self.num_tokens_per_bs != num_tokens:
|
||||||
self.seq_lens.fill_(1)
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||||
self.accept_length.fill_(1)
|
|
||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
|
self.accept_length.fill_(1)
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
|
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
|
||||||
@@ -223,18 +223,19 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
|
|
||||||
if forward_batch.seq_lens_cpu is not None:
|
if forward_batch.seq_lens_cpu is not None:
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens_cpu.fill_(1)
|
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||||
|
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
|
forward_batch.spec_info.positions = self.positions[:num_tokens]
|
||||||
forward_batch.spec_info.accept_length = self.accept_length[:bs]
|
forward_batch.spec_info.accept_length = self.accept_length[:bs]
|
||||||
forward_batch.spec_info.positions = None
|
|
||||||
|
|
||||||
self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
bs=bs,
|
bs=bs,
|
||||||
req_pool_indices=self.req_pool_indices,
|
req_pool_indices=self.req_pool_indices,
|
||||||
seq_lens=self.seq_lens,
|
seq_lens=self.seq_lens,
|
||||||
seq_lens_sum=forward_batch.seq_lens_sum + (bs - raw_bs),
|
seq_lens_sum=forward_batch.seq_lens_sum
|
||||||
|
+ (bs - raw_bs) * self.seq_len_fill_value,
|
||||||
encoder_lens=None,
|
encoder_lens=None,
|
||||||
forward_mode=ForwardMode.DRAFT_EXTEND,
|
forward_mode=ForwardMode.DRAFT_EXTEND,
|
||||||
spec_info=forward_batch.spec_info,
|
spec_info=forward_batch.spec_info,
|
||||||
|
|||||||
@@ -166,6 +166,10 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
|
|
||||||
def init_attention_backend(self):
|
def init_attention_backend(self):
|
||||||
# Create multi-step attn backends and cuda graph runners
|
# Create multi-step attn backends and cuda graph runners
|
||||||
|
|
||||||
|
self.has_prefill_wrapper_verify = False
|
||||||
|
self.draft_extend_attn_backend = None
|
||||||
|
|
||||||
if self.server_args.attention_backend == "flashinfer":
|
if self.server_args.attention_backend == "flashinfer":
|
||||||
if not global_server_args_dict["use_mla_backend"]:
|
if not global_server_args_dict["use_mla_backend"]:
|
||||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
@@ -213,7 +217,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.draft_model_runner,
|
self.draft_model_runner,
|
||||||
skip_prefill=False,
|
skip_prefill=False,
|
||||||
)
|
)
|
||||||
self.has_prefill_wrapper_verify = False
|
|
||||||
elif self.server_args.attention_backend == "fa3":
|
elif self.server_args.attention_backend == "fa3":
|
||||||
from sglang.srt.layers.attention.flashattention_backend import (
|
from sglang.srt.layers.attention.flashattention_backend import (
|
||||||
FlashAttentionBackend,
|
FlashAttentionBackend,
|
||||||
@@ -229,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.draft_model_runner,
|
self.draft_model_runner,
|
||||||
skip_prefill=False,
|
skip_prefill=False,
|
||||||
)
|
)
|
||||||
self.has_prefill_wrapper_verify = False
|
|
||||||
elif self.server_args.attention_backend == "flashmla":
|
elif self.server_args.attention_backend == "flashmla":
|
||||||
from sglang.srt.layers.attention.flashmla_backend import (
|
from sglang.srt.layers.attention.flashmla_backend import (
|
||||||
FlashMLAMultiStepDraftBackend,
|
FlashMLAMultiStepDraftBackend,
|
||||||
@@ -240,8 +242,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.topk,
|
self.topk,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
)
|
)
|
||||||
self.draft_extend_attn_backend = None
|
|
||||||
self.has_prefill_wrapper_verify = False
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
|
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
|
||||||
|
|||||||
Reference in New Issue
Block a user