[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)
Signed-off-by: xyDong0223 <dongxinyu03@baidu.com> Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -3,95 +3,113 @@ from vllm import ModelRegistry
|
|||||||
|
|
||||||
def register_model():
|
def register_model():
|
||||||
# from .demo_model import DemoModel # noqa: F401
|
# from .demo_model import DemoModel # noqa: F401
|
||||||
from .qwen2_vl import Qwen2VLForConditionalGeneration #noqa: F401
|
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration # noqa: F401
|
||||||
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration #noqa: F401
|
from .qwen2_vl import Qwen2VLForConditionalGeneration # noqa: F401
|
||||||
from .qwen3_moe import Qwen3MoeForCausalLM #noqa: F401
|
from .qwen3_moe import Qwen3MoeForCausalLM # noqa: F401
|
||||||
from .qwen3_vl import Qwen3VLForConditionalGeneration
|
from .qwen3_omni_moe_thinker import ( # noqa: F401
|
||||||
from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
Qwen3OmniMoeThinkerForConditionalGeneration,
|
||||||
from .qwen3_omni_moe_thinker import Qwen3OmniMoeThinkerForConditionalGeneration
|
)
|
||||||
|
from .qwen3_vl import Qwen3VLForConditionalGeneration # noqa: F401
|
||||||
|
from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration # noqa: F401
|
||||||
|
|
||||||
# from .llama4 import Llama4ForCausalLM #noqa: F401
|
# from .llama4 import Llama4ForCausalLM #noqa: F401
|
||||||
# from .mllama4 import Llama4ForConditionalGeneration #noqa: F401
|
# from .mllama4 import Llama4ForConditionalGeneration #noqa: F401
|
||||||
# from .deepseek_v2 import KunlunDeepseekV2MoE
|
# from .deepseek_v2 import KunlunDeepseekV2MoE
|
||||||
|
|
||||||
# ModelRegistry.register_model(
|
# ModelRegistry.register_model(
|
||||||
# "DemoModel",
|
# "DemoModel",
|
||||||
# "vllm_kunlun.model_executor.models.demo_model:DemoModel")
|
# "vllm_kunlun.model_executor.models.demo_model:DemoModel")
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"Qwen2VLForConditionalGeneration",
|
"Qwen2VLForConditionalGeneration",
|
||||||
"vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration")
|
"vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration",
|
||||||
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"Qwen2_5_VLForConditionalGeneration",
|
"Qwen2_5_VLForConditionalGeneration",
|
||||||
"vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration")
|
"vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration",
|
||||||
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"Qwen3ForCausalLM",
|
"Qwen3ForCausalLM", "vllm_kunlun.models.qwen3:Qwen3ForCausalLM"
|
||||||
"vllm_kunlun.models.qwen3:Qwen3ForCausalLM")
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"Qwen3MoeForCausalLM",
|
"Qwen3MoeForCausalLM", "vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM"
|
||||||
"vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM")
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"Qwen3NextForCausalLM",
|
"Qwen3NextForCausalLM", "vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM"
|
||||||
"vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM")
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"GptOssForCausalLM",
|
"Qwen3NextMTP", "vllm_kunlun.models.qwen3_next_mtp:Qwen3NextMTP"
|
||||||
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"InternLM2ForCausalLM",
|
"GlmForCausalLM", "vllm_kunlun.models.glm:GlmForCausalLM"
|
||||||
"vllm_kunlun.models.internlm2:InternLM2ForCausalLM")
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"InternVLChatModel",
|
"GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM"
|
||||||
"vllm_kunlun.models.internvl:InternVLChatModel")
|
)
|
||||||
|
|
||||||
|
ModelRegistry.register_model(
|
||||||
|
"InternLM2ForCausalLM", "vllm_kunlun.models.internlm2:InternLM2ForCausalLM"
|
||||||
|
)
|
||||||
|
|
||||||
|
ModelRegistry.register_model(
|
||||||
|
"InternVLChatModel", "vllm_kunlun.models.internvl:InternVLChatModel"
|
||||||
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"InternS1ForConditionalGeneration",
|
"InternS1ForConditionalGeneration",
|
||||||
"vllm_kunlun.models.interns1:InternS1ForConditionalGeneration")
|
"vllm_kunlun.models.interns1:InternS1ForConditionalGeneration",
|
||||||
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"Qwen3VLForConditionalGeneration",
|
"Qwen3VLForConditionalGeneration",
|
||||||
"vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration")
|
"vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration",
|
||||||
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"Qwen3VLMoeForConditionalGeneration",
|
"Qwen3VLMoeForConditionalGeneration",
|
||||||
"vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration")
|
"vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration",
|
||||||
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"Qwen3OmniMoeForConditionalGeneration",
|
"Qwen3OmniMoeForConditionalGeneration",
|
||||||
"vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration")
|
"vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration",
|
||||||
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"SeedOssForCausalLM",
|
"SeedOssForCausalLM", "vllm_kunlun.models.seed_oss:SeedOssForCausalLM"
|
||||||
"vllm_kunlun.models.seed_oss:SeedOssForCausalLM")
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"MiMoV2FlashForCausalLM",
|
"MiMoV2FlashForCausalLM",
|
||||||
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM")
|
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM",
|
||||||
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"GptOssForCausalLM",
|
"GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM"
|
||||||
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"DeepseekV3ForCausalLM",
|
"DeepseekV3ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM"
|
||||||
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"DeepseekV32ForCausalLM",
|
"DeepseekV32ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM"
|
||||||
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
|
||||||
"DeepSeekMTPModel",
|
|
||||||
"vllm_kunlun.models.deepseek_mtp:DeepSeekMTP")
|
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"GlmMoeDsaForCausalLM",
|
"DeepSeekMTPModel", "vllm_kunlun.models.deepseek_mtp:DeepSeekMTP"
|
||||||
"vllm_kunlun.models.deepseek_v2:GlmMoeDsaForCausalLM")
|
)
|
||||||
|
|
||||||
|
ModelRegistry.register_model(
|
||||||
|
"GlmMoeDsaForCausalLM", "vllm_kunlun.models.deepseek_v2:GlmMoeDsaForCausalLM"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_quant_method():
|
def register_quant_method():
|
||||||
"""to do"""
|
"""to do"""
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
303
vllm_kunlun/models/qwen3_next_mtp.py
Normal file
303
vllm_kunlun/models/qwen3_next_mtp.py
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Inference-only Qwen3Next MTP model."""
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
|
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
|
ParallelLMHead,
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsPP
|
||||||
|
from vllm.model_executor.models.utils import (
|
||||||
|
AutoWeightsLoader,
|
||||||
|
is_pp_missing_parameter,
|
||||||
|
make_empty_intermediate_tensors_factory,
|
||||||
|
maybe_prefix,
|
||||||
|
)
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.transformers_utils.configs import Qwen3NextConfig
|
||||||
|
|
||||||
|
from .qwen3_next import Qwen3NextDecoderLayer, Qwen3NextRMSNorm
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
KVCache = tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class Qwen3NextMultiTokenPredictor(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
config: Qwen3NextConfig = model_config.hf_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
lora_vocab = (
|
||||||
|
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||||
|
if lora_config
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||||
|
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
|
||||||
|
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
self.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fc = ColumnParallelLinear(
|
||||||
|
self.config.hidden_size * 2,
|
||||||
|
self.config.hidden_size,
|
||||||
|
gather_output=True,
|
||||||
|
bias=False,
|
||||||
|
return_bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.fc",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = torch.nn.ModuleList(
|
||||||
|
Qwen3NextDecoderLayer(
|
||||||
|
vllm_config,
|
||||||
|
layer_type="full_attention",
|
||||||
|
prefix=f"{prefix}.layers.{idx}",
|
||||||
|
)
|
||||||
|
for idx in range(self.num_mtp_layers)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(
|
||||||
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(
|
||||||
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
spec_step_idx: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||||
|
assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
|
||||||
|
inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
|
||||||
|
hidden_states = self.pre_fc_norm_hidden(hidden_states)
|
||||||
|
hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
|
||||||
|
hidden_states = self.fc(hidden_states)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
current_step_idx = spec_step_idx % self.num_mtp_layers
|
||||||
|
hidden_states, residual = self.layers[current_step_idx](
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
residual=residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors(
|
||||||
|
{"hidden_states": hidden_states, "residual": residual}
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=self.config.num_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "mlp.experts" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if (
|
||||||
|
name.endswith(".bias") or name.endswith("_bias")
|
||||||
|
) and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(
|
||||||
|
param,
|
||||||
|
loaded_weight,
|
||||||
|
name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class Qwen3NextMTP(nn.Module, SupportsPP):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": ["up_proj", "down_proj"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
assert (
|
||||||
|
not cache_config.enable_prefix_caching
|
||||||
|
), "Qwen3NextMTP currently does not support prefix caching"
|
||||||
|
|
||||||
|
self.quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = Qwen3NextMultiTokenPredictor(
|
||||||
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
|
||||||
|
)
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
|
)
|
||||||
|
self.logits_processor = LogitsProcessor(
|
||||||
|
self.unpadded_vocab_size, config.vocab_size
|
||||||
|
)
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
):
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
spec_step_idx: int = 0,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self.logits_processor(self.lm_head, hidden_states)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
shared_weight_names = ["embed_tokens", "lm_head"]
|
||||||
|
|
||||||
|
def remap_weight_names(weights):
|
||||||
|
for name, weight in weights:
|
||||||
|
if name.startswith("mtp."):
|
||||||
|
name = name.replace("mtp.", "model.")
|
||||||
|
elif not any(key in name for key in shared_weight_names):
|
||||||
|
continue
|
||||||
|
yield name, weight
|
||||||
|
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
return loader.load_weights(remap_weight_names(weights))
|
||||||
@@ -16,33 +16,33 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""kunlun custom op entry"""
|
"""kunlun custom op entry"""
|
||||||
import torch_xmlir
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import os
|
|
||||||
from typing import Optional, List, Dict
|
|
||||||
import vllm.envs as envs
|
|
||||||
import os
|
|
||||||
import ctypes
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import kunlun_ops
|
import kunlun_ops
|
||||||
logger.info(f"Load custom ops library success!")
|
|
||||||
|
logger.info("Load custom ops library success!")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("Import error msg: %s", e.msg)
|
logger.warning("Import error msg: %s", e.msg)
|
||||||
|
|
||||||
|
|
||||||
_per_token_smooth_quant = True
|
_per_token_smooth_quant = True
|
||||||
|
|
||||||
|
|
||||||
def is_per_token_smooth_quant():
|
def is_per_token_smooth_quant():
|
||||||
""" is per token smooth quant """
|
"""is per token smooth quant"""
|
||||||
return _per_token_smooth_quant
|
return _per_token_smooth_quant
|
||||||
|
|
||||||
|
|
||||||
class KunlunOps:
|
class KunlunOps:
|
||||||
"""KunlunOps"""
|
"""KunlunOps"""
|
||||||
|
|
||||||
# Attention ops
|
# Attention ops
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def paged_attention_v1(
|
def paged_attention_v1(
|
||||||
@@ -67,9 +67,9 @@ class KunlunOps:
|
|||||||
blocksparse_vert_stride,
|
blocksparse_vert_stride,
|
||||||
blocksparse_block_size,
|
blocksparse_block_size,
|
||||||
blocksparse_head_sliding_step,
|
blocksparse_head_sliding_step,
|
||||||
alibi_sqrt=False
|
alibi_sqrt=False,
|
||||||
):
|
):
|
||||||
""" PagedAttentionV1 """
|
"""PagedAttentionV1"""
|
||||||
# block_size = value_cache.shape[2]
|
# block_size = value_cache.shape[2]
|
||||||
kunlun_ops.paged_attention(
|
kunlun_ops.paged_attention(
|
||||||
x=query,
|
x=query,
|
||||||
@@ -81,7 +81,7 @@ class KunlunOps:
|
|||||||
is_context=is_context,
|
is_context=is_context,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
out=output,
|
out=output,
|
||||||
vo_head_dim=128
|
vo_head_dim=128,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -110,9 +110,9 @@ class KunlunOps:
|
|||||||
blocksparse_vert_stride,
|
blocksparse_vert_stride,
|
||||||
blocksparse_block_size,
|
blocksparse_block_size,
|
||||||
blocksparse_head_sliding_step,
|
blocksparse_head_sliding_step,
|
||||||
alibi_sqrt=False
|
alibi_sqrt=False,
|
||||||
):
|
):
|
||||||
""" PagedAttentionV2 """
|
"""PagedAttentionV2"""
|
||||||
# block_size = value_cache.shape[2]
|
# block_size = value_cache.shape[2]
|
||||||
kunlun_ops.paged_attention(
|
kunlun_ops.paged_attention(
|
||||||
x=query,
|
x=query,
|
||||||
@@ -124,31 +124,28 @@ class KunlunOps:
|
|||||||
is_context=is_context,
|
is_context=is_context,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
out=output,
|
out=output,
|
||||||
vo_head_dim=128
|
vo_head_dim=128,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Activation ops
|
# Activation ops
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def silu_and_mul(out: torch.Tensor,
|
def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
|
||||||
x: torch.Tensor):
|
"""silu and mul"""
|
||||||
""" silu and mul """
|
|
||||||
kunlun_ops.silu_and_mul(
|
kunlun_ops.silu_and_mul(
|
||||||
x,
|
x,
|
||||||
axis=-1,
|
axis=-1,
|
||||||
turn=True,
|
turn=True,
|
||||||
out=out,
|
out=out,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Activation ops
|
# Activation ops
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def quick_gelu(out: torch.Tensor,
|
def quick_gelu(out: torch.Tensor, x: torch.Tensor):
|
||||||
x: torch.Tensor):
|
"""quick gelu"""
|
||||||
""" quick gelu """
|
|
||||||
kunlun_ops.quick_gelu(
|
kunlun_ops.quick_gelu(
|
||||||
x,
|
x,
|
||||||
out=out,
|
out=out,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Layernorm
|
# Layernorm
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -159,9 +156,7 @@ class KunlunOps:
|
|||||||
epsilon,
|
epsilon,
|
||||||
):
|
):
|
||||||
"""rms_norm"""
|
"""rms_norm"""
|
||||||
kunlun_ops.rmsnorm(
|
kunlun_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
|
||||||
x, weight.to(torch.float32), epsilon, out=out
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fused_add_rms_norm(
|
def fused_add_rms_norm(
|
||||||
@@ -179,16 +174,11 @@ class KunlunOps:
|
|||||||
residual.copy_(fused_input, non_blocking=True)
|
residual.copy_(fused_input, non_blocking=True)
|
||||||
x.copy_(output)
|
x.copy_(output)
|
||||||
|
|
||||||
|
|
||||||
# Rotary embedding
|
# Rotary embedding
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def rotary_embedding(
|
def rotary_embedding(
|
||||||
positions,
|
positions, query, key, head_size, cos_sin_cache, is_neox_style
|
||||||
query,
|
):
|
||||||
key,
|
|
||||||
head_size,
|
|
||||||
cos_sin_cache,
|
|
||||||
is_neox_style):
|
|
||||||
"""
|
"""
|
||||||
refactor RotaryEmbedding forward function
|
refactor RotaryEmbedding forward function
|
||||||
"""
|
"""
|
||||||
@@ -196,62 +186,38 @@ class KunlunOps:
|
|||||||
key_x = key.contiguous()
|
key_x = key.contiguous()
|
||||||
|
|
||||||
torch.ops._C.rotary_embedding(
|
torch.ops._C.rotary_embedding(
|
||||||
positions,
|
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
|
||||||
query_x,
|
)
|
||||||
key_x,
|
|
||||||
head_size,
|
|
||||||
cos_sin_cache,
|
|
||||||
is_neox_style)
|
|
||||||
|
|
||||||
return query_x, key_x
|
return query_x, key_x
|
||||||
|
|
||||||
# Rotary embedding
|
# Rotary embedding
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mrotary_embedding(
|
def mrotary_embedding(
|
||||||
positions,
|
positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
|
||||||
mrope_section,
|
):
|
||||||
query,
|
|
||||||
key,
|
|
||||||
head_size,
|
|
||||||
cos_sin_cache,
|
|
||||||
is_neox_style):
|
|
||||||
"""
|
"""
|
||||||
refactor RotaryEmbedding forward function
|
refactor RotaryEmbedding forward function
|
||||||
"""
|
"""
|
||||||
query_x = query.contiguous()
|
query_x = query.contiguous()
|
||||||
key_x = key.contiguous()
|
key_x = key.contiguous()
|
||||||
query_x_dim = query_x.dim()
|
|
||||||
assert is_neox_style
|
assert is_neox_style
|
||||||
kunlun_ops.mrotary_embedding_neox(
|
kunlun_ops.mrotary_embedding_neox(
|
||||||
positions,
|
positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
|
||||||
query_x,
|
)
|
||||||
key_x,
|
|
||||||
head_size,
|
|
||||||
cos_sin_cache,
|
|
||||||
mrope_section)
|
|
||||||
|
|
||||||
query.data = query_x
|
query.data = query_x
|
||||||
key.data = key_x
|
key.data = key_x
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def swap_blocks(
|
def swap_blocks(src, dst, block_mapping):
|
||||||
src,
|
"""swap_blocks"""
|
||||||
dst,
|
kunlun_ops.swap_blocks(src, dst, block_mapping)
|
||||||
block_mapping):
|
|
||||||
""" swap_blocks """
|
|
||||||
kunlun_ops.swap_blocks(
|
|
||||||
src,
|
|
||||||
dst,
|
|
||||||
block_mapping
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def copy_blocks(
|
def copy_blocks(key_caches, value_caches, block_mapping):
|
||||||
key_caches,
|
"""copy_blocks"""
|
||||||
value_caches,
|
|
||||||
block_mapping):
|
|
||||||
""" copy_blocks """
|
|
||||||
for i in range(len(key_caches)):
|
for i in range(len(key_caches)):
|
||||||
key_caches[i] = key_caches[i].contiguous()
|
key_caches[i] = key_caches[i].contiguous()
|
||||||
value_caches[i] = value_caches[i].contiguous()
|
value_caches[i] = value_caches[i].contiguous()
|
||||||
@@ -269,16 +235,10 @@ class KunlunOps:
|
|||||||
value_cache,
|
value_cache,
|
||||||
slot_mapping,
|
slot_mapping,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
):
|
):
|
||||||
""" reshape_and_cache """
|
"""reshape_and_cache"""
|
||||||
# slot_mapping_cast = slot_mapping.to(torch.int32)
|
# slot_mapping_cast = slot_mapping.to(torch.int32)
|
||||||
kunlun_ops.reshape_and_cache(
|
kunlun_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||||
key,
|
|
||||||
value,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
slot_mapping
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def multi_query_kv_attention(
|
def multi_query_kv_attention(
|
||||||
@@ -287,7 +247,7 @@ class KunlunOps:
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
**kargs
|
**kargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||||
@@ -297,16 +257,12 @@ class KunlunOps:
|
|||||||
key = key.unsqueeze(0)
|
key = key.unsqueeze(0)
|
||||||
value = value.unsqueeze(0)
|
value = value.unsqueeze(0)
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
alibi_slopes = kargs.get("alibi_slopes", None)
|
|
||||||
mask = kargs.get("mask", None)
|
|
||||||
is_causal = kargs.get("is_causal", True)
|
|
||||||
is_lvsl = kargs.get("is_lvsl", True)
|
|
||||||
|
|
||||||
B, T, Qh, Hd = query.shape
|
B, T, Qh, Hd = query.shape
|
||||||
KVh = key.size(2)
|
KVh = key.size(2)
|
||||||
if KVh != Qh:
|
if KVh != Qh:
|
||||||
repeat = Qh // KVh
|
repeat = Qh // KVh
|
||||||
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
|
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
|
||||||
value = value.repeat_interleave(repeat, dim=2)
|
value = value.repeat_interleave(repeat, dim=2)
|
||||||
kunlun_ops.attention(
|
kunlun_ops.attention(
|
||||||
q=query,
|
q=query,
|
||||||
@@ -321,80 +277,90 @@ class KunlunOps:
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def quant_fusedresidual_rmsnorm_op(x,
|
def quant_fusedresidual_rmsnorm_op(
|
||||||
residual,
|
x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
|
||||||
weight,
|
):
|
||||||
bias,
|
|
||||||
scale_to_int,
|
|
||||||
eps,
|
|
||||||
dyn_scale: bool,
|
|
||||||
type: int = 1):
|
|
||||||
"""Quantized fused residual layer normalization"""
|
"""Quantized fused residual layer normalization"""
|
||||||
out = torch.empty_like(x, dtype=torch.int8)
|
out = torch.empty_like(x, dtype=torch.int8)
|
||||||
|
|
||||||
if is_per_token_smooth_quant():
|
if is_per_token_smooth_quant():
|
||||||
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
|
out_scale = torch.empty(
|
||||||
|
x.shape[:-1], device=x.device, dtype=torch.float
|
||||||
|
).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||||
|
|
||||||
kunlun_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
|
kunlun_ops.quant_fusedresidual_rmsnorm(
|
||||||
out=out, out_scale=out_scale , residual_tensor=residual)
|
x,
|
||||||
|
residual,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
eps,
|
||||||
|
out=out,
|
||||||
|
out_scale=out_scale,
|
||||||
|
residual_tensor=residual,
|
||||||
|
)
|
||||||
|
|
||||||
if residual is None:
|
if residual is None:
|
||||||
return out, out_scale
|
return out, out_scale
|
||||||
return out, out_scale, residual
|
return out, out_scale, residual
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def quant_rmsnorm_op(x,
|
def quant_rmsnorm_op(
|
||||||
weight,
|
x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
|
||||||
bias,
|
):
|
||||||
scale_to_int,
|
|
||||||
eps,
|
|
||||||
dyn_scale : bool,
|
|
||||||
type: int = 1):
|
|
||||||
"""Quantized RMSNorm"""
|
"""Quantized RMSNorm"""
|
||||||
|
|
||||||
out = torch.empty_like(x, dtype=torch.int8)
|
out = torch.empty_like(x, dtype=torch.int8)
|
||||||
if is_per_token_smooth_quant():
|
if is_per_token_smooth_quant():
|
||||||
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
|
out_scale = torch.empty(
|
||||||
|
x.shape[:-1], device=x.device, dtype=torch.float
|
||||||
|
).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||||
|
|
||||||
kunlun_ops.quant_rmsnorm(x, weight, bias, eps,
|
kunlun_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
|
||||||
out=out, out_scale=out_scale)
|
|
||||||
return out, out_scale
|
return out, out_scale
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def smooth_quant_matmul_column_row_kernels(input_tensor,
|
def smooth_quant_matmul_column_row_kernels(
|
||||||
weight,
|
input_tensor,
|
||||||
smoother,
|
weight,
|
||||||
input_scale,
|
smoother,
|
||||||
weight_scale,
|
input_scale,
|
||||||
perTokenScaling,
|
weight_scale,
|
||||||
perChannelScaling,
|
perTokenScaling,
|
||||||
otype):
|
perChannelScaling,
|
||||||
|
otype,
|
||||||
|
):
|
||||||
"""smooth_quant_matmul_column_row_kernels"""
|
"""smooth_quant_matmul_column_row_kernels"""
|
||||||
input_shape = input_tensor.shape
|
input_shape = input_tensor.shape
|
||||||
weight_shape = weight.shape
|
weight_shape = weight.shape
|
||||||
if input_tensor.dim() == 3:
|
if input_tensor.dim() == 3:
|
||||||
input_tensor = input_tensor.reshape(-1, input_shape[-1])
|
input_tensor = input_tensor.reshape(-1, input_shape[-1])
|
||||||
out = torch.empty((input_shape[0] * input_shape[1],
|
out = torch.empty(
|
||||||
weight_shape[0]),
|
(input_shape[0] * input_shape[1], weight_shape[0]),
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
device=weight.device)
|
device=weight.device,
|
||||||
|
)
|
||||||
output_bs_shape = [input_shape[0], input_shape[1]]
|
output_bs_shape = [input_shape[0], input_shape[1]]
|
||||||
elif input_tensor.dim() == 2:
|
elif input_tensor.dim() == 2:
|
||||||
out = torch.empty((input_shape[0], weight_shape[0]),
|
out = torch.empty(
|
||||||
dtype=torch.float16,
|
(input_shape[0], weight_shape[0]),
|
||||||
device=weight.device)
|
dtype=torch.float16,
|
||||||
|
device=weight.device,
|
||||||
|
)
|
||||||
output_bs_shape = [-1]
|
output_bs_shape = [-1]
|
||||||
kunlun_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
|
kunlun_ops.smooth_quant_matmul_column_row_kernels(
|
||||||
weight, smoother,
|
input_tensor,
|
||||||
input_scale,
|
weight,
|
||||||
weight_scale,
|
smoother,
|
||||||
perTokenScaling,
|
input_scale,
|
||||||
perChannelScaling,
|
weight_scale,
|
||||||
out=out)
|
perTokenScaling,
|
||||||
|
perChannelScaling,
|
||||||
|
out=out,
|
||||||
|
)
|
||||||
|
|
||||||
out = out.view(*output_bs_shape, weight_shape[0])
|
out = out.view(*output_bs_shape, weight_shape[0])
|
||||||
|
|
||||||
@@ -404,6 +370,7 @@ class KunlunOps:
|
|||||||
if torch.is_tensor(x):
|
if torch.is_tensor(x):
|
||||||
return (type(x), x.device, x.dtype, x.shape, x.is_contiguous())
|
return (type(x), x.device, x.dtype, x.shape, x.is_contiguous())
|
||||||
return (type(x), x)
|
return (type(x), x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fused_moe(
|
def fused_moe(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -420,23 +387,24 @@ class KunlunOps:
|
|||||||
w1_bias: Optional[torch.Tensor] = None,
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
w2_bias: Optional[torch.Tensor] = None,
|
w2_bias: Optional[torch.Tensor] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""fused_moe"""
|
"""fused_moe"""
|
||||||
global_num_experts, up_gate_size, _ = w1.shape
|
global_num_experts, up_gate_size, _ = w1.shape
|
||||||
M, N = hidden_states.shape
|
M, N = hidden_states.shape
|
||||||
hidden_dim = w2.shape[1]
|
hidden_dim = w2.shape[1]
|
||||||
normed_score = torch.empty(M,
|
normed_score = torch.empty(
|
||||||
moe_top_k,
|
M, moe_top_k, dtype=torch.float32, device=hidden_states.device
|
||||||
dtype=torch.float32,
|
)
|
||||||
device=hidden_states.device)
|
topk_ids = torch.empty(
|
||||||
topk_ids = torch.empty(M,
|
M, moe_top_k, dtype=torch.int32, device=hidden_states.device
|
||||||
moe_top_k,
|
)
|
||||||
dtype=torch.int32,
|
|
||||||
device=hidden_states.device)
|
|
||||||
num_blocks = 12
|
num_blocks = 12
|
||||||
block_statistic = torch.zeros(
|
block_statistic = torch.zeros(
|
||||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
num_blocks,
|
||||||
|
global_num_experts,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=hidden_states.device,
|
||||||
)
|
)
|
||||||
router_logits = router_logits.to(torch.float)
|
router_logits = router_logits.to(torch.float)
|
||||||
if scoring_func == "softmax":
|
if scoring_func == "softmax":
|
||||||
@@ -445,24 +413,27 @@ class KunlunOps:
|
|||||||
normed_score=normed_score,
|
normed_score=normed_score,
|
||||||
topk_index=topk_ids,
|
topk_index=topk_ids,
|
||||||
block_statistic=None,
|
block_statistic=None,
|
||||||
stable=True)
|
stable=False,
|
||||||
|
)
|
||||||
elif scoring_func == "sigmoid":
|
elif scoring_func == "sigmoid":
|
||||||
torch.ops._C.moe_sigmoid_group_topk_norm(
|
torch.ops._C.moe_sigmoid_group_topk_norm(
|
||||||
x=router_logits,
|
x=router_logits,
|
||||||
topk_index=topk_ids,
|
topk_index=topk_ids,
|
||||||
norm_score=normed_score,
|
norm_score=normed_score,
|
||||||
block_static=block_statistic,
|
block_static=block_statistic,
|
||||||
bias=e_score_correction_bias,
|
bias=e_score_correction_bias,
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
n_group=num_expert_group,
|
n_group=num_expert_group,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
if w1_bias is not None or w2_bias is not None:
|
if w1_bias is not None or w2_bias is not None:
|
||||||
# Rignt now this branch is for gpt oss
|
# Rignt now this branch is for gpt oss
|
||||||
# TODO (@xyDong23): faster here using moe_fc kernel
|
# TODO (@xyDong23): faster here using moe_fc kernel
|
||||||
normed_score = normed_score.to(hidden_states.dtype)
|
normed_score = normed_score.to(hidden_states.dtype)
|
||||||
out = torch.zeros(M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device)
|
out = torch.zeros(
|
||||||
|
M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device
|
||||||
|
)
|
||||||
repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0)
|
repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0)
|
||||||
topk_ids_flat = topk_ids.flatten()
|
topk_ids_flat = topk_ids.flatten()
|
||||||
for i in range(global_num_experts):
|
for i in range(global_num_experts):
|
||||||
@@ -470,9 +441,13 @@ class KunlunOps:
|
|||||||
selected_token = topk_ids_flat == experts_id
|
selected_token = topk_ids_flat == experts_id
|
||||||
if selected_token.sum():
|
if selected_token.sum():
|
||||||
cur_token = repeat_x[selected_token]
|
cur_token = repeat_x[selected_token]
|
||||||
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
|
up_gate = torch.empty(
|
||||||
dtype=cur_token.dtype, device=cur_token.device)
|
selected_token.sum(),
|
||||||
groupgemm1 = cur_token@ w1[i].T
|
up_gate_size // 2,
|
||||||
|
dtype=cur_token.dtype,
|
||||||
|
device=cur_token.device,
|
||||||
|
)
|
||||||
|
groupgemm1 = cur_token @ w1[i].T
|
||||||
# Add w13 bias
|
# Add w13 bias
|
||||||
if w1_bias is not None:
|
if w1_bias is not None:
|
||||||
groupgemm1 = groupgemm1 + w1_bias[i]
|
groupgemm1 = groupgemm1 + w1_bias[i]
|
||||||
@@ -482,53 +457,129 @@ class KunlunOps:
|
|||||||
if w2_bias is not None:
|
if w2_bias is not None:
|
||||||
groupgemm2 = groupgemm2 + w2_bias[i]
|
groupgemm2 = groupgemm2 + w2_bias[i]
|
||||||
out[selected_token] = groupgemm2
|
out[selected_token] = groupgemm2
|
||||||
ouput = (out.view(M, moe_top_k, N) * normed_score.unsqueeze(2)).sum(dim=1).to(hidden_states.dtype)
|
ouput = (
|
||||||
|
(out.view(M, moe_top_k, N) * normed_score.unsqueeze(2))
|
||||||
|
.sum(dim=1)
|
||||||
|
.to(hidden_states.dtype)
|
||||||
|
)
|
||||||
return ouput
|
return ouput
|
||||||
else:
|
else:
|
||||||
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
|
# from vllm.forward_context import get_forward_context
|
||||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
# forward_context = get_forward_context()
|
||||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
# attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||||
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
|
# prefix = "model.layers.0.linear_attn"
|
||||||
|
# if attn_metadata is not None:
|
||||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
# attn_metadata = attn_metadata[prefix]
|
||||||
|
|
||||||
torch.ops._C.moe_pre_sorted(
|
# if attn_metadata is None or attn_metadata.num_prefills > 0 or :
|
||||||
x=hidden_states,
|
if M * moe_top_k < 400:
|
||||||
topk_index=topk_ids,
|
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
|
||||||
block_statistic=block_statistic,
|
torch.ops.xspeedgate_ops.moe_pre_small(
|
||||||
moe_expand=moe_expand,
|
topk_ids, global_num_experts, False, False, hidden_states
|
||||||
moe_index=sorted_tokens_idx,
|
)
|
||||||
expert_m=expert_m,
|
)
|
||||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
experts_num_lod = torch.ops.xspeedgate_ops.moe_active_expert_balance(
|
||||||
|
topk_ids, global_num_experts, False
|
||||||
|
)
|
||||||
|
out = torch.ops.xspeedgate_ops.fused_moe(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
normed_score.to(hidden_states.dtype),
|
||||||
|
sorted_tokens_num_lod,
|
||||||
|
sorted_tokens_idx,
|
||||||
|
experts_num_lod,
|
||||||
|
)
|
||||||
|
return out.sum(1)
|
||||||
|
|
||||||
y = torch.empty(M,moe_top_k,
|
if M * moe_top_k > 768:
|
||||||
w1.shape[1],
|
moe_expand = torch.empty(
|
||||||
|
(M * moe_top_k, N),
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device=hidden_states.device)
|
device=hidden_states.device,
|
||||||
|
) # [M*top_k, N], float
|
||||||
|
expert_m = torch.zeros(
|
||||||
|
global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||||
|
) # [E]
|
||||||
|
sorted_tokens_num_lod = torch.zeros(
|
||||||
|
global_num_experts + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=hidden_states.device,
|
||||||
|
) # [E+1]
|
||||||
|
sorted_tokens_idx = torch.zeros(
|
||||||
|
M * moe_top_k, dtype=torch.int32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.ops._C.gen_block_statistic(topk_ids, block_statistic)
|
||||||
|
|
||||||
|
torch.ops._C.moe_pre_sorted(
|
||||||
|
x=hidden_states,
|
||||||
|
topk_index=topk_ids,
|
||||||
|
block_statistic=block_statistic,
|
||||||
|
moe_expand=moe_expand,
|
||||||
|
moe_index=sorted_tokens_idx,
|
||||||
|
expert_m=expert_m,
|
||||||
|
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
|
||||||
|
torch.ops.xspeedgate_ops.moe_pre_small(
|
||||||
|
topk_ids,
|
||||||
|
global_num_experts,
|
||||||
|
index_have_neg=False,
|
||||||
|
sort_mode=True,
|
||||||
|
x=hidden_states,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
y = torch.empty(
|
||||||
|
M,
|
||||||
|
moe_top_k,
|
||||||
|
w1.shape[1],
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
|
||||||
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
|
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
|
||||||
|
|
||||||
torch.ops._C.moe_fc(
|
if M < 1024:
|
||||||
x=moe_expand,
|
torch.ops._C.moe_fc(
|
||||||
weight=w1,
|
x=moe_expand,
|
||||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
weight=w1,
|
||||||
sorted_tokens_idx=sorted_tokens_idx,
|
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||||
moe_topk=moe_top_k,
|
sorted_tokens_idx=sorted_tokens_idx,
|
||||||
y=y,
|
moe_topk=moe_top_k,
|
||||||
|
y=y,
|
||||||
|
)
|
||||||
|
|
||||||
|
d = y.shape[-1] // 2
|
||||||
|
output_shape = y.shape[:-1] + (d,)
|
||||||
|
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||||
|
torch.ops._C.silu_and_mul(out1, y)
|
||||||
|
|
||||||
|
out1 = out1.reshape(-1, out1.shape[-1])
|
||||||
|
else:
|
||||||
|
torch.ops._C.moe_fc(
|
||||||
|
x=moe_expand,
|
||||||
|
weight=w1,
|
||||||
|
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||||
|
sorted_tokens_idx=sorted_tokens_idx,
|
||||||
|
moe_topk=moe_top_k,
|
||||||
|
y=y,
|
||||||
|
act="SWISH_GLU",
|
||||||
|
)
|
||||||
|
|
||||||
|
y = y[..., : y.shape[-1] // 2]
|
||||||
|
out1 = y.reshape(-1, y.shape[-1])
|
||||||
|
|
||||||
|
out = torch.empty(
|
||||||
|
M,
|
||||||
|
moe_top_k,
|
||||||
|
w2.shape[1],
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
d = y.shape[-1] // 2
|
|
||||||
output_shape = (y.shape[:-1] + (d, ))
|
|
||||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
|
||||||
torch.ops._C.silu_and_mul(out1, y)
|
|
||||||
|
|
||||||
out = torch.empty(M,moe_top_k,
|
|
||||||
w2.shape[1],
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
device=hidden_states.device)
|
|
||||||
|
|
||||||
out1 = out1.reshape(-1, out1.shape[-1])
|
|
||||||
|
|
||||||
torch.ops._C.moe_fc(
|
torch.ops._C.moe_fc(
|
||||||
x=out1,
|
x=out1,
|
||||||
weight=w2,
|
weight=w2,
|
||||||
@@ -538,8 +589,12 @@ class KunlunOps:
|
|||||||
y=out,
|
y=out,
|
||||||
)
|
)
|
||||||
|
|
||||||
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
|
dequant_scale = torch.ones(
|
||||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
[M, moe_top_k], dtype=torch.float32, device=out.device
|
||||||
|
)
|
||||||
|
output = torch.empty(
|
||||||
|
[M, N], dtype=hidden_states.dtype, device=hidden_states.device
|
||||||
|
)
|
||||||
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
|
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
|
||||||
|
|
||||||
torch.ops._C.moe_post(
|
torch.ops._C.moe_post(
|
||||||
@@ -547,9 +602,9 @@ class KunlunOps:
|
|||||||
moe_index=sorted_tokens_idx,
|
moe_index=sorted_tokens_idx,
|
||||||
normed_scale=normed_score,
|
normed_scale=normed_score,
|
||||||
dequant_scale=dequant_scale,
|
dequant_scale=dequant_scale,
|
||||||
y=output
|
y=output,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -568,23 +623,23 @@ class KunlunOps:
|
|||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
w1_bias: Optional[torch.Tensor] = None,
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
w2_bias: Optional[torch.Tensor] = None,
|
w2_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x = hidden_states
|
x = hidden_states
|
||||||
batch, hidden_size = x.shape
|
batch, hidden_size = x.shape
|
||||||
num_local_experts, up_gate_size, _ = w13_weight.shape
|
num_local_experts, up_gate_size, _ = w13_weight.shape
|
||||||
|
|
||||||
router_logits = x.to(linear_weights.dtype)@linear_weights.T
|
router_logits = x.to(linear_weights.dtype) @ linear_weights.T
|
||||||
|
|
||||||
topk_weights = torch.empty(batch,
|
topk_weights = torch.empty(
|
||||||
top_k,
|
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
|
||||||
dtype=router_logits.dtype,
|
)
|
||||||
device=router_logits.device)
|
topk_ids = torch.empty(
|
||||||
topk_ids = torch.empty(batch,
|
batch, top_k, dtype=torch.int32, device=router_logits.device
|
||||||
top_k,
|
)
|
||||||
dtype=torch.int32,
|
block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
|
||||||
device=router_logits.device)
|
torch.ops._C.moe_softmax_topk(
|
||||||
block_static = torch.empty(0, dtype=torch.int32,device=router_logits.device)
|
router_logits, topk_weights, topk_ids, block_static
|
||||||
torch.ops._C.moe_softmax_topk(router_logits, topk_weights, topk_ids, block_static)
|
)
|
||||||
|
|
||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
|
||||||
@@ -598,11 +653,19 @@ class KunlunOps:
|
|||||||
selected_token = topk_ids_flat == experts_id
|
selected_token = topk_ids_flat == experts_id
|
||||||
if selected_token.sum():
|
if selected_token.sum():
|
||||||
cur_token = repeat_x[selected_token]
|
cur_token = repeat_x[selected_token]
|
||||||
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
|
up_gate = torch.empty(
|
||||||
dtype=cur_token.dtype, device=cur_token.device)
|
selected_token.sum(),
|
||||||
torch.ops._C.silu_and_mul(up_gate, cur_token@ w13_weight[i].T)
|
up_gate_size // 2,
|
||||||
|
dtype=cur_token.dtype,
|
||||||
|
device=cur_token.device,
|
||||||
|
)
|
||||||
|
torch.ops._C.silu_and_mul(up_gate, cur_token @ w13_weight[i].T)
|
||||||
out[selected_token] = up_gate @ w2_weight[i].T
|
out[selected_token] = up_gate @ w2_weight[i].T
|
||||||
output = (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)).sum(dim=1).to(x.dtype)
|
output = (
|
||||||
|
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
|
||||||
|
.sum(dim=1)
|
||||||
|
.to(x.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@@ -638,10 +701,11 @@ class KunlunOps:
|
|||||||
prompt_lods_cpu: torch.Tensor,
|
prompt_lods_cpu: torch.Tensor,
|
||||||
k_cache: torch.Tensor,
|
k_cache: torch.Tensor,
|
||||||
v_cache: torch.Tensor,
|
v_cache: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""mla pa block"""
|
"""mla pa block"""
|
||||||
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
|
output = torch.empty(
|
||||||
device=hidden_states.device)
|
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
|
||||||
|
)
|
||||||
kunlun_ops.xft_multi_head_latent_page_attention_block(
|
kunlun_ops.xft_multi_head_latent_page_attention_block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
q_lora_rank,
|
q_lora_rank,
|
||||||
@@ -679,7 +743,6 @@ class KunlunOps:
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def fused_gdn_gating(
|
def fused_gdn_gating(
|
||||||
A_log: torch.Tensor,
|
A_log: torch.Tensor,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
@@ -695,25 +758,34 @@ class KunlunOps:
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def fused_recurrent_gated_delta_rule_fwd(
|
def fused_recurrent_gated_delta_rule_fwd(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
g: torch.Tensor,
|
g: torch.Tensor,
|
||||||
beta: torch.Tensor,
|
beta: torch.Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
h0_source: torch.Tensor,
|
h0_source: torch.Tensor,
|
||||||
output_final_state: bool,
|
output_final_state: bool,
|
||||||
use_qk_l2norm_in_kernel: bool,
|
use_qk_l2norm_in_kernel: bool,
|
||||||
cu_seqlens: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
cu_seqlens: torch.Tensor = None,
|
||||||
'''
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
|
"""
|
||||||
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。
|
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
|
||||||
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
|
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。
|
||||||
'''
|
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
|
||||||
|
"""
|
||||||
|
|
||||||
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
|
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
|
||||||
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
|
q,
|
||||||
cu_seqlens)
|
k,
|
||||||
return (o, final_state)
|
v,
|
||||||
|
g,
|
||||||
|
beta,
|
||||||
|
scale,
|
||||||
|
h0_source,
|
||||||
|
output_final_state,
|
||||||
|
use_qk_l2norm_in_kernel,
|
||||||
|
cu_seqlens,
|
||||||
|
)
|
||||||
|
return (o, final_state)
|
||||||
|
|||||||
@@ -9,60 +9,196 @@
|
|||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
import cocopod # noqa
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
||||||
from .chunk_o import chunk_fwd_o
|
|
||||||
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
|
||||||
from .cumsum import chunk_local_cumsum
|
|
||||||
from .l2norm import l2norm_fwd
|
from .l2norm import l2norm_fwd
|
||||||
from .solve_tril import solve_tril
|
|
||||||
from .utils import SUPPRESS_LEVEL, input_guard
|
from .utils import SUPPRESS_LEVEL, input_guard
|
||||||
from .wy_fast import recompute_w_u_fwd
|
|
||||||
from .index import prepare_chunk_indices
|
|
||||||
import xspeedgate_ops
|
|
||||||
import cocopod
|
|
||||||
|
|
||||||
|
|
||||||
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
|
def torch_solve_tril(
|
||||||
chunk_size=64
|
A: torch.Tensor,
|
||||||
A = -A.transpose(1,2)
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
|
output_dtype: torch.dtype = torch.float,
|
||||||
|
):
|
||||||
|
chunk_size = 64
|
||||||
|
A = -A.transpose(1, 2)
|
||||||
sequence_length = A.shape[-2]
|
sequence_length = A.shape[-2]
|
||||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||||
A = F.pad(A, (0, 0, 0, pad_size))
|
A = F.pad(A, (0, 0, 0, pad_size))
|
||||||
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
|
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
|
||||||
|
# mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
|
||||||
|
|
||||||
|
# A = A.masked_fill(mask, 0)
|
||||||
for i in range(1, chunk_size):
|
for i in range(1, chunk_size):
|
||||||
row = A[..., i, :i].clone()
|
row = A[..., i, :i].clone()
|
||||||
sub = A[..., :i, :i].clone()
|
sub = A[..., :i, :i].clone()
|
||||||
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||||
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
|
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
|
||||||
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[:,:,:sequence_length,:].transpose(1,2)
|
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[
|
||||||
|
:, :, :sequence_length, :
|
||||||
|
].transpose(1, 2)
|
||||||
|
|
||||||
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
g: torch.Tensor,
|
|
||||||
beta: torch.Tensor,
|
|
||||||
scale: float,
|
|
||||||
initial_state: torch.Tensor,
|
|
||||||
output_final_state: bool,
|
|
||||||
cu_seqlens: Optional[torch.LongTensor] = None):
|
|
||||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
|
||||||
A = chunk_scaled_dot_kkt_fwd(k=k,
|
|
||||||
beta=beta,
|
|
||||||
g_cumsum=g,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
output_dtype=q.dtype)
|
|
||||||
|
|
||||||
#kernel版
|
def recompute_w_u_fwd_torch(
|
||||||
torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
|
k: torch.Tensor, # [B, T, H, K]
|
||||||
chunk_indices = prepare_chunk_indices(
|
v: torch.Tensor, # [B, T, H, V]
|
||||||
cu_seqlens, 64) if cu_seqlens is not None else None
|
beta: torch.Tensor, # [B, T, H]
|
||||||
|
g: torch.Tensor, # [B, T, H]
|
||||||
|
A: torch.Tensor, # [B, H, T, T]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
最简单版本:假设等长序列,key和value头数相同
|
||||||
|
"""
|
||||||
|
chunk_size = 64
|
||||||
|
num_v_heads, num_k_heads = v.shape[2], k.shape[2]
|
||||||
|
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||||
|
k, v, beta, g, A = [
|
||||||
|
x.transpose(1, 2).contiguous().to(torch.float32) for x in (k, v, beta, g, A)
|
||||||
|
]
|
||||||
|
|
||||||
|
batch_size, num_heads, sequence_length, k_head_dim = k.shape
|
||||||
|
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||||
|
k = F.pad(k, (0, 0, 0, pad_size))
|
||||||
|
v = F.pad(v, (0, 0, 0, pad_size))
|
||||||
|
beta = F.pad(beta, (0, pad_size))
|
||||||
|
g = F.pad(g, (0, pad_size))
|
||||||
|
A = F.pad(A, (0, 0, 0, pad_size))
|
||||||
|
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
|
||||||
|
|
||||||
|
v_beta = v * beta.unsqueeze(-1)
|
||||||
|
k_beta = k * beta.unsqueeze(-1)
|
||||||
|
|
||||||
|
k, v, k_beta, v_beta = [
|
||||||
|
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
|
||||||
|
for x in (k, v, k_beta, v_beta)
|
||||||
|
]
|
||||||
|
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||||
|
|
||||||
|
u = A @ v_beta
|
||||||
|
w = A @ (k_beta * g.exp().unsqueeze(-1))
|
||||||
|
w = (
|
||||||
|
w.reshape(w.shape[0], w.shape[1], -1, w.shape[-1])[:, :, :sequence_length, :]
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
u = (
|
||||||
|
u.reshape(u.shape[0], u.shape[1], -1, u.shape[-1])[:, :, :sequence_length, :]
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
return w, u
|
||||||
|
|
||||||
|
|
||||||
|
def split_by_value(tensor, chunk_size=64):
|
||||||
|
indices = tensor.tolist()
|
||||||
|
result = set(indices) # 使用集合避免重复
|
||||||
|
|
||||||
|
for i in range(len(indices) - 1):
|
||||||
|
start = indices[i]
|
||||||
|
end = indices[i + 1]
|
||||||
|
|
||||||
|
# 计算第一个对齐边界
|
||||||
|
# 我们要找的是 start + n*chunk_size,其中n是使结果大于start的最小整数
|
||||||
|
first_boundary = start + chunk_size
|
||||||
|
|
||||||
|
# 在(start, end)范围内插入所有对齐边界
|
||||||
|
boundary = first_boundary
|
||||||
|
while boundary < end:
|
||||||
|
result.add(boundary)
|
||||||
|
boundary += chunk_size
|
||||||
|
|
||||||
|
return torch.tensor(sorted(result), dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_gated_delta_rule_fwd(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
g: torch.Tensor,
|
||||||
|
beta: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
initial_state: torch.Tensor,
|
||||||
|
output_final_state: bool,
|
||||||
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
chunk_size = 64
|
||||||
|
chunk_indices = (
|
||||||
|
prepare_chunk_indices(cu_seqlens, 64) if cu_seqlens is not None else None
|
||||||
|
)
|
||||||
|
chunk_offsets = (
|
||||||
|
prepare_chunk_offsets(cu_seqlens, chunk_size)
|
||||||
|
if cu_seqlens is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# !
|
||||||
|
# g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||||
|
g = torch.ops.xspeedgate_ops.chunk_local_cumsum(
|
||||||
|
g,
|
||||||
|
chunk_size=64,
|
||||||
|
reverse=False,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
head_first=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# !
|
||||||
|
# A = chunk_scaled_dot_kkt_fwd(k=k,
|
||||||
|
# beta=beta,
|
||||||
|
# g_cumsum=g,
|
||||||
|
# cu_seqlens=cu_seqlens,
|
||||||
|
# output_dtype=q.dtype)
|
||||||
|
A = torch.ops.xspeedgate_ops.chunk_scaled_dot_kkt_fwd(
|
||||||
|
k, beta, g, cu_seqlens, chunk_indices, chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# torch版
|
||||||
|
# if get_tensor_model_parallel_rank() == 0:
|
||||||
|
# torch.save(A, "A_in")
|
||||||
|
# torch.save(cu_seqlens, "cu_seqlens")
|
||||||
|
# A2 = A.clone()
|
||||||
|
torch.ops.xspeedgate_ops.solve_tril_ns(A, cu_seqlens, chunk_indices, chunk_size)
|
||||||
|
|
||||||
|
# !
|
||||||
|
# torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
|
||||||
|
# if get_tensor_model_parallel_rank() == 0:
|
||||||
|
# err = torch.max(torch.abs(A - A2))
|
||||||
|
# print("err", err)
|
||||||
|
# if err > 1e-3:
|
||||||
|
# raise
|
||||||
|
# A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||||
|
# for i in range(len(cu_seqlens)-1):
|
||||||
|
# A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||||
|
# A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
|
||||||
|
|
||||||
|
"""
|
||||||
|
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||||
|
H = v.shape[-2]
|
||||||
|
u = torch.empty_like(v)
|
||||||
|
w = k.new_empty(B, T, H, K)
|
||||||
|
for i in range(len(cu_seqlens)-1):
|
||||||
|
k_i = k[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||||
|
v_i = v[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||||
|
beta_i = beta[:, cu_seqlens[i]:cu_seqlens[i+1], :]
|
||||||
|
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||||
|
g_i = g[:, cu_seqlens[i]:cu_seqlens[i+1], :]
|
||||||
|
|
||||||
|
w_i, u_i = recompute_w_u_fwd_torch(
|
||||||
|
k=k_i,
|
||||||
|
v=v_i,
|
||||||
|
beta=beta_i,
|
||||||
|
A=A_i,
|
||||||
|
g=g_i,
|
||||||
|
)
|
||||||
|
w[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = w_i
|
||||||
|
u[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = u_i
|
||||||
|
"""
|
||||||
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
|
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
@@ -71,17 +207,63 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
|||||||
g_cumsum=g,
|
g_cumsum=g,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
chunk_indices=chunk_indices,
|
chunk_indices=chunk_indices,
|
||||||
chunk_size=64
|
chunk_size=64,
|
||||||
)
|
)
|
||||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
"""
|
||||||
|
w, u = recompute_w_u_fwd(
|
||||||
k=k,
|
k=k,
|
||||||
w=w,
|
v=v,
|
||||||
u=u,
|
beta=beta,
|
||||||
g=g,
|
A=A,
|
||||||
initial_state=initial_state,
|
g_cumsum=g,
|
||||||
output_final_state=output_final_state,
|
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# i
|
||||||
|
# import os
|
||||||
|
# if not os.path.exists("/qwen-next/in"):
|
||||||
|
# os.makedirs("/qwen-next/in")
|
||||||
|
# torch.save(k, "/qwen-next/in/k.pt")
|
||||||
|
# torch.save(u, "/qwen-next/in/u.pt")
|
||||||
|
# torch.save(w, "/qwen-next/in/w.pt")
|
||||||
|
# torch.save(g, "/qwen-next/in/g.pt")
|
||||||
|
# torch.save(initial_state, "/qwen-next/in/initial_state.pt")
|
||||||
|
# torch.save(cu_seqlens, "/qwen-next/in/cu_seqlens.pt")
|
||||||
|
# torch.save(chunk_indices, "/qwen-next/in/chunk_indices.pt")
|
||||||
|
# torch.save(chunk_offsets.to(torch.int32), "/qwen-next/in/chunk_offsets.pt")
|
||||||
|
# torch.save(chunk_size, "/qwen-next/in/chunk_size.pt")
|
||||||
|
# torch.save(output_final_state, "/qwen-next/in/output_final_state.pt")
|
||||||
|
|
||||||
|
h, v_new, final_state = torch.ops.xspeedgate_ops.chunk_gated_delta_rule_fwd_h(
|
||||||
|
k,
|
||||||
|
u,
|
||||||
|
w,
|
||||||
|
g,
|
||||||
|
initial_state,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
chunk_offsets.to(torch.int32),
|
||||||
|
chunk_size,
|
||||||
|
output_final_state,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||||
|
# k=k,
|
||||||
|
# w=w,
|
||||||
|
# u=u,
|
||||||
|
# g=g,
|
||||||
|
# initial_state=initial_state,
|
||||||
|
# output_final_state=output_final_state,
|
||||||
|
# cu_seqlens=cu_seqlens,
|
||||||
|
# )
|
||||||
|
# if not os.path.exists("/qwen-next/out"):
|
||||||
|
# os.makedirs("/qwen-next/out")
|
||||||
|
# torch.save(h, "/qwen-next/out/h.pt")
|
||||||
|
# torch.save(v_new, "/qwen-next/out/v_new.pt")
|
||||||
|
# torch.save(final_state, "/qwen-next/out/final_state.pt")
|
||||||
|
|
||||||
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
|
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
@@ -91,8 +273,19 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
|||||||
scale=scale,
|
scale=scale,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
chunk_indices=chunk_indices,
|
chunk_indices=chunk_indices,
|
||||||
chunk_size=64
|
chunk_size=64,
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
o = chunk_fwd_o(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v_new,
|
||||||
|
h=h,
|
||||||
|
g=g,
|
||||||
|
scale=scale,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
"""
|
||||||
if SUPPRESS_LEVEL < 3:
|
if SUPPRESS_LEVEL < 3:
|
||||||
return g, o, A, final_state, None, None, None
|
return g, o, A, final_state, None, None, None
|
||||||
elif SUPPRESS_LEVEL >= 3:
|
elif SUPPRESS_LEVEL >= 3:
|
||||||
@@ -103,18 +296,20 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@input_guard
|
@input_guard
|
||||||
@torch.amp.custom_fwd(device_type='cuda')
|
@torch.amp.custom_fwd(device_type="cuda")
|
||||||
def forward(ctx,
|
def forward(
|
||||||
q: torch.Tensor,
|
ctx,
|
||||||
k: torch.Tensor,
|
q: torch.Tensor,
|
||||||
v: torch.Tensor,
|
k: torch.Tensor,
|
||||||
g: torch.Tensor,
|
v: torch.Tensor,
|
||||||
beta: torch.Tensor,
|
g: torch.Tensor,
|
||||||
scale: float,
|
beta: torch.Tensor,
|
||||||
initial_state: torch.Tensor,
|
scale: float,
|
||||||
output_final_state: bool,
|
initial_state: torch.Tensor,
|
||||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
output_final_state: bool,
|
||||||
use_qk_l2norm_in_kernel: bool = False):
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
|
use_qk_l2norm_in_kernel: bool = False,
|
||||||
|
):
|
||||||
if use_qk_l2norm_in_kernel:
|
if use_qk_l2norm_in_kernel:
|
||||||
q = l2norm_fwd(q)
|
q = l2norm_fwd(q)
|
||||||
k = l2norm_fwd(k)
|
k = l2norm_fwd(k)
|
||||||
@@ -136,17 +331,19 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
|
|
||||||
@torch.compiler.disable
|
@torch.compiler.disable
|
||||||
def chunk_gated_delta_rule(q: torch.Tensor,
|
def chunk_gated_delta_rule(
|
||||||
k: torch.Tensor,
|
q: torch.Tensor,
|
||||||
v: torch.Tensor,
|
k: torch.Tensor,
|
||||||
g: torch.Tensor,
|
v: torch.Tensor,
|
||||||
beta: torch.Tensor,
|
g: torch.Tensor,
|
||||||
scale: float = None,
|
beta: torch.Tensor,
|
||||||
initial_state: torch.Tensor = None,
|
scale: float = None,
|
||||||
output_final_state: bool = False,
|
initial_state: torch.Tensor = None,
|
||||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
output_final_state: bool = False,
|
||||||
head_first: bool = False,
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
use_qk_l2norm_in_kernel: bool = False):
|
head_first: bool = False,
|
||||||
|
use_qk_l2norm_in_kernel: bool = False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
q (torch.Tensor):
|
q (torch.Tensor):
|
||||||
@@ -211,42 +408,85 @@ def chunk_gated_delta_rule(q: torch.Tensor,
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
assert q.dtype == k.dtype == v.dtype
|
assert q.dtype == k.dtype == v.dtype
|
||||||
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
assert (
|
||||||
assert len(
|
q.dtype != torch.float32
|
||||||
beta.shape
|
), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||||
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
assert (
|
||||||
|
len(beta.shape) == 3
|
||||||
|
), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||||
|
|
||||||
if head_first:
|
if head_first:
|
||||||
raise DeprecationWarning(
|
raise DeprecationWarning(
|
||||||
"head_first is deprecated and will be removed in a future version. "
|
"head_first is deprecated and will be removed in a future version. "
|
||||||
"Please use head_first=False for now instead.",
|
"Please use head_first=False for now instead.",
|
||||||
stacklevel=2)
|
stacklevel=2,
|
||||||
|
)
|
||||||
q, k, v, beta, g = map(
|
q, k, v, beta, g = map(
|
||||||
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
|
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
|
||||||
(q, k, v, beta, g))
|
)
|
||||||
if not head_first and q.shape[1] < q.shape[2]:
|
if not head_first and q.shape[1] < q.shape[2]:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
||||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||||
"when head_first=False was specified. "
|
"when head_first=False was specified. "
|
||||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||||
stacklevel=2)
|
stacklevel=2,
|
||||||
|
)
|
||||||
if cu_seqlens is not None:
|
if cu_seqlens is not None:
|
||||||
if q.shape[0] != 1:
|
if q.shape[0] != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||||
f"Please flatten variable-length inputs before processing.")
|
f"Please flatten variable-length inputs before processing."
|
||||||
if initial_state is not None and initial_state.shape[0] != len(
|
)
|
||||||
cu_seqlens) - 1:
|
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||||
)
|
)
|
||||||
if scale is None:
|
if scale is None:
|
||||||
scale = k.shape[-1]**-0.5
|
scale = k.shape[-1] ** -0.5
|
||||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
|
||||||
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
|
if False:
|
||||||
use_qk_l2norm_in_kernel)
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
g = g.contiguous()
|
||||||
|
beta = beta.contiguous()
|
||||||
|
initial_state = initial_state.contiguous()
|
||||||
|
|
||||||
|
o = torch.empty_like(v)
|
||||||
|
final_state = torch.empty_like(initial_state)
|
||||||
|
import kunlun_ops
|
||||||
|
|
||||||
|
kunlun_ops.gated_delta_rule(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
initial_state,
|
||||||
|
g,
|
||||||
|
beta,
|
||||||
|
final_state,
|
||||||
|
o,
|
||||||
|
scale,
|
||||||
|
cu_seqlens.cpu(),
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens.cpu(),
|
||||||
|
cu_seqlens,
|
||||||
|
use_qk_l2norm_in_kernel=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
g,
|
||||||
|
beta,
|
||||||
|
scale,
|
||||||
|
initial_state,
|
||||||
|
output_final_state,
|
||||||
|
cu_seqlens,
|
||||||
|
use_qk_l2norm_in_kernel,
|
||||||
|
)
|
||||||
if head_first:
|
if head_first:
|
||||||
o = rearrange(o, 'b t h ... -> b h t ...')
|
o = rearrange(o, "b t h ... -> b h t ...")
|
||||||
return o, final_state
|
return o, final_state
|
||||||
|
|||||||
@@ -12,21 +12,21 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
from .index import prepare_chunk_indices
|
from .index import prepare_chunk_indices
|
||||||
from .op import exp
|
|
||||||
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
|
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
|
||||||
|
|
||||||
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
|
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
|
||||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
|
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({
|
@triton.heuristics(
|
||||||
'USE_G': lambda args: args['g'] is not None,
|
{
|
||||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
"USE_G": lambda args: args["g"] is not None,
|
||||||
})
|
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||||
|
}
|
||||||
|
)
|
||||||
# @triton.autotune(
|
# @triton.autotune(
|
||||||
# configs=[
|
# configs=[
|
||||||
# triton.Config({
|
# triton.Config({
|
||||||
@@ -40,7 +40,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
|
|||||||
# ],
|
# ],
|
||||||
# key=['H', 'K', 'V', 'BT'],
|
# key=['H', 'K', 'V', 'BT'],
|
||||||
# )
|
# )
|
||||||
@triton.jit(do_not_specialize=['T'])
|
@triton.jit(do_not_specialize=["T"])
|
||||||
def chunk_fwd_kernel_o(
|
def chunk_fwd_kernel_o(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@@ -67,10 +67,12 @@ def chunk_fwd_kernel_o(
|
|||||||
|
|
||||||
if IS_VARLEN:
|
if IS_VARLEN:
|
||||||
i_tg = i_t
|
i_tg = i_t
|
||||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
chunk_indices + i_t * 2 + 1
|
||||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
).to(tl.int32)
|
||||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||||
|
cu_seqlens + i_n + 1
|
||||||
|
).to(tl.int32)
|
||||||
T = eos - bos
|
T = eos - bos
|
||||||
NT = tl.cdiv(T, BT)
|
NT = tl.cdiv(T, BT)
|
||||||
else:
|
else:
|
||||||
@@ -89,12 +91,15 @@ def chunk_fwd_kernel_o(
|
|||||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||||
|
|
||||||
for i_k in range(tl.cdiv(K, BK)):
|
for i_k in range(tl.cdiv(K, BK)):
|
||||||
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
|
p_q = tl.make_block_ptr(
|
||||||
(BT, BK), (1, 0))
|
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||||
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT),
|
)
|
||||||
(BK, BT), (0, 1))
|
p_k = tl.make_block_ptr(
|
||||||
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV),
|
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
|
||||||
(BK, BV), (1, 0))
|
)
|
||||||
|
p_h = tl.make_block_ptr(
|
||||||
|
h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
|
||||||
|
)
|
||||||
# [BT, BK]
|
# [BT, BK]
|
||||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||||
# [BK, BT]
|
# [BK, BT]
|
||||||
@@ -109,8 +114,8 @@ def chunk_fwd_kernel_o(
|
|||||||
|
|
||||||
if USE_G:
|
if USE_G:
|
||||||
g += bos * H + i_h
|
g += bos * H + i_h
|
||||||
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
|
p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
b_g = tl.load(p_g, boundary_check=(0,))
|
||||||
b_o = b_o * tl.exp(b_g)[:, None]
|
b_o = b_o * tl.exp(b_g)[:, None]
|
||||||
b_A = b_A * tl.exp(b_g[:, None] - b_g[None, :])
|
b_A = b_A * tl.exp(b_g[:, None] - b_g[None, :])
|
||||||
|
|
||||||
@@ -120,10 +125,12 @@ def chunk_fwd_kernel_o(
|
|||||||
# b_A = tl.where(m_A, b_A, 0)
|
# b_A = tl.where(m_A, b_A, 0)
|
||||||
b_A = tl.where(o_t[:, None] >= o_t[None, :], b_A, 0)
|
b_A = tl.where(o_t[:, None] >= o_t[None, :], b_A, 0)
|
||||||
|
|
||||||
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
p_v = tl.make_block_ptr(
|
||||||
(BT, BV), (1, 0))
|
v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||||
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
)
|
||||||
(BT, BV), (1, 0))
|
p_o = tl.make_block_ptr(
|
||||||
|
o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||||
|
)
|
||||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||||
|
|
||||||
# to fix mma -> mma layout conversion
|
# to fix mma -> mma layout conversion
|
||||||
@@ -133,48 +140,29 @@ def chunk_fwd_kernel_o(
|
|||||||
|
|
||||||
|
|
||||||
def chunk_fwd_o(
|
def chunk_fwd_o(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
h: torch.Tensor,
|
h: torch.Tensor,
|
||||||
g: Optional[torch.Tensor] = None, # cumsum of log decay
|
g: Optional[torch.Tensor] = None, # cumsum of log decay
|
||||||
scale: Optional[float] = None,
|
scale: Optional[float] = None,
|
||||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
chunk_size: int = 64) -> torch.Tensor:
|
chunk_size: int = 64,
|
||||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
) -> torch.Tensor:
|
||||||
H = v.shape[-2]
|
_, T, _, _, _ = *q.shape, v.shape[-1]
|
||||||
if FLA_GDN_FIX_BT:
|
if FLA_GDN_FIX_BT:
|
||||||
BT = 64
|
BT = 64
|
||||||
else:
|
else:
|
||||||
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||||
chunk_indices = prepare_chunk_indices(
|
chunk_indices = (
|
||||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
)
|
||||||
if scale is None:
|
if scale is None:
|
||||||
scale = k.shape[-1]**-0.5
|
scale = k.shape[-1] ** -0.5
|
||||||
|
|
||||||
o = torch.empty_like(v)
|
o = torch.empty_like(v)
|
||||||
|
|
||||||
def grid(meta):
|
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
|
||||||
return (triton.cdiv(V, meta['BV']), NT, B * H)
|
q, k, v, h, g, scale, cu_seqlens, chunk_indices, chunk_size
|
||||||
|
|
||||||
chunk_fwd_kernel_o[grid](
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
h,
|
|
||||||
g,
|
|
||||||
o,
|
|
||||||
cu_seqlens,
|
|
||||||
chunk_indices,
|
|
||||||
scale,
|
|
||||||
T=T,
|
|
||||||
H=H,
|
|
||||||
Hg=Hg,
|
|
||||||
K=K,
|
|
||||||
V=V,
|
|
||||||
BT=BT,
|
|
||||||
BK=64,
|
|
||||||
BV=32
|
|
||||||
)
|
)
|
||||||
return o
|
return o
|
||||||
|
|||||||
@@ -9,28 +9,28 @@
|
|||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import kunlun_ops
|
import kunlun_ops
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class FusedRecurrentFunction(torch.autograd.Function):
|
class FusedRecurrentFunction(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx,
|
def forward(
|
||||||
q: torch.Tensor,
|
ctx,
|
||||||
k: torch.Tensor,
|
q: torch.Tensor,
|
||||||
v: torch.Tensor,
|
k: torch.Tensor,
|
||||||
g: torch.Tensor,
|
v: torch.Tensor,
|
||||||
beta: torch.Tensor,
|
g: torch.Tensor,
|
||||||
scale: float,
|
beta: torch.Tensor,
|
||||||
initial_state: torch.Tensor,
|
scale: float,
|
||||||
inplace_final_state: bool = True,
|
initial_state: torch.Tensor,
|
||||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
inplace_final_state: bool = True,
|
||||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||||
use_qk_l2norm_in_kernel: bool = False):
|
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||||
|
use_qk_l2norm_in_kernel: bool = False,
|
||||||
|
):
|
||||||
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwdv2(
|
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwdv2(
|
||||||
q.contiguous(),
|
q.contiguous(),
|
||||||
k.contiguous(),
|
k.contiguous(),
|
||||||
@@ -44,7 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
|
|||||||
h0_indices=ssm_state_indices,
|
h0_indices=ssm_state_indices,
|
||||||
num_accepted_tokens=num_accepted_tokens,
|
num_accepted_tokens=num_accepted_tokens,
|
||||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||||
is_h0_transposed=True
|
is_h0_transposed=True,
|
||||||
)
|
)
|
||||||
return o, final_state
|
return o, final_state
|
||||||
|
|
||||||
@@ -130,9 +130,10 @@ def fused_recurrent_gated_delta_rule(
|
|||||||
if cu_seqlens is not None and q.shape[0] != 1:
|
if cu_seqlens is not None and q.shape[0] != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||||
f"Please flatten variable-length inputs before processing.")
|
f"Please flatten variable-length inputs before processing."
|
||||||
|
)
|
||||||
if scale is None:
|
if scale is None:
|
||||||
scale = k.shape[-1]**-0.5
|
scale = k.shape[-1] ** -0.5
|
||||||
else:
|
else:
|
||||||
assert scale > 0, "scale must be positive"
|
assert scale > 0, "scale must be positive"
|
||||||
if beta is None:
|
if beta is None:
|
||||||
|
|||||||
@@ -10,22 +10,21 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import kunlun_ops
|
||||||
import torch
|
import torch
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
import kunlun_ops
|
|
||||||
|
|
||||||
|
|
||||||
BT_LIST = [8, 16, 32, 64, 128]
|
BT_LIST = [8, 16, 32, 64, 128]
|
||||||
|
|
||||||
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
|
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
|
||||||
|
|
||||||
|
|
||||||
@triton.autotune(configs=[
|
@triton.autotune(
|
||||||
triton.Config({}, num_warps=num_warps)
|
configs=[
|
||||||
for num_warps in [1, 2, 4, 8, 16, 32]
|
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
|
||||||
],
|
],
|
||||||
key=['D'])
|
key=["D"],
|
||||||
|
)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def l2norm_fwd_kernel1(
|
def l2norm_fwd_kernel1(
|
||||||
x,
|
x,
|
||||||
@@ -49,11 +48,14 @@ def l2norm_fwd_kernel1(
|
|||||||
tl.store(y + cols, b_y, mask=mask)
|
tl.store(y + cols, b_y, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
@triton.autotune(configs=[
|
@triton.autotune(
|
||||||
triton.Config({'BT': BT}, num_warps=num_warps)
|
configs=[
|
||||||
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
|
triton.Config({"BT": BT}, num_warps=num_warps)
|
||||||
],
|
for num_warps in [1, 2, 4, 8, 16]
|
||||||
key=['D'])
|
for BT in BT_LIST
|
||||||
|
],
|
||||||
|
key=["D"],
|
||||||
|
)
|
||||||
@triton.jit(do_not_specialize=["NB"])
|
@triton.jit(do_not_specialize=["NB"])
|
||||||
def l2norm_fwd_kernel(
|
def l2norm_fwd_kernel(
|
||||||
x,
|
x,
|
||||||
@@ -87,67 +89,9 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
|||||||
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
||||||
|
|
||||||
|
|
||||||
def l2norm_fwd_triton(x: torch.Tensor,
|
def l2norm_fwd(
|
||||||
eps: float = 1e-6,
|
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
|
||||||
output_dtype: Optional[torch.dtype] = None):
|
):
|
||||||
x_shape_og = x.shape
|
|
||||||
x = x.view(-1, x.shape[-1])
|
|
||||||
# allocate output
|
|
||||||
if output_dtype is None:
|
|
||||||
y = torch.empty_like(x)
|
|
||||||
else:
|
|
||||||
y = torch.empty_like(x, dtype=output_dtype)
|
|
||||||
assert y.stride(-1) == 1
|
|
||||||
T, D = x.shape[0], x.shape[-1]
|
|
||||||
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
|
|
||||||
# Less than 64KB per feature: enqueue fused kernel
|
|
||||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
|
||||||
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
|
|
||||||
if D > BD:
|
|
||||||
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
|
|
||||||
|
|
||||||
if not USE_DEFAULT_FLA_NORM:
|
|
||||||
MBLOCK = 32
|
|
||||||
# M, N = x.shape
|
|
||||||
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
|
|
||||||
x,
|
|
||||||
y,
|
|
||||||
eps,
|
|
||||||
T,
|
|
||||||
D,
|
|
||||||
MBLOCK,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if D <= 512:
|
|
||||||
NB = triton.cdiv(T, 2048)
|
|
||||||
|
|
||||||
def grid(meta):
|
|
||||||
return (triton.cdiv(T, meta['BT']), )
|
|
||||||
|
|
||||||
l2norm_fwd_kernel[grid](
|
|
||||||
x,
|
|
||||||
y,
|
|
||||||
eps,
|
|
||||||
NB=NB,
|
|
||||||
T=T,
|
|
||||||
D=D,
|
|
||||||
BD=BD,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
l2norm_fwd_kernel1[(T, )](
|
|
||||||
x,
|
|
||||||
y,
|
|
||||||
eps=eps,
|
|
||||||
D=D,
|
|
||||||
BD=BD,
|
|
||||||
)
|
|
||||||
|
|
||||||
return y.view(x_shape_og)
|
|
||||||
|
|
||||||
|
|
||||||
def l2norm_fwd(x: torch.Tensor,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
output_dtype: Optional[torch.dtype] = None):
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
kunlun_ops.l2norm(x, out, eps)
|
kunlun_ops.l2norm(x, out, eps)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -19,20 +19,21 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
from .utils import input_guard
|
from .utils import input_guard
|
||||||
|
|
||||||
|
|
||||||
def rms_norm_ref(x,
|
def rms_norm_ref(
|
||||||
weight,
|
x,
|
||||||
bias,
|
weight,
|
||||||
z=None,
|
bias,
|
||||||
eps=1e-6,
|
z=None,
|
||||||
group_size=None,
|
eps=1e-6,
|
||||||
norm_before_gate=True,
|
group_size=None,
|
||||||
upcast=True):
|
norm_before_gate=True,
|
||||||
|
upcast=True,
|
||||||
|
):
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
weight = weight.float()
|
weight = weight.float()
|
||||||
bias = bias.float() if bias is not None else None
|
bias = bias.float() if bias is not None else None
|
||||||
@@ -43,12 +44,10 @@ def rms_norm_ref(x,
|
|||||||
x = x * F.silu(z)
|
x = x * F.silu(z)
|
||||||
if group_size is None:
|
if group_size is None:
|
||||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
|
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
||||||
weight)
|
|
||||||
else:
|
else:
|
||||||
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
||||||
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
|
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
|
||||||
eps)
|
|
||||||
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
out = out + bias
|
out = out + bias
|
||||||
@@ -57,10 +56,12 @@ def rms_norm_ref(x,
|
|||||||
return out.to(dtype)
|
return out.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({
|
@triton.heuristics(
|
||||||
"HAS_BIAS": lambda args: args["B"] is not None,
|
{
|
||||||
"HAS_Z": lambda args: args["Z"] is not None,
|
"HAS_BIAS": lambda args: args["B"] is not None,
|
||||||
})
|
"HAS_Z": lambda args: args["Z"] is not None,
|
||||||
|
}
|
||||||
|
)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def layer_norm_fwd_kernel(
|
def layer_norm_fwd_kernel(
|
||||||
X, # pointer to the input
|
X, # pointer to the input
|
||||||
@@ -97,17 +98,17 @@ def layer_norm_fwd_kernel(
|
|||||||
B += group * N
|
B += group * N
|
||||||
# Compute mean and variance
|
# Compute mean and variance
|
||||||
cols = tl.arange(0, BLOCK_N)
|
cols = tl.arange(0, BLOCK_N)
|
||||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||||
if HAS_Z and not NORM_BEFORE_GATE:
|
if HAS_Z and not NORM_BEFORE_GATE:
|
||||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||||
x *= z * tl.sigmoid(z)
|
x *= z * tl.sigmoid(z)
|
||||||
if not IS_RMS_NORM:
|
if not IS_RMS_NORM:
|
||||||
mean = tl.sum(x, axis=0) / N
|
mean = tl.sum(x, axis=0) / N
|
||||||
tl.store(Mean + row, mean)
|
tl.store(Mean + row, mean)
|
||||||
xbar = tl.where(cols < N, x - mean, 0.)
|
xbar = tl.where(cols < N, x - mean, 0.0)
|
||||||
var = tl.sum(xbar * xbar, axis=0) / N
|
var = tl.sum(xbar * xbar, axis=0) / N
|
||||||
else:
|
else:
|
||||||
xbar = tl.where(cols < N, x, 0.)
|
xbar = tl.where(cols < N, x, 0.0)
|
||||||
var = tl.sum(xbar * xbar, axis=0) / N
|
var = tl.sum(xbar * xbar, axis=0) / N
|
||||||
rstd = 1 / tl.sqrt(var + eps)
|
rstd = 1 / tl.sqrt(var + eps)
|
||||||
tl.store(Rstd + row, rstd)
|
tl.store(Rstd + row, rstd)
|
||||||
@@ -149,46 +150,50 @@ def layer_norm_fwd(
|
|||||||
# weight = weight.reshape(N)
|
# weight = weight.reshape(N)
|
||||||
# print("weight",weight.shape)
|
# print("weight",weight.shape)
|
||||||
# print("x",x.shape)
|
# print("x",x.shape)
|
||||||
assert weight.shape == (N, )
|
assert weight.shape == (N,)
|
||||||
assert weight.stride(-1) == 1
|
assert weight.stride(-1) == 1
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
assert bias.stride(-1) == 1
|
assert bias.stride(-1) == 1
|
||||||
assert bias.shape == (N, )
|
assert bias.shape == (N,)
|
||||||
# allocate output
|
# allocate output
|
||||||
if out is not None:
|
if out is not None:
|
||||||
assert out.shape == x.shape
|
assert out.shape == x.shape
|
||||||
else:
|
else:
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
assert out.stride(-1) == 1
|
assert out.stride(-1) == 1
|
||||||
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
|
mean = (
|
||||||
device=x.device) if not is_rms_norm else None
|
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
if not is_rms_norm
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||||
# Less than 64KB per feature: enqueue fused kernel
|
# Less than 64KB per feature: enqueue fused kernel
|
||||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||||
if group_size > BLOCK_N:
|
if group_size > BLOCK_N:
|
||||||
raise RuntimeError(
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||||
"This layer norm doesn't support feature dim >= 64KB.")
|
|
||||||
# heuristics for number of warps
|
# heuristics for number of warps
|
||||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||||
grid = (M, ngroups)
|
grid = (M, ngroups)
|
||||||
layer_norm_fwd_kernel[grid](x,
|
layer_norm_fwd_kernel[grid](
|
||||||
out,
|
x,
|
||||||
weight,
|
out,
|
||||||
bias,
|
weight,
|
||||||
z,
|
bias,
|
||||||
mean,
|
z,
|
||||||
rstd,
|
mean,
|
||||||
x.stride(0),
|
rstd,
|
||||||
out.stride(0),
|
x.stride(0),
|
||||||
z.stride(0) if z is not None else 0,
|
out.stride(0),
|
||||||
M,
|
z.stride(0) if z is not None else 0,
|
||||||
group_size,
|
M,
|
||||||
eps,
|
group_size,
|
||||||
BLOCK_N=BLOCK_N,
|
eps,
|
||||||
NORM_BEFORE_GATE=norm_before_gate,
|
BLOCK_N=BLOCK_N,
|
||||||
IS_RMS_NORM=is_rms_norm,
|
NORM_BEFORE_GATE=norm_before_gate,
|
||||||
num_warps=num_warps)
|
IS_RMS_NORM=is_rms_norm,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
return out, mean, rstd
|
return out, mean, rstd
|
||||||
|
|
||||||
|
|
||||||
@@ -196,17 +201,18 @@ class LayerNormFn(torch.autograd.Function):
|
|||||||
|
|
||||||
@input_guard
|
@input_guard
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx,
|
def forward(
|
||||||
x,
|
ctx,
|
||||||
weight,
|
x,
|
||||||
bias,
|
weight,
|
||||||
z=None,
|
bias,
|
||||||
eps=1e-6,
|
z=None,
|
||||||
group_size=None,
|
eps=1e-6,
|
||||||
norm_before_gate=True,
|
group_size=None,
|
||||||
is_rms_norm=False):
|
norm_before_gate=True,
|
||||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
is_rms_norm=False,
|
||||||
"""
|
):
|
||||||
|
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||||
|
|
||||||
x_shape_og = x.shape
|
x_shape_og = x.shape
|
||||||
# reshape input data into 2D tensor
|
# reshape input data into 2D tensor
|
||||||
@@ -223,16 +229,15 @@ class LayerNormFn(torch.autograd.Function):
|
|||||||
weight = weight.contiguous()
|
weight = weight.contiguous()
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
bias = bias.contiguous()
|
bias = bias.contiguous()
|
||||||
y, mean, rstd = layer_norm_fwd(
|
# y, mean, rstd = torch.ops.xspeedgate_ops.rms_norm_gated_fwd(x, weight, bias, eps, z, group_size, norm_before_gate, is_rms_norm)
|
||||||
x,
|
y = torch.empty_like(x)
|
||||||
weight,
|
mean, rstd = None, None
|
||||||
bias,
|
import kunlun_ops
|
||||||
eps,
|
|
||||||
z=z,
|
kunlun_ops.rms_norm_gated(
|
||||||
group_size=group_size,
|
x, y, z, weight, eps, group_size, norm_before_gate, is_rms_norm
|
||||||
norm_before_gate=norm_before_gate,
|
|
||||||
is_rms_norm=is_rms_norm,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
||||||
ctx.x_shape_og = x_shape_og
|
ctx.x_shape_og = x_shape_og
|
||||||
ctx.eps = eps
|
ctx.eps = eps
|
||||||
@@ -242,27 +247,27 @@ class LayerNormFn(torch.autograd.Function):
|
|||||||
return y.reshape(x_shape_og)
|
return y.reshape(x_shape_og)
|
||||||
|
|
||||||
|
|
||||||
def layernorm_fn(x,
|
def layernorm_fn(
|
||||||
weight,
|
x,
|
||||||
bias,
|
weight,
|
||||||
z=None,
|
bias,
|
||||||
eps=1e-6,
|
z=None,
|
||||||
group_size=None,
|
eps=1e-6,
|
||||||
norm_before_gate=True,
|
group_size=None,
|
||||||
is_rms_norm=False):
|
norm_before_gate=True,
|
||||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
is_rms_norm=False,
|
||||||
norm_before_gate, is_rms_norm)
|
):
|
||||||
|
return LayerNormFn.apply(
|
||||||
|
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def rmsnorm_fn(x,
|
def rmsnorm_fn(
|
||||||
weight,
|
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
||||||
bias,
|
):
|
||||||
z=None,
|
return LayerNormFn.apply(
|
||||||
eps=1e-6,
|
x, weight, bias, z, eps, group_size, norm_before_gate, True
|
||||||
group_size=None,
|
)
|
||||||
norm_before_gate=True):
|
|
||||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
|
||||||
norm_before_gate, True)
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNormGated(nn.Module):
|
class LayerNormGated(nn.Module):
|
||||||
@@ -294,15 +299,16 @@ class LayerNormGated(nn.Module):
|
|||||||
torch.nn.init.zeros_(self.bias)
|
torch.nn.init.zeros_(self.bias)
|
||||||
|
|
||||||
def forward(self, x, z=None):
|
def forward(self, x, z=None):
|
||||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||||
"""
|
return layernorm_fn(
|
||||||
return layernorm_fn(x,
|
x,
|
||||||
self.weight,
|
self.weight,
|
||||||
self.bias,
|
self.bias,
|
||||||
z=z,
|
z=z,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
eps=self.eps,
|
eps=self.eps,
|
||||||
norm_before_gate=self.norm_before_gate)
|
norm_before_gate=self.norm_before_gate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RMSNormGated(nn.Module):
|
class RMSNormGated(nn.Module):
|
||||||
@@ -332,12 +338,13 @@ class RMSNormGated(nn.Module):
|
|||||||
torch.nn.init.ones_(self.weight)
|
torch.nn.init.ones_(self.weight)
|
||||||
|
|
||||||
def forward(self, x, z=None):
|
def forward(self, x, z=None):
|
||||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||||
"""
|
return rmsnorm_fn(
|
||||||
return rmsnorm_fn(x,
|
x,
|
||||||
self.weight,
|
self.weight,
|
||||||
self.bias,
|
self.bias,
|
||||||
z=z,
|
z=z,
|
||||||
eps=self.eps,
|
eps=self.eps,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
norm_before_gate=self.norm_before_gate)
|
norm_before_gate=self.norm_before_gate,
|
||||||
|
)
|
||||||
|
|||||||
@@ -11,7 +11,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
from .index import prepare_chunk_indices
|
from .index import prepare_chunk_indices
|
||||||
@@ -28,6 +27,7 @@ RESOLUTION = {
|
|||||||
torch.complex64: 1.3e-6,
|
torch.complex64: 1.3e-6,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
|
def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
|
||||||
assert res.dtype == dtype
|
assert res.dtype == dtype
|
||||||
ref = ref.to(dtype)
|
ref = ref.to(dtype)
|
||||||
@@ -35,6 +35,7 @@ def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
|
|||||||
rtol = RESOLUTION[dtype]
|
rtol = RESOLUTION[dtype]
|
||||||
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)
|
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
# @triton.autotune(
|
# @triton.autotune(
|
||||||
# configs=[
|
# configs=[
|
||||||
@@ -80,7 +81,6 @@ def recompute_u_fwd_kernel(
|
|||||||
p_beta = tl.make_block_ptr(
|
p_beta = tl.make_block_ptr(
|
||||||
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||||
)
|
)
|
||||||
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
|
|
||||||
p_A = tl.make_block_ptr(
|
p_A = tl.make_block_ptr(
|
||||||
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||||
)
|
)
|
||||||
@@ -110,7 +110,6 @@ def recompute_u_fwd_kernel(
|
|||||||
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
# @triton.autotune(
|
# @triton.autotune(
|
||||||
# configs=[
|
# configs=[
|
||||||
@@ -195,53 +194,12 @@ def recompute_w_u_fwd(
|
|||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
cu_seqlens: Optional[torch.LongTensor],
|
cu_seqlens: Optional[torch.LongTensor],
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
|
||||||
H = v.shape[-2]
|
|
||||||
BT = A.shape[-1]
|
BT = A.shape[-1]
|
||||||
|
|
||||||
chunk_indices = prepare_chunk_indices(
|
chunk_indices = (
|
||||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
|
||||||
BK = 64
|
|
||||||
BV = 64
|
|
||||||
u = torch.empty_like(v)
|
|
||||||
w = k.new_empty(B, T, H, K)
|
|
||||||
recompute_u_fwd_kernel[(NT, B * H)](
|
|
||||||
k=k,
|
|
||||||
v=v,
|
|
||||||
beta=beta,
|
|
||||||
w=w,
|
|
||||||
u=u,
|
|
||||||
A=A,
|
|
||||||
g=g_cumsum,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
chunk_indices=chunk_indices,
|
|
||||||
T=T,
|
|
||||||
H=H,
|
|
||||||
Hg=Hg,
|
|
||||||
K=K,
|
|
||||||
V=V,
|
|
||||||
BT=BT,
|
|
||||||
BK=BK,
|
|
||||||
BV=BV,
|
|
||||||
)
|
)
|
||||||
recompute_w_fwd_kernel[(NT, B * H)](
|
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
|
||||||
k=k,
|
k, v, beta, g_cumsum, A, cu_seqlens, chunk_indices, chunk_size=BT
|
||||||
v=v,
|
|
||||||
beta=beta,
|
|
||||||
w=w,
|
|
||||||
u=u,
|
|
||||||
A=A,
|
|
||||||
g=g_cumsum,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
chunk_indices=chunk_indices,
|
|
||||||
T=T,
|
|
||||||
H=H,
|
|
||||||
Hg=Hg,
|
|
||||||
K=K,
|
|
||||||
V=V,
|
|
||||||
BT=BT,
|
|
||||||
BK=BK,
|
|
||||||
BV=BV,
|
|
||||||
)
|
)
|
||||||
return w, u
|
return w, u
|
||||||
|
|||||||
@@ -15,51 +15,52 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
#
|
#
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
||||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
|
|
||||||
from vllm.model_executor.layers import layernorm
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import kunlun_ops
|
|
||||||
|
import torch
|
||||||
|
from vllm.model_executor.layers import layernorm
|
||||||
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
def vllm_kunlun_forward_cuda(
|
def vllm_kunlun_forward_cuda(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor] = None,
|
residual: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
"""forward_cuda"""
|
"""forward_cuda"""
|
||||||
if x.is_contiguous() == False:
|
if not x.is_contiguous():
|
||||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||||
# so we must make it contiguous() manually
|
# so we must make it contiguous() manually
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
if self.variance_size_override is not None:
|
if self.variance_size_override is not None:
|
||||||
return self.forward_native(x, residual)
|
return self.forward_native(x, residual)
|
||||||
|
|
||||||
|
if residual is not None:
|
||||||
if residual is not None:
|
# residual_output = torch.empty_like(residual)
|
||||||
# residual_output = torch.empty_like(residual)
|
torch.ops._C.add_rmsnorm(
|
||||||
torch.ops._C.add_rmsnorm(
|
|
||||||
x,
|
|
||||||
residual,
|
|
||||||
residual_output=residual,
|
|
||||||
weight=self.weight.data,
|
|
||||||
eps=self.variance_epsilon,
|
|
||||||
output=x
|
|
||||||
)
|
|
||||||
return x, residual
|
|
||||||
out = torch.empty_like(x)
|
|
||||||
torch.ops._C.rmsnorm(
|
|
||||||
x,
|
x,
|
||||||
self.weight.data,
|
residual,
|
||||||
out,
|
residual_output=residual,
|
||||||
self.variance_epsilon,
|
weight=self.weight.data,
|
||||||
|
eps=self.variance_epsilon,
|
||||||
|
output=x,
|
||||||
)
|
)
|
||||||
return out
|
return x, residual
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
torch.ops._C.rmsnorm(
|
||||||
|
x,
|
||||||
|
self.weight.data,
|
||||||
|
out,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
|
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
|
||||||
RMSNorm.forward = vllm_kunlun_forward_cuda
|
RMSNorm.forward = vllm_kunlun_forward_cuda
|
||||||
|
|
||||||
|
|
||||||
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward_xpu(
|
def forward_xpu(
|
||||||
@@ -68,30 +69,42 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
if x.is_contiguous() == False:
|
if not x.is_contiguous():
|
||||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||||
# so we must make it contiguous() manually
|
# so we must make it contiguous() manually
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
|
if x.dim() == 3:
|
||||||
|
x_shape = x.shape
|
||||||
|
x = x.view(-1, x.size(-1))
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
torch.ops._C.add_rmsnorm(
|
out = torch.empty_like(x)
|
||||||
|
out_residual = torch.empty_like(residual)
|
||||||
|
torch.ops._C.gemma_add_rmsnorm(
|
||||||
x,
|
x,
|
||||||
residual,
|
residual,
|
||||||
residual_output=residual,
|
residual_output=out_residual,
|
||||||
weight=weight+1,
|
weight=weight,
|
||||||
eps=variance_epsilon,
|
eps=variance_epsilon,
|
||||||
output=x
|
output=out,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
torch.ops._C.gemma_rmsnorm(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
out,
|
||||||
|
variance_epsilon,
|
||||||
)
|
)
|
||||||
return x, residual
|
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
if x.dim() == 3:
|
||||||
torch.ops._C.rmsnorm(
|
x = x.view(x_shape)
|
||||||
x,
|
if out is not None:
|
||||||
weight+1,
|
out = out.view(x_shape)
|
||||||
out,
|
|
||||||
variance_epsilon,
|
if residual is not None:
|
||||||
)
|
return out, out_residual
|
||||||
return out
|
else:
|
||||||
|
return out
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
@@ -99,16 +112,17 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
|||||||
residual: Optional[torch.Tensor] = None,
|
residual: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
if torch.compiler.is_compiling():
|
if torch.compiler.is_compiling():
|
||||||
self.forward_static = self.forward_xpu # only use in cudagraph
|
self.forward_static = self.forward_xpu # only use in cudagraph
|
||||||
return self.forward_native(x, residual)
|
return self.forward_native(x, residual)
|
||||||
|
|
||||||
if not getattr(self, "_is_compiled", False):
|
if not getattr(self, "_is_compiled", False):
|
||||||
self.forward_static = torch.compile( # type: ignore
|
self.forward_static = torch.compile( # type: ignore
|
||||||
self.forward_static, backend="aot_eager")
|
self.forward_static, backend="aot_eager"
|
||||||
|
)
|
||||||
self._is_compiled = True
|
self._is_compiled = True
|
||||||
return self.forward_native(x, residual)
|
return self.forward_native(x, residual)
|
||||||
|
|
||||||
|
|
||||||
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
|
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
|
||||||
RMSNorm.forward = vllm_kunlun_forward_cuda
|
RMSNorm.forward = vllm_kunlun_forward_cuda
|
||||||
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm
|
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
390
vllm_kunlun/v1/attention/backends/gdn_attn.py
Normal file
390
vllm_kunlun/v1/attention/backends/gdn_attn.py
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Backend for GatedDeltaNet attention."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.v1.attention.backends import gdn_attn
|
||||||
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
AttentionCGSupport,
|
||||||
|
AttentionMetadataBuilder,
|
||||||
|
CommonAttentionMetadata,
|
||||||
|
compute_causal_conv1d_metadata,
|
||||||
|
split_decodes_and_prefills,
|
||||||
|
)
|
||||||
|
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||||
|
|
||||||
|
|
||||||
|
class GDNAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
|
||||||
|
return GDNAttentionMetadataBuilder
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GDNAttentionMetadata:
|
||||||
|
num_prefills: int
|
||||||
|
num_prefill_tokens: int
|
||||||
|
num_decodes: int
|
||||||
|
num_decode_tokens: int
|
||||||
|
num_spec_decodes: int
|
||||||
|
num_spec_decode_tokens: int
|
||||||
|
num_actual_tokens: int
|
||||||
|
|
||||||
|
has_initial_state: Optional[torch.Tensor] = None
|
||||||
|
has_initial_state_cpu: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
spec_query_start_loc: Optional[torch.Tensor] = (
|
||||||
|
None # shape: [num_spec_decodes + 1,]
|
||||||
|
)
|
||||||
|
non_spec_query_start_loc: Optional[torch.Tensor] = (
|
||||||
|
None # shape: [batch - num_spec_decodes + 1,]
|
||||||
|
)
|
||||||
|
|
||||||
|
spec_state_indices_tensor: Optional[torch.Tensor] = None # shape: [batch, num_spec]
|
||||||
|
non_spec_state_indices_tensor: Optional[torch.Tensor] = (
|
||||||
|
None # shape: [batch - num_spec_decodes,]
|
||||||
|
)
|
||||||
|
non_spec_state_indices_tensor_cpu: Optional[torch.Tensor] = None
|
||||||
|
spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,]
|
||||||
|
spec_token_masks: Optional[torch.Tensor] = (
|
||||||
|
None # shape: [num_prefill_tokens + num_decode_tokens,]
|
||||||
|
)
|
||||||
|
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
|
||||||
|
|
||||||
|
# The following attributes are for triton implementation of causal_conv1d
|
||||||
|
nums_dict: Optional[dict] = None
|
||||||
|
batch_ptr: Optional[torch.Tensor] = None
|
||||||
|
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
|
||||||
|
|
||||||
|
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||||
|
|
||||||
|
reorder_batch_threshold: int = 1
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kv_cache_spec: AttentionSpec,
|
||||||
|
layer_names: list[str],
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
assert isinstance(kv_cache_spec, MambaSpec)
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.compilation_config = vllm_config.compilation_config
|
||||||
|
self.speculative_config = vllm_config.speculative_config
|
||||||
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
if self.speculative_config:
|
||||||
|
self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501
|
||||||
|
else:
|
||||||
|
self.num_spec = 0
|
||||||
|
self.use_spec_decode = self.num_spec > 0
|
||||||
|
self._init_reorder_batch_threshold(1, self.use_spec_decode)
|
||||||
|
|
||||||
|
self.use_full_cuda_graph = (
|
||||||
|
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||||
|
)
|
||||||
|
self.decode_cudagraph_max_bs = min(
|
||||||
|
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
|
||||||
|
self.compilation_config.max_capture_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.spec_state_indices_tensor = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs, self.num_spec + 1),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.non_spec_state_indices_tensor = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs,),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.spec_sequence_masks = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs,),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.spec_token_masks = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.spec_query_start_loc = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs + 1,),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.non_spec_query_start_loc = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs + 1,),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.num_accepted_tokens = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs,),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||||
|
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
|
||||||
|
fast_build: bool = False,
|
||||||
|
) -> GDNAttentionMetadata:
|
||||||
|
m = common_attn_metadata
|
||||||
|
|
||||||
|
query_start_loc = m.query_start_loc
|
||||||
|
context_lens = m.num_computed_tokens_cpu
|
||||||
|
context_lens_tensor = context_lens.to(query_start_loc.device)
|
||||||
|
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||||
|
|
||||||
|
if (
|
||||||
|
not self.use_spec_decode
|
||||||
|
or num_decode_draft_tokens_cpu is None
|
||||||
|
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
== 0
|
||||||
|
):
|
||||||
|
spec_sequence_masks = None
|
||||||
|
num_spec_decodes = 0
|
||||||
|
else:
|
||||||
|
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
|
||||||
|
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||||
|
if num_spec_decodes == 0:
|
||||||
|
spec_sequence_masks = None
|
||||||
|
else:
|
||||||
|
spec_sequence_masks = spec_sequence_masks.to(
|
||||||
|
query_start_loc.device, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if spec_sequence_masks is None:
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
split_decodes_and_prefills(m, decode_threshold=1)
|
||||||
|
)
|
||||||
|
num_spec_decode_tokens = 0
|
||||||
|
spec_token_masks = None
|
||||||
|
spec_state_indices_tensor = None
|
||||||
|
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||||
|
spec_query_start_loc = None
|
||||||
|
non_spec_query_start_loc = query_start_loc
|
||||||
|
num_accepted_tokens = None
|
||||||
|
else:
|
||||||
|
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||||
|
|
||||||
|
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||||
|
num_decodes = (non_spec_query_lens == 1).sum().item()
|
||||||
|
num_prefills = non_spec_query_lens.size(0) - num_decodes
|
||||||
|
num_decode_tokens = num_decodes
|
||||||
|
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
|
||||||
|
|
||||||
|
if num_prefills == 0 and num_decodes == 0:
|
||||||
|
spec_token_masks = torch.ones(
|
||||||
|
(
|
||||||
|
min(
|
||||||
|
num_spec_decodes * (self.num_spec + 1),
|
||||||
|
query_start_loc[-1].item(),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=query_start_loc.device,
|
||||||
|
)
|
||||||
|
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
|
||||||
|
non_spec_state_indices_tensor = None
|
||||||
|
spec_query_start_loc = query_start_loc
|
||||||
|
non_spec_query_start_loc = None
|
||||||
|
else:
|
||||||
|
spec_token_masks = torch.repeat_interleave(
|
||||||
|
spec_sequence_masks, query_lens
|
||||||
|
)
|
||||||
|
spec_state_indices_tensor = m.block_table_tensor[
|
||||||
|
spec_sequence_masks, : self.num_spec + 1
|
||||||
|
]
|
||||||
|
non_spec_state_indices_tensor = m.block_table_tensor[
|
||||||
|
~spec_sequence_masks, 0
|
||||||
|
]
|
||||||
|
|
||||||
|
spec_query_start_loc = torch.zeros(
|
||||||
|
num_spec_decodes + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=query_start_loc.device,
|
||||||
|
)
|
||||||
|
torch.cumsum(
|
||||||
|
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
|
||||||
|
)
|
||||||
|
non_spec_query_start_loc = torch.zeros(
|
||||||
|
query_lens.size(0) - num_spec_decodes + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=query_start_loc.device,
|
||||||
|
)
|
||||||
|
torch.cumsum(
|
||||||
|
query_lens[~spec_sequence_masks],
|
||||||
|
dim=0,
|
||||||
|
out=non_spec_query_start_loc[1:],
|
||||||
|
)
|
||||||
|
|
||||||
|
num_spec_decode_tokens = (
|
||||||
|
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||||
|
)
|
||||||
|
assert num_accepted_tokens is not None
|
||||||
|
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||||
|
|
||||||
|
if num_prefills > 0:
|
||||||
|
has_initial_state = context_lens_tensor > 0
|
||||||
|
if spec_sequence_masks is not None:
|
||||||
|
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||||
|
has_initial_state_cpu = has_initial_state.cpu()
|
||||||
|
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||||
|
compute_causal_conv1d_metadata(non_spec_query_start_loc)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
has_initial_state = None
|
||||||
|
has_initial_state_cpu = None
|
||||||
|
num_actual_tokens = (
|
||||||
|
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# prepare tensors for cudagraph
|
||||||
|
#
|
||||||
|
# With speculative decoding, the xgrammar backend may rollback tokens
|
||||||
|
# and causing some sequences has less draft tokens than self.num_spec.
|
||||||
|
#
|
||||||
|
# In above cases, the max possible batch size for n tokens, can be
|
||||||
|
# min(n, cudagraph_max_bs).
|
||||||
|
if (
|
||||||
|
self.use_full_cuda_graph
|
||||||
|
and num_prefills == 0
|
||||||
|
and num_decodes == 0
|
||||||
|
and num_spec_decodes <= self.decode_cudagraph_max_bs
|
||||||
|
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
|
||||||
|
):
|
||||||
|
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||||
|
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
|
||||||
|
|
||||||
|
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
|
||||||
|
spec_state_indices_tensor, non_blocking=True
|
||||||
|
)
|
||||||
|
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
|
||||||
|
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
|
||||||
|
|
||||||
|
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
||||||
|
spec_sequence_masks, non_blocking=True
|
||||||
|
)
|
||||||
|
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
|
||||||
|
spec_sequence_masks[num_spec_decodes:].fill_(False)
|
||||||
|
|
||||||
|
assert spec_token_masks is not None
|
||||||
|
self.spec_token_masks[: spec_token_masks.size(0)].copy_(
|
||||||
|
spec_token_masks, non_blocking=True
|
||||||
|
)
|
||||||
|
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
|
||||||
|
spec_token_masks[spec_token_masks.size(0) :].fill_(False)
|
||||||
|
|
||||||
|
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
|
||||||
|
spec_query_start_loc, non_blocking=True
|
||||||
|
)
|
||||||
|
spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index]
|
||||||
|
spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
|
||||||
|
spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)
|
||||||
|
|
||||||
|
self.num_accepted_tokens[:num_spec_decodes].copy_(
|
||||||
|
num_accepted_tokens, non_blocking=True
|
||||||
|
)
|
||||||
|
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
|
||||||
|
num_accepted_tokens[num_spec_decodes:].fill_(1)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.use_full_cuda_graph
|
||||||
|
and num_prefills == 0
|
||||||
|
and num_spec_decodes == 0
|
||||||
|
and num_decodes <= self.decode_cudagraph_max_bs
|
||||||
|
):
|
||||||
|
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||||
|
batch_size = num_actual_tokens
|
||||||
|
|
||||||
|
self.non_spec_state_indices_tensor[:num_decodes].copy_(
|
||||||
|
non_spec_state_indices_tensor, non_blocking=True
|
||||||
|
)
|
||||||
|
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
|
||||||
|
:batch_size
|
||||||
|
]
|
||||||
|
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
|
||||||
|
|
||||||
|
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
|
||||||
|
non_spec_query_start_loc, non_blocking=True
|
||||||
|
)
|
||||||
|
non_spec_num_query_tokens = non_spec_query_start_loc[
|
||||||
|
-1
|
||||||
|
] # type: ignore[index]
|
||||||
|
non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
|
||||||
|
non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)
|
||||||
|
|
||||||
|
if num_accepted_tokens is not None:
|
||||||
|
num_accepted_tokens = num_accepted_tokens.to(torch.int32)
|
||||||
|
attn_metadata = GDNAttentionMetadata(
|
||||||
|
num_prefills=num_prefills,
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
num_decodes=num_decodes,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
num_spec_decodes=num_spec_decodes,
|
||||||
|
num_spec_decode_tokens=num_spec_decode_tokens,
|
||||||
|
num_actual_tokens=num_actual_tokens,
|
||||||
|
has_initial_state=has_initial_state,
|
||||||
|
has_initial_state_cpu=has_initial_state_cpu,
|
||||||
|
spec_query_start_loc=spec_query_start_loc,
|
||||||
|
non_spec_query_start_loc=non_spec_query_start_loc,
|
||||||
|
spec_state_indices_tensor=spec_state_indices_tensor,
|
||||||
|
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
||||||
|
non_spec_state_indices_tensor_cpu=(
|
||||||
|
non_spec_state_indices_tensor.cpu()
|
||||||
|
if non_spec_state_indices_tensor is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
spec_sequence_masks=spec_sequence_masks,
|
||||||
|
spec_token_masks=spec_token_masks,
|
||||||
|
num_accepted_tokens=num_accepted_tokens,
|
||||||
|
nums_dict=nums_dict,
|
||||||
|
batch_ptr=batch_ptr,
|
||||||
|
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||||
|
)
|
||||||
|
return attn_metadata
|
||||||
|
|
||||||
|
def build_for_cudagraph_capture(
|
||||||
|
self, common_attn_metadata: CommonAttentionMetadata
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This method builds the metadata for full cudagraph capture.
|
||||||
|
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||||
|
"""
|
||||||
|
m = common_attn_metadata
|
||||||
|
|
||||||
|
assert (
|
||||||
|
m.num_reqs <= self.decode_cudagraph_max_bs
|
||||||
|
and m.num_actual_tokens <= self.decode_cudagraph_max_bs
|
||||||
|
), (
|
||||||
|
f"GDN only supports decode-only full CUDAGraph capture. "
|
||||||
|
f"Make sure batch size ({m.num_reqs}) <= "
|
||||||
|
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
|
||||||
|
f"and number of tokens ({m.num_actual_tokens}) <= "
|
||||||
|
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})."
|
||||||
|
)
|
||||||
|
|
||||||
|
num_accepted_tokens = torch.diff(m.query_start_loc)
|
||||||
|
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
|
||||||
|
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
|
||||||
|
|
||||||
|
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
|
||||||
|
|
||||||
|
|
||||||
|
gdn_attn.GDNAttentionMetadata = GDNAttentionMetadata
|
||||||
|
gdn_attn.GDNAttentionMetadataBuilder = GDNAttentionMetadataBuilder
|
||||||
@@ -770,24 +770,14 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
# If kv_cache is not provided, the new key and value tensors are
|
# If kv_cache is not provided, the new key and value tensors are
|
||||||
# not cached. This happens during the initial memory
|
# not cached. This happens during the initial memory
|
||||||
value = value.contiguous()
|
value = value.contiguous()
|
||||||
if key_cache.is_contiguous():
|
kunlun_ops.reshape_and_cache_flash(
|
||||||
kunlun_ops.reshape_and_cache(
|
key[: attn_metadata.num_actual_tokens],
|
||||||
key[: attn_metadata.num_actual_tokens],
|
value[: attn_metadata.num_actual_tokens],
|
||||||
value[: attn_metadata.num_actual_tokens],
|
key_cache,
|
||||||
key_cache,
|
value_cache,
|
||||||
value_cache,
|
updated_slot_mapping,
|
||||||
updated_slot_mapping,
|
BLHD_LAYOUT=False,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
|
|
||||||
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
|
|
||||||
kunlun_ops.reshape_and_cache_flash(
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
cast_key_cache,
|
|
||||||
cast_value_cache,
|
|
||||||
updated_slot_mapping,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert attn_type == AttentionType.DECODER
|
assert attn_type == AttentionType.DECODER
|
||||||
# Decoder self-attention supports chunked prefill.
|
# Decoder self-attention supports chunked prefill.
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Union
|
|
||||||
|
import kunlun_ops
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module):
|
|||||||
bonus_token_ids: torch.Tensor,
|
bonus_token_ids: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
'''
|
"""
|
||||||
Args:
|
Args:
|
||||||
metadata:
|
metadata:
|
||||||
Metadata for spec decoding.
|
Metadata for spec decoding.
|
||||||
@@ -81,7 +80,7 @@ class RejectionSampler(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
output_token_ids (torch.Tensor):
|
output_token_ids (torch.Tensor):
|
||||||
A tensor containing the final output token IDs.
|
A tensor containing the final output token IDs.
|
||||||
'''
|
"""
|
||||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||||
# [num_tokens, vocab_size]
|
# [num_tokens, vocab_size]
|
||||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||||
@@ -124,11 +123,11 @@ class RejectionSampler(nn.Module):
|
|||||||
"""
|
"""
|
||||||
output_token_ids_np = output_token_ids.cpu().numpy()
|
output_token_ids_np = output_token_ids.cpu().numpy()
|
||||||
# Create mask for valid tokens.
|
# Create mask for valid tokens.
|
||||||
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
|
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
|
||||||
(output_token_ids_np < vocab_size))
|
output_token_ids_np < vocab_size
|
||||||
|
)
|
||||||
outputs = [
|
outputs = [
|
||||||
row[valid_mask[i]].tolist()
|
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
|
||||||
for i, row in enumerate(output_token_ids_np)
|
|
||||||
]
|
]
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -179,25 +178,15 @@ def rejection_sample(
|
|||||||
if not sampling_metadata.all_random:
|
if not sampling_metadata.all_random:
|
||||||
# Rejection sampling for greedy sampling requests.
|
# Rejection sampling for greedy sampling requests.
|
||||||
target_argmax = target_probs.argmax(dim=-1)
|
target_argmax = target_probs.argmax(dim=-1)
|
||||||
if min(num_draft_tokens) == 1 and max(
|
kunlun_ops.rejection_greedy_sample(
|
||||||
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
|
output_token_ids,
|
||||||
rejection_greedy_sample_spec_len_1_pytorch(
|
cu_num_draft_tokens,
|
||||||
output_token_ids,
|
draft_token_ids,
|
||||||
draft_token_ids,
|
target_argmax,
|
||||||
target_argmax,
|
bonus_token_ids,
|
||||||
bonus_token_ids,
|
is_greedy,
|
||||||
)
|
max_spec_len,
|
||||||
else:
|
)
|
||||||
rejection_greedy_sample_pytorch(
|
|
||||||
output_token_ids,
|
|
||||||
cu_num_draft_tokens,
|
|
||||||
draft_token_ids,
|
|
||||||
target_argmax,
|
|
||||||
bonus_token_ids,
|
|
||||||
num_draft_tokens,
|
|
||||||
max_spec_len,
|
|
||||||
is_greedy,
|
|
||||||
)
|
|
||||||
if sampling_metadata.all_greedy:
|
if sampling_metadata.all_greedy:
|
||||||
return output_token_ids
|
return output_token_ids
|
||||||
|
|
||||||
@@ -222,8 +211,9 @@ def rejection_sample(
|
|||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
|
bonus_token_ids = bonus_token_ids.squeeze(1)
|
||||||
|
|
||||||
rejection_random_sample_pytorch(
|
kunlun_ops.rejection_random_sample(
|
||||||
output_token_ids,
|
output_token_ids,
|
||||||
cu_num_draft_tokens,
|
cu_num_draft_tokens,
|
||||||
draft_token_ids,
|
draft_token_ids,
|
||||||
@@ -235,8 +225,7 @@ def rejection_sample(
|
|||||||
is_greedy,
|
is_greedy,
|
||||||
max_spec_len,
|
max_spec_len,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
IS_NGRAM=draft_probs is None,
|
no_draft_probs=draft_probs is None,
|
||||||
# num_warps=1,
|
|
||||||
)
|
)
|
||||||
return output_token_ids
|
return output_token_ids
|
||||||
|
|
||||||
@@ -374,7 +363,7 @@ def generate_uniform_probs(
|
|||||||
random values in the range [0, 1).
|
random values in the range [0, 1).
|
||||||
"""
|
"""
|
||||||
uniform_probs = torch.rand(
|
uniform_probs = torch.rand(
|
||||||
(num_tokens, ),
|
(num_tokens,),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
@@ -422,7 +411,7 @@ def sample_recovered_tokens(
|
|||||||
q[i].exponential_(generator=generator)
|
q[i].exponential_(generator=generator)
|
||||||
|
|
||||||
recovered_token_ids = torch.empty_like(draft_token_ids)
|
recovered_token_ids = torch.empty_like(draft_token_ids)
|
||||||
sample_recovered_tokens_pytorch(
|
kunlun_ops.sample_recovered_tokens(
|
||||||
recovered_token_ids,
|
recovered_token_ids,
|
||||||
cu_num_draft_tokens,
|
cu_num_draft_tokens,
|
||||||
draft_token_ids,
|
draft_token_ids,
|
||||||
@@ -430,16 +419,16 @@ def sample_recovered_tokens(
|
|||||||
target_probs,
|
target_probs,
|
||||||
q,
|
q,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
IS_NGRAM=draft_probs is None,
|
no_draft_probs=draft_probs is None,
|
||||||
)
|
)
|
||||||
return recovered_token_ids
|
return recovered_token_ids
|
||||||
|
|
||||||
|
|
||||||
def rejection_greedy_sample_spec_len_1_pytorch(
|
def rejection_greedy_sample_spec_len_1_pytorch(
|
||||||
output_token_ids, # [batch_size, 2]
|
output_token_ids, # [batch_size, 2]
|
||||||
draft_token_ids, # [num_tokens]
|
draft_token_ids, # [num_tokens]
|
||||||
target_argmax, # [num_tokens]
|
target_argmax, # [num_tokens]
|
||||||
bonus_token_ids, # [batch_size]
|
bonus_token_ids, # [batch_size]
|
||||||
):
|
):
|
||||||
batch_size = output_token_ids.size(0)
|
batch_size = output_token_ids.size(0)
|
||||||
num_tokens = draft_token_ids.size(0)
|
num_tokens = draft_token_ids.size(0)
|
||||||
@@ -447,73 +436,72 @@ def rejection_greedy_sample_spec_len_1_pytorch(
|
|||||||
accept_req_mask = draft_token_ids == target_argmax
|
accept_req_mask = draft_token_ids == target_argmax
|
||||||
output_token_ids[:, 0] = target_argmax
|
output_token_ids[:, 0] = target_argmax
|
||||||
bonus_token_ids = bonus_token_ids.squeeze(1)
|
bonus_token_ids = bonus_token_ids.squeeze(1)
|
||||||
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids,
|
output_token_ids[:, 1] = torch.where(
|
||||||
output_token_ids[:, 1])
|
accept_req_mask, bonus_token_ids, output_token_ids[:, 1]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def rejection_greedy_sample_pytorch(
|
def rejection_greedy_sample_pytorch(
|
||||||
output_token_ids, # [batch_size, max_spec_len + 1]
|
output_token_ids, # [batch_size, max_spec_len + 1]
|
||||||
cu_num_draft_tokens, # [batch_size]
|
cu_num_draft_tokens, # [batch_size]
|
||||||
draft_token_ids, # [num_tokens]
|
draft_token_ids, # [num_tokens]
|
||||||
target_argmax, # [num_tokens]
|
target_argmax, # [num_tokens]
|
||||||
bonus_token_ids, # [batch_size]
|
bonus_token_ids, # [batch_size]
|
||||||
draft_tokens_per_req, # [batch_size], list
|
draft_tokens_per_req, # [batch_size], list
|
||||||
max_spec_len,
|
max_spec_len,
|
||||||
is_greedy=None, # [batch_size] or None
|
is_greedy=None, # [batch_size] or None
|
||||||
):
|
):
|
||||||
batch_size = output_token_ids.size(0)
|
batch_size = output_token_ids.size(0)
|
||||||
num_tokens = draft_token_ids.size(0)
|
num_tokens = draft_token_ids.size(0)
|
||||||
device = output_token_ids.device
|
device = output_token_ids.device
|
||||||
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
|
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
|
||||||
device, non_blocking=True)
|
device, non_blocking=True
|
||||||
|
)
|
||||||
if is_greedy is None:
|
if is_greedy is None:
|
||||||
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
|
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
start_indices = cu_num_draft_tokens - draft_tokens_per_req
|
start_indices = cu_num_draft_tokens - draft_tokens_per_req
|
||||||
req_ids = torch.arange(batch_size, device=device)
|
req_ids = torch.arange(batch_size, device=device)
|
||||||
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
|
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
|
||||||
token_positions = torch.arange(
|
token_positions = (
|
||||||
num_tokens, device=device) - start_indices[token_req_ids]
|
torch.arange(num_tokens, device=device) - start_indices[token_req_ids]
|
||||||
|
)
|
||||||
|
|
||||||
# Find the first mismatch position of each request.
|
# Find the first mismatch position of each request.
|
||||||
mismatch_global = (draft_token_ids != target_argmax)
|
mismatch_global = draft_token_ids != target_argmax
|
||||||
if max_spec_len == 0:
|
if max_spec_len == 0:
|
||||||
first_mismatch_pos_per_req = torch.zeros(batch_size,
|
first_mismatch_pos_per_req = torch.zeros(
|
||||||
dtype=torch.long,
|
batch_size, dtype=torch.long, device=device
|
||||||
device=device)
|
)
|
||||||
else:
|
else:
|
||||||
# [bs, max_spec_len]
|
# [bs, max_spec_len]
|
||||||
pos_matrix = torch.full((batch_size, max_spec_len),
|
pos_matrix = torch.full(
|
||||||
-1,
|
(batch_size, max_spec_len), -1, dtype=torch.long, device=device
|
||||||
dtype=torch.long,
|
)
|
||||||
device=device)
|
|
||||||
pos_matrix[token_req_ids, token_positions] = token_positions
|
pos_matrix[token_req_ids, token_positions] = token_positions
|
||||||
mismatch_matrix = torch.full((batch_size, max_spec_len),
|
mismatch_matrix = torch.full(
|
||||||
False,
|
(batch_size, max_spec_len), False, dtype=torch.bool, device=device
|
||||||
dtype=torch.bool,
|
)
|
||||||
device=device)
|
|
||||||
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
|
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
|
||||||
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
|
mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2)
|
||||||
max_spec_len * 2)
|
|
||||||
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
|
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
|
||||||
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
|
no_mismatch_mask = first_mismatch_pos_per_req == max_spec_len * 2
|
||||||
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
|
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
|
||||||
no_mismatch_mask]
|
no_mismatch_mask
|
||||||
|
]
|
||||||
|
|
||||||
# Copy matched target tokens into output.
|
# Copy matched target tokens into output.
|
||||||
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
|
copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req)
|
||||||
draft_tokens_per_req)
|
copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1)
|
||||||
copy_indices = torch.arange(max_spec_len + 1,
|
|
||||||
device=device).expand(batch_size, -1)
|
|
||||||
copy_mask = copy_indices < copy_len.unsqueeze(1)
|
copy_mask = copy_indices < copy_len.unsqueeze(1)
|
||||||
greedy_mask = is_greedy.unsqueeze(1)
|
greedy_mask = is_greedy.unsqueeze(1)
|
||||||
final_copy_mask = copy_mask & greedy_mask
|
final_copy_mask = copy_mask & greedy_mask
|
||||||
global_idx = start_indices.unsqueeze(1) + copy_indices
|
global_idx = start_indices.unsqueeze(1) + copy_indices
|
||||||
output_token_ids[final_copy_mask] = target_argmax[
|
output_token_ids[final_copy_mask] = target_argmax[global_idx[final_copy_mask]].to(
|
||||||
global_idx[final_copy_mask]].to(output_token_ids.dtype)
|
output_token_ids.dtype
|
||||||
|
)
|
||||||
# Fill bonus token.
|
# Fill bonus token.
|
||||||
needs_bonus = is_greedy & (first_mismatch_pos_per_req
|
needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req)
|
||||||
>= draft_tokens_per_req)
|
|
||||||
if torch.any(needs_bonus):
|
if torch.any(needs_bonus):
|
||||||
bonus_rows = torch.where(needs_bonus)[0]
|
bonus_rows = torch.where(needs_bonus)[0]
|
||||||
bonus_cols = draft_tokens_per_req[bonus_rows]
|
bonus_cols = draft_tokens_per_req[bonus_rows]
|
||||||
@@ -556,11 +544,9 @@ def rejection_random_sample_pytorch(
|
|||||||
if IS_NGRAM:
|
if IS_NGRAM:
|
||||||
draft_prob = 1.0
|
draft_prob = 1.0
|
||||||
else:
|
else:
|
||||||
draft_prob = draft_probs[start_idx + pos,
|
draft_prob = draft_probs[start_idx + pos, draft_token_id].item()
|
||||||
draft_token_id].item()
|
|
||||||
|
|
||||||
target_prob = target_probs[start_idx + pos,
|
target_prob = target_probs[start_idx + pos, draft_token_id].item()
|
||||||
draft_token_id].item()
|
|
||||||
uniform_prob = uniform_probs[start_idx + pos].item()
|
uniform_prob = uniform_probs[start_idx + pos].item()
|
||||||
|
|
||||||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
||||||
@@ -629,12 +615,11 @@ def sample_recovered_tokens_pytorch(
|
|||||||
else:
|
else:
|
||||||
draft_p = draft_probs[token_idx].clone()
|
draft_p = draft_probs[token_idx].clone()
|
||||||
target_p = target_probs[token_idx].clone()
|
target_p = target_probs[token_idx].clone()
|
||||||
prob = torch.maximum(target_p - draft_p,
|
prob = torch.maximum(
|
||||||
torch.tensor(0.0, device=target_p.device))
|
target_p - draft_p, torch.tensor(0.0, device=target_p.device)
|
||||||
|
)
|
||||||
|
|
||||||
q_values = torch.full((vocab_size, ),
|
q_values = torch.full((vocab_size,), float("-inf"), device=q.device)
|
||||||
float('-inf'),
|
|
||||||
device=q.device)
|
|
||||||
q_values[:vocab_size] = q[req_idx, :vocab_size]
|
q_values[:vocab_size] = q[req_idx, :vocab_size]
|
||||||
|
|
||||||
recovered_id = torch.argmax(prob / q_values).item()
|
recovered_id = torch.argmax(prob / q_values).item()
|
||||||
@@ -642,4 +627,3 @@ def sample_recovered_tokens_pytorch(
|
|||||||
|
|
||||||
if IS_NGRAM:
|
if IS_NGRAM:
|
||||||
target_probs[token_idx, draft_token_id] = orig_prob
|
target_probs[token_idx, draft_token_id] = orig_prob
|
||||||
|
|
||||||
|
|||||||
@@ -337,5 +337,5 @@ def prepare_next_token_ids_padded(
|
|||||||
return next_token_ids, valid_sampled_tokens_count
|
return next_token_ids, valid_sampled_tokens_count
|
||||||
|
|
||||||
|
|
||||||
EagleProposer.propose = propose
|
# EagleProposer.propose = propose
|
||||||
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded
|
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded
|
||||||
|
|||||||
@@ -407,7 +407,7 @@ def add_rmsnorm(
|
|||||||
) -> None:
|
) -> None:
|
||||||
kunlun_ops.add_rmsnorm(
|
kunlun_ops.add_rmsnorm(
|
||||||
x,
|
x,
|
||||||
y, # 原来写 residual,这里其实是 y
|
y,
|
||||||
residual_output=residual_output,
|
residual_output=residual_output,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
@@ -523,6 +523,145 @@ def _fake_add_rmsnorm(
|
|||||||
add_rmsnorm.register_fake(_fake_add_rmsnorm)
|
add_rmsnorm.register_fake(_fake_add_rmsnorm)
|
||||||
|
|
||||||
|
|
||||||
|
@custom_op("_C::gemma_add_rmsnorm", mutates_args=())
|
||||||
|
def gemma_add_rmsnorm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
enable_pdl: bool = False,
|
||||||
|
interweaved: bool = False,
|
||||||
|
store_output_before_norm: bool = True,
|
||||||
|
bias: torch.Tensor = None,
|
||||||
|
smooth: torch.Tensor = None,
|
||||||
|
residual_output: torch.Tensor = None,
|
||||||
|
force_sdnn: bool = False,
|
||||||
|
) -> None:
|
||||||
|
# print("gemma_add_rmsnorm wrapper")
|
||||||
|
kunlun_ops.gemma_add_rmsnorm(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
weight=weight,
|
||||||
|
output=output,
|
||||||
|
eps=eps,
|
||||||
|
enable_pdl=enable_pdl,
|
||||||
|
interweaved=interweaved,
|
||||||
|
store_output_before_norm=store_output_before_norm,
|
||||||
|
bias=bias,
|
||||||
|
smooth=smooth,
|
||||||
|
residual_output=residual_output,
|
||||||
|
force_sdnn=force_sdnn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::gemma_add_rmsnorm", "CUDA")
|
||||||
|
def gemma_add_rmsnorm_cuda(
|
||||||
|
x: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
enable_pdl: bool = False,
|
||||||
|
interweaved: bool = False,
|
||||||
|
store_output_before_norm: bool = True,
|
||||||
|
bias: torch.Tensor = None,
|
||||||
|
smooth: torch.Tensor = None,
|
||||||
|
residual_output: torch.Tensor = None,
|
||||||
|
force_sdnn: bool = False,
|
||||||
|
) -> None:
|
||||||
|
# print("gemma_add_rmsnorm_cuda wrapper")
|
||||||
|
kunlun_ops.gemma_add_rmsnorm(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
weight=weight,
|
||||||
|
output=output,
|
||||||
|
eps=eps,
|
||||||
|
enable_pdl=enable_pdl,
|
||||||
|
interweaved=interweaved,
|
||||||
|
store_output_before_norm=store_output_before_norm,
|
||||||
|
bias=bias,
|
||||||
|
smooth=smooth,
|
||||||
|
residual_output=residual_output,
|
||||||
|
force_sdnn=force_sdnn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_gemma_add_rmsnorm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
enable_pdl: bool = False,
|
||||||
|
interweaved: bool = False,
|
||||||
|
store_output_before_norm: bool = True,
|
||||||
|
bias: torch.Tensor = None,
|
||||||
|
smooth: torch.Tensor = None,
|
||||||
|
residual_output: torch.Tensor = None,
|
||||||
|
force_sdnn: bool = False,
|
||||||
|
):
|
||||||
|
output.fake_shape = x.shape
|
||||||
|
output.fake_dtype = x.dtype
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
gemma_add_rmsnorm.register_fake(_fake_gemma_add_rmsnorm)
|
||||||
|
|
||||||
|
|
||||||
|
@custom_op("_C::gemma_rmsnorm", mutates_args=())
|
||||||
|
def gemma_rmsnorm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
enable_pdl: bool = False,
|
||||||
|
interweave: bool = False,
|
||||||
|
bias: torch.Tensor = None,
|
||||||
|
force_sdnn: bool = False,
|
||||||
|
) -> None:
|
||||||
|
# print("gemma_rmsnorm wrapper")
|
||||||
|
kunlun_ops.gemma_rmsnorm(
|
||||||
|
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::gemma_rmsnorm", "CUDA")
|
||||||
|
def gemma_rmsnorm_cuda(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
enable_pdl: bool = False,
|
||||||
|
interweave: bool = False,
|
||||||
|
bias: torch.Tensor = None,
|
||||||
|
force_sdnn: bool = False,
|
||||||
|
) -> None:
|
||||||
|
# print("gemma_rmsnorm_cuda wrapper")
|
||||||
|
kunlun_ops.gemma_rmsnorm(
|
||||||
|
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_gemma_rmsnorm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
enable_pdl: bool = False,
|
||||||
|
interweave: bool = False,
|
||||||
|
bias: torch.Tensor = None,
|
||||||
|
force_sdnn: bool = False,
|
||||||
|
):
|
||||||
|
# 设置 shape/dtype,但不返回值
|
||||||
|
output.fake_shape = x.shape
|
||||||
|
output.fake_dtype = x.dtype
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
gemma_rmsnorm.register_fake(_fake_gemma_rmsnorm)
|
||||||
|
|
||||||
|
|
||||||
@custom_op("_C::split_norm_rope_neox", mutates_args=())
|
@custom_op("_C::split_norm_rope_neox", mutates_args=())
|
||||||
def split_norm_rope_neox(
|
def split_norm_rope_neox(
|
||||||
q_emb: torch.Tensor,
|
q_emb: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user