diff --git a/benchmark/latency_throughput/bench_one.py b/benchmark/latency_throughput/bench_one.py index d67bfe960..5b2fa3cbc 100644 --- a/benchmark/latency_throughput/bench_one.py +++ b/benchmark/latency_throughput/bench_one.py @@ -92,4 +92,4 @@ if __name__ == "__main__": print(ret) speed = args.batch_size * max_new_tokens / latency - print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s") \ No newline at end of file + print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s") diff --git a/benchmark/latency_throughput/bench_serving.py b/benchmark/latency_throughput/bench_serving.py index cbe63a55b..1adb78958 100644 --- a/benchmark/latency_throughput/bench_serving.py +++ b/benchmark/latency_throughput/bench_serving.py @@ -307,8 +307,9 @@ def main(args: argparse.Namespace): avg_per_output_token_latency = np.mean( [latency / output_len for _, output_len, latency in REQUEST_LATENCY] ) - decoding_throughput = np.sum([ - output_len for _, output_len, _ in REQUEST_LATENCY]) / benchmark_time + decoding_throughput = ( + np.sum([output_len for _, output_len, _ in REQUEST_LATENCY]) / benchmark_time + ) print(f"Total time: {benchmark_time:.2f} s") print(f"Request throughput: {args.num_prompts / benchmark_time:.2f} requests/s") diff --git a/benchmark/line_retrieval/gen_data.py b/benchmark/line_retrieval/gen_data.py index 5763e6615..c88ecba49 100644 --- a/benchmark/line_retrieval/gen_data.py +++ b/benchmark/line_retrieval/gen_data.py @@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio): ) for i in redirect_indices: target_idx = np.random.choice(min(i * 2 + 100, num_lines)) - lines[ - i - ] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}." + lines[i] = ( + f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}." + ) redirects[i] = target_idx # Build links and find sources diff --git a/benchmark/mmlu/bench_sglang.py b/benchmark/mmlu/bench_sglang.py index b7b9267fa..00176343c 100644 --- a/benchmark/mmlu/bench_sglang.py +++ b/benchmark/mmlu/bench_sglang.py @@ -80,10 +80,12 @@ def main(args): for i in range(test_df.shape[0]): prompt_end = format_example(test_df, i, include_answer=False) - arguments.append({ - "examples": few_shot_examples, - "question": prompt_end, - }) + arguments.append( + { + "examples": few_shot_examples, + "question": prompt_end, + } + ) label = test_df.iloc[i, test_df.shape[1] - 1] labels.append(label) @@ -134,7 +136,9 @@ def main(args): pt = 0 for subject, num_qs in zip(subjects[: args.nsub], num_questions): - print(f"subject: {subject}, #q:{num_qs}, acc: {np.mean(cors[pt: pt + num_qs]):.3f}") + print( + f"subject: {subject}, #q:{num_qs}, acc: {np.mean(cors[pt: pt + num_qs]):.3f}" + ) pt += num_qs assert pt == len(cors) weighted_acc = np.mean(cors) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 14a2f1824..485242781 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -108,7 +108,7 @@ def prepare_inputs(bench_args, tokenizer): for i in range(len(prompts)): assert len(input_ids[i]) > bench_args.cut_len - tmp_input_ids = input_ids[i][:bench_args.cut_len] + tmp_input_ids = input_ids[i][: bench_args.cut_len] req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids) req.prefix_indices = [] req.sampling_params = sampling_params @@ -121,9 +121,9 @@ def prepare_inputs(bench_args, tokenizer): def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner): for i in range(len(reqs)): req = reqs[i] - req.input_ids += input_ids[i][bench_args.cut_len:] + req.input_ids += input_ids[i][bench_args.cut_len :] req.prefix_indices = model_runner.req_to_token_pool.req_to_token[ - i, :bench_args.cut_len + i, : bench_args.cut_len ] return reqs @@ -151,7 +151,8 @@ def extend(reqs, model_runner): reqs=reqs, req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool=model_runner.token_to_kv_pool, - tree_cache=None) + tree_cache=None, + ) batch.prepare_for_extend(model_runner.model_config.vocab_size, None) output = model_runner.forward(batch, ForwardMode.EXTEND) next_token_ids, _ = batch.sample(output.next_token_logits) @@ -212,7 +213,9 @@ def latency_test( # Load the model model_runner, tokenizer = load_model(server_args, tp_rank) - print(f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}") + print( + f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}" + ) # Prepare inputs reqs = prepare_synthetic_inputs(bench_args, tokenizer) @@ -232,7 +235,9 @@ def latency_test( prefill_latency = time.time() - tic tot_latency += prefill_latency throughput = bench_args.input_len * bench_args.batch_size / prefill_latency - rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s") + rank_print( + f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s" + ) # Decode for i in range(output_len): @@ -243,13 +248,24 @@ def latency_test( latency = time.time() - tic tot_latency += latency throughput = bench_args.batch_size / latency - if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s") + if i < 5: + rank_print( + f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" + ) avg_decode_latency = (tot_latency - prefill_latency) / output_len avg_decode_throughput = bench_args.batch_size / avg_decode_latency - rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s") - - throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency - rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s") + rank_print( + f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s" + ) + + throughput = ( + (bench_args.input_len + bench_args.output_len) + * bench_args.batch_size + / tot_latency + ) + rank_print( + f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s" + ) # Warm up run_once(4) @@ -298,4 +314,4 @@ if __name__ == "__main__": format="%(message)s", ) - main(server_args, bench_args) \ No newline at end of file + main(server_args, bench_args) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 377bde82e..00340b59a 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -39,4 +39,5 @@ class GlobalConfig: # This can improve the speed for large batch sizes during prefill. self.layer_sync_threshold = 8192 + global_config = GlobalConfig() diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index eb7d87c24..ad2e9fb2b 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -185,8 +185,10 @@ class SglFunction: batch_kwargs = [ {self.arg_names[i]: v for i, v in enumerate(arg_values)} for arg_values in batch_kwargs - if isinstance(arg_values, (list, tuple)) and - len(self.arg_names) - len(self.arg_defaults) <= len(arg_values) <= len(self.arg_names) + if isinstance(arg_values, (list, tuple)) + and len(self.arg_names) - len(self.arg_defaults) + <= len(arg_values) + <= len(self.arg_names) ] # Ensure to raise an exception if the number of arguments mismatch if len(batch_kwargs) != num_programs: diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index 6cc85fd64..b6d5a7358 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -5,13 +5,14 @@ from pydantic import BaseModel try: from outlines.caching import cache as disk_cache - from outlines.fsm.guide import RegexGuide from outlines.caching import disable_cache from outlines.fsm.guide import RegexGuide from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm from outlines.models.transformers import TransformerTokenizer except ImportError as e: - print(f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n') + print( + f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n' + ) raise try: diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 734601031..218af433c 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -264,7 +264,9 @@ class TiktokenTokenizer: return self.tokenizer.decode_batch(batch) def apply_chat_template(self, messages, tokenize, add_generation_prompt): - ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt) + ret = self.chat_template.render( + messages=messages, add_generation_prompt=add_generation_prompt + ) return self.encode(ret) if tokenize else ret @@ -297,5 +299,7 @@ class SentencePieceTokenizer: return self.tokenizer.decode(batch) def apply_chat_template(self, messages, tokenize, add_generation_prompt): - ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt) - return self.encode(ret) if tokenize else ret \ No newline at end of file + ret = self.chat_template.render( + messages=messages, add_generation_prompt=add_generation_prompt + ) + return self.encode(ret) if tokenize else ret diff --git a/python/sglang/srt/layers/fused_moe.py b/python/sglang/srt/layers/fused_moe.py index bfd6cc666..7dddabb05 100644 --- a/python/sglang/srt/layers/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe.py @@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple import torch import triton import triton.language as tl - from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -108,12 +107,16 @@ def fused_moe_kernel( offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + - offs_k[None, :] * stride_ak) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) off_experts = tl.load(expert_ids_ptr + pid_m) - b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + - offs_bn[None, :] * stride_bn) + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) if use_fp8: a_scale = tl.load(a_scale_ptr) @@ -129,13 +132,12 @@ def fused_moe_kernel( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load(a_ptrs, - mask=token_mask[:, None] & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0) + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. if use_fp8: accumulator = tl.dot(a, b, acc=accumulator) @@ -146,9 +148,7 @@ def fused_moe_kernel( b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, - mask=token_mask, - other=0) + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] if use_fp8: @@ -158,15 +158,14 @@ def fused_moe_kernel( # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ - None, :] + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) def moe_align_block_size( - topk_ids: torch.Tensor, block_size: int, - num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + topk_ids: torch.Tensor, block_size: int, num_experts: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. @@ -205,32 +204,38 @@ def moe_align_block_size( by block_size for proper block matrix operations. """ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids = torch.empty((max_num_tokens_padded, ), - dtype=torch.int32, - device=topk_ids.device) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - expert_ids = torch.empty((max_num_m_blocks, ), - dtype=torch.int32, - device=topk_ids.device) - num_tokens_post_pad = torch.empty((1), - dtype=torch.int32, - device=topk_ids.device) - ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad) + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + ops.moe_align_block_size( + topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad + ) return sorted_ids, expert_ids, num_tokens_post_pad -def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, top_k: int, - config: Dict[str, Any], compute_type: tl.dtype, - use_fp8: bool) -> None: +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: Dict[str, Any], + compute_type: tl.dtype, + use_fp8: bool, +) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -241,8 +246,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None - grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ - 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) + grid = lambda META: ( + triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), + ) fused_moe_kernel[grid]( A, @@ -280,8 +287,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: @functools.lru_cache -def get_moe_configs(E: int, N: int, - dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -296,11 +302,11 @@ def get_moe_configs(E: int, N: int, json_file_name = get_config_file_name(E, N, dtype) config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) if os.path.exists(config_file_path): with open(config_file_path) as f: - logger.info("Using configuration from %s for MoE layer.", - config_file_path) + logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} @@ -319,35 +325,35 @@ def get_default_config( ) -> Dict[str, int]: if dtype == "float8": config = { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 128, - 'GROUP_SIZE_M': 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 4 + "num_stages": 4, } if M <= E: config = { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 128, - 'GROUP_SIZE_M': 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 4, } else: config = { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, } if M <= E: config = { - 'BLOCK_SIZE_M': 16, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 1 + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, } return config @@ -358,23 +364,17 @@ def fused_topk( topk: int, renormalize: bool, ): - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" M, _ = hidden_states.shape - topk_weights = torch.empty(M, - topk, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) + token_expert_indicies = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) ops.topk_softmax( topk_weights, topk_ids, @@ -388,27 +388,27 @@ def fused_topk( return topk_weights, topk_ids -def fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None): +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] M, _ = hidden_states.shape E, N, _ = w1.shape @@ -417,8 +417,7 @@ def fused_experts(hidden_states: torch.Tensor, config = override_config else: # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], - "float8" if use_fp8 else None) + configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) if configs: # If an optimal configuration map has been found, look up the @@ -426,65 +425,76 @@ def fused_experts(hidden_states: torch.Tensor, config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1.shape[2], - topk_ids.shape[1], - "float8" if use_fp8 else None) + config = get_default_config( + M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None + ) - intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config['BLOCK_SIZE_M'], E) - compute_type = (tl.bfloat16 - if hidden_states.dtype == torch.bfloat16 else tl.float16) + topk_ids, config["BLOCK_SIZE_M"], E + ) + compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 - invoke_fused_moe_kernel(hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8) + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_kernel(intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8) + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) if inplace: - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=hidden_states) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) + return torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states, + ) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) def fused_moe( @@ -532,25 +542,28 @@ def fused_moe( assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" if hasattr(ops, "topk_softmax"): - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) + topk_weights, topk_ids = fused_topk( + hidden_states, gating_output, topk, renormalize + ) else: - topk_weights, topk_ids = fused_topk_v0_4_3(hidden_states, gating_output, topk, - renormalize) - - return fused_experts(hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace=inplace, - override_config=override_config, - use_fp8=use_fp8, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale) + topk_weights, topk_ids = fused_topk_v0_4_3( + hidden_states, gating_output, topk, renormalize + ) + return fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + override_config=override_config, + use_fp8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) def fused_topk_v0_4_3( @@ -560,6 +573,7 @@ def fused_topk_v0_4_3( renormalize: bool, ): import vllm._moe_C as moe_kernels + M, _ = hidden_states.shape topk_weights = torch.empty( @@ -579,4 +593,4 @@ def fused_topk_v0_4_3( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids \ No newline at end of file + return topk_weights, topk_ids diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 25493eef5..b9ea6ad85 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -1,4 +1,5 @@ """Radix attention.""" + import numpy as np import torch from torch import nn @@ -11,8 +12,13 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada class RadixAttention(nn.Module): def __init__( - self, num_heads: int, head_dim: int, scaling: float, num_kv_heads: int, - layer_id: int, logit_cap: int = -1 + self, + num_heads: int, + head_dim: int, + scaling: float, + num_kv_heads: int, + layer_id: int, + logit_cap: int = -1, ): super().__init__() self.tp_q_head_num = num_heads @@ -112,6 +118,7 @@ class RadixAttention(nn.Module): ) from flashinfer.cascade import merge_state + o, _ = merge_state(o1, s1, o2, s2) if input_metadata.total_num_tokens >= global_config.layer_sync_threshold: diff --git a/python/sglang/srt/managers/controller/manager_single.py b/python/sglang/srt/managers/controller/manager_single.py index bed97c391..4c2720733 100644 --- a/python/sglang/srt/managers/controller/manager_single.py +++ b/python/sglang/srt/managers/controller/manager_single.py @@ -99,4 +99,4 @@ def start_controller_process( except Exception: logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) finally: - kill_parent_process() \ No newline at end of file + kill_parent_process() diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index bded85af9..879f44151 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -127,7 +127,7 @@ class InputMetadata: num_qo_heads, num_kv_heads, head_dim, - 1 + 1, ) else: self.flashinfer_decode_wrapper.end_forward() @@ -140,7 +140,7 @@ class InputMetadata: head_dim, 1, pos_encoding_mode="NONE", - data_type=self.token_to_kv_pool.kv_data[0].dtype + data_type=self.token_to_kv_pool.kv_data[0].dtype, ) def init_extend_args(self): @@ -228,7 +228,7 @@ class InputMetadata: ret.init_flashinfer_args( model_runner.model_config.num_attention_heads // tp_size, model_runner.model_config.get_num_kv_heads(tp_size), - model_runner.model_config.head_dim + model_runner.model_config.head_dim, ) return ret @@ -269,7 +269,7 @@ class ModelRunner: world_size=self.tp_size, rank=self.tp_rank, local_rank=self.gpu_id, - distributed_init_method=nccl_init_method + distributed_init_method=nccl_init_method, ) initialize_model_parallel(tensor_model_parallel_size=self.tp_size) total_gpu_memory = get_available_gpu_memory( @@ -341,7 +341,13 @@ class ModelRunner: ) head_dim = self.model_config.head_dim head_num = self.model_config.get_num_kv_heads(self.tp_size) - cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.dtype) + cell_size = ( + head_num + * head_dim + * self.model_config.num_hidden_layers + * 2 + * torch._utils._element_size(self.dtype) + ) rest_memory = available_gpu_memory - total_gpu_memory * ( 1 - self.mem_fraction_static ) @@ -384,15 +390,16 @@ class ModelRunner: def init_flash_infer(self): if not global_server_args_dict.get("disable_flashinfer", False): from flashinfer import ( - BatchPrefillWithRaggedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper, BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, ) from flashinfer.decode import _grouped_size_compiled_for_decode_kernels if not _grouped_size_compiled_for_decode_kernels( self.model_config.num_attention_heads // self.tp_size, - self.model_config.get_num_kv_heads(self.tp_size)): + self.model_config.get_num_kv_heads(self.tp_size), + ): use_tensor_cores = True else: use_tensor_cores = False @@ -400,8 +407,8 @@ class ModelRunner: workspace_buffers = torch.empty( 3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda" ) - self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( - workspace_buffers[0], "NHD" + self.flashinfer_prefill_wrapper_ragged = ( + BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD") ) self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( workspace_buffers[1], "NHD" @@ -410,7 +417,9 @@ class ModelRunner: workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores ) else: - self.flashinfer_prefill_wrapper_ragged = self.flashinfer_prefill_wrapper_paged = None + self.flashinfer_prefill_wrapper_ragged = ( + self.flashinfer_prefill_wrapper_paged + ) = None self.flashinfer_decode_wrapper = None @torch.inference_mode() diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index a788118ec..6d92c6bab 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -34,11 +34,11 @@ from sglang.srt.managers.io_struct import ( from sglang.srt.model_config import ModelConfig from sglang.srt.server_args import ModelPortArgs, ServerArgs from sglang.srt.utils import ( + connect_rpyc_service, get_int_token_logit_bias, is_multimodal_model, set_random_seed, start_rpyc_service_process, - connect_rpyc_service, suppress_other_loggers, ) from sglang.utils import get_exception_traceback @@ -368,9 +368,11 @@ class ModelTpServer: if ( req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens < available_size - and (req.extend_input_len + new_batch_input_tokens - <= self.max_prefill_tokens - or len(can_run_list) == 0) + and ( + req.extend_input_len + new_batch_input_tokens + <= self.max_prefill_tokens + or len(can_run_list) == 0 + ) ): delta = self.tree_cache.inc_lock_ref(req.last_node) available_size += delta @@ -452,7 +454,9 @@ class ModelTpServer: next_token_ids, ].tolist() output.prefill_token_logprobs = output.prefill_token_logprobs.tolist() - output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist() + output.normalized_prompt_logprobs = ( + output.normalized_prompt_logprobs.tolist() + ) next_token_ids = next_token_ids.tolist() else: @@ -582,7 +586,9 @@ class ModelTpServer: req.check_finished() if req.return_logprob: - req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id)) + req.decode_token_logprobs.append( + (next_token_logprobs[i], next_token_id) + ) if req.top_logprobs_num > 0: req.decode_top_logprobs.append(output.decode_top_logprobs[i]) @@ -759,16 +765,27 @@ class ModelTpClient: with ThreadPoolExecutor(self.tp_size) as executor: # Launch model processes if server_args.nnodes == 1: - self.procs = list(executor.map( - lambda args: start_rpyc_service_process(*args), - [(ModelTpService, p) for p in model_port_args.model_tp_ports], - )) + self.procs = list( + executor.map( + lambda args: start_rpyc_service_process(*args), + [ + (ModelTpService, p) + for p in model_port_args.model_tp_ports + ], + ) + ) addrs = [("localhost", p) for p in model_port_args.model_tp_ports] else: - addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)] + addrs = [ + (ip, port) + for ip, port in zip( + model_port_args.model_tp_ips, model_port_args.model_tp_ports + ) + ] - self.model_services = list(executor.map( - lambda args: connect_rpyc_service(*args), addrs)) + self.model_services = list( + executor.map(lambda args: connect_rpyc_service(*args), addrs) + ) # Init model def init_model(i): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 0d137eb8a..42f970370 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -334,15 +334,15 @@ class TokenizerManager: ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs ) if top_logprobs_num > 0: - ret["meta_info"][ - "prefill_top_logprobs" - ] = self.detokenize_top_logprobs_tokens( - ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs + ret["meta_info"]["prefill_top_logprobs"] = ( + self.detokenize_top_logprobs_tokens( + ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs + ) ) - ret["meta_info"][ - "decode_top_logprobs" - ] = self.detokenize_top_logprobs_tokens( - ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs + ret["meta_info"]["decode_top_logprobs"] = ( + self.detokenize_top_logprobs_tokens( + ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs + ) ) return ret diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 3a54d25c2..2d92b53c9 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -5,19 +5,23 @@ from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn from transformers import Gemma2Config - from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size + +# FIXME: temporary solution, remove after next vllm release +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import GeluAndMul + # from vllm.model_executor.layers.layernorm import GemmaRMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -26,8 +30,6 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.controller.model_runner import InputMetadata -# FIXME: temporary solution, remove after next vllm release -from vllm.model_executor.custom_op import CustomOp class GemmaRMSNorm(CustomOp): """RMS normalization for Gemma. @@ -76,13 +78,19 @@ class GemmaRMSNorm(CustomOp): # FIXME: temporary solution, remove after next vllm release from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding + + class GemmaRotaryEmbedding(RotaryEmbedding): def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 - inv_freq = 1.0 / (base**( - torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() / - self.rotary_dim)) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() + / self.rotary_dim + ) + ) return inv_freq @@ -98,18 +106,17 @@ class Gemma2MLP(nn.Module): ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, quant_config=quant_config + ) if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"): raise ValueError( "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation " "function. Please set `hidden_act` and `hidden_activation` to " - "`gelu_pytorch_tanh`.") + "`gelu_pytorch_tanh`." + ) self.act_fn = GeluAndMul(approximate="tanh") def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -121,17 +128,19 @@ class Gemma2MLP(nn.Module): class Gemma2Attention(nn.Module): - def __init__(self, - layer_idx: int, - config: Gemma2Config, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - max_position_embeddings: int, - rope_theta: float, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + def __init__( + self, + layer_idx: int, + config: Gemma2Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + rope_theta: float, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: super().__init__() self.layer_idx = layer_idx self.config = config @@ -183,15 +192,16 @@ class Gemma2Attention(nn.Module): # from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every # odd layer, vLLM currently ignores it and uses global attention for # all layers. - use_sliding_window = (layer_idx % 2 == 1 - and config.sliding_window is not None) + use_sliding_window = layer_idx % 2 == 1 and config.sliding_window is not None del use_sliding_window # Unused. - self.attn = RadixAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - layer_id=layer_idx, - logit_cap=self.config.attn_logit_softcapping) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_idx, + logit_cap=self.config.attn_logit_softcapping, + ) def forward( self, @@ -238,14 +248,16 @@ class Gemma2DecoderLayer(nn.Module): hidden_activation=config.hidden_activation, quant_config=quant_config, ) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -258,8 +270,7 @@ class Gemma2DecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -268,7 +279,8 @@ class Gemma2DecoderLayer(nn.Module): hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, residual = self.pre_feedforward_layernorm( - hidden_states, residual) + hidden_states, residual + ) hidden_states = self.mlp(hidden_states) hidden_states = self.post_feedforward_layernorm(hidden_states) return hidden_states, residual @@ -289,10 +301,12 @@ class Gemma2Model(nn.Module): config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config) + for layer_idx in range(config.num_hidden_layers) + ] + ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Normalize the embedding by sqrt(hidden_size) @@ -392,7 +406,7 @@ class Gemma2ForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: - for (param_name, shard_name, shard_id) in stacked_params_mapping: + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue name = name.replace(shard_name, param_name) @@ -412,8 +426,7 @@ class Gemma2ForCausalLM(nn.Module): if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -421,7 +434,8 @@ class Gemma2ForCausalLM(nn.Module): if unloaded_params: raise RuntimeError( "Some weights are not initialized from checkpoints: " - f"{unloaded_params}") + f"{unloaded_params}" + ) -EntryClass = Gemma2ForCausalLM \ No newline at end of file +EntryClass = Gemma2ForCausalLM diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index cd053cf66..eb9dde45c 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -5,14 +5,12 @@ import tqdm from torch import nn from transformers import LlamaConfig from vllm.config import CacheConfig -from vllm.distributed import ( - get_tensor_model_parallel_rank, -) +from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.layers.logits_processor import LogitProcessorOutput +from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.models.llama2 import LlamaModel @@ -28,7 +26,9 @@ class LlamaForClassification(nn.Module): self.quant_config = quant_config self.model = LlamaModel(config, quant_config=quant_config) - self.classification_head = nn.Linear(config.hidden_size, config.classification_out_size) + self.classification_head = nn.Linear( + config.hidden_size, config.classification_out_size + ) self.eos_token_id = config.eos_token_id def forward( @@ -45,7 +45,9 @@ class LlamaForClassification(nn.Module): if scores.shape[0] != input_metadata.batch_size: print("Warning: the EOS tokens are missing in some sentences.") - scores = torch.ones((input_metadata.batch_size, self.config.classification_out_size)).to(input_ids.device) + scores = torch.ones( + (input_metadata.batch_size, self.config.classification_out_size) + ).to(input_ids.device) return LogitProcessorOutput( next_token_logits=scores, @@ -101,4 +103,5 @@ class LlamaForClassification(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) -EntryClass = LlamaForClassification \ No newline at end of file + +EntryClass = LlamaForClassification diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 045609b34..eb37c7bb5 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -51,13 +51,12 @@ from sglang.srt.utils import ( allocate_init_ports, assert_pkg_version, enable_show_time_cost, - send_addrs_to_rank_0, receive_addrs, + send_addrs_to_rank_0, start_rpyc_service_process, ) from sglang.utils import get_exception_traceback - logger = logging.getLogger(__name__) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -152,9 +151,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg if server_args.disable_disk_cache: disable_cache() if not server_args.disable_flashinfer: - assert_pkg_version("flashinfer", "0.0.8", "Please uninstall the old version and " - "reinstall the latest version by following the instructions " - "at https://docs.flashinfer.ai/installation.html.") + assert_pkg_version( + "flashinfer", + "0.0.8", + "Please uninstall the old version and " + "reinstall the latest version by following the instructions " + "at https://docs.flashinfer.ai/installation.html.", + ) if server_args.chat_template: # TODO: replace this with huggingface transformers template load_chat_template_for_openai_api(server_args.chat_template) @@ -176,7 +179,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ModelPortArgs( nccl_port=ports[3 + i * (tp_size_local + 1)], model_tp_ips=[None] * tp_size_local, - model_tp_ports=ports[3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)], + model_tp_ports=ports[ + 3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1) + ], ) ) port_args = PortArgs( @@ -194,9 +199,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg else: receive_addrs(model_port_args[0], server_args) for i in range(tp_size_local): - start_rpyc_service_process(ModelTpService, model_port_args[0].model_tp_ports[i]) + start_rpyc_service_process( + ModelTpService, model_port_args[0].model_tp_ports[i] + ) if server_args.node_rank != 0: - logger.info(f"[node_rank={server_args.node_rank}]: Listen for connections...") + logger.info( + f"[node_rank={server_args.node_rank}]: Listen for connections..." + ) while True: pass diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a1b8014bc..14cf4d3b0 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -137,17 +137,16 @@ class ServerArgs: "--dtype", type=str, default=ServerArgs.dtype, - choices=[ - "auto", "half", "float16", "bfloat16", "float", "float32" - ], - help='Data type for model weights and activations.\n\n' + choices=["auto", "half", "float16", "bfloat16", "float", "float32"], + help="Data type for model weights and activations.\n\n" '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - 'BF16 precision for BF16 models.\n' + "BF16 precision for BF16 models.\n" '* "half" for FP16. Recommended for AWQ quantization.\n' '* "float16" is the same as "half".\n' '* "bfloat16" for a balance between precision and range.\n' '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.') + '* "float32" for FP32 precision.', + ) parser.add_argument( "--trust-remote-code", action="store_true", @@ -271,19 +270,12 @@ class ServerArgs: parser.add_argument( "--nccl-init-addr", type=str, - help="The nccl init address of multi-node server." + help="The nccl init address of multi-node server.", ) parser.add_argument( - "--nnodes", - type=int, - default=1, - help="The number of nodes." - ) - parser.add_argument( - "--node-rank", - type=int, - help="The node rank." + "--nnodes", type=int, default=1, help="The number of nodes." ) + parser.add_argument("--node-rank", type=int, help="The node rank.") # Optimization/debug options parser.add_argument( diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 6da86cbeb..a9ea62e4b 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -432,13 +432,12 @@ def assert_pkg_version(pkg: str, min_version: str, message: str): if pkg_version.parse(installed_version) < pkg_version.parse(min_version): raise Exception( f"{pkg} is installed with version {installed_version}, which " - f"is less than the minimum required version {min_version}. " + - message + f"is less than the minimum required version {min_version}. " + message ) except PackageNotFoundError: raise Exception( - f"{pkg} with minimum required version {min_version} is not installed. " + - message + f"{pkg} with minimum required version {min_version} is not installed. " + + message ) @@ -474,24 +473,40 @@ def monkey_patch_vllm_dummy_weight_loader(): """ from vllm.model_executor.model_loader.loader import ( - ModelConfig, DeviceConfig, LoRAConfig, VisionLanguageConfig, - ParallelConfig, SchedulerConfig, CacheConfig, nn, - set_default_torch_dtype, _initialize_model, initialize_dummy_weights, - DummyModelLoader + CacheConfig, + DeviceConfig, + DummyModelLoader, + LoRAConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + VisionLanguageConfig, + _initialize_model, + initialize_dummy_weights, + nn, + set_default_torch_dtype, ) - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config, - cache_config) + model = _initialize_model( + model_config, + self.load_config, + lora_config, + vision_language_config, + cache_config, + ) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) @@ -541,7 +556,7 @@ def get_ip_address(ifname): ip_address = fcntl.ioctl( s.fileno(), 0x8915, # SIOCGIFADDR - struct.pack('256s', bytes(ifname[:15], 'utf-8')) + struct.pack("256s", bytes(ifname[:15], "utf-8")), )[20:24] return socket.inet_ntoa(ip_address) @@ -550,44 +565,66 @@ def send_addrs_to_rank_0(model_port_args, server_args): assert server_args.node_rank != 0 and server_args.dp_size == 1 import torch.distributed as dist - ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")) + ifname = os.environ.get( + "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0") + ) ip_addr = get_ip_address(ifname) num_tp_ports = server_args.tp_size // server_args.nnodes model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports ip_addr = [int(x) for x in ip_addr.split(".")] - addrs_tensor = torch.tensor(ip_addr + model_port_args.model_tp_ports, dtype=torch.int) + addrs_tensor = torch.tensor( + ip_addr + model_port_args.model_tp_ports, dtype=torch.int + ) init_method = f"tcp://{server_args.nccl_init_addr}" - dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes) + dist.init_process_group( + backend="gloo", + init_method=init_method, + rank=server_args.node_rank, + world_size=server_args.nnodes, + ) dist.send(addrs_tensor, dst=0) - print(f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}") + print( + f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}" + ) dist.barrier() - dist.destroy_process_group() + dist.destroy_process_group() def receive_addrs(model_port_args, server_args): assert server_args.node_rank == 0 and server_args.dp_size == 1 import torch.distributed as dist - ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")) + ifname = os.environ.get( + "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0") + ) ip_addr = get_ip_address(ifname) num_tp_ports = server_args.tp_size // server_args.nnodes model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports init_method = f"tcp://{server_args.nccl_init_addr}" - dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes) + dist.init_process_group( + backend="gloo", + init_method=init_method, + rank=server_args.node_rank, + world_size=server_args.nnodes, + ) for src_rank in range(1, server_args.nnodes): tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int) dist.recv(tensor, src=src_rank) ip = ".".join([str(x) for x in tensor[:4].tolist()]) ports = tensor[4:].tolist() - model_port_args.model_tp_ips[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = [ip] * num_tp_ports - model_port_args.model_tp_ports[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = ports + model_port_args.model_tp_ips[ + num_tp_ports * src_rank : num_tp_ports * (src_rank + 1) + ] = [ip] * num_tp_ports + model_port_args.model_tp_ports[ + num_tp_ports * src_rank : num_tp_ports * (src_rank + 1) + ] = ports print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}") dist.barrier() - dist.destroy_process_group() + dist.destroy_process_group() diff --git a/test/srt/test_httpserver_classify.py b/test/srt/test_httpserver_classify.py index 40da2b749..cafbd19fd 100644 --- a/test/srt/test_httpserver_classify.py +++ b/test/srt/test_httpserver_classify.py @@ -37,10 +37,12 @@ def get_logits_batch(url, prompts): }, ) ret = response.json() - logits = np.array(list( - ret[i]["meta_info"]["normalized_prompt_logprob"] - for i in range(len(prompts)) - )) + logits = np.array( + list( + ret[i]["meta_info"]["normalized_prompt_logprob"] + for i in range(len(prompts)) + ) + ) return logits @@ -64,4 +66,4 @@ if __name__ == "__main__": "This is a long long long long test prompt.<|eot_id|>", ] logits = get_logits_batch(url, prompts) - print(f"{logits=}") \ No newline at end of file + print(f"{logits=}")