[Qwen3-Next] switch to triton and cache conv states to accelerate MTP from 300 tok/s to 341 tok/s (#10335)
Co-authored-by: Binyao Jiang <byjiang1996@gmail.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from sglang.srt.layers.attention.fla.fused_recurrent import (
|
|||||||
from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
|
from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
|
||||||
fused_sigmoid_gating_delta_rule_update,
|
fused_sigmoid_gating_delta_rule_update,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.attention.mamba.causal_conv1d import (
|
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
|
||||||
causal_conv1d_fn,
|
causal_conv1d_fn,
|
||||||
causal_conv1d_update,
|
causal_conv1d_update,
|
||||||
)
|
)
|
||||||
@@ -195,7 +195,9 @@ class MambaAttnBackend(AttentionBackend):
|
|||||||
dt_bias = kwargs["dt_bias"]
|
dt_bias = kwargs["dt_bias"]
|
||||||
layer_id = kwargs["layer_id"]
|
layer_id = kwargs["layer_id"]
|
||||||
|
|
||||||
conv_states, ssm_states = self.req_to_token_pool.get_mamba_params(layer_id)
|
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
|
||||||
|
layer_id
|
||||||
|
)
|
||||||
query_start_loc = self.forward_metadata.query_start_loc
|
query_start_loc = self.forward_metadata.query_start_loc
|
||||||
cache_indices = self.forward_metadata.mamba_cache_indices
|
cache_indices = self.forward_metadata.mamba_cache_indices
|
||||||
|
|
||||||
@@ -277,12 +279,9 @@ class MambaAttnBackend(AttentionBackend):
|
|||||||
(
|
(
|
||||||
conv_states,
|
conv_states,
|
||||||
ssm_states,
|
ssm_states,
|
||||||
mixed_qkv_cache,
|
|
||||||
intermediate_state_cache,
|
intermediate_state_cache,
|
||||||
|
intermediate_conv_window_cache,
|
||||||
) = self.req_to_token_pool.get_mamba_params(layer_id)
|
) = self.req_to_token_pool.get_mamba_params(layer_id)
|
||||||
mixed_qkv_cache[cache_indices] = mixed_qkv.view(
|
|
||||||
(-1,) + mixed_qkv_cache.shape[1:]
|
|
||||||
).clone()
|
|
||||||
has_initial_states = torch.ones(
|
has_initial_states = torch.ones(
|
||||||
seq_len // forward_batch.spec_info.draft_token_num,
|
seq_len // forward_batch.spec_info.draft_token_num,
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
@@ -295,16 +294,38 @@ class MambaAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
has_initial_states = forward_batch.extend_prefix_lens > 0
|
has_initial_states = forward_batch.extend_prefix_lens > 0
|
||||||
conv_states_to_use = conv_states
|
conv_states_to_use = conv_states
|
||||||
mixed_qkv = causal_conv1d_fn(
|
|
||||||
mixed_qkv.transpose(0, 1),
|
if is_target_verify:
|
||||||
conv_weights,
|
batch_size = seq_len // forward_batch.spec_info.draft_token_num
|
||||||
bias,
|
draft_token_num = forward_batch.spec_info.draft_token_num
|
||||||
activation=activation,
|
mixed_qkv_reshaped = (
|
||||||
conv_states=conv_states_to_use,
|
mixed_qkv.view(batch_size, draft_token_num, -1)
|
||||||
has_initial_state=has_initial_states,
|
.transpose(1, 2)
|
||||||
cache_indices=cache_indices,
|
.contiguous()
|
||||||
query_start_loc=query_start_loc,
|
)
|
||||||
).transpose(0, 1)[:seq_len]
|
mixed_qkv_processed = causal_conv1d_update(
|
||||||
|
mixed_qkv_reshaped,
|
||||||
|
conv_states_to_use,
|
||||||
|
conv_weights,
|
||||||
|
bias,
|
||||||
|
activation,
|
||||||
|
conv_state_indices=cache_indices[:batch_size],
|
||||||
|
intermediate_conv_window=intermediate_conv_window_cache,
|
||||||
|
)
|
||||||
|
mixed_qkv = (
|
||||||
|
mixed_qkv_processed.transpose(1, 2).contiguous().view(seq_len, -1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mixed_qkv = causal_conv1d_fn(
|
||||||
|
mixed_qkv.transpose(0, 1),
|
||||||
|
conv_weights,
|
||||||
|
bias,
|
||||||
|
activation=activation,
|
||||||
|
conv_states=conv_states_to_use,
|
||||||
|
has_initial_state=has_initial_states,
|
||||||
|
cache_indices=cache_indices,
|
||||||
|
query_start_loc=query_start_loc,
|
||||||
|
).transpose(0, 1)[:seq_len]
|
||||||
|
|
||||||
key_split_dim = key_dim // attn_tp_size
|
key_split_dim = key_dim // attn_tp_size
|
||||||
value_split_dim = value_dim // attn_tp_size
|
value_split_dim = value_dim // attn_tp_size
|
||||||
@@ -507,26 +528,6 @@ class HybridLinearAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
def update_mamba_state_after_mtp_verify(self, accepted_length, model):
|
def update_mamba_state_after_mtp_verify(self, accepted_length, model):
|
||||||
request_number = accepted_length.shape[0]
|
request_number = accepted_length.shape[0]
|
||||||
# QQ: step = spec num_draft token num
|
|
||||||
num_draft_tokens = (
|
|
||||||
self.attn_backend_list[1]
|
|
||||||
.req_to_token_pool.mamba_pool.mamba_cache[2]
|
|
||||||
.shape[2]
|
|
||||||
)
|
|
||||||
query_start_loc = accepted_length.cumsum(-1, dtype=accepted_length.dtype)
|
|
||||||
query_start_loc = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
1,
|
|
||||||
dtype=query_start_loc.dtype,
|
|
||||||
device=query_start_loc.device,
|
|
||||||
),
|
|
||||||
query_start_loc,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
mask = torch.arange(num_draft_tokens, device=accepted_length.device).unsqueeze(
|
|
||||||
0
|
|
||||||
) < accepted_length.unsqueeze(1)
|
|
||||||
|
|
||||||
state_indices_tensor = self.attn_backend_list[
|
state_indices_tensor = self.attn_backend_list[
|
||||||
1
|
1
|
||||||
@@ -536,46 +537,48 @@ class HybridLinearAttnBackend(AttentionBackend):
|
|||||||
1
|
1
|
||||||
].req_to_token_pool.get_mamba_params_all_layers()
|
].req_to_token_pool.get_mamba_params_all_layers()
|
||||||
|
|
||||||
conv_states, ssm_states, mix_qkv_cache, intermediate_state_cache = mamba_caches
|
(
|
||||||
|
conv_states,
|
||||||
|
ssm_states,
|
||||||
|
intermediate_state_cache,
|
||||||
|
intermediate_conv_window_cache,
|
||||||
|
) = mamba_caches
|
||||||
|
|
||||||
mixed_qkvs = mix_qkv_cache[:, state_indices_tensor][:, mask]
|
# SSM state updates (chunked to reduce peak memory)
|
||||||
|
|
||||||
mamba_map = self.attn_backend_list[1].req_to_token_pool.mamba_map
|
|
||||||
|
|
||||||
has_initial_states = torch.ones(
|
|
||||||
request_number, dtype=torch.bool, device=accepted_length.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Batch SSM state updates (outside the loop for efficiency)
|
|
||||||
valid_mask = accepted_length > 0
|
valid_mask = accepted_length > 0
|
||||||
if intermediate_state_cache is not None:
|
|
||||||
last_steps = (accepted_length - 1).to(torch.int64)
|
|
||||||
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
|
|
||||||
|
|
||||||
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
|
# Compute common indices once to avoid duplication
|
||||||
:, valid_state_indices, last_steps
|
last_steps_all = (accepted_length - 1).to(torch.int64)
|
||||||
].to(ssm_states.dtype)
|
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
|
||||||
|
last_steps = last_steps_all[valid_mask].to(torch.int64)
|
||||||
|
|
||||||
# For loop conv state updates (can be optimized)
|
if valid_state_indices.numel() > 0:
|
||||||
for i in range(len(model.model.layers)):
|
chunk = 256
|
||||||
layer = model.model.layers[i]
|
num_valid = valid_state_indices.numel()
|
||||||
if isinstance(layer, Qwen3HybridLinearDecoderLayer):
|
|
||||||
conv_weights = layer.linear_attn.conv1d.weight.view(
|
|
||||||
layer.linear_attn.conv1d.weight.size(0),
|
|
||||||
layer.linear_attn.conv1d.weight.size(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
layer_id = mamba_map[i]
|
# SSM state updates
|
||||||
conv_state = conv_states[layer_id]
|
for i in range(0, num_valid, chunk):
|
||||||
mixed_qkv = mixed_qkvs[layer_id]
|
idx = valid_state_indices[i : i + chunk]
|
||||||
|
steps = last_steps[i : i + chunk]
|
||||||
|
# per (cache line, step)
|
||||||
|
for j in range(idx.numel()):
|
||||||
|
ci = idx[j].item()
|
||||||
|
st = steps[j].item()
|
||||||
|
ssm_states[:, ci, :].copy_(
|
||||||
|
intermediate_state_cache[:, ci, st].to(
|
||||||
|
ssm_states.dtype, copy=False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
_ = causal_conv1d_fn(
|
# Conv window updates
|
||||||
mixed_qkv.transpose(0, 1),
|
for i in range(0, num_valid, chunk):
|
||||||
conv_weights,
|
idx = valid_state_indices[i : i + chunk]
|
||||||
layer.linear_attn.conv1d.bias,
|
steps = last_steps[i : i + chunk]
|
||||||
activation=layer.linear_attn.activation,
|
for j in range(idx.numel()):
|
||||||
conv_states=conv_state,
|
ci = idx[j].item()
|
||||||
has_initial_state=has_initial_states,
|
st = steps[j].item()
|
||||||
cache_indices=state_indices_tensor,
|
conv_states[:, ci, :, :].copy_(
|
||||||
query_start_loc=query_start_loc,
|
intermediate_conv_window_cache[:, ci, st].to(
|
||||||
)
|
conv_states.dtype, copy=False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
1052
python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py
Normal file
1052
python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -125,16 +125,6 @@ class MambaPool:
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
if speculative_num_draft_tokens is not None:
|
if speculative_num_draft_tokens is not None:
|
||||||
mixed_qkv_cache = torch.empty(
|
|
||||||
size=(
|
|
||||||
num_mamba_layers,
|
|
||||||
size + 1,
|
|
||||||
speculative_num_draft_tokens,
|
|
||||||
conv_state_shape[0],
|
|
||||||
),
|
|
||||||
dtype=conv_dtype,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
# Cache intermediate SSM states per draft token during target verify
|
# Cache intermediate SSM states per draft token during target verify
|
||||||
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
||||||
intermediate_ssm_state_cache = torch.empty(
|
intermediate_ssm_state_cache = torch.empty(
|
||||||
@@ -149,11 +139,24 @@ class MambaPool:
|
|||||||
dtype=ssm_dtype,
|
dtype=ssm_dtype,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
||||||
|
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
||||||
|
intermediate_conv_window_cache = torch.empty(
|
||||||
|
size=(
|
||||||
|
num_mamba_layers,
|
||||||
|
size + 1,
|
||||||
|
speculative_num_draft_tokens,
|
||||||
|
conv_state_shape[0],
|
||||||
|
conv_state_shape[1],
|
||||||
|
),
|
||||||
|
dtype=conv_dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
self.mamba_cache = (
|
self.mamba_cache = (
|
||||||
conv_state,
|
conv_state,
|
||||||
temporal_state,
|
temporal_state,
|
||||||
mixed_qkv_cache,
|
|
||||||
intermediate_ssm_state_cache,
|
intermediate_ssm_state_cache,
|
||||||
|
intermediate_conv_window_cache,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.mamba_cache = (conv_state, temporal_state)
|
self.mamba_cache = (conv_state, temporal_state)
|
||||||
|
|||||||
@@ -1,195 +0,0 @@
|
|||||||
import bisect
|
|
||||||
from typing import TYPE_CHECKING, Callable
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from sglang.srt.layers.attention.fla.fused_recurrent import (
|
|
||||||
fused_recurrent_gated_delta_rule_update,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.attention.mamba.causal_conv1d import causal_conv1d_fn
|
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
|
||||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
|
||||||
CudaGraphRunner,
|
|
||||||
get_batch_sizes_to_capture,
|
|
||||||
get_global_graph_memory_pool,
|
|
||||||
model_capture_mode,
|
|
||||||
set_global_graph_memory_pool,
|
|
||||||
)
|
|
||||||
from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
|
||||||
|
|
||||||
|
|
||||||
class MambaStateUpdateCudaGraphRunner:
|
|
||||||
def __init__(self, eagle_worker: "EAGLEWorker"):
|
|
||||||
self.eagle_worker = eagle_worker
|
|
||||||
model_runner = eagle_worker.target_worker.model_runner
|
|
||||||
self.model_runner = model_runner
|
|
||||||
self.attn_backend = model_runner.attn_backend.attn_backend_list[1]
|
|
||||||
self.req_to_token_pool = self.attn_backend.req_to_token_pool
|
|
||||||
|
|
||||||
self.graphs = {}
|
|
||||||
self.output_buffers = {}
|
|
||||||
self.graph_input_buffer = None
|
|
||||||
self.stream = torch.cuda.Stream()
|
|
||||||
self.model = model_runner.model
|
|
||||||
|
|
||||||
self.enable_profile_cuda_graph = (
|
|
||||||
model_runner.server_args.enable_profile_cuda_graph
|
|
||||||
)
|
|
||||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
|
||||||
self.max_bs = self.capture_bs[-1]
|
|
||||||
|
|
||||||
self.init_cuda_graph_state()
|
|
||||||
# Capture
|
|
||||||
try:
|
|
||||||
with model_capture_mode():
|
|
||||||
self.capture()
|
|
||||||
except RuntimeError as e:
|
|
||||||
raise Exception(
|
|
||||||
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_cuda_graph_state(self):
|
|
||||||
self.mamba_cache = self.req_to_token_pool.mamba_pool.mamba_cache
|
|
||||||
self.num_tokens_per_bs = self.max_accepted_tokens = self.mamba_cache[2].shape[2]
|
|
||||||
num_mamba_layers = self.mamba_cache[0].shape[0]
|
|
||||||
conv_dtype = torch.bfloat16
|
|
||||||
conv_shape = self.mamba_cache[0].shape[2]
|
|
||||||
total_token_number = self.max_accepted_tokens * self.max_bs
|
|
||||||
self.mixed_qkv_cache = torch.empty(
|
|
||||||
size=(
|
|
||||||
num_mamba_layers,
|
|
||||||
total_token_number,
|
|
||||||
conv_shape,
|
|
||||||
),
|
|
||||||
dtype=conv_dtype,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
self.query_start_loc = torch.zeros(
|
|
||||||
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
self.state_indices = torch.zeros(
|
|
||||||
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
self.has_initial_states = torch.ones(
|
|
||||||
self.max_bs, dtype=torch.bool, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
def capture(self):
|
|
||||||
CudaGraphRunner.capture(self)
|
|
||||||
|
|
||||||
def capture_one_batch_size(self, bs: int, forward: Callable):
|
|
||||||
"""
|
|
||||||
Capture CUDA Graph for a typical workload
|
|
||||||
"""
|
|
||||||
graph = torch.cuda.CUDAGraph()
|
|
||||||
stream = self.stream
|
|
||||||
total_token_number = bs * self.max_accepted_tokens
|
|
||||||
mixed_qkvs = self.mixed_qkv_cache[:, :total_token_number]
|
|
||||||
|
|
||||||
query_start_loc = self.query_start_loc[: bs + 1]
|
|
||||||
state_indices = self.state_indices[:bs]
|
|
||||||
has_initial_states = self.has_initial_states[:bs]
|
|
||||||
|
|
||||||
mamba_caches = self.req_to_token_pool.get_mamba_params_all_layers()
|
|
||||||
conv_states = mamba_caches[0]
|
|
||||||
mamba_map = self.req_to_token_pool.mamba_map
|
|
||||||
|
|
||||||
def run_once():
|
|
||||||
for i in range(len(self.model.model.layers)):
|
|
||||||
layer = self.model.model.layers[i]
|
|
||||||
if not isinstance(layer, Qwen3HybridLinearDecoderLayer):
|
|
||||||
continue
|
|
||||||
conv_weights = layer.linear_attn.conv1d.weight.view(
|
|
||||||
layer.linear_attn.conv1d.weight.size(0),
|
|
||||||
layer.linear_attn.conv1d.weight.size(2),
|
|
||||||
)
|
|
||||||
layer_id = mamba_map[i]
|
|
||||||
|
|
||||||
causal_conv1d_fn(
|
|
||||||
mixed_qkvs[layer_id].transpose(0, 1),
|
|
||||||
conv_weights,
|
|
||||||
layer.linear_attn.conv1d.bias,
|
|
||||||
activation=layer.linear_attn.activation,
|
|
||||||
conv_states=conv_states[layer_id],
|
|
||||||
has_initial_state=has_initial_states,
|
|
||||||
cache_indices=state_indices,
|
|
||||||
query_start_loc=query_start_loc,
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
for _ in range(2):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.model_runner.tp_group.barrier()
|
|
||||||
|
|
||||||
run_once()
|
|
||||||
|
|
||||||
with torch.cuda.graph(
|
|
||||||
graph, pool=get_global_graph_memory_pool(), stream=stream
|
|
||||||
):
|
|
||||||
out = run_once()
|
|
||||||
|
|
||||||
set_global_graph_memory_pool(graph.pool())
|
|
||||||
return graph, out
|
|
||||||
|
|
||||||
def can_run(self, accepted_length):
|
|
||||||
bs = accepted_length.shape[0]
|
|
||||||
return bs <= self.max_bs
|
|
||||||
|
|
||||||
def replay_repare(self, accepted_length):
|
|
||||||
request_number = accepted_length.shape[0]
|
|
||||||
# QQ: step = spec num_draft token num
|
|
||||||
num_draft_tokens = self.req_to_token_pool.mamba_pool.mamba_cache[2].shape[2]
|
|
||||||
query_start_loc = accepted_length.cumsum(-1, dtype=accepted_length.dtype)
|
|
||||||
query_start_loc = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
1,
|
|
||||||
dtype=query_start_loc.dtype,
|
|
||||||
device=query_start_loc.device,
|
|
||||||
),
|
|
||||||
query_start_loc,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
mask = torch.arange(num_draft_tokens, device=accepted_length.device).unsqueeze(
|
|
||||||
0
|
|
||||||
) < accepted_length.unsqueeze(1)
|
|
||||||
|
|
||||||
state_indices_tensor = self.attn_backend.forward_metadata.mamba_cache_indices[
|
|
||||||
:request_number
|
|
||||||
]
|
|
||||||
mamba_caches = self.req_to_token_pool.get_mamba_params_all_layers()
|
|
||||||
|
|
||||||
_, ssm_states, mix_qkv_cache, intermediate_state_cache = mamba_caches
|
|
||||||
mixed_qkvs = mamba_caches[2][:, state_indices_tensor][:, mask]
|
|
||||||
self.mixed_qkv_cache[:, : mixed_qkvs.shape[1]].copy_(mixed_qkvs)
|
|
||||||
self.query_start_loc[: request_number + 1] = query_start_loc
|
|
||||||
self.query_start_loc[request_number + 1 :] = self.query_start_loc[
|
|
||||||
request_number
|
|
||||||
]
|
|
||||||
self.state_indices[:request_number] = state_indices_tensor
|
|
||||||
self.state_indices[request_number:] = -1
|
|
||||||
valid_mask = accepted_length > 0
|
|
||||||
if intermediate_state_cache is not None:
|
|
||||||
last_steps = (accepted_length - 1).to(torch.int64)
|
|
||||||
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
|
|
||||||
|
|
||||||
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
|
|
||||||
:, valid_state_indices, last_steps
|
|
||||||
].to(ssm_states.dtype)
|
|
||||||
|
|
||||||
def replay(self, accepted_length):
|
|
||||||
# batch_size and num_seqs can be different in case there are finished examples
|
|
||||||
# in the batch, which will not be counted as num_seqs
|
|
||||||
raw_bs = accepted_length.shape[0]
|
|
||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
|
||||||
|
|
||||||
bs = self.capture_bs[index]
|
|
||||||
|
|
||||||
self.replay_repare(accepted_length)
|
|
||||||
# Replay
|
|
||||||
self.graphs[bs].replay()
|
|
||||||
@@ -407,15 +407,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.target_worker.model_runner.is_hybrid_gdn:
|
|
||||||
from sglang.srt.speculative.eagle_target_verify_cuda_graph_runner import (
|
|
||||||
MambaStateUpdateCudaGraphRunner,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.cuda_graph_runner_for_target_verify = MambaStateUpdateCudaGraphRunner(
|
|
||||||
self
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def draft_model_runner(self):
|
def draft_model_runner(self):
|
||||||
return self.model_runner
|
return self.model_runner
|
||||||
@@ -848,12 +839,9 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
)
|
)
|
||||||
+ 1
|
+ 1
|
||||||
)
|
)
|
||||||
if self.cuda_graph_runner_for_target_verify.can_run(accepted_length):
|
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
|
||||||
self.cuda_graph_runner_for_target_verify.replay(accepted_length)
|
accepted_length, self.target_worker.model_runner.model
|
||||||
else:
|
)
|
||||||
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
|
|
||||||
accepted_length, self.target_worker.model_runner.model
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch.return_logprob:
|
if batch.return_logprob:
|
||||||
self.add_logprob_values(batch, res, logits_output)
|
self.add_logprob_values(batch, res, logits_output)
|
||||||
|
|||||||
Reference in New Issue
Block a user