[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)
Signed-off-by: xyDong0223 <dongxinyu03@baidu.com> Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -3,95 +3,113 @@ from vllm import ModelRegistry
|
||||
|
||||
def register_model():
|
||||
# from .demo_model import DemoModel # noqa: F401
|
||||
from .qwen2_vl import Qwen2VLForConditionalGeneration #noqa: F401
|
||||
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration #noqa: F401
|
||||
from .qwen3_moe import Qwen3MoeForCausalLM #noqa: F401
|
||||
from .qwen3_vl import Qwen3VLForConditionalGeneration
|
||||
from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
||||
from .qwen3_omni_moe_thinker import Qwen3OmniMoeThinkerForConditionalGeneration
|
||||
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration # noqa: F401
|
||||
from .qwen2_vl import Qwen2VLForConditionalGeneration # noqa: F401
|
||||
from .qwen3_moe import Qwen3MoeForCausalLM # noqa: F401
|
||||
from .qwen3_omni_moe_thinker import ( # noqa: F401
|
||||
Qwen3OmniMoeThinkerForConditionalGeneration,
|
||||
)
|
||||
from .qwen3_vl import Qwen3VLForConditionalGeneration # noqa: F401
|
||||
from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration # noqa: F401
|
||||
|
||||
# from .llama4 import Llama4ForCausalLM #noqa: F401
|
||||
# from .mllama4 import Llama4ForConditionalGeneration #noqa: F401
|
||||
# from .deepseek_v2 import KunlunDeepseekV2MoE
|
||||
|
||||
# ModelRegistry.register_model(
|
||||
# "DemoModel",
|
||||
# "vllm_kunlun.model_executor.models.demo_model:DemoModel")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
"vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration")
|
||||
"vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration",
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration")
|
||||
"vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration",
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3ForCausalLM",
|
||||
"vllm_kunlun.models.qwen3:Qwen3ForCausalLM")
|
||||
"Qwen3ForCausalLM", "vllm_kunlun.models.qwen3:Qwen3ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3MoeForCausalLM",
|
||||
"vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM")
|
||||
"Qwen3MoeForCausalLM", "vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3NextForCausalLM",
|
||||
"vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM")
|
||||
"Qwen3NextForCausalLM", "vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"GptOssForCausalLM",
|
||||
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
|
||||
"Qwen3NextMTP", "vllm_kunlun.models.qwen3_next_mtp:Qwen3NextMTP"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"InternLM2ForCausalLM",
|
||||
"vllm_kunlun.models.internlm2:InternLM2ForCausalLM")
|
||||
|
||||
"GlmForCausalLM", "vllm_kunlun.models.glm:GlmForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"InternVLChatModel",
|
||||
"vllm_kunlun.models.internvl:InternVLChatModel")
|
||||
"GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"InternLM2ForCausalLM", "vllm_kunlun.models.internlm2:InternLM2ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"InternVLChatModel", "vllm_kunlun.models.internvl:InternVLChatModel"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"InternS1ForConditionalGeneration",
|
||||
"vllm_kunlun.models.interns1:InternS1ForConditionalGeneration")
|
||||
|
||||
"vllm_kunlun.models.interns1:InternS1ForConditionalGeneration",
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3VLForConditionalGeneration",
|
||||
"vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration")
|
||||
|
||||
"vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration",
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3VLMoeForConditionalGeneration",
|
||||
"vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration")
|
||||
"vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration",
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3OmniMoeForConditionalGeneration",
|
||||
"vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration")
|
||||
"vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration",
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"SeedOssForCausalLM",
|
||||
"vllm_kunlun.models.seed_oss:SeedOssForCausalLM")
|
||||
"SeedOssForCausalLM", "vllm_kunlun.models.seed_oss:SeedOssForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"MiMoV2FlashForCausalLM",
|
||||
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM")
|
||||
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM",
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"GptOssForCausalLM",
|
||||
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
|
||||
"GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV3ForCausalLM",
|
||||
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
|
||||
"DeepseekV3ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV32ForCausalLM",
|
||||
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepSeekMTPModel",
|
||||
"vllm_kunlun.models.deepseek_mtp:DeepSeekMTP")
|
||||
"DeepseekV32ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"GlmMoeDsaForCausalLM",
|
||||
"vllm_kunlun.models.deepseek_v2:GlmMoeDsaForCausalLM")
|
||||
"DeepSeekMTPModel", "vllm_kunlun.models.deepseek_mtp:DeepSeekMTP"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"GlmMoeDsaForCausalLM", "vllm_kunlun.models.deepseek_v2:GlmMoeDsaForCausalLM"
|
||||
)
|
||||
|
||||
|
||||
def register_quant_method():
|
||||
"""to do"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
303
vllm_kunlun/models/qwen3_next_mtp.py
Normal file
303
vllm_kunlun/models/qwen3_next_mtp.py
Normal file
@@ -0,0 +1,303 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Inference-only Qwen3Next MTP model."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsPP
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import Qwen3NextConfig
|
||||
|
||||
from .qwen3_next import Qwen3NextDecoderLayer, Qwen3NextRMSNorm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
KVCache = tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Qwen3NextMultiTokenPredictor(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
config: Qwen3NextConfig = model_config.hf_config
|
||||
|
||||
self.config = config
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
self.fc = ColumnParallelLinear(
|
||||
self.config.hidden_size * 2,
|
||||
self.config.hidden_size,
|
||||
gather_output=True,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc",
|
||||
)
|
||||
|
||||
self.layers = torch.nn.ModuleList(
|
||||
Qwen3NextDecoderLayer(
|
||||
vllm_config,
|
||||
layer_type="full_attention",
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
)
|
||||
for idx in range(self.num_mtp_layers)
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
|
||||
inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
|
||||
hidden_states = self.pre_fc_norm_hidden(hidden_states)
|
||||
hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
|
||||
hidden_states = self.fc(hidden_states)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
current_step_idx = spec_step_idx % self.num_mtp_layers
|
||||
hidden_states, residual = self.layers[current_step_idx](
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts,
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
if "mlp.experts" in name:
|
||||
continue
|
||||
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if (
|
||||
name.endswith(".bias") or name.endswith("_bias")
|
||||
) and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Qwen3NextMTP(nn.Module, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": ["up_proj", "down_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
cache_config = vllm_config.cache_config
|
||||
assert (
|
||||
not cache_config.enable_prefix_caching
|
||||
), "Qwen3NextMTP currently does not support prefix caching"
|
||||
|
||||
self.quant_config = vllm_config.quant_config
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Qwen3NextMultiTokenPredictor(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
|
||||
)
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, config.vocab_size
|
||||
)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
spec_step_idx: int = 0,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.logits_processor(self.lm_head, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
shared_weight_names = ["embed_tokens", "lm_head"]
|
||||
|
||||
def remap_weight_names(weights):
|
||||
for name, weight in weights:
|
||||
if name.startswith("mtp."):
|
||||
name = name.replace("mtp.", "model.")
|
||||
elif not any(key in name for key in shared_weight_names):
|
||||
continue
|
||||
yield name, weight
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(remap_weight_names(weights))
|
||||
@@ -16,33 +16,33 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""kunlun custom op entry"""
|
||||
import torch_xmlir
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import os
|
||||
from typing import Optional, List, Dict
|
||||
import vllm.envs as envs
|
||||
import os
|
||||
import ctypes
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import kunlun_ops
|
||||
logger.info(f"Load custom ops library success!")
|
||||
|
||||
logger.info("Load custom ops library success!")
|
||||
except ImportError as e:
|
||||
logger.warning("Import error msg: %s", e.msg)
|
||||
|
||||
|
||||
_per_token_smooth_quant = True
|
||||
|
||||
|
||||
def is_per_token_smooth_quant():
|
||||
""" is per token smooth quant """
|
||||
"""is per token smooth quant"""
|
||||
return _per_token_smooth_quant
|
||||
|
||||
|
||||
class KunlunOps:
|
||||
"""KunlunOps"""
|
||||
|
||||
# Attention ops
|
||||
@staticmethod
|
||||
def paged_attention_v1(
|
||||
@@ -67,9 +67,9 @@ class KunlunOps:
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
alibi_sqrt=False
|
||||
):
|
||||
""" PagedAttentionV1 """
|
||||
alibi_sqrt=False,
|
||||
):
|
||||
"""PagedAttentionV1"""
|
||||
# block_size = value_cache.shape[2]
|
||||
kunlun_ops.paged_attention(
|
||||
x=query,
|
||||
@@ -81,7 +81,7 @@ class KunlunOps:
|
||||
is_context=is_context,
|
||||
is_causal=True,
|
||||
out=output,
|
||||
vo_head_dim=128
|
||||
vo_head_dim=128,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -110,9 +110,9 @@ class KunlunOps:
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
alibi_sqrt=False
|
||||
):
|
||||
""" PagedAttentionV2 """
|
||||
alibi_sqrt=False,
|
||||
):
|
||||
"""PagedAttentionV2"""
|
||||
# block_size = value_cache.shape[2]
|
||||
kunlun_ops.paged_attention(
|
||||
x=query,
|
||||
@@ -124,31 +124,28 @@ class KunlunOps:
|
||||
is_context=is_context,
|
||||
is_causal=True,
|
||||
out=output,
|
||||
vo_head_dim=128
|
||||
vo_head_dim=128,
|
||||
)
|
||||
|
||||
|
||||
# Activation ops
|
||||
@staticmethod
|
||||
def silu_and_mul(out: torch.Tensor,
|
||||
x: torch.Tensor):
|
||||
""" silu and mul """
|
||||
def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
|
||||
"""silu and mul"""
|
||||
kunlun_ops.silu_and_mul(
|
||||
x,
|
||||
axis=-1,
|
||||
turn=True,
|
||||
out=out,
|
||||
)
|
||||
)
|
||||
|
||||
# Activation ops
|
||||
@staticmethod
|
||||
def quick_gelu(out: torch.Tensor,
|
||||
x: torch.Tensor):
|
||||
""" quick gelu """
|
||||
def quick_gelu(out: torch.Tensor, x: torch.Tensor):
|
||||
"""quick gelu"""
|
||||
kunlun_ops.quick_gelu(
|
||||
x,
|
||||
out=out,
|
||||
)
|
||||
)
|
||||
|
||||
# Layernorm
|
||||
@staticmethod
|
||||
@@ -159,9 +156,7 @@ class KunlunOps:
|
||||
epsilon,
|
||||
):
|
||||
"""rms_norm"""
|
||||
kunlun_ops.rmsnorm(
|
||||
x, weight.to(torch.float32), epsilon, out=out
|
||||
)
|
||||
kunlun_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
|
||||
|
||||
@staticmethod
|
||||
def fused_add_rms_norm(
|
||||
@@ -179,16 +174,11 @@ class KunlunOps:
|
||||
residual.copy_(fused_input, non_blocking=True)
|
||||
x.copy_(output)
|
||||
|
||||
|
||||
# Rotary embedding
|
||||
@staticmethod
|
||||
def rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style):
|
||||
positions, query, key, head_size, cos_sin_cache, is_neox_style
|
||||
):
|
||||
"""
|
||||
refactor RotaryEmbedding forward function
|
||||
"""
|
||||
@@ -196,62 +186,38 @@ class KunlunOps:
|
||||
key_x = key.contiguous()
|
||||
|
||||
torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query_x,
|
||||
key_x,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style)
|
||||
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
|
||||
)
|
||||
|
||||
return query_x, key_x
|
||||
|
||||
# Rotary embedding
|
||||
@staticmethod
|
||||
def mrotary_embedding(
|
||||
positions,
|
||||
mrope_section,
|
||||
query,
|
||||
key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style):
|
||||
positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
|
||||
):
|
||||
"""
|
||||
refactor RotaryEmbedding forward function
|
||||
"""
|
||||
query_x = query.contiguous()
|
||||
key_x = key.contiguous()
|
||||
query_x_dim = query_x.dim()
|
||||
assert is_neox_style
|
||||
kunlun_ops.mrotary_embedding_neox(
|
||||
positions,
|
||||
query_x,
|
||||
key_x,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
mrope_section)
|
||||
positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
|
||||
)
|
||||
|
||||
query.data = query_x
|
||||
key.data = key_x
|
||||
key.data = key_x
|
||||
return query, key
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src,
|
||||
dst,
|
||||
block_mapping):
|
||||
""" swap_blocks """
|
||||
kunlun_ops.swap_blocks(
|
||||
src,
|
||||
dst,
|
||||
block_mapping
|
||||
)
|
||||
def swap_blocks(src, dst, block_mapping):
|
||||
"""swap_blocks"""
|
||||
kunlun_ops.swap_blocks(src, dst, block_mapping)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
key_caches,
|
||||
value_caches,
|
||||
block_mapping):
|
||||
""" copy_blocks """
|
||||
def copy_blocks(key_caches, value_caches, block_mapping):
|
||||
"""copy_blocks"""
|
||||
for i in range(len(key_caches)):
|
||||
key_caches[i] = key_caches[i].contiguous()
|
||||
value_caches[i] = value_caches[i].contiguous()
|
||||
@@ -269,16 +235,10 @@ class KunlunOps:
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype,
|
||||
):
|
||||
""" reshape_and_cache """
|
||||
):
|
||||
"""reshape_and_cache"""
|
||||
# slot_mapping_cast = slot_mapping.to(torch.int32)
|
||||
kunlun_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping
|
||||
)
|
||||
kunlun_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
@staticmethod
|
||||
def multi_query_kv_attention(
|
||||
@@ -287,7 +247,7 @@ class KunlunOps:
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
**kargs
|
||||
**kargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
@@ -297,16 +257,12 @@ class KunlunOps:
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
output = torch.empty_like(query)
|
||||
alibi_slopes = kargs.get("alibi_slopes", None)
|
||||
mask = kargs.get("mask", None)
|
||||
is_causal = kargs.get("is_causal", True)
|
||||
is_lvsl = kargs.get("is_lvsl", True)
|
||||
|
||||
B, T, Qh, Hd = query.shape
|
||||
KVh = key.size(2)
|
||||
if KVh != Qh:
|
||||
repeat = Qh // KVh
|
||||
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
|
||||
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
|
||||
value = value.repeat_interleave(repeat, dim=2)
|
||||
kunlun_ops.attention(
|
||||
q=query,
|
||||
@@ -321,80 +277,90 @@ class KunlunOps:
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def quant_fusedresidual_rmsnorm_op(x,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
scale_to_int,
|
||||
eps,
|
||||
dyn_scale: bool,
|
||||
type: int = 1):
|
||||
def quant_fusedresidual_rmsnorm_op(
|
||||
x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
|
||||
):
|
||||
"""Quantized fused residual layer normalization"""
|
||||
out = torch.empty_like(x, dtype=torch.int8)
|
||||
|
||||
if is_per_token_smooth_quant():
|
||||
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
|
||||
out_scale = torch.empty(
|
||||
x.shape[:-1], device=x.device, dtype=torch.float
|
||||
).unsqueeze(-1)
|
||||
else:
|
||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||
|
||||
kunlun_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
|
||||
out=out, out_scale=out_scale , residual_tensor=residual)
|
||||
kunlun_ops.quant_fusedresidual_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
out=out,
|
||||
out_scale=out_scale,
|
||||
residual_tensor=residual,
|
||||
)
|
||||
|
||||
if residual is None:
|
||||
return out, out_scale
|
||||
return out, out_scale, residual
|
||||
|
||||
@staticmethod
|
||||
def quant_rmsnorm_op(x,
|
||||
weight,
|
||||
bias,
|
||||
scale_to_int,
|
||||
eps,
|
||||
dyn_scale : bool,
|
||||
type: int = 1):
|
||||
def quant_rmsnorm_op(
|
||||
x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
|
||||
):
|
||||
"""Quantized RMSNorm"""
|
||||
|
||||
out = torch.empty_like(x, dtype=torch.int8)
|
||||
if is_per_token_smooth_quant():
|
||||
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
|
||||
out_scale = torch.empty(
|
||||
x.shape[:-1], device=x.device, dtype=torch.float
|
||||
).unsqueeze(-1)
|
||||
else:
|
||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||
|
||||
kunlun_ops.quant_rmsnorm(x, weight, bias, eps,
|
||||
out=out, out_scale=out_scale)
|
||||
kunlun_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
|
||||
return out, out_scale
|
||||
|
||||
@staticmethod
|
||||
def smooth_quant_matmul_column_row_kernels(input_tensor,
|
||||
weight,
|
||||
smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
otype):
|
||||
def smooth_quant_matmul_column_row_kernels(
|
||||
input_tensor,
|
||||
weight,
|
||||
smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
otype,
|
||||
):
|
||||
"""smooth_quant_matmul_column_row_kernels"""
|
||||
input_shape = input_tensor.shape
|
||||
weight_shape = weight.shape
|
||||
if input_tensor.dim() == 3:
|
||||
input_tensor = input_tensor.reshape(-1, input_shape[-1])
|
||||
out = torch.empty((input_shape[0] * input_shape[1],
|
||||
weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device)
|
||||
out = torch.empty(
|
||||
(input_shape[0] * input_shape[1], weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device,
|
||||
)
|
||||
output_bs_shape = [input_shape[0], input_shape[1]]
|
||||
elif input_tensor.dim() == 2:
|
||||
out = torch.empty((input_shape[0], weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device)
|
||||
out = torch.empty(
|
||||
(input_shape[0], weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device,
|
||||
)
|
||||
output_bs_shape = [-1]
|
||||
kunlun_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
|
||||
weight, smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
out=out)
|
||||
kunlun_ops.smooth_quant_matmul_column_row_kernels(
|
||||
input_tensor,
|
||||
weight,
|
||||
smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
out=out,
|
||||
)
|
||||
|
||||
out = out.view(*output_bs_shape, weight_shape[0])
|
||||
|
||||
@@ -404,6 +370,7 @@ class KunlunOps:
|
||||
if torch.is_tensor(x):
|
||||
return (type(x), x.device, x.dtype, x.shape, x.is_contiguous())
|
||||
return (type(x), x)
|
||||
|
||||
@staticmethod
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -420,23 +387,24 @@ class KunlunOps:
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""fused_moe"""
|
||||
global_num_experts, up_gate_size, _ = w1.shape
|
||||
M, N = hidden_states.shape
|
||||
hidden_dim = w2.shape[1]
|
||||
normed_score = torch.empty(M,
|
||||
moe_top_k,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
moe_top_k,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
normed_score = torch.empty(
|
||||
M, moe_top_k, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(
|
||||
M, moe_top_k, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
num_blocks = 12
|
||||
block_statistic = torch.zeros(
|
||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
num_blocks,
|
||||
global_num_experts,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
router_logits = router_logits.to(torch.float)
|
||||
if scoring_func == "softmax":
|
||||
@@ -445,24 +413,27 @@ class KunlunOps:
|
||||
normed_score=normed_score,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=None,
|
||||
stable=True)
|
||||
stable=False,
|
||||
)
|
||||
elif scoring_func == "sigmoid":
|
||||
torch.ops._C.moe_sigmoid_group_topk_norm(
|
||||
x=router_logits,
|
||||
topk_index=topk_ids,
|
||||
norm_score=normed_score,
|
||||
block_static=block_statistic,
|
||||
bias=e_score_correction_bias,
|
||||
scale=1.0,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
x=router_logits,
|
||||
topk_index=topk_ids,
|
||||
norm_score=normed_score,
|
||||
block_static=block_statistic,
|
||||
bias=e_score_correction_bias,
|
||||
scale=1.0,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
|
||||
if w1_bias is not None or w2_bias is not None:
|
||||
if w1_bias is not None or w2_bias is not None:
|
||||
# Rignt now this branch is for gpt oss
|
||||
# TODO (@xyDong23): faster here using moe_fc kernel
|
||||
normed_score = normed_score.to(hidden_states.dtype)
|
||||
out = torch.zeros(M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
out = torch.zeros(
|
||||
M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0)
|
||||
topk_ids_flat = topk_ids.flatten()
|
||||
for i in range(global_num_experts):
|
||||
@@ -470,9 +441,13 @@ class KunlunOps:
|
||||
selected_token = topk_ids_flat == experts_id
|
||||
if selected_token.sum():
|
||||
cur_token = repeat_x[selected_token]
|
||||
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
|
||||
dtype=cur_token.dtype, device=cur_token.device)
|
||||
groupgemm1 = cur_token@ w1[i].T
|
||||
up_gate = torch.empty(
|
||||
selected_token.sum(),
|
||||
up_gate_size // 2,
|
||||
dtype=cur_token.dtype,
|
||||
device=cur_token.device,
|
||||
)
|
||||
groupgemm1 = cur_token @ w1[i].T
|
||||
# Add w13 bias
|
||||
if w1_bias is not None:
|
||||
groupgemm1 = groupgemm1 + w1_bias[i]
|
||||
@@ -482,53 +457,129 @@ class KunlunOps:
|
||||
if w2_bias is not None:
|
||||
groupgemm2 = groupgemm2 + w2_bias[i]
|
||||
out[selected_token] = groupgemm2
|
||||
ouput = (out.view(M, moe_top_k, N) * normed_score.unsqueeze(2)).sum(dim=1).to(hidden_states.dtype)
|
||||
ouput = (
|
||||
(out.view(M, moe_top_k, N) * normed_score.unsqueeze(2))
|
||||
.sum(dim=1)
|
||||
.to(hidden_states.dtype)
|
||||
)
|
||||
return ouput
|
||||
else:
|
||||
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
|
||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||
# from vllm.forward_context import get_forward_context
|
||||
# forward_context = get_forward_context()
|
||||
# attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
# prefix = "model.layers.0.linear_attn"
|
||||
# if attn_metadata is not None:
|
||||
# attn_metadata = attn_metadata[prefix]
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||
# if attn_metadata is None or attn_metadata.num_prefills > 0 or :
|
||||
if M * moe_top_k < 400:
|
||||
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
|
||||
torch.ops.xspeedgate_ops.moe_pre_small(
|
||||
topk_ids, global_num_experts, False, False, hidden_states
|
||||
)
|
||||
)
|
||||
experts_num_lod = torch.ops.xspeedgate_ops.moe_active_expert_balance(
|
||||
topk_ids, global_num_experts, False
|
||||
)
|
||||
out = torch.ops.xspeedgate_ops.fused_moe(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
normed_score.to(hidden_states.dtype),
|
||||
sorted_tokens_num_lod,
|
||||
sorted_tokens_idx,
|
||||
experts_num_lod,
|
||||
)
|
||||
return out.sum(1)
|
||||
|
||||
y = torch.empty(M,moe_top_k,
|
||||
w1.shape[1],
|
||||
if M * moe_top_k > 768:
|
||||
moe_expand = torch.empty(
|
||||
(M * moe_top_k, N),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
device=hidden_states.device,
|
||||
) # [M*top_k, N], float
|
||||
expert_m = torch.zeros(
|
||||
global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(
|
||||
global_num_experts + 1,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(
|
||||
M * moe_top_k, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids, block_statistic)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
)
|
||||
else:
|
||||
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
|
||||
torch.ops.xspeedgate_ops.moe_pre_small(
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
index_have_neg=False,
|
||||
sort_mode=True,
|
||||
x=hidden_states,
|
||||
)
|
||||
)
|
||||
|
||||
y = torch.empty(
|
||||
M,
|
||||
moe_top_k,
|
||||
w1.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=w1,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_top_k,
|
||||
y=y,
|
||||
if M < 1024:
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=w1,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_top_k,
|
||||
y=y,
|
||||
)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = y.shape[:-1] + (d,)
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.silu_and_mul(out1, y)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
else:
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=w1,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_top_k,
|
||||
y=y,
|
||||
act="SWISH_GLU",
|
||||
)
|
||||
|
||||
y = y[..., : y.shape[-1] // 2]
|
||||
out1 = y.reshape(-1, y.shape[-1])
|
||||
|
||||
out = torch.empty(
|
||||
M,
|
||||
moe_top_k,
|
||||
w2.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = (y.shape[:-1] + (d, ))
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.silu_and_mul(out1, y)
|
||||
|
||||
out = torch.empty(M,moe_top_k,
|
||||
w2.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=out1,
|
||||
weight=w2,
|
||||
@@ -538,8 +589,12 @@ class KunlunOps:
|
||||
y=out,
|
||||
)
|
||||
|
||||
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
|
||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
dequant_scale = torch.ones(
|
||||
[M, moe_top_k], dtype=torch.float32, device=out.device
|
||||
)
|
||||
output = torch.empty(
|
||||
[M, N], dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
|
||||
|
||||
torch.ops._C.moe_post(
|
||||
@@ -547,9 +602,9 @@ class KunlunOps:
|
||||
moe_index=sorted_tokens_idx,
|
||||
normed_scale=normed_score,
|
||||
dequant_scale=dequant_scale,
|
||||
y=output
|
||||
y=output,
|
||||
)
|
||||
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@@ -568,23 +623,23 @@ class KunlunOps:
|
||||
topk_group: Optional[int] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> torch.Tensor:
|
||||
x = hidden_states
|
||||
batch, hidden_size = x.shape
|
||||
batch, hidden_size = x.shape
|
||||
num_local_experts, up_gate_size, _ = w13_weight.shape
|
||||
|
||||
router_logits = x.to(linear_weights.dtype)@linear_weights.T
|
||||
|
||||
topk_weights = torch.empty(batch,
|
||||
top_k,
|
||||
dtype=router_logits.dtype,
|
||||
device=router_logits.device)
|
||||
topk_ids = torch.empty(batch,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=router_logits.device)
|
||||
block_static = torch.empty(0, dtype=torch.int32,device=router_logits.device)
|
||||
torch.ops._C.moe_softmax_topk(router_logits, topk_weights, topk_ids, block_static)
|
||||
router_logits = x.to(linear_weights.dtype) @ linear_weights.T
|
||||
|
||||
topk_weights = torch.empty(
|
||||
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
|
||||
)
|
||||
topk_ids = torch.empty(
|
||||
batch, top_k, dtype=torch.int32, device=router_logits.device
|
||||
)
|
||||
block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
|
||||
torch.ops._C.moe_softmax_topk(
|
||||
router_logits, topk_weights, topk_ids, block_static
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
|
||||
@@ -598,11 +653,19 @@ class KunlunOps:
|
||||
selected_token = topk_ids_flat == experts_id
|
||||
if selected_token.sum():
|
||||
cur_token = repeat_x[selected_token]
|
||||
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
|
||||
dtype=cur_token.dtype, device=cur_token.device)
|
||||
torch.ops._C.silu_and_mul(up_gate, cur_token@ w13_weight[i].T)
|
||||
up_gate = torch.empty(
|
||||
selected_token.sum(),
|
||||
up_gate_size // 2,
|
||||
dtype=cur_token.dtype,
|
||||
device=cur_token.device,
|
||||
)
|
||||
torch.ops._C.silu_and_mul(up_gate, cur_token @ w13_weight[i].T)
|
||||
out[selected_token] = up_gate @ w2_weight[i].T
|
||||
output = (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)).sum(dim=1).to(x.dtype)
|
||||
output = (
|
||||
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
|
||||
.sum(dim=1)
|
||||
.to(x.dtype)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -638,10 +701,11 @@ class KunlunOps:
|
||||
prompt_lods_cpu: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
) -> torch.Tensor:
|
||||
"""mla pa block"""
|
||||
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
output = torch.empty(
|
||||
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
kunlun_ops.xft_multi_head_latent_page_attention_block(
|
||||
hidden_states,
|
||||
q_lora_rank,
|
||||
@@ -679,7 +743,6 @@ class KunlunOps:
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def fused_gdn_gating(
|
||||
A_log: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
@@ -695,25 +758,34 @@ class KunlunOps:
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
h0_source: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
use_qk_l2norm_in_kernel: bool,
|
||||
cu_seqlens: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
'''
|
||||
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
|
||||
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。
|
||||
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
|
||||
'''
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
h0_source: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
use_qk_l2norm_in_kernel: bool,
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
|
||||
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。
|
||||
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
|
||||
"""
|
||||
|
||||
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
|
||||
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
|
||||
cu_seqlens)
|
||||
return (o, final_state)
|
||||
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
h0_source,
|
||||
output_final_state,
|
||||
use_qk_l2norm_in_kernel,
|
||||
cu_seqlens,
|
||||
)
|
||||
return (o, final_state)
|
||||
|
||||
@@ -9,60 +9,196 @@
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
from typing import Optional
|
||||
import torch.nn.functional as F
|
||||
|
||||
import cocopod # noqa
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
||||
from .chunk_o import chunk_fwd_o
|
||||
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
||||
from .cumsum import chunk_local_cumsum
|
||||
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
||||
from .l2norm import l2norm_fwd
|
||||
from .solve_tril import solve_tril
|
||||
from .utils import SUPPRESS_LEVEL, input_guard
|
||||
from .wy_fast import recompute_w_u_fwd
|
||||
from .index import prepare_chunk_indices
|
||||
import xspeedgate_ops
|
||||
import cocopod
|
||||
|
||||
|
||||
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
|
||||
chunk_size=64
|
||||
A = -A.transpose(1,2)
|
||||
def torch_solve_tril(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
output_dtype: torch.dtype = torch.float,
|
||||
):
|
||||
chunk_size = 64
|
||||
A = -A.transpose(1, 2)
|
||||
sequence_length = A.shape[-2]
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
A = F.pad(A, (0, 0, 0, pad_size))
|
||||
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
|
||||
# mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
|
||||
|
||||
# A = A.masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = A[..., i, :i].clone()
|
||||
sub = A[..., :i, :i].clone()
|
||||
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
|
||||
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[:,:,:sequence_length,:].transpose(1,2)
|
||||
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[
|
||||
:, :, :sequence_length, :
|
||||
].transpose(1, 2)
|
||||
|
||||
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
A = chunk_scaled_dot_kkt_fwd(k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_dtype=q.dtype)
|
||||
|
||||
#kernel版
|
||||
torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, 64) if cu_seqlens is not None else None
|
||||
def recompute_w_u_fwd_torch(
|
||||
k: torch.Tensor, # [B, T, H, K]
|
||||
v: torch.Tensor, # [B, T, H, V]
|
||||
beta: torch.Tensor, # [B, T, H]
|
||||
g: torch.Tensor, # [B, T, H]
|
||||
A: torch.Tensor, # [B, H, T, T]
|
||||
):
|
||||
"""
|
||||
最简单版本:假设等长序列,key和value头数相同
|
||||
"""
|
||||
chunk_size = 64
|
||||
num_v_heads, num_k_heads = v.shape[2], k.shape[2]
|
||||
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
k, v, beta, g, A = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32) for x in (k, v, beta, g, A)
|
||||
]
|
||||
|
||||
batch_size, num_heads, sequence_length, k_head_dim = k.shape
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
k = F.pad(k, (0, 0, 0, pad_size))
|
||||
v = F.pad(v, (0, 0, 0, pad_size))
|
||||
beta = F.pad(beta, (0, pad_size))
|
||||
g = F.pad(g, (0, pad_size))
|
||||
A = F.pad(A, (0, 0, 0, pad_size))
|
||||
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
|
||||
|
||||
v_beta = v * beta.unsqueeze(-1)
|
||||
k_beta = k * beta.unsqueeze(-1)
|
||||
|
||||
k, v, k_beta, v_beta = [
|
||||
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
|
||||
for x in (k, v, k_beta, v_beta)
|
||||
]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
|
||||
u = A @ v_beta
|
||||
w = A @ (k_beta * g.exp().unsqueeze(-1))
|
||||
w = (
|
||||
w.reshape(w.shape[0], w.shape[1], -1, w.shape[-1])[:, :, :sequence_length, :]
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
u = (
|
||||
u.reshape(u.shape[0], u.shape[1], -1, u.shape[-1])[:, :, :sequence_length, :]
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
return w, u
|
||||
|
||||
|
||||
def split_by_value(tensor, chunk_size=64):
|
||||
indices = tensor.tolist()
|
||||
result = set(indices) # 使用集合避免重复
|
||||
|
||||
for i in range(len(indices) - 1):
|
||||
start = indices[i]
|
||||
end = indices[i + 1]
|
||||
|
||||
# 计算第一个对齐边界
|
||||
# 我们要找的是 start + n*chunk_size,其中n是使结果大于start的最小整数
|
||||
first_boundary = start + chunk_size
|
||||
|
||||
# 在(start, end)范围内插入所有对齐边界
|
||||
boundary = first_boundary
|
||||
while boundary < end:
|
||||
result.add(boundary)
|
||||
boundary += chunk_size
|
||||
|
||||
return torch.tensor(sorted(result), dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
chunk_size = 64
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, 64) if cu_seqlens is not None else None
|
||||
)
|
||||
chunk_offsets = (
|
||||
prepare_chunk_offsets(cu_seqlens, chunk_size)
|
||||
if cu_seqlens is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# !
|
||||
# g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
g = torch.ops.xspeedgate_ops.chunk_local_cumsum(
|
||||
g,
|
||||
chunk_size=64,
|
||||
reverse=False,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
head_first=False,
|
||||
)
|
||||
|
||||
# !
|
||||
# A = chunk_scaled_dot_kkt_fwd(k=k,
|
||||
# beta=beta,
|
||||
# g_cumsum=g,
|
||||
# cu_seqlens=cu_seqlens,
|
||||
# output_dtype=q.dtype)
|
||||
A = torch.ops.xspeedgate_ops.chunk_scaled_dot_kkt_fwd(
|
||||
k, beta, g, cu_seqlens, chunk_indices, chunk_size
|
||||
)
|
||||
|
||||
# torch版
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# torch.save(A, "A_in")
|
||||
# torch.save(cu_seqlens, "cu_seqlens")
|
||||
# A2 = A.clone()
|
||||
torch.ops.xspeedgate_ops.solve_tril_ns(A, cu_seqlens, chunk_indices, chunk_size)
|
||||
|
||||
# !
|
||||
# torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# err = torch.max(torch.abs(A - A2))
|
||||
# print("err", err)
|
||||
# if err > 1e-3:
|
||||
# raise
|
||||
# A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
# for i in range(len(cu_seqlens)-1):
|
||||
# A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
# A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
|
||||
|
||||
"""
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
u = torch.empty_like(v)
|
||||
w = k.new_empty(B, T, H, K)
|
||||
for i in range(len(cu_seqlens)-1):
|
||||
k_i = k[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
v_i = v[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
beta_i = beta[:, cu_seqlens[i]:cu_seqlens[i+1], :]
|
||||
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
g_i = g[:, cu_seqlens[i]:cu_seqlens[i+1], :]
|
||||
|
||||
w_i, u_i = recompute_w_u_fwd_torch(
|
||||
k=k_i,
|
||||
v=v_i,
|
||||
beta=beta_i,
|
||||
A=A_i,
|
||||
g=g_i,
|
||||
)
|
||||
w[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = w_i
|
||||
u[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = u_i
|
||||
"""
|
||||
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
@@ -71,17 +207,63 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_size=64
|
||||
chunk_size=64,
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
"""
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
"""
|
||||
|
||||
# i
|
||||
# import os
|
||||
# if not os.path.exists("/qwen-next/in"):
|
||||
# os.makedirs("/qwen-next/in")
|
||||
# torch.save(k, "/qwen-next/in/k.pt")
|
||||
# torch.save(u, "/qwen-next/in/u.pt")
|
||||
# torch.save(w, "/qwen-next/in/w.pt")
|
||||
# torch.save(g, "/qwen-next/in/g.pt")
|
||||
# torch.save(initial_state, "/qwen-next/in/initial_state.pt")
|
||||
# torch.save(cu_seqlens, "/qwen-next/in/cu_seqlens.pt")
|
||||
# torch.save(chunk_indices, "/qwen-next/in/chunk_indices.pt")
|
||||
# torch.save(chunk_offsets.to(torch.int32), "/qwen-next/in/chunk_offsets.pt")
|
||||
# torch.save(chunk_size, "/qwen-next/in/chunk_size.pt")
|
||||
# torch.save(output_final_state, "/qwen-next/in/output_final_state.pt")
|
||||
|
||||
h, v_new, final_state = torch.ops.xspeedgate_ops.chunk_gated_delta_rule_fwd_h(
|
||||
k,
|
||||
u,
|
||||
w,
|
||||
g,
|
||||
initial_state,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
chunk_offsets.to(torch.int32),
|
||||
chunk_size,
|
||||
output_final_state,
|
||||
True,
|
||||
)
|
||||
|
||||
# h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
# k=k,
|
||||
# w=w,
|
||||
# u=u,
|
||||
# g=g,
|
||||
# initial_state=initial_state,
|
||||
# output_final_state=output_final_state,
|
||||
# cu_seqlens=cu_seqlens,
|
||||
# )
|
||||
# if not os.path.exists("/qwen-next/out"):
|
||||
# os.makedirs("/qwen-next/out")
|
||||
# torch.save(h, "/qwen-next/out/h.pt")
|
||||
# torch.save(v_new, "/qwen-next/out/v_new.pt")
|
||||
# torch.save(final_state, "/qwen-next/out/final_state.pt")
|
||||
|
||||
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
@@ -91,8 +273,19 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_size=64
|
||||
chunk_size=64,
|
||||
)
|
||||
"""
|
||||
o = chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_new,
|
||||
h=h,
|
||||
g=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
"""
|
||||
if SUPPRESS_LEVEL < 3:
|
||||
return g, o, A, final_state, None, None, None
|
||||
elif SUPPRESS_LEVEL >= 3:
|
||||
@@ -103,18 +296,20 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
@torch.amp.custom_fwd(device_type='cuda')
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
@torch.amp.custom_fwd(device_type="cuda")
|
||||
def forward(
|
||||
ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm_fwd(q)
|
||||
k = l2norm_fwd(k)
|
||||
@@ -136,17 +331,19 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def chunk_gated_delta_rule(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
def chunk_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
@@ -211,42 +408,85 @@ def chunk_gated_delta_rule(q: torch.Tensor,
|
||||
)
|
||||
"""
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||
assert len(
|
||||
beta.shape
|
||||
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
assert (
|
||||
q.dtype != torch.float32
|
||||
), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||
assert (
|
||||
len(beta.shape) == 3
|
||||
), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
|
||||
if head_first:
|
||||
raise DeprecationWarning(
|
||||
"head_first is deprecated and will be removed in a future version. "
|
||||
"Please use head_first=False for now instead.",
|
||||
stacklevel=2)
|
||||
stacklevel=2,
|
||||
)
|
||||
q, k, v, beta, g = map(
|
||||
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
|
||||
(q, k, v, beta, g))
|
||||
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
|
||||
)
|
||||
if not head_first and q.shape[1] < q.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2)
|
||||
stacklevel=2,
|
||||
)
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
if initial_state is not None and initial_state.shape[0] != len(
|
||||
cu_seqlens) - 1:
|
||||
f"Please flatten variable-length inputs before processing."
|
||||
)
|
||||
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
|
||||
use_qk_l2norm_in_kernel)
|
||||
scale = k.shape[-1] ** -0.5
|
||||
|
||||
if False:
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
g = g.contiguous()
|
||||
beta = beta.contiguous()
|
||||
initial_state = initial_state.contiguous()
|
||||
|
||||
o = torch.empty_like(v)
|
||||
final_state = torch.empty_like(initial_state)
|
||||
import kunlun_ops
|
||||
|
||||
kunlun_ops.gated_delta_rule(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
initial_state,
|
||||
g,
|
||||
beta,
|
||||
final_state,
|
||||
o,
|
||||
scale,
|
||||
cu_seqlens.cpu(),
|
||||
cu_seqlens,
|
||||
cu_seqlens.cpu(),
|
||||
cu_seqlens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
else:
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
output_final_state,
|
||||
cu_seqlens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
if head_first:
|
||||
o = rearrange(o, 'b t h ... -> b h t ...')
|
||||
o = rearrange(o, "b t h ... -> b h t ...")
|
||||
return o, final_state
|
||||
|
||||
@@ -12,21 +12,21 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import exp
|
||||
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
|
||||
|
||||
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
|
||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_G': lambda args: args['g'] is not None,
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
||||
})
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({
|
||||
@@ -40,7 +40,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
|
||||
# ],
|
||||
# key=['H', 'K', 'V', 'BT'],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
@@ -67,10 +67,12 @@ def chunk_fwd_kernel_o(
|
||||
|
||||
if IS_VARLEN:
|
||||
i_tg = i_t
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
else:
|
||||
@@ -89,12 +91,15 @@ def chunk_fwd_kernel_o(
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
|
||||
(BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT),
|
||||
(BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV),
|
||||
(BK, BV), (1, 0))
|
||||
p_q = tl.make_block_ptr(
|
||||
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
)
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
|
||||
)
|
||||
p_h = tl.make_block_ptr(
|
||||
h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
|
||||
)
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
@@ -109,8 +114,8 @@ def chunk_fwd_kernel_o(
|
||||
|
||||
if USE_G:
|
||||
g += bos * H + i_h
|
||||
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_o = b_o * tl.exp(b_g)[:, None]
|
||||
b_A = b_A * tl.exp(b_g[:, None] - b_g[None, :])
|
||||
|
||||
@@ -120,10 +125,12 @@ def chunk_fwd_kernel_o(
|
||||
# b_A = tl.where(m_A, b_A, 0)
|
||||
b_A = tl.where(o_t[:, None] >= o_t[None, :], b_A, 0)
|
||||
|
||||
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
p_v = tl.make_block_ptr(
|
||||
v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
p_o = tl.make_block_ptr(
|
||||
o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
# to fix mma -> mma layout conversion
|
||||
@@ -133,48 +140,29 @@ def chunk_fwd_kernel_o(
|
||||
|
||||
|
||||
def chunk_fwd_o(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None, # cumsum of log decay
|
||||
scale: Optional[float] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None, # cumsum of log decay
|
||||
scale: Optional[float] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
) -> torch.Tensor:
|
||||
_, T, _, _, _ = *q.shape, v.shape[-1]
|
||||
if FLA_GDN_FIX_BT:
|
||||
BT = 64
|
||||
else:
|
||||
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
scale = k.shape[-1] ** -0.5
|
||||
|
||||
o = torch.empty_like(v)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta['BV']), NT, B * H)
|
||||
|
||||
chunk_fwd_kernel_o[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=64,
|
||||
BV=32
|
||||
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
|
||||
q, k, v, h, g, scale, cu_seqlens, chunk_indices, chunk_size
|
||||
)
|
||||
return o
|
||||
|
||||
@@ -9,28 +9,28 @@
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import kunlun_ops
|
||||
import torch
|
||||
|
||||
|
||||
class FusedRecurrentFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
|
||||
def forward(
|
||||
ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwdv2(
|
||||
q.contiguous(),
|
||||
k.contiguous(),
|
||||
@@ -44,7 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
|
||||
h0_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||
is_h0_transposed=True
|
||||
is_h0_transposed=True,
|
||||
)
|
||||
return o, final_state
|
||||
|
||||
@@ -130,9 +130,10 @@ def fused_recurrent_gated_delta_rule(
|
||||
if cu_seqlens is not None and q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
f"Please flatten variable-length inputs before processing."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
scale = k.shape[-1] ** -0.5
|
||||
else:
|
||||
assert scale > 0, "scale must be positive"
|
||||
if beta is None:
|
||||
|
||||
@@ -10,22 +10,21 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import kunlun_ops
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
import kunlun_ops
|
||||
|
||||
|
||||
BT_LIST = [8, 16, 32, 64, 128]
|
||||
|
||||
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
|
||||
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16, 32]
|
||||
],
|
||||
key=['D'])
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
|
||||
],
|
||||
key=["D"],
|
||||
)
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel1(
|
||||
x,
|
||||
@@ -49,11 +48,14 @@ def l2norm_fwd_kernel1(
|
||||
tl.store(y + cols, b_y, mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({'BT': BT}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
|
||||
],
|
||||
key=['D'])
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BT": BT}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16]
|
||||
for BT in BT_LIST
|
||||
],
|
||||
key=["D"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["NB"])
|
||||
def l2norm_fwd_kernel(
|
||||
x,
|
||||
@@ -87,67 +89,9 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
||||
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
||||
|
||||
|
||||
def l2norm_fwd_triton(x: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
output_dtype: Optional[torch.dtype] = None):
|
||||
x_shape_og = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# allocate output
|
||||
if output_dtype is None:
|
||||
y = torch.empty_like(x)
|
||||
else:
|
||||
y = torch.empty_like(x, dtype=output_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
T, D = x.shape[0], x.shape[-1]
|
||||
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
|
||||
if D > BD:
|
||||
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
|
||||
|
||||
if not USE_DEFAULT_FLA_NORM:
|
||||
MBLOCK = 32
|
||||
# M, N = x.shape
|
||||
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
T,
|
||||
D,
|
||||
MBLOCK,
|
||||
)
|
||||
else:
|
||||
if D <= 512:
|
||||
NB = triton.cdiv(T, 2048)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(T, meta['BT']), )
|
||||
|
||||
l2norm_fwd_kernel[grid](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB=NB,
|
||||
T=T,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
else:
|
||||
l2norm_fwd_kernel1[(T, )](
|
||||
x,
|
||||
y,
|
||||
eps=eps,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
|
||||
return y.view(x_shape_og)
|
||||
|
||||
|
||||
def l2norm_fwd(x: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
output_dtype: Optional[torch.dtype] = None):
|
||||
def l2norm_fwd(
|
||||
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
|
||||
):
|
||||
out = torch.empty_like(x)
|
||||
kunlun_ops.l2norm(x, out, eps)
|
||||
kunlun_ops.l2norm(x, out, eps)
|
||||
return out
|
||||
|
||||
@@ -19,20 +19,21 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .utils import input_guard
|
||||
|
||||
|
||||
def rms_norm_ref(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
upcast=True):
|
||||
def rms_norm_ref(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
upcast=True,
|
||||
):
|
||||
dtype = x.dtype
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
@@ -43,12 +44,10 @@ def rms_norm_ref(x,
|
||||
x = x * F.silu(z)
|
||||
if group_size is None:
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
|
||||
weight)
|
||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
||||
else:
|
||||
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
||||
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
|
||||
eps)
|
||||
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
@@ -57,10 +56,12 @@ def rms_norm_ref(x,
|
||||
return out.to(dtype)
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"HAS_BIAS": lambda args: args["B"] is not None,
|
||||
"HAS_Z": lambda args: args["Z"] is not None,
|
||||
})
|
||||
@triton.heuristics(
|
||||
{
|
||||
"HAS_BIAS": lambda args: args["B"] is not None,
|
||||
"HAS_Z": lambda args: args["Z"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def layer_norm_fwd_kernel(
|
||||
X, # pointer to the input
|
||||
@@ -97,17 +98,17 @@ def layer_norm_fwd_kernel(
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.)
|
||||
xbar = tl.where(cols < N, x - mean, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.)
|
||||
xbar = tl.where(cols < N, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
@@ -149,46 +150,50 @@ def layer_norm_fwd(
|
||||
# weight = weight.reshape(N)
|
||||
# print("weight",weight.shape)
|
||||
# print("x",x.shape)
|
||||
assert weight.shape == (N, )
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N, )
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
|
||||
device=x.device) if not is_rms_norm else None
|
||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
mean = (
|
||||
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
if not is_rms_norm
|
||||
else None
|
||||
)
|
||||
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
layer_norm_fwd_kernel[grid](x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps)
|
||||
layer_norm_fwd_kernel[grid](
|
||||
x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
@@ -196,17 +201,18 @@ class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@input_guard
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
@@ -223,16 +229,15 @@ class LayerNormFn(torch.autograd.Function):
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
# y, mean, rstd = torch.ops.xspeedgate_ops.rms_norm_gated_fwd(x, weight, bias, eps, z, group_size, norm_before_gate, is_rms_norm)
|
||||
y = torch.empty_like(x)
|
||||
mean, rstd = None, None
|
||||
import kunlun_ops
|
||||
|
||||
kunlun_ops.rms_norm_gated(
|
||||
x, y, z, weight, eps, group_size, norm_before_gate, is_rms_norm
|
||||
)
|
||||
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.eps = eps
|
||||
@@ -242,27 +247,27 @@ class LayerNormFn(torch.autograd.Function):
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def layernorm_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
||||
norm_before_gate, is_rms_norm)
|
||||
def layernorm_fn(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
|
||||
)
|
||||
|
||||
|
||||
def rmsnorm_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
||||
norm_before_gate, True)
|
||||
def rmsnorm_fn(
|
||||
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, True
|
||||
)
|
||||
|
||||
|
||||
class LayerNormGated(nn.Module):
|
||||
@@ -294,15 +299,16 @@ class LayerNormGated(nn.Module):
|
||||
torch.nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
return layernorm_fn(x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
group_size=self.group_size,
|
||||
eps=self.eps,
|
||||
norm_before_gate=self.norm_before_gate)
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
return layernorm_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
group_size=self.group_size,
|
||||
eps=self.eps,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormGated(nn.Module):
|
||||
@@ -332,12 +338,13 @@ class RMSNormGated(nn.Module):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
return rmsnorm_fn(x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate)
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
return rmsnorm_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
@@ -28,6 +27,7 @@ RESOLUTION = {
|
||||
torch.complex64: 1.3e-6,
|
||||
}
|
||||
|
||||
|
||||
def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
|
||||
assert res.dtype == dtype
|
||||
ref = ref.to(dtype)
|
||||
@@ -35,6 +35,7 @@ def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
|
||||
rtol = RESOLUTION[dtype]
|
||||
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
@@ -80,7 +81,6 @@ def recompute_u_fwd_kernel(
|
||||
p_beta = tl.make_block_ptr(
|
||||
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
p_A = tl.make_block_ptr(
|
||||
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||
)
|
||||
@@ -110,7 +110,6 @@ def recompute_u_fwd_kernel(
|
||||
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
@@ -195,53 +194,12 @@ def recompute_w_u_fwd(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = A.shape[-1]
|
||||
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
BK = 64
|
||||
BV = 64
|
||||
u = torch.empty_like(v)
|
||||
w = k.new_empty(B, T, H, K)
|
||||
recompute_u_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
recompute_w_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
|
||||
k, v, beta, g_cumsum, A, cu_seqlens, chunk_indices, chunk_size=BT
|
||||
)
|
||||
return w, u
|
||||
return w, u
|
||||
|
||||
@@ -15,51 +15,52 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
|
||||
from vllm.model_executor.layers import layernorm
|
||||
from typing import Optional, Union
|
||||
import kunlun_ops
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers import layernorm
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
|
||||
def vllm_kunlun_forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""forward_cuda"""
|
||||
if x.is_contiguous() == False:
|
||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||
# so we must make it contiguous() manually
|
||||
x = x.contiguous()
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""forward_cuda"""
|
||||
if not x.is_contiguous():
|
||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||
# so we must make it contiguous() manually
|
||||
x = x.contiguous()
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
if residual is not None:
|
||||
# residual_output = torch.empty_like(residual)
|
||||
torch.ops._C.add_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
residual_output=residual,
|
||||
weight=self.weight.data,
|
||||
eps=self.variance_epsilon,
|
||||
output=x
|
||||
)
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.rmsnorm(
|
||||
if residual is not None:
|
||||
# residual_output = torch.empty_like(residual)
|
||||
torch.ops._C.add_rmsnorm(
|
||||
x,
|
||||
self.weight.data,
|
||||
out,
|
||||
self.variance_epsilon,
|
||||
residual,
|
||||
residual_output=residual,
|
||||
weight=self.weight.data,
|
||||
eps=self.variance_epsilon,
|
||||
output=x,
|
||||
)
|
||||
return out
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.rmsnorm(
|
||||
x,
|
||||
self.weight.data,
|
||||
out,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RMSNorm.forward = vllm_kunlun_forward_cuda
|
||||
|
||||
|
||||
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
||||
@staticmethod
|
||||
def forward_xpu(
|
||||
@@ -68,30 +69,42 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if x.is_contiguous() == False:
|
||||
if not x.is_contiguous():
|
||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||
# so we must make it contiguous() manually
|
||||
x = x.contiguous()
|
||||
|
||||
if x.dim() == 3:
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x.size(-1))
|
||||
if residual is not None:
|
||||
torch.ops._C.add_rmsnorm(
|
||||
out = torch.empty_like(x)
|
||||
out_residual = torch.empty_like(residual)
|
||||
torch.ops._C.gemma_add_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
residual_output=residual,
|
||||
weight=weight+1,
|
||||
residual_output=out_residual,
|
||||
weight=weight,
|
||||
eps=variance_epsilon,
|
||||
output=x
|
||||
output=out,
|
||||
)
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.gemma_rmsnorm(
|
||||
x,
|
||||
weight,
|
||||
out,
|
||||
variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.rmsnorm(
|
||||
x,
|
||||
weight+1,
|
||||
out,
|
||||
variance_epsilon,
|
||||
)
|
||||
return out
|
||||
if x.dim() == 3:
|
||||
x = x.view(x_shape)
|
||||
if out is not None:
|
||||
out = out.view(x_shape)
|
||||
|
||||
if residual is not None:
|
||||
return out, out_residual
|
||||
else:
|
||||
return out
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@@ -99,16 +112,17 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if torch.compiler.is_compiling():
|
||||
self.forward_static = self.forward_xpu # only use in cudagraph
|
||||
self.forward_static = self.forward_xpu # only use in cudagraph
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
if not getattr(self, "_is_compiled", False):
|
||||
self.forward_static = torch.compile( # type: ignore
|
||||
self.forward_static, backend="aot_eager")
|
||||
self.forward_static, backend="aot_eager"
|
||||
)
|
||||
self._is_compiled = True
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RMSNorm.forward = vllm_kunlun_forward_cuda
|
||||
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm
|
||||
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
390
vllm_kunlun/v1/attention/backends/gdn_attn.py
Normal file
390
vllm_kunlun/v1/attention/backends/gdn_attn.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Backend for GatedDeltaNet attention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends import gdn_attn
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class GDNAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
|
||||
return GDNAttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class GDNAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_spec_decodes: int
|
||||
num_spec_decode_tokens: int
|
||||
num_actual_tokens: int
|
||||
|
||||
has_initial_state: Optional[torch.Tensor] = None
|
||||
has_initial_state_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
spec_query_start_loc: Optional[torch.Tensor] = (
|
||||
None # shape: [num_spec_decodes + 1,]
|
||||
)
|
||||
non_spec_query_start_loc: Optional[torch.Tensor] = (
|
||||
None # shape: [batch - num_spec_decodes + 1,]
|
||||
)
|
||||
|
||||
spec_state_indices_tensor: Optional[torch.Tensor] = None # shape: [batch, num_spec]
|
||||
non_spec_state_indices_tensor: Optional[torch.Tensor] = (
|
||||
None # shape: [batch - num_spec_decodes,]
|
||||
)
|
||||
non_spec_state_indices_tensor_cpu: Optional[torch.Tensor] = None
|
||||
spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,]
|
||||
spec_token_masks: Optional[torch.Tensor] = (
|
||||
None # shape: [num_prefill_tokens + num_decode_tokens,]
|
||||
)
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: Optional[dict] = None
|
||||
batch_ptr: Optional[torch.Tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
|
||||
|
||||
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
if self.speculative_config:
|
||||
self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501
|
||||
else:
|
||||
self.num_spec = 0
|
||||
self.use_spec_decode = self.num_spec > 0
|
||||
self._init_reorder_batch_threshold(1, self.use_spec_decode)
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
|
||||
self.compilation_config.max_capture_size,
|
||||
)
|
||||
|
||||
self.spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, self.num_spec + 1),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_sequence_masks = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
self.spec_token_masks = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
self.spec_query_start_loc = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_query_start_loc = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.num_accepted_tokens = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build( # type: ignore[override]
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
|
||||
fast_build: bool = False,
|
||||
) -> GDNAttentionMetadata:
|
||||
m = common_attn_metadata
|
||||
|
||||
query_start_loc = m.query_start_loc
|
||||
context_lens = m.num_computed_tokens_cpu
|
||||
context_lens_tensor = context_lens.to(query_start_loc.device)
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if (
|
||||
not self.use_spec_decode
|
||||
or num_decode_draft_tokens_cpu is None
|
||||
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
|
||||
.sum()
|
||||
.item()
|
||||
== 0
|
||||
):
|
||||
spec_sequence_masks = None
|
||||
num_spec_decodes = 0
|
||||
else:
|
||||
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||
if num_spec_decodes == 0:
|
||||
spec_sequence_masks = None
|
||||
else:
|
||||
spec_sequence_masks = spec_sequence_masks.to(
|
||||
query_start_loc.device, non_blocking=True
|
||||
)
|
||||
|
||||
if spec_sequence_masks is None:
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(m, decode_threshold=1)
|
||||
)
|
||||
num_spec_decode_tokens = 0
|
||||
spec_token_masks = None
|
||||
spec_state_indices_tensor = None
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
num_accepted_tokens = None
|
||||
else:
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||
num_decodes = (non_spec_query_lens == 1).sum().item()
|
||||
num_prefills = non_spec_query_lens.size(0) - num_decodes
|
||||
num_decode_tokens = num_decodes
|
||||
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
|
||||
|
||||
if num_prefills == 0 and num_decodes == 0:
|
||||
spec_token_masks = torch.ones(
|
||||
(
|
||||
min(
|
||||
num_spec_decodes * (self.num_spec + 1),
|
||||
query_start_loc[-1].item(),
|
||||
)
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc = None
|
||||
else:
|
||||
spec_token_masks = torch.repeat_interleave(
|
||||
spec_sequence_masks, query_lens
|
||||
)
|
||||
spec_state_indices_tensor = m.block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
]
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[
|
||||
~spec_sequence_masks, 0
|
||||
]
|
||||
|
||||
spec_query_start_loc = torch.zeros(
|
||||
num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
|
||||
)
|
||||
non_spec_query_start_loc = torch.zeros(
|
||||
query_lens.size(0) - num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[~spec_sequence_masks],
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc[1:],
|
||||
)
|
||||
|
||||
num_spec_decode_tokens = (
|
||||
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||
)
|
||||
assert num_accepted_tokens is not None
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||
|
||||
if num_prefills > 0:
|
||||
has_initial_state = context_lens_tensor > 0
|
||||
if spec_sequence_masks is not None:
|
||||
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||
has_initial_state_cpu = has_initial_state.cpu()
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(non_spec_query_start_loc)
|
||||
)
|
||||
else:
|
||||
has_initial_state = None
|
||||
has_initial_state_cpu = None
|
||||
num_actual_tokens = (
|
||||
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
|
||||
)
|
||||
|
||||
# prepare tensors for cudagraph
|
||||
#
|
||||
# With speculative decoding, the xgrammar backend may rollback tokens
|
||||
# and causing some sequences has less draft tokens than self.num_spec.
|
||||
#
|
||||
# In above cases, the max possible batch size for n tokens, can be
|
||||
# min(n, cudagraph_max_bs).
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
and num_decodes == 0
|
||||
and num_spec_decodes <= self.decode_cudagraph_max_bs
|
||||
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
|
||||
|
||||
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
|
||||
spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
|
||||
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
||||
spec_sequence_masks, non_blocking=True
|
||||
)
|
||||
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
|
||||
spec_sequence_masks[num_spec_decodes:].fill_(False)
|
||||
|
||||
assert spec_token_masks is not None
|
||||
self.spec_token_masks[: spec_token_masks.size(0)].copy_(
|
||||
spec_token_masks, non_blocking=True
|
||||
)
|
||||
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
|
||||
spec_token_masks[spec_token_masks.size(0) :].fill_(False)
|
||||
|
||||
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
|
||||
spec_query_start_loc, non_blocking=True
|
||||
)
|
||||
spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index]
|
||||
spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
|
||||
spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)
|
||||
|
||||
self.num_accepted_tokens[:num_spec_decodes].copy_(
|
||||
num_accepted_tokens, non_blocking=True
|
||||
)
|
||||
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
|
||||
num_accepted_tokens[num_spec_decodes:].fill_(1)
|
||||
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
and num_spec_decodes == 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||
batch_size = num_actual_tokens
|
||||
|
||||
self.non_spec_state_indices_tensor[:num_decodes].copy_(
|
||||
non_spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
|
||||
:batch_size
|
||||
]
|
||||
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
|
||||
non_spec_query_start_loc, non_blocking=True
|
||||
)
|
||||
non_spec_num_query_tokens = non_spec_query_start_loc[
|
||||
-1
|
||||
] # type: ignore[index]
|
||||
non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
|
||||
non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)
|
||||
|
||||
if num_accepted_tokens is not None:
|
||||
num_accepted_tokens = num_accepted_tokens.to(torch.int32)
|
||||
attn_metadata = GDNAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_spec_decodes=num_spec_decodes,
|
||||
num_spec_decode_tokens=num_spec_decode_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
has_initial_state=has_initial_state,
|
||||
has_initial_state_cpu=has_initial_state_cpu,
|
||||
spec_query_start_loc=spec_query_start_loc,
|
||||
non_spec_query_start_loc=non_spec_query_start_loc,
|
||||
spec_state_indices_tensor=spec_state_indices_tensor,
|
||||
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
||||
non_spec_state_indices_tensor_cpu=(
|
||||
non_spec_state_indices_tensor.cpu()
|
||||
if non_spec_state_indices_tensor is not None
|
||||
else None
|
||||
),
|
||||
spec_sequence_masks=spec_sequence_masks,
|
||||
spec_token_masks=spec_token_masks,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
):
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert (
|
||||
m.num_reqs <= self.decode_cudagraph_max_bs
|
||||
and m.num_actual_tokens <= self.decode_cudagraph_max_bs
|
||||
), (
|
||||
f"GDN only supports decode-only full CUDAGraph capture. "
|
||||
f"Make sure batch size ({m.num_reqs}) <= "
|
||||
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
|
||||
f"and number of tokens ({m.num_actual_tokens}) <= "
|
||||
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})."
|
||||
)
|
||||
|
||||
num_accepted_tokens = torch.diff(m.query_start_loc)
|
||||
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
|
||||
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
|
||||
|
||||
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
|
||||
|
||||
|
||||
gdn_attn.GDNAttentionMetadata = GDNAttentionMetadata
|
||||
gdn_attn.GDNAttentionMetadataBuilder = GDNAttentionMetadataBuilder
|
||||
@@ -770,24 +770,14 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory
|
||||
value = value.contiguous()
|
||||
if key_cache.is_contiguous():
|
||||
kunlun_ops.reshape_and_cache(
|
||||
key[: attn_metadata.num_actual_tokens],
|
||||
value[: attn_metadata.num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
)
|
||||
else:
|
||||
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
|
||||
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
|
||||
kunlun_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
cast_key_cache,
|
||||
cast_value_cache,
|
||||
updated_slot_mapping,
|
||||
)
|
||||
kunlun_ops.reshape_and_cache_flash(
|
||||
key[: attn_metadata.num_actual_tokens],
|
||||
value[: attn_metadata.num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
BLHD_LAYOUT=False,
|
||||
)
|
||||
|
||||
assert attn_type == AttentionType.DECODER
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import kunlun_ops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module):
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
"""
|
||||
Args:
|
||||
metadata:
|
||||
Metadata for spec decoding.
|
||||
@@ -81,7 +80,7 @@ class RejectionSampler(nn.Module):
|
||||
Returns:
|
||||
output_token_ids (torch.Tensor):
|
||||
A tensor containing the final output token IDs.
|
||||
'''
|
||||
"""
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
@@ -124,11 +123,11 @@ class RejectionSampler(nn.Module):
|
||||
"""
|
||||
output_token_ids_np = output_token_ids.cpu().numpy()
|
||||
# Create mask for valid tokens.
|
||||
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
|
||||
(output_token_ids_np < vocab_size))
|
||||
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
|
||||
output_token_ids_np < vocab_size
|
||||
)
|
||||
outputs = [
|
||||
row[valid_mask[i]].tolist()
|
||||
for i, row in enumerate(output_token_ids_np)
|
||||
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
|
||||
]
|
||||
return outputs
|
||||
|
||||
@@ -179,25 +178,15 @@ def rejection_sample(
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
if min(num_draft_tokens) == 1 and max(
|
||||
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
|
||||
rejection_greedy_sample_spec_len_1_pytorch(
|
||||
output_token_ids,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
)
|
||||
else:
|
||||
rejection_greedy_sample_pytorch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
num_draft_tokens,
|
||||
max_spec_len,
|
||||
is_greedy,
|
||||
)
|
||||
kunlun_ops.rejection_greedy_sample(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
|
||||
@@ -222,8 +211,9 @@ def rejection_sample(
|
||||
sampling_metadata,
|
||||
device,
|
||||
)
|
||||
bonus_token_ids = bonus_token_ids.squeeze(1)
|
||||
|
||||
rejection_random_sample_pytorch(
|
||||
kunlun_ops.rejection_random_sample(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
@@ -235,8 +225,7 @@ def rejection_sample(
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
IS_NGRAM=draft_probs is None,
|
||||
# num_warps=1,
|
||||
no_draft_probs=draft_probs is None,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
@@ -374,7 +363,7 @@ def generate_uniform_probs(
|
||||
random values in the range [0, 1).
|
||||
"""
|
||||
uniform_probs = torch.rand(
|
||||
(num_tokens, ),
|
||||
(num_tokens,),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
@@ -422,7 +411,7 @@ def sample_recovered_tokens(
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
recovered_token_ids = torch.empty_like(draft_token_ids)
|
||||
sample_recovered_tokens_pytorch(
|
||||
kunlun_ops.sample_recovered_tokens(
|
||||
recovered_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
@@ -430,16 +419,16 @@ def sample_recovered_tokens(
|
||||
target_probs,
|
||||
q,
|
||||
vocab_size,
|
||||
IS_NGRAM=draft_probs is None,
|
||||
no_draft_probs=draft_probs is None,
|
||||
)
|
||||
return recovered_token_ids
|
||||
|
||||
|
||||
def rejection_greedy_sample_spec_len_1_pytorch(
|
||||
output_token_ids, # [batch_size, 2]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
output_token_ids, # [batch_size, 2]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
):
|
||||
batch_size = output_token_ids.size(0)
|
||||
num_tokens = draft_token_ids.size(0)
|
||||
@@ -447,73 +436,72 @@ def rejection_greedy_sample_spec_len_1_pytorch(
|
||||
accept_req_mask = draft_token_ids == target_argmax
|
||||
output_token_ids[:, 0] = target_argmax
|
||||
bonus_token_ids = bonus_token_ids.squeeze(1)
|
||||
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids,
|
||||
output_token_ids[:, 1])
|
||||
output_token_ids[:, 1] = torch.where(
|
||||
accept_req_mask, bonus_token_ids, output_token_ids[:, 1]
|
||||
)
|
||||
|
||||
|
||||
def rejection_greedy_sample_pytorch(
|
||||
output_token_ids, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens, # [batch_size]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
draft_tokens_per_req, # [batch_size], list
|
||||
max_spec_len,
|
||||
is_greedy=None, # [batch_size] or None
|
||||
output_token_ids, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens, # [batch_size]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
draft_tokens_per_req, # [batch_size], list
|
||||
max_spec_len,
|
||||
is_greedy=None, # [batch_size] or None
|
||||
):
|
||||
batch_size = output_token_ids.size(0)
|
||||
num_tokens = draft_token_ids.size(0)
|
||||
device = output_token_ids.device
|
||||
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
|
||||
device, non_blocking=True)
|
||||
device, non_blocking=True
|
||||
)
|
||||
if is_greedy is None:
|
||||
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
start_indices = cu_num_draft_tokens - draft_tokens_per_req
|
||||
req_ids = torch.arange(batch_size, device=device)
|
||||
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
|
||||
token_positions = torch.arange(
|
||||
num_tokens, device=device) - start_indices[token_req_ids]
|
||||
token_positions = (
|
||||
torch.arange(num_tokens, device=device) - start_indices[token_req_ids]
|
||||
)
|
||||
|
||||
# Find the first mismatch position of each request.
|
||||
mismatch_global = (draft_token_ids != target_argmax)
|
||||
mismatch_global = draft_token_ids != target_argmax
|
||||
if max_spec_len == 0:
|
||||
first_mismatch_pos_per_req = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
first_mismatch_pos_per_req = torch.zeros(
|
||||
batch_size, dtype=torch.long, device=device
|
||||
)
|
||||
else:
|
||||
# [bs, max_spec_len]
|
||||
pos_matrix = torch.full((batch_size, max_spec_len),
|
||||
-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
pos_matrix = torch.full(
|
||||
(batch_size, max_spec_len), -1, dtype=torch.long, device=device
|
||||
)
|
||||
pos_matrix[token_req_ids, token_positions] = token_positions
|
||||
mismatch_matrix = torch.full((batch_size, max_spec_len),
|
||||
False,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
mismatch_matrix = torch.full(
|
||||
(batch_size, max_spec_len), False, dtype=torch.bool, device=device
|
||||
)
|
||||
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
|
||||
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
|
||||
max_spec_len * 2)
|
||||
mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2)
|
||||
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
|
||||
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
|
||||
no_mismatch_mask = first_mismatch_pos_per_req == max_spec_len * 2
|
||||
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
|
||||
no_mismatch_mask]
|
||||
no_mismatch_mask
|
||||
]
|
||||
|
||||
# Copy matched target tokens into output.
|
||||
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
|
||||
draft_tokens_per_req)
|
||||
copy_indices = torch.arange(max_spec_len + 1,
|
||||
device=device).expand(batch_size, -1)
|
||||
copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req)
|
||||
copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1)
|
||||
copy_mask = copy_indices < copy_len.unsqueeze(1)
|
||||
greedy_mask = is_greedy.unsqueeze(1)
|
||||
final_copy_mask = copy_mask & greedy_mask
|
||||
global_idx = start_indices.unsqueeze(1) + copy_indices
|
||||
output_token_ids[final_copy_mask] = target_argmax[
|
||||
global_idx[final_copy_mask]].to(output_token_ids.dtype)
|
||||
output_token_ids[final_copy_mask] = target_argmax[global_idx[final_copy_mask]].to(
|
||||
output_token_ids.dtype
|
||||
)
|
||||
# Fill bonus token.
|
||||
needs_bonus = is_greedy & (first_mismatch_pos_per_req
|
||||
>= draft_tokens_per_req)
|
||||
needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req)
|
||||
if torch.any(needs_bonus):
|
||||
bonus_rows = torch.where(needs_bonus)[0]
|
||||
bonus_cols = draft_tokens_per_req[bonus_rows]
|
||||
@@ -556,11 +544,9 @@ def rejection_random_sample_pytorch(
|
||||
if IS_NGRAM:
|
||||
draft_prob = 1.0
|
||||
else:
|
||||
draft_prob = draft_probs[start_idx + pos,
|
||||
draft_token_id].item()
|
||||
draft_prob = draft_probs[start_idx + pos, draft_token_id].item()
|
||||
|
||||
target_prob = target_probs[start_idx + pos,
|
||||
draft_token_id].item()
|
||||
target_prob = target_probs[start_idx + pos, draft_token_id].item()
|
||||
uniform_prob = uniform_probs[start_idx + pos].item()
|
||||
|
||||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
||||
@@ -629,12 +615,11 @@ def sample_recovered_tokens_pytorch(
|
||||
else:
|
||||
draft_p = draft_probs[token_idx].clone()
|
||||
target_p = target_probs[token_idx].clone()
|
||||
prob = torch.maximum(target_p - draft_p,
|
||||
torch.tensor(0.0, device=target_p.device))
|
||||
prob = torch.maximum(
|
||||
target_p - draft_p, torch.tensor(0.0, device=target_p.device)
|
||||
)
|
||||
|
||||
q_values = torch.full((vocab_size, ),
|
||||
float('-inf'),
|
||||
device=q.device)
|
||||
q_values = torch.full((vocab_size,), float("-inf"), device=q.device)
|
||||
q_values[:vocab_size] = q[req_idx, :vocab_size]
|
||||
|
||||
recovered_id = torch.argmax(prob / q_values).item()
|
||||
@@ -642,4 +627,3 @@ def sample_recovered_tokens_pytorch(
|
||||
|
||||
if IS_NGRAM:
|
||||
target_probs[token_idx, draft_token_id] = orig_prob
|
||||
|
||||
|
||||
@@ -337,5 +337,5 @@ def prepare_next_token_ids_padded(
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
|
||||
|
||||
EagleProposer.propose = propose
|
||||
# EagleProposer.propose = propose
|
||||
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded
|
||||
|
||||
@@ -407,7 +407,7 @@ def add_rmsnorm(
|
||||
) -> None:
|
||||
kunlun_ops.add_rmsnorm(
|
||||
x,
|
||||
y, # 原来写 residual,这里其实是 y
|
||||
y,
|
||||
residual_output=residual_output,
|
||||
weight=weight,
|
||||
eps=eps,
|
||||
@@ -523,6 +523,145 @@ def _fake_add_rmsnorm(
|
||||
add_rmsnorm.register_fake(_fake_add_rmsnorm)
|
||||
|
||||
|
||||
@custom_op("_C::gemma_add_rmsnorm", mutates_args=())
|
||||
def gemma_add_rmsnorm(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweaved: bool = False,
|
||||
store_output_before_norm: bool = True,
|
||||
bias: torch.Tensor = None,
|
||||
smooth: torch.Tensor = None,
|
||||
residual_output: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
# print("gemma_add_rmsnorm wrapper")
|
||||
kunlun_ops.gemma_add_rmsnorm(
|
||||
x,
|
||||
y,
|
||||
weight=weight,
|
||||
output=output,
|
||||
eps=eps,
|
||||
enable_pdl=enable_pdl,
|
||||
interweaved=interweaved,
|
||||
store_output_before_norm=store_output_before_norm,
|
||||
bias=bias,
|
||||
smooth=smooth,
|
||||
residual_output=residual_output,
|
||||
force_sdnn=force_sdnn,
|
||||
)
|
||||
|
||||
|
||||
@impl("_C::gemma_add_rmsnorm", "CUDA")
|
||||
def gemma_add_rmsnorm_cuda(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweaved: bool = False,
|
||||
store_output_before_norm: bool = True,
|
||||
bias: torch.Tensor = None,
|
||||
smooth: torch.Tensor = None,
|
||||
residual_output: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
# print("gemma_add_rmsnorm_cuda wrapper")
|
||||
kunlun_ops.gemma_add_rmsnorm(
|
||||
x,
|
||||
y,
|
||||
weight=weight,
|
||||
output=output,
|
||||
eps=eps,
|
||||
enable_pdl=enable_pdl,
|
||||
interweaved=interweaved,
|
||||
store_output_before_norm=store_output_before_norm,
|
||||
bias=bias,
|
||||
smooth=smooth,
|
||||
residual_output=residual_output,
|
||||
force_sdnn=force_sdnn,
|
||||
)
|
||||
|
||||
|
||||
def _fake_gemma_add_rmsnorm(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweaved: bool = False,
|
||||
store_output_before_norm: bool = True,
|
||||
bias: torch.Tensor = None,
|
||||
smooth: torch.Tensor = None,
|
||||
residual_output: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
):
|
||||
output.fake_shape = x.shape
|
||||
output.fake_dtype = x.dtype
|
||||
return None
|
||||
|
||||
|
||||
gemma_add_rmsnorm.register_fake(_fake_gemma_add_rmsnorm)
|
||||
|
||||
|
||||
@custom_op("_C::gemma_rmsnorm", mutates_args=())
|
||||
def gemma_rmsnorm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweave: bool = False,
|
||||
bias: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
# print("gemma_rmsnorm wrapper")
|
||||
kunlun_ops.gemma_rmsnorm(
|
||||
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
|
||||
)
|
||||
|
||||
|
||||
@impl("_C::gemma_rmsnorm", "CUDA")
|
||||
def gemma_rmsnorm_cuda(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweave: bool = False,
|
||||
bias: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
# print("gemma_rmsnorm_cuda wrapper")
|
||||
kunlun_ops.gemma_rmsnorm(
|
||||
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
|
||||
)
|
||||
|
||||
|
||||
def _fake_gemma_rmsnorm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweave: bool = False,
|
||||
bias: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
):
|
||||
# 设置 shape/dtype,但不返回值
|
||||
output.fake_shape = x.shape
|
||||
output.fake_dtype = x.dtype
|
||||
return None
|
||||
|
||||
|
||||
gemma_rmsnorm.register_fake(_fake_gemma_rmsnorm)
|
||||
|
||||
|
||||
@custom_op("_C::split_norm_rope_neox", mutates_args=())
|
||||
def split_norm_rope_neox(
|
||||
q_emb: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user