Add Llama4 support (#5092)
Co-authored-by: Cheng Wan <cwan39@gatech.edu> Co-authored-by: fzyzcjy <ch271828n@outlook.com> Co-authored-by: ispobock <ispobaoke@163.com>
This commit is contained in:
@@ -294,6 +294,30 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-4",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"<|header_start|>system<|header_end|>\n\n",
|
||||
"<|eot|>",
|
||||
),
|
||||
"user": (
|
||||
"<|header_start|>user<|header_end|>\n\n",
|
||||
"<|eot|>",
|
||||
),
|
||||
"assistant": (
|
||||
"<|header_start|>assistant<|header_end|>\n\n",
|
||||
"<|eot|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|eot|>",),
|
||||
image_token="<|image|>",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
|
||||
@@ -65,6 +65,9 @@ class ModelConfig:
|
||||
**kwargs,
|
||||
)
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.attention_chunk_size = getattr(
|
||||
self.hf_text_config, "attention_chunk_size", None
|
||||
)
|
||||
|
||||
# Check model type
|
||||
self.is_generation = is_generation_model(
|
||||
@@ -467,6 +470,7 @@ multimodal_model_archs = [
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Grok1VForCausalLM",
|
||||
"Grok1AForCausalLM",
|
||||
# TODO: add multimodal support for "Llama4ForConditionalGeneration",
|
||||
"LlavaLlamaForCausalLM",
|
||||
"LlavaMistralForCausalLM",
|
||||
"LlavaQwenForCausalLM",
|
||||
|
||||
@@ -33,6 +33,7 @@ class SeparatorStyle(IntEnum):
|
||||
ADD_NEW_LINE_SINGLE = auto()
|
||||
LLAMA2 = auto()
|
||||
LLAMA3 = auto()
|
||||
LLAMA4 = auto()
|
||||
CHATGLM = auto()
|
||||
CHATML = auto()
|
||||
CHATINTERN = auto()
|
||||
@@ -156,19 +157,30 @@ class Conversation:
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.LLAMA3:
|
||||
ret = "<|begin_of_text|>"
|
||||
elif self.sep_style == SeparatorStyle.LLAMA4:
|
||||
# begin_of_text is added by default
|
||||
if self.system_message:
|
||||
ret += system_prompt
|
||||
ret = system_prompt
|
||||
else:
|
||||
ret += ""
|
||||
ret = ""
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += f"<|header_start|>{role}<|header_end|>\n\n"
|
||||
ret += f"{message.strip()}<|eot|>"
|
||||
else:
|
||||
ret += f"<|header_start|>{role}<|header_end|>\n\n"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.LLAMA3:
|
||||
if self.system_message:
|
||||
ret = system_prompt
|
||||
else:
|
||||
ret = ""
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
||||
ret += f"{message.strip()}<|eot_id|>"
|
||||
else:
|
||||
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
||||
# print(ret)
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.LLAMA2:
|
||||
seps = [self.sep, self.sep2]
|
||||
@@ -561,6 +573,19 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="llama-4",
|
||||
system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
|
||||
roles=("user", "assistant"),
|
||||
sep_style=SeparatorStyle.LLAMA4,
|
||||
sep="",
|
||||
stop_str=["<|end_of_text|>", "<|eot|>", "<|eom|>"],
|
||||
image_token="<|image|>",
|
||||
)
|
||||
)
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="chatml",
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
"""
|
||||
@@ -45,6 +47,206 @@ class FlashAttentionMetadata:
|
||||
# Sequence lengths for the forward batch
|
||||
cache_seqlens_int32: torch.Tensor = None
|
||||
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
|
||||
local_seqused_k: torch.Tensor = None # sequence lengths for local attention
|
||||
local_block_table: torch.Tensor = None # block table for local attention
|
||||
local_max_query_len: int = 0 # max query length for local attention
|
||||
local_max_seq_len: int = 0 # max sequence length for local attention
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
# Copied from:
|
||||
# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
|
||||
#
|
||||
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
||||
# local attention blocks, where each block is passed to the attention kernel
|
||||
# as an independent local ("virtual") batch item.
|
||||
#
|
||||
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# kv_seqlens = [6, 17, 9]
|
||||
# Then normally for regular attention we would compute with an attention mask
|
||||
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1 1 1 1 1
|
||||
# 3 | 1 1 1 1 1 1
|
||||
#
|
||||
# for local attention (with attn_chunk_size = 4) we would compute with an
|
||||
# attention mask like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# We can simulate this mask using standard flash-attention by breaking the
|
||||
# sequences into local ("virtual") batches, where each local batch item is a
|
||||
# local attention block, so in this case batch idx 0 would be broken up into:
|
||||
#
|
||||
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
||||
# k_toks > 0 1 2 3
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
||||
# k_toks > 4 5
|
||||
# q_toks v _____________
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# e.g. if we have:
|
||||
# attn_chunk_size = 4
|
||||
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
||||
# Then this function would return:
|
||||
# __b0__ ______b1______ __b2__ < orig batch indices
|
||||
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
||||
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
||||
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
||||
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
||||
def make_local_attention_virtual_batches(
|
||||
attn_chunk_size: int,
|
||||
query_start_loc_np: np.ndarray,
|
||||
seq_lens_np: np.ndarray,
|
||||
block_table: torch.Tensor,
|
||||
page_size: int = 0,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
||||
local attention blocks, where each block is passed to the attention kernel
|
||||
as an independent local ("virtual") batch item.
|
||||
|
||||
Args:
|
||||
attn_chunk_size: Size of local attention chunks
|
||||
query_start_loc_np: Cumulative sum of query lengths (numpy array)
|
||||
seq_lens_np: Sequence lengths (numpy array)
|
||||
block_table: Block table for KV cache
|
||||
page_size: Size of each page in the KV cache
|
||||
|
||||
Returns:
|
||||
seqlens_q_local: Query sequence lengths for local attention
|
||||
cu_seqlens_q_local: Cumulative sum of query sequence lengths for local attention
|
||||
seqlens_k_local: Key sequence lengths for local attention
|
||||
block_table_local: Block table for local attention
|
||||
"""
|
||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||
actual_batch_size = seq_lens_np.shape[0]
|
||||
|
||||
# Handle if we are starting in the middle of a local attention block,
|
||||
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
||||
# the number of tokens that are not in the first local attention block and
|
||||
# then we can simply use a cdiv for the rest.
|
||||
# For example if we have:
|
||||
# attn_chunk_size = 4
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# k_seqlens = [6, 17, 9]
|
||||
# Then we would get:
|
||||
# new_tokens_in_first_block = [2, 1, 4]
|
||||
# local_blocks = [2, 4, 2]
|
||||
q_tokens_in_first_block = np.minimum(
|
||||
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
|
||||
).astype(np.int32)
|
||||
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
||||
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
|
||||
|
||||
# Once we know the number of local blocks we can compute the request spans
|
||||
# for each batch idx, we can figure out the number of "virtual" requests we
|
||||
# have to make,
|
||||
# For the above example we would get:
|
||||
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
||||
#
|
||||
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
||||
# (TODO: max a utility to share this code with _prepare_inputs)
|
||||
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
||||
cu_num_blocks = np.cumsum(local_blocks)
|
||||
virtual_batches = cu_num_blocks[-1]
|
||||
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
||||
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
||||
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
||||
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
||||
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
||||
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
||||
# Then we can compute the seqlens_q_local, handling the fact that the
|
||||
# first and last blocks could be partial
|
||||
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
||||
# set the first block since this may be a partial block
|
||||
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
||||
# set the remaining blocks
|
||||
seqlens_q_local[arange > 0] = np.minimum(
|
||||
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
|
||||
)[arange > 0]
|
||||
|
||||
# convert from q_seqlens to cu_seqlens_q
|
||||
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
|
||||
|
||||
# compute the seqlens_k_local,
|
||||
# basically a full local attention block for all but the last block in each
|
||||
# batch
|
||||
# For our example this will be:
|
||||
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
||||
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
|
||||
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||
|
||||
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
|
||||
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
|
||||
)
|
||||
# For the example the local attention blocks start at:
|
||||
# _b0_ _____b1_____ _b2_
|
||||
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
||||
block_starts = k_seqstarts_absolute // page_size
|
||||
|
||||
assert attn_chunk_size % page_size == 0, (
|
||||
f"attn_chunk_size {attn_chunk_size} is not "
|
||||
f"divisible by page_size {page_size}"
|
||||
)
|
||||
pages_per_local_batch = attn_chunk_size // page_size
|
||||
|
||||
# Create a block_table for the local attention blocks
|
||||
# For out example if we have a block-table like (assuming page_size=2):
|
||||
# block_table = [
|
||||
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
||||
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
||||
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
||||
# ]
|
||||
# Then for the local batches we would want a block-table like
|
||||
# block_table_local = [
|
||||
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
||||
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
||||
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
||||
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
||||
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
||||
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
||||
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
||||
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
||||
# ]
|
||||
block_indices = np.broadcast_to(
|
||||
np.arange(pages_per_local_batch, dtype=np.int32),
|
||||
(virtual_batches, pages_per_local_batch),
|
||||
) + np.expand_dims(block_starts, axis=1)
|
||||
block_indices = block_indices.flatten()
|
||||
batch_indices = np.repeat(
|
||||
np.arange(actual_batch_size, dtype=np.int32),
|
||||
local_blocks * pages_per_local_batch,
|
||||
)
|
||||
block_table_local = block_table[batch_indices, block_indices].view(
|
||||
virtual_batches, -1
|
||||
)
|
||||
|
||||
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
"""Ceiling division."""
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
"""FlashAttention backend implementation.
|
||||
@@ -100,6 +302,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.step_id = step_id
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
|
||||
# Local attention settings
|
||||
self.attention_chunk_size = (
|
||||
model_runner.attention_chunk_size
|
||||
if hasattr(model_runner, "attention_chunk_size")
|
||||
else None
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Initialize forward metadata to cache repetitive calculations."""
|
||||
metadata = FlashAttentionMetadata()
|
||||
@@ -189,6 +398,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
# Precompute cumulative sequence lengths
|
||||
if (
|
||||
any(forward_batch.extend_prefix_lens_cpu)
|
||||
@@ -203,6 +413,51 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
||||
metadata.max_seq_len_q = metadata.max_seq_len_k
|
||||
|
||||
# Setup local attention if enabled
|
||||
if (
|
||||
self.attention_chunk_size is not None
|
||||
and forward_batch.forward_mode == ForwardMode.EXTEND
|
||||
):
|
||||
# Convert tensors to numpy for local attention processing
|
||||
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
|
||||
seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
|
||||
|
||||
# Adjust attention_chunk_size based on the actual sequence length
|
||||
# to avoid index out of bounds errors
|
||||
max_seq_len = seq_lens_np.max()
|
||||
effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
|
||||
# Make sure effective_chunk_size is divisible by page_size
|
||||
effective_chunk_size = (
|
||||
effective_chunk_size // self.page_size
|
||||
) * self.page_size
|
||||
if effective_chunk_size < self.page_size:
|
||||
effective_chunk_size = self.page_size
|
||||
|
||||
# Create local attention metadata
|
||||
(
|
||||
seqlens_q_local_np,
|
||||
cu_seqlens_q_local_np,
|
||||
seqlens_k_local_np,
|
||||
block_table_local,
|
||||
) = make_local_attention_virtual_batches(
|
||||
effective_chunk_size,
|
||||
cu_seqlens_q_np,
|
||||
seq_lens_np,
|
||||
metadata.page_table,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
|
||||
device
|
||||
),
|
||||
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
||||
local_block_table=block_table_local,
|
||||
local_max_query_len=seqlens_q_local_np.max(),
|
||||
local_max_seq_len=seqlens_k_local_np.max(),
|
||||
)
|
||||
metadata.local_attn_metadata = local_metadata
|
||||
|
||||
# Precompute strided indices
|
||||
if self.page_size > 1:
|
||||
self.strided_indices = torch.arange(
|
||||
@@ -211,6 +466,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.page_table = (
|
||||
metadata.page_table[:, self.strided_indices] // self.page_size
|
||||
)
|
||||
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def forward_extend(
|
||||
@@ -254,7 +510,28 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
else (-1, -1)
|
||||
)
|
||||
|
||||
page_table = metadata.page_table
|
||||
# Check if we should use local attention
|
||||
use_local_attn = (
|
||||
self.attention_chunk_size is not None
|
||||
and metadata.local_attn_metadata is not None
|
||||
and (hasattr(layer, "use_irope") and layer.use_irope)
|
||||
)
|
||||
|
||||
# Get the appropriate page table based on whether we're using local attention
|
||||
if use_local_attn:
|
||||
local_metadata = metadata.local_attn_metadata
|
||||
page_table = local_metadata.local_block_table
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
cache_seqlens = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
else:
|
||||
page_table = metadata.page_table
|
||||
cu_seqlens_q = metadata.cu_seqlens_q
|
||||
cache_seqlens = metadata.cache_seqlens_int32
|
||||
max_seqlen_q = metadata.max_seq_len_q
|
||||
max_seqlen_k = metadata.max_seq_len_k
|
||||
cu_seqlens_k = metadata.cu_seqlens_k
|
||||
|
||||
# Use Flash Attention for prefill
|
||||
if not self.use_mla:
|
||||
@@ -272,10 +549,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
page_table=page_table,
|
||||
cache_seqlens=metadata.cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
max_seqlen_q=metadata.max_seq_len_q,
|
||||
cache_seqlens=cache_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
@@ -307,10 +584,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
v_cache=c_kv_cache,
|
||||
qv=q_nope,
|
||||
page_table=page_table,
|
||||
cache_seqlens=metadata.cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
max_seqlen_q=metadata.max_seq_len_q,
|
||||
cache_seqlens=cache_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
softcap=layer.logit_cap,
|
||||
|
||||
@@ -23,9 +23,14 @@ def fused_moe_forward_native(
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -1079,6 +1079,7 @@ def inplace_fused_experts(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
@@ -1099,6 +1100,7 @@ def inplace_fused_experts(
|
||||
topk_ids,
|
||||
True,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
@@ -1120,6 +1122,7 @@ def inplace_fused_experts_fake(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
@@ -1150,6 +1153,7 @@ def outplace_fused_experts(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
@@ -1171,6 +1175,7 @@ def outplace_fused_experts(
|
||||
topk_ids,
|
||||
False,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
@@ -1193,6 +1198,7 @@ def outplace_fused_experts_fake(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
@@ -1225,6 +1231,7 @@ def fused_experts(
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
@@ -1247,6 +1254,7 @@ def fused_experts(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
@@ -1268,6 +1276,7 @@ def fused_experts(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
@@ -1291,6 +1300,7 @@ def fused_experts_impl(
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
@@ -1423,7 +1433,7 @@ def fused_experts_impl(
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
apply_router_weight_on_input,
|
||||
topk_ids.shape[1],
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
@@ -1456,7 +1466,7 @@ def fused_experts_impl(
|
||||
(
|
||||
intermediate_cache3
|
||||
if not no_combine and topk_ids.shape[1] != 1
|
||||
else out_hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
else out_hidden_states[begin_chunk_idx:end_chunk_idx].unsqueeze(0)
|
||||
),
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
@@ -1466,7 +1476,7 @@ def fused_experts_impl(
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
True,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
|
||||
@@ -128,6 +128,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -143,6 +144,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
@@ -160,6 +162,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -200,6 +203,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
@@ -276,6 +280,7 @@ class FusedMoE(torch.nn.Module):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_presharded_weights: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
@@ -302,6 +307,7 @@ class FusedMoE(torch.nn.Module):
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.correction_bias = correction_bias
|
||||
self.activation = activation
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
self.inplace = inplace
|
||||
self.no_combine = no_combine
|
||||
@@ -630,6 +636,7 @@ class FusedMoE(torch.nn.Module):
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
|
||||
@@ -280,6 +280,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
):
|
||||
|
||||
@@ -370,6 +370,7 @@ class BlockInt8MoEMethod:
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -398,6 +399,7 @@ class BlockInt8MoEMethod:
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=(layer.w13_weight_scale_inv),
|
||||
w2_scale=(layer.w2_weight_scale_inv),
|
||||
|
||||
@@ -905,6 +905,7 @@ class Fp8MoEMethod:
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -975,6 +976,7 @@ class Fp8MoEMethod:
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
|
||||
@@ -344,6 +344,7 @@ class MoeWNA16Method:
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -374,6 +375,7 @@ class MoeWNA16Method:
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
w1_scale=layer.w13_scales,
|
||||
|
||||
@@ -230,6 +230,7 @@ class W8A8Int8MoEMethod:
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -257,6 +258,7 @@ class W8A8Int8MoEMethod:
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale),
|
||||
|
||||
@@ -35,6 +35,7 @@ class RadixAttention(nn.Module):
|
||||
sliding_window_size: int = -1,
|
||||
is_cross_attention: bool = False,
|
||||
prefix: str = "",
|
||||
use_irope: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_q_head_num = num_heads
|
||||
@@ -50,6 +51,7 @@ class RadixAttention(nn.Module):
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.k_scale = None
|
||||
self.v_scale = None
|
||||
self.use_irope = use_irope
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -733,6 +733,69 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
||||
return new_freqs
|
||||
|
||||
|
||||
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
inv_freqs = super()._compute_inv_freq(base)
|
||||
inv_freqs = inv_freqs[: (self.rotary_dim // 2)]
|
||||
return inv_freqs
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
|
||||
# self.max_position_embeddings here is number of image patches
|
||||
# i.e. (image_size // patch_size) ** 2
|
||||
num_patches = self.max_position_embeddings
|
||||
img_idx = torch.arange(num_patches, dtype=torch.int32).reshape(num_patches, 1)
|
||||
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
||||
img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
|
||||
num_patches_single_dim = int(math.sqrt(num_patches))
|
||||
frequencies_x = img_idx % num_patches_single_dim
|
||||
frequencies_y = img_idx // num_patches_single_dim
|
||||
freqs_x = (
|
||||
(frequencies_x + 1)[..., None] * inv_freq[None, None, :]
|
||||
).repeat_interleave(2, dim=-1)
|
||||
freqs_y = (
|
||||
(frequencies_y + 1)[..., None] * inv_freq[None, None, :]
|
||||
).repeat_interleave(2, dim=-1)
|
||||
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
|
||||
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
|
||||
cache = torch.view_as_complex(
|
||||
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
||||
)
|
||||
return cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
|
||||
query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
|
||||
key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
|
||||
broadcast_shape = [
|
||||
d if i == 1 or i == (query_.ndim - 1) else 1
|
||||
for i, d in enumerate(query_.shape)
|
||||
]
|
||||
freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
|
||||
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
|
||||
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
|
||||
return query_out.type_as(query), key_out.type_as(key)
|
||||
|
||||
|
||||
class MRotaryEmbedding(RotaryEmbedding):
|
||||
"""Rotary Embedding with Multimodal Sections."""
|
||||
|
||||
|
||||
161
python/sglang/srt/managers/multimodal_processors/mllama4.py
Normal file
161
python/sglang/srt/managers/multimodal_processors/mllama4.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from typing import List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import Llama4Processor
|
||||
from transformers.image_utils import SizeDict
|
||||
from transformers.models.llama4.image_processing_llama4 import (
|
||||
find_supported_resolutions,
|
||||
get_best_fit,
|
||||
)
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
|
||||
class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
models = [Llama4ForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.vision_config = hf_config.vision_config
|
||||
self.text_config = hf_config.text_config
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token
|
||||
)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
max_req_input_len=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
# Process images and text using the base processor's load_mm_data method
|
||||
processed_data = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
max_req_input_len=max_req_input_len or 4096,
|
||||
image_data=image_data,
|
||||
return_text=True,
|
||||
)
|
||||
|
||||
# Process the images using the processor
|
||||
processor = Llama4Processor.from_pretrained(
|
||||
self.server_args.model_path, **kwargs
|
||||
)
|
||||
|
||||
# Process the prompt and images
|
||||
image_inputs = processor(
|
||||
text=processed_data.input_text,
|
||||
images=processed_data.images,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Handle image resolutions and aspect ratios
|
||||
if "pixel_values" in image_inputs:
|
||||
image_processor = processor.image_processor
|
||||
tokenizer = self._processor.tokenizer
|
||||
|
||||
# Calculate tile size and find supported resolutions
|
||||
tile_size = self.vision_config.image_size
|
||||
max_num_tiles = getattr(self.vision_config, "max_patches", 1)
|
||||
|
||||
possible_resolutions = find_supported_resolutions(
|
||||
max_num_chunks=max_num_tiles,
|
||||
patch_size=SizeDict(height=tile_size, width=tile_size),
|
||||
)
|
||||
|
||||
# Find best fit for each image
|
||||
best_fit_sizes = [
|
||||
get_best_fit(
|
||||
(image.size[1], image.size[0]), # (height, width)
|
||||
torch.tensor(possible_resolutions),
|
||||
resize_to_max_canvas=image_processor.resize_to_max_canvas,
|
||||
)
|
||||
for image in processed_data.images
|
||||
]
|
||||
|
||||
# Calculate aspect ratios and patches per image
|
||||
aspect_ratios = [
|
||||
(image_size[0] // tile_size, image_size[1] // tile_size)
|
||||
for image_size in best_fit_sizes
|
||||
]
|
||||
|
||||
patches_per_image = [
|
||||
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
|
||||
]
|
||||
|
||||
# Add to image_inputs
|
||||
image_inputs["aspect_ratios"] = aspect_ratios
|
||||
image_inputs["patches_per_image"] = torch.tensor(patches_per_image)
|
||||
|
||||
# Process embed_is_patch
|
||||
vocab = tokenizer.get_vocab()
|
||||
patch_id = vocab.get(processor.img_patch_token, -1)
|
||||
image_end_id = vocab.get(processor.end_of_img_token, -1)
|
||||
|
||||
if patch_id != -1 and image_end_id != -1:
|
||||
input_ids = image_inputs["input_ids"].view(-1)
|
||||
|
||||
# Remove BOS token if present
|
||||
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
|
||||
input_ids = input_ids[1:]
|
||||
|
||||
# Find image end indices and split input_ids
|
||||
image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
|
||||
|
||||
if image_end_indices.size(0) > 0:
|
||||
# Split at image boundaries
|
||||
split_indices = (image_end_indices + 1)[:-1]
|
||||
split_input_ids = torch.tensor_split(input_ids, split_indices)
|
||||
split_input_ids = [x for x in split_input_ids if x.numel() > 0]
|
||||
|
||||
# Create embed_is_patch for each image
|
||||
embed_is_patch = []
|
||||
for per_image_input_ids in split_input_ids:
|
||||
embed_is_patch.append(per_image_input_ids == patch_id)
|
||||
|
||||
image_inputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
# Convert to the format expected by SGLang
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
|
||||
# Add metadata for image processing
|
||||
image_inputs["mm_items"] = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=image_inputs["pixel_values"],
|
||||
modality=Modality.IMAGE,
|
||||
# Add additional metadata needed for Llama4 vision processing
|
||||
embed_is_patch=image_inputs.get("embed_is_patch", None),
|
||||
aspect_ratios=image_inputs.get("aspect_ratios", None),
|
||||
patches_per_image=image_inputs.get("patches_per_image", None),
|
||||
)
|
||||
]
|
||||
|
||||
return image_inputs
|
||||
|
||||
def get_patch_per_chunk(self):
|
||||
"""Calculate patches per chunk based on vision config"""
|
||||
image_size = self.vision_config.image_size
|
||||
patch_size = self.vision_config.patch_size
|
||||
|
||||
assert (
|
||||
image_size % patch_size == 0
|
||||
), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
|
||||
|
||||
ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
|
||||
return (image_size // patch_size) ** 2 // ds_ratio
|
||||
@@ -128,6 +128,7 @@ class ModelRunner:
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and not server_args.disable_mla
|
||||
)
|
||||
self.attention_chunk_size = model_config.attention_chunk_size
|
||||
|
||||
# Model-specific adjustment
|
||||
self.model_specific_adjustment()
|
||||
|
||||
@@ -63,6 +63,7 @@ class LlamaMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
reduce_results: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@@ -78,6 +79,7 @@ class LlamaMLP(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
reduce_results=reduce_results,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -281,7 +283,7 @@ class LlamaModel(nn.Module):
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda idx, prefix: LlamaDecoderLayer(
|
||||
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
||||
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
||||
),
|
||||
prefix="model.layers",
|
||||
)
|
||||
@@ -375,9 +377,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = LlamaModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
|
||||
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
||||
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
||||
if self.config.tie_word_embeddings:
|
||||
@@ -402,6 +402,14 @@ class LlamaForCausalLM(nn.Module):
|
||||
|
||||
self.capture_aux_hidden_states = False
|
||||
|
||||
def _init_model(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
return LlamaModel(config, quant_config=quant_config, prefix=prefix)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
|
||||
420
python/sglang/srt/models/llama4.py
Normal file
420
python/sglang/srt/models/llama4.py
Normal file
@@ -0,0 +1,420 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/models/llama4.py
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Llama4TextConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
||||
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Llama4MoE(nn.Module):
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
@staticmethod
|
||||
def custom_routing_function(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1)
|
||||
router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
|
||||
hidden_states.dtype
|
||||
)
|
||||
return (
|
||||
router_scores_aK.view(-1).reshape(router_scores_aK.shape),
|
||||
router_indices_aK.to(torch.int32),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4TextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
intermediate_size_moe = config.intermediate_size
|
||||
self.router = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_local_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=add_prefix("router", prefix),
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
intermediate_size=intermediate_size_moe,
|
||||
reduce_results=False,
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
apply_router_weight_on_input=True,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
|
||||
self.shared_expert = LlamaMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size_moe,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("shared_expert", prefix),
|
||||
reduce_results=False, # We need to do scatter before reduce
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# router_scores: [num_tokens, num_experts]
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
shared_out = self.shared_expert(hidden_states)
|
||||
routed_out = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
out_aD = routed_out + shared_out
|
||||
|
||||
if self.tp_size > 1:
|
||||
out_aD = tensor_model_parallel_all_reduce(out_aD)
|
||||
|
||||
return out_aD
|
||||
|
||||
|
||||
class Llama4Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4TextConfig,
|
||||
layer_id: int,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
bias_o_proj: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.hidden_size = hidden_size
|
||||
self.use_rope = int((layer_id + 1) % 4 != 0)
|
||||
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = config.head_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn_temperature_tuning = config.attn_temperature_tuning
|
||||
self.floor_scale = config.floor_scale
|
||||
self.attn_scale = config.attn_scale
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.n_rep = self.num_heads // self.num_kv_heads
|
||||
self.qk_norm = (
|
||||
RMSNorm(
|
||||
hidden_size=self.head_dim,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
if self.use_qk_norm
|
||||
else None
|
||||
)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=hidden_size,
|
||||
bias=bias_o_proj,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
is_neox_style = True
|
||||
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
||||
if is_gguf and config.model_type in ["llama", "llama4"]:
|
||||
is_neox_style = False
|
||||
|
||||
self.rotary_emb = (
|
||||
get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=int(rope_theta),
|
||||
rope_scaling=rope_scaling if rope_scaling != "default" else None,
|
||||
is_neox_style=is_neox_style,
|
||||
)
|
||||
if self.use_rope
|
||||
else None
|
||||
)
|
||||
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
use_irope=self.use_rope,
|
||||
)
|
||||
|
||||
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
||||
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
|
||||
|
||||
return attn_scale.unsqueeze(-1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
if self.rotary_emb is not None:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
if self.qk_norm is not None:
|
||||
# TODO: support float
|
||||
q = q.reshape(-1, self.head_dim).contiguous().bfloat16()
|
||||
k = k.reshape(-1, self.head_dim).contiguous().bfloat16()
|
||||
q = self.qk_norm(q).to(q.dtype)
|
||||
k = self.qk_norm(k).to(k.dtype)
|
||||
q = q.reshape(-1, self.q_size)
|
||||
k = k.reshape(-1, self.kv_size)
|
||||
|
||||
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
||||
# the inference-time temperature tuning function is customized to not affect short context
|
||||
# while working at very long context
|
||||
# https://arxiv.org/abs/2501.19399
|
||||
if self.attn_temperature_tuning and not self.use_rope:
|
||||
attn_scale = self._get_attn_scale(positions)
|
||||
q = (q * attn_scale).to(q.dtype)
|
||||
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Llama4DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4TextConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = config.rope_theta
|
||||
rope_scaling = config.rope_scaling
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
self.self_attn = Llama4Attention(
|
||||
config=config,
|
||||
layer_id=layer_id,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
bias_o_proj=False,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
|
||||
if is_moe_layer:
|
||||
self.feed_forward = Llama4MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("feed_forward", prefix),
|
||||
)
|
||||
else:
|
||||
self.feed_forward = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size_mlp,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("feed_forward", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class Llama4Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4TextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda idx, prefix: Llama4DecoderLayer(
|
||||
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
||||
),
|
||||
prefix="model.layers",
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.layers_to_capture = []
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
residual = None
|
||||
aux_hidden_states = []
|
||||
for i in range(len(self.layers)):
|
||||
if i in self.layers_to_capture:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if len(aux_hidden_states) == 0:
|
||||
return hidden_states
|
||||
|
||||
return hidden_states, aux_hidden_states
|
||||
|
||||
|
||||
class Llama4ForCausalLM(LlamaForCausalLM):
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4TextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(config, quant_config, prefix)
|
||||
|
||||
def _init_model(
|
||||
self,
|
||||
config: Llama4TextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
return Llama4Model(config, quant_config=quant_config, prefix=prefix)
|
||||
|
||||
|
||||
EntryClass = [Llama4ForCausalLM]
|
||||
154
python/sglang/srt/models/mllama4.py
Normal file
154
python/sglang/srt/models/mllama4.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# TODO: add Aapted from vllm/mllama4.py
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Llama4Config
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class Llama4ForConditionalGeneration(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Initialize the language model
|
||||
from sglang.srt.models.llama4 import Llama4ForCausalLM
|
||||
|
||||
self.language_model = Llama4ForCausalLM(
|
||||
config.text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("language_model", prefix),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.text_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.language_model(input_ids, positions, forward_batch)
|
||||
|
||||
def permute_qk_weight_for_rotary(
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> Tuple[str, torch.Tensor]:
|
||||
|
||||
def permute(w: torch.Tensor, n_heads: int):
|
||||
attn_in = self.language_model.config.head_dim * n_heads
|
||||
attn_out = self.language_model.config.hidden_size
|
||||
|
||||
return (
|
||||
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
|
||||
.transpose(1, 2)
|
||||
.reshape(attn_in, attn_out)
|
||||
)
|
||||
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
|
||||
loaded_weight = permute(
|
||||
loaded_weight, self.language_model.config.num_key_value_heads
|
||||
)
|
||||
elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
|
||||
loaded_weight = permute(
|
||||
loaded_weight, self.language_model.config.num_attention_heads
|
||||
)
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||
(".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
|
||||
(".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
|
||||
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
|
||||
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
num_experts = self.config.text_config.num_local_experts
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
|
||||
if name.startswith("vision_model") or name.startswith(
|
||||
"multi_modal_projector"
|
||||
):
|
||||
continue
|
||||
|
||||
name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if ".experts" in name:
|
||||
if ".gate_up_proj" in name:
|
||||
name_list = [
|
||||
name.replace(".experts.gate_up_proj", ".experts.w13_weight")
|
||||
] * 2
|
||||
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
|
||||
shard_id_list = ["w1", "w3"]
|
||||
else:
|
||||
name_list = [
|
||||
name.replace(".experts.down_proj", ".experts.w2_weight")
|
||||
]
|
||||
shard_id_list = ["w2"]
|
||||
loaded_weight_list = [loaded_weight]
|
||||
for name, loaded_weight, shard_id in zip(
|
||||
name_list, loaded_weight_list, shard_id_list
|
||||
):
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
for expert_id in range(num_experts):
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight[expert_id].T,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = Llama4ForConditionalGeneration
|
||||
Reference in New Issue
Block a user