[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)

Signed-off-by: xyDong0223 <dongxinyu03@baidu.com>
Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
chanzhennan
2026-02-28 11:15:50 +08:00
committed by GitHub
parent 153093d3b3
commit 82544aa0cc
17 changed files with 2668 additions and 1532 deletions

View File

@@ -3,95 +3,113 @@ from vllm import ModelRegistry
def register_model():
# from .demo_model import DemoModel # noqa: F401
from .qwen2_vl import Qwen2VLForConditionalGeneration #noqa: F401
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration #noqa: F401
from .qwen3_moe import Qwen3MoeForCausalLM #noqa: F401
from .qwen3_vl import Qwen3VLForConditionalGeneration
from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
from .qwen3_omni_moe_thinker import Qwen3OmniMoeThinkerForConditionalGeneration
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration # noqa: F401
from .qwen2_vl import Qwen2VLForConditionalGeneration # noqa: F401
from .qwen3_moe import Qwen3MoeForCausalLM # noqa: F401
from .qwen3_omni_moe_thinker import ( # noqa: F401
Qwen3OmniMoeThinkerForConditionalGeneration,
)
from .qwen3_vl import Qwen3VLForConditionalGeneration # noqa: F401
from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration # noqa: F401
# from .llama4 import Llama4ForCausalLM #noqa: F401
# from .mllama4 import Llama4ForConditionalGeneration #noqa: F401
# from .deepseek_v2 import KunlunDeepseekV2MoE
# ModelRegistry.register_model(
# "DemoModel",
# "vllm_kunlun.model_executor.models.demo_model:DemoModel")
ModelRegistry.register_model(
"Qwen2VLForConditionalGeneration",
"vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration")
"vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration",
)
ModelRegistry.register_model(
"Qwen2_5_VLForConditionalGeneration",
"vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration")
"vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration",
)
ModelRegistry.register_model(
"Qwen3ForCausalLM",
"vllm_kunlun.models.qwen3:Qwen3ForCausalLM")
"Qwen3ForCausalLM", "vllm_kunlun.models.qwen3:Qwen3ForCausalLM"
)
ModelRegistry.register_model(
"Qwen3MoeForCausalLM",
"vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM")
"Qwen3MoeForCausalLM", "vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM"
)
ModelRegistry.register_model(
"Qwen3NextForCausalLM",
"vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM")
"Qwen3NextForCausalLM", "vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM"
)
ModelRegistry.register_model(
"GptOssForCausalLM",
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
"Qwen3NextMTP", "vllm_kunlun.models.qwen3_next_mtp:Qwen3NextMTP"
)
ModelRegistry.register_model(
"InternLM2ForCausalLM",
"vllm_kunlun.models.internlm2:InternLM2ForCausalLM")
"GlmForCausalLM", "vllm_kunlun.models.glm:GlmForCausalLM"
)
ModelRegistry.register_model(
"InternVLChatModel",
"vllm_kunlun.models.internvl:InternVLChatModel")
"GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM"
)
ModelRegistry.register_model(
"InternLM2ForCausalLM", "vllm_kunlun.models.internlm2:InternLM2ForCausalLM"
)
ModelRegistry.register_model(
"InternVLChatModel", "vllm_kunlun.models.internvl:InternVLChatModel"
)
ModelRegistry.register_model(
"InternS1ForConditionalGeneration",
"vllm_kunlun.models.interns1:InternS1ForConditionalGeneration")
"vllm_kunlun.models.interns1:InternS1ForConditionalGeneration",
)
ModelRegistry.register_model(
"Qwen3VLForConditionalGeneration",
"vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration")
"vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration",
)
ModelRegistry.register_model(
"Qwen3VLMoeForConditionalGeneration",
"vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration")
"vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration",
)
ModelRegistry.register_model(
"Qwen3OmniMoeForConditionalGeneration",
"vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration")
"vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration",
)
ModelRegistry.register_model(
"SeedOssForCausalLM",
"vllm_kunlun.models.seed_oss:SeedOssForCausalLM")
"SeedOssForCausalLM", "vllm_kunlun.models.seed_oss:SeedOssForCausalLM"
)
ModelRegistry.register_model(
"MiMoV2FlashForCausalLM",
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM")
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM",
)
ModelRegistry.register_model(
"GptOssForCausalLM",
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
"GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM"
)
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
"DeepseekV3ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM"
)
ModelRegistry.register_model(
"DeepseekV32ForCausalLM",
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
ModelRegistry.register_model(
"DeepSeekMTPModel",
"vllm_kunlun.models.deepseek_mtp:DeepSeekMTP")
"DeepseekV32ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM"
)
ModelRegistry.register_model(
"GlmMoeDsaForCausalLM",
"vllm_kunlun.models.deepseek_v2:GlmMoeDsaForCausalLM")
"DeepSeekMTPModel", "vllm_kunlun.models.deepseek_mtp:DeepSeekMTP"
)
ModelRegistry.register_model(
"GlmMoeDsaForCausalLM", "vllm_kunlun.models.deepseek_v2:GlmMoeDsaForCausalLM"
)
def register_quant_method():
"""to do"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,303 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next MTP model."""
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig
from .qwen3_next import Qwen3NextDecoderLayer, Qwen3NextRMSNorm
logger = init_logger(__name__)
KVCache = tuple[torch.Tensor, torch.Tensor]
@support_torch_compile
class Qwen3NextMultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
config: Qwen3NextConfig = model_config.hf_config
self.config = config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.fc = ColumnParallelLinear(
self.config.hidden_size * 2,
self.config.hidden_size,
gather_output=True,
bias=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fc",
)
self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer(
vllm_config,
layer_type="full_attention",
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(self.num_mtp_layers)
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
hidden_states = self.pre_fc_norm_hidden(hidden_states)
hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
hidden_states = self.fc(hidden_states)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
current_step_idx = spec_step_idx % self.num_mtp_layers
hidden_states, residual = self.layers[current_step_idx](
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@support_torch_compile
class Qwen3NextMTP(nn.Module, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["up_proj", "down_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
cache_config = vllm_config.cache_config
assert (
not cache_config.enable_prefix_caching
), "Qwen3NextMTP currently does not support prefix caching"
self.quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.model = Qwen3NextMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
)
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
hidden_states = self.model(
input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> Optional[torch.Tensor]:
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
shared_weight_names = ["embed_tokens", "lm_head"]
def remap_weight_names(weights):
for name, weight in weights:
if name.startswith("mtp."):
name = name.replace("mtp.", "model.")
elif not any(key in name for key in shared_weight_names):
continue
yield name, weight
loader = AutoWeightsLoader(self)
return loader.load_weights(remap_weight_names(weights))

View File

@@ -16,33 +16,33 @@
# limitations under the License.
"""kunlun custom op entry"""
import torch_xmlir
from typing import Optional
import torch
import os
from typing import Optional, List, Dict
import vllm.envs as envs
import os
import ctypes
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import kunlun_ops
logger.info(f"Load custom ops library success!")
logger.info("Load custom ops library success!")
except ImportError as e:
logger.warning("Import error msg: %s", e.msg)
_per_token_smooth_quant = True
def is_per_token_smooth_quant():
""" is per token smooth quant """
"""is per token smooth quant"""
return _per_token_smooth_quant
class KunlunOps:
"""KunlunOps"""
# Attention ops
@staticmethod
def paged_attention_v1(
@@ -67,9 +67,9 @@ class KunlunOps:
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False
):
""" PagedAttentionV1 """
alibi_sqrt=False,
):
"""PagedAttentionV1"""
# block_size = value_cache.shape[2]
kunlun_ops.paged_attention(
x=query,
@@ -81,7 +81,7 @@ class KunlunOps:
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128
vo_head_dim=128,
)
@staticmethod
@@ -110,9 +110,9 @@ class KunlunOps:
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False
):
""" PagedAttentionV2 """
alibi_sqrt=False,
):
"""PagedAttentionV2"""
# block_size = value_cache.shape[2]
kunlun_ops.paged_attention(
x=query,
@@ -124,31 +124,28 @@ class KunlunOps:
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128
vo_head_dim=128,
)
# Activation ops
@staticmethod
def silu_and_mul(out: torch.Tensor,
x: torch.Tensor):
""" silu and mul """
def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
"""silu and mul"""
kunlun_ops.silu_and_mul(
x,
axis=-1,
turn=True,
out=out,
)
)
# Activation ops
@staticmethod
def quick_gelu(out: torch.Tensor,
x: torch.Tensor):
""" quick gelu """
def quick_gelu(out: torch.Tensor, x: torch.Tensor):
"""quick gelu"""
kunlun_ops.quick_gelu(
x,
out=out,
)
)
# Layernorm
@staticmethod
@@ -159,9 +156,7 @@ class KunlunOps:
epsilon,
):
"""rms_norm"""
kunlun_ops.rmsnorm(
x, weight.to(torch.float32), epsilon, out=out
)
kunlun_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
@staticmethod
def fused_add_rms_norm(
@@ -179,16 +174,11 @@ class KunlunOps:
residual.copy_(fused_input, non_blocking=True)
x.copy_(output)
# Rotary embedding
@staticmethod
def rotary_embedding(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style):
positions, query, key, head_size, cos_sin_cache, is_neox_style
):
"""
refactor RotaryEmbedding forward function
"""
@@ -196,62 +186,38 @@ class KunlunOps:
key_x = key.contiguous()
torch.ops._C.rotary_embedding(
positions,
query_x,
key_x,
head_size,
cos_sin_cache,
is_neox_style)
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
)
return query_x, key_x
# Rotary embedding
@staticmethod
def mrotary_embedding(
positions,
mrope_section,
query,
key,
head_size,
cos_sin_cache,
is_neox_style):
positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
):
"""
refactor RotaryEmbedding forward function
"""
query_x = query.contiguous()
key_x = key.contiguous()
query_x_dim = query_x.dim()
assert is_neox_style
kunlun_ops.mrotary_embedding_neox(
positions,
query_x,
key_x,
head_size,
cos_sin_cache,
mrope_section)
positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
)
query.data = query_x
key.data = key_x
key.data = key_x
return query, key
@staticmethod
def swap_blocks(
src,
dst,
block_mapping):
""" swap_blocks """
kunlun_ops.swap_blocks(
src,
dst,
block_mapping
)
def swap_blocks(src, dst, block_mapping):
"""swap_blocks"""
kunlun_ops.swap_blocks(src, dst, block_mapping)
@staticmethod
def copy_blocks(
key_caches,
value_caches,
block_mapping):
""" copy_blocks """
def copy_blocks(key_caches, value_caches, block_mapping):
"""copy_blocks"""
for i in range(len(key_caches)):
key_caches[i] = key_caches[i].contiguous()
value_caches[i] = value_caches[i].contiguous()
@@ -269,16 +235,10 @@ class KunlunOps:
value_cache,
slot_mapping,
kv_cache_dtype,
):
""" reshape_and_cache """
):
"""reshape_and_cache"""
# slot_mapping_cast = slot_mapping.to(torch.int32)
kunlun_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping
)
kunlun_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
@staticmethod
def multi_query_kv_attention(
@@ -287,7 +247,7 @@ class KunlunOps:
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
**kargs
**kargs,
) -> torch.Tensor:
"""
query: shape = [num_prompt_tokens, num_heads, head_size]
@@ -297,16 +257,12 @@ class KunlunOps:
key = key.unsqueeze(0)
value = value.unsqueeze(0)
output = torch.empty_like(query)
alibi_slopes = kargs.get("alibi_slopes", None)
mask = kargs.get("mask", None)
is_causal = kargs.get("is_causal", True)
is_lvsl = kargs.get("is_lvsl", True)
B, T, Qh, Hd = query.shape
KVh = key.size(2)
if KVh != Qh:
repeat = Qh // KVh
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
value = value.repeat_interleave(repeat, dim=2)
kunlun_ops.attention(
q=query,
@@ -321,80 +277,90 @@ class KunlunOps:
return output
@staticmethod
def quant_fusedresidual_rmsnorm_op(x,
residual,
weight,
bias,
scale_to_int,
eps,
dyn_scale: bool,
type: int = 1):
def quant_fusedresidual_rmsnorm_op(
x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
"""Quantized fused residual layer normalization"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
kunlun_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
out=out, out_scale=out_scale , residual_tensor=residual)
kunlun_ops.quant_fusedresidual_rmsnorm(
x,
residual,
weight,
bias,
eps,
out=out,
out_scale=out_scale,
residual_tensor=residual,
)
if residual is None:
return out, out_scale
return out, out_scale, residual
@staticmethod
def quant_rmsnorm_op(x,
weight,
bias,
scale_to_int,
eps,
dyn_scale : bool,
type: int = 1):
def quant_rmsnorm_op(
x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
"""Quantized RMSNorm"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
kunlun_ops.quant_rmsnorm(x, weight, bias, eps,
out=out, out_scale=out_scale)
kunlun_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
return out, out_scale
@staticmethod
def smooth_quant_matmul_column_row_kernels(input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
otype):
def smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
otype,
):
"""smooth_quant_matmul_column_row_kernels"""
input_shape = input_tensor.shape
weight_shape = weight.shape
if input_tensor.dim() == 3:
input_tensor = input_tensor.reshape(-1, input_shape[-1])
out = torch.empty((input_shape[0] * input_shape[1],
weight_shape[0]),
dtype=torch.float16,
device=weight.device)
out = torch.empty(
(input_shape[0] * input_shape[1], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
output_bs_shape = [input_shape[0], input_shape[1]]
elif input_tensor.dim() == 2:
out = torch.empty((input_shape[0], weight_shape[0]),
dtype=torch.float16,
device=weight.device)
out = torch.empty(
(input_shape[0], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
output_bs_shape = [-1]
kunlun_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
weight, smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
out=out)
kunlun_ops.smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
out=out,
)
out = out.view(*output_bs_shape, weight_shape[0])
@@ -404,6 +370,7 @@ class KunlunOps:
if torch.is_tensor(x):
return (type(x), x.device, x.dtype, x.shape, x.is_contiguous())
return (type(x), x)
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
@@ -420,23 +387,24 @@ class KunlunOps:
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""fused_moe"""
global_num_experts, up_gate_size, _ = w1.shape
M, N = hidden_states.shape
hidden_dim = w2.shape[1]
normed_score = torch.empty(M,
moe_top_k,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
moe_top_k,
dtype=torch.int32,
device=hidden_states.device)
normed_score = torch.empty(
M, moe_top_k, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(
M, moe_top_k, dtype=torch.int32, device=hidden_states.device
)
num_blocks = 12
block_statistic = torch.zeros(
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
num_blocks,
global_num_experts,
dtype=torch.int32,
device=hidden_states.device,
)
router_logits = router_logits.to(torch.float)
if scoring_func == "softmax":
@@ -445,24 +413,27 @@ class KunlunOps:
normed_score=normed_score,
topk_index=topk_ids,
block_statistic=None,
stable=True)
stable=False,
)
elif scoring_func == "sigmoid":
torch.ops._C.moe_sigmoid_group_topk_norm(
x=router_logits,
topk_index=topk_ids,
norm_score=normed_score,
block_static=block_statistic,
bias=e_score_correction_bias,
scale=1.0,
n_group=num_expert_group,
topk_group=topk_group,
)
x=router_logits,
topk_index=topk_ids,
norm_score=normed_score,
block_static=block_statistic,
bias=e_score_correction_bias,
scale=1.0,
n_group=num_expert_group,
topk_group=topk_group,
)
if w1_bias is not None or w2_bias is not None:
if w1_bias is not None or w2_bias is not None:
# Rignt now this branch is for gpt oss
# TODO (@xyDong23): faster here using moe_fc kernel
normed_score = normed_score.to(hidden_states.dtype)
out = torch.zeros(M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device)
out = torch.zeros(
M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device
)
repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0)
topk_ids_flat = topk_ids.flatten()
for i in range(global_num_experts):
@@ -470,9 +441,13 @@ class KunlunOps:
selected_token = topk_ids_flat == experts_id
if selected_token.sum():
cur_token = repeat_x[selected_token]
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
dtype=cur_token.dtype, device=cur_token.device)
groupgemm1 = cur_token@ w1[i].T
up_gate = torch.empty(
selected_token.sum(),
up_gate_size // 2,
dtype=cur_token.dtype,
device=cur_token.device,
)
groupgemm1 = cur_token @ w1[i].T
# Add w13 bias
if w1_bias is not None:
groupgemm1 = groupgemm1 + w1_bias[i]
@@ -482,53 +457,129 @@ class KunlunOps:
if w2_bias is not None:
groupgemm2 = groupgemm2 + w2_bias[i]
out[selected_token] = groupgemm2
ouput = (out.view(M, moe_top_k, N) * normed_score.unsqueeze(2)).sum(dim=1).to(hidden_states.dtype)
ouput = (
(out.view(M, moe_top_k, N) * normed_score.unsqueeze(2))
.sum(dim=1)
.to(hidden_states.dtype)
)
return ouput
else:
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
# from vllm.forward_context import get_forward_context
# forward_context = get_forward_context()
# attn_metadata: AttentionMetadata = forward_context.attn_metadata
# prefix = "model.layers.0.linear_attn"
# if attn_metadata is not None:
# attn_metadata = attn_metadata[prefix]
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod)
# if attn_metadata is None or attn_metadata.num_prefills > 0 or :
if M * moe_top_k < 400:
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
torch.ops.xspeedgate_ops.moe_pre_small(
topk_ids, global_num_experts, False, False, hidden_states
)
)
experts_num_lod = torch.ops.xspeedgate_ops.moe_active_expert_balance(
topk_ids, global_num_experts, False
)
out = torch.ops.xspeedgate_ops.fused_moe(
hidden_states,
w1,
w2,
normed_score.to(hidden_states.dtype),
sorted_tokens_num_lod,
sorted_tokens_idx,
experts_num_lod,
)
return out.sum(1)
y = torch.empty(M,moe_top_k,
w1.shape[1],
if M * moe_top_k > 768:
moe_expand = torch.empty(
(M * moe_top_k, N),
dtype=hidden_states.dtype,
device=hidden_states.device)
device=hidden_states.device,
) # [M*top_k, N], float
expert_m = torch.zeros(
global_num_experts, dtype=torch.int32, device=hidden_states.device
) # [E]
sorted_tokens_num_lod = torch.zeros(
global_num_experts + 1,
dtype=torch.int32,
device=hidden_states.device,
) # [E+1]
sorted_tokens_idx = torch.zeros(
M * moe_top_k, dtype=torch.int32, device=hidden_states.device
)
torch.ops._C.gen_block_statistic(topk_ids, block_statistic)
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod,
)
else:
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
torch.ops.xspeedgate_ops.moe_pre_small(
topk_ids,
global_num_experts,
index_have_neg=False,
sort_mode=True,
x=hidden_states,
)
)
y = torch.empty(
M,
moe_top_k,
w1.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y,
if M < 1024:
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y,
)
d = y.shape[-1] // 2
output_shape = y.shape[:-1] + (d,)
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out1 = out1.reshape(-1, out1.shape[-1])
else:
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y,
act="SWISH_GLU",
)
y = y[..., : y.shape[-1] // 2]
out1 = y.reshape(-1, y.shape[-1])
out = torch.empty(
M,
moe_top_k,
w2.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
d = y.shape[-1] // 2
output_shape = (y.shape[:-1] + (d, ))
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out = torch.empty(M,moe_top_k,
w2.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
out1 = out1.reshape(-1, out1.shape[-1])
torch.ops._C.moe_fc(
x=out1,
weight=w2,
@@ -538,8 +589,12 @@ class KunlunOps:
y=out,
)
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
dequant_scale = torch.ones(
[M, moe_top_k], dtype=torch.float32, device=out.device
)
output = torch.empty(
[M, N], dtype=hidden_states.dtype, device=hidden_states.device
)
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
torch.ops._C.moe_post(
@@ -547,9 +602,9 @@ class KunlunOps:
moe_index=sorted_tokens_idx,
normed_scale=normed_score,
dequant_scale=dequant_scale,
y=output
y=output,
)
return output
@staticmethod
@@ -568,23 +623,23 @@ class KunlunOps:
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> torch.Tensor:
x = hidden_states
batch, hidden_size = x.shape
batch, hidden_size = x.shape
num_local_experts, up_gate_size, _ = w13_weight.shape
router_logits = x.to(linear_weights.dtype)@linear_weights.T
topk_weights = torch.empty(batch,
top_k,
dtype=router_logits.dtype,
device=router_logits.device)
topk_ids = torch.empty(batch,
top_k,
dtype=torch.int32,
device=router_logits.device)
block_static = torch.empty(0, dtype=torch.int32,device=router_logits.device)
torch.ops._C.moe_softmax_topk(router_logits, topk_weights, topk_ids, block_static)
router_logits = x.to(linear_weights.dtype) @ linear_weights.T
topk_weights = torch.empty(
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
)
topk_ids = torch.empty(
batch, top_k, dtype=torch.int32, device=router_logits.device
)
block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
torch.ops._C.moe_softmax_topk(
router_logits, topk_weights, topk_ids, block_static
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
@@ -598,11 +653,19 @@ class KunlunOps:
selected_token = topk_ids_flat == experts_id
if selected_token.sum():
cur_token = repeat_x[selected_token]
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
dtype=cur_token.dtype, device=cur_token.device)
torch.ops._C.silu_and_mul(up_gate, cur_token@ w13_weight[i].T)
up_gate = torch.empty(
selected_token.sum(),
up_gate_size // 2,
dtype=cur_token.dtype,
device=cur_token.device,
)
torch.ops._C.silu_and_mul(up_gate, cur_token @ w13_weight[i].T)
out[selected_token] = up_gate @ w2_weight[i].T
output = (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)).sum(dim=1).to(x.dtype)
output = (
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
.sum(dim=1)
.to(x.dtype)
)
return output
@@ -638,10 +701,11 @@ class KunlunOps:
prompt_lods_cpu: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
) -> torch.Tensor:
) -> torch.Tensor:
"""mla pa block"""
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
device=hidden_states.device)
output = torch.empty(
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
)
kunlun_ops.xft_multi_head_latent_page_attention_block(
hidden_states,
q_lora_rank,
@@ -679,7 +743,6 @@ class KunlunOps:
)
return output
def fused_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
@@ -695,25 +758,34 @@ class KunlunOps:
)
return output
def fused_recurrent_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
h0_source: torch.Tensor,
output_final_state: bool,
use_qk_l2norm_in_kernel: bool,
cu_seqlens: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
'''
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制
'''
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
h0_source: torch.Tensor,
output_final_state: bool,
use_qk_l2norm_in_kernel: bool,
cu_seqlens: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
"""
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
cu_seqlens)
return (o, final_state)
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
q,
k,
v,
g,
beta,
scale,
h0_source,
output_final_state,
use_qk_l2norm_in_kernel,
cu_seqlens,
)
return (o, final_state)

View File

@@ -9,60 +9,196 @@
# ruff: noqa: E501
import warnings
from typing import Optional
import torch.nn.functional as F
import cocopod # noqa
import torch
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_fwd_o
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from .cumsum import chunk_local_cumsum
from .index import prepare_chunk_indices, prepare_chunk_offsets
from .l2norm import l2norm_fwd
from .solve_tril import solve_tril
from .utils import SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd
from .index import prepare_chunk_indices
import xspeedgate_ops
import cocopod
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
chunk_size=64
A = -A.transpose(1,2)
def torch_solve_tril(
A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
output_dtype: torch.dtype = torch.float,
):
chunk_size = 64
A = -A.transpose(1, 2)
sequence_length = A.shape[-2]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
# mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
# A = A.masked_fill(mask, 0)
for i in range(1, chunk_size):
row = A[..., i, :i].clone()
sub = A[..., :i, :i].clone()
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[:,:,:sequence_length,:].transpose(1,2)
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[
:, :, :sequence_length, :
].transpose(1, 2)
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
A = chunk_scaled_dot_kkt_fwd(k=k,
beta=beta,
g_cumsum=g,
cu_seqlens=cu_seqlens,
output_dtype=q.dtype)
#kernel版
torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
chunk_indices = prepare_chunk_indices(
cu_seqlens, 64) if cu_seqlens is not None else None
def recompute_w_u_fwd_torch(
k: torch.Tensor, # [B, T, H, K]
v: torch.Tensor, # [B, T, H, V]
beta: torch.Tensor, # [B, T, H]
g: torch.Tensor, # [B, T, H]
A: torch.Tensor, # [B, H, T, T]
):
"""
最简单版本假设等长序列key和value头数相同
"""
chunk_size = 64
num_v_heads, num_k_heads = v.shape[2], k.shape[2]
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k, v, beta, g, A = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (k, v, beta, g, A)
]
batch_size, num_heads, sequence_length, k_head_dim = k.shape
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
k = F.pad(k, (0, 0, 0, pad_size))
v = F.pad(v, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
v_beta = v * beta.unsqueeze(-1)
k_beta = k * beta.unsqueeze(-1)
k, v, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (k, v, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
u = A @ v_beta
w = A @ (k_beta * g.exp().unsqueeze(-1))
w = (
w.reshape(w.shape[0], w.shape[1], -1, w.shape[-1])[:, :, :sequence_length, :]
.transpose(1, 2)
.contiguous()
)
u = (
u.reshape(u.shape[0], u.shape[1], -1, u.shape[-1])[:, :, :sequence_length, :]
.transpose(1, 2)
.contiguous()
)
return w, u
def split_by_value(tensor, chunk_size=64):
indices = tensor.tolist()
result = set(indices) # 使用集合避免重复
for i in range(len(indices) - 1):
start = indices[i]
end = indices[i + 1]
# 计算第一个对齐边界
# 我们要找的是 start + n*chunk_size其中n是使结果大于start的最小整数
first_boundary = start + chunk_size
# 在(start, end)范围内插入所有对齐边界
boundary = first_boundary
while boundary < end:
result.add(boundary)
boundary += chunk_size
return torch.tensor(sorted(result), dtype=tensor.dtype, device=tensor.device)
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
):
chunk_size = 64
chunk_indices = (
prepare_chunk_indices(cu_seqlens, 64) if cu_seqlens is not None else None
)
chunk_offsets = (
prepare_chunk_offsets(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
# !
# g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
g = torch.ops.xspeedgate_ops.chunk_local_cumsum(
g,
chunk_size=64,
reverse=False,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
head_first=False,
)
# !
# A = chunk_scaled_dot_kkt_fwd(k=k,
# beta=beta,
# g_cumsum=g,
# cu_seqlens=cu_seqlens,
# output_dtype=q.dtype)
A = torch.ops.xspeedgate_ops.chunk_scaled_dot_kkt_fwd(
k, beta, g, cu_seqlens, chunk_indices, chunk_size
)
# torch版
# if get_tensor_model_parallel_rank() == 0:
# torch.save(A, "A_in")
# torch.save(cu_seqlens, "cu_seqlens")
# A2 = A.clone()
torch.ops.xspeedgate_ops.solve_tril_ns(A, cu_seqlens, chunk_indices, chunk_size)
# !
# torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
# if get_tensor_model_parallel_rank() == 0:
# err = torch.max(torch.abs(A - A2))
# print("err", err)
# if err > 1e-3:
# raise
# A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
# for i in range(len(cu_seqlens)-1):
# A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
# A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
"""
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
for i in range(len(cu_seqlens)-1):
k_i = k[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
v_i = v[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
beta_i = beta[:, cu_seqlens[i]:cu_seqlens[i+1], :]
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
g_i = g[:, cu_seqlens[i]:cu_seqlens[i+1], :]
w_i, u_i = recompute_w_u_fwd_torch(
k=k_i,
v=v_i,
beta=beta_i,
A=A_i,
g=g_i,
)
w[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = w_i
u[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = u_i
"""
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
k=k,
v=v,
@@ -71,17 +207,63 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
g_cumsum=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_size=64
chunk_size=64,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
"""
w, u = recompute_w_u_fwd(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
output_final_state=output_final_state,
v=v,
beta=beta,
A=A,
g_cumsum=g,
cu_seqlens=cu_seqlens,
)
"""
# i
# import os
# if not os.path.exists("/qwen-next/in"):
# os.makedirs("/qwen-next/in")
# torch.save(k, "/qwen-next/in/k.pt")
# torch.save(u, "/qwen-next/in/u.pt")
# torch.save(w, "/qwen-next/in/w.pt")
# torch.save(g, "/qwen-next/in/g.pt")
# torch.save(initial_state, "/qwen-next/in/initial_state.pt")
# torch.save(cu_seqlens, "/qwen-next/in/cu_seqlens.pt")
# torch.save(chunk_indices, "/qwen-next/in/chunk_indices.pt")
# torch.save(chunk_offsets.to(torch.int32), "/qwen-next/in/chunk_offsets.pt")
# torch.save(chunk_size, "/qwen-next/in/chunk_size.pt")
# torch.save(output_final_state, "/qwen-next/in/output_final_state.pt")
h, v_new, final_state = torch.ops.xspeedgate_ops.chunk_gated_delta_rule_fwd_h(
k,
u,
w,
g,
initial_state,
cu_seqlens,
chunk_indices,
chunk_offsets.to(torch.int32),
chunk_size,
output_final_state,
True,
)
# h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
# k=k,
# w=w,
# u=u,
# g=g,
# initial_state=initial_state,
# output_final_state=output_final_state,
# cu_seqlens=cu_seqlens,
# )
# if not os.path.exists("/qwen-next/out"):
# os.makedirs("/qwen-next/out")
# torch.save(h, "/qwen-next/out/h.pt")
# torch.save(v_new, "/qwen-next/out/v_new.pt")
# torch.save(final_state, "/qwen-next/out/final_state.pt")
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
q=q,
k=k,
@@ -91,8 +273,19 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_size=64
chunk_size=64,
)
"""
o = chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
)
"""
if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None
elif SUPPRESS_LEVEL >= 3:
@@ -103,18 +296,20 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False):
@torch.amp.custom_fwd(device_type="cuda")
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
k = l2norm_fwd(k)
@@ -136,17 +331,19 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@torch.compiler.disable
def chunk_gated_delta_rule(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False):
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
Args:
q (torch.Tensor):
@@ -211,42 +408,85 @@ def chunk_gated_delta_rule(q: torch.Tensor,
)
"""
assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert len(
beta.shape
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
assert (
q.dtype != torch.float32
), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert (
len(beta.shape) == 3
), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
if head_first:
raise DeprecationWarning(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead.",
stacklevel=2)
stacklevel=2,
)
q, k, v, beta, g = map(
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
(q, k, v, beta, g))
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
)
if not head_first and q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2)
stacklevel=2,
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
if initial_state is not None and initial_state.shape[0] != len(
cu_seqlens) - 1:
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1]**-0.5
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
use_qk_l2norm_in_kernel)
scale = k.shape[-1] ** -0.5
if False:
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
g = g.contiguous()
beta = beta.contiguous()
initial_state = initial_state.contiguous()
o = torch.empty_like(v)
final_state = torch.empty_like(initial_state)
import kunlun_ops
kunlun_ops.gated_delta_rule(
q,
k,
v,
initial_state,
g,
beta,
final_state,
o,
scale,
cu_seqlens.cpu(),
cu_seqlens,
cu_seqlens.cpu(),
cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
else:
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
output_final_state,
cu_seqlens,
use_qk_l2norm_in_kernel,
)
if head_first:
o = rearrange(o, 'b t h ... -> b h t ...')
o = rearrange(o, "b t h ... -> b h t ...")
return o, final_state

View File

@@ -12,21 +12,21 @@
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import exp
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
@triton.heuristics({
'USE_G': lambda args: args['g'] is not None,
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
})
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
# @triton.autotune(
# configs=[
# triton.Config({
@@ -40,7 +40,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
# ],
# key=['H', 'K', 'V', 'BT'],
# )
@triton.jit(do_not_specialize=['T'])
@triton.jit(do_not_specialize=["T"])
def chunk_fwd_kernel_o(
q,
k,
@@ -67,10 +67,12 @@ def chunk_fwd_kernel_o(
if IS_VARLEN:
i_tg = i_t
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
chunk_indices + i_t * 2 + 1
).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
cu_seqlens + i_n + 1
).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
else:
@@ -89,12 +91,15 @@ def chunk_fwd_kernel_o(
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
(BT, BK), (1, 0))
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT),
(BK, BT), (0, 1))
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV),
(BK, BV), (1, 0))
p_q = tl.make_block_ptr(
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_k = tl.make_block_ptr(
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
)
p_h = tl.make_block_ptr(
h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
)
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT]
@@ -109,8 +114,8 @@ def chunk_fwd_kernel_o(
if USE_G:
g += bos * H + i_h
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_o = b_o * tl.exp(b_g)[:, None]
b_A = b_A * tl.exp(b_g[:, None] - b_g[None, :])
@@ -120,10 +125,12 @@ def chunk_fwd_kernel_o(
# b_A = tl.where(m_A, b_A, 0)
b_A = tl.where(o_t[:, None] >= o_t[None, :], b_A, 0)
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
(BT, BV), (1, 0))
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
(BT, BV), (1, 0))
p_v = tl.make_block_ptr(
v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_o = tl.make_block_ptr(
o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1))
# to fix mma -> mma layout conversion
@@ -133,48 +140,29 @@ def chunk_fwd_kernel_o(
def chunk_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None, # cumsum of log decay
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None, # cumsum of log decay
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
) -> torch.Tensor:
_, T, _, _, _ = *q.shape, v.shape[-1]
if FLA_GDN_FIX_BT:
BT = 64
else:
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
o = torch.empty_like(v)
def grid(meta):
return (triton.cdiv(V, meta['BV']), NT, B * H)
chunk_fwd_kernel_o[grid](
q,
k,
v,
h,
g,
o,
cu_seqlens,
chunk_indices,
scale,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=64,
BV=32
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
q, k, v, h, g, scale, cu_seqlens, chunk_indices, chunk_size
)
return o

View File

@@ -9,28 +9,28 @@
# ruff: noqa: E501
from typing import Optional
import torch
import kunlun_ops
import torch
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False):
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwdv2(
q.contiguous(),
k.contiguous(),
@@ -44,7 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
h0_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
is_h0_transposed=True
is_h0_transposed=True,
)
return o, final_state
@@ -130,9 +130,10 @@ def fused_recurrent_gated_delta_rule(
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
f"Please flatten variable-length inputs before processing."
)
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
else:
assert scale > 0, "scale must be positive"
if beta is None:

View File

@@ -10,22 +10,21 @@
import os
from typing import Optional
import kunlun_ops
import torch
from vllm.triton_utils import tl, triton
import kunlun_ops
BT_LIST = [8, 16, 32, 64, 128]
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
@triton.autotune(configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16, 32]
],
key=['D'])
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
],
key=["D"],
)
@triton.jit
def l2norm_fwd_kernel1(
x,
@@ -49,11 +48,14 @@ def l2norm_fwd_kernel1(
tl.store(y + cols, b_y, mask=mask)
@triton.autotune(configs=[
triton.Config({'BT': BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
],
key=['D'])
@triton.autotune(
configs=[
triton.Config({"BT": BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16]
for BT in BT_LIST
],
key=["D"],
)
@triton.jit(do_not_specialize=["NB"])
def l2norm_fwd_kernel(
x,
@@ -87,67 +89,9 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
def l2norm_fwd_triton(x: torch.Tensor,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None):
x_shape_og = x.shape
x = x.view(-1, x.shape[-1])
# allocate output
if output_dtype is None:
y = torch.empty_like(x)
else:
y = torch.empty_like(x, dtype=output_dtype)
assert y.stride(-1) == 1
T, D = x.shape[0], x.shape[-1]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
if D > BD:
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
if not USE_DEFAULT_FLA_NORM:
MBLOCK = 32
# M, N = x.shape
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
x,
y,
eps,
T,
D,
MBLOCK,
)
else:
if D <= 512:
NB = triton.cdiv(T, 2048)
def grid(meta):
return (triton.cdiv(T, meta['BT']), )
l2norm_fwd_kernel[grid](
x,
y,
eps,
NB=NB,
T=T,
D=D,
BD=BD,
)
else:
l2norm_fwd_kernel1[(T, )](
x,
y,
eps=eps,
D=D,
BD=BD,
)
return y.view(x_shape_og)
def l2norm_fwd(x: torch.Tensor,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None):
def l2norm_fwd(
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
):
out = torch.empty_like(x)
kunlun_ops.l2norm(x, out, eps)
kunlun_ops.l2norm(x, out, eps)
return out

View File

@@ -19,20 +19,21 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from vllm.triton_utils import tl, triton
from .utils import input_guard
def rms_norm_ref(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
upcast=True):
def rms_norm_ref(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
upcast=True,
):
dtype = x.dtype
weight = weight.float()
bias = bias.float() if bias is not None else None
@@ -43,12 +44,10 @@ def rms_norm_ref(x,
x = x * F.silu(z)
if group_size is None:
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
weight)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
else:
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
eps)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
if bias is not None:
out = out + bias
@@ -57,10 +56,12 @@ def rms_norm_ref(x,
return out.to(dtype)
@triton.heuristics({
"HAS_BIAS": lambda args: args["B"] is not None,
"HAS_Z": lambda args: args["Z"] is not None,
})
@triton.heuristics(
{
"HAS_BIAS": lambda args: args["B"] is not None,
"HAS_Z": lambda args: args["Z"] is not None,
}
)
@triton.jit
def layer_norm_fwd_kernel(
X, # pointer to the input
@@ -97,17 +98,17 @@ def layer_norm_fwd_kernel(
B += group * N
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
x *= z * tl.sigmoid(z)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.)
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
@@ -149,46 +150,50 @@ def layer_norm_fwd(
# weight = weight.reshape(N)
# print("weight",weight.shape)
# print("x",x.shape)
assert weight.shape == (N, )
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N, )
assert bias.shape == (N,)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
device=x.device) if not is_rms_norm else None
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
mean = (
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
layer_norm_fwd_kernel[grid](x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps)
layer_norm_fwd_kernel[grid](
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
)
return out, mean, rstd
@@ -196,17 +201,18 @@ class LayerNormFn(torch.autograd.Function):
@input_guard
@staticmethod
def forward(ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
def forward(
ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
@@ -223,16 +229,15 @@ class LayerNormFn(torch.autograd.Function):
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
# y, mean, rstd = torch.ops.xspeedgate_ops.rms_norm_gated_fwd(x, weight, bias, eps, z, group_size, norm_before_gate, is_rms_norm)
y = torch.empty_like(x)
mean, rstd = None, None
import kunlun_ops
kunlun_ops.rms_norm_gated(
x, y, z, weight, eps, group_size, norm_before_gate, is_rms_norm
)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
@@ -242,27 +247,27 @@ class LayerNormFn(torch.autograd.Function):
return y.reshape(x_shape_og)
def layernorm_fn(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False):
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
norm_before_gate, is_rms_norm)
def layernorm_fn(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
)
def rmsnorm_fn(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True):
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
norm_before_gate, True)
def rmsnorm_fn(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, True
)
class LayerNormGated(nn.Module):
@@ -294,15 +299,16 @@ class LayerNormGated(nn.Module):
torch.nn.init.zeros_(self.bias)
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return layernorm_fn(x,
self.weight,
self.bias,
z=z,
group_size=self.group_size,
eps=self.eps,
norm_before_gate=self.norm_before_gate)
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return layernorm_fn(
x,
self.weight,
self.bias,
z=z,
group_size=self.group_size,
eps=self.eps,
norm_before_gate=self.norm_before_gate,
)
class RMSNormGated(nn.Module):
@@ -332,12 +338,13 @@ class RMSNormGated(nn.Module):
torch.nn.init.ones_(self.weight)
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return rmsnorm_fn(x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate)
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return rmsnorm_fn(
x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
)

View File

@@ -11,7 +11,6 @@
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
@@ -28,6 +27,7 @@ RESOLUTION = {
torch.complex64: 1.3e-6,
}
def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
assert res.dtype == dtype
ref = ref.to(dtype)
@@ -35,6 +35,7 @@ def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
rtol = RESOLUTION[dtype]
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
# @triton.autotune(
# configs=[
@@ -80,7 +81,6 @@ def recompute_u_fwd_kernel(
p_beta = tl.make_block_ptr(
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
@@ -110,7 +110,6 @@ def recompute_u_fwd_kernel(
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
# @triton.autotune(
# configs=[
@@ -195,53 +194,12 @@ def recompute_w_u_fwd(
A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor],
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
BT = A.shape[-1]
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 64
BV = 64
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
recompute_u_fwd_kernel[(NT, B * H)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g_cumsum,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
recompute_w_fwd_kernel[(NT, B * H)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g_cumsum,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
k, v, beta, g_cumsum, A, cu_seqlens, chunk_indices, chunk_size=BT
)
return w, u
return w, u

View File

@@ -15,51 +15,52 @@
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
from vllm.model_executor.layers import layernorm
from typing import Optional, Union
import kunlun_ops
import torch
from vllm.model_executor.layers import layernorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm
def vllm_kunlun_forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""forward_cuda"""
if x.is_contiguous() == False:
# kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually
x = x.contiguous()
if self.variance_size_override is not None:
return self.forward_native(x, residual)
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""forward_cuda"""
if not x.is_contiguous():
# kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually
x = x.contiguous()
if self.variance_size_override is not None:
return self.forward_native(x, residual)
if residual is not None:
# residual_output = torch.empty_like(residual)
torch.ops._C.add_rmsnorm(
x,
residual,
residual_output=residual,
weight=self.weight.data,
eps=self.variance_epsilon,
output=x
)
return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
if residual is not None:
# residual_output = torch.empty_like(residual)
torch.ops._C.add_rmsnorm(
x,
self.weight.data,
out,
self.variance_epsilon,
residual,
residual_output=residual,
weight=self.weight.data,
eps=self.variance_epsilon,
output=x,
)
return out
return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
x,
self.weight.data,
out,
self.variance_epsilon,
)
return out
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
RMSNorm.forward = vllm_kunlun_forward_cuda
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
@staticmethod
def forward_xpu(
@@ -68,30 +69,42 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
x: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if x.is_contiguous() == False:
if not x.is_contiguous():
# kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually
x = x.contiguous()
if x.dim() == 3:
x_shape = x.shape
x = x.view(-1, x.size(-1))
if residual is not None:
torch.ops._C.add_rmsnorm(
out = torch.empty_like(x)
out_residual = torch.empty_like(residual)
torch.ops._C.gemma_add_rmsnorm(
x,
residual,
residual_output=residual,
weight=weight+1,
residual_output=out_residual,
weight=weight,
eps=variance_epsilon,
output=x
output=out,
)
else:
out = torch.empty_like(x)
torch.ops._C.gemma_rmsnorm(
x,
weight,
out,
variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
x,
weight+1,
out,
variance_epsilon,
)
return out
if x.dim() == 3:
x = x.view(x_shape)
if out is not None:
out = out.view(x_shape)
if residual is not None:
return out, out_residual
else:
return out
def forward_cuda(
self,
@@ -99,16 +112,17 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if torch.compiler.is_compiling():
self.forward_static = self.forward_xpu # only use in cudagraph
self.forward_static = self.forward_xpu # only use in cudagraph
return self.forward_native(x, residual)
if not getattr(self, "_is_compiled", False):
self.forward_static = torch.compile( # type: ignore
self.forward_static, backend="aot_eager")
self.forward_static, backend="aot_eager"
)
self._is_compiled = True
return self.forward_native(x, residual)
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
RMSNorm.forward = vllm_kunlun_forward_cuda
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,390 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Backend for GatedDeltaNet attention."""
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends import gdn_attn
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class GDNAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
return GDNAttentionMetadataBuilder
@dataclass
class GDNAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_spec_decodes: int
num_spec_decode_tokens: int
num_actual_tokens: int
has_initial_state: Optional[torch.Tensor] = None
has_initial_state_cpu: Optional[torch.Tensor] = None
spec_query_start_loc: Optional[torch.Tensor] = (
None # shape: [num_spec_decodes + 1,]
)
non_spec_query_start_loc: Optional[torch.Tensor] = (
None # shape: [batch - num_spec_decodes + 1,]
)
spec_state_indices_tensor: Optional[torch.Tensor] = None # shape: [batch, num_spec]
non_spec_state_indices_tensor: Optional[torch.Tensor] = (
None # shape: [batch - num_spec_decodes,]
)
non_spec_state_indices_tensor_cpu: Optional[torch.Tensor] = None
spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,]
spec_token_masks: Optional[torch.Tensor] = (
None # shape: [num_prefill_tokens + num_decode_tokens,]
)
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: int = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
assert isinstance(kv_cache_spec, MambaSpec)
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.speculative_config = vllm_config.speculative_config
self.kv_cache_spec = kv_cache_spec
if self.speculative_config:
self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self._init_reorder_batch_threshold(1, self.use_spec_decode)
self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
self.compilation_config.max_capture_size,
)
self.spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1),
dtype=torch.int32,
device=device,
)
self.non_spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.spec_sequence_masks = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.bool,
device=device,
)
self.spec_token_masks = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.bool,
device=device,
)
self.spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
device=device,
)
self.non_spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
device=device,
)
self.num_accepted_tokens = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
def build( # type: ignore[override]
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
num_accepted_tokens: Optional[torch.Tensor] = None,
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
fast_build: bool = False,
) -> GDNAttentionMetadata:
m = common_attn_metadata
query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device)
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if (
not self.use_spec_decode
or num_decode_draft_tokens_cpu is None
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
.sum()
.item()
== 0
):
spec_sequence_masks = None
num_spec_decodes = 0
else:
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
num_spec_decodes = spec_sequence_masks.sum().item()
if num_spec_decodes == 0:
spec_sequence_masks = None
else:
spec_sequence_masks = spec_sequence_masks.to(
query_start_loc.device, non_blocking=True
)
if spec_sequence_masks is None:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(m, decode_threshold=1)
)
num_spec_decode_tokens = 0
spec_token_masks = None
spec_state_indices_tensor = None
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
num_accepted_tokens = None
else:
query_lens = query_start_loc[1:] - query_start_loc[:-1]
non_spec_query_lens = query_lens[~spec_sequence_masks]
num_decodes = (non_spec_query_lens == 1).sum().item()
num_prefills = non_spec_query_lens.size(0) - num_decodes
num_decode_tokens = num_decodes
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
if num_prefills == 0 and num_decodes == 0:
spec_token_masks = torch.ones(
(
min(
num_spec_decodes * (self.num_spec + 1),
query_start_loc[-1].item(),
)
),
dtype=torch.bool,
device=query_start_loc.device,
)
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
non_spec_query_start_loc = None
else:
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens
)
spec_state_indices_tensor = m.block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
]
non_spec_state_indices_tensor = m.block_table_tensor[
~spec_sequence_masks, 0
]
spec_query_start_loc = torch.zeros(
num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device,
)
torch.cumsum(
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
)
non_spec_query_start_loc = torch.zeros(
query_lens.size(0) - num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device,
)
torch.cumsum(
query_lens[~spec_sequence_masks],
dim=0,
out=non_spec_query_start_loc[1:],
)
num_spec_decode_tokens = (
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
)
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
if num_prefills > 0:
has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks]
has_initial_state_cpu = has_initial_state.cpu()
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(non_spec_query_start_loc)
)
else:
has_initial_state = None
has_initial_state_cpu = None
num_actual_tokens = (
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
)
# prepare tensors for cudagraph
#
# With speculative decoding, the xgrammar backend may rollback tokens
# and causing some sequences has less draft tokens than self.num_spec.
#
# In above cases, the max possible batch size for n tokens, can be
# min(n, cudagraph_max_bs).
if (
self.use_full_cuda_graph
and num_prefills == 0
and num_decodes == 0
and num_spec_decodes <= self.decode_cudagraph_max_bs
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True
)
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
self.spec_sequence_masks[:num_spec_decodes].copy_(
spec_sequence_masks, non_blocking=True
)
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
spec_sequence_masks[num_spec_decodes:].fill_(False)
assert spec_token_masks is not None
self.spec_token_masks[: spec_token_masks.size(0)].copy_(
spec_token_masks, non_blocking=True
)
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
spec_token_masks[spec_token_masks.size(0) :].fill_(False)
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
spec_query_start_loc, non_blocking=True
)
spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index]
spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)
self.num_accepted_tokens[:num_spec_decodes].copy_(
num_accepted_tokens, non_blocking=True
)
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
num_accepted_tokens[num_spec_decodes:].fill_(1)
if (
self.use_full_cuda_graph
and num_prefills == 0
and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs
):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
batch_size = num_actual_tokens
self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True
)
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
:batch_size
]
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
non_spec_query_start_loc, non_blocking=True
)
non_spec_num_query_tokens = non_spec_query_start_loc[
-1
] # type: ignore[index]
non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)
if num_accepted_tokens is not None:
num_accepted_tokens = num_accepted_tokens.to(torch.int32)
attn_metadata = GDNAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens,
num_actual_tokens=num_actual_tokens,
has_initial_state=has_initial_state,
has_initial_state_cpu=has_initial_state_cpu,
spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc,
spec_state_indices_tensor=spec_state_indices_tensor,
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
non_spec_state_indices_tensor_cpu=(
non_spec_state_indices_tensor.cpu()
if non_spec_state_indices_tensor is not None
else None
),
spec_sequence_masks=spec_sequence_masks,
spec_token_masks=spec_token_masks,
num_accepted_tokens=num_accepted_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert (
m.num_reqs <= self.decode_cudagraph_max_bs
and m.num_actual_tokens <= self.decode_cudagraph_max_bs
), (
f"GDN only supports decode-only full CUDAGraph capture. "
f"Make sure batch size ({m.num_reqs}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
f"and number of tokens ({m.num_actual_tokens}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})."
)
num_accepted_tokens = torch.diff(m.query_start_loc)
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
gdn_attn.GDNAttentionMetadata = GDNAttentionMetadata
gdn_attn.GDNAttentionMetadataBuilder = GDNAttentionMetadataBuilder

View File

@@ -770,24 +770,14 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
value = value.contiguous()
if key_cache.is_contiguous():
kunlun_ops.reshape_and_cache(
key[: attn_metadata.num_actual_tokens],
value[: attn_metadata.num_actual_tokens],
key_cache,
value_cache,
updated_slot_mapping,
)
else:
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
kunlun_ops.reshape_and_cache_flash(
key,
value,
cast_key_cache,
cast_value_cache,
updated_slot_mapping,
)
kunlun_ops.reshape_and_cache_flash(
key[: attn_metadata.num_actual_tokens],
value[: attn_metadata.num_actual_tokens],
key_cache,
value_cache,
updated_slot_mapping,
BLHD_LAYOUT=False,
)
assert attn_type == AttentionType.DECODER
# Decoder self-attention supports chunked prefill.

View File

@@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Union
import kunlun_ops
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module):
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
'''
"""
Args:
metadata:
Metadata for spec decoding.
@@ -81,7 +80,7 @@ class RejectionSampler(nn.Module):
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
"""
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
@@ -124,11 +123,11 @@ class RejectionSampler(nn.Module):
"""
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size))
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size
)
outputs = [
row[valid_mask[i]].tolist()
for i, row in enumerate(output_token_ids_np)
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
]
return outputs
@@ -179,25 +178,15 @@ def rejection_sample(
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids,
draft_token_ids,
target_argmax,
bonus_token_ids,
)
else:
rejection_greedy_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
num_draft_tokens,
max_spec_len,
is_greedy,
)
kunlun_ops.rejection_greedy_sample(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
)
if sampling_metadata.all_greedy:
return output_token_ids
@@ -222,8 +211,9 @@ def rejection_sample(
sampling_metadata,
device,
)
bonus_token_ids = bonus_token_ids.squeeze(1)
rejection_random_sample_pytorch(
kunlun_ops.rejection_random_sample(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -235,8 +225,7 @@ def rejection_sample(
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
# num_warps=1,
no_draft_probs=draft_probs is None,
)
return output_token_ids
@@ -374,7 +363,7 @@ def generate_uniform_probs(
random values in the range [0, 1).
"""
uniform_probs = torch.rand(
(num_tokens, ),
(num_tokens,),
dtype=torch.float32,
device=device,
)
@@ -422,7 +411,7 @@ def sample_recovered_tokens(
q[i].exponential_(generator=generator)
recovered_token_ids = torch.empty_like(draft_token_ids)
sample_recovered_tokens_pytorch(
kunlun_ops.sample_recovered_tokens(
recovered_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -430,16 +419,16 @@ def sample_recovered_tokens(
target_probs,
q,
vocab_size,
IS_NGRAM=draft_probs is None,
no_draft_probs=draft_probs is None,
)
return recovered_token_ids
def rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids, # [batch_size, 2]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
output_token_ids, # [batch_size, 2]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
):
batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0)
@@ -447,73 +436,72 @@ def rejection_greedy_sample_spec_len_1_pytorch(
accept_req_mask = draft_token_ids == target_argmax
output_token_ids[:, 0] = target_argmax
bonus_token_ids = bonus_token_ids.squeeze(1)
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids,
output_token_ids[:, 1])
output_token_ids[:, 1] = torch.where(
accept_req_mask, bonus_token_ids, output_token_ids[:, 1]
)
def rejection_greedy_sample_pytorch(
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
draft_tokens_per_req, # [batch_size], list
max_spec_len,
is_greedy=None, # [batch_size] or None
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
draft_tokens_per_req, # [batch_size], list
max_spec_len,
is_greedy=None, # [batch_size] or None
):
batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0)
device = output_token_ids.device
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
device, non_blocking=True)
device, non_blocking=True
)
if is_greedy is None:
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
start_indices = cu_num_draft_tokens - draft_tokens_per_req
req_ids = torch.arange(batch_size, device=device)
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
token_positions = torch.arange(
num_tokens, device=device) - start_indices[token_req_ids]
token_positions = (
torch.arange(num_tokens, device=device) - start_indices[token_req_ids]
)
# Find the first mismatch position of each request.
mismatch_global = (draft_token_ids != target_argmax)
mismatch_global = draft_token_ids != target_argmax
if max_spec_len == 0:
first_mismatch_pos_per_req = torch.zeros(batch_size,
dtype=torch.long,
device=device)
first_mismatch_pos_per_req = torch.zeros(
batch_size, dtype=torch.long, device=device
)
else:
# [bs, max_spec_len]
pos_matrix = torch.full((batch_size, max_spec_len),
-1,
dtype=torch.long,
device=device)
pos_matrix = torch.full(
(batch_size, max_spec_len), -1, dtype=torch.long, device=device
)
pos_matrix[token_req_ids, token_positions] = token_positions
mismatch_matrix = torch.full((batch_size, max_spec_len),
False,
dtype=torch.bool,
device=device)
mismatch_matrix = torch.full(
(batch_size, max_spec_len), False, dtype=torch.bool, device=device
)
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
max_spec_len * 2)
mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2)
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
no_mismatch_mask = first_mismatch_pos_per_req == max_spec_len * 2
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
no_mismatch_mask]
no_mismatch_mask
]
# Copy matched target tokens into output.
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
draft_tokens_per_req)
copy_indices = torch.arange(max_spec_len + 1,
device=device).expand(batch_size, -1)
copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req)
copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1)
copy_mask = copy_indices < copy_len.unsqueeze(1)
greedy_mask = is_greedy.unsqueeze(1)
final_copy_mask = copy_mask & greedy_mask
global_idx = start_indices.unsqueeze(1) + copy_indices
output_token_ids[final_copy_mask] = target_argmax[
global_idx[final_copy_mask]].to(output_token_ids.dtype)
output_token_ids[final_copy_mask] = target_argmax[global_idx[final_copy_mask]].to(
output_token_ids.dtype
)
# Fill bonus token.
needs_bonus = is_greedy & (first_mismatch_pos_per_req
>= draft_tokens_per_req)
needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req)
if torch.any(needs_bonus):
bonus_rows = torch.where(needs_bonus)[0]
bonus_cols = draft_tokens_per_req[bonus_rows]
@@ -556,11 +544,9 @@ def rejection_random_sample_pytorch(
if IS_NGRAM:
draft_prob = 1.0
else:
draft_prob = draft_probs[start_idx + pos,
draft_token_id].item()
draft_prob = draft_probs[start_idx + pos, draft_token_id].item()
target_prob = target_probs[start_idx + pos,
draft_token_id].item()
target_prob = target_probs[start_idx + pos, draft_token_id].item()
uniform_prob = uniform_probs[start_idx + pos].item()
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
@@ -629,12 +615,11 @@ def sample_recovered_tokens_pytorch(
else:
draft_p = draft_probs[token_idx].clone()
target_p = target_probs[token_idx].clone()
prob = torch.maximum(target_p - draft_p,
torch.tensor(0.0, device=target_p.device))
prob = torch.maximum(
target_p - draft_p, torch.tensor(0.0, device=target_p.device)
)
q_values = torch.full((vocab_size, ),
float('-inf'),
device=q.device)
q_values = torch.full((vocab_size,), float("-inf"), device=q.device)
q_values[:vocab_size] = q[req_idx, :vocab_size]
recovered_id = torch.argmax(prob / q_values).item()
@@ -642,4 +627,3 @@ def sample_recovered_tokens_pytorch(
if IS_NGRAM:
target_probs[token_idx, draft_token_id] = orig_prob

View File

@@ -337,5 +337,5 @@ def prepare_next_token_ids_padded(
return next_token_ids, valid_sampled_tokens_count
EagleProposer.propose = propose
# EagleProposer.propose = propose
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded

View File

@@ -407,7 +407,7 @@ def add_rmsnorm(
) -> None:
kunlun_ops.add_rmsnorm(
x,
y, # 原来写 residual这里其实是 y
y,
residual_output=residual_output,
weight=weight,
eps=eps,
@@ -523,6 +523,145 @@ def _fake_add_rmsnorm(
add_rmsnorm.register_fake(_fake_add_rmsnorm)
@custom_op("_C::gemma_add_rmsnorm", mutates_args=())
def gemma_add_rmsnorm(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_add_rmsnorm wrapper")
kunlun_ops.gemma_add_rmsnorm(
x,
y,
weight=weight,
output=output,
eps=eps,
enable_pdl=enable_pdl,
interweaved=interweaved,
store_output_before_norm=store_output_before_norm,
bias=bias,
smooth=smooth,
residual_output=residual_output,
force_sdnn=force_sdnn,
)
@impl("_C::gemma_add_rmsnorm", "CUDA")
def gemma_add_rmsnorm_cuda(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_add_rmsnorm_cuda wrapper")
kunlun_ops.gemma_add_rmsnorm(
x,
y,
weight=weight,
output=output,
eps=eps,
enable_pdl=enable_pdl,
interweaved=interweaved,
store_output_before_norm=store_output_before_norm,
bias=bias,
smooth=smooth,
residual_output=residual_output,
force_sdnn=force_sdnn,
)
def _fake_gemma_add_rmsnorm(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
force_sdnn: bool = False,
):
output.fake_shape = x.shape
output.fake_dtype = x.dtype
return None
gemma_add_rmsnorm.register_fake(_fake_gemma_add_rmsnorm)
@custom_op("_C::gemma_rmsnorm", mutates_args=())
def gemma_rmsnorm(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweave: bool = False,
bias: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_rmsnorm wrapper")
kunlun_ops.gemma_rmsnorm(
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
)
@impl("_C::gemma_rmsnorm", "CUDA")
def gemma_rmsnorm_cuda(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweave: bool = False,
bias: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_rmsnorm_cuda wrapper")
kunlun_ops.gemma_rmsnorm(
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
)
def _fake_gemma_rmsnorm(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweave: bool = False,
bias: torch.Tensor = None,
force_sdnn: bool = False,
):
# 设置 shape/dtype但不返回值
output.fake_shape = x.shape
output.fake_dtype = x.dtype
return None
gemma_rmsnorm.register_fake(_fake_gemma_rmsnorm)
@custom_op("_C::split_norm_rope_neox", mutates_args=())
def split_norm_rope_neox(
q_emb: torch.Tensor,