[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)

Signed-off-by: xyDong0223 <dongxinyu03@baidu.com>
Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
chanzhennan
2026-02-28 11:15:50 +08:00
committed by GitHub
parent 153093d3b3
commit 82544aa0cc
17 changed files with 2668 additions and 1532 deletions

View File

@@ -3,95 +3,113 @@ from vllm import ModelRegistry
def register_model(): def register_model():
# from .demo_model import DemoModel # noqa: F401 # from .demo_model import DemoModel # noqa: F401
from .qwen2_vl import Qwen2VLForConditionalGeneration #noqa: F401 from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration # noqa: F401
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration #noqa: F401 from .qwen2_vl import Qwen2VLForConditionalGeneration # noqa: F401
from .qwen3_moe import Qwen3MoeForCausalLM #noqa: F401 from .qwen3_moe import Qwen3MoeForCausalLM # noqa: F401
from .qwen3_vl import Qwen3VLForConditionalGeneration from .qwen3_omni_moe_thinker import ( # noqa: F401
from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration Qwen3OmniMoeThinkerForConditionalGeneration,
from .qwen3_omni_moe_thinker import Qwen3OmniMoeThinkerForConditionalGeneration )
from .qwen3_vl import Qwen3VLForConditionalGeneration # noqa: F401
from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration # noqa: F401
# from .llama4 import Llama4ForCausalLM #noqa: F401 # from .llama4 import Llama4ForCausalLM #noqa: F401
# from .mllama4 import Llama4ForConditionalGeneration #noqa: F401 # from .mllama4 import Llama4ForConditionalGeneration #noqa: F401
# from .deepseek_v2 import KunlunDeepseekV2MoE # from .deepseek_v2 import KunlunDeepseekV2MoE
# ModelRegistry.register_model( # ModelRegistry.register_model(
# "DemoModel", # "DemoModel",
# "vllm_kunlun.model_executor.models.demo_model:DemoModel") # "vllm_kunlun.model_executor.models.demo_model:DemoModel")
ModelRegistry.register_model( ModelRegistry.register_model(
"Qwen2VLForConditionalGeneration", "Qwen2VLForConditionalGeneration",
"vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration") "vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration",
)
ModelRegistry.register_model( ModelRegistry.register_model(
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration") "vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration",
)
ModelRegistry.register_model( ModelRegistry.register_model(
"Qwen3ForCausalLM", "Qwen3ForCausalLM", "vllm_kunlun.models.qwen3:Qwen3ForCausalLM"
"vllm_kunlun.models.qwen3:Qwen3ForCausalLM") )
ModelRegistry.register_model( ModelRegistry.register_model(
"Qwen3MoeForCausalLM", "Qwen3MoeForCausalLM", "vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM"
"vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM") )
ModelRegistry.register_model( ModelRegistry.register_model(
"Qwen3NextForCausalLM", "Qwen3NextForCausalLM", "vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM"
"vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM") )
ModelRegistry.register_model( ModelRegistry.register_model(
"GptOssForCausalLM", "Qwen3NextMTP", "vllm_kunlun.models.qwen3_next_mtp:Qwen3NextMTP"
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM") )
ModelRegistry.register_model( ModelRegistry.register_model(
"InternLM2ForCausalLM", "GlmForCausalLM", "vllm_kunlun.models.glm:GlmForCausalLM"
"vllm_kunlun.models.internlm2:InternLM2ForCausalLM") )
ModelRegistry.register_model( ModelRegistry.register_model(
"InternVLChatModel", "GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM"
"vllm_kunlun.models.internvl:InternVLChatModel") )
ModelRegistry.register_model(
"InternLM2ForCausalLM", "vllm_kunlun.models.internlm2:InternLM2ForCausalLM"
)
ModelRegistry.register_model(
"InternVLChatModel", "vllm_kunlun.models.internvl:InternVLChatModel"
)
ModelRegistry.register_model( ModelRegistry.register_model(
"InternS1ForConditionalGeneration", "InternS1ForConditionalGeneration",
"vllm_kunlun.models.interns1:InternS1ForConditionalGeneration") "vllm_kunlun.models.interns1:InternS1ForConditionalGeneration",
)
ModelRegistry.register_model( ModelRegistry.register_model(
"Qwen3VLForConditionalGeneration", "Qwen3VLForConditionalGeneration",
"vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration") "vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration",
)
ModelRegistry.register_model( ModelRegistry.register_model(
"Qwen3VLMoeForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration",
"vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration") "vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration",
)
ModelRegistry.register_model( ModelRegistry.register_model(
"Qwen3OmniMoeForConditionalGeneration", "Qwen3OmniMoeForConditionalGeneration",
"vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration") "vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration",
)
ModelRegistry.register_model( ModelRegistry.register_model(
"SeedOssForCausalLM", "SeedOssForCausalLM", "vllm_kunlun.models.seed_oss:SeedOssForCausalLM"
"vllm_kunlun.models.seed_oss:SeedOssForCausalLM") )
ModelRegistry.register_model( ModelRegistry.register_model(
"MiMoV2FlashForCausalLM", "MiMoV2FlashForCausalLM",
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM") "vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM",
)
ModelRegistry.register_model( ModelRegistry.register_model(
"GptOssForCausalLM", "GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM"
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM") )
ModelRegistry.register_model( ModelRegistry.register_model(
"DeepseekV3ForCausalLM", "DeepseekV3ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM"
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM") )
ModelRegistry.register_model( ModelRegistry.register_model(
"DeepseekV32ForCausalLM", "DeepseekV32ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM"
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM") )
ModelRegistry.register_model(
"DeepSeekMTPModel",
"vllm_kunlun.models.deepseek_mtp:DeepSeekMTP")
ModelRegistry.register_model( ModelRegistry.register_model(
"GlmMoeDsaForCausalLM", "DeepSeekMTPModel", "vllm_kunlun.models.deepseek_mtp:DeepSeekMTP"
"vllm_kunlun.models.deepseek_v2:GlmMoeDsaForCausalLM") )
ModelRegistry.register_model(
"GlmMoeDsaForCausalLM", "vllm_kunlun.models.deepseek_v2:GlmMoeDsaForCausalLM"
)
def register_quant_method(): def register_quant_method():
"""to do""" """to do"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,303 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next MTP model."""
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig
from .qwen3_next import Qwen3NextDecoderLayer, Qwen3NextRMSNorm
logger = init_logger(__name__)
KVCache = tuple[torch.Tensor, torch.Tensor]
@support_torch_compile
class Qwen3NextMultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
config: Qwen3NextConfig = model_config.hf_config
self.config = config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.fc = ColumnParallelLinear(
self.config.hidden_size * 2,
self.config.hidden_size,
gather_output=True,
bias=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fc",
)
self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer(
vllm_config,
layer_type="full_attention",
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(self.num_mtp_layers)
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
hidden_states = self.pre_fc_norm_hidden(hidden_states)
hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
hidden_states = self.fc(hidden_states)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
current_step_idx = spec_step_idx % self.num_mtp_layers
hidden_states, residual = self.layers[current_step_idx](
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@support_torch_compile
class Qwen3NextMTP(nn.Module, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["up_proj", "down_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
cache_config = vllm_config.cache_config
assert (
not cache_config.enable_prefix_caching
), "Qwen3NextMTP currently does not support prefix caching"
self.quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.model = Qwen3NextMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
)
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
hidden_states = self.model(
input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> Optional[torch.Tensor]:
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
shared_weight_names = ["embed_tokens", "lm_head"]
def remap_weight_names(weights):
for name, weight in weights:
if name.startswith("mtp."):
name = name.replace("mtp.", "model.")
elif not any(key in name for key in shared_weight_names):
continue
yield name, weight
loader = AutoWeightsLoader(self)
return loader.load_weights(remap_weight_names(weights))

View File

@@ -16,33 +16,33 @@
# limitations under the License. # limitations under the License.
"""kunlun custom op entry""" """kunlun custom op entry"""
import torch_xmlir
from typing import Optional
import torch import torch
import os
from typing import Optional, List, Dict
import vllm.envs as envs
import os
import ctypes
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
try: try:
import kunlun_ops import kunlun_ops
logger.info(f"Load custom ops library success!")
logger.info("Load custom ops library success!")
except ImportError as e: except ImportError as e:
logger.warning("Import error msg: %s", e.msg) logger.warning("Import error msg: %s", e.msg)
_per_token_smooth_quant = True _per_token_smooth_quant = True
def is_per_token_smooth_quant(): def is_per_token_smooth_quant():
""" is per token smooth quant """ """is per token smooth quant"""
return _per_token_smooth_quant return _per_token_smooth_quant
class KunlunOps: class KunlunOps:
"""KunlunOps""" """KunlunOps"""
# Attention ops # Attention ops
@staticmethod @staticmethod
def paged_attention_v1( def paged_attention_v1(
@@ -67,9 +67,9 @@ class KunlunOps:
blocksparse_vert_stride, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_block_size,
blocksparse_head_sliding_step, blocksparse_head_sliding_step,
alibi_sqrt=False alibi_sqrt=False,
): ):
""" PagedAttentionV1 """ """PagedAttentionV1"""
# block_size = value_cache.shape[2] # block_size = value_cache.shape[2]
kunlun_ops.paged_attention( kunlun_ops.paged_attention(
x=query, x=query,
@@ -81,7 +81,7 @@ class KunlunOps:
is_context=is_context, is_context=is_context,
is_causal=True, is_causal=True,
out=output, out=output,
vo_head_dim=128 vo_head_dim=128,
) )
@staticmethod @staticmethod
@@ -110,9 +110,9 @@ class KunlunOps:
blocksparse_vert_stride, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_block_size,
blocksparse_head_sliding_step, blocksparse_head_sliding_step,
alibi_sqrt=False alibi_sqrt=False,
): ):
""" PagedAttentionV2 """ """PagedAttentionV2"""
# block_size = value_cache.shape[2] # block_size = value_cache.shape[2]
kunlun_ops.paged_attention( kunlun_ops.paged_attention(
x=query, x=query,
@@ -124,31 +124,28 @@ class KunlunOps:
is_context=is_context, is_context=is_context,
is_causal=True, is_causal=True,
out=output, out=output,
vo_head_dim=128 vo_head_dim=128,
) )
# Activation ops # Activation ops
@staticmethod @staticmethod
def silu_and_mul(out: torch.Tensor, def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
x: torch.Tensor): """silu and mul"""
""" silu and mul """
kunlun_ops.silu_and_mul( kunlun_ops.silu_and_mul(
x, x,
axis=-1, axis=-1,
turn=True, turn=True,
out=out, out=out,
) )
# Activation ops # Activation ops
@staticmethod @staticmethod
def quick_gelu(out: torch.Tensor, def quick_gelu(out: torch.Tensor, x: torch.Tensor):
x: torch.Tensor): """quick gelu"""
""" quick gelu """
kunlun_ops.quick_gelu( kunlun_ops.quick_gelu(
x, x,
out=out, out=out,
) )
# Layernorm # Layernorm
@staticmethod @staticmethod
@@ -159,9 +156,7 @@ class KunlunOps:
epsilon, epsilon,
): ):
"""rms_norm""" """rms_norm"""
kunlun_ops.rmsnorm( kunlun_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
x, weight.to(torch.float32), epsilon, out=out
)
@staticmethod @staticmethod
def fused_add_rms_norm( def fused_add_rms_norm(
@@ -179,16 +174,11 @@ class KunlunOps:
residual.copy_(fused_input, non_blocking=True) residual.copy_(fused_input, non_blocking=True)
x.copy_(output) x.copy_(output)
# Rotary embedding # Rotary embedding
@staticmethod @staticmethod
def rotary_embedding( def rotary_embedding(
positions, positions, query, key, head_size, cos_sin_cache, is_neox_style
query, ):
key,
head_size,
cos_sin_cache,
is_neox_style):
""" """
refactor RotaryEmbedding forward function refactor RotaryEmbedding forward function
""" """
@@ -196,62 +186,38 @@ class KunlunOps:
key_x = key.contiguous() key_x = key.contiguous()
torch.ops._C.rotary_embedding( torch.ops._C.rotary_embedding(
positions, positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
query_x, )
key_x,
head_size,
cos_sin_cache,
is_neox_style)
return query_x, key_x return query_x, key_x
# Rotary embedding # Rotary embedding
@staticmethod @staticmethod
def mrotary_embedding( def mrotary_embedding(
positions, positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
mrope_section, ):
query,
key,
head_size,
cos_sin_cache,
is_neox_style):
""" """
refactor RotaryEmbedding forward function refactor RotaryEmbedding forward function
""" """
query_x = query.contiguous() query_x = query.contiguous()
key_x = key.contiguous() key_x = key.contiguous()
query_x_dim = query_x.dim()
assert is_neox_style assert is_neox_style
kunlun_ops.mrotary_embedding_neox( kunlun_ops.mrotary_embedding_neox(
positions, positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
query_x, )
key_x,
head_size,
cos_sin_cache,
mrope_section)
query.data = query_x query.data = query_x
key.data = key_x key.data = key_x
return query, key return query, key
@staticmethod @staticmethod
def swap_blocks( def swap_blocks(src, dst, block_mapping):
src, """swap_blocks"""
dst, kunlun_ops.swap_blocks(src, dst, block_mapping)
block_mapping):
""" swap_blocks """
kunlun_ops.swap_blocks(
src,
dst,
block_mapping
)
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(key_caches, value_caches, block_mapping):
key_caches, """copy_blocks"""
value_caches,
block_mapping):
""" copy_blocks """
for i in range(len(key_caches)): for i in range(len(key_caches)):
key_caches[i] = key_caches[i].contiguous() key_caches[i] = key_caches[i].contiguous()
value_caches[i] = value_caches[i].contiguous() value_caches[i] = value_caches[i].contiguous()
@@ -269,16 +235,10 @@ class KunlunOps:
value_cache, value_cache,
slot_mapping, slot_mapping,
kv_cache_dtype, kv_cache_dtype,
): ):
""" reshape_and_cache """ """reshape_and_cache"""
# slot_mapping_cast = slot_mapping.to(torch.int32) # slot_mapping_cast = slot_mapping.to(torch.int32)
kunlun_ops.reshape_and_cache( kunlun_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
key,
value,
key_cache,
value_cache,
slot_mapping
)
@staticmethod @staticmethod
def multi_query_kv_attention( def multi_query_kv_attention(
@@ -287,7 +247,7 @@ class KunlunOps:
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
**kargs **kargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
query: shape = [num_prompt_tokens, num_heads, head_size] query: shape = [num_prompt_tokens, num_heads, head_size]
@@ -297,16 +257,12 @@ class KunlunOps:
key = key.unsqueeze(0) key = key.unsqueeze(0)
value = value.unsqueeze(0) value = value.unsqueeze(0)
output = torch.empty_like(query) output = torch.empty_like(query)
alibi_slopes = kargs.get("alibi_slopes", None)
mask = kargs.get("mask", None)
is_causal = kargs.get("is_causal", True)
is_lvsl = kargs.get("is_lvsl", True)
B, T, Qh, Hd = query.shape B, T, Qh, Hd = query.shape
KVh = key.size(2) KVh = key.size(2)
if KVh != Qh: if KVh != Qh:
repeat = Qh // KVh repeat = Qh // KVh
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd] key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
value = value.repeat_interleave(repeat, dim=2) value = value.repeat_interleave(repeat, dim=2)
kunlun_ops.attention( kunlun_ops.attention(
q=query, q=query,
@@ -321,80 +277,90 @@ class KunlunOps:
return output return output
@staticmethod @staticmethod
def quant_fusedresidual_rmsnorm_op(x, def quant_fusedresidual_rmsnorm_op(
residual, x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
weight, ):
bias,
scale_to_int,
eps,
dyn_scale: bool,
type: int = 1):
"""Quantized fused residual layer normalization""" """Quantized fused residual layer normalization"""
out = torch.empty_like(x, dtype=torch.int8) out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant(): if is_per_token_smooth_quant():
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1) out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
else: else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float) out_scale = torch.empty(12, device=x.device, dtype=torch.float)
kunlun_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps, kunlun_ops.quant_fusedresidual_rmsnorm(
out=out, out_scale=out_scale , residual_tensor=residual) x,
residual,
weight,
bias,
eps,
out=out,
out_scale=out_scale,
residual_tensor=residual,
)
if residual is None: if residual is None:
return out, out_scale return out, out_scale
return out, out_scale, residual return out, out_scale, residual
@staticmethod @staticmethod
def quant_rmsnorm_op(x, def quant_rmsnorm_op(
weight, x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
bias, ):
scale_to_int,
eps,
dyn_scale : bool,
type: int = 1):
"""Quantized RMSNorm""" """Quantized RMSNorm"""
out = torch.empty_like(x, dtype=torch.int8) out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant(): if is_per_token_smooth_quant():
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1) out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
else: else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float) out_scale = torch.empty(12, device=x.device, dtype=torch.float)
kunlun_ops.quant_rmsnorm(x, weight, bias, eps, kunlun_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
out=out, out_scale=out_scale)
return out, out_scale return out, out_scale
@staticmethod @staticmethod
def smooth_quant_matmul_column_row_kernels(input_tensor, def smooth_quant_matmul_column_row_kernels(
weight, input_tensor,
smoother, weight,
input_scale, smoother,
weight_scale, input_scale,
perTokenScaling, weight_scale,
perChannelScaling, perTokenScaling,
otype): perChannelScaling,
otype,
):
"""smooth_quant_matmul_column_row_kernels""" """smooth_quant_matmul_column_row_kernels"""
input_shape = input_tensor.shape input_shape = input_tensor.shape
weight_shape = weight.shape weight_shape = weight.shape
if input_tensor.dim() == 3: if input_tensor.dim() == 3:
input_tensor = input_tensor.reshape(-1, input_shape[-1]) input_tensor = input_tensor.reshape(-1, input_shape[-1])
out = torch.empty((input_shape[0] * input_shape[1], out = torch.empty(
weight_shape[0]), (input_shape[0] * input_shape[1], weight_shape[0]),
dtype=torch.float16, dtype=torch.float16,
device=weight.device) device=weight.device,
)
output_bs_shape = [input_shape[0], input_shape[1]] output_bs_shape = [input_shape[0], input_shape[1]]
elif input_tensor.dim() == 2: elif input_tensor.dim() == 2:
out = torch.empty((input_shape[0], weight_shape[0]), out = torch.empty(
dtype=torch.float16, (input_shape[0], weight_shape[0]),
device=weight.device) dtype=torch.float16,
device=weight.device,
)
output_bs_shape = [-1] output_bs_shape = [-1]
kunlun_ops.smooth_quant_matmul_column_row_kernels(input_tensor, kunlun_ops.smooth_quant_matmul_column_row_kernels(
weight, smoother, input_tensor,
input_scale, weight,
weight_scale, smoother,
perTokenScaling, input_scale,
perChannelScaling, weight_scale,
out=out) perTokenScaling,
perChannelScaling,
out=out,
)
out = out.view(*output_bs_shape, weight_shape[0]) out = out.view(*output_bs_shape, weight_shape[0])
@@ -404,6 +370,7 @@ class KunlunOps:
if torch.is_tensor(x): if torch.is_tensor(x):
return (type(x), x.device, x.dtype, x.shape, x.is_contiguous()) return (type(x), x.device, x.dtype, x.shape, x.is_contiguous())
return (type(x), x) return (type(x), x)
@staticmethod @staticmethod
def fused_moe( def fused_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -420,23 +387,24 @@ class KunlunOps:
w1_bias: Optional[torch.Tensor] = None, w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""fused_moe""" """fused_moe"""
global_num_experts, up_gate_size, _ = w1.shape global_num_experts, up_gate_size, _ = w1.shape
M, N = hidden_states.shape M, N = hidden_states.shape
hidden_dim = w2.shape[1] hidden_dim = w2.shape[1]
normed_score = torch.empty(M, normed_score = torch.empty(
moe_top_k, M, moe_top_k, dtype=torch.float32, device=hidden_states.device
dtype=torch.float32, )
device=hidden_states.device) topk_ids = torch.empty(
topk_ids = torch.empty(M, M, moe_top_k, dtype=torch.int32, device=hidden_states.device
moe_top_k, )
dtype=torch.int32,
device=hidden_states.device)
num_blocks = 12 num_blocks = 12
block_statistic = torch.zeros( block_statistic = torch.zeros(
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device num_blocks,
global_num_experts,
dtype=torch.int32,
device=hidden_states.device,
) )
router_logits = router_logits.to(torch.float) router_logits = router_logits.to(torch.float)
if scoring_func == "softmax": if scoring_func == "softmax":
@@ -445,24 +413,27 @@ class KunlunOps:
normed_score=normed_score, normed_score=normed_score,
topk_index=topk_ids, topk_index=topk_ids,
block_statistic=None, block_statistic=None,
stable=True) stable=False,
)
elif scoring_func == "sigmoid": elif scoring_func == "sigmoid":
torch.ops._C.moe_sigmoid_group_topk_norm( torch.ops._C.moe_sigmoid_group_topk_norm(
x=router_logits, x=router_logits,
topk_index=topk_ids, topk_index=topk_ids,
norm_score=normed_score, norm_score=normed_score,
block_static=block_statistic, block_static=block_statistic,
bias=e_score_correction_bias, bias=e_score_correction_bias,
scale=1.0, scale=1.0,
n_group=num_expert_group, n_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
) )
if w1_bias is not None or w2_bias is not None: if w1_bias is not None or w2_bias is not None:
# Rignt now this branch is for gpt oss # Rignt now this branch is for gpt oss
# TODO (@xyDong23): faster here using moe_fc kernel # TODO (@xyDong23): faster here using moe_fc kernel
normed_score = normed_score.to(hidden_states.dtype) normed_score = normed_score.to(hidden_states.dtype)
out = torch.zeros(M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device) out = torch.zeros(
M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device
)
repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0) repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0)
topk_ids_flat = topk_ids.flatten() topk_ids_flat = topk_ids.flatten()
for i in range(global_num_experts): for i in range(global_num_experts):
@@ -470,9 +441,13 @@ class KunlunOps:
selected_token = topk_ids_flat == experts_id selected_token = topk_ids_flat == experts_id
if selected_token.sum(): if selected_token.sum():
cur_token = repeat_x[selected_token] cur_token = repeat_x[selected_token]
up_gate = torch.empty(selected_token.sum(), up_gate_size//2, up_gate = torch.empty(
dtype=cur_token.dtype, device=cur_token.device) selected_token.sum(),
groupgemm1 = cur_token@ w1[i].T up_gate_size // 2,
dtype=cur_token.dtype,
device=cur_token.device,
)
groupgemm1 = cur_token @ w1[i].T
# Add w13 bias # Add w13 bias
if w1_bias is not None: if w1_bias is not None:
groupgemm1 = groupgemm1 + w1_bias[i] groupgemm1 = groupgemm1 + w1_bias[i]
@@ -482,53 +457,129 @@ class KunlunOps:
if w2_bias is not None: if w2_bias is not None:
groupgemm2 = groupgemm2 + w2_bias[i] groupgemm2 = groupgemm2 + w2_bias[i]
out[selected_token] = groupgemm2 out[selected_token] = groupgemm2
ouput = (out.view(M, moe_top_k, N) * normed_score.unsqueeze(2)).sum(dim=1).to(hidden_states.dtype) ouput = (
(out.view(M, moe_top_k, N) * normed_score.unsqueeze(2))
.sum(dim=1)
.to(hidden_states.dtype)
)
return ouput return ouput
else: else:
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float # from vllm.forward_context import get_forward_context
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E] # forward_context = get_forward_context()
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1] # attn_metadata: AttentionMetadata = forward_context.attn_metadata
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device) # prefix = "model.layers.0.linear_attn"
# if attn_metadata is not None:
torch.ops._C.gen_block_statistic(topk_ids,block_statistic) # attn_metadata = attn_metadata[prefix]
torch.ops._C.moe_pre_sorted( # if attn_metadata is None or attn_metadata.num_prefills > 0 or :
x=hidden_states, if M * moe_top_k < 400:
topk_index=topk_ids, sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
block_statistic=block_statistic, torch.ops.xspeedgate_ops.moe_pre_small(
moe_expand=moe_expand, topk_ids, global_num_experts, False, False, hidden_states
moe_index=sorted_tokens_idx, )
expert_m=expert_m, )
sorted_tokens_num_lod=sorted_tokens_num_lod) experts_num_lod = torch.ops.xspeedgate_ops.moe_active_expert_balance(
topk_ids, global_num_experts, False
)
out = torch.ops.xspeedgate_ops.fused_moe(
hidden_states,
w1,
w2,
normed_score.to(hidden_states.dtype),
sorted_tokens_num_lod,
sorted_tokens_idx,
experts_num_lod,
)
return out.sum(1)
y = torch.empty(M,moe_top_k, if M * moe_top_k > 768:
w1.shape[1], moe_expand = torch.empty(
(M * moe_top_k, N),
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device=hidden_states.device) device=hidden_states.device,
) # [M*top_k, N], float
expert_m = torch.zeros(
global_num_experts, dtype=torch.int32, device=hidden_states.device
) # [E]
sorted_tokens_num_lod = torch.zeros(
global_num_experts + 1,
dtype=torch.int32,
device=hidden_states.device,
) # [E+1]
sorted_tokens_idx = torch.zeros(
M * moe_top_k, dtype=torch.int32, device=hidden_states.device
)
torch.ops._C.gen_block_statistic(topk_ids, block_statistic)
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod,
)
else:
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
torch.ops.xspeedgate_ops.moe_pre_small(
topk_ids,
global_num_experts,
index_have_neg=False,
sort_mode=True,
x=hidden_states,
)
)
y = torch.empty(
M,
moe_top_k,
w1.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim) moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
torch.ops._C.moe_fc( if M < 1024:
x=moe_expand, torch.ops._C.moe_fc(
weight=w1, x=moe_expand,
sorted_tokens_num_lod=sorted_tokens_num_lod, weight=w1,
sorted_tokens_idx=sorted_tokens_idx, sorted_tokens_num_lod=sorted_tokens_num_lod,
moe_topk=moe_top_k, sorted_tokens_idx=sorted_tokens_idx,
y=y, moe_topk=moe_top_k,
y=y,
)
d = y.shape[-1] // 2
output_shape = y.shape[:-1] + (d,)
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out1 = out1.reshape(-1, out1.shape[-1])
else:
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y,
act="SWISH_GLU",
)
y = y[..., : y.shape[-1] // 2]
out1 = y.reshape(-1, y.shape[-1])
out = torch.empty(
M,
moe_top_k,
w2.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
) )
d = y.shape[-1] // 2
output_shape = (y.shape[:-1] + (d, ))
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out = torch.empty(M,moe_top_k,
w2.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
out1 = out1.reshape(-1, out1.shape[-1])
torch.ops._C.moe_fc( torch.ops._C.moe_fc(
x=out1, x=out1,
weight=w2, weight=w2,
@@ -538,8 +589,12 @@ class KunlunOps:
y=out, y=out,
) )
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device) dequant_scale = torch.ones(
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device) [M, moe_top_k], dtype=torch.float32, device=out.device
)
output = torch.empty(
[M, N], dtype=hidden_states.dtype, device=hidden_states.device
)
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k) sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
torch.ops._C.moe_post( torch.ops._C.moe_post(
@@ -547,9 +602,9 @@ class KunlunOps:
moe_index=sorted_tokens_idx, moe_index=sorted_tokens_idx,
normed_scale=normed_score, normed_scale=normed_score,
dequant_scale=dequant_scale, dequant_scale=dequant_scale,
y=output y=output,
) )
return output return output
@staticmethod @staticmethod
@@ -568,23 +623,23 @@ class KunlunOps:
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None, w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
x = hidden_states x = hidden_states
batch, hidden_size = x.shape batch, hidden_size = x.shape
num_local_experts, up_gate_size, _ = w13_weight.shape num_local_experts, up_gate_size, _ = w13_weight.shape
router_logits = x.to(linear_weights.dtype)@linear_weights.T router_logits = x.to(linear_weights.dtype) @ linear_weights.T
topk_weights = torch.empty(batch, topk_weights = torch.empty(
top_k, batch, top_k, dtype=router_logits.dtype, device=router_logits.device
dtype=router_logits.dtype, )
device=router_logits.device) topk_ids = torch.empty(
topk_ids = torch.empty(batch, batch, top_k, dtype=torch.int32, device=router_logits.device
top_k, )
dtype=torch.int32, block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
device=router_logits.device) torch.ops._C.moe_softmax_topk(
block_static = torch.empty(0, dtype=torch.int32,device=router_logits.device) router_logits, topk_weights, topk_ids, block_static
torch.ops._C.moe_softmax_topk(router_logits, topk_weights, topk_ids, block_static) )
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
@@ -598,11 +653,19 @@ class KunlunOps:
selected_token = topk_ids_flat == experts_id selected_token = topk_ids_flat == experts_id
if selected_token.sum(): if selected_token.sum():
cur_token = repeat_x[selected_token] cur_token = repeat_x[selected_token]
up_gate = torch.empty(selected_token.sum(), up_gate_size//2, up_gate = torch.empty(
dtype=cur_token.dtype, device=cur_token.device) selected_token.sum(),
torch.ops._C.silu_and_mul(up_gate, cur_token@ w13_weight[i].T) up_gate_size // 2,
dtype=cur_token.dtype,
device=cur_token.device,
)
torch.ops._C.silu_and_mul(up_gate, cur_token @ w13_weight[i].T)
out[selected_token] = up_gate @ w2_weight[i].T out[selected_token] = up_gate @ w2_weight[i].T
output = (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)).sum(dim=1).to(x.dtype) output = (
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
.sum(dim=1)
.to(x.dtype)
)
return output return output
@@ -638,10 +701,11 @@ class KunlunOps:
prompt_lods_cpu: torch.Tensor, prompt_lods_cpu: torch.Tensor,
k_cache: torch.Tensor, k_cache: torch.Tensor,
v_cache: torch.Tensor, v_cache: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""mla pa block""" """mla pa block"""
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, output = torch.empty(
device=hidden_states.device) hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
)
kunlun_ops.xft_multi_head_latent_page_attention_block( kunlun_ops.xft_multi_head_latent_page_attention_block(
hidden_states, hidden_states,
q_lora_rank, q_lora_rank,
@@ -679,7 +743,6 @@ class KunlunOps:
) )
return output return output
def fused_gdn_gating( def fused_gdn_gating(
A_log: torch.Tensor, A_log: torch.Tensor,
a: torch.Tensor, a: torch.Tensor,
@@ -695,25 +758,34 @@ class KunlunOps:
) )
return output return output
def fused_recurrent_gated_delta_rule_fwd( def fused_recurrent_gated_delta_rule_fwd(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
g: torch.Tensor, g: torch.Tensor,
beta: torch.Tensor, beta: torch.Tensor,
scale: float, scale: float,
h0_source: torch.Tensor, h0_source: torch.Tensor,
output_final_state: bool, output_final_state: bool,
use_qk_l2norm_in_kernel: bool, use_qk_l2norm_in_kernel: bool,
cu_seqlens: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]: cu_seqlens: torch.Tensor = None,
''' ) -> tuple[torch.Tensor, torch.Tensor]:
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起 """
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。 Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制 1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)
''' 2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
"""
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd( o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel, q,
cu_seqlens) k,
return (o, final_state) v,
g,
beta,
scale,
h0_source,
output_final_state,
use_qk_l2norm_in_kernel,
cu_seqlens,
)
return (o, final_state)

View File

@@ -9,60 +9,196 @@
# ruff: noqa: E501 # ruff: noqa: E501
import warnings import warnings
from typing import Optional from typing import Optional
import torch.nn.functional as F
import cocopod # noqa
import torch import torch
import torch.distributed as dist import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h from .index import prepare_chunk_indices, prepare_chunk_offsets
from .chunk_o import chunk_fwd_o
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from .cumsum import chunk_local_cumsum
from .l2norm import l2norm_fwd from .l2norm import l2norm_fwd
from .solve_tril import solve_tril
from .utils import SUPPRESS_LEVEL, input_guard from .utils import SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd
from .index import prepare_chunk_indices
import xspeedgate_ops
import cocopod
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,): def torch_solve_tril(
chunk_size=64 A: torch.Tensor,
A = -A.transpose(1,2) cu_seqlens: Optional[torch.LongTensor] = None,
output_dtype: torch.dtype = torch.float,
):
chunk_size = 64
A = -A.transpose(1, 2)
sequence_length = A.shape[-2] sequence_length = A.shape[-2]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
A = F.pad(A, (0, 0, 0, pad_size)) A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1]) A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
# mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
# A = A.masked_fill(mask, 0)
for i in range(1, chunk_size): for i in range(1, chunk_size):
row = A[..., i, :i].clone() row = A[..., i, :i].clone()
sub = A[..., :i, :i].clone() sub = A[..., :i, :i].clone()
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device) A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[:,:,:sequence_length,:].transpose(1,2) return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[
:, :, :sequence_length, :
].transpose(1, 2)
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
A = chunk_scaled_dot_kkt_fwd(k=k,
beta=beta,
g_cumsum=g,
cu_seqlens=cu_seqlens,
output_dtype=q.dtype)
#kernel版 def recompute_w_u_fwd_torch(
torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens) k: torch.Tensor, # [B, T, H, K]
chunk_indices = prepare_chunk_indices( v: torch.Tensor, # [B, T, H, V]
cu_seqlens, 64) if cu_seqlens is not None else None beta: torch.Tensor, # [B, T, H]
g: torch.Tensor, # [B, T, H]
A: torch.Tensor, # [B, H, T, T]
):
"""
最简单版本假设等长序列key和value头数相同
"""
chunk_size = 64
num_v_heads, num_k_heads = v.shape[2], k.shape[2]
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k, v, beta, g, A = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (k, v, beta, g, A)
]
batch_size, num_heads, sequence_length, k_head_dim = k.shape
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
k = F.pad(k, (0, 0, 0, pad_size))
v = F.pad(v, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
v_beta = v * beta.unsqueeze(-1)
k_beta = k * beta.unsqueeze(-1)
k, v, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (k, v, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
u = A @ v_beta
w = A @ (k_beta * g.exp().unsqueeze(-1))
w = (
w.reshape(w.shape[0], w.shape[1], -1, w.shape[-1])[:, :, :sequence_length, :]
.transpose(1, 2)
.contiguous()
)
u = (
u.reshape(u.shape[0], u.shape[1], -1, u.shape[-1])[:, :, :sequence_length, :]
.transpose(1, 2)
.contiguous()
)
return w, u
def split_by_value(tensor, chunk_size=64):
indices = tensor.tolist()
result = set(indices) # 使用集合避免重复
for i in range(len(indices) - 1):
start = indices[i]
end = indices[i + 1]
# 计算第一个对齐边界
# 我们要找的是 start + n*chunk_size其中n是使结果大于start的最小整数
first_boundary = start + chunk_size
# 在(start, end)范围内插入所有对齐边界
boundary = first_boundary
while boundary < end:
result.add(boundary)
boundary += chunk_size
return torch.tensor(sorted(result), dtype=tensor.dtype, device=tensor.device)
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
):
chunk_size = 64
chunk_indices = (
prepare_chunk_indices(cu_seqlens, 64) if cu_seqlens is not None else None
)
chunk_offsets = (
prepare_chunk_offsets(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
# !
# g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
g = torch.ops.xspeedgate_ops.chunk_local_cumsum(
g,
chunk_size=64,
reverse=False,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
head_first=False,
)
# !
# A = chunk_scaled_dot_kkt_fwd(k=k,
# beta=beta,
# g_cumsum=g,
# cu_seqlens=cu_seqlens,
# output_dtype=q.dtype)
A = torch.ops.xspeedgate_ops.chunk_scaled_dot_kkt_fwd(
k, beta, g, cu_seqlens, chunk_indices, chunk_size
)
# torch版
# if get_tensor_model_parallel_rank() == 0:
# torch.save(A, "A_in")
# torch.save(cu_seqlens, "cu_seqlens")
# A2 = A.clone()
torch.ops.xspeedgate_ops.solve_tril_ns(A, cu_seqlens, chunk_indices, chunk_size)
# !
# torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
# if get_tensor_model_parallel_rank() == 0:
# err = torch.max(torch.abs(A - A2))
# print("err", err)
# if err > 1e-3:
# raise
# A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
# for i in range(len(cu_seqlens)-1):
# A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
# A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
"""
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
for i in range(len(cu_seqlens)-1):
k_i = k[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
v_i = v[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
beta_i = beta[:, cu_seqlens[i]:cu_seqlens[i+1], :]
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
g_i = g[:, cu_seqlens[i]:cu_seqlens[i+1], :]
w_i, u_i = recompute_w_u_fwd_torch(
k=k_i,
v=v_i,
beta=beta_i,
A=A_i,
g=g_i,
)
w[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = w_i
u[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = u_i
"""
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd( w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
k=k, k=k,
v=v, v=v,
@@ -71,17 +207,63 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
g_cumsum=g, g_cumsum=g,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices, chunk_indices=chunk_indices,
chunk_size=64 chunk_size=64,
) )
h, v_new, final_state = chunk_gated_delta_rule_fwd_h( """
w, u = recompute_w_u_fwd(
k=k, k=k,
w=w, v=v,
u=u, beta=beta,
g=g, A=A,
initial_state=initial_state, g_cumsum=g,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
) )
"""
# i
# import os
# if not os.path.exists("/qwen-next/in"):
# os.makedirs("/qwen-next/in")
# torch.save(k, "/qwen-next/in/k.pt")
# torch.save(u, "/qwen-next/in/u.pt")
# torch.save(w, "/qwen-next/in/w.pt")
# torch.save(g, "/qwen-next/in/g.pt")
# torch.save(initial_state, "/qwen-next/in/initial_state.pt")
# torch.save(cu_seqlens, "/qwen-next/in/cu_seqlens.pt")
# torch.save(chunk_indices, "/qwen-next/in/chunk_indices.pt")
# torch.save(chunk_offsets.to(torch.int32), "/qwen-next/in/chunk_offsets.pt")
# torch.save(chunk_size, "/qwen-next/in/chunk_size.pt")
# torch.save(output_final_state, "/qwen-next/in/output_final_state.pt")
h, v_new, final_state = torch.ops.xspeedgate_ops.chunk_gated_delta_rule_fwd_h(
k,
u,
w,
g,
initial_state,
cu_seqlens,
chunk_indices,
chunk_offsets.to(torch.int32),
chunk_size,
output_final_state,
True,
)
# h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
# k=k,
# w=w,
# u=u,
# g=g,
# initial_state=initial_state,
# output_final_state=output_final_state,
# cu_seqlens=cu_seqlens,
# )
# if not os.path.exists("/qwen-next/out"):
# os.makedirs("/qwen-next/out")
# torch.save(h, "/qwen-next/out/h.pt")
# torch.save(v_new, "/qwen-next/out/v_new.pt")
# torch.save(final_state, "/qwen-next/out/final_state.pt")
o = torch.ops.xspeedgate_ops.chunk_fwd_o( o = torch.ops.xspeedgate_ops.chunk_fwd_o(
q=q, q=q,
k=k, k=k,
@@ -91,8 +273,19 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
scale=scale, scale=scale,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices, chunk_indices=chunk_indices,
chunk_size=64 chunk_size=64,
) )
"""
o = chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
)
"""
if SUPPRESS_LEVEL < 3: if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None return g, o, A, final_state, None, None, None
elif SUPPRESS_LEVEL >= 3: elif SUPPRESS_LEVEL >= 3:
@@ -103,18 +296,20 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod @staticmethod
@input_guard @input_guard
@torch.amp.custom_fwd(device_type='cuda') @torch.amp.custom_fwd(device_type="cuda")
def forward(ctx, def forward(
q: torch.Tensor, ctx,
k: torch.Tensor, q: torch.Tensor,
v: torch.Tensor, k: torch.Tensor,
g: torch.Tensor, v: torch.Tensor,
beta: torch.Tensor, g: torch.Tensor,
scale: float, beta: torch.Tensor,
initial_state: torch.Tensor, scale: float,
output_final_state: bool, initial_state: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None, output_final_state: bool,
use_qk_l2norm_in_kernel: bool = False): cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel: if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q) q = l2norm_fwd(q)
k = l2norm_fwd(k) k = l2norm_fwd(k)
@@ -136,17 +331,19 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@torch.compiler.disable @torch.compiler.disable
def chunk_gated_delta_rule(q: torch.Tensor, def chunk_gated_delta_rule(
k: torch.Tensor, q: torch.Tensor,
v: torch.Tensor, k: torch.Tensor,
g: torch.Tensor, v: torch.Tensor,
beta: torch.Tensor, g: torch.Tensor,
scale: float = None, beta: torch.Tensor,
initial_state: torch.Tensor = None, scale: float = None,
output_final_state: bool = False, initial_state: torch.Tensor = None,
cu_seqlens: Optional[torch.LongTensor] = None, output_final_state: bool = False,
head_first: bool = False, cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False): head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
):
r""" r"""
Args: Args:
q (torch.Tensor): q (torch.Tensor):
@@ -211,42 +408,85 @@ def chunk_gated_delta_rule(q: torch.Tensor,
) )
""" """
assert q.dtype == k.dtype == v.dtype assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." assert (
assert len( q.dtype != torch.float32
beta.shape ), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." assert (
len(beta.shape) == 3
), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
if head_first: if head_first:
raise DeprecationWarning( raise DeprecationWarning(
"head_first is deprecated and will be removed in a future version. " "head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead.", "Please use head_first=False for now instead.",
stacklevel=2) stacklevel=2,
)
q, k, v, beta, g = map( q, k, v, beta, g = map(
lambda x: rearrange(x, 'b h t ... -> b t h ...'), lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
(q, k, v, beta, g)) )
if not head_first and q.shape[1] < q.shape[2]: if not head_first and q.shape[1] < q.shape[2]:
warnings.warn( warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] " "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. " "when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].", "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2) stacklevel=2,
)
if cu_seqlens is not None: if cu_seqlens is not None:
if q.shape[0] != 1: if q.shape[0] != 1:
raise ValueError( raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.") f"Please flatten variable-length inputs before processing."
if initial_state is not None and initial_state.shape[0] != len( )
cu_seqlens) - 1: if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError( raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, " f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
) )
if scale is None: if scale is None:
scale = k.shape[-1]**-0.5 scale = k.shape[-1] ** -0.5
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, if False:
use_qk_l2norm_in_kernel) q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
g = g.contiguous()
beta = beta.contiguous()
initial_state = initial_state.contiguous()
o = torch.empty_like(v)
final_state = torch.empty_like(initial_state)
import kunlun_ops
kunlun_ops.gated_delta_rule(
q,
k,
v,
initial_state,
g,
beta,
final_state,
o,
scale,
cu_seqlens.cpu(),
cu_seqlens,
cu_seqlens.cpu(),
cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
else:
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
output_final_state,
cu_seqlens,
use_qk_l2norm_in_kernel,
)
if head_first: if head_first:
o = rearrange(o, 'b t h ... -> b h t ...') o = rearrange(o, "b t h ... -> b h t ...")
return o, final_state return o, final_state

View File

@@ -12,21 +12,21 @@
from typing import Optional from typing import Optional
import torch import torch
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices from .index import prepare_chunk_indices
from .op import exp
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
@triton.heuristics({ @triton.heuristics(
'USE_G': lambda args: args['g'] is not None, {
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None "USE_G": lambda args: args["g"] is not None,
}) "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
# triton.Config({ # triton.Config({
@@ -40,7 +40,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
# ], # ],
# key=['H', 'K', 'V', 'BT'], # key=['H', 'K', 'V', 'BT'],
# ) # )
@triton.jit(do_not_specialize=['T']) @triton.jit(do_not_specialize=["T"])
def chunk_fwd_kernel_o( def chunk_fwd_kernel_o(
q, q,
k, k,
@@ -67,10 +67,12 @@ def chunk_fwd_kernel_o(
if IS_VARLEN: if IS_VARLEN:
i_tg = i_t i_tg = i_t
i_n, i_t = tl.load(chunk_indices + i_t * 2).to( i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) chunk_indices + i_t * 2 + 1
bos, eos = tl.load(cu_seqlens + i_n).to( ).to(tl.int32)
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
cu_seqlens + i_n + 1
).to(tl.int32)
T = eos - bos T = eos - bos
NT = tl.cdiv(T, BT) NT = tl.cdiv(T, BT)
else: else:
@@ -89,12 +91,15 @@ def chunk_fwd_kernel_o(
b_A = tl.zeros([BT, BT], dtype=tl.float32) b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)): for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), p_q = tl.make_block_ptr(
(BT, BK), (1, 0)) q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), )
(BK, BT), (0, 1)) p_k = tl.make_block_ptr(
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
(BK, BV), (1, 0)) )
p_h = tl.make_block_ptr(
h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
)
# [BT, BK] # [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT] # [BK, BT]
@@ -109,8 +114,8 @@ def chunk_fwd_kernel_o(
if USE_G: if USE_G:
g += bos * H + i_h g += bos * H + i_h
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, )) p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0, )) b_g = tl.load(p_g, boundary_check=(0,))
b_o = b_o * tl.exp(b_g)[:, None] b_o = b_o * tl.exp(b_g)[:, None]
b_A = b_A * tl.exp(b_g[:, None] - b_g[None, :]) b_A = b_A * tl.exp(b_g[:, None] - b_g[None, :])
@@ -120,10 +125,12 @@ def chunk_fwd_kernel_o(
# b_A = tl.where(m_A, b_A, 0) # b_A = tl.where(m_A, b_A, 0)
b_A = tl.where(o_t[:, None] >= o_t[None, :], b_A, 0) b_A = tl.where(o_t[:, None] >= o_t[None, :], b_A, 0)
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), p_v = tl.make_block_ptr(
(BT, BV), (1, 0)) v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), )
(BT, BV), (1, 0)) p_o = tl.make_block_ptr(
o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1)) b_v = tl.load(p_v, boundary_check=(0, 1))
# to fix mma -> mma layout conversion # to fix mma -> mma layout conversion
@@ -133,48 +140,29 @@ def chunk_fwd_kernel_o(
def chunk_fwd_o( def chunk_fwd_o(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
h: torch.Tensor, h: torch.Tensor,
g: Optional[torch.Tensor] = None, # cumsum of log decay g: Optional[torch.Tensor] = None, # cumsum of log decay
scale: Optional[float] = None, scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64) -> torch.Tensor: chunk_size: int = 64,
B, T, Hg, K, V = *q.shape, v.shape[-1] ) -> torch.Tensor:
H = v.shape[-2] _, T, _, _, _ = *q.shape, v.shape[-1]
if FLA_GDN_FIX_BT: if FLA_GDN_FIX_BT:
BT = 64 BT = 64
else: else:
BT = min(chunk_size, max(16, triton.next_power_of_2(T))) BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices( chunk_indices = (
cu_seqlens, BT) if cu_seqlens is not None else None prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) )
if scale is None: if scale is None:
scale = k.shape[-1]**-0.5 scale = k.shape[-1] ** -0.5
o = torch.empty_like(v) o = torch.empty_like(v)
def grid(meta): o = torch.ops.xspeedgate_ops.chunk_fwd_o(
return (triton.cdiv(V, meta['BV']), NT, B * H) q, k, v, h, g, scale, cu_seqlens, chunk_indices, chunk_size
chunk_fwd_kernel_o[grid](
q,
k,
v,
h,
g,
o,
cu_seqlens,
chunk_indices,
scale,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=64,
BV=32
) )
return o return o

View File

@@ -9,28 +9,28 @@
# ruff: noqa: E501 # ruff: noqa: E501
from typing import Optional from typing import Optional
import torch
import kunlun_ops import kunlun_ops
import torch
class FusedRecurrentFunction(torch.autograd.Function): class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, def forward(
q: torch.Tensor, ctx,
k: torch.Tensor, q: torch.Tensor,
v: torch.Tensor, k: torch.Tensor,
g: torch.Tensor, v: torch.Tensor,
beta: torch.Tensor, g: torch.Tensor,
scale: float, beta: torch.Tensor,
initial_state: torch.Tensor, scale: float,
inplace_final_state: bool = True, initial_state: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None, inplace_final_state: bool = True,
ssm_state_indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.LongTensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None, ssm_state_indices: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False): num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwdv2( o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwdv2(
q.contiguous(), q.contiguous(),
k.contiguous(), k.contiguous(),
@@ -44,7 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
h0_indices=ssm_state_indices, h0_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens, num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
is_h0_transposed=True is_h0_transposed=True,
) )
return o, final_state return o, final_state
@@ -130,9 +130,10 @@ def fused_recurrent_gated_delta_rule(
if cu_seqlens is not None and q.shape[0] != 1: if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError( raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.") f"Please flatten variable-length inputs before processing."
)
if scale is None: if scale is None:
scale = k.shape[-1]**-0.5 scale = k.shape[-1] ** -0.5
else: else:
assert scale > 0, "scale must be positive" assert scale > 0, "scale must be positive"
if beta is None: if beta is None:

View File

@@ -10,22 +10,21 @@
import os import os
from typing import Optional from typing import Optional
import kunlun_ops
import torch import torch
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
import kunlun_ops
BT_LIST = [8, 16, 32, 64, 128] BT_LIST = [8, 16, 32, 64, 128]
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
@triton.autotune(configs=[ @triton.autotune(
triton.Config({}, num_warps=num_warps) configs=[
for num_warps in [1, 2, 4, 8, 16, 32] triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
], ],
key=['D']) key=["D"],
)
@triton.jit @triton.jit
def l2norm_fwd_kernel1( def l2norm_fwd_kernel1(
x, x,
@@ -49,11 +48,14 @@ def l2norm_fwd_kernel1(
tl.store(y + cols, b_y, mask=mask) tl.store(y + cols, b_y, mask=mask)
@triton.autotune(configs=[ @triton.autotune(
triton.Config({'BT': BT}, num_warps=num_warps) configs=[
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST triton.Config({"BT": BT}, num_warps=num_warps)
], for num_warps in [1, 2, 4, 8, 16]
key=['D']) for BT in BT_LIST
],
key=["D"],
)
@triton.jit(do_not_specialize=["NB"]) @triton.jit(do_not_specialize=["NB"])
def l2norm_fwd_kernel( def l2norm_fwd_kernel(
x, x,
@@ -87,67 +89,9 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
def l2norm_fwd_triton(x: torch.Tensor, def l2norm_fwd(
eps: float = 1e-6, x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
output_dtype: Optional[torch.dtype] = None): ):
x_shape_og = x.shape
x = x.view(-1, x.shape[-1])
# allocate output
if output_dtype is None:
y = torch.empty_like(x)
else:
y = torch.empty_like(x, dtype=output_dtype)
assert y.stride(-1) == 1
T, D = x.shape[0], x.shape[-1]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
if D > BD:
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
if not USE_DEFAULT_FLA_NORM:
MBLOCK = 32
# M, N = x.shape
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
x,
y,
eps,
T,
D,
MBLOCK,
)
else:
if D <= 512:
NB = triton.cdiv(T, 2048)
def grid(meta):
return (triton.cdiv(T, meta['BT']), )
l2norm_fwd_kernel[grid](
x,
y,
eps,
NB=NB,
T=T,
D=D,
BD=BD,
)
else:
l2norm_fwd_kernel1[(T, )](
x,
y,
eps=eps,
D=D,
BD=BD,
)
return y.view(x_shape_og)
def l2norm_fwd(x: torch.Tensor,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None):
out = torch.empty_like(x) out = torch.empty_like(x)
kunlun_ops.l2norm(x, out, eps) kunlun_ops.l2norm(x, out, eps)
return out return out

View File

@@ -19,20 +19,21 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from .utils import input_guard from .utils import input_guard
def rms_norm_ref(x, def rms_norm_ref(
weight, x,
bias, weight,
z=None, bias,
eps=1e-6, z=None,
group_size=None, eps=1e-6,
norm_before_gate=True, group_size=None,
upcast=True): norm_before_gate=True,
upcast=True,
):
dtype = x.dtype dtype = x.dtype
weight = weight.float() weight = weight.float()
bias = bias.float() if bias is not None else None bias = bias.float() if bias is not None else None
@@ -43,12 +44,10 @@ def rms_norm_ref(x,
x = x * F.silu(z) x = x * F.silu(z)
if group_size is None: if group_size is None:
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
weight)
else: else:
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
eps)
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
if bias is not None: if bias is not None:
out = out + bias out = out + bias
@@ -57,10 +56,12 @@ def rms_norm_ref(x,
return out.to(dtype) return out.to(dtype)
@triton.heuristics({ @triton.heuristics(
"HAS_BIAS": lambda args: args["B"] is not None, {
"HAS_Z": lambda args: args["Z"] is not None, "HAS_BIAS": lambda args: args["B"] is not None,
}) "HAS_Z": lambda args: args["Z"] is not None,
}
)
@triton.jit @triton.jit
def layer_norm_fwd_kernel( def layer_norm_fwd_kernel(
X, # pointer to the input X, # pointer to the input
@@ -97,17 +98,17 @@ def layer_norm_fwd_kernel(
B += group * N B += group * N
# Compute mean and variance # Compute mean and variance
cols = tl.arange(0, BLOCK_N) cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE: if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=cols < N).to(tl.float32) z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
x *= z * tl.sigmoid(z) x *= z * tl.sigmoid(z)
if not IS_RMS_NORM: if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean) tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.) xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N var = tl.sum(xbar * xbar, axis=0) / N
else: else:
xbar = tl.where(cols < N, x, 0.) xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps) rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd) tl.store(Rstd + row, rstd)
@@ -149,46 +150,50 @@ def layer_norm_fwd(
# weight = weight.reshape(N) # weight = weight.reshape(N)
# print("weight",weight.shape) # print("weight",weight.shape)
# print("x",x.shape) # print("x",x.shape)
assert weight.shape == (N, ) assert weight.shape == (N,)
assert weight.stride(-1) == 1 assert weight.stride(-1) == 1
if bias is not None: if bias is not None:
assert bias.stride(-1) == 1 assert bias.stride(-1) == 1
assert bias.shape == (N, ) assert bias.shape == (N,)
# allocate output # allocate output
if out is not None: if out is not None:
assert out.shape == x.shape assert out.shape == x.shape
else: else:
out = torch.empty_like(x) out = torch.empty_like(x)
assert out.stride(-1) == 1 assert out.stride(-1) == 1
mean = torch.empty((ngroups * M, ), dtype=torch.float32, mean = (
device=x.device) if not is_rms_norm else None torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm
else None
)
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel # Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size() MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N: if group_size > BLOCK_N:
raise RuntimeError( raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
"This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps # heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8) num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups) grid = (M, ngroups)
layer_norm_fwd_kernel[grid](x, layer_norm_fwd_kernel[grid](
out, x,
weight, out,
bias, weight,
z, bias,
mean, z,
rstd, mean,
x.stride(0), rstd,
out.stride(0), x.stride(0),
z.stride(0) if z is not None else 0, out.stride(0),
M, z.stride(0) if z is not None else 0,
group_size, M,
eps, group_size,
BLOCK_N=BLOCK_N, eps,
NORM_BEFORE_GATE=norm_before_gate, BLOCK_N=BLOCK_N,
IS_RMS_NORM=is_rms_norm, NORM_BEFORE_GATE=norm_before_gate,
num_warps=num_warps) IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
)
return out, mean, rstd return out, mean, rstd
@@ -196,17 +201,18 @@ class LayerNormFn(torch.autograd.Function):
@input_guard @input_guard
@staticmethod @staticmethod
def forward(ctx, def forward(
x, ctx,
weight, x,
bias, weight,
z=None, bias,
eps=1e-6, z=None,
group_size=None, eps=1e-6,
norm_before_gate=True, group_size=None,
is_rms_norm=False): norm_before_gate=True,
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) is_rms_norm=False,
""" ):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape x_shape_og = x.shape
# reshape input data into 2D tensor # reshape input data into 2D tensor
@@ -223,16 +229,15 @@ class LayerNormFn(torch.autograd.Function):
weight = weight.contiguous() weight = weight.contiguous()
if bias is not None: if bias is not None:
bias = bias.contiguous() bias = bias.contiguous()
y, mean, rstd = layer_norm_fwd( # y, mean, rstd = torch.ops.xspeedgate_ops.rms_norm_gated_fwd(x, weight, bias, eps, z, group_size, norm_before_gate, is_rms_norm)
x, y = torch.empty_like(x)
weight, mean, rstd = None, None
bias, import kunlun_ops
eps,
z=z, kunlun_ops.rms_norm_gated(
group_size=group_size, x, y, z, weight, eps, group_size, norm_before_gate, is_rms_norm
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
) )
ctx.save_for_backward(x, weight, bias, mean, rstd, z) ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og ctx.x_shape_og = x_shape_og
ctx.eps = eps ctx.eps = eps
@@ -242,27 +247,27 @@ class LayerNormFn(torch.autograd.Function):
return y.reshape(x_shape_og) return y.reshape(x_shape_og)
def layernorm_fn(x, def layernorm_fn(
weight, x,
bias, weight,
z=None, bias,
eps=1e-6, z=None,
group_size=None, eps=1e-6,
norm_before_gate=True, group_size=None,
is_rms_norm=False): norm_before_gate=True,
return LayerNormFn.apply(x, weight, bias, z, eps, group_size, is_rms_norm=False,
norm_before_gate, is_rms_norm) ):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
)
def rmsnorm_fn(x, def rmsnorm_fn(
weight, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
bias, ):
z=None, return LayerNormFn.apply(
eps=1e-6, x, weight, bias, z, eps, group_size, norm_before_gate, True
group_size=None, )
norm_before_gate=True):
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
norm_before_gate, True)
class LayerNormGated(nn.Module): class LayerNormGated(nn.Module):
@@ -294,15 +299,16 @@ class LayerNormGated(nn.Module):
torch.nn.init.zeros_(self.bias) torch.nn.init.zeros_(self.bias)
def forward(self, x, z=None): def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
""" return layernorm_fn(
return layernorm_fn(x, x,
self.weight, self.weight,
self.bias, self.bias,
z=z, z=z,
group_size=self.group_size, group_size=self.group_size,
eps=self.eps, eps=self.eps,
norm_before_gate=self.norm_before_gate) norm_before_gate=self.norm_before_gate,
)
class RMSNormGated(nn.Module): class RMSNormGated(nn.Module):
@@ -332,12 +338,13 @@ class RMSNormGated(nn.Module):
torch.nn.init.ones_(self.weight) torch.nn.init.ones_(self.weight)
def forward(self, x, z=None): def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
""" return rmsnorm_fn(
return rmsnorm_fn(x, x,
self.weight, self.weight,
self.bias, self.bias,
z=z, z=z,
eps=self.eps, eps=self.eps,
group_size=self.group_size, group_size=self.group_size,
norm_before_gate=self.norm_before_gate) norm_before_gate=self.norm_before_gate,
)

View File

@@ -11,7 +11,6 @@
from typing import Optional from typing import Optional
import torch import torch
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices from .index import prepare_chunk_indices
@@ -28,6 +27,7 @@ RESOLUTION = {
torch.complex64: 1.3e-6, torch.complex64: 1.3e-6,
} }
def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1): def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
assert res.dtype == dtype assert res.dtype == dtype
ref = ref.to(dtype) ref = ref.to(dtype)
@@ -35,6 +35,7 @@ def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
rtol = RESOLUTION[dtype] rtol = RESOLUTION[dtype]
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan) torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
@@ -80,7 +81,6 @@ def recompute_u_fwd_kernel(
p_beta = tl.make_block_ptr( p_beta = tl.make_block_ptr(
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
) )
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
p_A = tl.make_block_ptr( p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0) A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
) )
@@ -110,7 +110,6 @@ def recompute_u_fwd_kernel(
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
@@ -195,53 +194,12 @@ def recompute_w_u_fwd(
A: torch.Tensor, A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor], cu_seqlens: Optional[torch.LongTensor],
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
BT = A.shape[-1] BT = A.shape[-1]
chunk_indices = prepare_chunk_indices( chunk_indices = (
cu_seqlens, BT) if cu_seqlens is not None else None prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 64
BV = 64
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
recompute_u_fwd_kernel[(NT, B * H)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g_cumsum,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
) )
recompute_w_fwd_kernel[(NT, B * H)]( w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
k=k, k, v, beta, g_cumsum, A, cu_seqlens, chunk_indices, chunk_size=BT
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g_cumsum,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
) )
return w, u return w, u

View File

@@ -15,51 +15,52 @@
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
# #
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
from vllm.model_executor.layers import layernorm
from typing import Optional, Union from typing import Optional, Union
import kunlun_ops
import torch
from vllm.model_executor.layers import layernorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm
def vllm_kunlun_forward_cuda( def vllm_kunlun_forward_cuda(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""forward_cuda""" """forward_cuda"""
if x.is_contiguous() == False: if not x.is_contiguous():
# kunlun does not support uncontiguous input and they do not think it is a bug # kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually # so we must make it contiguous() manually
x = x.contiguous() x = x.contiguous()
if self.variance_size_override is not None: if self.variance_size_override is not None:
return self.forward_native(x, residual) return self.forward_native(x, residual)
if residual is not None:
if residual is not None: # residual_output = torch.empty_like(residual)
# residual_output = torch.empty_like(residual) torch.ops._C.add_rmsnorm(
torch.ops._C.add_rmsnorm(
x,
residual,
residual_output=residual,
weight=self.weight.data,
eps=self.variance_epsilon,
output=x
)
return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
x, x,
self.weight.data, residual,
out, residual_output=residual,
self.variance_epsilon, weight=self.weight.data,
eps=self.variance_epsilon,
output=x,
) )
return out return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
x,
self.weight.data,
out,
self.variance_epsilon,
)
return out
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
RMSNorm.forward = vllm_kunlun_forward_cuda RMSNorm.forward = vllm_kunlun_forward_cuda
class KunlunGemmaRMSNorm(OriGemmaRMSNorm): class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
@staticmethod @staticmethod
def forward_xpu( def forward_xpu(
@@ -68,30 +69,42 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if x.is_contiguous() == False: if not x.is_contiguous():
# kunlun does not support uncontiguous input and they do not think it is a bug # kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually # so we must make it contiguous() manually
x = x.contiguous() x = x.contiguous()
if x.dim() == 3:
x_shape = x.shape
x = x.view(-1, x.size(-1))
if residual is not None: if residual is not None:
torch.ops._C.add_rmsnorm( out = torch.empty_like(x)
out_residual = torch.empty_like(residual)
torch.ops._C.gemma_add_rmsnorm(
x, x,
residual, residual,
residual_output=residual, residual_output=out_residual,
weight=weight+1, weight=weight,
eps=variance_epsilon, eps=variance_epsilon,
output=x output=out,
)
else:
out = torch.empty_like(x)
torch.ops._C.gemma_rmsnorm(
x,
weight,
out,
variance_epsilon,
) )
return x, residual
out = torch.empty_like(x) if x.dim() == 3:
torch.ops._C.rmsnorm( x = x.view(x_shape)
x, if out is not None:
weight+1, out = out.view(x_shape)
out,
variance_epsilon, if residual is not None:
) return out, out_residual
return out else:
return out
def forward_cuda( def forward_cuda(
self, self,
@@ -99,16 +112,17 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if torch.compiler.is_compiling(): if torch.compiler.is_compiling():
self.forward_static = self.forward_xpu # only use in cudagraph self.forward_static = self.forward_xpu # only use in cudagraph
return self.forward_native(x, residual) return self.forward_native(x, residual)
if not getattr(self, "_is_compiled", False): if not getattr(self, "_is_compiled", False):
self.forward_static = torch.compile( # type: ignore self.forward_static = torch.compile( # type: ignore
self.forward_static, backend="aot_eager") self.forward_static, backend="aot_eager"
)
self._is_compiled = True self._is_compiled = True
return self.forward_native(x, residual) return self.forward_native(x, residual)
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
RMSNorm.forward = vllm_kunlun_forward_cuda RMSNorm.forward = vllm_kunlun_forward_cuda
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,390 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Backend for GatedDeltaNet attention."""
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends import gdn_attn
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class GDNAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
return GDNAttentionMetadataBuilder
@dataclass
class GDNAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_spec_decodes: int
num_spec_decode_tokens: int
num_actual_tokens: int
has_initial_state: Optional[torch.Tensor] = None
has_initial_state_cpu: Optional[torch.Tensor] = None
spec_query_start_loc: Optional[torch.Tensor] = (
None # shape: [num_spec_decodes + 1,]
)
non_spec_query_start_loc: Optional[torch.Tensor] = (
None # shape: [batch - num_spec_decodes + 1,]
)
spec_state_indices_tensor: Optional[torch.Tensor] = None # shape: [batch, num_spec]
non_spec_state_indices_tensor: Optional[torch.Tensor] = (
None # shape: [batch - num_spec_decodes,]
)
non_spec_state_indices_tensor_cpu: Optional[torch.Tensor] = None
spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,]
spec_token_masks: Optional[torch.Tensor] = (
None # shape: [num_prefill_tokens + num_decode_tokens,]
)
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: int = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
assert isinstance(kv_cache_spec, MambaSpec)
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.speculative_config = vllm_config.speculative_config
self.kv_cache_spec = kv_cache_spec
if self.speculative_config:
self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self._init_reorder_batch_threshold(1, self.use_spec_decode)
self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
self.compilation_config.max_capture_size,
)
self.spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1),
dtype=torch.int32,
device=device,
)
self.non_spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.spec_sequence_masks = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.bool,
device=device,
)
self.spec_token_masks = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.bool,
device=device,
)
self.spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
device=device,
)
self.non_spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
device=device,
)
self.num_accepted_tokens = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
def build( # type: ignore[override]
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
num_accepted_tokens: Optional[torch.Tensor] = None,
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
fast_build: bool = False,
) -> GDNAttentionMetadata:
m = common_attn_metadata
query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device)
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if (
not self.use_spec_decode
or num_decode_draft_tokens_cpu is None
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
.sum()
.item()
== 0
):
spec_sequence_masks = None
num_spec_decodes = 0
else:
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
num_spec_decodes = spec_sequence_masks.sum().item()
if num_spec_decodes == 0:
spec_sequence_masks = None
else:
spec_sequence_masks = spec_sequence_masks.to(
query_start_loc.device, non_blocking=True
)
if spec_sequence_masks is None:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(m, decode_threshold=1)
)
num_spec_decode_tokens = 0
spec_token_masks = None
spec_state_indices_tensor = None
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
num_accepted_tokens = None
else:
query_lens = query_start_loc[1:] - query_start_loc[:-1]
non_spec_query_lens = query_lens[~spec_sequence_masks]
num_decodes = (non_spec_query_lens == 1).sum().item()
num_prefills = non_spec_query_lens.size(0) - num_decodes
num_decode_tokens = num_decodes
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
if num_prefills == 0 and num_decodes == 0:
spec_token_masks = torch.ones(
(
min(
num_spec_decodes * (self.num_spec + 1),
query_start_loc[-1].item(),
)
),
dtype=torch.bool,
device=query_start_loc.device,
)
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
non_spec_query_start_loc = None
else:
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens
)
spec_state_indices_tensor = m.block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
]
non_spec_state_indices_tensor = m.block_table_tensor[
~spec_sequence_masks, 0
]
spec_query_start_loc = torch.zeros(
num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device,
)
torch.cumsum(
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
)
non_spec_query_start_loc = torch.zeros(
query_lens.size(0) - num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device,
)
torch.cumsum(
query_lens[~spec_sequence_masks],
dim=0,
out=non_spec_query_start_loc[1:],
)
num_spec_decode_tokens = (
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
)
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
if num_prefills > 0:
has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks]
has_initial_state_cpu = has_initial_state.cpu()
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(non_spec_query_start_loc)
)
else:
has_initial_state = None
has_initial_state_cpu = None
num_actual_tokens = (
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
)
# prepare tensors for cudagraph
#
# With speculative decoding, the xgrammar backend may rollback tokens
# and causing some sequences has less draft tokens than self.num_spec.
#
# In above cases, the max possible batch size for n tokens, can be
# min(n, cudagraph_max_bs).
if (
self.use_full_cuda_graph
and num_prefills == 0
and num_decodes == 0
and num_spec_decodes <= self.decode_cudagraph_max_bs
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True
)
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
self.spec_sequence_masks[:num_spec_decodes].copy_(
spec_sequence_masks, non_blocking=True
)
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
spec_sequence_masks[num_spec_decodes:].fill_(False)
assert spec_token_masks is not None
self.spec_token_masks[: spec_token_masks.size(0)].copy_(
spec_token_masks, non_blocking=True
)
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
spec_token_masks[spec_token_masks.size(0) :].fill_(False)
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
spec_query_start_loc, non_blocking=True
)
spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index]
spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)
self.num_accepted_tokens[:num_spec_decodes].copy_(
num_accepted_tokens, non_blocking=True
)
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
num_accepted_tokens[num_spec_decodes:].fill_(1)
if (
self.use_full_cuda_graph
and num_prefills == 0
and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs
):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
batch_size = num_actual_tokens
self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True
)
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
:batch_size
]
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
non_spec_query_start_loc, non_blocking=True
)
non_spec_num_query_tokens = non_spec_query_start_loc[
-1
] # type: ignore[index]
non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)
if num_accepted_tokens is not None:
num_accepted_tokens = num_accepted_tokens.to(torch.int32)
attn_metadata = GDNAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens,
num_actual_tokens=num_actual_tokens,
has_initial_state=has_initial_state,
has_initial_state_cpu=has_initial_state_cpu,
spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc,
spec_state_indices_tensor=spec_state_indices_tensor,
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
non_spec_state_indices_tensor_cpu=(
non_spec_state_indices_tensor.cpu()
if non_spec_state_indices_tensor is not None
else None
),
spec_sequence_masks=spec_sequence_masks,
spec_token_masks=spec_token_masks,
num_accepted_tokens=num_accepted_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert (
m.num_reqs <= self.decode_cudagraph_max_bs
and m.num_actual_tokens <= self.decode_cudagraph_max_bs
), (
f"GDN only supports decode-only full CUDAGraph capture. "
f"Make sure batch size ({m.num_reqs}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
f"and number of tokens ({m.num_actual_tokens}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})."
)
num_accepted_tokens = torch.diff(m.query_start_loc)
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
gdn_attn.GDNAttentionMetadata = GDNAttentionMetadata
gdn_attn.GDNAttentionMetadataBuilder = GDNAttentionMetadataBuilder

View File

@@ -770,24 +770,14 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# If kv_cache is not provided, the new key and value tensors are # If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory # not cached. This happens during the initial memory
value = value.contiguous() value = value.contiguous()
if key_cache.is_contiguous(): kunlun_ops.reshape_and_cache_flash(
kunlun_ops.reshape_and_cache( key[: attn_metadata.num_actual_tokens],
key[: attn_metadata.num_actual_tokens], value[: attn_metadata.num_actual_tokens],
value[: attn_metadata.num_actual_tokens], key_cache,
key_cache, value_cache,
value_cache, updated_slot_mapping,
updated_slot_mapping, BLHD_LAYOUT=False,
) )
else:
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
kunlun_ops.reshape_and_cache_flash(
key,
value,
cast_key_cache,
cast_value_cache,
updated_slot_mapping,
)
assert attn_type == AttentionType.DECODER assert attn_type == AttentionType.DECODER
# Decoder self-attention supports chunked prefill. # Decoder self-attention supports chunked prefill.

View File

@@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Optional
from typing import Union
import kunlun_ops
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module):
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
''' """
Args: Args:
metadata: metadata:
Metadata for spec decoding. Metadata for spec decoding.
@@ -81,7 +80,7 @@ class RejectionSampler(nn.Module):
Returns: Returns:
output_token_ids (torch.Tensor): output_token_ids (torch.Tensor):
A tensor containing the final output token IDs. A tensor containing the final output token IDs.
''' """
assert metadata.max_spec_len <= MAX_SPEC_LEN assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size] # [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the # NOTE(woosuk): `target_logits` can be updated in place inside the
@@ -124,11 +123,11 @@ class RejectionSampler(nn.Module):
""" """
output_token_ids_np = output_token_ids.cpu().numpy() output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens. # Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
(output_token_ids_np < vocab_size)) output_token_ids_np < vocab_size
)
outputs = [ outputs = [
row[valid_mask[i]].tolist() row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
for i, row in enumerate(output_token_ids_np)
] ]
return outputs return outputs
@@ -179,25 +178,15 @@ def rejection_sample(
if not sampling_metadata.all_random: if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests. # Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1) target_argmax = target_probs.argmax(dim=-1)
if min(num_draft_tokens) == 1 and max( kunlun_ops.rejection_greedy_sample(
num_draft_tokens) == 1 and sampling_metadata.all_greedy: output_token_ids,
rejection_greedy_sample_spec_len_1_pytorch( cu_num_draft_tokens,
output_token_ids, draft_token_ids,
draft_token_ids, target_argmax,
target_argmax, bonus_token_ids,
bonus_token_ids, is_greedy,
) max_spec_len,
else: )
rejection_greedy_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
num_draft_tokens,
max_spec_len,
is_greedy,
)
if sampling_metadata.all_greedy: if sampling_metadata.all_greedy:
return output_token_ids return output_token_ids
@@ -222,8 +211,9 @@ def rejection_sample(
sampling_metadata, sampling_metadata,
device, device,
) )
bonus_token_ids = bonus_token_ids.squeeze(1)
rejection_random_sample_pytorch( kunlun_ops.rejection_random_sample(
output_token_ids, output_token_ids,
cu_num_draft_tokens, cu_num_draft_tokens,
draft_token_ids, draft_token_ids,
@@ -235,8 +225,7 @@ def rejection_sample(
is_greedy, is_greedy,
max_spec_len, max_spec_len,
vocab_size, vocab_size,
IS_NGRAM=draft_probs is None, no_draft_probs=draft_probs is None,
# num_warps=1,
) )
return output_token_ids return output_token_ids
@@ -374,7 +363,7 @@ def generate_uniform_probs(
random values in the range [0, 1). random values in the range [0, 1).
""" """
uniform_probs = torch.rand( uniform_probs = torch.rand(
(num_tokens, ), (num_tokens,),
dtype=torch.float32, dtype=torch.float32,
device=device, device=device,
) )
@@ -422,7 +411,7 @@ def sample_recovered_tokens(
q[i].exponential_(generator=generator) q[i].exponential_(generator=generator)
recovered_token_ids = torch.empty_like(draft_token_ids) recovered_token_ids = torch.empty_like(draft_token_ids)
sample_recovered_tokens_pytorch( kunlun_ops.sample_recovered_tokens(
recovered_token_ids, recovered_token_ids,
cu_num_draft_tokens, cu_num_draft_tokens,
draft_token_ids, draft_token_ids,
@@ -430,16 +419,16 @@ def sample_recovered_tokens(
target_probs, target_probs,
q, q,
vocab_size, vocab_size,
IS_NGRAM=draft_probs is None, no_draft_probs=draft_probs is None,
) )
return recovered_token_ids return recovered_token_ids
def rejection_greedy_sample_spec_len_1_pytorch( def rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids, # [batch_size, 2] output_token_ids, # [batch_size, 2]
draft_token_ids, # [num_tokens] draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens] target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size] bonus_token_ids, # [batch_size]
): ):
batch_size = output_token_ids.size(0) batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0) num_tokens = draft_token_ids.size(0)
@@ -447,73 +436,72 @@ def rejection_greedy_sample_spec_len_1_pytorch(
accept_req_mask = draft_token_ids == target_argmax accept_req_mask = draft_token_ids == target_argmax
output_token_ids[:, 0] = target_argmax output_token_ids[:, 0] = target_argmax
bonus_token_ids = bonus_token_ids.squeeze(1) bonus_token_ids = bonus_token_ids.squeeze(1)
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids, output_token_ids[:, 1] = torch.where(
output_token_ids[:, 1]) accept_req_mask, bonus_token_ids, output_token_ids[:, 1]
)
def rejection_greedy_sample_pytorch( def rejection_greedy_sample_pytorch(
output_token_ids, # [batch_size, max_spec_len + 1] output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size] cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens] draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens] target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size] bonus_token_ids, # [batch_size]
draft_tokens_per_req, # [batch_size], list draft_tokens_per_req, # [batch_size], list
max_spec_len, max_spec_len,
is_greedy=None, # [batch_size] or None is_greedy=None, # [batch_size] or None
): ):
batch_size = output_token_ids.size(0) batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0) num_tokens = draft_token_ids.size(0)
device = output_token_ids.device device = output_token_ids.device
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to( draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
device, non_blocking=True) device, non_blocking=True
)
if is_greedy is None: if is_greedy is None:
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device) is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
start_indices = cu_num_draft_tokens - draft_tokens_per_req start_indices = cu_num_draft_tokens - draft_tokens_per_req
req_ids = torch.arange(batch_size, device=device) req_ids = torch.arange(batch_size, device=device)
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req) token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
token_positions = torch.arange( token_positions = (
num_tokens, device=device) - start_indices[token_req_ids] torch.arange(num_tokens, device=device) - start_indices[token_req_ids]
)
# Find the first mismatch position of each request. # Find the first mismatch position of each request.
mismatch_global = (draft_token_ids != target_argmax) mismatch_global = draft_token_ids != target_argmax
if max_spec_len == 0: if max_spec_len == 0:
first_mismatch_pos_per_req = torch.zeros(batch_size, first_mismatch_pos_per_req = torch.zeros(
dtype=torch.long, batch_size, dtype=torch.long, device=device
device=device) )
else: else:
# [bs, max_spec_len] # [bs, max_spec_len]
pos_matrix = torch.full((batch_size, max_spec_len), pos_matrix = torch.full(
-1, (batch_size, max_spec_len), -1, dtype=torch.long, device=device
dtype=torch.long, )
device=device)
pos_matrix[token_req_ids, token_positions] = token_positions pos_matrix[token_req_ids, token_positions] = token_positions
mismatch_matrix = torch.full((batch_size, max_spec_len), mismatch_matrix = torch.full(
False, (batch_size, max_spec_len), False, dtype=torch.bool, device=device
dtype=torch.bool, )
device=device)
mismatch_matrix[token_req_ids, token_positions] = mismatch_global mismatch_matrix[token_req_ids, token_positions] = mismatch_global
mismatch_positions = torch.where(mismatch_matrix, pos_matrix, mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2)
max_spec_len * 2)
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1) first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2) no_mismatch_mask = first_mismatch_pos_per_req == max_spec_len * 2
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[ first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
no_mismatch_mask] no_mismatch_mask
]
# Copy matched target tokens into output. # Copy matched target tokens into output.
copy_len = torch.minimum(first_mismatch_pos_per_req + 1, copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req)
draft_tokens_per_req) copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1)
copy_indices = torch.arange(max_spec_len + 1,
device=device).expand(batch_size, -1)
copy_mask = copy_indices < copy_len.unsqueeze(1) copy_mask = copy_indices < copy_len.unsqueeze(1)
greedy_mask = is_greedy.unsqueeze(1) greedy_mask = is_greedy.unsqueeze(1)
final_copy_mask = copy_mask & greedy_mask final_copy_mask = copy_mask & greedy_mask
global_idx = start_indices.unsqueeze(1) + copy_indices global_idx = start_indices.unsqueeze(1) + copy_indices
output_token_ids[final_copy_mask] = target_argmax[ output_token_ids[final_copy_mask] = target_argmax[global_idx[final_copy_mask]].to(
global_idx[final_copy_mask]].to(output_token_ids.dtype) output_token_ids.dtype
)
# Fill bonus token. # Fill bonus token.
needs_bonus = is_greedy & (first_mismatch_pos_per_req needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req)
>= draft_tokens_per_req)
if torch.any(needs_bonus): if torch.any(needs_bonus):
bonus_rows = torch.where(needs_bonus)[0] bonus_rows = torch.where(needs_bonus)[0]
bonus_cols = draft_tokens_per_req[bonus_rows] bonus_cols = draft_tokens_per_req[bonus_rows]
@@ -556,11 +544,9 @@ def rejection_random_sample_pytorch(
if IS_NGRAM: if IS_NGRAM:
draft_prob = 1.0 draft_prob = 1.0
else: else:
draft_prob = draft_probs[start_idx + pos, draft_prob = draft_probs[start_idx + pos, draft_token_id].item()
draft_token_id].item()
target_prob = target_probs[start_idx + pos, target_prob = target_probs[start_idx + pos, draft_token_id].item()
draft_token_id].item()
uniform_prob = uniform_probs[start_idx + pos].item() uniform_prob = uniform_probs[start_idx + pos].item()
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
@@ -629,12 +615,11 @@ def sample_recovered_tokens_pytorch(
else: else:
draft_p = draft_probs[token_idx].clone() draft_p = draft_probs[token_idx].clone()
target_p = target_probs[token_idx].clone() target_p = target_probs[token_idx].clone()
prob = torch.maximum(target_p - draft_p, prob = torch.maximum(
torch.tensor(0.0, device=target_p.device)) target_p - draft_p, torch.tensor(0.0, device=target_p.device)
)
q_values = torch.full((vocab_size, ), q_values = torch.full((vocab_size,), float("-inf"), device=q.device)
float('-inf'),
device=q.device)
q_values[:vocab_size] = q[req_idx, :vocab_size] q_values[:vocab_size] = q[req_idx, :vocab_size]
recovered_id = torch.argmax(prob / q_values).item() recovered_id = torch.argmax(prob / q_values).item()
@@ -642,4 +627,3 @@ def sample_recovered_tokens_pytorch(
if IS_NGRAM: if IS_NGRAM:
target_probs[token_idx, draft_token_id] = orig_prob target_probs[token_idx, draft_token_id] = orig_prob

View File

@@ -337,5 +337,5 @@ def prepare_next_token_ids_padded(
return next_token_ids, valid_sampled_tokens_count return next_token_ids, valid_sampled_tokens_count
EagleProposer.propose = propose # EagleProposer.propose = propose
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded

View File

@@ -407,7 +407,7 @@ def add_rmsnorm(
) -> None: ) -> None:
kunlun_ops.add_rmsnorm( kunlun_ops.add_rmsnorm(
x, x,
y, # 原来写 residual这里其实是 y y,
residual_output=residual_output, residual_output=residual_output,
weight=weight, weight=weight,
eps=eps, eps=eps,
@@ -523,6 +523,145 @@ def _fake_add_rmsnorm(
add_rmsnorm.register_fake(_fake_add_rmsnorm) add_rmsnorm.register_fake(_fake_add_rmsnorm)
@custom_op("_C::gemma_add_rmsnorm", mutates_args=())
def gemma_add_rmsnorm(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_add_rmsnorm wrapper")
kunlun_ops.gemma_add_rmsnorm(
x,
y,
weight=weight,
output=output,
eps=eps,
enable_pdl=enable_pdl,
interweaved=interweaved,
store_output_before_norm=store_output_before_norm,
bias=bias,
smooth=smooth,
residual_output=residual_output,
force_sdnn=force_sdnn,
)
@impl("_C::gemma_add_rmsnorm", "CUDA")
def gemma_add_rmsnorm_cuda(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_add_rmsnorm_cuda wrapper")
kunlun_ops.gemma_add_rmsnorm(
x,
y,
weight=weight,
output=output,
eps=eps,
enable_pdl=enable_pdl,
interweaved=interweaved,
store_output_before_norm=store_output_before_norm,
bias=bias,
smooth=smooth,
residual_output=residual_output,
force_sdnn=force_sdnn,
)
def _fake_gemma_add_rmsnorm(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
force_sdnn: bool = False,
):
output.fake_shape = x.shape
output.fake_dtype = x.dtype
return None
gemma_add_rmsnorm.register_fake(_fake_gemma_add_rmsnorm)
@custom_op("_C::gemma_rmsnorm", mutates_args=())
def gemma_rmsnorm(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweave: bool = False,
bias: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_rmsnorm wrapper")
kunlun_ops.gemma_rmsnorm(
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
)
@impl("_C::gemma_rmsnorm", "CUDA")
def gemma_rmsnorm_cuda(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweave: bool = False,
bias: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_rmsnorm_cuda wrapper")
kunlun_ops.gemma_rmsnorm(
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
)
def _fake_gemma_rmsnorm(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweave: bool = False,
bias: torch.Tensor = None,
force_sdnn: bool = False,
):
# 设置 shape/dtype但不返回值
output.fake_shape = x.shape
output.fake_dtype = x.dtype
return None
gemma_rmsnorm.register_fake(_fake_gemma_rmsnorm)
@custom_op("_C::split_norm_rope_neox", mutates_args=()) @custom_op("_C::split_norm_rope_neox", mutates_args=())
def split_norm_rope_neox( def split_norm_rope_neox(
q_emb: torch.Tensor, q_emb: torch.Tensor,