Add Llama4 support (#5092)
Co-authored-by: Cheng Wan <cwan39@gatech.edu> Co-authored-by: fzyzcjy <ch271828n@outlook.com> Co-authored-by: ispobock <ispobaoke@163.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
# Supported Models
|
# Supported Models
|
||||||
|
|
||||||
## Generative Models
|
## Generative Models
|
||||||
- Llama / Llama 2 / Llama 3 / Llama 3.1 / Llama 3.2 / Llama 3.3
|
- Llama / Llama 2 / Llama 3 / Llama 3.1 / Llama 3.2 / Llama 3.3 / Llama 4
|
||||||
- Mistral / Mixtral / Mistral NeMo / Mistral Small 3
|
- Mistral / Mixtral / Mistral NeMo / Mistral Small 3
|
||||||
- Gemma / Gemma 2 / Gemma3
|
- Gemma / Gemma 2 / Gemma3
|
||||||
- Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL / Qwen 2.5 VL / Olympic Coder
|
- Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL / Qwen 2.5 VL / Olympic Coder
|
||||||
|
|||||||
@@ -294,6 +294,30 @@ register_chat_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
||||||
|
register_chat_template(
|
||||||
|
ChatTemplate(
|
||||||
|
name="llama-4",
|
||||||
|
default_system_prompt=None,
|
||||||
|
role_prefix_and_suffix={
|
||||||
|
"system": (
|
||||||
|
"<|header_start|>system<|header_end|>\n\n",
|
||||||
|
"<|eot|>",
|
||||||
|
),
|
||||||
|
"user": (
|
||||||
|
"<|header_start|>user<|header_end|>\n\n",
|
||||||
|
"<|eot|>",
|
||||||
|
),
|
||||||
|
"assistant": (
|
||||||
|
"<|header_start|>assistant<|header_end|>\n\n",
|
||||||
|
"<|eot|>",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
stop_str=("<|eot|>",),
|
||||||
|
image_token="<|image|>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
||||||
register_chat_template(
|
register_chat_template(
|
||||||
ChatTemplate(
|
ChatTemplate(
|
||||||
|
|||||||
@@ -65,6 +65,9 @@ class ModelConfig:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
|
self.attention_chunk_size = getattr(
|
||||||
|
self.hf_text_config, "attention_chunk_size", None
|
||||||
|
)
|
||||||
|
|
||||||
# Check model type
|
# Check model type
|
||||||
self.is_generation = is_generation_model(
|
self.is_generation = is_generation_model(
|
||||||
@@ -467,6 +470,7 @@ multimodal_model_archs = [
|
|||||||
"Gemma3ForConditionalGeneration",
|
"Gemma3ForConditionalGeneration",
|
||||||
"Grok1VForCausalLM",
|
"Grok1VForCausalLM",
|
||||||
"Grok1AForCausalLM",
|
"Grok1AForCausalLM",
|
||||||
|
# TODO: add multimodal support for "Llama4ForConditionalGeneration",
|
||||||
"LlavaLlamaForCausalLM",
|
"LlavaLlamaForCausalLM",
|
||||||
"LlavaMistralForCausalLM",
|
"LlavaMistralForCausalLM",
|
||||||
"LlavaQwenForCausalLM",
|
"LlavaQwenForCausalLM",
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ class SeparatorStyle(IntEnum):
|
|||||||
ADD_NEW_LINE_SINGLE = auto()
|
ADD_NEW_LINE_SINGLE = auto()
|
||||||
LLAMA2 = auto()
|
LLAMA2 = auto()
|
||||||
LLAMA3 = auto()
|
LLAMA3 = auto()
|
||||||
|
LLAMA4 = auto()
|
||||||
CHATGLM = auto()
|
CHATGLM = auto()
|
||||||
CHATML = auto()
|
CHATML = auto()
|
||||||
CHATINTERN = auto()
|
CHATINTERN = auto()
|
||||||
@@ -156,19 +157,30 @@ class Conversation:
|
|||||||
else:
|
else:
|
||||||
ret += role + ":"
|
ret += role + ":"
|
||||||
return ret
|
return ret
|
||||||
elif self.sep_style == SeparatorStyle.LLAMA3:
|
elif self.sep_style == SeparatorStyle.LLAMA4:
|
||||||
ret = "<|begin_of_text|>"
|
# begin_of_text is added by default
|
||||||
if self.system_message:
|
if self.system_message:
|
||||||
ret += system_prompt
|
ret = system_prompt
|
||||||
else:
|
else:
|
||||||
ret += ""
|
ret = ""
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
ret += f"<|header_start|>{role}<|header_end|>\n\n"
|
||||||
|
ret += f"{message.strip()}<|eot|>"
|
||||||
|
else:
|
||||||
|
ret += f"<|header_start|>{role}<|header_end|>\n\n"
|
||||||
|
return ret
|
||||||
|
elif self.sep_style == SeparatorStyle.LLAMA3:
|
||||||
|
if self.system_message:
|
||||||
|
ret = system_prompt
|
||||||
|
else:
|
||||||
|
ret = ""
|
||||||
for i, (role, message) in enumerate(self.messages):
|
for i, (role, message) in enumerate(self.messages):
|
||||||
if message:
|
if message:
|
||||||
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
||||||
ret += f"{message.strip()}<|eot_id|>"
|
ret += f"{message.strip()}<|eot_id|>"
|
||||||
else:
|
else:
|
||||||
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
||||||
# print(ret)
|
|
||||||
return ret
|
return ret
|
||||||
elif self.sep_style == SeparatorStyle.LLAMA2:
|
elif self.sep_style == SeparatorStyle.LLAMA2:
|
||||||
seps = [self.sep, self.sep2]
|
seps = [self.sep, self.sep2]
|
||||||
@@ -561,6 +573,19 @@ register_conv_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="llama-4",
|
||||||
|
system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
|
||||||
|
roles=("user", "assistant"),
|
||||||
|
sep_style=SeparatorStyle.LLAMA4,
|
||||||
|
sep="",
|
||||||
|
stop_str=["<|end_of_text|>", "<|eot|>", "<|eom|>"],
|
||||||
|
image_token="<|image|>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
register_conv_template(
|
register_conv_template(
|
||||||
Conversation(
|
Conversation(
|
||||||
name="chatml",
|
name="chatml",
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -45,6 +47,206 @@ class FlashAttentionMetadata:
|
|||||||
# Sequence lengths for the forward batch
|
# Sequence lengths for the forward batch
|
||||||
cache_seqlens_int32: torch.Tensor = None
|
cache_seqlens_int32: torch.Tensor = None
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LocalAttentionMetadata:
|
||||||
|
local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
|
||||||
|
local_seqused_k: torch.Tensor = None # sequence lengths for local attention
|
||||||
|
local_block_table: torch.Tensor = None # block table for local attention
|
||||||
|
local_max_query_len: int = 0 # max query length for local attention
|
||||||
|
local_max_seq_len: int = 0 # max sequence length for local attention
|
||||||
|
|
||||||
|
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from:
|
||||||
|
# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
|
||||||
|
#
|
||||||
|
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
||||||
|
# local attention blocks, where each block is passed to the attention kernel
|
||||||
|
# as an independent local ("virtual") batch item.
|
||||||
|
#
|
||||||
|
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
||||||
|
# q_seqlens = [4, 10, 5]
|
||||||
|
# kv_seqlens = [6, 17, 9]
|
||||||
|
# Then normally for regular attention we would compute with an attention mask
|
||||||
|
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
||||||
|
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
||||||
|
# k_toks > 0 1 2 3 4 5
|
||||||
|
# q_toks v _____________
|
||||||
|
# 0 | 1 1 1
|
||||||
|
# 1 | 1 1 1 1
|
||||||
|
# 2 | 1 1 1 1 1
|
||||||
|
# 3 | 1 1 1 1 1 1
|
||||||
|
#
|
||||||
|
# for local attention (with attn_chunk_size = 4) we would compute with an
|
||||||
|
# attention mask like:
|
||||||
|
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
||||||
|
# k_toks > 0 1 2 3 4 5
|
||||||
|
# q_toks v _____________
|
||||||
|
# 0 | 1 1 1
|
||||||
|
# 1 | 1 1 1 1
|
||||||
|
# 2 | 1
|
||||||
|
# 3 | 1 1
|
||||||
|
#
|
||||||
|
# We can simulate this mask using standard flash-attention by breaking the
|
||||||
|
# sequences into local ("virtual") batches, where each local batch item is a
|
||||||
|
# local attention block, so in this case batch idx 0 would be broken up into:
|
||||||
|
#
|
||||||
|
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
||||||
|
# k_toks > 0 1 2 3
|
||||||
|
# q_toks v _____________
|
||||||
|
# 0 | 1 1 1
|
||||||
|
# 1 | 1 1 1 1
|
||||||
|
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
||||||
|
# k_toks > 4 5
|
||||||
|
# q_toks v _____________
|
||||||
|
# 2 | 1
|
||||||
|
# 3 | 1 1
|
||||||
|
#
|
||||||
|
# e.g. if we have:
|
||||||
|
# attn_chunk_size = 4
|
||||||
|
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
||||||
|
# Then this function would return:
|
||||||
|
# __b0__ ______b1______ __b2__ < orig batch indices
|
||||||
|
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
||||||
|
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
||||||
|
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
||||||
|
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
||||||
|
def make_local_attention_virtual_batches(
|
||||||
|
attn_chunk_size: int,
|
||||||
|
query_start_loc_np: np.ndarray,
|
||||||
|
seq_lens_np: np.ndarray,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
page_size: int = 0,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
||||||
|
local attention blocks, where each block is passed to the attention kernel
|
||||||
|
as an independent local ("virtual") batch item.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attn_chunk_size: Size of local attention chunks
|
||||||
|
query_start_loc_np: Cumulative sum of query lengths (numpy array)
|
||||||
|
seq_lens_np: Sequence lengths (numpy array)
|
||||||
|
block_table: Block table for KV cache
|
||||||
|
page_size: Size of each page in the KV cache
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
seqlens_q_local: Query sequence lengths for local attention
|
||||||
|
cu_seqlens_q_local: Cumulative sum of query sequence lengths for local attention
|
||||||
|
seqlens_k_local: Key sequence lengths for local attention
|
||||||
|
block_table_local: Block table for local attention
|
||||||
|
"""
|
||||||
|
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||||
|
actual_batch_size = seq_lens_np.shape[0]
|
||||||
|
|
||||||
|
# Handle if we are starting in the middle of a local attention block,
|
||||||
|
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
||||||
|
# the number of tokens that are not in the first local attention block and
|
||||||
|
# then we can simply use a cdiv for the rest.
|
||||||
|
# For example if we have:
|
||||||
|
# attn_chunk_size = 4
|
||||||
|
# q_seqlens = [4, 10, 5]
|
||||||
|
# k_seqlens = [6, 17, 9]
|
||||||
|
# Then we would get:
|
||||||
|
# new_tokens_in_first_block = [2, 1, 4]
|
||||||
|
# local_blocks = [2, 4, 2]
|
||||||
|
q_tokens_in_first_block = np.minimum(
|
||||||
|
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
|
||||||
|
).astype(np.int32)
|
||||||
|
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
||||||
|
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
|
||||||
|
|
||||||
|
# Once we know the number of local blocks we can compute the request spans
|
||||||
|
# for each batch idx, we can figure out the number of "virtual" requests we
|
||||||
|
# have to make,
|
||||||
|
# For the above example we would get:
|
||||||
|
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
||||||
|
#
|
||||||
|
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
||||||
|
# (TODO: max a utility to share this code with _prepare_inputs)
|
||||||
|
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
||||||
|
cu_num_blocks = np.cumsum(local_blocks)
|
||||||
|
virtual_batches = cu_num_blocks[-1]
|
||||||
|
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
||||||
|
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
||||||
|
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
||||||
|
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
||||||
|
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
||||||
|
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
||||||
|
# Then we can compute the seqlens_q_local, handling the fact that the
|
||||||
|
# first and last blocks could be partial
|
||||||
|
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
||||||
|
# set the first block since this may be a partial block
|
||||||
|
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
||||||
|
# set the remaining blocks
|
||||||
|
seqlens_q_local[arange > 0] = np.minimum(
|
||||||
|
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
|
||||||
|
)[arange > 0]
|
||||||
|
|
||||||
|
# convert from q_seqlens to cu_seqlens_q
|
||||||
|
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
|
||||||
|
|
||||||
|
# compute the seqlens_k_local,
|
||||||
|
# basically a full local attention block for all but the last block in each
|
||||||
|
# batch
|
||||||
|
# For our example this will be:
|
||||||
|
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
||||||
|
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
|
||||||
|
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||||
|
|
||||||
|
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
|
||||||
|
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
|
||||||
|
)
|
||||||
|
# For the example the local attention blocks start at:
|
||||||
|
# _b0_ _____b1_____ _b2_
|
||||||
|
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
||||||
|
block_starts = k_seqstarts_absolute // page_size
|
||||||
|
|
||||||
|
assert attn_chunk_size % page_size == 0, (
|
||||||
|
f"attn_chunk_size {attn_chunk_size} is not "
|
||||||
|
f"divisible by page_size {page_size}"
|
||||||
|
)
|
||||||
|
pages_per_local_batch = attn_chunk_size // page_size
|
||||||
|
|
||||||
|
# Create a block_table for the local attention blocks
|
||||||
|
# For out example if we have a block-table like (assuming page_size=2):
|
||||||
|
# block_table = [
|
||||||
|
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
||||||
|
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
||||||
|
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
||||||
|
# ]
|
||||||
|
# Then for the local batches we would want a block-table like
|
||||||
|
# block_table_local = [
|
||||||
|
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
||||||
|
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
||||||
|
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
||||||
|
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
||||||
|
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
||||||
|
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
||||||
|
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
||||||
|
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
||||||
|
# ]
|
||||||
|
block_indices = np.broadcast_to(
|
||||||
|
np.arange(pages_per_local_batch, dtype=np.int32),
|
||||||
|
(virtual_batches, pages_per_local_batch),
|
||||||
|
) + np.expand_dims(block_starts, axis=1)
|
||||||
|
block_indices = block_indices.flatten()
|
||||||
|
batch_indices = np.repeat(
|
||||||
|
np.arange(actual_batch_size, dtype=np.int32),
|
||||||
|
local_blocks * pages_per_local_batch,
|
||||||
|
)
|
||||||
|
block_table_local = block_table[batch_indices, block_indices].view(
|
||||||
|
virtual_batches, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
|
||||||
|
|
||||||
|
|
||||||
|
def cdiv(a: int, b: int) -> int:
|
||||||
|
"""Ceiling division."""
|
||||||
|
return -(a // -b)
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
"""FlashAttention backend implementation.
|
"""FlashAttention backend implementation.
|
||||||
@@ -100,6 +302,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.step_id = step_id
|
self.step_id = step_id
|
||||||
self.speculative_num_steps = speculative_num_steps
|
self.speculative_num_steps = speculative_num_steps
|
||||||
|
|
||||||
|
# Local attention settings
|
||||||
|
self.attention_chunk_size = (
|
||||||
|
model_runner.attention_chunk_size
|
||||||
|
if hasattr(model_runner, "attention_chunk_size")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Initialize forward metadata to cache repetitive calculations."""
|
"""Initialize forward metadata to cache repetitive calculations."""
|
||||||
metadata = FlashAttentionMetadata()
|
metadata = FlashAttentionMetadata()
|
||||||
@@ -189,6 +398,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
]
|
]
|
||||||
|
|
||||||
# Precompute cumulative sequence lengths
|
# Precompute cumulative sequence lengths
|
||||||
if (
|
if (
|
||||||
any(forward_batch.extend_prefix_lens_cpu)
|
any(forward_batch.extend_prefix_lens_cpu)
|
||||||
@@ -203,6 +413,51 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
||||||
metadata.max_seq_len_q = metadata.max_seq_len_k
|
metadata.max_seq_len_q = metadata.max_seq_len_k
|
||||||
|
|
||||||
|
# Setup local attention if enabled
|
||||||
|
if (
|
||||||
|
self.attention_chunk_size is not None
|
||||||
|
and forward_batch.forward_mode == ForwardMode.EXTEND
|
||||||
|
):
|
||||||
|
# Convert tensors to numpy for local attention processing
|
||||||
|
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
|
||||||
|
seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
|
||||||
|
|
||||||
|
# Adjust attention_chunk_size based on the actual sequence length
|
||||||
|
# to avoid index out of bounds errors
|
||||||
|
max_seq_len = seq_lens_np.max()
|
||||||
|
effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
|
||||||
|
# Make sure effective_chunk_size is divisible by page_size
|
||||||
|
effective_chunk_size = (
|
||||||
|
effective_chunk_size // self.page_size
|
||||||
|
) * self.page_size
|
||||||
|
if effective_chunk_size < self.page_size:
|
||||||
|
effective_chunk_size = self.page_size
|
||||||
|
|
||||||
|
# Create local attention metadata
|
||||||
|
(
|
||||||
|
seqlens_q_local_np,
|
||||||
|
cu_seqlens_q_local_np,
|
||||||
|
seqlens_k_local_np,
|
||||||
|
block_table_local,
|
||||||
|
) = make_local_attention_virtual_batches(
|
||||||
|
effective_chunk_size,
|
||||||
|
cu_seqlens_q_np,
|
||||||
|
seq_lens_np,
|
||||||
|
metadata.page_table,
|
||||||
|
self.page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||||
|
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
|
||||||
|
device
|
||||||
|
),
|
||||||
|
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
||||||
|
local_block_table=block_table_local,
|
||||||
|
local_max_query_len=seqlens_q_local_np.max(),
|
||||||
|
local_max_seq_len=seqlens_k_local_np.max(),
|
||||||
|
)
|
||||||
|
metadata.local_attn_metadata = local_metadata
|
||||||
|
|
||||||
# Precompute strided indices
|
# Precompute strided indices
|
||||||
if self.page_size > 1:
|
if self.page_size > 1:
|
||||||
self.strided_indices = torch.arange(
|
self.strided_indices = torch.arange(
|
||||||
@@ -211,6 +466,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.page_table = (
|
metadata.page_table = (
|
||||||
metadata.page_table[:, self.strided_indices] // self.page_size
|
metadata.page_table[:, self.strided_indices] // self.page_size
|
||||||
)
|
)
|
||||||
|
|
||||||
self.forward_metadata = metadata
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
@@ -254,7 +510,28 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
else (-1, -1)
|
else (-1, -1)
|
||||||
)
|
)
|
||||||
|
|
||||||
page_table = metadata.page_table
|
# Check if we should use local attention
|
||||||
|
use_local_attn = (
|
||||||
|
self.attention_chunk_size is not None
|
||||||
|
and metadata.local_attn_metadata is not None
|
||||||
|
and (hasattr(layer, "use_irope") and layer.use_irope)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the appropriate page table based on whether we're using local attention
|
||||||
|
if use_local_attn:
|
||||||
|
local_metadata = metadata.local_attn_metadata
|
||||||
|
page_table = local_metadata.local_block_table
|
||||||
|
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||||
|
cache_seqlens = local_metadata.local_seqused_k
|
||||||
|
max_seqlen_q = local_metadata.local_max_query_len
|
||||||
|
max_seqlen_k = local_metadata.local_max_seq_len
|
||||||
|
else:
|
||||||
|
page_table = metadata.page_table
|
||||||
|
cu_seqlens_q = metadata.cu_seqlens_q
|
||||||
|
cache_seqlens = metadata.cache_seqlens_int32
|
||||||
|
max_seqlen_q = metadata.max_seq_len_q
|
||||||
|
max_seqlen_k = metadata.max_seq_len_k
|
||||||
|
cu_seqlens_k = metadata.cu_seqlens_k
|
||||||
|
|
||||||
# Use Flash Attention for prefill
|
# Use Flash Attention for prefill
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
@@ -272,10 +549,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
v_cache=value_cache,
|
v_cache=value_cache,
|
||||||
page_table=page_table,
|
page_table=page_table,
|
||||||
cache_seqlens=metadata.cache_seqlens_int32,
|
cache_seqlens=cache_seqlens,
|
||||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||||
max_seqlen_q=metadata.max_seq_len_q,
|
max_seqlen_q=max_seqlen_q,
|
||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
causal=True,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
@@ -307,10 +584,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
v_cache=c_kv_cache,
|
v_cache=c_kv_cache,
|
||||||
qv=q_nope,
|
qv=q_nope,
|
||||||
page_table=page_table,
|
page_table=page_table,
|
||||||
cache_seqlens=metadata.cache_seqlens_int32,
|
cache_seqlens=cache_seqlens,
|
||||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||||
max_seqlen_q=metadata.max_seq_len_q,
|
max_seqlen_q=max_seqlen_q,
|
||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
causal=True,
|
||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
|
|||||||
@@ -23,9 +23,14 @@ def fused_moe_forward_native(
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if apply_router_weight_on_input:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
|
|||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 256,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 256,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1079,6 +1079,7 @@ def inplace_fused_experts(
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
@@ -1099,6 +1100,7 @@ def inplace_fused_experts(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
True,
|
True,
|
||||||
activation,
|
activation,
|
||||||
|
apply_router_weight_on_input,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
@@ -1120,6 +1122,7 @@ def inplace_fused_experts_fake(
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
@@ -1150,6 +1153,7 @@ def outplace_fused_experts(
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
@@ -1171,6 +1175,7 @@ def outplace_fused_experts(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
False,
|
False,
|
||||||
activation,
|
activation,
|
||||||
|
apply_router_weight_on_input,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
@@ -1193,6 +1198,7 @@ def outplace_fused_experts_fake(
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
@@ -1225,6 +1231,7 @@ def fused_experts(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
@@ -1247,6 +1254,7 @@ def fused_experts(
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
activation,
|
activation,
|
||||||
|
apply_router_weight_on_input,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
@@ -1268,6 +1276,7 @@ def fused_experts(
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
activation,
|
activation,
|
||||||
|
apply_router_weight_on_input,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
@@ -1291,6 +1300,7 @@ def fused_experts_impl(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
@@ -1423,7 +1433,7 @@ def fused_experts_impl(
|
|||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_padded,
|
num_tokens_post_padded,
|
||||||
False,
|
apply_router_weight_on_input,
|
||||||
topk_ids.shape[1],
|
topk_ids.shape[1],
|
||||||
config,
|
config,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
@@ -1456,7 +1466,7 @@ def fused_experts_impl(
|
|||||||
(
|
(
|
||||||
intermediate_cache3
|
intermediate_cache3
|
||||||
if not no_combine and topk_ids.shape[1] != 1
|
if not no_combine and topk_ids.shape[1] != 1
|
||||||
else out_hidden_states[begin_chunk_idx:end_chunk_idx]
|
else out_hidden_states[begin_chunk_idx:end_chunk_idx].unsqueeze(0)
|
||||||
),
|
),
|
||||||
a2_scale,
|
a2_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
@@ -1466,7 +1476,7 @@ def fused_experts_impl(
|
|||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_padded,
|
num_tokens_post_padded,
|
||||||
True,
|
not apply_router_weight_on_input,
|
||||||
1,
|
1,
|
||||||
config,
|
config,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -143,6 +144,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
inplace=inplace,
|
inplace=inplace,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
)
|
)
|
||||||
@@ -160,6 +162,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -200,6 +203,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=inplace and not no_combine,
|
inplace=inplace and not no_combine,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -276,6 +280,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
use_presharded_weights: bool = False,
|
use_presharded_weights: bool = False,
|
||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
@@ -302,6 +307,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.custom_routing_function = custom_routing_function
|
self.custom_routing_function = custom_routing_function
|
||||||
self.correction_bias = correction_bias
|
self.correction_bias = correction_bias
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||||
self.use_presharded_weights = use_presharded_weights
|
self.use_presharded_weights = use_presharded_weights
|
||||||
self.inplace = inplace
|
self.inplace = inplace
|
||||||
self.no_combine = no_combine
|
self.no_combine = no_combine
|
||||||
@@ -630,6 +636,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
custom_routing_function=self.custom_routing_function,
|
custom_routing_function=self.custom_routing_function,
|
||||||
correction_bias=self.correction_bias,
|
correction_bias=self.correction_bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.reduce_results and self.tp_size > 1:
|
if self.reduce_results and self.tp_size > 1:
|
||||||
|
|||||||
@@ -280,6 +280,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -370,6 +370,7 @@ class BlockInt8MoEMethod:
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -398,6 +399,7 @@ class BlockInt8MoEMethod:
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=inplace,
|
inplace=inplace,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
use_int8_w8a8=True,
|
use_int8_w8a8=True,
|
||||||
w1_scale=(layer.w13_weight_scale_inv),
|
w1_scale=(layer.w13_weight_scale_inv),
|
||||||
w2_scale=(layer.w2_weight_scale_inv),
|
w2_scale=(layer.w2_weight_scale_inv),
|
||||||
|
|||||||
@@ -905,6 +905,7 @@ class Fp8MoEMethod:
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -975,6 +976,7 @@ class Fp8MoEMethod:
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=inplace and not no_combine,
|
inplace=inplace and not no_combine,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
w1_scale=(
|
w1_scale=(
|
||||||
layer.w13_weight_scale_inv
|
layer.w13_weight_scale_inv
|
||||||
|
|||||||
@@ -344,6 +344,7 @@ class MoeWNA16Method:
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -374,6 +375,7 @@ class MoeWNA16Method:
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=inplace,
|
inplace=inplace,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
use_int4_w4a16=weight_bits == 4,
|
use_int4_w4a16=weight_bits == 4,
|
||||||
use_int8_w8a16=weight_bits == 8,
|
use_int8_w8a16=weight_bits == 8,
|
||||||
w1_scale=layer.w13_scales,
|
w1_scale=layer.w13_scales,
|
||||||
|
|||||||
@@ -230,6 +230,7 @@ class W8A8Int8MoEMethod:
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -257,6 +258,7 @@ class W8A8Int8MoEMethod:
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=inplace,
|
inplace=inplace,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
use_int8_w8a8=True,
|
use_int8_w8a8=True,
|
||||||
w1_scale=(layer.w13_weight_scale),
|
w1_scale=(layer.w13_weight_scale),
|
||||||
w2_scale=(layer.w2_weight_scale),
|
w2_scale=(layer.w2_weight_scale),
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class RadixAttention(nn.Module):
|
|||||||
sliding_window_size: int = -1,
|
sliding_window_size: int = -1,
|
||||||
is_cross_attention: bool = False,
|
is_cross_attention: bool = False,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
use_irope: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_q_head_num = num_heads
|
self.tp_q_head_num = num_heads
|
||||||
@@ -50,6 +51,7 @@ class RadixAttention(nn.Module):
|
|||||||
self.is_cross_attention = is_cross_attention
|
self.is_cross_attention = is_cross_attention
|
||||||
self.k_scale = None
|
self.k_scale = None
|
||||||
self.v_scale = None
|
self.v_scale = None
|
||||||
|
self.use_irope = use_irope
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -733,6 +733,69 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
|||||||
return new_freqs
|
return new_freqs
|
||||||
|
|
||||||
|
|
||||||
|
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||||
|
inv_freqs = super()._compute_inv_freq(base)
|
||||||
|
inv_freqs = inv_freqs[: (self.rotary_dim // 2)]
|
||||||
|
return inv_freqs
|
||||||
|
|
||||||
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
|
inv_freq = self._compute_inv_freq(self.base)
|
||||||
|
|
||||||
|
# self.max_position_embeddings here is number of image patches
|
||||||
|
# i.e. (image_size // patch_size) ** 2
|
||||||
|
num_patches = self.max_position_embeddings
|
||||||
|
img_idx = torch.arange(num_patches, dtype=torch.int32).reshape(num_patches, 1)
|
||||||
|
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
||||||
|
img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
|
||||||
|
num_patches_single_dim = int(math.sqrt(num_patches))
|
||||||
|
frequencies_x = img_idx % num_patches_single_dim
|
||||||
|
frequencies_y = img_idx // num_patches_single_dim
|
||||||
|
freqs_x = (
|
||||||
|
(frequencies_x + 1)[..., None] * inv_freq[None, None, :]
|
||||||
|
).repeat_interleave(2, dim=-1)
|
||||||
|
freqs_y = (
|
||||||
|
(frequencies_y + 1)[..., None] * inv_freq[None, None, :]
|
||||||
|
).repeat_interleave(2, dim=-1)
|
||||||
|
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
|
||||||
|
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
|
||||||
|
cache = torch.view_as_complex(
|
||||||
|
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
||||||
|
)
|
||||||
|
return cache
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
|
||||||
|
query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
|
||||||
|
key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
|
||||||
|
broadcast_shape = [
|
||||||
|
d if i == 1 or i == (query_.ndim - 1) else 1
|
||||||
|
for i, d in enumerate(query_.shape)
|
||||||
|
]
|
||||||
|
freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
|
||||||
|
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
|
||||||
|
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
|
||||||
|
return query_out.type_as(query), key_out.type_as(key)
|
||||||
|
|
||||||
|
|
||||||
class MRotaryEmbedding(RotaryEmbedding):
|
class MRotaryEmbedding(RotaryEmbedding):
|
||||||
"""Rotary Embedding with Multimodal Sections."""
|
"""Rotary Embedding with Multimodal Sections."""
|
||||||
|
|
||||||
|
|||||||
161
python/sglang/srt/managers/multimodal_processors/mllama4.py
Normal file
161
python/sglang/srt/managers/multimodal_processors/mllama4.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
from typing import List, Mapping, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import Llama4Processor
|
||||||
|
from transformers.image_utils import SizeDict
|
||||||
|
from transformers.models.llama4.image_processing_llama4 import (
|
||||||
|
find_supported_resolutions,
|
||||||
|
get_best_fit,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
|
BaseMultimodalProcessor,
|
||||||
|
MultimodalSpecialTokens,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||||
|
from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
|
||||||
|
from sglang.srt.utils import load_image
|
||||||
|
|
||||||
|
|
||||||
|
class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||||
|
models = [Llama4ForConditionalGeneration]
|
||||||
|
|
||||||
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
|
super().__init__(hf_config, server_args, _processor)
|
||||||
|
self.vision_config = hf_config.vision_config
|
||||||
|
self.text_config = hf_config.text_config
|
||||||
|
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||||
|
image_token=_processor.image_token
|
||||||
|
)
|
||||||
|
|
||||||
|
async def process_mm_data_async(
|
||||||
|
self,
|
||||||
|
image_data: List[Union[str, bytes]],
|
||||||
|
input_text,
|
||||||
|
max_req_input_len=None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if not image_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(input_text, list):
|
||||||
|
assert len(input_text) and isinstance(input_text[0], int)
|
||||||
|
input_text = self._processor.tokenizer.decode(input_text)
|
||||||
|
|
||||||
|
# Process images and text using the base processor's load_mm_data method
|
||||||
|
processed_data = self.load_mm_data(
|
||||||
|
prompt=input_text,
|
||||||
|
multimodal_tokens=self.multimodal_tokens,
|
||||||
|
max_req_input_len=max_req_input_len or 4096,
|
||||||
|
image_data=image_data,
|
||||||
|
return_text=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the images using the processor
|
||||||
|
processor = Llama4Processor.from_pretrained(
|
||||||
|
self.server_args.model_path, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the prompt and images
|
||||||
|
image_inputs = processor(
|
||||||
|
text=processed_data.input_text,
|
||||||
|
images=processed_data.images,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle image resolutions and aspect ratios
|
||||||
|
if "pixel_values" in image_inputs:
|
||||||
|
image_processor = processor.image_processor
|
||||||
|
tokenizer = self._processor.tokenizer
|
||||||
|
|
||||||
|
# Calculate tile size and find supported resolutions
|
||||||
|
tile_size = self.vision_config.image_size
|
||||||
|
max_num_tiles = getattr(self.vision_config, "max_patches", 1)
|
||||||
|
|
||||||
|
possible_resolutions = find_supported_resolutions(
|
||||||
|
max_num_chunks=max_num_tiles,
|
||||||
|
patch_size=SizeDict(height=tile_size, width=tile_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find best fit for each image
|
||||||
|
best_fit_sizes = [
|
||||||
|
get_best_fit(
|
||||||
|
(image.size[1], image.size[0]), # (height, width)
|
||||||
|
torch.tensor(possible_resolutions),
|
||||||
|
resize_to_max_canvas=image_processor.resize_to_max_canvas,
|
||||||
|
)
|
||||||
|
for image in processed_data.images
|
||||||
|
]
|
||||||
|
|
||||||
|
# Calculate aspect ratios and patches per image
|
||||||
|
aspect_ratios = [
|
||||||
|
(image_size[0] // tile_size, image_size[1] // tile_size)
|
||||||
|
for image_size in best_fit_sizes
|
||||||
|
]
|
||||||
|
|
||||||
|
patches_per_image = [
|
||||||
|
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add to image_inputs
|
||||||
|
image_inputs["aspect_ratios"] = aspect_ratios
|
||||||
|
image_inputs["patches_per_image"] = torch.tensor(patches_per_image)
|
||||||
|
|
||||||
|
# Process embed_is_patch
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
patch_id = vocab.get(processor.img_patch_token, -1)
|
||||||
|
image_end_id = vocab.get(processor.end_of_img_token, -1)
|
||||||
|
|
||||||
|
if patch_id != -1 and image_end_id != -1:
|
||||||
|
input_ids = image_inputs["input_ids"].view(-1)
|
||||||
|
|
||||||
|
# Remove BOS token if present
|
||||||
|
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
|
||||||
|
input_ids = input_ids[1:]
|
||||||
|
|
||||||
|
# Find image end indices and split input_ids
|
||||||
|
image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
|
||||||
|
|
||||||
|
if image_end_indices.size(0) > 0:
|
||||||
|
# Split at image boundaries
|
||||||
|
split_indices = (image_end_indices + 1)[:-1]
|
||||||
|
split_input_ids = torch.tensor_split(input_ids, split_indices)
|
||||||
|
split_input_ids = [x for x in split_input_ids if x.numel() > 0]
|
||||||
|
|
||||||
|
# Create embed_is_patch for each image
|
||||||
|
embed_is_patch = []
|
||||||
|
for per_image_input_ids in split_input_ids:
|
||||||
|
embed_is_patch.append(per_image_input_ids == patch_id)
|
||||||
|
|
||||||
|
image_inputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
|
# Convert to the format expected by SGLang
|
||||||
|
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||||
|
|
||||||
|
# Add metadata for image processing
|
||||||
|
image_inputs["mm_items"] = [
|
||||||
|
MultimodalDataItem(
|
||||||
|
pixel_values=image_inputs["pixel_values"],
|
||||||
|
modality=Modality.IMAGE,
|
||||||
|
# Add additional metadata needed for Llama4 vision processing
|
||||||
|
embed_is_patch=image_inputs.get("embed_is_patch", None),
|
||||||
|
aspect_ratios=image_inputs.get("aspect_ratios", None),
|
||||||
|
patches_per_image=image_inputs.get("patches_per_image", None),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return image_inputs
|
||||||
|
|
||||||
|
def get_patch_per_chunk(self):
|
||||||
|
"""Calculate patches per chunk based on vision config"""
|
||||||
|
image_size = self.vision_config.image_size
|
||||||
|
patch_size = self.vision_config.patch_size
|
||||||
|
|
||||||
|
assert (
|
||||||
|
image_size % patch_size == 0
|
||||||
|
), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
|
||||||
|
|
||||||
|
ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
|
||||||
|
return (image_size // patch_size) ** 2 // ds_ratio
|
||||||
@@ -128,6 +128,7 @@ class ModelRunner:
|
|||||||
self.model_config.attention_arch == AttentionArch.MLA
|
self.model_config.attention_arch == AttentionArch.MLA
|
||||||
and not server_args.disable_mla
|
and not server_args.disable_mla
|
||||||
)
|
)
|
||||||
|
self.attention_chunk_size = model_config.attention_chunk_size
|
||||||
|
|
||||||
# Model-specific adjustment
|
# Model-specific adjustment
|
||||||
self.model_specific_adjustment()
|
self.model_specific_adjustment()
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ class LlamaMLP(nn.Module):
|
|||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
reduce_results: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
@@ -78,6 +79,7 @@ class LlamaMLP(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("down_proj", prefix),
|
prefix=add_prefix("down_proj", prefix),
|
||||||
|
reduce_results=reduce_results,
|
||||||
)
|
)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -281,7 +283,7 @@ class LlamaModel(nn.Module):
|
|||||||
self.layers = make_layers(
|
self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda idx, prefix: LlamaDecoderLayer(
|
lambda idx, prefix: LlamaDecoderLayer(
|
||||||
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
||||||
),
|
),
|
||||||
prefix="model.layers",
|
prefix="model.layers",
|
||||||
)
|
)
|
||||||
@@ -375,9 +377,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = LlamaModel(
|
self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
|
||||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
|
||||||
)
|
|
||||||
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
||||||
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
@@ -402,6 +402,14 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
|
|
||||||
self.capture_aux_hidden_states = False
|
self.capture_aux_hidden_states = False
|
||||||
|
|
||||||
|
def _init_model(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
return LlamaModel(config, quant_config=quant_config, prefix=prefix)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
420
python/sglang/srt/models/llama4.py
Normal file
420
python/sglang/srt/models/llama4.py
Normal file
@@ -0,0 +1,420 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/models/llama4.py
|
||||||
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import Llama4TextConfig
|
||||||
|
|
||||||
|
from sglang.srt.distributed import (
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
|
from sglang.srt.layers.linear import (
|
||||||
|
QKVParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
||||||
|
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Llama4MoE(nn.Module):
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
|
@staticmethod
|
||||||
|
def custom_routing_function(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1)
|
||||||
|
router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
|
||||||
|
hidden_states.dtype
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
router_scores_aK.view(-1).reshape(router_scores_aK.shape),
|
||||||
|
router_indices_aK.to(torch.int32),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Llama4TextConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
|
||||||
|
intermediate_size_moe = config.intermediate_size
|
||||||
|
self.router = ReplicatedLinear(
|
||||||
|
config.hidden_size,
|
||||||
|
config.num_local_experts,
|
||||||
|
bias=False,
|
||||||
|
quant_config=None,
|
||||||
|
prefix=add_prefix("router", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.experts = FusedMoE(
|
||||||
|
num_experts=config.num_local_experts,
|
||||||
|
top_k=config.num_experts_per_tok,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||||
|
intermediate_size=intermediate_size_moe,
|
||||||
|
reduce_results=False,
|
||||||
|
renormalize=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
apply_router_weight_on_input=True,
|
||||||
|
prefix=add_prefix("experts", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.shared_expert = LlamaMLP(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=intermediate_size_moe,
|
||||||
|
hidden_act="silu",
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("shared_expert", prefix),
|
||||||
|
reduce_results=False, # We need to do scatter before reduce
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
# router_scores: [num_tokens, num_experts]
|
||||||
|
router_logits, _ = self.router(hidden_states)
|
||||||
|
shared_out = self.shared_expert(hidden_states)
|
||||||
|
routed_out = self.experts(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
)
|
||||||
|
out_aD = routed_out + shared_out
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
out_aD = tensor_model_parallel_all_reduce(out_aD)
|
||||||
|
|
||||||
|
return out_aD
|
||||||
|
|
||||||
|
|
||||||
|
class Llama4Attention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Llama4TextConfig,
|
||||||
|
layer_id: int,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
bias: bool = False,
|
||||||
|
bias_o_proj: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.use_rope = int((layer_id + 1) % 4 != 0)
|
||||||
|
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
self.head_dim = config.head_dim
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.attn_temperature_tuning = config.attn_temperature_tuning
|
||||||
|
self.floor_scale = config.floor_scale
|
||||||
|
self.attn_scale = config.attn_scale
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.n_rep = self.num_heads // self.num_kv_heads
|
||||||
|
self.qk_norm = (
|
||||||
|
RMSNorm(
|
||||||
|
hidden_size=self.head_dim,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
if self.use_qk_norm
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
head_size=self.head_dim,
|
||||||
|
total_num_heads=self.total_num_heads,
|
||||||
|
total_num_kv_heads=self.total_num_kv_heads,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("qkv_proj", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
input_size=self.total_num_heads * self.head_dim,
|
||||||
|
output_size=hidden_size,
|
||||||
|
bias=bias_o_proj,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("o_proj", prefix),
|
||||||
|
)
|
||||||
|
is_neox_style = True
|
||||||
|
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
||||||
|
if is_gguf and config.model_type in ["llama", "llama4"]:
|
||||||
|
is_neox_style = False
|
||||||
|
|
||||||
|
self.rotary_emb = (
|
||||||
|
get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=int(rope_theta),
|
||||||
|
rope_scaling=rope_scaling if rope_scaling != "default" else None,
|
||||||
|
is_neox_style=is_neox_style,
|
||||||
|
)
|
||||||
|
if self.use_rope
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn = RadixAttention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
layer_id=layer_id,
|
||||||
|
prefix=add_prefix("attn", prefix),
|
||||||
|
use_irope=self.use_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||||
|
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
||||||
|
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
|
||||||
|
|
||||||
|
return attn_scale.unsqueeze(-1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
|
if self.rotary_emb is not None:
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
|
||||||
|
if self.qk_norm is not None:
|
||||||
|
# TODO: support float
|
||||||
|
q = q.reshape(-1, self.head_dim).contiguous().bfloat16()
|
||||||
|
k = k.reshape(-1, self.head_dim).contiguous().bfloat16()
|
||||||
|
q = self.qk_norm(q).to(q.dtype)
|
||||||
|
k = self.qk_norm(k).to(k.dtype)
|
||||||
|
q = q.reshape(-1, self.q_size)
|
||||||
|
k = k.reshape(-1, self.kv_size)
|
||||||
|
|
||||||
|
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
||||||
|
# the inference-time temperature tuning function is customized to not affect short context
|
||||||
|
# while working at very long context
|
||||||
|
# https://arxiv.org/abs/2501.19399
|
||||||
|
if self.attn_temperature_tuning and not self.use_rope:
|
||||||
|
attn_scale = self._get_attn_scale(positions)
|
||||||
|
q = (q * attn_scale).to(q.dtype)
|
||||||
|
|
||||||
|
attn_output = self.attn(q, k, v, forward_batch)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Llama4DecoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Llama4TextConfig,
|
||||||
|
layer_id: int = 0,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
rope_theta = config.rope_theta
|
||||||
|
rope_scaling = config.rope_scaling
|
||||||
|
max_position_embeddings = config.max_position_embeddings
|
||||||
|
|
||||||
|
self.self_attn = Llama4Attention(
|
||||||
|
config=config,
|
||||||
|
layer_id=layer_id,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
quant_config=quant_config,
|
||||||
|
bias=False,
|
||||||
|
bias_o_proj=False,
|
||||||
|
prefix=add_prefix("self_attn", prefix),
|
||||||
|
)
|
||||||
|
is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
|
||||||
|
if is_moe_layer:
|
||||||
|
self.feed_forward = Llama4MoE(
|
||||||
|
config=config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("feed_forward", prefix),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.feed_forward = LlamaMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size_mlp,
|
||||||
|
hidden_act="silu",
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("feed_forward", prefix),
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(
|
||||||
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Self Attention
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
|
hidden_states = self.feed_forward(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class Llama4Model(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Llama4TextConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("embed_tokens", prefix),
|
||||||
|
)
|
||||||
|
self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda idx, prefix: Llama4DecoderLayer(
|
||||||
|
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
||||||
|
),
|
||||||
|
prefix="model.layers",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.layers_to_capture = []
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
if input_embeds is None:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
hidden_states = input_embeds
|
||||||
|
residual = None
|
||||||
|
aux_hidden_states = []
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
if i in self.layers_to_capture:
|
||||||
|
aux_hidden_states.append(hidden_states + residual)
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
forward_batch,
|
||||||
|
residual,
|
||||||
|
)
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
if len(aux_hidden_states) == 0:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
return hidden_states, aux_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Llama4ForCausalLM(LlamaForCausalLM):
|
||||||
|
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Llama4TextConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__(config, quant_config, prefix)
|
||||||
|
|
||||||
|
def _init_model(
|
||||||
|
self,
|
||||||
|
config: Llama4TextConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
return Llama4Model(config, quant_config=quant_config, prefix=prefix)
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = [Llama4ForCausalLM]
|
||||||
154
python/sglang/srt/models/mllama4.py
Normal file
154
python/sglang/srt/models/mllama4.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
# TODO: add Aapted from vllm/mllama4.py
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import Llama4Config
|
||||||
|
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.quantization import QuantizationConfig
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
|
from sglang.srt.utils import add_prefix
|
||||||
|
|
||||||
|
|
||||||
|
class Llama4ForConditionalGeneration(nn.Module):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Llama4Config,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
# Initialize the language model
|
||||||
|
from sglang.srt.models.llama4 import Llama4ForCausalLM
|
||||||
|
|
||||||
|
self.language_model = Llama4ForCausalLM(
|
||||||
|
config.text_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("language_model", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(config.text_config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
return self.language_model(input_ids, positions, forward_batch)
|
||||||
|
|
||||||
|
def permute_qk_weight_for_rotary(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
) -> Tuple[str, torch.Tensor]:
|
||||||
|
|
||||||
|
def permute(w: torch.Tensor, n_heads: int):
|
||||||
|
attn_in = self.language_model.config.head_dim * n_heads
|
||||||
|
attn_out = self.language_model.config.hidden_size
|
||||||
|
|
||||||
|
return (
|
||||||
|
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape(attn_in, attn_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
modules = name.split(".")
|
||||||
|
|
||||||
|
# rotary embeds should be sliced
|
||||||
|
if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
|
||||||
|
loaded_weight = permute(
|
||||||
|
loaded_weight, self.language_model.config.num_key_value_heads
|
||||||
|
)
|
||||||
|
elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
|
||||||
|
loaded_weight = permute(
|
||||||
|
loaded_weight, self.language_model.config.num_attention_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
return name, loaded_weight
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
||||||
|
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||||
|
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
||||||
|
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||||
|
(".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
|
||||||
|
(".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
|
||||||
|
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
|
||||||
|
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
|
num_experts = self.config.text_config.num_local_experts
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
|
||||||
|
if name.startswith("vision_model") or name.startswith(
|
||||||
|
"multi_modal_projector"
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
|
||||||
|
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if ".experts" in name:
|
||||||
|
if ".gate_up_proj" in name:
|
||||||
|
name_list = [
|
||||||
|
name.replace(".experts.gate_up_proj", ".experts.w13_weight")
|
||||||
|
] * 2
|
||||||
|
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
|
||||||
|
shard_id_list = ["w1", "w3"]
|
||||||
|
else:
|
||||||
|
name_list = [
|
||||||
|
name.replace(".experts.down_proj", ".experts.w2_weight")
|
||||||
|
]
|
||||||
|
shard_id_list = ["w2"]
|
||||||
|
loaded_weight_list = [loaded_weight]
|
||||||
|
for name, loaded_weight, shard_id in zip(
|
||||||
|
name_list, loaded_weight_list, shard_id_list
|
||||||
|
):
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
for expert_id in range(num_experts):
|
||||||
|
weight_loader(
|
||||||
|
param,
|
||||||
|
loaded_weight[expert_id].T,
|
||||||
|
name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = Llama4ForConditionalGeneration
|
||||||
Reference in New Issue
Block a user