[Minor] more code cleanup (#4077)
This commit is contained in:
@@ -40,6 +40,7 @@ runtime_common = [
|
|||||||
"transformers==4.48.3",
|
"transformers==4.48.3",
|
||||||
"llguidance>=0.6.15"
|
"llguidance>=0.6.15"
|
||||||
]
|
]
|
||||||
|
|
||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"sgl-kernel==0.0.3.post6",
|
"sgl-kernel==0.0.3.post6",
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||||
|
ASSISTANT_SUFFIX = "Assistant:"
|
||||||
|
|
||||||
global args
|
global args
|
||||||
|
|
||||||
@@ -635,7 +636,11 @@ def sample_sharegpt_requests(
|
|||||||
# Tokenize the prompts and completions.
|
# Tokenize the prompts and completions.
|
||||||
prompt = dataset[i][0]
|
prompt = dataset[i][0]
|
||||||
if prompt_suffix:
|
if prompt_suffix:
|
||||||
prompt = prompt
|
prompt = (
|
||||||
|
remove_suffix(prompt, ASSISTANT_SUFFIX)
|
||||||
|
+ prompt_suffix
|
||||||
|
+ ASSISTANT_SUFFIX
|
||||||
|
)
|
||||||
|
|
||||||
if apply_chat_template:
|
if apply_chat_template:
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from json import JSONDecodeError, JSONDecoder
|
from json import JSONDecodeError, JSONDecoder
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|||||||
39
python/sglang/srt/layers/attention/utils.py
Normal file
39
python/sglang/srt/layers/attention/utils.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def create_flashinfer_kv_indices_triton(
|
||||||
|
req_to_token_ptr, # [max_batch, max_context_len]
|
||||||
|
req_pool_indices_ptr,
|
||||||
|
page_kernel_lens_ptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_start_idx,
|
||||||
|
kv_indices_ptr,
|
||||||
|
req_to_token_ptr_stride: tl.constexpr,
|
||||||
|
):
|
||||||
|
BLOCK_SIZE: tl.constexpr = 512
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
|
||||||
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||||
|
kv_indices_offset = tl.load(kv_indptr + pid)
|
||||||
|
|
||||||
|
kv_start = 0
|
||||||
|
kv_end = 0
|
||||||
|
if kv_start_idx:
|
||||||
|
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
||||||
|
kv_end = kv_start
|
||||||
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
||||||
|
|
||||||
|
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||||
|
for i in range(num_loop):
|
||||||
|
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||||
|
mask = offset < kv_end - kv_start
|
||||||
|
data = tl.load(
|
||||||
|
req_to_token_ptr
|
||||||
|
+ req_pool_index * req_to_token_ptr_stride
|
||||||
|
+ kv_start
|
||||||
|
+ offset,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
||||||
@@ -33,6 +33,7 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
get_attention_dp_size,
|
get_attention_dp_size,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
CaptureHiddenMode,
|
CaptureHiddenMode,
|
||||||
ForwardBatch,
|
ForwardBatch,
|
||||||
@@ -152,6 +153,13 @@ class LogitsMetadata:
|
|||||||
token_ids_logprobs=forward_batch.token_ids_logprobs,
|
token_ids_logprobs=forward_batch.token_ids_logprobs,
|
||||||
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
||||||
padded_static_len=forward_batch.padded_static_len,
|
padded_static_len=forward_batch.padded_static_len,
|
||||||
|
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
||||||
|
dp_local_start_pos=forward_batch.dp_local_start_pos,
|
||||||
|
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
|
||||||
|
gathered_buffer=forward_batch.gathered_buffer,
|
||||||
|
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
|
||||||
|
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
|
||||||
|
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
|
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
|
||||||
@@ -204,8 +212,6 @@ class LogitsProcessor(nn.Module):
|
|||||||
):
|
):
|
||||||
self.final_logit_softcapping = None
|
self.final_logit_softcapping = None
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
||||||
|
|
||||||
self.debug_tensor_dump_output_folder = global_server_args_dict.get(
|
self.debug_tensor_dump_output_folder = global_server_args_dict.get(
|
||||||
"debug_tensor_dump_output_folder", None
|
"debug_tensor_dump_output_folder", None
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -212,6 +212,7 @@ class DetokenizerManager:
|
|||||||
rids=recv_obj.rids,
|
rids=recv_obj.rids,
|
||||||
finished_reasons=recv_obj.finished_reasons,
|
finished_reasons=recv_obj.finished_reasons,
|
||||||
output_strs=output_strs,
|
output_strs=output_strs,
|
||||||
|
output_ids=None,
|
||||||
prompt_tokens=recv_obj.prompt_tokens,
|
prompt_tokens=recv_obj.prompt_tokens,
|
||||||
completion_tokens=recv_obj.completion_tokens,
|
completion_tokens=recv_obj.completion_tokens,
|
||||||
cached_tokens=recv_obj.cached_tokens,
|
cached_tokens=recv_obj.cached_tokens,
|
||||||
|
|||||||
@@ -414,6 +414,12 @@ class BatchTokenIDOut:
|
|||||||
class BatchMultimodalDecodeReq:
|
class BatchMultimodalDecodeReq:
|
||||||
# The request id
|
# The request id
|
||||||
rids: List[str]
|
rids: List[str]
|
||||||
|
finished_reasons: List[BaseFinishReason]
|
||||||
|
|
||||||
|
# Token counts
|
||||||
|
prompt_tokens: List[int]
|
||||||
|
completion_tokens: List[int]
|
||||||
|
cached_tokens: List[int]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -424,6 +430,8 @@ class BatchStrOut:
|
|||||||
finished_reasons: List[dict]
|
finished_reasons: List[dict]
|
||||||
# The output decoded strings
|
# The output decoded strings
|
||||||
output_strs: List[str]
|
output_strs: List[str]
|
||||||
|
# The token ids
|
||||||
|
output_ids: Optional[List[int]]
|
||||||
|
|
||||||
# Token counts
|
# Token counts
|
||||||
prompt_tokens: List[int]
|
prompt_tokens: List[int]
|
||||||
@@ -453,6 +461,15 @@ class BatchStrOut:
|
|||||||
class BatchMultimodalOut:
|
class BatchMultimodalOut:
|
||||||
# The request id
|
# The request id
|
||||||
rids: List[str]
|
rids: List[str]
|
||||||
|
# The finish reason
|
||||||
|
finished_reasons: List[dict]
|
||||||
|
# The outputs
|
||||||
|
outputs: List[List[Dict]]
|
||||||
|
|
||||||
|
# Token counts
|
||||||
|
prompt_tokens: List[int]
|
||||||
|
completion_tokens: List[int]
|
||||||
|
cached_tokens: List[int]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -1141,7 +1141,7 @@ async def print_exception_wrapper(func):
|
|||||||
|
|
||||||
|
|
||||||
class SignalHandler:
|
class SignalHandler:
|
||||||
def __init__(self, tokenizer_manager):
|
def __init__(self, tokenizer_manager: TokenizerManager):
|
||||||
self.tokenizer_manager = tokenizer_manager
|
self.tokenizer_manager = tokenizer_manager
|
||||||
|
|
||||||
def signal_handler(self, signum=None, frame=None):
|
def signal_handler(self, signum=None, frame=None):
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|||||||
|
|
||||||
k_size, v_size = self.get_kv_size_bytes()
|
k_size, v_size = self.get_kv_size_bytes()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB."
|
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_buffers(self):
|
def _create_buffers(self):
|
||||||
|
|||||||
@@ -238,6 +238,9 @@ class CudaGraphRunner:
|
|||||||
),
|
),
|
||||||
dtype=self.model_runner.dtype,
|
dtype=self.model_runner.dtype,
|
||||||
)
|
)
|
||||||
|
self.global_num_tokens_gpu = torch.zeros(
|
||||||
|
(self.dp_size,), dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
# Capture
|
# Capture
|
||||||
try:
|
try:
|
||||||
@@ -266,9 +269,9 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
|
min_num_tokens, max_num_tokens = min(
|
||||||
forward_batch.global_num_tokens
|
forward_batch.global_num_tokens_cpu
|
||||||
)
|
), max(forward_batch.global_num_tokens_cpu)
|
||||||
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
||||||
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
|
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
|
||||||
if self.disable_padding
|
if self.disable_padding
|
||||||
@@ -360,7 +363,7 @@ class CudaGraphRunner:
|
|||||||
encoder_lens=encoder_lens,
|
encoder_lens=encoder_lens,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
global_num_tokens=global_num_tokens,
|
global_num_tokens_cpu=global_num_tokens,
|
||||||
gathered_buffer=gathered_buffer,
|
gathered_buffer=gathered_buffer,
|
||||||
mrope_positions=mrope_positions,
|
mrope_positions=mrope_positions,
|
||||||
spec_algorithm=self.model_runner.spec_algorithm,
|
spec_algorithm=self.model_runner.spec_algorithm,
|
||||||
@@ -430,7 +433,7 @@ class CudaGraphRunner:
|
|||||||
# Pad
|
# Pad
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
index = bisect.bisect_left(
|
index = bisect.bisect_left(
|
||||||
self.capture_bs, max(forward_batch.global_num_tokens)
|
self.capture_bs, max(forward_batch.global_num_tokens_cpu)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
|
|||||||
@@ -190,7 +190,16 @@ class ForwardBatch:
|
|||||||
attn_backend: AttentionBackend = None
|
attn_backend: AttentionBackend = None
|
||||||
|
|
||||||
# For DP attention
|
# For DP attention
|
||||||
global_num_tokens: Optional[List[int]] = None
|
global_num_tokens_cpu: Optional[List[int]] = None
|
||||||
|
global_num_tokens_gpu: Optional[torch.Tensor] = None
|
||||||
|
# Has to be None when cuda graph is captured.
|
||||||
|
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
|
||||||
|
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
||||||
|
# for extend, local start pos and num tokens is different in logits processor
|
||||||
|
# this will be computed in get_dp_local_info
|
||||||
|
# this will be recomputed in LogitsMetadata.from_forward_batch
|
||||||
|
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
|
||||||
|
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
||||||
gathered_buffer: Optional[torch.Tensor] = None
|
gathered_buffer: Optional[torch.Tensor] = None
|
||||||
can_run_dp_cuda_graph: bool = False
|
can_run_dp_cuda_graph: bool = False
|
||||||
|
|
||||||
@@ -234,7 +243,6 @@ class ForwardBatch:
|
|||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
token_ids_logprobs=batch.token_ids_logprobs,
|
token_ids_logprobs=batch.token_ids_logprobs,
|
||||||
global_num_tokens=batch.global_num_tokens,
|
|
||||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||||
lora_paths=batch.lora_paths,
|
lora_paths=batch.lora_paths,
|
||||||
sampling_info=batch.sampling_info,
|
sampling_info=batch.sampling_info,
|
||||||
@@ -248,8 +256,9 @@ class ForwardBatch:
|
|||||||
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
if ret.global_num_tokens is not None:
|
if batch.global_num_tokens is not None:
|
||||||
max_len = max(ret.global_num_tokens)
|
ret.global_num_tokens_cpu = batch.global_num_tokens
|
||||||
|
max_len = max(ret.global_num_tokens_cpu)
|
||||||
ret.gathered_buffer = torch.zeros(
|
ret.gathered_buffer = torch.zeros(
|
||||||
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
|
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
|
||||||
dtype=model_runner.dtype,
|
dtype=model_runner.dtype,
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""ModelRunner runs the forward passes of the models."""
|
"""ModelRunner runs the forward passes of the models."""
|
||||||
|
|
||||||
import collections
|
|
||||||
import datetime
|
import datetime
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
@@ -269,6 +268,7 @@ class ModelRunner:
|
|||||||
elif self.device == "cpu":
|
elif self.device == "cpu":
|
||||||
backend = "gloo"
|
backend = "gloo"
|
||||||
|
|
||||||
|
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
if not self.server_args.enable_p2p_check:
|
if not self.server_args.enable_p2p_check:
|
||||||
monkey_patch_p2p_access_check()
|
monkey_patch_p2p_access_check()
|
||||||
|
|
||||||
@@ -299,20 +299,24 @@ class ModelRunner:
|
|||||||
min_per_gpu_memory = get_available_gpu_memory(
|
min_per_gpu_memory = get_available_gpu_memory(
|
||||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||||
)
|
)
|
||||||
|
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
self.tp_group = get_tp_group()
|
self.tp_group = get_tp_group()
|
||||||
self.attention_tp_group = get_attention_tp_group()
|
self.attention_tp_group = get_attention_tp_group()
|
||||||
|
|
||||||
# Check memory for tensor parallelism
|
# Check memory for tensor parallelism
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
|
||||||
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
|
||||||
|
)
|
||||||
return min_per_gpu_memory
|
return min_per_gpu_memory
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
|
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
@@ -382,11 +386,13 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
self.dtype = self.model_config.dtype
|
self.dtype = self.model_config.dtype
|
||||||
|
|
||||||
|
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Load weight end. "
|
f"Load weight end. "
|
||||||
f"type={type(self.model).__name__}, "
|
f"type={type(self.model).__name__}, "
|
||||||
f"dtype={self.dtype}, "
|
f"dtype={self.dtype}, "
|
||||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
f"avail mem={after_avail_memory:.2f} GB, "
|
||||||
|
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_weights_from_disk(
|
def update_weights_from_disk(
|
||||||
@@ -785,12 +791,15 @@ class ModelRunner:
|
|||||||
return
|
return
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||||
)
|
)
|
||||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||||
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
|
||||||
|
f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
||||||
)
|
)
|
||||||
|
|
||||||
def apply_torch_tp(self):
|
def apply_torch_tp(self):
|
||||||
@@ -806,8 +815,12 @@ class ModelRunner:
|
|||||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_extend(self, forward_batch: ForwardBatch):
|
def forward_extend(
|
||||||
self.attn_backend.init_forward_metadata(forward_batch)
|
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
||||||
|
):
|
||||||
|
if not skip_attn_backend_init:
|
||||||
|
self.attn_backend.init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
if forward_batch.input_embeds is None:
|
if forward_batch.input_embeds is None:
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
|
|||||||
@@ -818,8 +818,8 @@ def all_gather(
|
|||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
all_lens = forward_batch.global_num_tokens
|
all_lens = forward_batch.global_num_tokens_cpu
|
||||||
max_len = max(forward_batch.global_num_tokens)
|
max_len = max(forward_batch.global_num_tokens_cpu)
|
||||||
|
|
||||||
padded_tensor = torch.nn.functional.pad(
|
padded_tensor = torch.nn.functional.pad(
|
||||||
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
||||||
|
|||||||
@@ -741,13 +741,6 @@ def pytorch_profile(name, func, *args, data_size=-1):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def first_rank_print(*args, **kwargs):
|
|
||||||
if torch.cuda.current_device() == 0:
|
|
||||||
print(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def get_zmq_socket(
|
def get_zmq_socket(
|
||||||
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
|
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
|
||||||
):
|
):
|
||||||
@@ -1177,6 +1170,11 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
|
|||||||
return value.lower() in ("true", "1")
|
return value.lower() in ("true", "1")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=2)
|
||||||
|
def disable_request_logging() -> bool:
|
||||||
|
return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING")
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=8)
|
@lru_cache(maxsize=8)
|
||||||
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
|
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
|
||||||
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ nvcc_flags = [
|
|||||||
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1",
|
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1",
|
||||||
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
|
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
|
||||||
"--ptxas-options=-v",
|
"--ptxas-options=-v",
|
||||||
|
"--expt-relaxed-constexpr",
|
||||||
"-Xcompiler=-Wconversion",
|
"-Xcompiler=-Wconversion",
|
||||||
"-Xcompiler=-fno-strict-aliasing",
|
"-Xcompiler=-fno-strict-aliasing",
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user