Use int64 as indices for set_kv_buffer (#3039)
This commit is contained in:
@@ -99,10 +99,7 @@ class BenchArgs:
|
|||||||
parser.add_argument("--correctness-test", action="store_true")
|
parser.add_argument("--correctness-test", action="store_true")
|
||||||
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--profile",
|
"--profile", action="store_true", help="Use Torch Profiler."
|
||||||
action="store_true",
|
|
||||||
help="Use Torch Profiler. The endpoint must be launched with "
|
|
||||||
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--profile-filename-prefix",
|
"--profile-filename-prefix",
|
||||||
@@ -381,6 +378,7 @@ def latency_test_run_once(
|
|||||||
parent_dir = os.path.dirname(os.path.abspath(profile_filename))
|
parent_dir = os.path.dirname(os.path.abspath(profile_filename))
|
||||||
os.makedirs(parent_dir, exist_ok=True)
|
os.makedirs(parent_dir, exist_ok=True)
|
||||||
profiler.export_chrome_trace(profile_filename)
|
profiler.export_chrome_trace(profile_filename)
|
||||||
|
rank_print(f"torch profiler chrome trace saved to {profile_filename}")
|
||||||
|
|
||||||
# Record decode timing from 2nd output
|
# Record decode timing from 2nd output
|
||||||
if output_len > 1:
|
if output_len > 1:
|
||||||
@@ -451,7 +449,7 @@ def latency_test(
|
|||||||
il,
|
il,
|
||||||
ol,
|
ol,
|
||||||
server_args.device,
|
server_args.device,
|
||||||
bench_args.profile,
|
bench_args.profile if tp_rank == 0 else None,
|
||||||
bench_args.profile_filename_prefix,
|
bench_args.profile_filename_prefix,
|
||||||
)
|
)
|
||||||
if ret is not None:
|
if ret is not None:
|
||||||
|
|||||||
@@ -296,7 +296,7 @@ def fused_softcap_kernel(
|
|||||||
n_elements,
|
n_elements,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid = tl.program_id(0)
|
pid = tl.program_id(0).to(tl.int64)
|
||||||
block_start = pid * BLOCK_SIZE
|
block_start = pid * BLOCK_SIZE
|
||||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
mask = offsets < n_elements
|
mask = offsets < n_elements
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
|
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
|
||||||
|
|
||||||
@@ -109,8 +108,6 @@ class Sampler(nn.Module):
|
|||||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_next_token_ids = batch_next_token_ids.to(torch.int32)
|
|
||||||
|
|
||||||
# Attach logprobs to logits_output (in-place modification)
|
# Attach logprobs to logits_output (in-place modification)
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
if any(x > 0 for x in top_logprobs_nums):
|
if any(x > 0 for x in top_logprobs_nums):
|
||||||
@@ -124,7 +121,7 @@ class Sampler(nn.Module):
|
|||||||
batch_next_token_ids,
|
batch_next_token_ids,
|
||||||
]
|
]
|
||||||
|
|
||||||
return batch_next_token_ids
|
return batch_next_token_ids.to(torch.int32)
|
||||||
|
|
||||||
def _apply_custom_logit_processor(
|
def _apply_custom_logit_processor(
|
||||||
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
||||||
|
|||||||
@@ -550,13 +550,13 @@ class ScheduleBatch:
|
|||||||
next_batch_sampling_info: SamplingBatchInfo = None
|
next_batch_sampling_info: SamplingBatchInfo = None
|
||||||
|
|
||||||
# Batched arguments to model runner
|
# Batched arguments to model runner
|
||||||
input_ids: torch.Tensor = None
|
input_ids: torch.Tensor = None # shape: [b], int32
|
||||||
input_embeds: torch.Tensor = None
|
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
|
||||||
req_pool_indices: torch.Tensor = None
|
req_pool_indices: torch.Tensor = None # shape: [b], int32
|
||||||
seq_lens: torch.Tensor = None
|
seq_lens: torch.Tensor = None # shape: [b], int64
|
||||||
# The output locations of the KV cache
|
# The output locations of the KV cache
|
||||||
out_cache_loc: torch.Tensor = None
|
out_cache_loc: torch.Tensor = None # shape: [b], int32
|
||||||
output_ids: torch.Tensor = None
|
output_ids: torch.Tensor = None # shape: [b], int32
|
||||||
|
|
||||||
# The sum of all sequence lengths
|
# The sum of all sequence lengths
|
||||||
seq_lens_sum: int = None
|
seq_lens_sum: int = None
|
||||||
@@ -1026,7 +1026,7 @@ class ScheduleBatch:
|
|||||||
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||||
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||||
self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device)
|
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||||
self.seq_lens_sum = 0
|
self.seq_lens_sum = 0
|
||||||
self.extend_num_tokens = 0
|
self.extend_num_tokens = 0
|
||||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import tqdm
|
|||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.distributed.parallel_state import graph_capture
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
||||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||||
@@ -63,7 +63,7 @@ def patch_model(
|
|||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
enable_compile: bool,
|
enable_compile: bool,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
tp_group: "GroupCoordinator",
|
tp_group: GroupCoordinator,
|
||||||
):
|
):
|
||||||
"""Patch the model to make it compatible with with torch.compile"""
|
"""Patch the model to make it compatible with with torch.compile"""
|
||||||
backup_ca_comm = None
|
backup_ca_comm = None
|
||||||
@@ -149,9 +149,18 @@ class CudaGraphRunner:
|
|||||||
and bs <= model_runner.server_args.cuda_graph_max_bs
|
and bs <= model_runner.server_args.cuda_graph_max_bs
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.compile_bs = (
|
||||||
|
[
|
||||||
|
bs
|
||||||
|
for bs in self.capture_bs
|
||||||
|
if bs <= self.model_runner.server_args.torch_compile_max_bs
|
||||||
|
]
|
||||||
|
if self.use_torch_compile
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
self.capture_forward_mode = ForwardMode.DECODE
|
self.capture_forward_mode = ForwardMode.DECODE
|
||||||
self.num_tokens_per_bs = 1
|
self.num_tokens_per_bs = 1
|
||||||
|
|
||||||
if model_runner.spec_algorithm.is_eagle():
|
if model_runner.spec_algorithm.is_eagle():
|
||||||
if self.model_runner.is_draft_worker:
|
if self.model_runner.is_draft_worker:
|
||||||
self.num_tokens_per_bs = (
|
self.num_tokens_per_bs = (
|
||||||
@@ -163,16 +172,6 @@ class CudaGraphRunner:
|
|||||||
self.model_runner.server_args.speculative_num_draft_tokens
|
self.model_runner.server_args.speculative_num_draft_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
self.compile_bs = (
|
|
||||||
[
|
|
||||||
bs
|
|
||||||
for bs in self.capture_bs
|
|
||||||
if bs <= self.model_runner.server_args.torch_compile_max_bs
|
|
||||||
]
|
|
||||||
if self.use_torch_compile
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.max_bs = max(self.capture_bs)
|
self.max_bs = max(self.capture_bs)
|
||||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||||
@@ -180,7 +179,6 @@ class CudaGraphRunner:
|
|||||||
self.seq_len_fill_value = (
|
self.seq_len_fill_value = (
|
||||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||||
self.encoder_len_fill_value = 0
|
self.encoder_len_fill_value = 0
|
||||||
|
|
||||||
@@ -189,14 +187,14 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32)
|
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||||
self.seq_lens = torch.full(
|
self.seq_lens = torch.full(
|
||||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||||
)
|
)
|
||||||
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32)
|
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
||||||
|
|
||||||
# Speculative_inference
|
# Speculative_inference
|
||||||
if model_runner.spec_algorithm.is_eagle():
|
if model_runner.spec_algorithm.is_eagle():
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from sglang.srt.utils import maybe_torch_compile
|
from sglang.srt.utils import get_compiler_backend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention import AttentionBackend
|
||||||
@@ -415,6 +415,6 @@ def compute_position_torch(
|
|||||||
return positions.to(torch.int64), extend_start_loc
|
return positions.to(torch.int64), extend_start_loc
|
||||||
|
|
||||||
|
|
||||||
@maybe_torch_compile(dynamic=True)
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
def clamp_position(seq_lens):
|
def clamp_position(seq_lens):
|
||||||
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
||||||
|
|||||||
Reference in New Issue
Block a user