diff --git a/benchmark/latency_throughput/bench_throughput.py b/benchmark/latency_throughput/bench_throughput.py index f058ecad4..323a3f449 100644 --- a/benchmark/latency_throughput/bench_throughput.py +++ b/benchmark/latency_throughput/bench_throughput.py @@ -233,8 +233,10 @@ def main(args: argparse.Namespace): if args.dataset: input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) else: - input_lens = np.random.randint(args.input_len * args.range_ratio, args.input_len + 1, size=args.num_prompts) - output_lens = np.random.randint(args.output_len * args.range_ratio, args.output_len + 1, size=args.num_prompts) + input_lens = np.random.randint( + int(args.input_len * args.range_ratio), args.input_len + 1, size=args.num_prompts) + output_lens = np.random.randint( + int(args.output_len * args.range_ratio), args.output_len + 1, size=args.num_prompts) offsets = np.random.randint(0, tokenizer.vocab_size, size=args.num_prompts) input_requests = [] for i in range(args.num_prompts): @@ -284,13 +286,13 @@ if __name__ == "__main__": choices=["vllm", "tgi", "srt", "lightllm", "xinfer"], ) parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--port", type=int, default=30000) parser.add_argument( "--dataset", type=str, help="Path to the dataset." ) - parser.add_argument("--input-len", type=str, default=1024) - parser.add_argument("--output-len", type=str, default=128) - parser.add_argument("--range-ratio", type=float, default=1.0) + parser.add_argument("--input-len", type=str, default=2048) + parser.add_argument("--output-len", type=str, default=256) + parser.add_argument("--range-ratio", type=float, default=0.5) parser.add_argument( "--tokenizer", type=str, default="NousResearch/Meta-Llama-3-8B", diff --git a/benchmark/mmlu/bench_other.py b/benchmark/mmlu/bench_other.py index c5d48dac6..1799744f0 100644 --- a/benchmark/mmlu/bench_other.py +++ b/benchmark/mmlu/bench_other.py @@ -170,4 +170,4 @@ if __name__ == "__main__": parser.add_argument("--data_dir", type=str, default="data") parser.add_argument("--nsub", type=int, default=60) args = add_common_other_args_and_parse(parser) - main(args) + main(args) \ No newline at end of file diff --git a/python/sglang/srt/layers/fused_moe.py b/python/sglang/srt/layers/fused_moe.py new file mode 100644 index 000000000..776194710 --- /dev/null +++ b/python/sglang/srt/layers/fused_moe.py @@ -0,0 +1,485 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/fused_moe/fused_moe.py#L1 +"""Fused MoE kernel.""" +import functools +import json +import os +from typing import Any, Dict, Optional, Tuple + +import torch +import triton +import triton.language as tl + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.utils import is_hip + +logger = init_logger(__name__) + + +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + + if use_fp8: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + # We accumulate along the K dimension. + if use_fp8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + if use_fp8: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def moe_align_block_size( + topk_ids: torch.Tensor, block_size: int, + num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block + size for matrix multiplication. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the + top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according + to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, + ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process + so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions + align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], + block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, + with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids + [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in + the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible + by block_size for proper block matrix operations. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) + return sorted_ids, expert_ids, num_tokens_post_pad + + +def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, top_k: int, + config: Dict[str, Any], compute_type: tl.dtype, + use_fp8: bool) -> None: + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + if not use_fp8: + assert A_scale is None + assert B_scale is None + else: + A, A_scale = ops.scaled_fp8_quant(A, A_scale) + assert B_scale is not None + + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) + + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2], + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8=use_fp8, + **config, + ) + + +def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: + device_name = torch.cuda.get_device_name().replace(" ", "_") + dtype_selector = "" if not dtype else f",dtype={dtype}" + return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + + +@functools.lru_cache +def get_moe_configs(E: int, N: int, + dtype: Optional[str]) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the fused MoE kernel. + + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the fused_moe kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs + # directory + json_file_name = get_config_file_name(E, N, dtype) + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", + config_file_path) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + return None + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + M, _ = hidden_states.shape + E, N, _ = w1.shape + + if is_hip(): + # The MoE kernels are not yet supported on ROCm. + routing_weights = torch.softmax(gating_output, + dim=-1, + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) + else: + import vllm._moe_C as moe_kernels + + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + moe_kernels.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if override_config: + config = override_config + else: + # First try to load optimal config from the file + configs = get_moe_configs(E, w2.shape[2], + "float8" if use_fp8 else None) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } + + if M <= E: + config = { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } + + intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config['BLOCK_SIZE_M'], E) + compute_type = (tl.bfloat16 + if hidden_states.dtype == torch.bfloat16 else tl.float16) + + invoke_fused_moe_kernel(hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8) + + ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel(intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8) + + if inplace: + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) \ No newline at end of file diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py new file mode 100644 index 000000000..3aeb72850 --- /dev/null +++ b/python/sglang/srt/models/grok.py @@ -0,0 +1,669 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 +"""Inference-only Grok1 model.""" +from typing import Iterable, Optional, Tuple, List + +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +from torch import nn +from transformers import PretrainedConfig + +from vllm import _custom_ops as ops +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.loader import DefaultModelLoader +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import print_warning_once + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.fused_moe import fused_moe +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.router.model_runner import InputMetadata + + +use_fused = True + + +class Grok1MLP(nn.Module): + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.num_experts = num_experts + self.ffn_dim = intermediate_size + self.hidden_dim = hidden_size + + self.w1 = ReplicatedLinear( + self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config + ) + self.w2 = ReplicatedLinear( + self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config + ) + self.w3 = ReplicatedLinear( + self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config + ) + + self.act_fn = nn.GELU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + w1_out, _ = self.w1(hidden_states) + w1_out = self.act_fn(w1_out) + w3_out, _ = self.w3(hidden_states) + current_hidden_states = w1_out * w3_out + current_hidden_states, _ = self.w2(current_hidden_states) + return current_hidden_states + + +class Grok1MoEUnfused(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.num_total_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.num_total_experts}." + ) + # Split experts equally between ranks + self.expert_indicies = np.array_split( + range(self.num_total_experts), self.tp_size + )[self.rank].tolist() + if not self.expert_indicies: + raise ValueError(f"Rank {self.rank} has no experts assigned to it.") + + self.experts = nn.ModuleList( + [ + ( + Grok1MLP( + self.num_total_experts, + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + ) + if idx in self.expert_indicies + else None + ) + for idx in range(self.num_total_experts) + ] + ) + self.gate = ReplicatedLinear( + config.hidden_size, self.num_total_experts, bias=False, quant_config=None + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + router_logits, _ = self.gate(hidden_states) + router_logits = 30 * F.tanh(router_logits / 30) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + routing_weights = routing_weights.to(hidden_states.dtype) + hidden_dim = hidden_states.shape[1] + + final_hidden_states = torch.zeros( + (hidden_states.shape[0], hidden_dim), + dtype=hidden_states.dtype, device=hidden_states.device + ) + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_total_experts).permute(2, 1, 0) + + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states) + + +class Grok1MoE(nn.Module): + """A tensor-parallel MoE implementation for Grok1 that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // self.tp_size + self.quant_config = quant_config + + # FIXME(pcmoritz): Make this more general to support different + # quantization schemes + self.use_fp8 = isinstance(quant_config, Fp8Config) + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=None) + + if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + self.w13_weight = nn.Parameter( + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + dtype=params_dtype)) + self.w2_weight = nn.Parameter( + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + dtype=params_dtype)) + + set_weight_attrs(self.w13_weight, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2_weight, { + "weight_loader": self.weight_loader, + }) + + # Used for fp8. + self.w13_scale = None + self.w2_scale = None + self.a13_scale = None + self.a2_scale = None + + if self.use_fp8: + # WEIGHT_SCALE (for fp8) + self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, + dtype=torch.float32), + requires_grad=False) + self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, + dtype=torch.float32), + requires_grad=False) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(self.w13_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2_scale, { + "weight_loader": self.weight_loader, + }) + + # ACT_SCALE (for fp8) + if quant_config.activation_scheme == "static": + if not quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8.") + self.a13_scale = nn.Parameter(torch.zeros( + self.num_total_experts, dtype=torch.float32), + requires_grad=False) + self.a2_scale = nn.Parameter(torch.zeros( + self.num_total_experts, dtype=torch.float32), + requires_grad=False) + + set_weight_attrs(self.a13_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.a2_scale, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int, pre_sharded: bool): + param_data = param.data + shard_size = self.intermediate_size + if pre_sharded: + # The weight is already sharded. Readl the full shard + shard = slice(None) + else: + tp_rank = get_tensor_model_parallel_rank() + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + if "act_scale" in weight_name or "weight_scale" in weight_name: + param_data[expert_id] = loaded_weight + + def process_weights_after_loading(self): + # Fp8 is the only case where we need to process after loading. + if not self.use_fp8: + return + + # If checkpoint is fp16, quantize here. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like(self.w13_weight.data, + dtype=torch.float8_e4m3fn) + w2_weight = torch.empty_like(self.w2_weight.data, + dtype=torch.float8_e4m3fn) + for expert in range(self.num_total_experts): + w13_weight[expert, :, :], self.w13_scale[ + expert] = ops.scaled_fp8_quant( + self.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], self.w2_scale[ + expert] = ops.scaled_fp8_quant( + self.w2_weight.data[expert, :, :]) + self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) + self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) + + # If checkpoint is fp8 + static, cleanup act_scales. + # Since state_dict has an act_scale per expert but our kernels + # are passed one act_scale shared across all experts. + elif self.quant_config.activation_scheme == "static": + if self.a13_scale is None or self.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + + if (not all_close_1d(self.a13_scale) + or not all_close_1d(self.a2_scale)): + print_warning_once( + "Found act_scales that are not equal for fp8 MoE layer. " + "Using the maximum across experts for each layer. ") + + self.a13_scale = nn.Parameter(self.a13_scale.max(), + requires_grad=False) + self.a2_scale = nn.Parameter(self.a2_scale.max(), + requires_grad=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.w13_weight, + self.w2_weight, + router_logits, + self.top_k, + renormalize=False, + inplace=True, + use_fp8=self.use_fp8, + w1_scale=self.w13_scale, + w2_scale=self.w2_scale, + a1_scale=self.a13_scale, + a2_scale=self.a2_scale) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_size) + + +class Grok1Attention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + logit_cap: float = 30, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = 128 + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + logit_cap=logit_cap, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Grok1DecoderLayer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = Grok1Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + quant_config=quant_config, + ) + if use_fused: + self.block_sparse_moe = Grok1MoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config) + else: + self.block_sparse_moe = Grok1MoEUnfused( + config=config, quant_config=quant_config) + self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + + hidden_states = self.post_attn_norm(self.self_attn( + positions=positions, hidden_states=self.pre_attn_norm(hidden_states), + input_metadata=input_metadata, + )) + hidden_states + + hidden_states = self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states))) + hidden_states + + return hidden_states + + +class Grok1Model(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + Grok1DecoderLayer(config, i, quant_config=quant_config) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + hidden_states.mul_(self.config.embedding_multiplier_scale) + + for i in range(len(self.layers)): + hidden_states = self.layers[i]( + positions, hidden_states, input_metadata + ) + + hidden_states = self.norm(hidden_states) + hidden_states.mul_(self.config.output_multiplier_scale) + return hidden_states + + +class Grok1ModelForCausalLM(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = Grok1Model(config, quant_config=quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config) + + # Monkey patch _prepare_weights to load pre-sharded weights + setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + if use_fused: + expert_params_mapping = [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id) + ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id) + ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + [ + # These are the activation scales for the experts + # (param_name, weight_name, expert_id) + ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", + f"experts.{expert_id}.{weight_name}.act_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + else: + expert_params_mapping = [] + + params_dict = dict(self.named_parameters()) + if get_tensor_model_parallel_rank() == 0: + weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4)) + for name, loaded_weight in weights: + #print(get_tensor_model_parallel_rank(), name) + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id, + pre_sharded=get_tensor_model_parallel_world_size() > 1) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) + + +old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights") +def _prepare_presharded_weights(self, + model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool) -> Tuple[str, List[str], bool]: + import glob + import os + + if get_tensor_model_parallel_world_size() == 1: + return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt) + + tp_rank = get_tensor_model_parallel_rank() + allow_patterns = [f"*-{tp_rank:03d}.bin"] + + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + use_safetensors = False + + return hf_folder, hf_weights_files, use_safetensors + + +EntryClass = Grok1ModelForCausalLM \ No newline at end of file diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f006ad6b1..c222f4378 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -11,13 +11,13 @@ from io import BytesIO from typing import List, Optional import numpy as np -import pydantic import requests import torch +import triton from fastapi.responses import JSONResponse from packaging import version as pkg_version -from pydantic import BaseModel from starlette.middleware.base import BaseHTTPMiddleware +import torch.distributed as dist logger = logging.getLogger(__name__) @@ -178,7 +178,8 @@ def get_int_token_logit_bias(tokenizer, vocab_size): def wrap_kernel_launcher(kernel): """A faster launcher for triton kernels.""" - import torch.distributed as dist + if int(triton.__version__.split(".")[0]) >= 3: + return None if dist.is_initialized(): rank = dist.get_rank() diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index d73c59219..18d0f6c32 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -95,6 +95,11 @@ def call_generate_xinfer(prompt, temperature, max_tokens, stop=None, url=None): sampler_channel = grpc.insecure_channel(url.replace("http://", "")) sampler = sampler_pb2_grpc.SamplerStub(sampler_channel) + if stop is None: + stop_strings = None + else: + stop_strings = [stop] + sample_request = sampler_pb2.SampleTextRequest( prompt=prompt, settings=sampler_pb2.SampleSettings( @@ -102,7 +107,7 @@ def call_generate_xinfer(prompt, temperature, max_tokens, stop=None, url=None): rng_seed=0, temperature=max(temperature, 1e-7), nucleus_p=1, - stop_strings=[stop], + stop_strings=stop_strings, ), ) stream = sampler.SampleText(sample_request)