diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 790e9a858..dcee80b89 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -1,7 +1,7 @@ # Supported Models ## Generative Models -- Llama / Llama 2 / Llama 3 / Llama 3.1 / Llama 3.2 / Llama 3.3 +- Llama / Llama 2 / Llama 3 / Llama 3.1 / Llama 3.2 / Llama 3.3 / Llama 4 - Mistral / Mixtral / Mistral NeMo / Mistral Small 3 - Gemma / Gemma 2 / Gemma3 - Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL / Qwen 2.5 VL / Olympic Coder diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 8554d28d0..36128bab1 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -294,6 +294,30 @@ register_chat_template( ) ) +# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json +register_chat_template( + ChatTemplate( + name="llama-4", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "<|header_start|>system<|header_end|>\n\n", + "<|eot|>", + ), + "user": ( + "<|header_start|>user<|header_end|>\n\n", + "<|eot|>", + ), + "assistant": ( + "<|header_start|>assistant<|header_end|>\n\n", + "<|eot|>", + ), + }, + stop_str=("<|eot|>",), + image_token="<|image|>", + ) +) + # Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1 register_chat_template( ChatTemplate( diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index bb16e2747..654095c85 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -65,6 +65,9 @@ class ModelConfig: **kwargs, ) self.hf_text_config = get_hf_text_config(self.hf_config) + self.attention_chunk_size = getattr( + self.hf_text_config, "attention_chunk_size", None + ) # Check model type self.is_generation = is_generation_model( @@ -467,6 +470,7 @@ multimodal_model_archs = [ "Gemma3ForConditionalGeneration", "Grok1VForCausalLM", "Grok1AForCausalLM", + # TODO: add multimodal support for "Llama4ForConditionalGeneration", "LlavaLlamaForCausalLM", "LlavaMistralForCausalLM", "LlavaQwenForCausalLM", diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 70152a6b7..8e9129224 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -33,6 +33,7 @@ class SeparatorStyle(IntEnum): ADD_NEW_LINE_SINGLE = auto() LLAMA2 = auto() LLAMA3 = auto() + LLAMA4 = auto() CHATGLM = auto() CHATML = auto() CHATINTERN = auto() @@ -156,19 +157,30 @@ class Conversation: else: ret += role + ":" return ret - elif self.sep_style == SeparatorStyle.LLAMA3: - ret = "<|begin_of_text|>" + elif self.sep_style == SeparatorStyle.LLAMA4: + # begin_of_text is added by default if self.system_message: - ret += system_prompt + ret = system_prompt else: - ret += "" + ret = "" + for i, (role, message) in enumerate(self.messages): + if message: + ret += f"<|header_start|>{role}<|header_end|>\n\n" + ret += f"{message.strip()}<|eot|>" + else: + ret += f"<|header_start|>{role}<|header_end|>\n\n" + return ret + elif self.sep_style == SeparatorStyle.LLAMA3: + if self.system_message: + ret = system_prompt + else: + ret = "" for i, (role, message) in enumerate(self.messages): if message: ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" ret += f"{message.strip()}<|eot_id|>" else: ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" - # print(ret) return ret elif self.sep_style == SeparatorStyle.LLAMA2: seps = [self.sep, self.sep2] @@ -561,6 +573,19 @@ register_conv_template( ) ) +# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json +register_conv_template( + Conversation( + name="llama-4", + system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>", + roles=("user", "assistant"), + sep_style=SeparatorStyle.LLAMA4, + sep="", + stop_str=["<|end_of_text|>", "<|eot|>", "<|eom|>"], + image_token="<|image|>", + ) +) + register_conv_template( Conversation( name="chatml", diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 051603a87..62604fe56 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1,5 +1,7 @@ from __future__ import annotations +import numpy as np + from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput """ @@ -45,6 +47,206 @@ class FlashAttentionMetadata: # Sequence lengths for the forward batch cache_seqlens_int32: torch.Tensor = None + @dataclass + class LocalAttentionMetadata: + local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention + local_seqused_k: torch.Tensor = None # sequence lengths for local attention + local_block_table: torch.Tensor = None # block table for local attention + local_max_query_len: int = 0 # max query length for local attention + local_max_seq_len: int = 0 # max sequence length for local attention + + local_attn_metadata: Optional[LocalAttentionMetadata] = None + + +# Copied from: +# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py +# +# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into +# local attention blocks, where each block is passed to the attention kernel +# as an independent local ("virtual") batch item. +# +# For example, if are performing a chunked prefill a batch of 3 sequences: +# q_seqlens = [4, 10, 5] +# kv_seqlens = [6, 17, 9] +# Then normally for regular attention we would compute with an attention mask +# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like: +# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6) +# k_toks > 0 1 2 3 4 5 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# 2 | 1 1 1 1 1 +# 3 | 1 1 1 1 1 1 +# +# for local attention (with attn_chunk_size = 4) we would compute with an +# attention mask like: +# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4) +# k_toks > 0 1 2 3 4 5 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# 2 | 1 +# 3 | 1 1 +# +# We can simulate this mask using standard flash-attention by breaking the +# sequences into local ("virtual") batches, where each local batch item is a +# local attention block, so in this case batch idx 0 would be broken up into: +# +# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0) +# k_toks > 0 1 2 3 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0) +# k_toks > 4 5 +# q_toks v _____________ +# 2 | 1 +# 3 | 1 1 +# +# e.g. if we have: +# attn_chunk_size = 4 +# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5]) +# Then this function would return: +# __b0__ ______b1______ __b2__ < orig batch indices +# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1] +# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24] +# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1] +# block_table_local : shape[local_virtual_batches, pages_per_local_batch] +def make_local_attention_virtual_batches( + attn_chunk_size: int, + query_start_loc_np: np.ndarray, + seq_lens_np: np.ndarray, + block_table: torch.Tensor, + page_size: int = 0, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: + """ + Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into + local attention blocks, where each block is passed to the attention kernel + as an independent local ("virtual") batch item. + + Args: + attn_chunk_size: Size of local attention chunks + query_start_loc_np: Cumulative sum of query lengths (numpy array) + seq_lens_np: Sequence lengths (numpy array) + block_table: Block table for KV cache + page_size: Size of each page in the KV cache + + Returns: + seqlens_q_local: Query sequence lengths for local attention + cu_seqlens_q_local: Cumulative sum of query sequence lengths for local attention + seqlens_k_local: Key sequence lengths for local attention + block_table_local: Block table for local attention + """ + q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] + actual_batch_size = seq_lens_np.shape[0] + + # Handle if we are starting in the middle of a local attention block, + # we assume q_seqlens > 0 (for all elements), for each batch idx we compute + # the number of tokens that are not in the first local attention block and + # then we can simply use a cdiv for the rest. + # For example if we have: + # attn_chunk_size = 4 + # q_seqlens = [4, 10, 5] + # k_seqlens = [6, 17, 9] + # Then we would get: + # new_tokens_in_first_block = [2, 1, 4] + # local_blocks = [2, 4, 2] + q_tokens_in_first_block = np.minimum( + attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens + ).astype(np.int32) + tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) + local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size) + + # Once we know the number of local blocks we can compute the request spans + # for each batch idx, we can figure out the number of "virtual" requests we + # have to make, + # For the above example we would get: + # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] + # + # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) + # (TODO: max a utility to share this code with _prepare_inputs) + # arange step 1. [2, 4, 2] -> [2, 6, 8] + cu_num_blocks = np.cumsum(local_blocks) + virtual_batches = cu_num_blocks[-1] + # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] + block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) + # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] + arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets + # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) + rarange = np.repeat(local_blocks, local_blocks) - arange - 1 + # Then we can compute the seqlens_q_local, handling the fact that the + # first and last blocks could be partial + seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) + # set the first block since this may be a partial block + seqlens_q_local[arange == 0] = q_tokens_in_first_block + # set the remaining blocks + seqlens_q_local[arange > 0] = np.minimum( + seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size + )[arange > 0] + + # convert from q_seqlens to cu_seqlens_q + cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32) + + # compute the seqlens_k_local, + # basically a full local attention block for all but the last block in each + # batch + # For our example this will be: + # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] + seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32) + seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block + + k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - ( + rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks) + ) + # For the example the local attention blocks start at: + # _b0_ _____b1_____ _b2_ + # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] + block_starts = k_seqstarts_absolute // page_size + + assert attn_chunk_size % page_size == 0, ( + f"attn_chunk_size {attn_chunk_size} is not " + f"divisible by page_size {page_size}" + ) + pages_per_local_batch = attn_chunk_size // page_size + + # Create a block_table for the local attention blocks + # For out example if we have a block-table like (assuming page_size=2): + # block_table = [ + # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 + # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 + # ] + # Then for the local batches we would want a block-table like + # block_table_local = [ + # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) + # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) + # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) + # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) + # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) + # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) + # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) + # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) + # ] + block_indices = np.broadcast_to( + np.arange(pages_per_local_batch, dtype=np.int32), + (virtual_batches, pages_per_local_batch), + ) + np.expand_dims(block_starts, axis=1) + block_indices = block_indices.flatten() + batch_indices = np.repeat( + np.arange(actual_batch_size, dtype=np.int32), + local_blocks * pages_per_local_batch, + ) + block_table_local = block_table[batch_indices, block_indices].view( + virtual_batches, -1 + ) + + return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local + + +def cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return -(a // -b) + class FlashAttentionBackend(AttentionBackend): """FlashAttention backend implementation. @@ -100,6 +302,13 @@ class FlashAttentionBackend(AttentionBackend): self.step_id = step_id self.speculative_num_steps = speculative_num_steps + # Local attention settings + self.attention_chunk_size = ( + model_runner.attention_chunk_size + if hasattr(model_runner, "attention_chunk_size") + else None + ) + def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize forward metadata to cache repetitive calculations.""" metadata = FlashAttentionMetadata() @@ -189,6 +398,7 @@ class FlashAttentionBackend(AttentionBackend): metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ] + # Precompute cumulative sequence lengths if ( any(forward_batch.extend_prefix_lens_cpu) @@ -203,6 +413,51 @@ class FlashAttentionBackend(AttentionBackend): metadata.cu_seqlens_q = metadata.cu_seqlens_k metadata.max_seq_len_q = metadata.max_seq_len_k + # Setup local attention if enabled + if ( + self.attention_chunk_size is not None + and forward_batch.forward_mode == ForwardMode.EXTEND + ): + # Convert tensors to numpy for local attention processing + cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy() + seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy() + + # Adjust attention_chunk_size based on the actual sequence length + # to avoid index out of bounds errors + max_seq_len = seq_lens_np.max() + effective_chunk_size = min(self.attention_chunk_size, max_seq_len) + # Make sure effective_chunk_size is divisible by page_size + effective_chunk_size = ( + effective_chunk_size // self.page_size + ) * self.page_size + if effective_chunk_size < self.page_size: + effective_chunk_size = self.page_size + + # Create local attention metadata + ( + seqlens_q_local_np, + cu_seqlens_q_local_np, + seqlens_k_local_np, + block_table_local, + ) = make_local_attention_virtual_batches( + effective_chunk_size, + cu_seqlens_q_np, + seq_lens_np, + metadata.page_table, + self.page_size, + ) + + local_metadata = FlashAttentionMetadata.LocalAttentionMetadata( + local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to( + device + ), + local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device), + local_block_table=block_table_local, + local_max_query_len=seqlens_q_local_np.max(), + local_max_seq_len=seqlens_k_local_np.max(), + ) + metadata.local_attn_metadata = local_metadata + # Precompute strided indices if self.page_size > 1: self.strided_indices = torch.arange( @@ -211,6 +466,7 @@ class FlashAttentionBackend(AttentionBackend): metadata.page_table = ( metadata.page_table[:, self.strided_indices] // self.page_size ) + self.forward_metadata = metadata def forward_extend( @@ -254,7 +510,28 @@ class FlashAttentionBackend(AttentionBackend): else (-1, -1) ) - page_table = metadata.page_table + # Check if we should use local attention + use_local_attn = ( + self.attention_chunk_size is not None + and metadata.local_attn_metadata is not None + and (hasattr(layer, "use_irope") and layer.use_irope) + ) + + # Get the appropriate page table based on whether we're using local attention + if use_local_attn: + local_metadata = metadata.local_attn_metadata + page_table = local_metadata.local_block_table + cu_seqlens_q = local_metadata.local_query_start_loc + cache_seqlens = local_metadata.local_seqused_k + max_seqlen_q = local_metadata.local_max_query_len + max_seqlen_k = local_metadata.local_max_seq_len + else: + page_table = metadata.page_table + cu_seqlens_q = metadata.cu_seqlens_q + cache_seqlens = metadata.cache_seqlens_int32 + max_seqlen_q = metadata.max_seq_len_q + max_seqlen_k = metadata.max_seq_len_k + cu_seqlens_k = metadata.cu_seqlens_k # Use Flash Attention for prefill if not self.use_mla: @@ -272,10 +549,10 @@ class FlashAttentionBackend(AttentionBackend): k_cache=key_cache, v_cache=value_cache, page_table=page_table, - cache_seqlens=metadata.cache_seqlens_int32, - cu_seqlens_q=metadata.cu_seqlens_q, - cu_seqlens_k_new=metadata.cu_seqlens_k, - max_seqlen_q=metadata.max_seq_len_q, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, + max_seqlen_q=max_seqlen_q, softmax_scale=layer.scaling, causal=True, window_size=window_size, @@ -307,10 +584,10 @@ class FlashAttentionBackend(AttentionBackend): v_cache=c_kv_cache, qv=q_nope, page_table=page_table, - cache_seqlens=metadata.cache_seqlens_int32, - cu_seqlens_q=metadata.cu_seqlens_q, - cu_seqlens_k_new=metadata.cu_seqlens_k, - max_seqlen_q=metadata.max_seq_len_q, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, + max_seqlen_q=max_seqlen_q, softmax_scale=layer.scaling, causal=True, softcap=layer.logit_cap, diff --git a/python/sglang/srt/layers/moe/fused_moe_native.py b/python/sglang/srt/layers/moe/fused_moe_native.py index 1766e2c25..57e910943 100644 --- a/python/sglang/srt/layers/moe/fused_moe_native.py +++ b/python/sglang/srt/layers/moe/fused_moe_native.py @@ -23,9 +23,14 @@ def fused_moe_forward_native( custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, ) -> torch.Tensor: + + if apply_router_weight_on_input: + raise NotImplementedError + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..249359fb9 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..587fb2f2e --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..089894816 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json new file mode 100644 index 000000000..9814a3819 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..7e09be1ff --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..d471f9821 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..fe8dab46e --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 7c9ead9ce..588173cec 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -1079,6 +1079,7 @@ def inplace_fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -1099,6 +1100,7 @@ def inplace_fused_experts( topk_ids, True, activation, + apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, @@ -1120,6 +1122,7 @@ def inplace_fused_experts_fake( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -1150,6 +1153,7 @@ def outplace_fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -1171,6 +1175,7 @@ def outplace_fused_experts( topk_ids, False, activation, + apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, @@ -1193,6 +1198,7 @@ def outplace_fused_experts_fake( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -1225,6 +1231,7 @@ def fused_experts( topk_ids: torch.Tensor, inplace: bool = False, activation: str = "silu", + apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -1247,6 +1254,7 @@ def fused_experts( topk_weights, topk_ids, activation, + apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, @@ -1268,6 +1276,7 @@ def fused_experts( topk_weights, topk_ids, activation, + apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, @@ -1291,6 +1300,7 @@ def fused_experts_impl( topk_ids: torch.Tensor, inplace: bool = False, activation: str = "silu", + apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -1423,7 +1433,7 @@ def fused_experts_impl( sorted_token_ids, expert_ids, num_tokens_post_padded, - False, + apply_router_weight_on_input, topk_ids.shape[1], config, compute_type=compute_type, @@ -1456,7 +1466,7 @@ def fused_experts_impl( ( intermediate_cache3 if not no_combine and topk_ids.shape[1] != 1 - else out_hidden_states[begin_chunk_idx:end_chunk_idx] + else out_hidden_states[begin_chunk_idx:end_chunk_idx].unsqueeze(0) ), a2_scale, w2_scale, @@ -1466,7 +1476,7 @@ def fused_experts_impl( sorted_token_ids, expert_ids, num_tokens_post_padded, - True, + not apply_router_weight_on_input, 1, config, compute_type=compute_type, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index da16e1680..a33cf691f 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -128,6 +128,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, ) -> torch.Tensor: @@ -143,6 +144,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function=custom_routing_function, correction_bias=correction_bias, activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, inplace=inplace, no_combine=no_combine, ) @@ -160,6 +162,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, ) -> torch.Tensor: @@ -200,6 +203,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): topk_ids=topk_ids, inplace=inplace and not no_combine, activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, no_combine=no_combine, ) @@ -276,6 +280,7 @@ class FusedMoE(torch.nn.Module): custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + apply_router_weight_on_input: bool = False, use_presharded_weights: bool = False, inplace: bool = True, no_combine: bool = False, @@ -302,6 +307,7 @@ class FusedMoE(torch.nn.Module): self.custom_routing_function = custom_routing_function self.correction_bias = correction_bias self.activation = activation + self.apply_router_weight_on_input = apply_router_weight_on_input self.use_presharded_weights = use_presharded_weights self.inplace = inplace self.no_combine = no_combine @@ -630,6 +636,7 @@ class FusedMoE(torch.nn.Module): custom_routing_function=self.custom_routing_function, correction_bias=self.correction_bias, activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, ) if self.reduce_results and self.tp_size > 1: diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 3152e265f..b2375bc35 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -280,6 +280,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, ): diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index a5f15c92b..c147d5b2f 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -370,6 +370,7 @@ class BlockInt8MoEMethod: custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, ) -> torch.Tensor: @@ -398,6 +399,7 @@ class BlockInt8MoEMethod: topk_ids=topk_ids, inplace=inplace, activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, use_int8_w8a8=True, w1_scale=(layer.w13_weight_scale_inv), w2_scale=(layer.w2_weight_scale_inv), diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 7e668fb95..f977d899b 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -905,6 +905,7 @@ class Fp8MoEMethod: custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, ) -> torch.Tensor: @@ -975,6 +976,7 @@ class Fp8MoEMethod: topk_ids=topk_ids, inplace=inplace and not no_combine, activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=True, w1_scale=( layer.w13_weight_scale_inv diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py index b99016d95..4c3e1dfc7 100644 --- a/python/sglang/srt/layers/quantization/moe_wna16.py +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -344,6 +344,7 @@ class MoeWNA16Method: custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, ) -> torch.Tensor: @@ -374,6 +375,7 @@ class MoeWNA16Method: topk_weights=topk_weights, topk_ids=topk_ids, inplace=inplace, + apply_router_weight_on_input=apply_router_weight_on_input, use_int4_w4a16=weight_bits == 4, use_int8_w8a16=weight_bits == 8, w1_scale=layer.w13_scales, diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 6467aca5f..280a9a249 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -230,6 +230,7 @@ class W8A8Int8MoEMethod: custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, ) -> torch.Tensor: @@ -257,6 +258,7 @@ class W8A8Int8MoEMethod: topk_ids=topk_ids, inplace=inplace, activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, use_int8_w8a8=True, w1_scale=(layer.w13_weight_scale), w2_scale=(layer.w2_weight_scale), diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 507f6e950..69c105997 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -35,6 +35,7 @@ class RadixAttention(nn.Module): sliding_window_size: int = -1, is_cross_attention: bool = False, prefix: str = "", + use_irope: bool = False, ): super().__init__() self.tp_q_head_num = num_heads @@ -50,6 +51,7 @@ class RadixAttention(nn.Module): self.is_cross_attention = is_cross_attention self.k_scale = None self.v_scale = None + self.use_irope = use_irope def forward( self, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index fb6bdd76b..aadaf4e3e 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -733,6 +733,69 @@ class Llama3RotaryEmbedding(RotaryEmbedding): return new_freqs +class Llama4VisionRotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ): + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + inv_freqs = inv_freqs[: (self.rotary_dim // 2)] + return inv_freqs + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + + # self.max_position_embeddings here is number of image patches + # i.e. (image_size // patch_size) ** 2 + num_patches = self.max_position_embeddings + img_idx = torch.arange(num_patches, dtype=torch.int32).reshape(num_patches, 1) + img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) + img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN + num_patches_single_dim = int(math.sqrt(num_patches)) + frequencies_x = img_idx % num_patches_single_dim + frequencies_y = img_idx // num_patches_single_dim + freqs_x = ( + (frequencies_x + 1)[..., None] * inv_freq[None, None, :] + ).repeat_interleave(2, dim=-1) + freqs_y = ( + (frequencies_y + 1)[..., None] * inv_freq[None, None, :] + ).repeat_interleave(2, dim=-1) + freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] + freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) + cache = torch.view_as_complex( + torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) + ) + return cache + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) + query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2)) + key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2)) + broadcast_shape = [ + d if i == 1 or i == (query_.ndim - 1) else 1 + for i, d in enumerate(query_.shape) + ] + freqs_ci = self.cos_sin_cache.view(*broadcast_shape) + query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) + key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) + return query_out.type_as(query), key_out.type_as(key) + + class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" diff --git a/python/sglang/srt/managers/multimodal_processors/mllama4.py b/python/sglang/srt/managers/multimodal_processors/mllama4.py new file mode 100644 index 000000000..41b6f3835 --- /dev/null +++ b/python/sglang/srt/managers/multimodal_processors/mllama4.py @@ -0,0 +1,161 @@ +from typing import List, Mapping, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import Llama4Processor +from transformers.image_utils import SizeDict +from transformers.models.llama4.image_processing_llama4 import ( + find_supported_resolutions, + get_best_fit, +) + +from sglang.srt.managers.multimodal_processors.base_processor import ( + BaseMultimodalProcessor, + MultimodalSpecialTokens, +) +from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem +from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration +from sglang.srt.utils import load_image + + +class Mllama4ImageProcessor(BaseMultimodalProcessor): + models = [Llama4ForConditionalGeneration] + + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + self.vision_config = hf_config.vision_config + self.text_config = hf_config.text_config + self.multimodal_tokens = MultimodalSpecialTokens( + image_token=_processor.image_token + ) + + async def process_mm_data_async( + self, + image_data: List[Union[str, bytes]], + input_text, + max_req_input_len=None, + *args, + **kwargs, + ): + if not image_data: + return None + + if isinstance(input_text, list): + assert len(input_text) and isinstance(input_text[0], int) + input_text = self._processor.tokenizer.decode(input_text) + + # Process images and text using the base processor's load_mm_data method + processed_data = self.load_mm_data( + prompt=input_text, + multimodal_tokens=self.multimodal_tokens, + max_req_input_len=max_req_input_len or 4096, + image_data=image_data, + return_text=True, + ) + + # Process the images using the processor + processor = Llama4Processor.from_pretrained( + self.server_args.model_path, **kwargs + ) + + # Process the prompt and images + image_inputs = processor( + text=processed_data.input_text, + images=processed_data.images, + return_tensors="pt", + ) + + # Handle image resolutions and aspect ratios + if "pixel_values" in image_inputs: + image_processor = processor.image_processor + tokenizer = self._processor.tokenizer + + # Calculate tile size and find supported resolutions + tile_size = self.vision_config.image_size + max_num_tiles = getattr(self.vision_config, "max_patches", 1) + + possible_resolutions = find_supported_resolutions( + max_num_chunks=max_num_tiles, + patch_size=SizeDict(height=tile_size, width=tile_size), + ) + + # Find best fit for each image + best_fit_sizes = [ + get_best_fit( + (image.size[1], image.size[0]), # (height, width) + torch.tensor(possible_resolutions), + resize_to_max_canvas=image_processor.resize_to_max_canvas, + ) + for image in processed_data.images + ] + + # Calculate aspect ratios and patches per image + aspect_ratios = [ + (image_size[0] // tile_size, image_size[1] // tile_size) + for image_size in best_fit_sizes + ] + + patches_per_image = [ + 1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios + ] + + # Add to image_inputs + image_inputs["aspect_ratios"] = aspect_ratios + image_inputs["patches_per_image"] = torch.tensor(patches_per_image) + + # Process embed_is_patch + vocab = tokenizer.get_vocab() + patch_id = vocab.get(processor.img_patch_token, -1) + image_end_id = vocab.get(processor.end_of_img_token, -1) + + if patch_id != -1 and image_end_id != -1: + input_ids = image_inputs["input_ids"].view(-1) + + # Remove BOS token if present + if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id: + input_ids = input_ids[1:] + + # Find image end indices and split input_ids + image_end_indices = (input_ids == image_end_id).nonzero().view(-1) + + if image_end_indices.size(0) > 0: + # Split at image boundaries + split_indices = (image_end_indices + 1)[:-1] + split_input_ids = torch.tensor_split(input_ids, split_indices) + split_input_ids = [x for x in split_input_ids if x.numel() > 0] + + # Create embed_is_patch for each image + embed_is_patch = [] + for per_image_input_ids in split_input_ids: + embed_is_patch.append(per_image_input_ids == patch_id) + + image_inputs["embed_is_patch"] = embed_is_patch + + # Convert to the format expected by SGLang + image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] + + # Add metadata for image processing + image_inputs["mm_items"] = [ + MultimodalDataItem( + pixel_values=image_inputs["pixel_values"], + modality=Modality.IMAGE, + # Add additional metadata needed for Llama4 vision processing + embed_is_patch=image_inputs.get("embed_is_patch", None), + aspect_ratios=image_inputs.get("aspect_ratios", None), + patches_per_image=image_inputs.get("patches_per_image", None), + ) + ] + + return image_inputs + + def get_patch_per_chunk(self): + """Calculate patches per chunk based on vision config""" + image_size = self.vision_config.image_size + patch_size = self.vision_config.patch_size + + assert ( + image_size % patch_size == 0 + ), f"chunk size {image_size} should be multiple of patch_size {patch_size}" + + ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2))) + return (image_size // patch_size) ** 2 // ds_ratio diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 65a3d5fed..8506485fe 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -128,6 +128,7 @@ class ModelRunner: self.model_config.attention_arch == AttentionArch.MLA and not server_args.disable_mla ) + self.attention_chunk_size = model_config.attention_chunk_size # Model-specific adjustment self.model_specific_adjustment() diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 3ddf2f479..d53100d2c 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -63,6 +63,7 @@ class LlamaMLP(nn.Module): hidden_act: str, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + reduce_results: bool = True, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -78,6 +79,7 @@ class LlamaMLP(nn.Module): bias=False, quant_config=quant_config, prefix=add_prefix("down_proj", prefix), + reduce_results=reduce_results, ) if hidden_act != "silu": raise ValueError( @@ -281,7 +283,7 @@ class LlamaModel(nn.Module): self.layers = make_layers( config.num_hidden_layers, lambda idx, prefix: LlamaDecoderLayer( - config=config, quant_config=quant_config, layer_id=idx, prefix=prefix + config=config, layer_id=idx, quant_config=quant_config, prefix=prefix ), prefix="model.layers", ) @@ -375,9 +377,7 @@ class LlamaForCausalLM(nn.Module): super().__init__() self.config = config self.quant_config = quant_config - self.model = LlamaModel( - config, quant_config=quant_config, prefix=add_prefix("model", prefix) - ) + self.model = self._init_model(config, quant_config, add_prefix("model", prefix)) # Llama 3.2 1B Instruct set tie_word_embeddings to True # Llama 3.1 8B Instruct set tie_word_embeddings to False if self.config.tie_word_embeddings: @@ -402,6 +402,14 @@ class LlamaForCausalLM(nn.Module): self.capture_aux_hidden_states = False + def _init_model( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + return LlamaModel(config, quant_config=quant_config, prefix=prefix) + @torch.no_grad() def forward( self, diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py new file mode 100644 index 000000000..b3dbd50f0 --- /dev/null +++ b/python/sglang/srt/models/llama4.py @@ -0,0 +1,420 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/models/llama4.py +"""Inference-only LLaMA model compatible with HuggingFace weights.""" + +import logging +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import Llama4TextConfig + +from sglang.srt.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP +from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers + +logger = logging.getLogger(__name__) + + +class Llama4MoE(nn.Module): + + @torch.compile(dynamic=True, backend=get_compiler_backend()) + @staticmethod + def custom_routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1) + router_scores_aK = torch.sigmoid(router_scores_aK.float()).to( + hidden_states.dtype + ) + return ( + router_scores_aK.view(-1).reshape(router_scores_aK.shape), + router_indices_aK.to(torch.int32), + ) + + def __init__( + self, + config: Llama4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.top_k = config.num_experts_per_tok + + intermediate_size_moe = config.intermediate_size + self.router = ReplicatedLinear( + config.hidden_size, + config.num_local_experts, + bias=False, + quant_config=None, + prefix=add_prefix("router", prefix), + ) + + self.experts = FusedMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + custom_routing_function=Llama4MoE.custom_routing_function, + intermediate_size=intermediate_size_moe, + reduce_results=False, + renormalize=False, + quant_config=quant_config, + apply_router_weight_on_input=True, + prefix=add_prefix("experts", prefix), + ) + + self.shared_expert = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size_moe, + hidden_act="silu", + quant_config=quant_config, + prefix=add_prefix("shared_expert", prefix), + reduce_results=False, # We need to do scatter before reduce + ) + + def forward(self, hidden_states): + # router_scores: [num_tokens, num_experts] + router_logits, _ = self.router(hidden_states) + shared_out = self.shared_expert(hidden_states) + routed_out = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + ) + out_aD = routed_out + shared_out + + if self.tp_size > 1: + out_aD = tensor_model_parallel_all_reduce(out_aD) + + return out_aD + + +class Llama4Attention(nn.Module): + + def __init__( + self, + config: Llama4TextConfig, + layer_id: int, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_id = layer_id + self.hidden_size = hidden_size + self.use_rope = int((layer_id + 1) % 4 != 0) + self.use_qk_norm = config.use_qk_norm and self.use_rope + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.attn_temperature_tuning = config.attn_temperature_tuning + self.floor_scale = config.floor_scale + self.attn_scale = config.attn_scale + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.n_rep = self.num_heads // self.num_kv_heads + self.qk_norm = ( + RMSNorm( + hidden_size=self.head_dim, + eps=config.rms_norm_eps, + ) + if self.use_qk_norm + else None + ) + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + ) + is_neox_style = True + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and config.model_type in ["llama", "llama4"]: + is_neox_style = False + + self.rotary_emb = ( + get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=int(rope_theta), + rope_scaling=rope_scaling if rope_scaling != "default" else None, + is_neox_style=is_neox_style, + ) + if self.use_rope + else None + ) + + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + prefix=add_prefix("attn", prefix), + use_irope=self.use_rope, + ) + + def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: + floor = torch.floor((positions + 1.0) / self.floor_scale) + attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0 + + return attn_scale.unsqueeze(-1) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + if self.rotary_emb is not None: + q, k = self.rotary_emb(positions, q, k) + + if self.qk_norm is not None: + # TODO: support float + q = q.reshape(-1, self.head_dim).contiguous().bfloat16() + k = k.reshape(-1, self.head_dim).contiguous().bfloat16() + q = self.qk_norm(q).to(q.dtype) + k = self.qk_norm(k).to(k.dtype) + q = q.reshape(-1, self.q_size) + k = k.reshape(-1, self.kv_size) + + # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where + # the inference-time temperature tuning function is customized to not affect short context + # while working at very long context + # https://arxiv.org/abs/2501.19399 + if self.attn_temperature_tuning and not self.use_rope: + attn_scale = self._get_attn_scale(positions) + q = (q * attn_scale).to(q.dtype) + + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class Llama4DecoderLayer(nn.Module): + def __init__( + self, + config: Llama4TextConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.layer_id = layer_id + self.hidden_size = config.hidden_size + rope_theta = config.rope_theta + rope_scaling = config.rope_scaling + max_position_embeddings = config.max_position_embeddings + + self.self_attn = Llama4Attention( + config=config, + layer_id=layer_id, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=False, + bias_o_proj=False, + prefix=add_prefix("self_attn", prefix), + ) + is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0 + if is_moe_layer: + self.feed_forward = Llama4MoE( + config=config, + quant_config=quant_config, + prefix=add_prefix("feed_forward", prefix), + ) + else: + self.feed_forward = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size_mlp, + hidden_act="silu", + quant_config=quant_config, + prefix=add_prefix("feed_forward", prefix), + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class Llama4Model(nn.Module): + def __init__( + self, + config: Llama4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("embed_tokens", prefix), + ) + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: Llama4DecoderLayer( + config=config, layer_id=idx, quant_config=quant_config, prefix=prefix + ), + prefix="model.layers", + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layers_to_capture = [] + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None + aux_hidden_states = [] + for i in range(len(self.layers)): + if i in self.layers_to_capture: + aux_hidden_states.append(hidden_states + residual) + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) == 0: + return hidden_states + + return hidden_states, aux_hidden_states + + +class Llama4ForCausalLM(LlamaForCausalLM): + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__( + self, + config: Llama4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__(config, quant_config, prefix) + + def _init_model( + self, + config: Llama4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + return Llama4Model(config, quant_config=quant_config, prefix=prefix) + + +EntryClass = [Llama4ForCausalLM] diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py new file mode 100644 index 000000000..f254903a2 --- /dev/null +++ b/python/sglang/srt/models/mllama4.py @@ -0,0 +1,154 @@ +# TODO: add Aapted from vllm/mllama4.py +from collections.abc import Iterable +from typing import Optional, Set, Tuple + +import torch +from torch import nn +from transformers import Llama4Config + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix + + +class Llama4ForConditionalGeneration(nn.Module): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + } + + def __init__( + self, + config: Llama4Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.quant_config = quant_config + + # Initialize the language model + from sglang.srt.models.llama4 import Llama4ForCausalLM + + self.language_model = Llama4ForCausalLM( + config.text_config, + quant_config=quant_config, + prefix=add_prefix("language_model", prefix), + ) + + self.logits_processor = LogitsProcessor(config.text_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs: object, + ) -> torch.Tensor: + + return self.language_model(input_ids, positions, forward_batch) + + def permute_qk_weight_for_rotary( + self, + name: str, + loaded_weight: torch.Tensor, + ) -> Tuple[str, torch.Tensor]: + + def permute(w: torch.Tensor, n_heads: int): + attn_in = self.language_model.config.head_dim * n_heads + attn_out = self.language_model.config.hidden_size + + return ( + w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) + .transpose(1, 2) + .reshape(attn_in, attn_out) + ) + + modules = name.split(".") + + # rotary embeds should be sliced + if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight": + loaded_weight = permute( + loaded_weight, self.language_model.config.num_key_value_heads + ) + elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight": + loaded_weight = permute( + loaded_weight, self.language_model.config.num_attention_heads + ) + + return name, loaded_weight + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), + (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), + (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), + (".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0), + (".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1), + (".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0), + (".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + + num_experts = self.config.text_config.num_local_experts + + for name, loaded_weight in weights: + + if name.startswith("vision_model") or name.startswith( + "multi_modal_projector" + ): + continue + + name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight) + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if ".experts" in name: + if ".gate_up_proj" in name: + name_list = [ + name.replace(".experts.gate_up_proj", ".experts.w13_weight") + ] * 2 + loaded_weight_list = loaded_weight.chunk(2, dim=-1) + shard_id_list = ["w1", "w3"] + else: + name_list = [ + name.replace(".experts.down_proj", ".experts.w2_weight") + ] + shard_id_list = ["w2"] + loaded_weight_list = [loaded_weight] + for name, loaded_weight, shard_id in zip( + name_list, loaded_weight_list, shard_id_list + ): + param = params_dict[name] + weight_loader = param.weight_loader + for expert_id in range(num_experts): + weight_loader( + param, + loaded_weight[expert_id].T, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + + +EntryClass = Llama4ForConditionalGeneration