Sync the changes on cuda graph runners (#6932)
This commit is contained in:
@@ -20,7 +20,7 @@ import copy
|
|||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from sglang.srt.mm_utils import has_valid_data
|
from sglang.srt.mm_utils import has_valid_data
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
Image = Any
|
Image = Any
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason, flatten_nested_list
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -259,23 +259,8 @@ class CudaGraphRunner:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Speculative_inference
|
# Speculative_inference
|
||||||
if (
|
if model_runner.spec_algorithm.is_eagle3():
|
||||||
model_runner.spec_algorithm.is_eagle3()
|
|
||||||
and not model_runner.is_draft_worker
|
|
||||||
):
|
|
||||||
self.hidden_states = torch.zeros(
|
|
||||||
(
|
|
||||||
self.max_num_token,
|
|
||||||
3 * self.model_runner.model_config.hidden_size,
|
|
||||||
),
|
|
||||||
dtype=self.model_runner.dtype,
|
|
||||||
)
|
|
||||||
self.model_runner.model.set_eagle3_layers_to_capture()
|
self.model_runner.model.set_eagle3_layers_to_capture()
|
||||||
elif model_runner.spec_algorithm.is_eagle():
|
|
||||||
self.hidden_states = torch.zeros(
|
|
||||||
(self.max_num_token, self.model_runner.model_config.hidden_size),
|
|
||||||
dtype=self.model_runner.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
||||||
@@ -284,6 +269,7 @@ class CudaGraphRunner:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.encoder_lens = None
|
self.encoder_lens = None
|
||||||
|
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
||||||
self.gathered_buffer = torch.zeros(
|
self.gathered_buffer = torch.zeros(
|
||||||
@@ -303,13 +289,7 @@ class CudaGraphRunner:
|
|||||||
self.capture()
|
self.capture()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Capture CUDA graph failed: {e}\n"
|
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||||
"Possible solutions:\n"
|
|
||||||
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
|
||||||
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
|
|
||||||
"3. disable torch compile by not using --enable-torch-compile\n"
|
|
||||||
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
|
|
||||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@@ -439,6 +419,7 @@ class CudaGraphRunner:
|
|||||||
self.capture_hidden_mode = (
|
self.capture_hidden_mode = (
|
||||||
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.model_runner.server_args.lora_paths is not None:
|
if self.model_runner.server_args.lora_paths is not None:
|
||||||
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
|
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
|
||||||
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
|
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
|
||||||
@@ -467,9 +448,9 @@ class CudaGraphRunner:
|
|||||||
spec_algorithm=self.model_runner.spec_algorithm,
|
spec_algorithm=self.model_runner.spec_algorithm,
|
||||||
spec_info=spec_info,
|
spec_info=spec_info,
|
||||||
capture_hidden_mode=self.capture_hidden_mode,
|
capture_hidden_mode=self.capture_hidden_mode,
|
||||||
lora_paths=lora_paths,
|
|
||||||
num_token_non_padded=self.num_token_non_padded,
|
num_token_non_padded=self.num_token_non_padded,
|
||||||
global_forward_mode=self.capture_forward_mode,
|
global_forward_mode=self.capture_forward_mode,
|
||||||
|
lora_paths=lora_paths,
|
||||||
)
|
)
|
||||||
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
||||||
|
|
||||||
@@ -497,7 +478,9 @@ class CudaGraphRunner:
|
|||||||
self.pp_size > 1
|
self.pp_size > 1
|
||||||
and "pp_proxy_tensors" in inspect.signature(forward).parameters
|
and "pp_proxy_tensors" in inspect.signature(forward).parameters
|
||||||
):
|
):
|
||||||
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
kwargs["pp_proxy_tensors"] = PPProxyTensors(
|
||||||
|
{k: v.clone() for k, v in pp_proxy_tensors.tensors.items()}
|
||||||
|
)
|
||||||
|
|
||||||
logits_output_or_pp_proxy_tensors = forward(
|
logits_output_or_pp_proxy_tensors = forward(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -590,9 +573,6 @@ class CudaGraphRunner:
|
|||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||||
|
|
||||||
if hasattr(forward_batch.spec_info, "hidden_states"):
|
|
||||||
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
|
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
@@ -650,7 +630,7 @@ class CudaGraphRunner:
|
|||||||
else:
|
else:
|
||||||
spec_info = EagleVerifyInput(
|
spec_info = EagleVerifyInput(
|
||||||
draft_token=None,
|
draft_token=None,
|
||||||
custom_mask=torch.zeros(
|
custom_mask=torch.ones(
|
||||||
(num_tokens * self.model_runner.model_config.context_len),
|
(num_tokens * self.model_runner.model_config.context_len),
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
@@ -660,9 +640,20 @@ class CudaGraphRunner:
|
|||||||
retrive_next_token=None,
|
retrive_next_token=None,
|
||||||
retrive_next_sibling=None,
|
retrive_next_sibling=None,
|
||||||
retrive_cum_len=None,
|
retrive_cum_len=None,
|
||||||
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
|
||||||
spec_steps=self.model_runner.server_args.speculative_num_steps,
|
spec_steps=self.model_runner.server_args.speculative_num_steps,
|
||||||
|
topk=self.model_runner.server_args.speculative_eagle_topk,
|
||||||
|
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
||||||
capture_hidden_mode=CaptureHiddenMode.FULL,
|
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||||
)
|
)
|
||||||
|
|
||||||
return spec_info
|
return spec_info
|
||||||
|
|
||||||
|
|
||||||
|
CUDA_GRAPH_CAPTURE_FAILED_MSG = (
|
||||||
|
"Possible solutions:\n"
|
||||||
|
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
||||||
|
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
|
||||||
|
"3. disable torch compile by not using --enable-torch-compile\n"
|
||||||
|
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
|
||||||
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||||
|
)
|
||||||
|
|||||||
@@ -447,7 +447,7 @@ class ServerArgs:
|
|||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
self.speculative_eagle_topk,
|
self.speculative_eagle_topk,
|
||||||
self.speculative_num_draft_tokens,
|
self.speculative_num_draft_tokens,
|
||||||
) = auto_choose_speculative_params(model_arch)
|
) = auto_choose_speculative_params(self)
|
||||||
|
|
||||||
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
||||||
self.speculative_eagle_topk = 1
|
self.speculative_eagle_topk = 1
|
||||||
@@ -1655,12 +1655,23 @@ def get_model_arch(args: ServerArgs):
|
|||||||
return hf_config.architectures[0]
|
return hf_config.architectures[0]
|
||||||
|
|
||||||
|
|
||||||
def auto_choose_speculative_params(arch: str):
|
def auto_choose_speculative_params(self: ServerArgs):
|
||||||
"""
|
"""
|
||||||
Automatically choose the parameters for speculative decoding.
|
Automatically choose the parameters for speculative decoding.
|
||||||
|
|
||||||
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
|
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
|
||||||
"""
|
"""
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
hf_config = get_config(
|
||||||
|
self.model_path,
|
||||||
|
trust_remote_code=self.trust_remote_code,
|
||||||
|
revision=self.revision,
|
||||||
|
model_override_args=json.loads(self.json_model_override_args),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
arch = hf_config.architectures[0]
|
||||||
|
|
||||||
if arch in ["LlamaForCausalLM"]:
|
if arch in ["LlamaForCausalLM"]:
|
||||||
# The default value for llama
|
# The default value for llama
|
||||||
return (5, 4, 8)
|
return (5, 4, 8)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.utils import is_cuda, is_hip
|
from sglang.srt.utils import is_cuda, is_hip, rank0_print
|
||||||
|
|
||||||
if is_cuda() or is_hip():
|
if is_cuda() or is_hip():
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
@@ -344,13 +344,13 @@ def test_build_tree_kernel_efficient():
|
|||||||
num_verify_tokens=num_draft_token,
|
num_verify_tokens=num_draft_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
first_rank_print("=========== build tree kernel efficient ==========")
|
rank0_print("=========== build tree kernel efficient ==========")
|
||||||
# first_rank_print(f"{tree_mask=}", flush=True)
|
# rank0_print(f"{tree_mask=}", flush=True)
|
||||||
first_rank_print(f"{position=}", flush=True)
|
rank0_print(f"{position=}", flush=True)
|
||||||
first_rank_print(f"{retrive_index=}", flush=True)
|
rank0_print(f"{retrive_index=}", flush=True)
|
||||||
first_rank_print(f"{retrive_next_token=}", flush=True)
|
rank0_print(f"{retrive_next_token=}", flush=True)
|
||||||
first_rank_print(f"{retrive_next_sibling=}", flush=True)
|
rank0_print(f"{retrive_next_sibling=}", flush=True)
|
||||||
first_rank_print(f"{draft_tokens=}", flush=True)
|
rank0_print(f"{draft_tokens=}", flush=True)
|
||||||
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
||||||
assert retrive_index.tolist() == [
|
assert retrive_index.tolist() == [
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7],
|
[0, 1, 2, 3, 4, 5, 6, 7],
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Callable
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||||
|
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||||
CudaGraphRunner,
|
CudaGraphRunner,
|
||||||
get_batch_sizes_to_capture,
|
get_batch_sizes_to_capture,
|
||||||
get_global_graph_memory_pool,
|
get_global_graph_memory_pool,
|
||||||
@@ -73,7 +74,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
||||||
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
||||||
self.hidden_states = torch.zeros(
|
self.hidden_states = torch.zeros(
|
||||||
(self.max_bs, self.model_runner.model_config.hidden_size),
|
(self.max_num_token, self.model_runner.model_config.hidden_size),
|
||||||
dtype=self.model_runner.dtype,
|
dtype=self.model_runner.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -82,13 +83,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
self.capture()
|
self.capture()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Capture CUDA graph failed: {e}\n"
|
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||||
"Possible solutions:\n"
|
|
||||||
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
|
||||||
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
|
|
||||||
"3. disable torch compile by not using --enable-torch-compile\n"
|
|
||||||
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
|
|
||||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Callable
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||||
|
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||||
CudaGraphRunner,
|
CudaGraphRunner,
|
||||||
LogitsProcessorOutput,
|
LogitsProcessorOutput,
|
||||||
get_batch_sizes_to_capture,
|
get_batch_sizes_to_capture,
|
||||||
@@ -89,13 +90,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
self.capture()
|
self.capture()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Capture CUDA graph failed: {e}\n"
|
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||||
"Possible solutions:\n"
|
|
||||||
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
|
||||||
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
|
|
||||||
"3. disable torch compile by not using --enable-torch-compile\n"
|
|
||||||
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
|
|
||||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
@@ -200,7 +195,6 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
# in the batch, which will not be counted as num_seqs
|
# in the batch, which will not be counted as num_seqs
|
||||||
raw_bs = forward_batch.batch_size
|
raw_bs = forward_batch.batch_size
|
||||||
num_tokens = forward_batch.input_ids.shape[0]
|
num_tokens = forward_batch.input_ids.shape[0]
|
||||||
assert raw_bs * self.num_tokens_per_bs == num_tokens
|
|
||||||
|
|
||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
bs = self.capture_bs[index]
|
bs = self.capture_bs[index]
|
||||||
@@ -224,9 +218,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
self.seq_lens_cpu.fill_(1)
|
self.seq_lens_cpu.fill_(1)
|
||||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||||
|
|
||||||
forward_batch.spec_info.positions = None
|
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
forward_batch.spec_info.accept_length = self.accept_length[:bs]
|
forward_batch.spec_info.accept_length = self.accept_length[:bs]
|
||||||
|
forward_batch.spec_info.positions = None
|
||||||
|
|
||||||
self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
bs=bs,
|
bs=bs,
|
||||||
|
|||||||
@@ -232,8 +232,9 @@ class EagleVerifyInput:
|
|||||||
retrive_next_token: torch.Tensor
|
retrive_next_token: torch.Tensor
|
||||||
retrive_next_sibling: torch.Tensor
|
retrive_next_sibling: torch.Tensor
|
||||||
retrive_cum_len: torch.Tensor
|
retrive_cum_len: torch.Tensor
|
||||||
draft_token_num: int
|
|
||||||
spec_steps: int
|
spec_steps: int
|
||||||
|
topk: int
|
||||||
|
draft_token_num: int
|
||||||
capture_hidden_mode: CaptureHiddenMode
|
capture_hidden_mode: CaptureHiddenMode
|
||||||
grammar: BaseGrammarObject = None
|
grammar: BaseGrammarObject = None
|
||||||
|
|
||||||
@@ -270,16 +271,17 @@ class EagleVerifyInput:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
draft_tokens,
|
draft_token=draft_tokens,
|
||||||
tree_mask,
|
custom_mask=tree_mask,
|
||||||
position,
|
positions=position,
|
||||||
retrive_index,
|
retrive_index=retrive_index,
|
||||||
retrive_next_token,
|
retrive_next_token=retrive_next_token,
|
||||||
retrive_next_sibling,
|
retrive_next_sibling=retrive_next_sibling,
|
||||||
None,
|
retrive_cum_len=None,
|
||||||
num_verify_tokens,
|
spec_steps=spec_steps,
|
||||||
spec_steps,
|
topk=topk,
|
||||||
CaptureHiddenMode.FULL,
|
draft_token_num=num_verify_tokens,
|
||||||
|
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
||||||
|
|||||||
Reference in New Issue
Block a user