Files
sglang/python/sglang/srt/model_executor/cpu_graph_runner.py

641 lines
23 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run the model with cpu torch compile."""
# The implementation of CPUGraphRunner follows the CudaGraphRunner
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Optional, Union
import psutil
import torch
import tqdm
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
log_info_on_rank0,
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_sync,
require_mlp_tp_gather,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
@contextmanager
def patch_model(
model: torch.nn.Module,
enable_compile: bool,
num_tokens: int,
tp_group: GroupCoordinator,
):
"""Patch the model to make it compatible with torch.compile"""
backup_ca_comm = None
try:
if enable_compile:
backup_ca_comm = tp_group.ca_comm
# Use custom-allreduce here.
# We found the custom allreduce is much faster than the built-in allreduce in torch,
# even with ENABLE_INTRA_NODE_COMM=1.
# tp_group.ca_comm = None
yield torch.compile(
torch.no_grad()(model.forward),
dynamic=False,
)
else:
yield model.forward
finally:
if enable_compile:
tp_group.ca_comm = backup_ca_comm
def set_torch_compile_config():
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
torch._inductor.config.freezing = True
torch._dynamo.config.accumulated_cache_size_limit = 1024
if hasattr(torch._dynamo.config, "cache_size_limit"):
torch._dynamo.config.cache_size_limit = 1024
monkey_patch_torch_compile()
def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args
# cpu torch compile only speeds up decoding by
# reducing python overhead when bs is small
capture_bs = list(range(1, 17))
capture_bs = [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
capture_bs = list(sorted(set(capture_bs)))
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
return capture_bs
def register_fake_ops():
"""
Registers fake/meta implementations for all custom sgl_kernel CPU operators
using torch.library.register_fake to support torch.compile
"""
none_return_ops = [
"shm_allreduce",
"bmm_cpu",
"fused_add_rmsnorm_cpu",
"decode_attention_cpu",
"extend_attention_cpu",
]
for op in none_return_ops:
@torch.library.register_fake(f"sgl_kernel::{op}")
def _(*args, **kwargs):
return
for op in [
"rmsnorm_cpu",
"l2norm_cpu",
"fused_experts_cpu",
"shared_expert_cpu",
]:
@torch.library.register_fake(f"sgl_kernel::{op}")
def _(input, *args, **kwargs):
return torch.empty_like(input)
@torch.library.register_fake("sgl_kernel::qkv_proj_with_rope")
def _(
hidden_states,
q_a_proj_weight,
q_b_proj_weight,
kv_a_proj_weight,
w_kc,
q_a_layernorm_weight,
kv_a_layernorm_weight,
positions,
cos_sin_cache,
eps,
use_int8_w8a8,
use_fp8_w8a16,
q_a_proj_scale,
q_b_proj_scale,
kv_a_proj_scale,
is_vnni,
block_size,
):
num_seqs = hidden_states.shape[0]
num_heads = w_kc.shape[0]
kv_lora_rank = w_kc.shape[1]
qk_rope_head_dim = kv_a_proj_weight.shape[0] - kv_lora_rank
q_input = torch.empty(
num_seqs,
num_heads,
kv_lora_rank + qk_rope_head_dim,
dtype=hidden_states.dtype,
device=hidden_states.device,
)
k_input = torch.empty(
num_seqs,
1,
kv_lora_rank + qk_rope_head_dim,
dtype=hidden_states.dtype,
device=hidden_states.device,
)
v_input = k_input.narrow(-1, 0, kv_lora_rank)
return q_input, k_input, v_input
@torch.library.register_fake("sgl_kernel::rotary_embedding_cpu")
def _(positions, query, key, head_size, cos_sin_cache, is_neox):
if query.ndim == 2:
return query, key
else:
return torch.empty_like(query), torch.empty_like(key)
@torch.library.register_fake("sgl_kernel::qkv_proj_with_rope_fused_weight")
def _(
hidden_states,
q_a_proj_weight,
q_b_proj_weight,
w_kc,
q_a_layernorm_weight,
kv_a_layernorm_weight,
positions,
cos_sin_cache,
eps,
use_int8_w8a8,
use_fp8_w8a16,
qkv_a_proj_scale,
q_b_proj_scale,
is_vnni,
block_size,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
):
num_seqs = hidden_states.shape[0]
num_heads = w_kc.shape[0]
kv_lora_rank = w_kc.shape[1]
weight_chunks = torch.split(
q_a_proj_weight, [q_lora_rank, kv_lora_rank + qk_rope_head_dim], dim=0
)
qk_rope_head_dim = weight_chunks[1].shape[0] - kv_lora_rank
q_input = torch.empty(
num_seqs,
num_heads,
kv_lora_rank + qk_rope_head_dim,
dtype=hidden_states.dtype,
device=hidden_states.device,
)
k_input = torch.empty(
num_seqs,
1,
kv_lora_rank + qk_rope_head_dim,
dtype=hidden_states.dtype,
device=hidden_states.device,
)
v_input = k_input.narrow(-1, 0, kv_lora_rank)
return q_input, k_input, v_input
@torch.library.register_fake("sgl_kernel::weight_packed_linear")
def _(x, weight, bias, is_vnni):
return x.new_empty(x.shape[0], weight.shape[0])
@torch.library.register_fake("sgl_kernel::per_token_quant_int8_cpu")
def _(input):
M = input.shape[0]
K = input.shape[1]
Aq = input.new_empty(M, K, dtype=torch.int8)
As = input.new_empty(M, dtype=torch.float32)
return Aq, As
@torch.library.register_fake("sgl_kernel::int8_scaled_mm_cpu")
def _(mat1, mat2, scales1, scales2, bias, out_dtype, is_vnni):
M = mat1.shape[0]
N = mat2.shape[0]
out = mat1.new_empty(M, N, dtype=out_dtype)
return out
@torch.library.register_fake("sgl_kernel::grouped_topk_cpu")
def _(
hidden_states,
gating_output,
topk,
renormalize,
num_expert_group,
topk_group,
num_fused_shared_experts,
routed_scaling_factor,
num_token_non_padded,
):
num_tokens = hidden_states.shape[0]
shape = (num_tokens, topk)
device = hidden_states.device
topk_weights = torch.empty(shape, device=device, dtype=torch.float32)
topk_ids = torch.empty(shape, device=device, dtype=torch.int)
return topk_weights, topk_ids
@torch.library.register_fake("sgl_kernel::biased_grouped_topk_cpu")
def _(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
num_expert_group,
topk_group,
num_fused_shared_experts,
routed_scaling_factor,
num_token_non_padded,
):
num_tokens = hidden_states.shape[0]
shape = (num_tokens, topk)
device = hidden_states.device
topk_weights = torch.empty(shape, device=device, dtype=torch.float32)
topk_ids = torch.empty(shape, device=device, dtype=torch.int)
return topk_weights, topk_ids
@torch.library.register_fake("sgl_kernel::topk_sigmoid_cpu")
def _(hidden_states, gating_output, topk, renormalize):
num_tokens = hidden_states.shape[0]
shape = (num_tokens, topk)
return (
torch.empty(shape, device=hidden_states.device, dtype=torch.float),
torch.empty(shape, device=hidden_states.device, dtype=torch.int),
)
@torch.library.register_fake("sgl_kernel::topk_softmax_cpu")
def _(
hidden_states,
gating_output,
topk,
renormalize,
):
num_tokens = hidden_states.shape[0]
shape = (num_tokens, topk)
return (
torch.empty(shape, device=hidden_states.device, dtype=torch.float),
torch.empty(shape, device=hidden_states.device, dtype=torch.int),
)
@torch.library.register_fake("sgl_kernel::silu_and_mul_cpu")
def _(input):
return input.new_empty(input.shape[0], input.shape[1] // 2)
@torch.library.register_fake("sgl_kernel::int8_scaled_mm_with_quant")
def _(
mat1,
mat2,
scales2,
bias,
out_dtype,
is_vnni,
):
M = mat1.shape[0]
N = mat2.shape[0]
return mat1.new_empty(M, N, dtype=out_dtype)
@torch.library.register_fake("sgl_kernel::fp8_scaled_mm_cpu")
def _(
mat1,
mat2,
scales2,
block_size,
bias,
out_dtype,
is_vnni,
):
M = mat1.shape[0]
N = mat2.shape[0]
return mat1.new_empty(M, N, dtype=out_dtype)
# TODO Remove unnecessary settings for CPUGraphRunner.
# Re-abstract the graph runner and restructure CPUGraphRunner to reuse the same logic.
class CPUGraphRunner:
"""A CPUGraphRunner runs the forward pass of a model with cpu torch.compile."""
def __init__(self, model_runner: ModelRunner):
# Parse args
self.model_runner = model_runner
self.device = model_runner.device
self.graphs = {}
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.enable_two_batch_overlap = (
model_runner.server_args.enable_two_batch_overlap
)
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.enable_profile_cuda_graph = (
model_runner.server_args.enable_profile_cuda_graph
)
self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size
self.pp_size = model_runner.server_args.pp_size
self.capture_forward_mode = ForwardMode.DECODE
self.capture_hidden_mode = CaptureHiddenMode.NULL
self.num_tokens_per_bs = 1
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
if model_runner.server_args.enable_return_hidden_states:
self.capture_hidden_mode = CaptureHiddenMode.FULL
assert (
not self.model_runner.server_args.enable_lora
), "CPUGraphRunner does not support LoRA yet."
assert (
not self.enable_two_batch_overlap
), "CPUGraphRunner does not support two batch overlap yet."
assert (
not self.require_mlp_tp_gather
), "CPUGraphRunner does not support MLP TP gather yet."
assert (
not self.require_mlp_sync
), "CPUGraphRunner does not support MLP sync yet."
assert (
not self.require_gathered_buffer
), "CPUGraphRunner does not support gathered buffer yet."
assert (
model_runner.spec_algorithm == SpeculativeAlgorithm.NONE
), "CPUGraphRunner does not support speculative inference yet."
# TODO add compile support for encoder-decoder models
assert (
not self.is_encoder_decoder
), "CPUGraphRunner does not support encoder-decoder models yet."
assert self.dp_size == 1, "CPUGraphRunner does not support DP yet."
assert self.pp_size == 1, "CPUGraphRunner does not support PP yet."
# Batch sizes to capture
self.capture_bs = get_batch_sizes_to_capture(model_runner)
log_info_on_rank0(logger, f"Capture cpu graph bs {self.capture_bs}")
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_graph_seq_len_fill_value()
)
if self.enable_torch_compile:
register_fake_ops()
set_torch_compile_config()
# Graph inputs
with torch.device(self.device):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int64)
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int64
)
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int64)
self.custom_mask = torch.ones(
(
(self.seq_lens.sum().item() + self.max_num_token)
* self.num_tokens_per_bs
),
dtype=torch.bool,
device=self.device,
)
# Capture
try:
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture CPU graph failed: {e}\n{CPU_GRAPH_CAPTURE_FAILED_MSG}"
)
def can_run(self, forward_batch: ForwardBatch):
is_bs_supported = forward_batch.batch_size in self.graphs
requested_capture_hidden_mode = max(
forward_batch.capture_hidden_mode,
(
forward_batch.spec_info.capture_hidden_mode
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
is not None
else CaptureHiddenMode.NULL
),
)
capture_hidden_mode_matches = (
requested_capture_hidden_mode == CaptureHiddenMode.NULL
or requested_capture_hidden_mode == self.capture_hidden_mode
)
return is_bs_supported and capture_hidden_mode_matches
def capture(self) -> None:
capture_range = (
tqdm.tqdm(list(reversed(self.capture_bs)))
if get_tensor_model_parallel_rank() == 0
else reversed(self.capture_bs)
)
for bs in capture_range:
if get_tensor_model_parallel_rank() == 0:
avail_mem = psutil.virtual_memory().available / (1 << 30)
capture_range.set_description(
f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
)
with patch_model(
self.model_runner.model,
bs in self.capture_bs,
num_tokens=bs * self.num_tokens_per_bs,
tp_group=self.model_runner.tp_group,
) as forward:
(
graph,
output_buffers,
) = self.capture_one_batch_size(bs, forward)
self.graphs[bs] = graph
self.output_buffers[bs] = output_buffers
def capture_one_batch_size(self, bs: int, forward: Callable):
num_tokens = bs * self.num_tokens_per_bs
# Graph inputs
input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens]
mrope_positions = self.mrope_positions[:, :bs]
self.num_token_non_padded[...] = num_tokens
spec_info = self.get_spec_info(num_tokens)
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
self.capture_hidden_mode = (
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
)
forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode,
batch_size=bs,
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum().item(),
return_logprob=False,
positions=positions,
mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=self.capture_hidden_mode,
num_token_non_padded=self.num_token_non_padded,
global_forward_mode=self.capture_forward_mode,
)
# Attention backend
self.model_runner.attn_backend.init_forward_metadata(forward_batch)
# Do infernence to avoid setting attr at runtime, e.g.,
# self.attn_mha.kv_b_proj = self.kv_b_proj for full graph compile on CPU
self.model_runner.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
)
# Run and capture
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
logits_output_or_pp_proxy_tensors = forward(
input_ids,
forward_batch.positions,
forward_batch,
)
return logits_output_or_pp_proxy_tensors
with torch.no_grad():
for _ in range(2):
self.model_runner.tp_group.barrier()
out = run_once()
return forward, out
def recapture_if_needed(self, forward_batch: ForwardBatch):
# If the required capture_hidden_mode changes, we need to recapture the graph
# These are the different factors that can influence the capture_hidden_mode
capture_hidden_mode_required_by_forward_batch = (
forward_batch.capture_hidden_mode
)
capture_hidden_mode_required_by_spec_info = getattr(
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
)
capture_hidden_mode_required_for_returning_hidden_states = (
CaptureHiddenMode.FULL
if self.model_runner.server_args.enable_return_hidden_states
else CaptureHiddenMode.NULL
)
# Determine the highest capture_hidden_mode required
# (If we have FULL, we can emulate LAST or NULL)
# (If we have LAST, we can emulate NULL)
required_capture_hidden_mode = max(
capture_hidden_mode_required_by_forward_batch,
capture_hidden_mode_required_by_spec_info,
capture_hidden_mode_required_for_returning_hidden_states,
)
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
if self.capture_hidden_mode != required_capture_hidden_mode:
self.capture_hidden_mode = required_capture_hidden_mode
self.capture()
# TODO add padding support for CPUGraphRunner
def replay(
self,
forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
assert (
pp_proxy_tensors is None
), "PPProxyTensors is not supported in CPUGraphRunner yet."
self.recapture_if_needed(forward_batch)
self.model_runner.attn_backend.init_forward_metadata(forward_batch)
output = self.graphs[forward_batch.batch_size](
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
)
return output
def get_spec_info(self, num_tokens: int):
spec_info = None
if self.model_runner.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen.")
else:
spec_info = EagleVerifyInput(
draft_token=None,
custom_mask=self.custom_mask,
positions=None,
retrive_index=None,
retrive_next_token=None,
retrive_next_sibling=None,
retrive_cum_len=None,
spec_steps=self.model_runner.server_args.speculative_num_steps,
topk=self.model_runner.server_args.speculative_eagle_topk,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
capture_hidden_mode=CaptureHiddenMode.FULL,
seq_lens_sum=None,
seq_lens_cpu=None,
)
return spec_info
CPU_GRAPH_CAPTURE_FAILED_MSG = (
"Possible solutions:\n"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
"2. set --torch-compile-max-bs to a smaller value (e.g., 8)\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)