Add Llama4 support (#5092)

Co-authored-by: Cheng Wan <cwan39@gatech.edu>
Co-authored-by: fzyzcjy <ch271828n@outlook.com>
Co-authored-by: ispobock <ispobaoke@163.com>
This commit is contained in:
Chang Su
2025-04-07 00:29:36 -07:00
committed by GitHub
parent d1bb171180
commit f04c80dc42
27 changed files with 2214 additions and 22 deletions

View File

@@ -294,6 +294,30 @@ register_chat_template(
)
)
# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
register_chat_template(
ChatTemplate(
name="llama-4",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|header_start|>system<|header_end|>\n\n",
"<|eot|>",
),
"user": (
"<|header_start|>user<|header_end|>\n\n",
"<|eot|>",
),
"assistant": (
"<|header_start|>assistant<|header_end|>\n\n",
"<|eot|>",
),
},
stop_str=("<|eot|>",),
image_token="<|image|>",
)
)
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
register_chat_template(
ChatTemplate(

View File

@@ -65,6 +65,9 @@ class ModelConfig:
**kwargs,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None
)
# Check model type
self.is_generation = is_generation_model(
@@ -467,6 +470,7 @@ multimodal_model_archs = [
"Gemma3ForConditionalGeneration",
"Grok1VForCausalLM",
"Grok1AForCausalLM",
# TODO: add multimodal support for "Llama4ForConditionalGeneration",
"LlavaLlamaForCausalLM",
"LlavaMistralForCausalLM",
"LlavaQwenForCausalLM",

View File

@@ -33,6 +33,7 @@ class SeparatorStyle(IntEnum):
ADD_NEW_LINE_SINGLE = auto()
LLAMA2 = auto()
LLAMA3 = auto()
LLAMA4 = auto()
CHATGLM = auto()
CHATML = auto()
CHATINTERN = auto()
@@ -156,19 +157,30 @@ class Conversation:
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.LLAMA3:
ret = "<|begin_of_text|>"
elif self.sep_style == SeparatorStyle.LLAMA4:
# begin_of_text is added by default
if self.system_message:
ret += system_prompt
ret = system_prompt
else:
ret += ""
ret = ""
for i, (role, message) in enumerate(self.messages):
if message:
ret += f"<|header_start|>{role}<|header_end|>\n\n"
ret += f"{message.strip()}<|eot|>"
else:
ret += f"<|header_start|>{role}<|header_end|>\n\n"
return ret
elif self.sep_style == SeparatorStyle.LLAMA3:
if self.system_message:
ret = system_prompt
else:
ret = ""
for i, (role, message) in enumerate(self.messages):
if message:
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
ret += f"{message.strip()}<|eot_id|>"
else:
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
# print(ret)
return ret
elif self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
@@ -561,6 +573,19 @@ register_conv_template(
)
)
# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
register_conv_template(
Conversation(
name="llama-4",
system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA4,
sep="",
stop_str=["<|end_of_text|>", "<|eot|>", "<|eom|>"],
image_token="<|image|>",
)
)
register_conv_template(
Conversation(
name="chatml",

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import numpy as np
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
"""
@@ -45,6 +47,206 @@ class FlashAttentionMetadata:
# Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor = None
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
local_seqused_k: torch.Tensor = None # sequence lengths for local attention
local_block_table: torch.Tensor = None # block table for local attention
local_max_query_len: int = 0 # max query length for local attention
local_max_seq_len: int = 0 # max sequence length for local attention
local_attn_metadata: Optional[LocalAttentionMetadata] = None
# Copied from:
# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
# q_seqlens = [4, 10, 5]
# kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1 1 1 1 1
# 3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
# attention mask like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1
# 3 | 1 1
#
# We can simulate this mask using standard flash-attention by breaking the
# sequences into local ("virtual") batches, where each local batch item is a
# local attention block, so in this case batch idx 0 would be broken up into:
#
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
# k_toks > 0 1 2 3
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
# k_toks > 4 5
# q_toks v _____________
# 2 | 1
# 3 | 1 1
#
# e.g. if we have:
# attn_chunk_size = 4
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
# __b0__ ______b1______ __b2__ < orig batch indices
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def make_local_attention_virtual_batches(
attn_chunk_size: int,
query_start_loc_np: np.ndarray,
seq_lens_np: np.ndarray,
block_table: torch.Tensor,
page_size: int = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
"""
Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
local attention blocks, where each block is passed to the attention kernel
as an independent local ("virtual") batch item.
Args:
attn_chunk_size: Size of local attention chunks
query_start_loc_np: Cumulative sum of query lengths (numpy array)
seq_lens_np: Sequence lengths (numpy array)
block_table: Block table for KV cache
page_size: Size of each page in the KV cache
Returns:
seqlens_q_local: Query sequence lengths for local attention
cu_seqlens_q_local: Cumulative sum of query sequence lengths for local attention
seqlens_k_local: Key sequence lengths for local attention
block_table_local: Block table for local attention
"""
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block = np.minimum(
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
).astype(np.int32)
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks = np.cumsum(local_blocks)
virtual_batches = cu_num_blocks[-1]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
# set the first block since this may be a partial block
seqlens_q_local[arange == 0] = q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local[arange > 0] = np.minimum(
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
)[arange > 0]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
)
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts = k_seqstarts_absolute // page_size
assert attn_chunk_size % page_size == 0, (
f"attn_chunk_size {attn_chunk_size} is not "
f"divisible by page_size {page_size}"
)
pages_per_local_batch = attn_chunk_size // page_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming page_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices = np.broadcast_to(
np.arange(pages_per_local_batch, dtype=np.int32),
(virtual_batches, pages_per_local_batch),
) + np.expand_dims(block_starts, axis=1)
block_indices = block_indices.flatten()
batch_indices = np.repeat(
np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch,
)
block_table_local = block_table[batch_indices, block_indices].view(
virtual_batches, -1
)
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
class FlashAttentionBackend(AttentionBackend):
"""FlashAttention backend implementation.
@@ -100,6 +302,13 @@ class FlashAttentionBackend(AttentionBackend):
self.step_id = step_id
self.speculative_num_steps = speculative_num_steps
# Local attention settings
self.attention_chunk_size = (
model_runner.attention_chunk_size
if hasattr(model_runner, "attention_chunk_size")
else None
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata to cache repetitive calculations."""
metadata = FlashAttentionMetadata()
@@ -189,6 +398,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
# Precompute cumulative sequence lengths
if (
any(forward_batch.extend_prefix_lens_cpu)
@@ -203,6 +413,51 @@ class FlashAttentionBackend(AttentionBackend):
metadata.cu_seqlens_q = metadata.cu_seqlens_k
metadata.max_seq_len_q = metadata.max_seq_len_k
# Setup local attention if enabled
if (
self.attention_chunk_size is not None
and forward_batch.forward_mode == ForwardMode.EXTEND
):
# Convert tensors to numpy for local attention processing
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
# Adjust attention_chunk_size based on the actual sequence length
# to avoid index out of bounds errors
max_seq_len = seq_lens_np.max()
effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
# Make sure effective_chunk_size is divisible by page_size
effective_chunk_size = (
effective_chunk_size // self.page_size
) * self.page_size
if effective_chunk_size < self.page_size:
effective_chunk_size = self.page_size
# Create local attention metadata
(
seqlens_q_local_np,
cu_seqlens_q_local_np,
seqlens_k_local_np,
block_table_local,
) = make_local_attention_virtual_batches(
effective_chunk_size,
cu_seqlens_q_np,
seq_lens_np,
metadata.page_table,
self.page_size,
)
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
device
),
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
local_block_table=block_table_local,
local_max_query_len=seqlens_q_local_np.max(),
local_max_seq_len=seqlens_k_local_np.max(),
)
metadata.local_attn_metadata = local_metadata
# Precompute strided indices
if self.page_size > 1:
self.strided_indices = torch.arange(
@@ -211,6 +466,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table = (
metadata.page_table[:, self.strided_indices] // self.page_size
)
self.forward_metadata = metadata
def forward_extend(
@@ -254,7 +510,28 @@ class FlashAttentionBackend(AttentionBackend):
else (-1, -1)
)
page_table = metadata.page_table
# Check if we should use local attention
use_local_attn = (
self.attention_chunk_size is not None
and metadata.local_attn_metadata is not None
and (hasattr(layer, "use_irope") and layer.use_irope)
)
# Get the appropriate page table based on whether we're using local attention
if use_local_attn:
local_metadata = metadata.local_attn_metadata
page_table = local_metadata.local_block_table
cu_seqlens_q = local_metadata.local_query_start_loc
cache_seqlens = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
else:
page_table = metadata.page_table
cu_seqlens_q = metadata.cu_seqlens_q
cache_seqlens = metadata.cache_seqlens_int32
max_seqlen_q = metadata.max_seq_len_q
max_seqlen_k = metadata.max_seq_len_k
cu_seqlens_k = metadata.cu_seqlens_k
# Use Flash Attention for prefill
if not self.use_mla:
@@ -272,10 +549,10 @@ class FlashAttentionBackend(AttentionBackend):
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=metadata.max_seq_len_q,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
window_size=window_size,
@@ -307,10 +584,10 @@ class FlashAttentionBackend(AttentionBackend):
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=metadata.max_seq_len_q,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,

View File

@@ -23,9 +23,14 @@ def fused_moe_forward_native(
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
if apply_router_weight_on_input:
raise NotImplementedError
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -1079,6 +1079,7 @@ def inplace_fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
@@ -1099,6 +1100,7 @@ def inplace_fused_experts(
topk_ids,
True,
activation,
apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
@@ -1120,6 +1122,7 @@ def inplace_fused_experts_fake(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
@@ -1150,6 +1153,7 @@ def outplace_fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
@@ -1171,6 +1175,7 @@ def outplace_fused_experts(
topk_ids,
False,
activation,
apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
@@ -1193,6 +1198,7 @@ def outplace_fused_experts_fake(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
@@ -1225,6 +1231,7 @@ def fused_experts(
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
@@ -1247,6 +1254,7 @@ def fused_experts(
topk_weights,
topk_ids,
activation,
apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
@@ -1268,6 +1276,7 @@ def fused_experts(
topk_weights,
topk_ids,
activation,
apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
@@ -1291,6 +1300,7 @@ def fused_experts_impl(
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
@@ -1423,7 +1433,7 @@ def fused_experts_impl(
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
apply_router_weight_on_input,
topk_ids.shape[1],
config,
compute_type=compute_type,
@@ -1456,7 +1466,7 @@ def fused_experts_impl(
(
intermediate_cache3
if not no_combine and topk_ids.shape[1] != 1
else out_hidden_states[begin_chunk_idx:end_chunk_idx]
else out_hidden_states[begin_chunk_idx:end_chunk_idx].unsqueeze(0)
),
a2_scale,
w2_scale,
@@ -1466,7 +1476,7 @@ def fused_experts_impl(
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,

View File

@@ -128,6 +128,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
@@ -143,6 +144,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
)
@@ -160,6 +162,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
@@ -200,6 +203,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids=topk_ids,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
)
@@ -276,6 +280,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_presharded_weights: bool = False,
inplace: bool = True,
no_combine: bool = False,
@@ -302,6 +307,7 @@ class FusedMoE(torch.nn.Module):
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.activation = activation
self.apply_router_weight_on_input = apply_router_weight_on_input
self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
@@ -630,6 +636,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
)
if self.reduce_results and self.tp_size > 1:

View File

@@ -280,6 +280,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
):

View File

@@ -370,6 +370,7 @@ class BlockInt8MoEMethod:
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
@@ -398,6 +399,7 @@ class BlockInt8MoEMethod:
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
w1_scale=(layer.w13_weight_scale_inv),
w2_scale=(layer.w2_weight_scale_inv),

View File

@@ -905,6 +905,7 @@ class Fp8MoEMethod:
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
@@ -975,6 +976,7 @@ class Fp8MoEMethod:
topk_ids=topk_ids,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv

View File

@@ -344,6 +344,7 @@ class MoeWNA16Method:
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
@@ -374,6 +375,7 @@ class MoeWNA16Method:
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
w1_scale=layer.w13_scales,

View File

@@ -230,6 +230,7 @@ class W8A8Int8MoEMethod:
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
@@ -257,6 +258,7 @@ class W8A8Int8MoEMethod:
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),

View File

@@ -35,6 +35,7 @@ class RadixAttention(nn.Module):
sliding_window_size: int = -1,
is_cross_attention: bool = False,
prefix: str = "",
use_irope: bool = False,
):
super().__init__()
self.tp_q_head_num = num_heads
@@ -50,6 +51,7 @@ class RadixAttention(nn.Module):
self.is_cross_attention = is_cross_attention
self.k_scale = None
self.v_scale = None
self.use_irope = use_irope
def forward(
self,

View File

@@ -733,6 +733,69 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
return new_freqs
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
):
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base)
inv_freqs = inv_freqs[: (self.rotary_dim // 2)]
return inv_freqs
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
# self.max_position_embeddings here is number of image patches
# i.e. (image_size // patch_size) ** 2
num_patches = self.max_position_embeddings
img_idx = torch.arange(num_patches, dtype=torch.int32).reshape(num_patches, 1)
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
num_patches_single_dim = int(math.sqrt(num_patches))
frequencies_x = img_idx % num_patches_single_dim
frequencies_y = img_idx // num_patches_single_dim
freqs_x = (
(frequencies_x + 1)[..., None] * inv_freq[None, None, :]
).repeat_interleave(2, dim=-1)
freqs_y = (
(frequencies_y + 1)[..., None] * inv_freq[None, None, :]
).repeat_interleave(2, dim=-1)
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
cache = torch.view_as_complex(
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
)
return cache
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
broadcast_shape = [
d if i == 1 or i == (query_.ndim - 1) else 1
for i, d in enumerate(query_.shape)
]
freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
return query_out.type_as(query), key_out.type_as(key)
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""

View File

@@ -0,0 +1,161 @@
from typing import List, Mapping, Optional, Tuple, Union
import torch
from PIL import Image
from transformers import Llama4Processor
from transformers.image_utils import SizeDict
from transformers.models.llama4.image_processing_llama4 import (
find_supported_resolutions,
get_best_fit,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
from sglang.srt.utils import load_image
class Mllama4ImageProcessor(BaseMultimodalProcessor):
models = [Llama4ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.vision_config = hf_config.vision_config
self.text_config = hf_config.text_config
self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token
)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
max_req_input_len=None,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
# Process images and text using the base processor's load_mm_data method
processed_data = self.load_mm_data(
prompt=input_text,
multimodal_tokens=self.multimodal_tokens,
max_req_input_len=max_req_input_len or 4096,
image_data=image_data,
return_text=True,
)
# Process the images using the processor
processor = Llama4Processor.from_pretrained(
self.server_args.model_path, **kwargs
)
# Process the prompt and images
image_inputs = processor(
text=processed_data.input_text,
images=processed_data.images,
return_tensors="pt",
)
# Handle image resolutions and aspect ratios
if "pixel_values" in image_inputs:
image_processor = processor.image_processor
tokenizer = self._processor.tokenizer
# Calculate tile size and find supported resolutions
tile_size = self.vision_config.image_size
max_num_tiles = getattr(self.vision_config, "max_patches", 1)
possible_resolutions = find_supported_resolutions(
max_num_chunks=max_num_tiles,
patch_size=SizeDict(height=tile_size, width=tile_size),
)
# Find best fit for each image
best_fit_sizes = [
get_best_fit(
(image.size[1], image.size[0]), # (height, width)
torch.tensor(possible_resolutions),
resize_to_max_canvas=image_processor.resize_to_max_canvas,
)
for image in processed_data.images
]
# Calculate aspect ratios and patches per image
aspect_ratios = [
(image_size[0] // tile_size, image_size[1] // tile_size)
for image_size in best_fit_sizes
]
patches_per_image = [
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
]
# Add to image_inputs
image_inputs["aspect_ratios"] = aspect_ratios
image_inputs["patches_per_image"] = torch.tensor(patches_per_image)
# Process embed_is_patch
vocab = tokenizer.get_vocab()
patch_id = vocab.get(processor.img_patch_token, -1)
image_end_id = vocab.get(processor.end_of_img_token, -1)
if patch_id != -1 and image_end_id != -1:
input_ids = image_inputs["input_ids"].view(-1)
# Remove BOS token if present
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
input_ids = input_ids[1:]
# Find image end indices and split input_ids
image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
if image_end_indices.size(0) > 0:
# Split at image boundaries
split_indices = (image_end_indices + 1)[:-1]
split_input_ids = torch.tensor_split(input_ids, split_indices)
split_input_ids = [x for x in split_input_ids if x.numel() > 0]
# Create embed_is_patch for each image
embed_is_patch = []
for per_image_input_ids in split_input_ids:
embed_is_patch.append(per_image_input_ids == patch_id)
image_inputs["embed_is_patch"] = embed_is_patch
# Convert to the format expected by SGLang
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
# Add metadata for image processing
image_inputs["mm_items"] = [
MultimodalDataItem(
pixel_values=image_inputs["pixel_values"],
modality=Modality.IMAGE,
# Add additional metadata needed for Llama4 vision processing
embed_is_patch=image_inputs.get("embed_is_patch", None),
aspect_ratios=image_inputs.get("aspect_ratios", None),
patches_per_image=image_inputs.get("patches_per_image", None),
)
]
return image_inputs
def get_patch_per_chunk(self):
"""Calculate patches per chunk based on vision config"""
image_size = self.vision_config.image_size
patch_size = self.vision_config.patch_size
assert (
image_size % patch_size == 0
), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
return (image_size // patch_size) ** 2 // ds_ratio

View File

@@ -128,6 +128,7 @@ class ModelRunner:
self.model_config.attention_arch == AttentionArch.MLA
and not server_args.disable_mla
)
self.attention_chunk_size = model_config.attention_chunk_size
# Model-specific adjustment
self.model_specific_adjustment()

View File

@@ -63,6 +63,7 @@ class LlamaMLP(nn.Module):
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@@ -78,6 +79,7 @@ class LlamaMLP(nn.Module):
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
reduce_results=reduce_results,
)
if hidden_act != "silu":
raise ValueError(
@@ -281,7 +283,7 @@ class LlamaModel(nn.Module):
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: LlamaDecoderLayer(
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
),
prefix="model.layers",
)
@@ -375,9 +377,7 @@ class LlamaForCausalLM(nn.Module):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = LlamaModel(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
# Llama 3.2 1B Instruct set tie_word_embeddings to True
# Llama 3.1 8B Instruct set tie_word_embeddings to False
if self.config.tie_word_embeddings:
@@ -402,6 +402,14 @@ class LlamaForCausalLM(nn.Module):
self.capture_aux_hidden_states = False
def _init_model(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
return LlamaModel(config, quant_config=quant_config, prefix=prefix)
@torch.no_grad()
def forward(
self,

View File

@@ -0,0 +1,420 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from
# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/models/llama4.py
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import Llama4TextConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
logger = logging.getLogger(__name__)
class Llama4MoE(nn.Module):
@torch.compile(dynamic=True, backend=get_compiler_backend())
@staticmethod
def custom_routing_function(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1)
router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
hidden_states.dtype
)
return (
router_scores_aK.view(-1).reshape(router_scores_aK.shape),
router_indices_aK.to(torch.int32),
)
def __init__(
self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok
intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear(
config.hidden_size,
config.num_local_experts,
bias=False,
quant_config=None,
prefix=add_prefix("router", prefix),
)
self.experts = FusedMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
custom_routing_function=Llama4MoE.custom_routing_function,
intermediate_size=intermediate_size_moe,
reduce_results=False,
renormalize=False,
quant_config=quant_config,
apply_router_weight_on_input=True,
prefix=add_prefix("experts", prefix),
)
self.shared_expert = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
prefix=add_prefix("shared_expert", prefix),
reduce_results=False, # We need to do scatter before reduce
)
def forward(self, hidden_states):
# router_scores: [num_tokens, num_experts]
router_logits, _ = self.router(hidden_states)
shared_out = self.shared_expert(hidden_states)
routed_out = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
out_aD = routed_out + shared_out
if self.tp_size > 1:
out_aD = tensor_model_parallel_all_reduce(out_aD)
return out_aD
class Llama4Attention(nn.Module):
def __init__(
self,
config: Llama4TextConfig,
layer_id: int,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.layer_id = layer_id
self.hidden_size = hidden_size
self.use_rope = int((layer_id + 1) % 4 != 0)
self.use_qk_norm = config.use_qk_norm and self.use_rope
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.attn_temperature_tuning = config.attn_temperature_tuning
self.floor_scale = config.floor_scale
self.attn_scale = config.attn_scale
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.n_rep = self.num_heads // self.num_kv_heads
self.qk_norm = (
RMSNorm(
hidden_size=self.head_dim,
eps=config.rms_norm_eps,
)
if self.use_qk_norm
else None
)
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
)
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type in ["llama", "llama4"]:
is_neox_style = False
self.rotary_emb = (
get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
rope_scaling=rope_scaling if rope_scaling != "default" else None,
is_neox_style=is_neox_style,
)
if self.use_rope
else None
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
prefix=add_prefix("attn", prefix),
use_irope=self.use_rope,
)
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
return attn_scale.unsqueeze(-1)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k)
if self.qk_norm is not None:
# TODO: support float
q = q.reshape(-1, self.head_dim).contiguous().bfloat16()
k = k.reshape(-1, self.head_dim).contiguous().bfloat16()
q = self.qk_norm(q).to(q.dtype)
k = self.qk_norm(k).to(k.dtype)
q = q.reshape(-1, self.q_size)
k = k.reshape(-1, self.kv_size)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context
# while working at very long context
# https://arxiv.org/abs/2501.19399
if self.attn_temperature_tuning and not self.use_rope:
attn_scale = self._get_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
class Llama4DecoderLayer(nn.Module):
def __init__(
self,
config: Llama4TextConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.layer_id = layer_id
self.hidden_size = config.hidden_size
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
max_position_embeddings = config.max_position_embeddings
self.self_attn = Llama4Attention(
config=config,
layer_id=layer_id,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
bias_o_proj=False,
prefix=add_prefix("self_attn", prefix),
)
is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("feed_forward", prefix),
)
else:
self.feed_forward = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size_mlp,
hidden_act="silu",
quant_config=quant_config,
prefix=add_prefix("feed_forward", prefix),
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
class Llama4Model(nn.Module):
def __init__(
self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("embed_tokens", prefix),
)
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Llama4DecoderLayer(
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
),
prefix="model.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layers_to_capture = []
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
aux_hidden_states = []
for i in range(len(self.layers)):
if i in self.layers_to_capture:
aux_hidden_states.append(hidden_states + residual)
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
class Llama4ForCausalLM(LlamaForCausalLM):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(
self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(config, quant_config, prefix)
def _init_model(
self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
return Llama4Model(config, quant_config=quant_config, prefix=prefix)
EntryClass = [Llama4ForCausalLM]

View File

@@ -0,0 +1,154 @@
# TODO: add Aapted from vllm/mllama4.py
from collections.abc import Iterable
from typing import Optional, Set, Tuple
import torch
from torch import nn
from transformers import Llama4Config
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class Llama4ForConditionalGeneration(nn.Module):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
}
def __init__(
self,
config: Llama4Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.quant_config = quant_config
# Initialize the language model
from sglang.srt.models.llama4 import Llama4ForCausalLM
self.language_model = Llama4ForCausalLM(
config.text_config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
self.logits_processor = LogitsProcessor(config.text_config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
**kwargs: object,
) -> torch.Tensor:
return self.language_model(input_ids, positions, forward_batch)
def permute_qk_weight_for_rotary(
self,
name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
attn_in = self.language_model.config.head_dim * n_heads
attn_out = self.language_model.config.hidden_size
return (
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
.transpose(1, 2)
.reshape(attn_in, attn_out)
)
modules = name.split(".")
# rotary embeds should be sliced
if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.language_model.config.num_key_value_heads
)
elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.language_model.config.num_attention_heads
)
return name, loaded_weight
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
(".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
(".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
]
params_dict = dict(self.named_parameters())
num_experts = self.config.text_config.num_local_experts
for name, loaded_weight in weights:
if name.startswith("vision_model") or name.startswith(
"multi_modal_projector"
):
continue
name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if ".experts" in name:
if ".gate_up_proj" in name:
name_list = [
name.replace(".experts.gate_up_proj", ".experts.w13_weight")
] * 2
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
shard_id_list = ["w1", "w3"]
else:
name_list = [
name.replace(".experts.down_proj", ".experts.w2_weight")
]
shard_id_list = ["w2"]
loaded_weight_list = [loaded_weight]
for name, loaded_weight, shard_id in zip(
name_list, loaded_weight_list, shard_id_list
):
param = params_dict[name]
weight_loader = param.weight_loader
for expert_id in range(num_experts):
weight_loader(
param,
loaded_weight[expert_id].T,
name,
shard_id=shard_id,
expert_id=expert_id,
)
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
EntryClass = Llama4ForConditionalGeneration