### What this PR does / why we need it?
Refactor model structure in qwen3_next.py to reduce code line.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
```
def main():
prompts = [
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
enforce_eager=True,
trust_remote_code=True,
max_model_len=256,
gpu_memory_utilization=0.7,
block_size=64,
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
- vLLM version: v0.10.2
- vLLM main:
https://github.com/vllm-project/vllm/commit/releases/v0.11.0
---------
Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
@@ -53,4 +53,4 @@ def register_model():
|
||||
)
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3NextForCausalLM",
|
||||
"vllm_ascend.models.qwen3_next:Qwen3NextForCausalLM")
|
||||
"vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM")
|
||||
|
||||
@@ -14,10 +14,9 @@ from vllm.attention import AttentionBackend, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
|
||||
VllmConfig, get_current_vllm_config)
|
||||
from vllm.distributed import (divide, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fla.ops import RMSNormGated
|
||||
from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule
|
||||
from vllm.model_executor.layers.fla.ops.fused_recurrent import \
|
||||
@@ -44,27 +43,24 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||
MixtureOfExperts,
|
||||
SupportsLoRA, SupportsPP)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
||||
from vllm.model_executor.models.qwen3_next import (Qwen3NextAttention,
|
||||
Qwen3NextSparseMoeBlock,
|
||||
fused_gdn_gating)
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
|
||||
make_layers, maybe_prefix)
|
||||
PPMissingLayer, extract_layer_index, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import Qwen3NextConfig
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
|
||||
from vllm.model_executor.models.qwen3_next import Qwen3NextAttention # isort: skip
|
||||
from vllm.model_executor.models.qwen3_next import Qwen3NextDecoderLayer # isort: skip
|
||||
from vllm.model_executor.models.qwen3_next import Qwen3NextForCausalLM # isort: skip
|
||||
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet # isort: skip
|
||||
from vllm.model_executor.models.qwen3_next import Qwen3NextModel # isort: skip
|
||||
from vllm.model_executor.models.qwen3_next import Qwen3NextSparseMoeBlock # isort: skip
|
||||
from vllm.model_executor.models.qwen3_next import fused_gdn_gating # isort: skip
|
||||
|
||||
class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
|
||||
class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
@@ -80,14 +76,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.gated_delta_net_state_shape(
|
||||
self.tp_size,
|
||||
self.num_k_heads,
|
||||
self.num_v_heads,
|
||||
self.head_k_dim,
|
||||
self.head_v_dim,
|
||||
self.conv_kernel_size,
|
||||
self.num_spec,
|
||||
use_v1=True)
|
||||
self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim,
|
||||
self.head_v_dim, self.conv_kernel_size, self.num_spec)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -98,7 +88,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
nn.Module.__init__(self)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -195,85 +185,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def fix_query_key_value_ordering(
|
||||
self,
|
||||
mixed_qkvz,
|
||||
mixed_ba,
|
||||
):
|
||||
"""
|
||||
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
|
||||
"""
|
||||
new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
|
||||
self.num_k_heads // self.tp_size,
|
||||
(self.head_k_dim + self.head_k_dim +
|
||||
(self.head_v_dim + self.head_v_dim) * self.num_v_heads //
|
||||
self.num_k_heads),
|
||||
)
|
||||
new_tensor_shape_ba = mixed_qkvz.size()[:-1] + (
|
||||
self.num_k_heads // self.tp_size,
|
||||
2 * self.num_v_heads // self.num_k_heads,
|
||||
)
|
||||
|
||||
mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
|
||||
mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
|
||||
|
||||
split_arg_list_qkvz = [
|
||||
self.head_k_dim,
|
||||
self.head_k_dim,
|
||||
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
|
||||
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
|
||||
]
|
||||
split_arg_list_ba = [
|
||||
self.num_v_heads // self.num_k_heads,
|
||||
self.num_v_heads // self.num_k_heads
|
||||
]
|
||||
|
||||
# [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
|
||||
# --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn],
|
||||
# [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
|
||||
(query, key, value, z) = torch.split(mixed_qkvz,
|
||||
split_arg_list_qkvz,
|
||||
dim=2)
|
||||
(b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)
|
||||
|
||||
# [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
|
||||
value = value.reshape(value.size(0), -1, self.head_v_dim)
|
||||
z = z.reshape(z.size(0), -1, self.head_v_dim)
|
||||
b = b.reshape(b.size(0), self.num_v_heads // self.tp_size)
|
||||
a = a.reshape(a.size(0), self.num_v_heads // self.tp_size)
|
||||
|
||||
return query, key, value, z, b, a
|
||||
|
||||
def rearrange_mixed_qkv(self, mixed_qkv):
|
||||
if mixed_qkv is None:
|
||||
return None, None, None
|
||||
query, key, value = torch.split(
|
||||
mixed_qkv,
|
||||
[
|
||||
self.key_dim // self.tp_size,
|
||||
self.key_dim // self.tp_size,
|
||||
self.value_dim // self.tp_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
query, key = map(
|
||||
lambda x: rearrange(x, 'l (h d) -> 1 l h d', d=self.head_k_dim),
|
||||
(query, key))
|
||||
value = rearrange(value, 'l (h d) -> 1 l h d', d=self.head_v_dim)
|
||||
return query, key, value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
cache_params: Optional[MambaCacheParams] = None,
|
||||
):
|
||||
return torch.ops.vllm.npu_gdn_attention(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -340,24 +251,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
mixed_qkv_spec = None
|
||||
mixed_qkv_non_spec = mixed_qkv
|
||||
|
||||
# 2.1: process the mutli-query part
|
||||
# if spec_sequence_masks is not None:
|
||||
# mixed_qkv_spec = mixed_qkv_spec.view(
|
||||
# attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
|
||||
# mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
|
||||
# mixed_qkv_spec = causal_conv1d_update(
|
||||
# mixed_qkv_spec,
|
||||
# conv_state,
|
||||
# conv_weights,
|
||||
# self.conv1d.bias,
|
||||
# self.activation,
|
||||
# conv_state_indices=spec_state_indices_tensor[:, 0]
|
||||
# [:attn_metadata.num_spec_decodes],
|
||||
# num_accepted_tokens=num_accepted_tokens,
|
||||
# validate_data=False,
|
||||
# )
|
||||
# mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')
|
||||
|
||||
# 2.2: process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
@@ -532,7 +425,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
|
||||
|
||||
|
||||
class Qwen3NextDecoderLayer(nn.Module):
|
||||
class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -545,14 +438,14 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
|
||||
self.layer_type = layer_type
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
|
||||
if self.layer_type == "linear_attention":
|
||||
self.linear_attn = Qwen3NextGatedDeltaNet(
|
||||
self.linear_attn = CustomQwen3NextGatedDeltaNet(
|
||||
config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
@@ -611,69 +504,12 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
dtype=config.torch_dtype,
|
||||
), )
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
positions: torch.Tensor = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
self_attention_output = torch.empty_like(hidden_states)
|
||||
if self.layer_type == "linear_attention":
|
||||
self.linear_attn(
|
||||
hidden_states=hidden_states,
|
||||
output=self_attention_output,
|
||||
)
|
||||
elif self.layer_type == "full_attention":
|
||||
self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
output=self_attention_output,
|
||||
positions=positions,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid layer_type")
|
||||
hidden_states = self_attention_output
|
||||
|
||||
if self.layer_scale:
|
||||
if len(hidden_states.shape) == 2:
|
||||
hidden_states = hidden_states * (
|
||||
self.attn_layer_scale.to(hidden_states.dtype)[0] + 1)
|
||||
else:
|
||||
hidden_states = hidden_states * (
|
||||
self.attn_layer_scale.to(hidden_states.dtype) + 1)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if self.layer_scale:
|
||||
if len(hidden_states.shape) == 2:
|
||||
hidden_states = hidden_states * (
|
||||
self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1)
|
||||
else:
|
||||
assert len(hidden_states.shape) == len(
|
||||
self.ffn_layer_scale.shape
|
||||
), f'shape must be the same {len(hidden_states.shape)}, {len(self.ffn_layer_scale.shape)}' # noqa: E501
|
||||
hidden_states = hidden_states * (
|
||||
self.ffn_layer_scale.to(hidden_states.dtype) + 1)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Qwen3NextModel(nn.Module):
|
||||
class CustomQwen3NextModel(Qwen3NextModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
nn.Module.__init__(self)
|
||||
config: Qwen3NextConfig = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@@ -697,7 +533,7 @@ class Qwen3NextModel(nn.Module):
|
||||
)
|
||||
|
||||
def get_layer(prefix: str):
|
||||
return Qwen3NextDecoderLayer(
|
||||
return CustomQwen3NextDecoderLayer(
|
||||
config,
|
||||
layer_type=config.layer_types[extract_layer_index(prefix)],
|
||||
model_config=model_config,
|
||||
@@ -717,52 +553,6 @@ class Qwen3NextModel(nn.Module):
|
||||
self.norm = Qwen3NextRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts,
|
||||
num_redundant_experts=self.num_redundant_experts)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
@@ -842,10 +632,10 @@ class Qwen3NextModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
MixtureOfExperts, IsHybrid):
|
||||
class CustomQwen3NextForCausalLM(Qwen3NextForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
@@ -856,12 +646,10 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
"Qwen3Next currently does not support prefix caching"
|
||||
assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1"
|
||||
self.quant_config = vllm_config.quant_config
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model = Qwen3NextModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.model = CustomQwen3NextModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
@@ -904,127 +692,3 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
self.num_local_physical_experts = example_layer.n_local_physical_experts
|
||||
self.num_routed_experts = example_layer.n_routed_experts
|
||||
self.num_redundant_experts = example_layer.n_redundant_experts
|
||||
|
||||
def set_eplb_state(
|
||||
self,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> None:
|
||||
for layer_idx, layer in enumerate(self.moe_layers):
|
||||
# Register the expert weights.
|
||||
self.expert_weights.append(layer.get_expert_weights())
|
||||
layer.set_eplb_state(
|
||||
moe_layer_idx=layer_idx,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def update_physical_experts_metadata(
|
||||
self,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
) -> None:
|
||||
assert self.num_local_physical_experts == num_local_physical_experts
|
||||
self.num_physical_experts = num_physical_experts
|
||||
self.num_local_physical_experts = num_local_physical_experts
|
||||
self.num_redundant_experts = (num_physical_experts -
|
||||
self.num_logical_experts)
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
|
||||
moe = layer.mlp
|
||||
moe.n_local_physical_experts = num_local_physical_experts
|
||||
moe.n_physical_experts = num_physical_experts
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls, vllm_config: "VllmConfig"
|
||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
num_spec = (vllm_config.speculative_config.num_speculative_tokens
|
||||
if vllm_config.speculative_config else 0)
|
||||
return MambaStateShapeCalculator.gated_delta_net_state_shape(
|
||||
tp_size,
|
||||
hf_config.linear_num_key_heads,
|
||||
hf_config.linear_num_value_heads,
|
||||
hf_config.linear_key_head_dim,
|
||||
hf_config.linear_value_head_dim,
|
||||
hf_config.linear_conv_kernel_dim,
|
||||
num_spec,
|
||||
use_v1=True)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata=None, # type: ignore
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=["mtp."],
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
|
||||
def npu_gdn_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self._forward(hidden_states=hidden_states, output=output)
|
||||
|
||||
|
||||
def npu_gdn_attention_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="npu_gdn_attention",
|
||||
op_func=npu_gdn_attention,
|
||||
mutates_args=["output"],
|
||||
fake_impl=npu_gdn_attention_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user