diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index e01919399..bc7a9c7a1 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -99,10 +99,7 @@ class BenchArgs: parser.add_argument("--correctness-test", action="store_true") parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) parser.add_argument( - "--profile", - action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + "--profile", action="store_true", help="Use Torch Profiler." ) parser.add_argument( "--profile-filename-prefix", @@ -381,6 +378,7 @@ def latency_test_run_once( parent_dir = os.path.dirname(os.path.abspath(profile_filename)) os.makedirs(parent_dir, exist_ok=True) profiler.export_chrome_trace(profile_filename) + rank_print(f"torch profiler chrome trace saved to {profile_filename}") # Record decode timing from 2nd output if output_len > 1: @@ -451,7 +449,7 @@ def latency_test( il, ol, server_args.device, - bench_args.profile, + bench_args.profile if tp_rank == 0 else None, bench_args.profile_filename_prefix, ) if ret is not None: diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index e5794f052..08ee5a350 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -296,7 +296,7 @@ def fused_softcap_kernel( n_elements, BLOCK_SIZE: tl.constexpr, ): - pid = tl.program_id(0) + pid = tl.program_id(0).to(tl.int64) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index ebaa1aa0e..f3c376ed1 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,12 +1,11 @@ import logging -from typing import Dict, List +from typing import List import torch from torch import nn from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.utils import crash_on_warnings, is_flashinfer_available @@ -109,8 +108,6 @@ class Sampler(nn.Module): f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" ) - batch_next_token_ids = batch_next_token_ids.to(torch.int32) - # Attach logprobs to logits_output (in-place modification) if return_logprob: if any(x > 0 for x in top_logprobs_nums): @@ -124,7 +121,7 @@ class Sampler(nn.Module): batch_next_token_ids, ] - return batch_next_token_ids + return batch_next_token_ids.to(torch.int32) def _apply_custom_logit_processor( self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index d9af81515..6c44b17ff 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -550,13 +550,13 @@ class ScheduleBatch: next_batch_sampling_info: SamplingBatchInfo = None # Batched arguments to model runner - input_ids: torch.Tensor = None - input_embeds: torch.Tensor = None - req_pool_indices: torch.Tensor = None - seq_lens: torch.Tensor = None + input_ids: torch.Tensor = None # shape: [b], int32 + input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32 + req_pool_indices: torch.Tensor = None # shape: [b], int32 + seq_lens: torch.Tensor = None # shape: [b], int64 # The output locations of the KV cache - out_cache_loc: torch.Tensor = None - output_ids: torch.Tensor = None + out_cache_loc: torch.Tensor = None # shape: [b], int32 + output_ids: torch.Tensor = None # shape: [b], int32 # The sum of all sequence lengths seq_lens_sum: int = None @@ -1026,7 +1026,7 @@ class ScheduleBatch: self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device) self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device) - self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device) + self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.seq_lens_sum = 0 self.extend_num_tokens = 0 self.sampling_info = SamplingBatchInfo.from_schedule_batch( diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 762dac140..169b64343 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -24,7 +24,7 @@ import tqdm from vllm.model_executor.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank -from sglang.srt.distributed.parallel_state import graph_capture +from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native from sglang.srt.layers.torchao_utils import save_gemlite_cache @@ -63,7 +63,7 @@ def patch_model( model: torch.nn.Module, enable_compile: bool, batch_size: int, - tp_group: "GroupCoordinator", + tp_group: GroupCoordinator, ): """Patch the model to make it compatible with with torch.compile""" backup_ca_comm = None @@ -149,9 +149,18 @@ class CudaGraphRunner: and bs <= model_runner.server_args.cuda_graph_max_bs ] + self.compile_bs = ( + [ + bs + for bs in self.capture_bs + if bs <= self.model_runner.server_args.torch_compile_max_bs + ] + if self.use_torch_compile + else [] + ) + self.capture_forward_mode = ForwardMode.DECODE self.num_tokens_per_bs = 1 - if model_runner.spec_algorithm.is_eagle(): if self.model_runner.is_draft_worker: self.num_tokens_per_bs = ( @@ -163,16 +172,6 @@ class CudaGraphRunner: self.model_runner.server_args.speculative_num_draft_tokens ) - self.compile_bs = ( - [ - bs - for bs in self.capture_bs - if bs <= self.model_runner.server_args.torch_compile_max_bs - ] - if self.use_torch_compile - else [] - ) - # Attention backend self.max_bs = max(self.capture_bs) self.max_num_token = self.max_bs * self.num_tokens_per_bs @@ -180,7 +179,6 @@ class CudaGraphRunner: self.seq_len_fill_value = ( self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) - # FIXME(lsyin): leave it here for now, I don't know whether it is necessary self.encoder_len_fill_value = 0 @@ -189,14 +187,14 @@ class CudaGraphRunner: # Common inputs with torch.device("cuda"): - self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32) + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) - self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32) + self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) - self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32) + self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) # Speculative_inference if model_runner.spec_algorithm.is_eagle(): diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 8ef5c57b8..8bd105275 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -38,7 +38,7 @@ import triton import triton.language as tl from sglang.srt.layers.rotary_embedding import MRotaryEmbedding -from sglang.srt.utils import maybe_torch_compile +from sglang.srt.utils import get_compiler_backend if TYPE_CHECKING: from sglang.srt.layers.attention import AttentionBackend @@ -415,6 +415,6 @@ def compute_position_torch( return positions.to(torch.int64), extend_start_loc -@maybe_torch_compile(dynamic=True) +@torch.compile(dynamic=True, backend=get_compiler_backend()) def clamp_position(seq_lens): return torch.clamp((seq_lens - 1), min=0).to(torch.int64)