add qwen3

This commit is contained in:
Chranos
2026-02-04 17:22:39 +08:00
parent d1c0f68ab4
commit 8511fe8530
1932 changed files with 300426 additions and 0 deletions

View File

@@ -0,0 +1,8 @@
# hijack vllm layers
import vllm_mlu.model_executor.layers
# hijack vllm models
import vllm_mlu.model_executor.models
# hijack vllm model loader
import vllm_mlu.model_executor.model_loader

View File

@@ -0,0 +1 @@
import vllm_mlu.model_executor.custom_model.custom

View File

@@ -0,0 +1,644 @@
from collections import namedtuple
from typing import Any, Dict, Iterable, Union, List, Optional, Tuple
import math
import torch
from torch import nn
from vllm import _mlu_ops as mlu_ops
from vllm.config import CacheConfig, VllmConfig
from vllm_mlu.transformers_utils.configs import CustomConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear,
ReplicatedLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm.model_executor.models.utils import PPMissingLayer, make_layers
from vllm_mlu.model_executor.models.layer_utils import (
decoder_layer_forward_base, is_per_tensor_smoothquant,
is_per_token_smoothquant, quant_fusion_with_rmsnorm,
quant_fusion_with_layernorm)
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class LayerNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
x = x.view(-1, self.weight.data.shape[0])
if residual is not None:
residual = residual.view(-1, self.weight.data.shape[0])
return mlu_ops.fused_layer_norm(x, residual, self.weight.data, self.bias.data, None, self.variance_epsilon, True)
else:
return mlu_ops.fused_layer_norm(x, residual, self.weight.data, self.bias.data, None, self.variance_epsilon, False)
_NORM_DICT: Dict[str, nn.Module] = {"rmsnorm": RMSNorm, "layernorm": LayerNorm}
class CustomMoeBlock(nn.Module):
def __init__(
self,
config: CustomConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.n_routed_experts = config.num_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.n_routed_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.n_routed_experts}.")
self.moe_intermediate_size = self.config.moe_intermediate_size // self.tp_size
if quant_config is None:
self.w1 = nn.Parameter(
torch.empty(self.config.num_experts,
2 * self.moe_intermediate_size if self.config.is_gated else self.moe_intermediate_size,
self.config.hidden_size,
dtype=torch.get_default_dtype()), requires_grad=False)
self.w2 = nn.Parameter(
torch.empty(self.config.num_experts,
self.config.hidden_size,
self.moe_intermediate_size,
dtype=torch.get_default_dtype()), requires_grad=False)
self.w1_scale = None
self.w2_scale = None
self.input_smooth = None
self.act_smooth = None
else:
assert quant_config.weight_bits == 8
self.w1 = nn.Parameter(
torch.empty(self.config.num_experts,
2 * self.moe_intermediate_size if self.config.is_gated else self.moe_intermediate_size,
self.config.hidden_size,
device="mlu",
dtype=torch.int8), requires_grad=False)
self.w2 = nn.Parameter(
torch.empty(self.config.num_experts,
self.config.hidden_size,
self.moe_intermediate_size,
device="mlu",
dtype=torch.int8), requires_grad=False)
self.w1_scale = nn.Parameter(
torch.empty(
self.config.num_experts,
2 * self.moe_intermediate_size if self.config.is_gated else self.moe_intermediate_size,
device="mlu",
dtype=torch.float32), requires_grad=False)
self.w2_scale = nn.Parameter(
torch.empty(
self.config.num_experts,
self.config.hidden_size,
device="mlu",
dtype=torch.float32), requires_grad=False)
self.input_smooth = None
self.act_smooth = None
if quant_config.quant_mode == "SmoothQuant":
self.input_smooth =nn.Parameter(
torch.empty(
self.config.num_experts,
self.config.hidden_size,
device="mlu",
dtype=torch.float32), requires_grad=False)
self.act_smooth =nn.Parameter(
torch.empty(
self.config.num_experts,
self.moe_intermediate_size,
device="mlu",
dtype=torch.float32), requires_grad=False)
self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts,
bias=False,
quant_config=None)
if config.shared_expert_intermediate_size > 0:
self.shared_expert = FeedForward(hidden_size=config.hidden_size,
intermediate_size=config.shared_expert_intermediate_size,
hidden_act = self.config.hidden_act,
up_proj_name='gate_up_proj',
is_gated=self.config.is_gated,
down_proj_name='down_proj',
bias=False,
quant_config=quant_config,
reduce_results=False)
else:
self.shared_expert = None
self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
1,
bias=False)
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
if self.shared_expert_gate is not None:
shared_output = F.sigmoid(
self.shared_expert_gate(hidden_states)) * shared_output
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
residual_ = None if self.rank > 0 else residual
final_hidden_states = mlu_ops.fused_moe(hidden_states,
router_logits,
self.w1,
self.w2,
None,
None,
residual_,
self.input_smooth,
self.act_smooth,
self.w1_scale,
self.w2_scale,
self.top_k,
self.config.norm_topk_prob,
self.config.is_gated,
self.config.hidden_act)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
reduce_results = (self.config.use_parallel_residual == False)
if reduce_results:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
class CustomAttention(nn.Module):
def __init__(
self,
config: CustomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
attention_bias = getattr(config, "attention_bias", False) or getattr(config, "bias", False)
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads)
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.kv_scale = 1.0
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=attention_bias,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=attention_bias,
quant_config=quant_config,
skip_bias_add=(self.config.use_parallel_residual and attention_bias),
reduce_results = (self.config.use_parallel_residual == False),
)
self.alibi_slopes = None
self.rotary_emb = None
if self.config.position_embedding_type == "ALIBI":
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
self.alibi_slopes = alibi_slopes[head_start:head_end].tolist()
else:
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_sequence_length", 8192)
is_neox_style = getattr(config, "is_neox_style", False)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
is_neox_style=is_neox_style,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
alibi_slopes=self.alibi_slopes,
cache_config=cache_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.rotary_emb:
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, bias = self.o_proj(attn_output, residual)
if self.o_proj.skip_bias_add and get_tensor_model_parallel_rank() == 0:
output += bias
return output
class CustomDecoderLayer(nn.Module):
def __init__(
self,
config: CustomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.self_attn = CustomAttention(
config=config,
cache_config=cache_config,
quant_config=quant_config,
)
mlp_bias = getattr(config, "mlp_bias", False) or getattr(config, "bias", False)
is_gated = getattr(config, "is_gated", False)
if config.num_experts is not None:
self.mlp = CustomMoeBlock(config=config,
quant_config=quant_config)
else:
self.mlp = FeedForward(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=self.config.hidden_act,
up_proj_name='up_proj',
is_gated=is_gated,
down_proj_name='down_proj',
bias=mlp_bias,
quant_config=quant_config,
skip_bias_add=(self.config.use_parallel_residual and mlp_bias),
reduce_results = (self.config.use_parallel_residual == False))
self.input_layernorm = _NORM_DICT[self.config.norm_type](config.hidden_size, eps=config.norm_eps)
self.post_attention_layernorm = _NORM_DICT[self.config.norm_type](config.hidden_size, eps=config.norm_eps)
# perf per-tensor sq cases by fusing quantization in layernorm
self.is_per_tesnor_sq_perf_cases = (is_per_tensor_smoothquant(quant_config) and
not self.config.apply_residual_connection_post_layernorm)
self.is_per_token_sq_perf_cases = (is_per_token_smoothquant(quant_config) and
not self.config.apply_residual_connection_post_layernorm)
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
self.quant_fusion_attn_layernorm = None
self.is_moe = config.num_experts is not None
self.use_rmsnorm = self.config.norm_type == "rmsnorm"
if not self.is_moe:
self.mlp.up_proj.quant_method.skip_quant_input = True
self.quant_fusion_mlp_layernorm = None
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.config.use_parallel_residual:
# x = x + attn(ln1(x)) + mlp(ln2(x))
layernorm_output = self.input_layernorm(hidden_states)
attention_output = self.self_attn(
positions=positions,
hidden_states=layernorm_output,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
layernorm_output = self.post_attention_layernorm(hidden_states)
if self.mlp.skip_bias_add:
mlp_output, mlp_bias = self.mlp(layernorm_output)
if get_tensor_model_parallel_rank() == 0:
mlp_output += mlp_bias
else:
mlp_output = self.mlp(layernorm_output)
if get_tensor_model_parallel_rank() == 0:
hidden_states = mlp_output + attention_output + hidden_states
else:
hidden_states = mlp_output + attention_output
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
return hidden_states, None
else:
# rmsnorm use fused_rms_norm to get better performance
# if apply_residual_connection_post_layernorm:
# x = ln1(x) + attn(ln1(x))
# x = ln2(x) + mlp(ln2(x))
# else:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_layernorm = self.input_layernorm
mlp_layernorm = self.post_attention_layernorm
if self.is_per_tesnor_sq_perf_cases:
quant_fusion_func = (quant_fusion_with_rmsnorm if
self.use_rmsnorm else quant_fusion_with_layernorm)
if self.quant_fusion_attn_layernorm is None:
self.quant_fusion_attn_layernorm = quant_fusion_func(
self.input_layernorm, self.self_attn.qkv_proj.scale_to_int)
attn_layernorm = self.quant_fusion_attn_layernorm
if not self.is_moe:
if self.quant_fusion_mlp_layernorm is None:
self.quant_fusion_mlp_layernorm = quant_fusion_func(
self.post_attention_layernorm, self.mlp.up_proj.scale_to_int)
mlp_layernorm = self.quant_fusion_mlp_layernorm
elif self.is_per_token_sq_perf_cases:
quant_fusion_func = (quant_fusion_with_rmsnorm if
self.use_rmsnorm else quant_fusion_with_layernorm)
if self.quant_fusion_attn_layernorm is None:
self.quant_fusion_attn_layernorm = quant_fusion_func(
self.input_layernorm, self.self_attn.qkv_proj.smooth, dynamic_quant=True)
attn_layernorm = self.quant_fusion_attn_layernorm
if not self.is_moe:
if self.quant_fusion_mlp_layernorm is None:
self.quant_fusion_mlp_layernorm = quant_fusion_func(
self.post_attention_layernorm, self.mlp.up_proj.smooth, dynamic_quant=True)
mlp_layernorm = self.quant_fusion_mlp_layernorm
post_norm_fuse_en=(self.is_per_token_sq_perf_cases and not self.is_moe)
return decoder_layer_forward_base(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
input_layernorm=attn_layernorm,
self_attn=self.self_attn,
post_layernorm=mlp_layernorm,
mlp=self.mlp,
apply_residual_connection_post_layernorm=self.config.apply_residual_connection_post_layernorm,
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
post_norm_fuse_en=post_norm_fuse_en), None
class CustomModel(nn.Module):
def __init__(
self,
config: CustomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
embed_layer = VocabParallelEmbedding if self.config.use_parallel_embedding else nn.Embedding
self.embed_tokens = embed_layer(
config.vocab_size,
config.hidden_size,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: CustomDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config),
prefix="custom_model")
if get_pp_group().is_last_rank:
self.norm = _NORM_DICT[self.config.norm_type](config.hidden_size, eps=config.norm_eps)
else:
self.norm = PPMissingLayer()
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
if residual is not None:
hidden_states, _ = self.norm(hidden_states, residual)
else:
hidden_states = self.norm(hidden_states)
return hidden_states
class CustomForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
self.config = vllm_config.model_config.hf_text_config
self.quant_config = vllm_config.quant_config
self.cache_config = vllm_config.cache_config
self._verify_params()
self.model = CustomModel(self.config, self.cache_config, self.quant_config)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(self.config.vocab_size, self.config.hidden_size)
self.logits_processor = LogitsProcessor(self.config.vocab_size)
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
pass
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def _verify_params(self) -> None:
if (self.config.max_sequence_length) is None or \
(self.config.num_hidden_layers) is None or \
(self.config.hidden_size) is None or \
(self.config.vocab_size) is None or \
(self.config.num_attention_heads) is None:
raise ValueError(
"max_sequence_length, num_hidden_layers, hidden_size, vocab_size, "
"num_attention_heads, must be vaild int values")
if self.config.hidden_act not in ["silu", "gelu"]:
raise ValueError(
"CustomConfig hidden_act must be one of [silu, gelu]. Got "
f"{self.config.hidden_act}.")
if self.config.position_embedding_type not in ["ALIBI", "ROPE"]:
raise ValueError(
"position_embedding_type must be one of [ALIBI, ROPE]. Got "
f"{self.config.position_embedding_type}.")
if self.config.num_experts is not None:
if self.config.num_experts_per_tok is None:
raise ValueError(
"num_experts_per_tok must be a valid int value when num_experts is not None")
if self.config.moe_intermediate_size is None:
raise ValueError(
"moe_intermediate_size must be a valid int value when num_experts is not None")
if self.config.shared_expert_intermediate_size is None:
raise ValueError(
"shared_expert_intermediate_size must be a valid int value when num_experts is not None")
if self.config.norm_topk_prob is None:
raise ValueError(
"norm_topk_prob must be a valid bool value when num_experts is not None")
if self.config.mlp_bias is True:
raise ValueError(
"mlp_bias must be False when num_experts is not None")
if self.quant_config is not None and self.quant_config.get_name() != "SmoothQuant":
raise ValueError(
"moe only support smoothquant now")
else:
if self.config.intermediate_size is None:
raise ValueError(
"intermediate_size must be a valid int value when num_experts is None")
if self.config.norm_type not in ["rmsnorm", "layernorm"]:
raise ValueError(
"norm_type must be one of [rmsnorm, layernorm]. Got "
f"{self.config.norm_type}.")

View File

@@ -0,0 +1,8 @@
import vllm_mlu.model_executor.layers.feed_forward
import vllm_mlu.model_executor.layers.sparse_moe_mlp
import vllm_mlu.model_executor.layers.linear
import vllm_mlu.model_executor.layers.spec_decode_base_sampler
import vllm_mlu.model_executor.layers.rotary_embedding
import vllm_mlu.model_executor.layers.quantization
import vllm_mlu.model_executor.layers.activation

View File

@@ -0,0 +1,22 @@
import torch
from vllm.model_executor.layers.activation import QuickGELU
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm import _mlu_ops as mlu_ops
def vllm__model_executor__activation__QuickGELU__forward_mlu(self, x: torch.Tensor) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: implement forward_mlu
'''
return mlu_ops.active(x, 'quick_gelu', False)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(QuickGELU,
"forward_mlu",
vllm__model_executor__activation__QuickGELU__forward_mlu)

View File

@@ -0,0 +1,150 @@
import torch
import torch.nn.functional as F
from typing import Optional
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm import _mlu_ops as mlu_ops
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
ColumnParallelLinear,
RowParallelLinear
)
from vllm_mlu.mlu_hijack_utils import set_is_gated
from vllm.distributed import get_tensor_model_parallel_rank
logger = init_logger(__name__)
class FeedForward(torch.nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
up_proj_name: str,
is_gated: bool,
down_proj_name: str,
bias: bool,
quant_config: Optional[QuantizationConfig] = None,
skip_bias_add: bool = False,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.is_gated = is_gated
self.bias = bias
self.up_proj_name = up_proj_name
self.down_proj_name = down_proj_name
self.quant_config = quant_config
self.is_initialized = False
self.skip_bias_add = skip_bias_add
self.reduce_results = reduce_results
self.use_bt_ffn = True if quant_config is None else False
self.tp_size = get_tensor_model_parallel_world_size()
set_is_gated(self.is_gated)
self.tp_rank = get_tensor_model_parallel_rank()
# up_proj with gate or not
if self.is_gated:
up_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.{up_proj_name}")
else:
up_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=bias,
skip_bias_add=skip_bias_add,
quant_config=quant_config,
prefix=f"{prefix}.{up_proj_name}")
self.register_module(up_proj_name, up_proj)
# down_proj
down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=bias,
skip_bias_add=skip_bias_add,
reduce_results=reduce_results,
quant_config=quant_config,
prefix=f"{prefix}.{down_proj_name}")
self.register_module(down_proj_name, down_proj)
def prepare_weight(self):
if not self.is_initialized:
# alpha and beta are 1.0 and 0.0 respectively due to the fact that we don't need residual for now
self.alpha = 1.0
self.beta = 0.0
# place it here to avoid the overhead of calling it in the forward pass
self.is_initialized = True
def _forward(self, hidden_states):
self.prepare_weight()
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
act_dict = {
"relu": F.relu,
"gelu": F.gelu,
"silu": F.silu,
}
fc1 = F.linear(hidden_states, up_proj.weight, bias=up_proj.bias)
if self.is_gated:
d = fc1.shape[-1] // 2
fc1 = act_dict[self.hidden_act](fc1[..., :d]) * fc1[..., d:]
else:
fc1 = act_dict[self.hidden_act](fc1)
fc2 = F.linear(fc1, down_proj.weight, bias=None)
fc2 = tensor_model_parallel_all_reduce(fc2)
if not self.skip_bias_add:
fc2 = fc2 + down_proj.bias if down_proj.bias is not None else fc2
return fc2
def forward(
self,
hidden_states,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None
):
self.prepare_weight()
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
residual_ = None if self.tp_rank > 0 else residual
if (self.use_bt_ffn and not isinstance(up_proj, BaseLayerWithLoRA)
and not isinstance(down_proj, BaseLayerWithLoRA)):
# The matmul formula is the following:
# mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual
# output = active(mul_out)
# Notes: We cannot use the activation function in matmul because it does not support gated operation
# we might support its in tmo matmul in the future
fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj.weight, up_proj.bias,
None, 'none', self.alpha, self.beta)
act_out = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
beta = 0.0
if residual_ is not None:
beta = 1.0
residual_ = residual_.view(-1, residual_.shape[-1])
out_ = mlu_ops.matmul(act_out, down_proj.weight, None, residual_, 'none', self.alpha, beta)
# bias if existed need to add after second matmul according to the original design of vllm
if self.reduce_results:
out = tensor_model_parallel_all_reduce(out_)
else:
out = out_
# do the bias add if needed
if not self.skip_bias_add:
out = out + down_proj.bias if down_proj.bias is not None else out
else:
return out, down_proj.bias
else:
fc1, bias = up_proj(hidden_states, smooth_quant_scale=smooth_quant_scale)
if bias is not None:
fc1 += bias
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
out, bias = down_proj(fc1, residual=residual_)
if self.skip_bias_add:
return out, bias
return out

View File

@@ -0,0 +1,101 @@
import torch
from typing import Optional
from vllm.distributed import (get_tensor_model_parallel_rank,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (
UnquantizedLinearMethod,
ColumnParallelLinear,
RowParallelLinear
)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm import _mlu_ops as mlu_ops
from vllm.logger import init_logger
logger = init_logger(__name__)
def vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
beta = 0.0
if residual is not None:
beta = 1.0
residual = residual.view(-1, residual.shape[-1])
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
return mlu_ops.matmul(x.view(-1, x.shape[-1]), layer.weight, bias, residual, 'none', 1.0, beta).view(res_shape)
def vllm__module_executor__layers__linear__RowParallelLinear__forward(self, input_, residual: Optional[torch.Tensor] = None):
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
residual_ = None if self.tp_rank > 0 else residual
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_,
residual=residual_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def vllm__module_executor__layers__linear__ColumnParallelLinear__forward(
self, input_, smooth_quant_scale: Optional[torch.Tensor] = None):
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add input_scale parameter.
'''
if smooth_quant_scale is not None:
output_parallel = self.quant_method.apply(self, input_, bias,
input_scale=smooth_quant_scale)
else:
output_parallel = self.quant_method.apply(self, input_, bias)
'''
==================
End of MLU Hijack
==================
'''
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
MluHijackObject.apply_hijack(UnquantizedLinearMethod,
UnquantizedLinearMethod.apply,
vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.forward,
vllm__module_executor__layers__linear__RowParallelLinear__forward)
MluHijackObject.apply_hijack(ColumnParallelLinear,
ColumnParallelLinear.forward,
vllm__module_executor__layers__linear__ColumnParallelLinear__forward)

View File

@@ -0,0 +1,13 @@
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm_mlu.model_executor.layers.quantization.weightonly import WeightOnlyConfig
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantConfig
from vllm_mlu.model_executor.layers.quantization.gptq_mlu import GPTQMluConfig
from vllm_mlu.model_executor.layers.quantization.awq_mlu import AWQMluConfig
QUANTIZATION_METHODS.update({
"gptq_mlu": GPTQMluConfig,
"awq_mlu": AWQMluConfig,
"weightonly": WeightOnlyConfig,
"smoothquant": SmoothQuantConfig,
})

View File

@@ -0,0 +1,414 @@
from typing import Any, Dict, List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.layers.linear import QKVParallelLinear, MergedColumnParallelLinear
from vllm.scalar_type import ScalarType, scalar_types
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm import _mlu_ops as mlu_ops
logger = init_logger(__name__)
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
def query_mlu_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if device_capability < 50:
return []
if has_zp:
# AWQ style, unsigned + zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
return [scalar_types.uint4b8, scalar_types.uint8b128]
def check_mlu_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
supported_types = query_mlu_supported_quant_types(
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"Mlu does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
return (False, f"Mlu does not support group_size = {group_size}. "
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True
class AWQMluConfig(QuantizationConfig):
"""Config class for AWQMlu.
Reference: https://arxiv.org/abs/2306.00978
"""
# num_bits -> type
TYPE_MAP = {
4: {
False: scalar_types.uint4b8,
True: scalar_types.uint4,
},
8: {
False: scalar_types.uint8b128,
True: scalar_types.uint8,
}
}
VERSION = ["gemm"]
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
lm_head_quantized: bool,
version: str = "gemm",
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.pack_factor = 32 // self.weight_bits
self.version = version
self.support_scale_zeros = False
if self.weight_bits not in [4, 8]:
raise ValueError(
"Currently, only 4/8-bit weight quantization is supported for "
f"AWQMlu, but got {self.weight_bits} bits.")
if self.version not in self.VERSION:
raise ValueError(
"Currently, only gemm, gemv version is supported for "
f"AWQMlu, but got verion:{self.version}.")
if self.version in ["gemm"]:
self.order_map = {4: [0, 2, 4, 6, 1, 3, 5, 7], 8: [0, 2, 1, 3]}
self.reverse_order_map = {4 : [0, 4, 1, 5, 2, 6, 3, 7], 8: [0, 2, 1, 3]}
else:
self.order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
self.reverse_order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
def __repr__(self) -> str:
return (f"AWQMluConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point}), "
f"lm_head_quantized={self.lm_head_quantized})")
@classmethod
def get_name(cls) -> str:
return "awq_mlu"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@classmethod
def get_min_capability(cls) -> int:
return 50
@staticmethod
def get_config_filenames() -> List[str]:
return ["quant_config.json", "quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQMluConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
version = cls.get_from_keys_or(config, ["version"],
default="gemm")
return cls(weight_bits, group_size, zero_point, lm_head_quantized, version)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQMluLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return AWQMluLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_awq_mlu_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "awq"
or user_quant == "awq_mlu")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "awq":
logger.info("Detected that the model can run with awq_mlu"
", however you specified quantization=awq explicitly,"
" so forcing awq. Use quantization=awq_mlu for"
" faster inference")
return None
@classmethod
def is_awq_mlu_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
has_zp = quant_config.get("zero_point", None)
version = quant_config.get("version", "gemm")
if quant_method != "awq":
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or has_zp is None):
return False
if num_bits not in cls.TYPE_MAP:
return False
if version not in cls.VERSION:
return False
return check_mlu_supported(quant_type=cls.TYPE_MAP[num_bits][has_zp],
group_size=group_size,
has_zp=has_zp)
class AWQMluLinearMethod(LinearMethodBase):
"""Linear method for AWQMlu.
Args:
quant_config: The AWQMlu quantization config.
"""
def __init__(self, quant_config: AWQMluConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
weight_loader = extra_weight_attrs.get("weight_loader")
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
qzeros = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
scales = GroupQuantScaleParameter(data=torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
dtype=params_dtype,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader)
layer.register_parameter("qweight", qweight)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
packed_qweight, scale_zeros = self.extract_autoawq(layer)
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
layer.qzeros = None
layer.scales = None
else:
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
if scale_zeros is not None:
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
else:
layer.qzeros = None
layer.scales = torch.nn.Parameter(layer.scales.data.transpose(0, 1).contiguous(), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.quant_config.zero_point and not self.quant_config.support_scale_zeros:
output = mlu_ops.matmul(x, layer.qweight, bias)
if residual is not None:
output = output + residual
else:
output = mlu_ops.weight_only_quant_matmul(x,
layer.qweight,
layer.scales,
layer.qzeros,
bias,
residual,
"none",
self.quant_config.weight_bits)
return output
def extract_autoawq(self, layer: torch.nn.Module):
qweight = layer.qweight.data
qzeros = layer.qzeros.data
scales = layer.scales.data
bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
# Unpack the qweight and qzeros tensors
iweight, izeros = self.unpack_awq_int32_into_int8(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = self.reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
if izeros is not None:
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
scales = scales.repeat_interleave(group_size, dim=0)
if izeros is not None:
izeros = izeros.repeat_interleave(group_size, dim=0)
fweight = (iweight - izeros) * scales
else:
fweight = iweight * scales
# transpose [ci, co] -> [co, ci]
fweight = fweight.transpose(0, 1)
return fweight, None
if self.quant_config.zero_point and self.quant_config.support_scale_zeros and izeros is not None:
scale_zeros = izeros.to(scales.dtype) * -1 * scales
# transpose [ci, co] -> [co, ci]
scale_zeros = scale_zeros.transpose(0, 1)
else:
scale_zeros = None
# transpose [ci, co] -> [co, ci]
iweight = iweight.to(torch.int8).transpose(0, 1)
if bits == 4:
higher_bit_tensor = iweight[:, 1::2]
lower_bit_tensor = iweight[:, 0::2]
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
else:
packed_qweight = iweight
return packed_qweight, scale_zeros
def unpack_awq_int32_into_int8(self, qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qweight.device)
dtype = torch.int16 if bits == 8 else torch.int8
# unpacking columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(dtype)
iweights = iweights.view(iweights.shape[0], -1)
if not self.quant_config.zero_point or self.quant_config.support_scale_zeros:
iweights = torch.bitwise_and(iweights - 2**(bits - 1), (2 ** bits) - 1)
# unpacking columnwise
if qzeros is not None:
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(dtype)
izeros = izeros.view(izeros.shape[0], -1)
if not self.quant_config.zero_point:
izeros = torch.bitwise_and(izeros - 2**(bits - 1), (2 ** bits) - 1)
else:
izeros = None
return iweights, izeros
def reverse_awq_order(self, iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
reverse_order_tensor = torch.arange(iweights.shape[-1], dtype=torch.int32, device=iweights.device)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, self.quant_config.reverse_order_map[bits]]
reverse_order_tensor = reverse_order_tensor.view(-1)
rweights = iweights[:, reverse_order_tensor]
if izeros is not None:
rzeros = izeros[:, reverse_order_tensor]
return rweights, rzeros
def combine_low_bits(self, tensor_a, tensor_b):
"""
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
Args:
tensor_a (torch.Tensor): First tensor of type int8.
tensor_b (torch.Tensor): Second tensor of type int8.
Returns:
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
"""
# 确保输入是 int8 类型
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
raise ValueError("Both tensors must be of int8 type.")
# 提取每个 tensor 的低4位
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
# 将 tensor_a 的低4位左移4位
shifted_low_bits_a = low_bits_a << 4
# 组合两个 tensor 的低4位
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
return combined

View File

@@ -0,0 +1,441 @@
import enum
from enum import Enum
from fractions import Fraction
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm import _mlu_ops as mlu_ops
logger = init_logger(__name__)
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
def query_mlu_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if device_capability < 50:
return []
if has_zp:
# AWQ style, unsigned + zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
return [scalar_types.uint4b8, scalar_types.uint8b128]
def check_mlu_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
supported_types = query_mlu_supported_quant_types(
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"Mlu does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
return (False, f"Mlu does not support group_size = {group_size}. "
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True
class GPTQMluConfig(QuantizationConfig):
"""Config class for GPTQMlu.
Reference: https://arxiv.org/abs/2210.17323
"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.lm_head_quantized = lm_head_quantized
self.pack_factor = Fraction(32, self.weight_bits)
self.support_scale_zeros = False
if self.weight_bits not in [4, 8]:
raise ValueError(
"Currently, only 4/8-bit weight quantization is "
f"supported for GPTQMlu, but got {self.weight_bits} bits.")
def __repr__(self) -> str:
return (f"GPTQMluConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}),"
f"lm_head_quantized={self.lm_head_quantized}")
@classmethod
def get_name(cls) -> str:
return "gptq_mlu"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 50
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quant_config.json", "quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMluConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym, lm_head_quantized)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQMluLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQMluLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
@classmethod
def is_gptq_mlu_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
if quant_method != "gptq":
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
if (num_bits, sym) not in cls.TYPE_MAP:
return False
return check_mlu_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size, has_zp=False)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_gptq_mlu_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "gptq_mlu")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
return None
class GPTQMluLinearMethod(LinearMethodBase):
"""Linear method for GPTQMlu.
Args:
quant_config: The GPTQMlu quantization config.
"""
def __init__(self, quant_config: GPTQMluConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs.get("weight_loader")
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if (input_size != input_size_per_partition) and (self.quant_config.group_size !=
-1) and (not self.quant_config.desc_act):
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
g_idx = RowvLLMParameter(data=torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
qzeros_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scale_and_zero_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.device = layer.qweight.data.device
if self.quant_config.desc_act:
g_idx_list = layer.g_idx.data.tolist()
g_idx_unique = list(dict.fromkeys(g_idx_list))
g_idx = torch.tensor(g_idx_unique, dtype=layer.g_idx.data.dtype, device=self.device)
scales = layer.scales.data[g_idx]
else:
scales = layer.scales.data
packed_qweight, scale_zeros = self.extract_autogptq(layer, scales)
if (not self.quant_config.is_sym) and (not self.quant_config.support_scale_zeros):
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
layer.qzeros = None
layer.scales = None
else:
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
if scale_zeros is not None:
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
else:
layer.qzeros = None
layer.scales = torch.nn.Parameter(scales.transpose(0, 1).contiguous(), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
if (not self.quant_config.is_sym) and (not self.quant_config.support_scale_zeros):
output = mlu_ops.matmul(x, layer.qweight, bias)
if residual is not None:
output = output + residual
else:
output = mlu_ops.weight_only_quant_matmul(x,
layer.qweight,
layer.scales,
layer.qzeros,
bias,
residual,
"none",
self.quant_config.weight_bits)
return output
def extract_autogptq(self, layer: torch.nn.Module, scales: torch.Tensor):
bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
# Unpack the qweight and qzeros tensors
iweight = self.unpack_gptq_qweight_int32_into_int8(layer.qweight.data, bits)
izeros = self.unpack_gptq_qzeros_int32_into_int8(layer.qzeros.data, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
if izeros is not None:
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
if not self.quant_config.is_sym and (not self.quant_config.support_scale_zeros):
scales = scales.repeat_interleave(group_size, dim=0)
if izeros is not None:
izeros = izeros.repeat_interleave(group_size, dim=0)
fweight = (iweight - izeros) * scales
else:
fweight = iweight * scales
# transpose [ci, co] -> [co, ci]
fweight = fweight.transpose(0, 1)
return fweight, None
if not self.quant_config.is_sym and self.quant_config.support_scale_zeros and izeros is not None:
scale_zeros = izeros.to(scales.dtype) * -1 * scales
# transpose [ci, co] -> [co, ci]
scale_zeros = scale_zeros.transpose(0, 1)
else:
scale_zeros = None
# transpose [ci, co] -> [co, ci]
iweight = iweight.to(torch.int8).transpose(0, 1)
if bits == 4:
higher_bit_tensor = iweight[:, 1::2]
lower_bit_tensor = iweight[:, 0::2]
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
else:
packed_qweight = iweight
return packed_qweight, scale_zeros
def unpack_gptq_qweight_int32_into_int8(self, qweight: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qweight.device)
dtype = torch.int16 if bits == 8 else torch.int8
# unpacking columnwise
iweight = torch.bitwise_right_shift(qweight[:, None, :], shifts[None, :, None]).to(dtype)
iweight = iweight.view(-1, iweight.shape[-1])
# minus 2**(bit-1)
if self.quant_config.is_sym or self.quant_config.support_scale_zeros:
iweight = torch.bitwise_and(iweight - 2**(bits - 1), (2 ** bits) - 1)
return iweight
def unpack_gptq_qzeros_int32_into_int8(self, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qzeros.device)
dtype = torch.int16 if bits == 8 else torch.int8
# unpacking columnwise
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(dtype)
izeros = izeros.view(izeros.shape[0], -1)
izeros = izeros + 1
# minus 2**(bit-1)
if self.quant_config.is_sym:
izeros = torch.bitwise_and(izeros - 2**(bits - 1), (2 ** bits) - 1)
return izeros
def combine_low_bits(self, tensor_a, tensor_b):
"""
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
Args:
tensor_a (torch.Tensor): First tensor of type int8.
tensor_b (torch.Tensor): Second tensor of type int8.
Returns:
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
"""
# 确保输入是 int8 类型
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
raise ValueError("Both tensors must be of int8 type.")
# 提取每个 tensor 的低4位
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
# 将 tensor_a 的低4位左移4位
shifted_low_bits_a = low_bits_a << 4
# 组合两个 tensor 的低4位
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
return combined

View File

@@ -0,0 +1,192 @@
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm import _mlu_ops as mlu_ops
from vllm.logger import init_logger
logger = init_logger(__name__)
class SmoothQuantConfig(QuantizationConfig):
"""Config class for SmoothQuant.
"""
def __init__(
self,
weight_bits: int,
quant_mode: str, # smoothquant
input_quant_method: str, # per token/per tensor
) -> None:
self.weight_bits = weight_bits
self.quant_mode = quant_mode
self.input_quant_method = input_quant_method
if quant_mode == "SmoothQuant" and (self.weight_bits != 8):
raise ValueError(
"Currently, only 8-bit weight quantization is supported for "
f"SmoothQuant, but got {self.weight_bits} bits.")
if quant_mode == "SmoothQuant" and (self.input_quant_method != "per_token" and self.input_quant_method != "per_tensor"):
raise ValueError(
"Currently, only per_token or per_tensor input quantization is supported for "
f"SmoothQuant, but got {self.input_quant_method}.")
self.pack_factor = 8 // self.weight_bits
def __repr__(self) -> str:
return (f"SmoothQuantConfig(weight_bits={self.weight_bits}, "
f"input_quant_method={self.input_quant_method}, "
f"quant_mode={self.quant_mode})")
def get_name(self) -> str:
return "SmoothQuant"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
def get_min_capability(self) -> int:
return 30
@staticmethod
def get_config_filenames() -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
input_quant_method = cls.get_from_keys(config, ["input_quant_method"])
quant_mode = cls.get_from_keys(config, ["quant_mode"])
return cls(weight_bits, quant_mode, input_quant_method)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["SmoothQuantLinearMethod"]:
if isinstance(layer, LinearBase):
return SmoothQuantLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class SmoothQuantLinearMethod(LinearMethodBase):
"""Linear method for SmoothQuant.
Args:
quant_config: The SmoothQuant quantization config.
"""
def __init__(self, quant_config: SmoothQuantConfig):
self.quant_config = quant_config
# for per-tensor case, we can skip quant input for the first attn|ffn linear
# and fusion this step in layernorm to get better performance
self.skip_quant_input = False
self.compute_dtype = torch.get_default_dtype()
def create_weights(
self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs,
) -> Dict[str, Any]:
output_size_per_partition = sum(output_partition_sizes)
if self.quant_config.quant_mode == "SmoothQuant":
input_dim = None
if input_size != input_size_per_partition:
input_dim = 0
qweight = Parameter(
torch.empty(
output_size_per_partition,
input_size_per_partition // self.quant_config.pack_factor,
device="mlu",
dtype=torch.int8,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 1,
"output_dim": 0,
})
per_channel_scale = Parameter(
torch.empty(
output_size_per_partition,
device="mlu",
dtype=torch.float32,
),
requires_grad=False,
)
set_weight_attrs(per_channel_scale, {
"input_dim": None,
"output_dim": 0,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("per_channel_scale", per_channel_scale)
set_weight_attrs(per_channel_scale, extra_weight_attrs)
if self.quant_config.input_quant_method == "per_token":
smooth = Parameter(
torch.empty(
input_size_per_partition,
device="mlu",
dtype=torch.float32,
),
requires_grad=False,
)
set_weight_attrs(smooth, {
"input_dim": input_dim,
"output_dim": None,
"ignore_warning": True,
})
layer.register_parameter("smooth", smooth)
set_weight_attrs(smooth, extra_weight_attrs)
if self.quant_config.input_quant_method == "per_tensor":
scale_to_int = Parameter(
torch.empty(
input_size_per_partition,
device="mlu",
dtype=torch.float32,
),
requires_grad=False,
)
set_weight_attrs(scale_to_int, {
"input_dim": input_dim,
"output_dim": None,
"ignore_warning": True,
})
layer.register_parameter("scale_to_int", scale_to_int)
set_weight_attrs(scale_to_int, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.quant_config.input_quant_method == "per_token" and layer.smooth.dtype != torch.float:
layer.smooth = Parameter(layer.smooth.to(torch.float), requires_grad=False)
if self.quant_config.input_quant_method == "per_tensor" and layer.scale_to_int.dtype != torch.float:
layer.scale_to_int = Parameter(layer.scale_to_int.to(torch.float), requires_grad=False)
if layer.per_channel_scale.dtype != torch.float:
layer.per_channel_scale = Parameter(layer.per_channel_scale.to(torch.float), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
quant_input = None
if self.skip_quant_input:
quant_input = x
elif self.quant_config.input_quant_method == "per_token":
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer.smooth, None)
elif self.quant_config.input_quant_method == "per_tensor":
quant_input = mlu_ops.quantize(x, layer.scale_to_int, None)
else:
raise ValueError(
"Currently, only per_token or per_tensor input quantization is supported for "
f"SmoothQuant, but got {self.input_quant_method}.")
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer.qweight,
layer.per_channel_scale, self.compute_dtype, bias, residual)
return out

View File

@@ -0,0 +1,143 @@
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm import _mlu_ops as mlu_ops
from vllm.logger import init_logger
logger = init_logger(__name__)
class WeightOnlyConfig(QuantizationConfig):
"""Config class for WeightOnly.
"""
def __init__(
self,
weight_bits: int,
quant_mode: str, # weight_only
) -> None:
self.weight_bits = weight_bits
self.quant_mode = quant_mode
if quant_mode == "WeightOnly" and (self.weight_bits != 8 and self.weight_bits != 4):
raise ValueError(
"Currently, only 8/4-bit weight quantization is supported for "
f"weight_only, but got {self.weight_bits} bits.")
self.pack_factor = 8 // self.weight_bits
def __repr__(self) -> str:
return (f"WeightOnlyConfig(weight_bits={self.weight_bits}, "
f"quant_mode={self.quant_mode})")
def get_name(self) -> str:
return "WeightOnly"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
def get_min_capability(self) -> int:
return 30
@staticmethod
def get_config_filenames() -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "WeightOnlyConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
try:
quant_mode = cls.get_from_keys(config, ["quant_mode"])
except Exception:
quant_mode = "WeightOnly"
return cls(weight_bits, quant_mode)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["WeightOnlyLinearMethod"]:
if isinstance(layer, LinearBase):
return WeightOnlyLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class WeightOnlyLinearMethod(LinearMethodBase):
"""Linear method for WeightOnly.
Args:
quant_config: The WeightOnly quantization config.
"""
def __init__(self, quant_config: WeightOnlyConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> Dict[str, Any]:
output_size_per_partition = sum(output_partition_sizes)
if self.quant_config.quant_mode == "WeightOnly":
scale_and_zero_input_dim = None
if output_size != output_size_per_partition:
scale_and_zero_input_dim = 0
qweight = Parameter(
torch.empty(
output_size_per_partition,
input_size_per_partition // self.quant_config.pack_factor,
device="mlu",
dtype=torch.int8,
),
requires_grad=False,
)
set_weight_attrs(qweight, {
"input_dim": 1,
"output_dim": 0,
})
scales = Parameter(
torch.empty(
output_size_per_partition,
device="mlu",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": scale_and_zero_input_dim,
"output_dim": 0,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if layer.scales.dtype != torch.float:
layer.scales = Parameter(layer.scales.to(torch.float), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
out = mlu_ops.weight_only_quant_matmul(x,
layer.qweight,
layer.scales,
None,
bias,
residual,
"none",
self.quant_config.weight_bits)
return out

View File

@@ -0,0 +1,647 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import math
import torch
from vllm.attention import AttentionMetadata
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding, MRotaryEmbedding,
LinearScalingRotaryEmbedding, DeepseekScalingRotaryEmbedding,
DynamicNTKScalingRotaryEmbedding, DynamicNTKAlphaRotaryEmbedding,
YaRNScalingRotaryEmbedding, Phi3LongRoPEScaledRotaryEmbedding,
_yarn_find_correction_range, _ROPE_DICT, yarn_get_mscale, _yarn_linear_ramp_mask)
from vllm.model_executor.layers import rotary_embedding
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.envs import VLLM_ALLOW_LONG_MAX_MODEL_LEN
from vllm.logger import init_logger
logger = init_logger(__name__)
def get_long_max_model_max_position_emb(max_position_embeddings, scaling_factor):
if MLURotaryEmbedding.max_seq_len != None and \
MLURotaryEmbedding.max_seq_len > max_position_embeddings * scaling_factor:
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
f"max_position_embedding ({max_position_embeddings}) * scaling_factor ({scaling_factor}) " +
"from model's config.json, This may lead to incorrect model outputs or MLU errors. " +
f"Make sure the value is correct and within the model context size. " +
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
return math.ceil(MLURotaryEmbedding.max_seq_len / scaling_factor)
return max_position_embeddings
@CustomOp.register("rotary_embedding_mlu")
class MLURotaryEmbedding(RotaryEmbedding, CustomOp):
cu_seq_lens : torch.Tensor = None
max_seq_len : int = None
is_prompt : bool = False
is_chunked : bool = False
set_cos_sin : bool = False
cos_ : torch.Tensor = None
sin_ : torch.Tensor = None
positions_: torch.Tensor = None
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
CustomOp.__init__(self)
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
if MLURotaryEmbedding.max_seq_len != None \
and self.max_position_embeddings < MLURotaryEmbedding.max_seq_len and \
not isinstance(self, (YaRNScalingRotaryEmbedding, DeepseekScalingRotaryEmbedding)):
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
f"max_position_embedding ({max_position_embeddings}) from model's config.json, " +
f"This may lead to incorrect model outputs or MLU errors. " +
f"Make sure the value is correct and within the model context size. " +
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
self.max_position_embeddings = MLURotaryEmbedding.max_seq_len
cache = self._compute_cos_sin_cache()
if isinstance(self, MLULinearScalingRotaryEmbedding):
logger.debug(f"Using mlu defining _compute_cos_sin_cache due to the special tensor composition")
elif is_neox_style:
cache_pos = cache.shape[0]
cache = cache.reshape(cache_pos, 2, -1)
cache = torch.tile(cache, (1, 1, 2)).reshape(cache_pos, -1)
else:
cache = cache.repeat_interleave(2, dim=-1)
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
@classmethod
def set_mlu_var(
cls,
input_ids: torch.Tensor,
attn_metadata: AttentionMetadata
) -> None:
cls.unset_mlu_var()
is_chunked = False
is_prompt = False
prefill_metadata = attn_metadata.prefill_metadata
decode_metadata = attn_metadata.decode_metadata
if prefill_metadata:
cu_seq_lens = prefill_metadata.query_start_loc
rope_max_seq_len = prefill_metadata.max_query_len
is_prompt = True
# Workaround: mlugraph does not support torch.ne|eq|equal .etc for now,
# because context mlugraph always uses in benchmark latency, and in this
# case, query_start_loc always equals to seq_start_loc, so we can set
# is_chunked to False directly.
if prefill_metadata.use_cuda_graph:
is_chunked = False
elif decode_metadata or \
max(prefill_metadata.seq_lens) != prefill_metadata.max_query_len:
is_chunked = True
if decode_metadata:
if prefill_metadata:
cu_seq_lens = attn_metadata.query_start_loc
rope_max_seq_len = max(rope_max_seq_len,
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens)
else:
# input_ids is pack mode, and the decode_seq_len = 1
cu_seq_lens = torch.arange(0, input_ids.shape[0] + 1, 1, dtype=torch.int32, device="mlu")
rope_max_seq_len = 1
cls.cu_seq_lens = cu_seq_lens
cls.max_seq_len = rope_max_seq_len
cls.is_prompt = is_prompt
cls.is_chunked = is_chunked
@classmethod
def unset_mlu_var(cls):
cls.cu_seq_lens = None
cls.max_seq_len = None
cls.is_prompt = False
cls.is_chunked = False
cls.set_cos_sin = False
cls.cos_ = None
cls.sin_ = None
cls.positions_ = None
def _get_cos_sin(self) -> Tuple[torch.Tensor, torch.Tensor]:
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
sin = sin.view(-1, self.rotary_dim)
cos = cos.view(-1, self.rotary_dim)
return cos, sin
def _get_positions_with_offsets_mlu(
self,
positions: torch.Tensor,
offsets: torch.Tensor
) -> torch.Tensor:
if offsets.numel() != positions.numel():
raise Exception("rope offsets numel mismatch with positions, "
f"positions: {positions.numel()}, offsets: {offsets.numel()}")
return (positions + offsets).to(torch.int32)
def forward_mlu(
self,
positions: torch.Tensor,
x: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _mlu_ops as mlu_ops
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if MLURotaryEmbedding.set_cos_sin == False:
MLURotaryEmbedding.cos_, MLURotaryEmbedding.sin_ = self._get_cos_sin()
MLURotaryEmbedding.set_cos_sin = True
interleaved = True
if self.is_neox_style:
interleaved = False
if offsets is not None:
if MLURotaryEmbedding.positions_ is None:
MLURotaryEmbedding.positions_ = (
self._get_positions_with_offsets_mlu(positions, offsets))
position_ids = MLURotaryEmbedding.positions_
discrete = True
elif MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
position_ids = positions
discrete = True
else:
position_ids = None
discrete = False
x = mlu_ops.rotary_embedding(x,
MLURotaryEmbedding.sin_,
MLURotaryEmbedding.cos_,
position_ids,
MLURotaryEmbedding.cu_seq_lens,
interleaved,
discrete,
False,
MLURotaryEmbedding.max_seq_len)
return x
class MLULinearScalingRotaryEmbedding(MLURotaryEmbedding, LinearScalingRotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factors: Union[List[float], float],
dtype: torch.dtype,
) -> None:
if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors: List[float] = scaling_factors # noqa
MLURotaryEmbedding.__init__(self, head_size, rotary_dim,
max_position_embeddings, base,
is_neox_style, dtype)
# Lazy initialized.
self._scaling_factor_to_offset: Dict[float, int]
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
half_dim = self.rotary_dim // 2
if self.is_neox_style:
inv_freq = 1.0 / (base ** ((torch.arange(
0, self.rotary_dim, 1, dtype=torch.float32, device="mlu") % half_dim) * 2 / self.rotary_dim)
)
else:
inv_freq = 1.0 / (
base
** ( torch.arange(0, self.rotary_dim, 1, device="mlu", dtype=torch.float32) // 2 * 2
/ self.rotary_dim
)
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
cache_list: List[torch.Tensor] = []
# offsets to the next cache in a tensor.
# Each offset corresponds to the same index in scaling_factors.
offsets: List[int] = []
for scaling_factor in self.scaling_factors:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * scaling_factor
t = torch.arange(max_len, dtype=torch.float, device="mlu")
t = t / scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
if not cache_list:
offset = 0
else:
last_offset = offsets[-1]
next_max_len = cache_list[-1].shape[0]
offset = last_offset + next_max_len
offsets.append(offset)
cache_list.append(cache)
self._scaling_factor_to_offset = {
float(scaling_factor): offsets[i]
for i, scaling_factor in enumerate(self.scaling_factors)
}
assert len(self.scaling_factors) == len(offsets)
return torch.cat(cache_list, dim=0)
class MLUDeepseekScalingRotaryEmbedding(MLURotaryEmbedding, DeepseekScalingRotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
MLURotaryEmbedding.__init__(self, head_size, rotary_dim,
max_position_embeddings, base,
is_neox_style, dtype)
def forward_mlu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _mlu_ops as mlu_ops
"""PyTorch-native implementation equivalent to forward()."""
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
if MLURotaryEmbedding.set_cos_sin == False:
MLURotaryEmbedding.cos_, MLURotaryEmbedding.sin_ = self._get_cos_sin()
MLURotaryEmbedding.set_cos_sin = True
interleaved = True
if self.is_neox_style:
interleaved = False
if MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
position_ids = positions
discrete = True
else :
position_ids = None
discrete = False
query_rot = mlu_ops.rotary_embedding(query_rot,
MLURotaryEmbedding.sin_,
MLURotaryEmbedding.cos_,
position_ids,
MLURotaryEmbedding.cu_seq_lens,
interleaved,
discrete,
False,
MLURotaryEmbedding.max_seq_len)
key_rot = mlu_ops.rotary_embedding(key_rot,
MLURotaryEmbedding.sin_,
MLURotaryEmbedding.cos_,
position_ids,
MLURotaryEmbedding.cu_seq_lens,
interleaved,
discrete,
False,
MLURotaryEmbedding.max_seq_len)
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: change device cuda to mlu
'''
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="mlu") /
self.rotary_dim)
'''
==================
End of MLU Hijack
==================
'''
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
'''
=============================
Modify by vllm_mlu
=============================
@brief: change device cuda to mlu
'''
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="mlu",
dtype=torch.float32)
'''
==================
End of MLU Hijack
==================
'''
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
print("Cache shape", cache.shape)
return cache
class MLULlama3RotaryEmbedding(MLURotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
scaling_factor: float,
low_freq_factor: float,
high_freq_factor: float,
orig_max_position: int,
) -> None:
self.scaling_factor = scaling_factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
class MLUDynamicNTKAlphaRotaryEmbedding(MLURotaryEmbedding, DynamicNTKAlphaRotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_alpha: float,
dtype: torch.dtype,
) -> None:
self.scaling_alpha = scaling_alpha
MLURotaryEmbedding.__init__(
self, head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
class MLUMRotaryEmbedding(MLURotaryEmbedding, MRotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[List[int]] = None,
) -> None:
MLURotaryEmbedding.__init__(
self, head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
self.mrope_section = mrope_section
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
def forward_mlu(
self,
positions: torch.Tensor,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert positions.ndim == 1 or positions.ndim == 2
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
num_section = len(self.mrope_section)
mrope_section = self.mrope_section * 2
cos = torch.cat([
m[i % num_section]
for i, m in enumerate(cos.split(mrope_section, dim=-1))
],
dim=-1)
sin = torch.cat([
m[i % num_section]
for i, m in enumerate(sin.split(mrope_section, dim=-1))
],
dim=-1)
from vllm import _mlu_ops as mlu_ops
interleaved = True
if self.is_neox_style:
interleaved = False
position_ids = None
discrete = False
# mlu_ops.rotary_embedding() is a in-place operation that update the query and key tensors.
x = mlu_ops.rotary_embedding(x,
sin,
cos,
position_ids,
MLURotaryEmbedding.cu_seq_lens,
interleaved,
discrete,
False,
MLURotaryEmbedding.max_seq_len)
return x
def vllm__model_executor__layers__rotary_embedding__get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args, dtype)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None:
rotary_emb = MLURotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
else:
scaling_type = rope_scaling["rope_type"]
if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
rotary_emb = MLULlama3RotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype,
scaling_factor, low_freq_factor,
high_freq_factor,
original_max_position)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MLUMRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else:
rotary_emb = MLURotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
rotary_emb = MLULinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor, dtype)
elif scaling_type == "dynamic":
if "alpha" in rope_scaling:
rotary_emb = MLUDynamicNTKAlphaRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling["alpha"], dtype)
else:
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow")
}
original_max_position = get_long_max_model_max_position_emb(original_max_position, scaling_factor)
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
base, is_neox_style,
scaling_factor, dtype,
**extra_kwargs)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow", "mscale", "mscale_all_dim")
}
original_max_position = get_long_max_model_max_position_emb(original_max_position, scaling_factor)
rotary_emb = MLUDeepseekScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs)
elif scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb
MluHijackObject.apply_hijack(rotary_embedding,
rotary_embedding.get_rope,
vllm__model_executor__layers__rotary_embedding__get_rope)

View File

@@ -0,0 +1,468 @@
"""Inference-only MOE model."""
from typing import Optional
import torch
from torch import nn
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.quantization.weightonly import WeightOnlyConfig
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantConfig
from vllm_mlu._mlu_utils import get_device_major_capability
from vllm.platforms import current_platform
class SparseMoeMlp(nn.Module):
"""
Tensor Parallel evenly splits each expert's weight and distributes them to different ranks,
which means each rank holds partial weight of all experts.
While Expert Parallel evenly distributes some of the experts' full weight to different ranks,
which means each rank holds part of the experts' full weight.
As a result, each rank in the Tensor Parallel group receives all tokens' hidden states for all experts,
then computes using the partial weights, while for Expert Parallel, each rank only receives
part of tokens' hidden states for experts on this rank, then computes using the full weights.
When both Tensor Parallel and Expert Parallel are enabled, each rank handles
a portion of the expert weights matrices (as in EP mode) and these weights are further sliced
across ranks (as in TP mode). This hybrid approach aims to balance the workload more evenly across ranks,
enhancing efficiency and reducing the likelihood of bottlenecks associated with EP mode alone.
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
up_proj_name: str,
is_gated: bool,
down_proj_name: str,
has_bias: bool,
skip_bias_add: bool = False,
renormalize:bool = False,
hidden_act: str = "silu",
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
is_use_fused_moe: bool = False,
expert_group: Optional[int] = 1,
topk_group: Optional[int] = 1,
):
super().__init__()
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tensor_model_parallel_group()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.up_proj_name = up_proj_name
self.is_gated = is_gated
self.down_proj_name = down_proj_name
self.has_bias = has_bias
self.renormalize = renormalize
self.hidden_act = hidden_act
self.quant_config = quant_config
self.is_use_fused_moe = is_use_fused_moe
self.expert_group = expert_group
self.topk_group = topk_group
if get_device_major_capability() == 3:
self.is_use_fused_moe = False
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# NOTE: The bias for fc2 is only applied on tp_rank 0. If we added it on all nodes the allreduce() would
# contain multiple copies of the bias. The bias on other node will be ignored, and may be set to nullptr
self.skip_bias_add = True if self.tp_rank > 0 else False
assert self.intermediate_size % self.tp_size == 0, (
f"need intermediate_size:{self.intermediate_size} % tp_size:{self.tp_size} == 0")
self.num_experts_per_rank = self.num_total_experts
self.start_expert_id = 0
self.end_expert_id = self.start_expert_id + self.num_experts_per_rank
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=None)
self.experts = nn.ModuleList([
FeedForward(hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
up_proj_name=self.up_proj_name,
is_gated=self.is_gated,
down_proj_name=self.down_proj_name,
bias=self.has_bias,
quant_config=self.quant_config,
skip_bias_add=self.skip_bias_add,
reduce_results=False) for idx in range(self.num_experts_per_rank)
])
self.init_pack_param()
def init_pack_param(self):
self.w13 = None
self.w2 = None
self.b13 = None
self.b2 = None
self.w13_scale = None
self.w2_scale = None
self.a13_scale = None
self.a2_scale = None
self.pack_params_done = False
def map_param_data(self, param_list, is_use_first_data=False):
if len(param_list) == 0:
return None
if is_use_first_data or len(param_list) == 1:
first_data = param_list[0].data
for param in param_list[1: -1]:
param.data = first_data
out_param = first_data.view_as(param_list[0])
else:
packed_param = torch._utils._flatten_dense_tensors(param_list)
data_list = torch._utils._unflatten_dense_tensors(packed_param, param_list)
for data, param in zip(data_list, param_list):
param.data = data
out_param = packed_param.view(len(param_list), *data_list[0].shape)
torch.mlu.empty_cache()
return out_param
def pack_unquantized_params(self, w13, w2, b13, b2):
for expert in self.experts:
up_proj = getattr(expert, self.up_proj_name)
down_proj = getattr(expert, self.down_proj_name)
w13.append(up_proj.weight)
w2.append(down_proj.weight)
if self.has_bias:
b13.append(up_proj.bias)
b2.append(down_proj.bias)
def pack_smoothquant_params(self, w13, w2, b13, b2, w13_scale, w2_scale, a13_scale, a2_scale):
for expert in self.experts:
up_proj = getattr(expert, self.up_proj_name)
down_proj = getattr(expert, self.down_proj_name)
w13.append(up_proj.qweight)
w2.append(down_proj.qweight)
if self.has_bias:
b13.append(up_proj.bias)
b2.append(down_proj.bias)
w13_scale.append(up_proj.per_channel_scale)
w2_scale.append(down_proj.per_channel_scale)
if self.quant_config.input_quant_method == "per_token":
a13_scale.append(up_proj.smooth)
a2_scale.append(down_proj.smooth)
else:
a13_scale.append(up_proj.scale_to_int)
a2_scale.append(down_proj.scale_to_int)
def pack_weightonly_params(self, w13, w2, b13, b2, w13_scale, w2_scale):
for expert in self.experts:
up_proj = getattr(expert, self.up_proj_name)
down_proj = getattr(expert, self.down_proj_name)
w13.append(up_proj.qweight)
w2.append(down_proj.qweight)
if self.has_bias:
b13.append(up_proj.bias)
b2.append(down_proj.bias)
w13_scale.append(up_proj.per_channel_scale)
w2_scale.append(up_proj.per_channel_scale)
def pack_params(self):
if self.pack_params_done:
return
w13 = []
w2 = []
b13 = []
b2 = []
w13_scale = []
w2_scale = []
a13_scale = []
a2_scale = []
if self.quant_config is None:
self.pack_unquantized_params(w13, w2, b13, b2)
elif isinstance(self.quant_config, SmoothQuantConfig):
self.pack_smoothquant_params(w13, w2, b13, b2, w13_scale, w2_scale, a13_scale, a2_scale)
elif isinstance(self.quant_config, WeightOnlyConfig):
self.pack_weightonly_params(w13, w2, b13, b2, w13_scale, w2_scale)
else:
raise ValueError(f'Unsupported quantization:{self.quant_config}')
# pack weigth
self.w13 = self.map_param_data(w13)
self.w2 = self.map_param_data(w2)
# pack bias
if self.has_bias:
self.b13 = self.map_param_data(b13)
# NOTE: The bias for fc2 is only applied on tp_rank 0. If we added it on all nodes the allreduce() would
# contain multiple copies of the bias. The bias on other node will be ignored, and may be set to nullptr
if self.skip_bias_add is False:
self.b2 = self.map_param_data(b2)
# pack weight scale
if len(w13_scale) > 0:
self.w13_scale = self.map_param_data(w13_scale)
if len(w2_scale) > 0:
self.w2_scale = self.map_param_data(w2_scale)
# pack activate scale
if len(a13_scale) > 0:
self.a13_scale = self.map_param_data(a13_scale)
if len(a2_scale) > 0:
self.a2_scale = self.map_param_data(a2_scale)
self.pack_params_done = True
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
orig_hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# expert_logits: [num_tokens, self.num_experts_per_rank]
expert_logits, _ = self.gate(hidden_states)
final_hidden_states = self.forward_experts(hidden_states, expert_logits, residual)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
output = final_hidden_states.view(orig_hidden_states_shape)
return output
def forward_experts(self, hidden_states, expert_logits, residual: Optional[torch.Tensor] = None):
residual_ = None if self.tp_rank > 0 else residual
if self.is_use_fused_moe and self.expert_group != 1:
final_hidden_states = self.forward_group_experts(hidden_states, expert_logits, residual_)
elif self.is_use_fused_moe:
self.pack_params()
final_hidden_states = mlu_ops.fused_moe(hidden_states=hidden_states,
gating_output=expert_logits,
w1=self.w13,
w2=self.w2,
bias1=self.b13,
bias2=self.b2,
residual=residual_,
input_smooth=self.a13_scale,
act_smooth=self.a2_scale,
w1_scale=self.w13_scale,
w2_scale=self.w2_scale,
topk=self.top_k,
renormalize=self.renormalize,
gated=self.is_gated,
act_mode=self.hidden_act,
start_expert_id=self.start_expert_id)
else:
final_hidden_states = self.forward_experts_nofused(hidden_states, expert_logits)
if residual_ is not None:
final_hidden_states = final_hidden_states + residual_
return final_hidden_states
def forward_experts_nofused(self, hidden_states, expert_logits):
hidden_states_shape = hidden_states.shape
topk_values, topk_indices = self.topk_softmax(expert_logits)
expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count = self.generate_gather_idx(
topk_indices)
# no expert is routed, then expand_gather_idx, expand_scatter_idx has no item,
# expand_token_count and expand_cusum_token_count has item but the value is all zero
# so this rank should only return final_hidden_states with zero value
if expand_gather_idx.numel() == 0:
final_hidden_states = torch.zeros_like(hidden_states,
dtype=hidden_states.dtype,
device=hidden_states.device)
return final_hidden_states
expand_hidden_states = self.expand_input(hidden_states, expand_gather_idx)
expand_output_list = []
expand_cusum_token_count = cusum_token_count[self.start_expert_id:self.end_expert_id +
1] - cusum_token_count[self.start_expert_id]
for expert_idx, num_tokens_per_expert in enumerate(expand_token_count):
if num_tokens_per_expert > 0:
expert_hidden_states = expand_hidden_states[
expand_cusum_token_count[expert_idx]:expand_cusum_token_count[expert_idx + 1]]
expert_output = self.experts[expert_idx](expert_hidden_states)
expert_output = expert_output[0] if isinstance(expert_output, (tuple, list)) else expert_output
expand_output_list.append(expert_output)
expand_output = torch.cat(expand_output_list, dim=0)
final_hidden_states = self.combine_moe(expand_output, scatter_idx, cusum_token_count, hidden_states_shape,
topk_values)
return final_hidden_states
def forward_group_experts(self, hidden_states, expert_logits, residual_):
ori_input_shape = hidden_states.shape
dtype = hidden_states.dtype
self.pack_params()
gating_output=expert_logits.to(torch.float32)
w1=self.w13
w2=self.w2
bias1=self.b13
bias2=self.b2
input_smooth=self.a13_scale
act_smooth=self.a2_scale
w1_scale=self.w13_scale
w2_scale=self.w2_scale
topk=self.top_k
renormalized=self.renormalize
gated=self.is_gated
act_mode=self.hidden_act
start_expert_id=self.start_expert_id
expert_num = gating_output.size(-1)
expert_size = w1.size(0)
max_m = hidden_states.shape[0]
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
gating_output = gating_output.view(-1, gating_output.size(-1))
residual_ = residual_.view(-1, residual_.size(-1)) if residual_ is not None else None
per_token_sq = False
# check quant
check_list = [input_smooth, act_smooth, w1_scale, w2_scale]
if all(x is not None for x in check_list):
per_token_sq = True
if not (all(x is None for x in check_list) or all(x is not None for x in check_list)):
raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present "
"and absent at the same time.")
# softmax_topk
reduce_weight, expert_id = mlu_ops.moe_softmax_topk(gating_output,
topk, renormalized)
# gen_idx
expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(expert_id, expert_num)
# check quant
if per_token_sq:
major, minor = current_platform.get_device_capability()
if major == 3:
expand_hidden_states = mlu_ops.moe_expand_input(hidden_states, expand_idx,
cusum_token_count, start_expert_id, expert_size)
quant_input, input_scale = mlu_ops.moe_quantize(expand_hidden_states,
input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size])
else:
quant_input, input_scale = mlu_ops.moe_quantize(hidden_states,
input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size], expand_idx,
cusum_token_count[start_expert_id].unsqueeze(0))
else:
expand_hidden_states = mlu_ops.moe_expand_input(hidden_states, expand_idx,
cusum_token_count, start_expert_id, expert_size)
if per_token_sq:
gemm1_out = mlu_ops.smooth_quant_group_gemm(quant_input, w1,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None,
input_scale, w1_scale, dtype, max_m)
else:
gemm1_out = mlu_ops.group_gemm(expand_hidden_states, w1,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, max_m)
# add_bias_active
act_out = mlu_ops.moe_active(gemm1_out, act_mode, gated, None, bias1, cusum_token_count, start_expert_id, expert_size)
if per_token_sq:
quant_input, input_scale = mlu_ops.moe_quantize(act_out, act_smooth, None,
token_count[start_expert_id:start_expert_id+expert_size])
if per_token_sq:
gemm2_out = mlu_ops.smooth_quant_group_gemm(quant_input, w2,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, input_scale, w2_scale, dtype, max_m)
else:
gemm2_out = mlu_ops.group_gemm(act_out, w2,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, max_m)
output = mlu_ops.moe_combine_result(gemm2_out, reduce_weight, combine_idx,
residual_, cusum_token_count, start_expert_id,
expert_size, bias2)
return output.view(ori_input_shape)
def topk_softmax(self, expert_logits):
# expert_logits: [num_tokens, self.num_experts_per_rank]
# topk_values: [num_tokens, self.top_k]
# topk_indices: [num_tokens, self.top_k]
if self.renormalize:
topk_values, topk_indices = torch.topk(expert_logits, self.top_k, dim=-1)
topk_values = torch.softmax(topk_values, -1)
else:
router_probs = torch.softmax(expert_logits, -1)
topk_values, topk_indices = torch.topk(router_probs, self.top_k, dim=-1)
return topk_values, topk_indices
def generate_gather_idx(self, topk_indices):
device = topk_indices.device
# gather_expand_idx: [num_tokens * self.top_k]
sorted_expert_id, indices = topk_indices.flatten().sort()
gather_idx = indices // self.top_k
seqs = torch.arange(indices.numel(), dtype=indices.dtype, device=indices.device)
scatter_idx=torch.zeros((indices.numel(),), dtype=seqs.dtype, device=seqs.device).scatter(0, indices, seqs)
# token_count: [self.num_experts_per_rank]
partial_token_index, partial_token_count = sorted_expert_id.unique(sorted=True, return_counts=True)
zero_token_count = torch.zeros(self.num_total_experts, dtype=partial_token_count.dtype, device=device)
token_count = zero_token_count.scatter(dim=0, index=partial_token_index, src=partial_token_count)
# cusum_token_count: [self.num_experts_per_rank + 1]
cusum_token_count = torch.cat(
[torch.tensor([0], dtype=token_count.dtype, device=device),
token_count.cumsum(dim=0)])
num_tokens_before_expert = cusum_token_count[self.start_expert_id]
num_tokens_including_expert = cusum_token_count[self.end_expert_id]
expand_gather_idx = gather_idx[num_tokens_before_expert:num_tokens_including_expert]
expand_token_count = token_count[self.start_expert_id:self.end_expert_id]
return expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count
def expand_input(self, hidden_states, expand_gather_idx):
expand_hidden_states = hidden_states[expand_gather_idx]
return expand_hidden_states
def combine_moe(self, expand_output, scatter_idx, cusum_token_count, hidden_states_shape, topk_values):
num_tokens, hidden_size = hidden_states_shape
num_tokens_before_expert = cusum_token_count[self.start_expert_id]
num_tokens_after_expert = cusum_token_count[-1] - cusum_token_count[self.end_expert_id]
expand_output_before_expert = torch.zeros((num_tokens_before_expert, hidden_size),
dtype=expand_output.dtype,
device=expand_output.device)
expand_output_after_expert = torch.zeros((num_tokens_after_expert, hidden_size),
dtype=expand_output.dtype,
device=expand_output.device)
unscatted_output = torch.cat([expand_output_before_expert, expand_output, expand_output_after_expert], dim=0)
scatter_output = unscatted_output[scatter_idx]
hidden_states_weight = topk_values.flatten().unsqueeze(-1)
weighted_hidden_states = scatter_output * hidden_states_weight
unreduced_hidden_states = weighted_hidden_states.view(num_tokens, self.top_k, hidden_size)
final_hidden_states = unreduced_hidden_states.sum(dim=1)
return final_hidden_states

View File

@@ -0,0 +1,42 @@
from typing import Union
import torch
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler
from vllm.platforms import current_platform
def vllm__model_executor__layers__spec_decode_base_sampler__SpecDecodeBaseSampler__init_gpu_tensors(
self, device: Union[int, str]
) -> None:
assert self.num_accepted_tokens is None
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add mlu device support.
'''
if isinstance(device, int) and current_platform.is_mlu():
device = f"mlu:{device}"
elif isinstance(device, int) and current_platform.is_cuda():
device = f"cuda:{device}"
elif not isinstance(device, str):
raise ValueError(f"Device must be int or str, get {type(device)}")
'''
==================
End of MLU Hijack
==================
'''
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
MluHijackObject.apply_hijack(
SpecDecodeBaseSampler,
SpecDecodeBaseSampler.init_gpu_tensors,
vllm__model_executor__layers__spec_decode_base_sampler__SpecDecodeBaseSampler__init_gpu_tensors
)

View File

@@ -0,0 +1,2 @@
import vllm_mlu.model_executor.model_loader.loader
import vllm_mlu.model_executor.model_loader.tensorizer

View File

@@ -0,0 +1,138 @@
import torch
from tqdm.auto import tqdm
from safetensors.torch import safe_open
from typing import List, Tuple, Generator
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import (np_cache_weights_iterator,
_BAR_FORMAT)
from vllm.config import LoadFormat
from vllm.platforms import current_platform
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu._mlu_utils import get_device_major_capability
CAST_BFLOAT16_TO_FLOAT16_ENABLE = (get_device_major_capability() == 3)
def vllm__model_executor__model_loader__weight_utils__safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
'''
=============================
Modify by vllm_mlu
=============================
@brief: cast bfloat16 to float16 for MLU3xx
'''
if CAST_BFLOAT16_TO_FLOAT16_ENABLE and param.dtype == torch.bfloat16:
param = param.to(torch.float16)
'''
==================
End of MLU Hijack
==================
'''
yield name, param
def vllm__model_executor__model_loader__weight_utils__pt_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
'''
=============================
Modify by vllm_mlu
=============================
@brief: cast bfloat16 to float16 for MLU3xx
'''
if CAST_BFLOAT16_TO_FLOAT16_ENABLE and param.dtype == torch.bfloat16:
param = param.to(torch.float16)
'''
==================
End of MLU Hijack
==================
'''
yield name, param
del state
torch.mlu.empty_cache()
def vllm__model_executor__model_loader__loader__DefaultModelLoader___get_weights_iterator(
self, source: "DefaultModelLoader.Source"
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt)
if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
weights_iterator = np_cache_weights_iterator(
source.model_or_path, self.load_config.download_dir, hf_folder,
hf_weights_files)
elif use_safetensors:
'''
=============================
Modify by vllm_mlu
=============================
@brief: cast bfloat16 to float16 for MLU3xx
'''
weights_iterator = vllm__model_executor__model_loader__weight_utils__safetensors_weights_iterator(hf_weights_files)
'''
==================
End of MLU Hijack
==================
'''
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: cast bfloat16 to float16 for MLU3xx
'''
weights_iterator = vllm__model_executor__model_loader__weight_utils__pt_weights_iterator(hf_weights_files)
'''
==================
End of MLU Hijack
==================
'''
if current_platform.is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import torch_xla.core.xla_model as xm
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
xm.mark_step()
weights_iterator = _xla_weights_iterator(weights_iterator)
# Apply the prefix.
return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator)
MluHijackObject.apply_hijack(DefaultModelLoader,
DefaultModelLoader._get_weights_iterator,
vllm__model_executor__model_loader__loader__DefaultModelLoader___get_weights_iterator)

View File

@@ -0,0 +1,68 @@
import time
import torch
from vllm.model_executor.model_loader.tensorizer import (TensorizerAgent,
TensorDeserializer,
get_mem_usage,
_read_stream,
convert_bytes)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
logger = init_logger(__name__)
def vllm__model_executor__model_loader__tensorizer__TensorizerAgent__deserialize(self):
"""
Deserialize the model using the TensorDeserializer. This method is
specifically for vLLM models using tensorizer's plaid_mode.
The deserializer makes use of tensorizer_args.stream_params
to configure the behavior of the stream when loading tensors from a
serialized model. The deserializer_params are used to configure the
behavior of the TensorDeserializer when loading tensors themselves.
Documentation on these params can be found in TensorizerArgs
Returns:
nn.Module: The deserialized model.
"""
before_mem = get_mem_usage()
start = time.perf_counter()
'''
=============================
Modify by vllm_mlu
=============================
@brief: use mlu device
'''
with _read_stream(
self.tensorizer_config.tensorizer_uri,
**self.tensorizer_args.stream_params
) as stream, TensorDeserializer(
stream,
dtype=self.tensorizer_config.dtype,
device=f'mlu:{torch.mlu.current_device()}',
**self.tensorizer_args.deserializer_params) as deserializer:
deserializer.load_into_module(self.model)
end = time.perf_counter()
'''
==================
End of MLU Hijack
==================
'''
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
deserializer.close()
logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
end - start, per_second)
logger.info("Memory usage before: %s", before_mem)
logger.info("Memory usage after: %s", after_mem)
self._check_tensors_on_meta_device()
self._resize_lora_embeddings()
del self.model.vllm_tensorized_marker
return self.model.eval()
MluHijackObject.apply_hijack(TensorizerAgent,
TensorizerAgent.deserialize,
vllm__model_executor__model_loader__tensorizer__TensorizerAgent__deserialize)

View File

@@ -0,0 +1,41 @@
import vllm_mlu.model_executor.models.deepseek_v2
import vllm_mlu.model_executor.models.baichuan
import vllm_mlu.model_executor.models.bloom
import vllm_mlu.model_executor.models.chatglm
# Multimodal models - may fail with older transformers versions
try:
import vllm_mlu.model_executor.models.clip
except ImportError as e:
import logging
logging.warning(f"Failed to import clip hijack: {e}")
import vllm_mlu.model_executor.models.gpt_neox
import vllm_mlu.model_executor.models.llama
import vllm_mlu.model_executor.models.mixtral
import vllm_mlu.model_executor.models.qwen
import vllm_mlu.model_executor.models.qwen2
import vllm_mlu.model_executor.models.qwen2_moe
try:
import vllm_mlu.model_executor.models.qwen2_vl
except ImportError as e:
import logging
logging.warning(f"Failed to import qwen2_vl hijack: {e}")
try:
import vllm_mlu.model_executor.models.qwen3
except ImportError as e:
import logging
logging.warning(f"Failed to import qwen3 hijack: {e}")
import vllm_mlu.model_executor.models.falcon
import vllm_mlu.model_executor.models.internlm2
import vllm_mlu.model_executor.models.hunyuan
import vllm_mlu.model_executor.models.layer_utils
try:
import vllm_mlu.model_executor.models.mllama
except ImportError as e:
import logging
logging.warning(f"Failed to import mllama hijack: {e}")

View File

@@ -0,0 +1,309 @@
import torch
from typing import List, Optional, Union
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.sequence import IntermediateTensors
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.models.baichuan import (
_get_alibi_slopes, BaiChuanAttention,
BaiChuanDecoderLayer, BaiChuanModel)
from vllm_mlu.model_executor.models.layer_utils import (
decoder_layer_forward_base, decoder_model_forward_base_pp,
is_per_tensor_smoothquant, is_per_token_smoothquant,
quant_fusion_with_rmsnorm)
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
logger = init_logger(__name__)
def vllm__module_executor__models__baichuan__BaiChuanAttention__init__(
self,
hidden_size: int,
num_heads: int,
position_embedding: str,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super(BaiChuanAttention, self).__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.postion_embedding = position_embedding
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# pylint: disable=invalid-name
self.W_pack = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5
'''
=============================
Modify by vllm_mlu
=============================
@brief: add cache_config to support kv8
'''
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config)
'''
==================
End of MLU Hijack
==================
'''
else:
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
cache_config=cache_config,
quant_config=quant_config)
def vllm__module_executor__models__baichuan__BaiChuanAttention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states, smooth_quant_scale)
q, k, v = qkv.chunk(chunks=3, dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
qk, _ = qkv.split([self.num_heads * self.head_dim * 2, self.num_heads * self.head_dim], dim=-1)
if self.postion_embedding != "ALIBI":
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual
'''
output, _ = self.o_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
def vllm__module_executor__models__baichuan__BaiChuanDecoderLayer__init__(
self,
config: PretrainedConfig,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None
):
super(BaiChuanDecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
position_embedding=position_embedding,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: use FeedForward instead of MLP
'''
self.mlp = FeedForward(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act='silu',
up_proj_name='gate_up_proj',
is_gated=True,
down_proj_name='down_proj',
bias=False,
quant_config=quant_config)
'''
==================
End of MLU Hijack
==================
'''
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
'''
=============================
Modify by vllm_mlu
=============================
@brief: prepare to perf per-tensor sq cases if suitable
'''
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
self.self_attn.W_pack.quant_method.skip_quant_input = True
self.mlp.gate_up_proj.quant_method.skip_quant_input = True
self.quant_fusion_attn_layernorm = None
self.quant_fusion_mlp_layernorm = None
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__baichuan__BaiChuanDecoderLayer__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: perf model by:
1) add residual in matmul;
2) fuse quantization in layernorm in per-tensor sq case;
'''
attn_layernorm = self.input_layernorm
mlp_layernorm = self.post_attention_layernorm
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
if self.is_per_token_sq_perf_cases:
attn_quant_scale = self.self_attn.W_pack.smooth
mlp_quant_scale = self.mlp.gate_up_proj.smooth
else:
attn_quant_scale = self.self_attn.W_pack.scale_to_int
mlp_quant_scale = self.mlp.gate_up_proj.scale_to_int
if self.quant_fusion_attn_layernorm is None:
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.input_layernorm, attn_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
self.quant_fusion_mlp_layernorm = quant_fusion_with_rmsnorm(
self.post_attention_layernorm, mlp_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
attn_layernorm = self.quant_fusion_attn_layernorm
mlp_layernorm = self.quant_fusion_mlp_layernorm
return decoder_layer_forward_base(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
input_layernorm=attn_layernorm,
self_attn=self.self_attn,
post_layernorm=mlp_layernorm,
mlp=self.mlp,
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__baichuan__BaiChuanModel__forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
'''
=============================
Modify by vllm_mlu
=============================
'''
return decoder_model_forward_base_pp(input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
layers=self.layers,
start_layer=self.start_layer,
end_layer=self.end_layer,
get_input_embeddings=self.embed_tokens,
norm=self.norm)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(BaiChuanAttention,
BaiChuanAttention.__init__,
vllm__module_executor__models__baichuan__BaiChuanAttention__init__)
MluHijackObject.apply_hijack(BaiChuanAttention,
BaiChuanAttention.forward,
vllm__module_executor__models__baichuan__BaiChuanAttention__forward)
MluHijackObject.apply_hijack(BaiChuanDecoderLayer,
BaiChuanDecoderLayer.__init__,
vllm__module_executor__models__baichuan__BaiChuanDecoderLayer__init__)
MluHijackObject.apply_hijack(BaiChuanDecoderLayer,
BaiChuanDecoderLayer.forward,
vllm__module_executor__models__baichuan__BaiChuanDecoderLayer__forward)
MluHijackObject.apply_hijack(BaiChuanModel,
BaiChuanModel.forward,
vllm__module_executor__models__baichuan__BaiChuanModel__forward)

View File

@@ -0,0 +1,170 @@
from typing import Optional
import torch
from torch import nn
from transformers import BloomConfig
from vllm.config import CacheConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor.models.bloom import BloomAttention, BloomBlock
from vllm.attention import AttentionMetadata
from vllm_mlu.model_executor.models.layer_utils import (
decoder_layer_forward_base,
is_per_tensor_smoothquant,
is_per_token_smoothquant,
quant_fusion_with_layernorm
)
logger = init_logger(__name__)
def vllm__module_executor__models__bloom__BloomAttention__forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
del position_ids # Unused.
qkv, _ = self.query_key_value(hidden_states, smooth_quant_scale)
q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
'''
output, _ = self.dense(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
def vllm__module_executor__models__bloom__BloomBlock__init__(
self,
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super(BloomBlock, self).__init__()
hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config, cache_config,
quant_config)
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
'''
=============================
Modify by vllm_mlu
=============================
@brief: use FeedForward instead of MLP
'''
self.mlp = FeedForward(hidden_size=hidden_size,
intermediate_size=hidden_size * 4,
hidden_act='gelu',
up_proj_name="dense_h_to_4h",
is_gated=False,
down_proj_name="dense_4h_to_h",
bias=True,
quant_config=quant_config)
'''
==================
End of MLU Hijack
==================
'''
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
'''
=============================
Modify by vllm_mlu
=============================
@brief: prepare to perf sq cases if suitable
'''
self.is_per_tesnor_sq_perf_cases = (is_per_tensor_smoothquant(quant_config) and
not self.apply_residual_connection_post_layernorm)
self.is_per_token_sq_perf_cases = (is_per_token_smoothquant(quant_config) and
not self.apply_residual_connection_post_layernorm)
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
self.self_attention.query_key_value.quant_method.skip_quant_input = True
self.mlp.dense_h_to_4h.quant_method.skip_quant_input = True
self.quant_fusion_attn_layernorm = None
self.quant_fusion_mlp_layernorm = None
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__bloom__BloomBlock__forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: perf model by:
1) add residual in matmul;
2) fuse quantization in layernorm in per-tensor sq case;
'''
attn_layernorm = self.input_layernorm
mlp_layernorm = self.post_attention_layernorm
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
if self.is_per_token_sq_perf_cases:
attn_quant_scale = self.self_attention.query_key_value.smooth
mlp_quant_scale = self.mlp.dense_h_to_4h.smooth
else:
attn_quant_scale = self.self_attention.query_key_value.scale_to_int
mlp_quant_scale = self.mlp.dense_h_to_4h.scale_to_int
self.quant_fusion_attn_layernorm = quant_fusion_with_layernorm(
self.input_layernorm, attn_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
self.quant_fusion_mlp_layernorm = quant_fusion_with_layernorm(
self.post_attention_layernorm, mlp_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
attn_layernorm = self.quant_fusion_attn_layernorm
mlp_layernorm = self.quant_fusion_mlp_layernorm
return decoder_layer_forward_base(positions=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
input_layernorm=attn_layernorm,
self_attn=self.self_attention,
post_layernorm=mlp_layernorm,
mlp=self.mlp,
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
position_name='position_ids',
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(BloomAttention,
BloomAttention.forward,
vllm__module_executor__models__bloom__BloomAttention__forward)
MluHijackObject.apply_hijack(BloomBlock,
BloomBlock.__init__,
vllm__module_executor__models__bloom__BloomBlock__init__)
MluHijackObject.apply_hijack(BloomBlock,
BloomBlock.forward,
vllm__module_executor__models__bloom__BloomBlock__forward)

View File

@@ -0,0 +1,195 @@
import torch
from torch.nn import LayerNorm
from typing import Optional
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
from vllm.model_executor.models.chatglm import GLMAttention, GLMBlock
from vllm_mlu.model_executor.models.layer_utils import (
decoder_layer_forward_base,
is_per_tensor_smoothquant,
is_per_token_smoothquant,
quant_fusion_with_layernorm,
quant_fusion_with_rmsnorm
)
logger = init_logger(__name__)
def vllm__module_executor__models__chatglm__GLMAttention__forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states, smooth_quant_scale)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(position_ids, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
context_layer = self.attn(
q,
k,
v,
kv_cache,
attn_metadata,
)
'''
=============================
Modify by vllm_mlu
=============================
'''
attn_output, _ = self.dense(context_layer, residual)
'''
==================
End of MLU Hijack
==================
'''
return attn_output
def vllm__module_executor__models__chatglm__GLMBlock__init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super(GLMBlock, self).__init__()
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
self.fp32_residual_connection = config.fp32_residual_connection
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = layer_norm_func(config.hidden_size,
eps=config.layernorm_epsilon)
# Self attention.
self.self_attention = GLMAttention(config, cache_config, quant_config)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
self.post_attention_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1) use FeedForward instead of MLP
2) prepare to perf per-tensor sq cases if suitable
'''
# MLP
self.mlp = FeedForward(
hidden_size=config.hidden_size,
intermediate_size=config.ffn_hidden_size,
hidden_act='silu',
up_proj_name='dense_h_to_4h',
is_gated=True,
down_proj_name='dense_4h_to_h',
bias=config.add_bias_linear,
quant_config=quant_config
)
self.is_per_tesnor_sq_perf_cases = (is_per_tensor_smoothquant(quant_config) and
not self.apply_residual_connection_post_layernorm)
self.is_per_token_sq_perf_cases = (is_per_token_smoothquant(quant_config) and
not self.apply_residual_connection_post_layernorm)
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
self.self_attention.query_key_value.quant_method.skip_quant_input = True
self.mlp.dense_h_to_4h.quant_method.skip_quant_input = True
self.use_rmsnorm = config.rmsnorm
self.quant_fusion_attn_layernorm = None
self.quant_fusion_mlp_layernorm = None
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__chatglm__GLMBlock__forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: perf model by:
1) add residual in matmul;
2) fuse quantization in layernorm in per-tensor sq case;
'''
attn_layernorm = self.input_layernorm
mlp_layernorm = self.post_attention_layernorm
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
quant_fusion_func = (quant_fusion_with_rmsnorm if
self.use_rmsnorm else quant_fusion_with_layernorm)
if self.is_per_token_sq_perf_cases:
attn_quant_scale = self.self_attention.query_key_value.smooth
mlp_quant_scale = self.mlp.dense_h_to_4h.smooth
else:
attn_quant_scale = self.self_attention.query_key_value.scale_to_int
mlp_quant_scale = self.mlp.dense_h_to_4h.scale_to_int
self.quant_fusion_attn_layernorm = quant_fusion_func(
self.input_layernorm, attn_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
self.quant_fusion_mlp_layernorm = quant_fusion_func(
self.post_attention_layernorm, mlp_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
attn_layernorm = self.quant_fusion_attn_layernorm
mlp_layernorm = self.quant_fusion_mlp_layernorm
return decoder_layer_forward_base(positions=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
input_layernorm=attn_layernorm,
self_attn=self.self_attention,
post_layernorm=mlp_layernorm,
mlp=self.mlp,
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
position_name='position_ids',
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(GLMAttention,
GLMAttention.forward,
vllm__module_executor__models__chatglm__GLMAttention__forward)
MluHijackObject.apply_hijack(GLMBlock,
GLMBlock.__init__,
vllm__module_executor__models__chatglm__GLMBlock__init__)
MluHijackObject.apply_hijack(GLMBlock,
GLMBlock.forward,
vllm__module_executor__models__chatglm__GLMBlock__forward)

View File

@@ -0,0 +1,370 @@
from typing import Optional
import torch
import torch.nn as nn
from transformers import CLIPVisionConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.clip import (CLIPVisionModel,
CLIPVisionTransformer,
CLIPEncoderLayer,
CLIPParallelAttention)
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.logger import init_logger
logger = init_logger(__name__)
class MLUCLIPAttention(nn.Module):
"""
MLU-compatible CLIP attention implementation.
Used as fallback when num_heads % tp_size != 0.
This implementation does not use tensor parallelism.
"""
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
)
self.scale = self.head_dim ** -0.5
self.dropout = getattr(config, 'attention_dropout', 0.0)
# Use non-parallel linear layers since this is fallback for non-divisible cases
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
bias=True,
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
bias=True,
)
def forward(
self,
hidden_states: torch.Tensor,
):
"""
Input shape: Batch x Time x Channel
Compatible with CLIPSdpaAttention interface
"""
bsz, tgt_len, _ = hidden_states.size()
# Project to Q, K, V
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
# Reshape for attention computation
query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, tgt_len, self.num_heads, self.head_dim)
value_states = value_states.view(bsz, tgt_len, self.num_heads, self.head_dim)
# Use MLU flash attention for inference
if self.dropout == 0.0:
try:
from vllm import _mlu_ops as mlu_ops
out = mlu_ops.flash_attention(
query_states,
key_states,
value_states,
out=None,
cu_seq_lens_q=None,
cu_seq_lens_kv=None,
alibi_slope=None,
attn_bias=None,
max_seq_len_q=tgt_len,
max_seq_len_kv=tgt_len,
softmax_scale=self.scale,
is_causal=False
)
except (ImportError, AttributeError):
# Fallback to standard PyTorch attention if MLU ops not available
logger.warning("MLU ops not available, using standard PyTorch attention")
out = self._pytorch_attention(query_states, key_states, value_states)
else:
# Use xformers if dropout is needed (training mode)
try:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(
query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale
)
except ImportError:
logger.warning("xformers not available, using standard PyTorch attention")
out = self._pytorch_attention(query_states, key_states, value_states)
# Reshape output
out = out.view(bsz, tgt_len, -1)
# Output projection
attn_output, _ = self.out_proj(out)
return attn_output, None
def _pytorch_attention(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
) -> torch.Tensor:
"""
Standard PyTorch scaled dot-product attention as fallback.
Input shape: [batch, seq_len, num_heads, head_dim]
"""
# Transpose to [batch, num_heads, seq_len, head_dim]
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# Compute attention scores
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scale
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# Apply attention to values
attn_output = torch.matmul(attn_weights, value_states)
# Transpose back to [batch, seq_len, num_heads, head_dim]
attn_output = attn_output.transpose(1, 2)
return attn_output
def vllm__module_executor__models__clip__CLIPParallelAttention__forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
'''
=============================
Modify by vllm_mlu
=============================
@brief: perf attn using tmo flash attn
'''
if self.dropout is None or self.dropout == 0.0:
# Always true for inference
from vllm import _mlu_ops as mlu_ops
out = mlu_ops.flash_attention(query_states,
key_states,
value_states,
out=None,
cu_seq_lens_q=None,
cu_seq_lens_kv=None,
alibi_slope=None,
attn_bias=None,
max_seq_len_q=tgt_len,
max_seq_len_kv=tgt_len,
softmax_scale=self.scale,
is_causal=False)
else:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
'''
==================
End of MLU Hijack
==================
'''
out = out.view(bsz, tgt_len, -1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual
'''
attn_output, _ = self.out_proj(out, residual)
'''
==================
End of MLU Hijack
==================
'''
return attn_output, None
def vllm__module_executor__models__clip__CLIPEncoderLayer____init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super(CLIPEncoderLayer, self).__init__()
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
'''
=============================
Modify by vllm_mlu
=============================
@brief: perf attn using tmo flash attn, do not check xformers
'''
if num_heads % tp_size == 0:
self.self_attn = CLIPParallelAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.use_parallel_attn = True
else:
logger.warning("Use MLUCLIPAttention for clip model (fallback for non-divisible heads).")
self.self_attn = MLUCLIPAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.use_parallel_attn = False
'''
==================
End of MLU Hijack
==================
'''
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
'''
=============================
Modify by vllm_mlu
=============================
@brief: use FeedForward instead of MLP
'''
self.mlp = FeedForward(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
is_gated=False,
bias=True,
hidden_act=config.hidden_act,
quant_config=quant_config,
up_proj_name='fc1',
down_proj_name='fc2',
prefix=f"{prefix}.mlp")
'''
==================
End of MLU Hijack
==================
'''
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def vllm__module_executor__models__clip__CLIPEncoderLayer__forward(
self,
hidden_states: torch.Tensor
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: apply residual fusion
'''
residual = hidden_states
if self.use_parallel_attn:
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states, residual)
else:
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
bsz, tgt_len, _ = hidden_states.size()
hidden_states = self.mlp(hidden_states, residual)
hidden_states = hidden_states.view(bsz, tgt_len, -1)
'''
==================
End of MLU Hijack
==================
'''
return hidden_states
def vllm__module_executor__models__clip__CLIPVisionModel____init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super(CLIPVisionModel, self).__init__()
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
'''
=============================
Modify by vllm_mlu
=============================
@brief: perf attn using tmo flash attn, do not check xformers
'''
self.shard_weight = num_heads % tp_size == 0
'''
==================
End of MLU Hijack
==================
'''
self.vision_model = CLIPVisionTransformer(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
)
MluHijackObject.apply_hijack(CLIPParallelAttention,
CLIPParallelAttention.forward,
vllm__module_executor__models__clip__CLIPParallelAttention__forward)
MluHijackObject.apply_hijack(CLIPEncoderLayer,
CLIPEncoderLayer.__init__,
vllm__module_executor__models__clip__CLIPEncoderLayer____init__)
MluHijackObject.apply_hijack(CLIPEncoderLayer,
CLIPEncoderLayer.forward,
vllm__module_executor__models__clip__CLIPEncoderLayer__forward)
MluHijackObject.apply_hijack(CLIPVisionModel,
CLIPVisionModel.__init__,
vllm__module_executor__models__clip__CLIPVisionModel____init__)

View File

@@ -0,0 +1,625 @@
import torch
from torch import nn
from typing import Any, Dict, Iterable, List, Optional, Tuple
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
from vllm.utils import print_warning_once
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm_mlu.model_executor.models.layer_utils import quant_fusion_with_rmsnorm
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, yarn_get_mscale
class DeepseekV2MoE(SparseMoeMlp):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
up_proj_name="gate_up_proj",
is_gated=True,
down_proj_name="down_proj",
has_bias=False,
skip_bias_add=False,
renormalize=config.norm_topk_prob,
hidden_act=config.hidden_act,
params_dtype=None,
quant_config=quant_config,
is_use_fused_moe=True,
expert_group=config.n_group,
topk_group=config.topk_group)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.routed_scaling_factor = config.routed_scaling_factor
if self.tp_size > config.n_routed_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.n_routed_experts}.")
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate")
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace MLP with FeedForward.
'''
self.shared_experts = FeedForward(hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
up_proj_name='gate_up_proj',
is_gated=True,
down_proj_name='down_proj',
bias=False,
quant_config=quant_config,
reduce_results=False)
'''
==================
End of MLU Hijack
==================
'''
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace experts() with forward_experts, which defined by SparseMoeMlp.
'''
final_hidden_states = self.forward_experts(
hidden_states, router_logits) * self.routed_scaling_factor
'''
==================
End of MLU Hijack
==================
'''
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
def forward_prefill(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q_scale = None
if hasattr(self.q_b_proj.quant_method, "quant_config"):
self.q_b_proj.quant_method.skip_quant_input = True
quant_scale = self.q_b_proj.smooth
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.q_a_layernorm, quant_scale, dynamic_quant=True)
q, q_scale = self.quant_fusion_attn_layernorm(q)
q = self.q_b_proj(q, q_scale)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
self.qk_head_dim)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, _ = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads,
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
'''
=============================
Modify by vllm_mlu
=============================
@brief: MLA save cache before flashattn
'''
from vllm import _mlu_ops as mlu_ops
if len(kv_cache) != 0 and kv_cache[0].numel() > 0:
key_cache = kv_cache[0][0]
key_value = torch.concat((kv_a.unsqueeze(1), k_pe), dim=-1)
updated_slot_mapping = attn_metadata.slot_mapping
if self.attn.kv_cache_dtype == 'int8':
key_cache_scale = kv_cache[1][0]
mlu_ops.quant_to_paged_cache(key_value,
None,
key_cache,
None,
key_cache_scale,
None,
updated_slot_mapping.flatten())
else:
mlu_ops.reshape_paged_cache(key_value,
None,
key_cache,
None,
updated_slot_mapping.flatten())
'''
==================
End of MLU Hijack
==================
'''
q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim:] = k_pe
'''
=============================
Modify by vllm_mlu
=============================
@brief: mlu attention not pad but qk_head_dim 192 v_head_dim 128.
'''
q = q.reshape(-1, self.num_local_heads * self.qk_head_dim)
k = k.reshape(-1, self.num_local_heads * self.qk_head_dim)
v = v.contiguous().reshape(-1, self.num_local_heads * self.v_head_dim)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
'''
==================
End of MLU Hijack
==================
'''
output, _ = self.o_proj(attn_output)
return output
def forward_decoder(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty(
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
)
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q_scale = None
if hasattr(self.q_b_proj.quant_method, "quant_config"):
self.q_b_proj.quant_method.skip_quant_input = True
quant_scale = self.q_b_proj.smooth
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.q_a_layernorm, quant_scale, dynamic_quant=True)
q, q_scale = self.quant_fusion_attn_layernorm(q)
q = self.q_b_proj(q, q_scale)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
from vllm import _mlu_ops as mlu_ops
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
v_input = latent_cache[..., : self.kv_lora_rank]
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
k_input = latent_cache.unsqueeze(1)
k_input[..., : self.kv_lora_rank] = v_input
k_pe = k_input[..., self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q_input[..., self.kv_lora_rank :] = q_pe
k_input[..., self.kv_lora_rank :] = k_pe
v_input = torch.nn.functional.pad(v_input, [0, self.qk_rope_head_dim, 0, 0, 0, 0],
value=0).view(-1, self.kv_lora_rank + self.qk_rope_head_dim)
q_input = q_input.reshape(q_input.shape[0], -1)
k_input = k_input.reshape(k_input.shape[0], -1)
v_input = v_input.reshape(v_input.shape[0], -1)
attn_output = self.attn_decoder(q_input, k_input, v_input, kv_cache, attn_metadata)
attn_output = attn_output.reshape(-1, self.num_local_heads,
self.kv_lora_rank + self.qk_rope_head_dim)
attn_output = attn_output[:, :, :self.kv_lora_rank]
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc.transpose(1, 2).contiguous())
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
output, _ = self.o_proj(attn_output)
return output
def vllm__module_executor__models__deepseek_v2__DeepseekV2Attention__init(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int,
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super(DeepseekV2Attention, self).__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
# only RowParallelLinear/ColumnParallelLinear will be quantize
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=None,
prefix=f"{prefix}.q_a_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj")
else:
self.q_proj = ColumnParallelLinear(self.hidden_size,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=None,
prefix=f"{prefix}.kv_a_proj_with_mqa")
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=None,
prefix=f"{prefix}.kv_b_proj")
kv_b_proj_weight = self.kv_b_proj.weight
w_kc, w_vc = kv_b_proj_weight.unflatten(
0, (-1, self.qk_nope_head_dim + self.v_head_dim)
).split([self.qk_nope_head_dim, self.v_head_dim], dim=1)
self.w_kc = w_kc
self.w_vc = w_vc
# O projection.
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling['rope_type'] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
# self.attn = Attention(self.num_heads,
# self.qk_head_dim,
# self.scaling,
# num_kv_heads=self.num_heads)
'''
=============================
Modify by vllm_mlu
=============================
@brief: mlu attention support head_size 192.
'''
self.attn = Attention(self.num_local_heads,
self.qk_head_dim,
self.scaling,
num_kv_heads=self.num_local_heads,
cache_config=cache_config,
quant_config=quant_config,
use_mla=True)
self.attn_decoder = Attention(self.num_local_heads,
self.kv_lora_rank + self.qk_rope_head_dim,
self.scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
use_mla=True)
import types
self.forward_prefill = types.MethodType(forward_prefill, self)
self.forward_decoder = types.MethodType(forward_decoder, self)
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__deepseek_v2__DeepseekV2Attention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Use normal computation for prefill and use weight absorption for extend/decode
if attn_metadata.prefill_metadata:
return self.forward_prefill(positions, hidden_states, kv_cache,
attn_metadata)
else:
return self.forward_decoder(positions, hidden_states, kv_cache,
attn_metadata)
def vllm__module_executor__models__deepseek_v2__DeepseekV2DecoderLayer__init__(
self,
config: PretrainedConfig,
prefix: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super(DeepseekV2DecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.self_attn = DeepseekV2Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank
if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace MLP with FeedForward.
'''
self.mlp = FeedForward(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
up_proj_name='gate_up_proj',
is_gated=True,
down_proj_name='down_proj',
bias=False,
quant_config=quant_config)
'''
==================
End of MLU Hijack
==================
'''
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack params and cal start expert id
'''
for name, m in self.model.named_modules():
if isinstance(m, SparseMoeMlp):
m.pack_params()
start_expert_id = 0
'''
==================
End of MLU Hijack
==================
'''
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
'''
=============================
Modify by vllm_mlu
=============================
@brief: delete expert_params_mapping for no useless
'''
'''
==================
End of MLU Hijack
==================
'''
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace expert_id in weight to named_expert_id in params_dict
'''
if start_expert_id > 0 and "mlp.experts." in name:
expert_str = re.search(r'experts\.\d+', name).group(0)
expert_id=int(expert_str.split(".")[1])
named_expert_id = expert_id - start_expert_id
old_expert_name = f"experts.{expert_id}"
new_expert_name = f"experts.{named_expert_id}"
name = name.replace(old_expert_name, new_expert_name)
'''
==================
End of MLU Hijack
==================
'''
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
'''
=============================
Modify by vllm_mlu
=============================
@brief: add expert skiped condition and delete useless if name not in params_dict: continue condition
'''
name = name.replace(weight_name, param_name)
if (("mlp.experts." in name or "mlp.shared_experts." in name or "mlp.shared_expert_gate." in name)
and name not in params_dict):
continue
'''
==================
End of MLU Hijack
==================
'''
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add expert skiped condition
'''
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if (("mlp.experts." in name or "mlp.shared_experts." in name or "mlp.shared_expert_gate." in name)
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(DeepseekV2Attention,
DeepseekV2Attention.forward,
vllm__module_executor__models__deepseek_v2__DeepseekV2Attention__forward)
MluHijackObject.apply_hijack(DeepseekV2Attention,
DeepseekV2Attention.__init__,
vllm__module_executor__models__deepseek_v2__DeepseekV2Attention__init)
MluHijackObject.apply_hijack(DeepseekV2DecoderLayer,
DeepseekV2DecoderLayer.__init__,
vllm__module_executor__models__deepseek_v2__DeepseekV2DecoderLayer__init__)
MluHijackObject.apply_hijack(DeepseekV2ForCausalLM,
DeepseekV2ForCausalLM.load_weights,
vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weights)

View File

@@ -0,0 +1,242 @@
import math
import torch
from torch import nn
from typing import List, Optional, Union
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig
from vllm.transformers_utils.configs import RWConfig
FalconConfig = Union[HF_FalconConfig, RWConfig]
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.models.falcon import (FalconAttention,
FalconDecoderLayer,
_get_alibi_slopes)
logger = init_logger(__name__)
def vllm__module_executor__models__falcon__FalconAttention____init__(
self,
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super(FalconAttention, self).__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
self.new_decoder_architecture = config.new_decoder_architecture
self.multi_query = config.multi_query
if self.new_decoder_architecture:
self.total_num_kv_heads = config.num_kv_heads
elif self.multi_query:
self.total_num_kv_heads = 1
else:
self.total_num_kv_heads = self.total_num_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config,
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config,
reduce_results=self.reduce_row_parallel_results)
self.use_rotary = config.rotary
self.use_alibi = config.alibi
assert not (self.use_rotary and self.use_alibi), (
"Rotary and alibi are mutually exclusive.")
'''
=============================
Modify by vllm_mlu
=============================
@brief: set cache_config for rotary & alibi
'''
if self.use_rotary:
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config,
"max_position_embeddings", 8192)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
self.inv_norm_factor)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.attn = Attention(self.num_heads,
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config)
else:
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__falcon__FalconAttention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, bias = self.query_key_value(hidden_states)
if bias is not None:
qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
if self.use_rotary:
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output, bias = self.dense(attn_output)
return attn_output, bias
def vllm__module_executor__models__falcon__FalconDecoderLayer____init__(
self,
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super(FalconDecoderLayer, self).__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config, cache_config,
quant_config)
'''
=============================
Modify by vllm_mlu
=============================
@brief: use FeedForward instead of MLP
'''
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.mlp = FeedForward(hidden_size=hidden_size,
intermediate_size=hidden_size * 4,
hidden_act='gelu',
up_proj_name='dense_h_to_4h',
is_gated=False,
down_proj_name='dense_4h_to_h',
bias=config.bias,
quant_config=quant_config,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results)
'''
==================
End of MLU Hijack
==================
'''
self.config = config
if (config.num_ln_in_parallel_attn is None
and config.new_decoder_architecture):
config.num_ln_in_parallel_attn = 2
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.input_layernorm = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
else:
if config.num_ln_in_parallel_attn == 2:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
MluHijackObject.apply_hijack(FalconAttention,
FalconAttention.__init__,
vllm__module_executor__models__falcon__FalconAttention____init__)
MluHijackObject.apply_hijack(FalconAttention,
FalconAttention.forward,
vllm__module_executor__models__falcon__FalconAttention__forward)
MluHijackObject.apply_hijack(FalconDecoderLayer,
FalconDecoderLayer.__init__,
vllm__module_executor__models__falcon__FalconDecoderLayer____init__)

View File

@@ -0,0 +1,238 @@
import torch
from torch import nn
from typing import Optional
from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.models.gpt_neox import GPTNeoXAttention, GPTNeoXLayer
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
logger = init_logger(__name__)
def vllm__module_executor__models__gpt_neox__GPTNeoXAttention__init__(
self,
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super(GPTNeoXAttention, self).__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
self.bias = getattr(config, "attention_bias", True)
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.query_key_value = QKVParallelLinear(
config.hidden_size,
self.head_size,
self.total_num_heads,
bias=self.bias,
quant_config=quant_config,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: do not do allreduce in linear and skip bias add when use_parallel_residual
'''
if config.use_parallel_residual:
reduce_results = False
skip_bias_add = True
else:
reduce_results = True
skip_bias_add = False
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=self.bias,
quant_config=quant_config,
reduce_results=reduce_results,
skip_bias_add=skip_bias_add,
)
'''
==================
End of MLU Hijack
==================
'''
scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config,
quant_config=quant_config)
def vllm__module_executor__models__gpt_neox__GPTNeoXAttention__forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
qk, _ = qkv.split([self.num_heads * self.head_size * 2, self.num_heads * self.head_size], dim=-1)
self.rotary_emb(position_ids, qk.view(-1, self.num_heads + self.num_heads, self.head_size))
'''
==================
End of MLU Hijack
==================
'''
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add bias for rank 0 when use_parallel_residual
'''
output, bias = self.dense(attn_output)
if self.dense.skip_bias_add and get_tensor_model_parallel_rank() == 0:
output += bias
'''
==================
End of MLU Hijack
==================
'''
return output
def vllm__module_executor__models__gpt_neox__GPTNeoXLayer__init__(
self,
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super(GPTNeoXLayer, self).__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config, cache_config, quant_config)
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1) use FeedForward instead of MLP
2) do not do allreduce in row linear and skip bias add in it
'''
if self.use_parallel_residual:
reduce_results = False
skip_bias_add = True
else:
reduce_results = True
skip_bias_add = False
self.mlp = FeedForward(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act='gelu',
up_proj_name='dense_h_to_4h',
is_gated=False,
down_proj_name='dense_4h_to_h',
bias=True,
quant_config=quant_config,
skip_bias_add=skip_bias_add,
reduce_results=reduce_results)
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__gpt_neox__GPTNeoXLayer__forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
attn_input = self.input_layernorm(hidden_states)
attn_output = self.attention(
position_ids=position_ids,
hidden_states=attn_input,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: only do one allreduce when use_parallel_residual
'''
if self.use_parallel_residual:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_input = self.post_attention_layernorm(hidden_states)
mlp_output, mlp_bias = self.mlp(mlp_input)
if get_tensor_model_parallel_rank() == 0:
mlp_output += mlp_bias
hidden_states = mlp_output + attn_output + hidden_states
else:
hidden_states = mlp_output + attn_output
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_input = self.post_attention_layernorm(attn_output)
mlp_output = self.mlp(mlp_input)
hidden_states = mlp_output + attn_output
return hidden_states
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(GPTNeoXAttention,
GPTNeoXAttention.__init__,
vllm__module_executor__models__gpt_neox__GPTNeoXAttention__init__)
MluHijackObject.apply_hijack(GPTNeoXAttention,
GPTNeoXAttention.forward,
vllm__module_executor__models__gpt_neox__GPTNeoXAttention__forward)
MluHijackObject.apply_hijack(GPTNeoXLayer,
GPTNeoXLayer.__init__,
vllm__module_executor__models__gpt_neox__GPTNeoXLayer__init__)
MluHijackObject.apply_hijack(GPTNeoXLayer,
GPTNeoXLayer.forward,
vllm__module_executor__models__gpt_neox__GPTNeoXLayer__forward)

View File

@@ -0,0 +1,502 @@
import torch
import re
from typing import List, Optional, Tuple, Iterable, Union
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.models.hunyuan import HunYuanAttention, HunYuanDecoderLayer, HunYuanForCausalLM, HunYuanModel
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.sequence import IntermediateTensors
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.model_executor.models.layer_utils import (
hunyuan_decoder_layer_forward_base, hunyuan_decoder_model_forward_base_pp,
is_per_tensor_smoothquant, is_per_token_smoothquant, quant_fusion_with_rmsnorm)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
class HunYuanSparseMoeBlock(SparseMoeMlp):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(num_experts=config.num_experts,
top_k=config.moe_topk,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
up_proj_name="gate_up_proj",
is_gated=True,
down_proj_name="down_proj",
has_bias=False,
skip_bias_add=False,
renormalize=True if config.moe_topk>1 else False,
hidden_act=config.hidden_act,
params_dtype=None,
quant_config=quant_config,
is_use_fused_moe=True)
self.config = config
self.shared_mlp = None
if config.use_mixed_mlp_moe > 0:
self.shared_mlp = FeedForward(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size * config.num_shared_expert,
hidden_act=config.hidden_act,
up_proj_name='gate_up_proj',
is_gated=True,
down_proj_name='down_proj',
bias=False,
quant_config=quant_config,
reduce_results=False)
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_mlp is not None:
shared_output = self.shared_mlp(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.forward_experts(hidden_states, router_logits, residual)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
def vllm__module_executor__models__hunyuan__HunYuanAttention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
kv_states: Optional[Tuple[torch.Tensor]] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.attention_type == "self":
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
ori_k = k
if self.use_qk_norm:
q = self.query_layernorm(q.reshape(-1, self.num_heads, self.head_dim).contiguous()).reshape(-1, self.num_heads*self.head_dim)
k = self.key_layernorm(k.reshape(-1, self.num_kv_heads, self.head_dim).contiguous()).reshape(-1, self.num_kv_heads*self.head_dim)
elif self.attention_type == "cross":
assert kv_states is not None
ori_k, v = kv_states # use last layer kv,
k = ori_k
q, _ = self.q_proj(hidden_states, smooth_quant_scale)
k_tmp = torch.empty_like(k) # Todo: reduant rotary embedding
qk_temp = torch.cat((q, k_tmp), dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
self.rotary_emb(positions, qk_temp.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
if self.use_qk_norm:
q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous()).reshape(-1, self.num_heads*self.head_dim)
k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous()).reshape(-1, self.num_kv_heads*self.head_dim)
else:
raise RuntimeError("Not support attnention type")
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual
'''
output, _ = self.o_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output, (ori_k, v)
def vllm__module_executor__models__hunyuan__HunYuanDecoderLayer____init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
layer_id: int = -1,
) -> None:
super(HunYuanDecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
cla_factor = getattr(config, "cla_share_factor", 1)
attention_type = "cross" \
if layer_id >= 0 and layer_id % cla_factor != 0 else "self"
self.self_attn = HunYuanAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
attention_type=attention_type,
)
if getattr(config, "num_experts", None):
self.mlp = HunYuanSparseMoeBlock(config=config,
quant_config=quant_config)
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: use FeedForward instead of MLP
'''
self.mlp = FeedForward(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
up_proj_name='gate_up_proj',
is_gated=True,
down_proj_name='down_proj',
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
quant_config=quant_config)
'''
==================
End of MLU Hijack
==================
'''
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
'''
=============================
Modify by vllm_mlu
=============================
@brief: prepare to perf per-tensor sq cases if suitable. For moe
model, we only do quant fusion in attn block.
'''
self.is_per_tensor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
if self.is_per_tensor_sq_perf_cases or self.is_per_token_sq_perf_cases:
if self.self_attn.attention_type == "self":
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
if self.self_attn.attention_type == "cross":
self.self_attn.q_proj.quant_method.skip_quant_input = True
self.quant_fusion_attn_layernorm = None
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__hunyuan__HunYuanForCausalLM__load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]]):
cla_factor = getattr(self.config, "cla_share_factor", 1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack params and cal start expert id
'''
for name, m in self.model.named_modules():
if isinstance(m, SparseMoeMlp):
m.pack_params()
start_expert_id = 0
'''
==================
End of MLU Hijack
==================
'''
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
'''
=============================
Modify by vllm_mlu
=============================
@brief: delete expert_params_mapping for no useless
'''
'''
==================
End of MLU Hijack
==================
'''
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace expert_id in weight to named_expert_id in params_dict
'''
if start_expert_id > 0 and "mlp.experts." in name:
expert_str = re.search(r'experts\.\d+', name).group(0)
expert_id=int(expert_str.split(".")[1])
named_expert_id = expert_id - start_expert_id
old_expert_name = f"experts.{expert_id}"
new_expert_name = f"experts.{named_expert_id}"
name = name.replace(old_expert_name, new_expert_name)
'''
==================
End of MLU Hijack
==================
'''
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
'''
=============================
Modify by vllm_mlu
=============================
@brief: delete if "mlp.experts" in name: continue condition
'''
'''
==================
End of MLU Hijack
==================
'''
# cross layer only have q_proj, skip qkv pack
if weight_name == ".q_proj":
match = re.search(r'layers\.\d+', name)
if match:
layer_id = int(match.group(0).split('.')[-1])
if cla_factor > 1 and layer_id % cla_factor != 0:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if (name.endswith(".bias") and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
'''
=============================
Modify by vllm_mlu
=============================
@brief: add expert skiped condition and delete useless if name not in params_dict: continue condition
'''
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_mlp." in name)
and name not in params_dict):
continue
'''
==================
End of MLU Hijack
==================
'''
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if (name.endswith(".bias") and name not in params_dict):
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
if "mlp.gate.wg." in name:
name = name.replace("wg.", "")
'''
=============================
Modify by vllm_mlu
=============================
@brief: add expert skiped condition
'''
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_mlp." in name)
and name not in params_dict):
continue
'''
==================
End of MLU Hijack
==================
'''
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def vllm__module_executor__models__hunyuan__HunYuanDecoderLayer__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
kv_states: Optional[Tuple[torch.Tensor]] = None,
residual: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
'''
=============================
Modify by vllm_mlu
=============================
@brief: perf model by:
1) add residual in matmul;
2) fuse quantization in layernorm in per-tensor sq case;
'''
attn_layernorm = self.input_layernorm
if self.is_per_tensor_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
if self.self_attn.attention_type == "self":
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.input_layernorm, self.self_attn.qkv_proj.scale_to_int)
if self.self_attn.attention_type == "cross":
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.input_layernorm, self.self_attn.q_proj.scale_to_int)
attn_layernorm = self.quant_fusion_attn_layernorm
elif self.is_per_token_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
if self.self_attn.attention_type == "self":
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.input_layernorm, self.self_attn.qkv_proj.smooth, dynamic_quant=True)
if self.self_attn.attention_type == "cross":
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.input_layernorm, self.self_attn.q_proj.smooth, dynamic_quant=True)
attn_layernorm = self.quant_fusion_attn_layernorm
return hunyuan_decoder_layer_forward_base(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
input_layernorm=attn_layernorm,
self_attn=self.self_attn,
post_layernorm=self.post_attention_layernorm,
mlp=self.mlp,
kv_states=kv_states,
input_norm_fuse_en=self.is_per_token_sq_perf_cases)
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__hunyuan__HunYuanModel__forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
'''
=============================
Modify by vllm_mlu
=============================
'''
return hunyuan_decoder_model_forward_base_pp(config=self.config,
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
layers=self.layers,
start_layer=self.start_layer,
end_layer=self.end_layer,
get_input_embeddings=self.get_input_embeddings,
norm=self.norm,
inputs_embeds=inputs_embeds)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(HunYuanAttention,
HunYuanAttention.forward,
vllm__module_executor__models__hunyuan__HunYuanAttention__forward)
MluHijackObject.apply_hijack(HunYuanDecoderLayer,
HunYuanDecoderLayer.__init__,
vllm__module_executor__models__hunyuan__HunYuanDecoderLayer____init__)
MluHijackObject.apply_hijack(HunYuanForCausalLM,
HunYuanForCausalLM.load_weights,
vllm__module_executor__models__hunyuan__HunYuanForCausalLM__load_weights)
MluHijackObject.apply_hijack(HunYuanDecoderLayer,
HunYuanDecoderLayer.forward,
vllm__module_executor__models__hunyuan__HunYuanDecoderLayer__forward)
MluHijackObject.apply_hijack(HunYuanModel,
HunYuanModel.forward,
vllm__module_executor__models__hunyuan__HunYuanModel__forward)

View File

@@ -0,0 +1,294 @@
import torch
from typing import Optional, Tuple, Iterable, Union, List
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
from vllm.model_executor.models.internlm2 import InternLM2Attention, InternLMDecoderLayer, InternLM2ForCausalLM, InternLM2Model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import (is_pp_missing_parameter)
from vllm.sequence import IntermediateTensors
from vllm_mlu.model_executor.models.layer_utils import (
decoder_layer_forward_base, decoder_model_forward_base_pp,
is_per_tensor_smoothquant, is_per_token_smoothquant,
quant_fusion_with_rmsnorm)
logger = init_logger(__name__)
def vllm__module_executor__models__internlm2__InternLM2Attention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states, smooth_quant_scale)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual
'''
output, _ = self.wo(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
def vllm__module_executor__models__internlm2__InternLMDecoderLayer____init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super(InternLMDecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.attention = InternLM2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attention",
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: use FeedForward instead of MLP
'''
self.feed_forward = FeedForward(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
up_proj_name='gate_up_proj',
is_gated=True,
down_proj_name='w2',
bias=False,
quant_config=quant_config,
prefix=f'{prefix}.feed_forward'
)
'''
==================
End of MLU Hijack
==================
'''
self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
'''
=============================
Modify by vllm_mlu
=============================
@brief: prepare to perf per-tensor sq cases if suitable
'''
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
self.attention.wqkv.quant_method.skip_quant_input = True
self.feed_forward.gate_up_proj.quant_method.skip_quant_input = True
self.quant_fusion_attn_layernorm = None
self.quant_fusion_mlp_layernorm = None
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__internlm2__InternLM2ForCausalLM__load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w1", 0),
("gate_up_proj", "w3", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
'''
=============================
Modify by vllm_mlu
=============================
@brief: support load quant weights and params
'''
if "wqkv" in name and 'smooth' not in name and 'scale_to_int' not in name:
config = self.config
kv_groups = (config.num_attention_heads //
config.num_key_value_heads)
head_dim = config.hidden_size // config.num_attention_heads
if 'weight' in name:
loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
head_dim,
loaded_weight.shape[-1])
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
dim=1)
wq = wq.reshape(-1, wq.shape[-1])
wk = wk.reshape(-1, wk.shape[-1])
wv = wv.reshape(-1, wv.shape[-1])
elif 'scale' in name:
loaded_weight = loaded_weight.view(-1, 2 + kv_groups, head_dim)
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
dim=1)
wq = wq.reshape(-1)
wk = wk.reshape(-1)
wv = wv.reshape(-1)
else:
logger.error(f"unsupport internlm2 quant param: {name}, shape: {loaded_weight.shape}")
weight_loader = param.weight_loader
weight_loader(param, wq, 'q')
weight_loader(param, wk, 'k')
weight_loader(param, wv, 'v')
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__internlm2__InternLMDecoderLayer__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: perf model by:
1) add residual in matmul;
2) fuse quantization in layernorm in per-tensor sq case;
'''
attn_layernorm = self.attention_norm
mlp_layernorm = self.ffn_norm
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
if self.is_per_token_sq_perf_cases:
attn_quant_scale = self.attention.wqkv.smooth
mlp_quant_scale = self.feed_forward.gate_up_proj.smooth
else:
attn_quant_scale = self.attention.wqkv.scale_to_int
mlp_quant_scale = self.feed_forward.gate_up_proj.scale_to_int
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.attention_norm, attn_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
self.quant_fusion_mlp_layernorm = quant_fusion_with_rmsnorm(
self.ffn_norm, mlp_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
attn_layernorm = self.quant_fusion_attn_layernorm
mlp_layernorm = self.quant_fusion_mlp_layernorm
return decoder_layer_forward_base(positions, hidden_states, kv_cache, attn_metadata,
attn_layernorm,
self.attention,
mlp_layernorm,
self.feed_forward,
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__internlm2__InternLM2Model__forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
'''
=============================
Modify by vllm_mlu
=============================
'''
return decoder_model_forward_base_pp(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors,
self.layers, self.start_layer, self.end_layer,
self.tok_embeddings,
self.norm,
inputs_embeds)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(InternLM2Attention,
InternLM2Attention.forward,
vllm__module_executor__models__internlm2__InternLM2Attention__forward)
MluHijackObject.apply_hijack(InternLMDecoderLayer,
InternLMDecoderLayer.__init__,
vllm__module_executor__models__internlm2__InternLMDecoderLayer____init__)
MluHijackObject.apply_hijack(InternLM2ForCausalLM,
InternLM2ForCausalLM.load_weights,
vllm__module_executor__models__internlm2__InternLM2ForCausalLM__load_weights)
MluHijackObject.apply_hijack(InternLMDecoderLayer,
InternLMDecoderLayer.forward,
vllm__module_executor__models__internlm2__InternLMDecoderLayer__forward)
MluHijackObject.apply_hijack(InternLM2Model,
InternLM2Model.forward,
vllm__module_executor__models__internlm2__InternLM2Model__forward)

View File

@@ -0,0 +1,273 @@
import torch
from typing import Callable, Optional, List, Union, Tuple
from vllm import _mlu_ops as mlu_ops
from vllm.attention import AttentionMetadata
from vllm.sequence import IntermediateTensors
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from transformers import PretrainedConfig
from vllm_mlu._mlu_utils import check_context_comm_cmpt_parallel
def hunyuan_decoder_layer_forward_base(
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
input_layernorm: Callable,
self_attn: Callable,
post_layernorm: Callable,
mlp: Callable,
kv_states: Optional[Tuple[torch.Tensor]] = None,
apply_residual_connection_post_layernorm: bool = False,
position_name: str = 'positions',
input_norm_fuse_en: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
smooth_quant_scale = None
if input_norm_fuse_en:
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
else:
layernorm_output = input_layernorm(hidden_states)
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self Attention
attention_output, ori_kv_states = self_attn(
**{position_name: positions},
hidden_states=layernorm_output,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=residual,
kv_states=kv_states,
smooth_quant_scale=smooth_quant_scale,
)
layernorm_output = post_layernorm(attention_output)
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# Fully Connected
hidden_states = mlp(layernorm_output, residual)
return hidden_states, ori_kv_states
def decoder_layer_forward_base(
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
input_layernorm: Callable,
self_attn: Callable,
post_layernorm: Callable,
mlp: Callable,
apply_residual_connection_post_layernorm: bool = False,
position_name: str = 'positions',
input_norm_fuse_en: bool = False,
post_norm_fuse_en: bool = False,
) -> torch.Tensor:
smooth_quant_scale = None
if input_norm_fuse_en:
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
else:
layernorm_output = input_layernorm(hidden_states)
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self Attention
attention_output = self_attn(
**{position_name: positions},
hidden_states=layernorm_output,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=residual,
smooth_quant_scale=smooth_quant_scale,
)
smooth_quant_scale = None
if post_norm_fuse_en:
layernorm_output, smooth_quant_scale = post_layernorm(attention_output)
else:
layernorm_output = post_layernorm(attention_output)
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# Fully Connected
kwargs = dict()
if post_norm_fuse_en:
kwargs['smooth_quant_scale'] = smooth_quant_scale
hidden_states = mlp(layernorm_output, residual, **kwargs)
return hidden_states
def decoder_model_forward_base(
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
layers: torch.nn.ModuleList,
get_input_embeddings: Callable,
norm: Callable
) -> torch.Tensor:
hidden_states = get_input_embeddings(input_ids)
for i in range(len(layers)):
layer = layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
)
hidden_states = norm(hidden_states)
return hidden_states
def hunyuan_decoder_model_forward_base_pp(
config: PretrainedConfig,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
layers: torch.nn.ModuleList,
start_layer: int,
end_layer: int,
get_input_embeddings: Callable,
norm: Callable,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = get_input_embeddings(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
cla_factor = getattr(config, "cla_share_factor", 1)
prev_kv_states = None
for i in range(start_layer, end_layer):
layer = layers[i]
hidden_states, kv_states = layer(
positions,
hidden_states,
kv_caches[i - start_layer],
attn_metadata,
prev_kv_states,
)
if (i - start_layer) % cla_factor == 0:
prev_kv_states = kv_states
else:
prev_kv_states = None
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
})
hidden_states = norm(hidden_states)
return hidden_states
def decoder_model_forward_base_pp(
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
layers: torch.nn.ModuleList,
start_layer: int,
end_layer: int,
get_input_embeddings: Callable,
norm: Callable,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = get_input_embeddings(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(start_layer, end_layer):
layer = layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
})
hidden_states = norm(hidden_states)
return hidden_states
def is_smoothquant(quant_config: QuantizationConfig) -> bool:
return (quant_config is not None and quant_config.get_name() == "SmoothQuant")
def is_per_tensor_smoothquant(quant_config: QuantizationConfig) -> bool:
return is_smoothquant(quant_config) and quant_config.input_quant_method == "per_tensor"
def is_per_token_smoothquant(quant_config: QuantizationConfig) -> bool:
if check_context_comm_cmpt_parallel():
return False
return is_smoothquant(quant_config) and quant_config.input_quant_method == "per_token"
def quant_fusion_with_layernorm(
op: torch.nn.LayerNorm,
quant_scale: torch.Tensor,
dynamic_quant: bool = False,
) -> Callable:
bias = None
if op.bias is not None:
bias = op.bias.data
def func(x: torch.Tensor) -> torch.Tensor:
return mlu_ops.fused_layer_norm(
x,
None,
op.weight.data,
bias,
None,
op.eps,
False,
quant_scale,
dynamic_quant)
return func
def quant_fusion_with_rmsnorm(
op: RMSNorm,
quant_scale: torch.Tensor,
dynamic_quant: bool = False,
) -> Callable:
def func(x: torch.Tensor) -> torch.Tensor:
return mlu_ops.fused_rms_norm(
x,
None,
op.weight.data,
None,
None,
op.variance_epsilon,
False,
quant_scale,
dynamic_quant)
return func

View File

@@ -0,0 +1,275 @@
import torch
from typing import Dict, List, Optional, Union, Any
from transformers import LlamaConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
from vllm.model_executor.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
from vllm.sequence import IntermediateTensors
from vllm_mlu.model_executor.models.layer_utils import (
decoder_layer_forward_base, decoder_model_forward_base_pp,
is_per_tensor_smoothquant, is_per_token_smoothquant,
quant_fusion_with_rmsnorm)
logger = init_logger(__name__)
vllm__module_executor__models__llama__LlamaAttention__init__org = LlamaAttention.__init__
def vllm__module_executor__models__llama__LlamaAttention____init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
vllm__module_executor__models__llama__LlamaAttention__init__org(
self,
config,
hidden_size,
num_heads,
num_kv_heads,
rope_theta,
rope_scaling,
max_position_embeddings,
quant_config,
bias,
cache_config,
prefix)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add rope_scaling params
'''
self.rope_scaling = rope_scaling
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__llama__LlamaAttention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
if self.rope_scaling is not None and self.rope_scaling["rope_type"] == "longrope":
q, k = self.rotary_emb(positions, q, k)
else:
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual
'''
output, _ = self.o_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
def vllm__module_executor__models__llama__LlamaDecoderLayer____init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super(LlamaDecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: use FeedForward instead of MLP
'''
self.mlp = FeedForward(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act='silu',
up_proj_name='gate_up_proj',
is_gated=True,
down_proj_name='down_proj',
bias=getattr(config, "mlp_bias", False),
quant_config=quant_config,
prefix=f"{prefix}.mlp")
'''
==================
End of MLU Hijack
==================
'''
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
'''
=============================
Modify by vllm_mlu
=============================
@brief: prepare to perf sq cases if suitable
'''
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
self.mlp.gate_up_proj.quant_method.skip_quant_input = True
self.quant_fusion_attn_layernorm = None
self.quant_fusion_mlp_layernorm = None
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__llama__LlamaDecoderLayer__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: perf model by:
1) add residual in matmul;
2) fuse quantization in layernorm in per-tensor sq case;
'''
attn_layernorm = self.input_layernorm
mlp_layernorm = self.post_attention_layernorm
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
if self.is_per_token_sq_perf_cases:
attn_quant_scale = self.self_attn.qkv_proj.smooth
mlp_quant_scale = self.mlp.gate_up_proj.smooth
else:
attn_quant_scale = self.self_attn.qkv_proj.scale_to_int
mlp_quant_scale = self.mlp.gate_up_proj.scale_to_int
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.input_layernorm, attn_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
self.quant_fusion_mlp_layernorm = quant_fusion_with_rmsnorm(
self.post_attention_layernorm, mlp_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
attn_layernorm = self.quant_fusion_attn_layernorm
mlp_layernorm = self.quant_fusion_mlp_layernorm
return decoder_layer_forward_base(positions, hidden_states, kv_cache, attn_metadata,
attn_layernorm,
self.self_attn,
mlp_layernorm,
self.mlp,
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__llama__LlamaModel__forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
'''
=============================
Modify by vllm_mlu
=============================
'''
return decoder_model_forward_base_pp(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors,
self.layers, self.start_layer, self.end_layer,
self.get_input_embeddings,
self.norm,
inputs_embeds)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(LlamaAttention,
LlamaAttention.__init__,
vllm__module_executor__models__llama__LlamaAttention____init__)
MluHijackObject.apply_hijack(LlamaAttention,
LlamaAttention.forward,
vllm__module_executor__models__llama__LlamaAttention__forward)
MluHijackObject.apply_hijack(LlamaDecoderLayer,
LlamaDecoderLayer.__init__,
vllm__module_executor__models__llama__LlamaDecoderLayer____init__)
MluHijackObject.apply_hijack(LlamaDecoderLayer,
LlamaDecoderLayer.forward,
vllm__module_executor__models__llama__LlamaDecoderLayer__forward)
MluHijackObject.apply_hijack(LlamaModel,
LlamaModel.forward,
vllm__module_executor__models__llama__LlamaModel__forward)

View File

@@ -0,0 +1,336 @@
import torch
import re
from typing import List, Optional, Tuple, Union, Iterable
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor.models.mixtral import (MixtralAttention, MixtralDecoderLayer,
MixtralForCausalLM, MixtralModel)
from vllm_mlu.mlu_hijack_utils import set_is_gated
from transformers import MixtralConfig
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.model_loader.weight_utils import (default_weight_loader,
maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.sequence import IntermediateTensors
from vllm_mlu.model_executor.models.layer_utils import (
decoder_layer_forward_base, decoder_model_forward_base_pp,
is_per_tensor_smoothquant, is_per_token_smoothquant,
quant_fusion_with_rmsnorm)
def vllm__module_executor__models__mixtral__MixtralAttention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual
'''
output, _ = self.o_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
def vllm__module_executor__models__mixtral__MixtralDecoderLayer____init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super(MixtralDecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace MixtralMoE to SparseMoeMlp
'''
self.block_sparse_moe = SparseMoeMlp(num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
up_proj_name="w13",
is_gated=True,
down_proj_name="w2",
has_bias=False,
skip_bias_add=False,
renormalize=True,
hidden_act=config.hidden_act,
params_dtype=None,
quant_config=quant_config,
is_use_fused_moe=True)
'''
==================
End of MLU Hijack
==================
'''
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
'''
=============================
Modify by vllm_mlu
=============================
@brief: prepare to perf per-tensor sq cases if suitable. MoE gate linear always runs
in half/full precision for now, so we only do quant fusion in attn block.
'''
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
self.quant_fusion_attn_layernorm = None
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__mixtral__MixtralForCausalLM__load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]]):
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack params and cal start expert id
'''
for name, m in self.model.named_modules():
if isinstance(m, SparseMoeMlp):
m.pack_params()
start_expert_id = 0
'''
==================
End of MLU Hijack
==================
'''
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("w13", "w1", 0),
("w13", "w3", 1),
]
'''
=============================
Modify by vllm_mlu
=============================
@brief: delete expert_params_mapping for no useless
'''
'''
==================
End of MLU Hijack
==================
'''
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace expert_id in weight to named_expert_id in params_dict
'''
if start_expert_id > 0 and "block_sparse_moe.experts." in name:
expert_str = re.search(r'experts\.\d+', name).group(0)
expert_id=int(expert_str.split(".")[1])
named_expert_id = expert_id - start_expert_id
old_expert_name = f"experts.{expert_id}"
new_expert_name = f"experts.{named_expert_id}"
name = name.replace(old_expert_name, new_expert_name)
'''
==================
End of MLU Hijack
==================
'''
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
'''
=============================
Modify by vllm_mlu
=============================
@brief: add expert skiped condition
'''
# Skip experts that are not assigned to this worker.
if (("block_sparse_moe.experts." in name) and (name not in params_dict)):
continue
'''
==================
End of MLU Hijack
==================
'''
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
'''
=============================
Modify by vllm_mlu
=============================
@brief: add expert skiped condition
'''
# Skip experts that are not assigned to this worker.
if (("block_sparse_moe.experts." in name) and (name not in params_dict)):
continue
'''
==================
End of MLU Hijack
==================
'''
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def vllm__module_executor__models__mixtral__MixtralDecoderLayer__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: perf model by:
1) add residual in matmul;
2) fuse quantization in layernorm in per-tensor sq case;
'''
attn_layernorm = self.input_layernorm
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
if self.is_per_token_sq_perf_cases:
attn_quant_scale = self.self_attn.qkv_proj.smooth
else:
attn_quant_scale = self.self_attn.qkv_proj.scale_to_int
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.input_layernorm, attn_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
attn_layernorm = self.quant_fusion_attn_layernorm
return decoder_layer_forward_base(positions, hidden_states, kv_cache, attn_metadata,
attn_layernorm,
self.self_attn,
self.post_attention_layernorm,
self.block_sparse_moe,
input_norm_fuse_en=self.is_per_token_sq_perf_cases)
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__mixtral__MixtralModel__forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
'''
=============================
Modify by vllm_mlu
=============================
'''
return decoder_model_forward_base_pp(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors,
self.layers, self.start_layer, self.end_layer,
self.embed_tokens,
self.norm)
'''
==================
End of MLU Hijack
==================
'''
set_is_gated(True)
MluHijackObject.apply_hijack(MixtralAttention,
MixtralAttention.forward,
vllm__module_executor__models__mixtral__MixtralAttention__forward)
MluHijackObject.apply_hijack(MixtralDecoderLayer,
MixtralDecoderLayer.__init__,
vllm__module_executor__models__mixtral__MixtralDecoderLayer____init__)
MluHijackObject.apply_hijack(MixtralForCausalLM,
MixtralForCausalLM.load_weights,
vllm__module_executor__models__mixtral__MixtralForCausalLM__load_weights)
MluHijackObject.apply_hijack(MixtralDecoderLayer,
MixtralDecoderLayer.forward,
vllm__module_executor__models__mixtral__MixtralDecoderLayer__forward)
MluHijackObject.apply_hijack(MixtralModel,
MixtralModel.forward,
vllm__module_executor__models__mixtral__MixtralModel__forward)

View File

@@ -0,0 +1,230 @@
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
import math
from typing import (List, Optional, Tuple)
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from vllm import _mlu_ops as mlu_ops
from vllm._mlu_ops import flash_attention
from vllm.attention import AttentionMetadata
from vllm.attention.backends.mlu_attn import MLUFlashAttentionMetadata
from vllm.logger import init_logger
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.model_executor.models.llama import LlamaDecoderLayer
from vllm.model_executor.models.mllama import (MllamaCrossAttentionDecoderLayer,
MllamaTextCrossAttention,
MllamaTextModel,
MllamaVisionSdpaAttention)
from vllm_mlu._mlu_utils import USE_PAGED, BlockSizeInfo
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__model_executor__models__mllama__MllamaTextModel__forward(
self,
input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
skip_cross_attention: bool,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers):
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
if not skip_cross_attention:
hidden_states = decoder_layer(
hidden_states=hidden_states,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=
full_text_row_masked_out_mask,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
elif isinstance(decoder_layer, LlamaDecoderLayer):
'''
=============================
Modify by vllm_mlu
=============================
@brief: fuse residual into decoder layer.
'''
hidden_states = decoder_layer(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
'''
==================
End of MLU Hijack
==================
'''
else:
raise ValueError(
f"Unknown decoder layer type {type(decoder_layer)}")
hidden_states = self.norm(hidden_states)
return hidden_states
def vllm__model_executor__models__mllama__MllamaVisionSdpaAttention__forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_state)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(q.shape[0], q.shape[1], self.num_local_heads,
self.head_dim)
k = k.view(k.shape[0], k.shape[1], self.num_local_heads,
self.head_dim)
v = v.view(v.shape[0], v.shape[1], self.num_local_heads,
self.head_dim)
batch, seq_len_q, q_head_num, head_size = q.shape
seq_len_k = k.shape[1]
softmax_scale = head_size ** -0.5
attention_mask = attention_mask.repeat(1, q_head_num, 1, 1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace SDPA with flash attn.
'''
attn_output = flash_attention(q, k, v,
None, # out
None, # cu_seq_lens_q
None, # cu_seq_lens_kv
None, # alibi_slop
attention_mask, # attn_bias
seq_len_q, # max_seq_len_q
seq_len_k, # max_seq_len_kv
softmax_scale, # softmax_scale
False, # is_casual
)
'''
==================
End of MLU Hijack
==================
'''
attn_output = attn_output.reshape(attn_output.shape[0],
attn_output.shape[1], -1).contiguous()
output, _ = self.o_proj(attn_output)
return output
def vllm__model_executor__models__mllama__MllamaTextCrossAttention___attention_with_mask(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
attention_mask: torch.Tensor,
kv_range_for_decode: List[Tuple[int, int]],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache[0].shape) > 1:
if isinstance(attn_metadata, MLUFlashAttentionMetadata):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
if USE_PAGED:
mlu_ops.reshape_paged_cache(cached_k,
cached_v,
kv_cache[0][0],
kv_cache[0][1],
attn_metadata.cross_slot_mapping,
)
else:
cross_slot_mapping_flat = attn_metadata.cross_slot_mapping.flatten()
seq_start_loc = attn_metadata._cached_prefill_metadata.encoder_seq_start_loc
batch_ids = cross_slot_mapping_flat[seq_start_loc[:-1]] // BlockSizeInfo.BLOCK_SIZE
max_context_len = attn_metadata._cached_prefill_metadata.max_encoder_seq_len
mlu_ops.reshape_linear_cache(cached_k,
cached_v,
kv_cache[0][0],
kv_cache[0][1],
seq_start_loc, # context_lengths
max_context_len,
True, # packed
None, # context_seq_offset
batch_ids, # cache_bs_id
None, # cache_seqlen_offset
)
else:
raise ValueError(
f"Unsupported AttentionMetadata {type(attn_metadata)} "
f"class found. Expected the AttentionMetadata to "
f"be either FlashAttentionMetadata.")
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace SDPA with flash attn.
'''
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which
# can be optimized by xformers.BlockDiagonalMask.
# The mask is specially calculated for supporting multi
# images and interleaved images.
seq_len_q, q_head_num, head_size = q.shape
softmax_scale = head_size ** -0.5
cu_seq_lens_q = attn_metadata.seq_start_loc
cu_seq_lens_kv = attn_metadata.encoder_seq_start_loc
max_seq_len_q = attn_metadata.max_prefill_seq_len
max_seq_len_kv = attn_metadata.max_encoder_seq_len
attn_output = flash_attention(q, k, v,
None, # out
cu_seq_lens_q,
cu_seq_lens_kv,
None, # alibi_slope
None, # attn_bias
max_seq_len_q,
max_seq_len_kv,
softmax_scale,
False, # is_causal
)
'''
==================
End of MLU Hijack
==================
'''
output = attn_output.reshape(seq_len_q, self.num_local_heads * self.head_dim)
return output
MluHijackObject.apply_hijack(MllamaTextCrossAttention,
MllamaTextCrossAttention._attention_with_mask,
vllm__model_executor__models__mllama__MllamaTextCrossAttention___attention_with_mask)
MluHijackObject.apply_hijack(MllamaTextModel,
MllamaTextModel.forward,
vllm__model_executor__models__mllama__MllamaTextModel__forward)
MluHijackObject.apply_hijack(MllamaVisionSdpaAttention,
MllamaVisionSdpaAttention.forward,
vllm__model_executor__models__mllama__MllamaVisionSdpaAttention__forward)

View File

@@ -0,0 +1,241 @@
import torch
import numpy as np
from typing import List, Optional, Union
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
from vllm.model_executor.models.qwen import QWenAttention, QWenBlock, QWenModel, QwenImageInputs
from vllm.sequence import IntermediateTensors
from vllm.distributed import get_pp_group
from .layer_utils import decoder_layer_forward_base
from vllm_mlu.model_executor.models.layer_utils import (
decoder_layer_forward_base, decoder_model_forward_base,
is_per_tensor_smoothquant, is_per_token_smoothquant,
quant_fusion_with_rmsnorm)
logger = init_logger(__name__)
def vllm__module_executor__models__qwen__QwenAttention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states, smooth_quant_scale)
q, k, v = qkv.chunk(chunks=3, dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
qk, _ = qkv.split([self.head_dim * self.num_heads * 2, self.head_dim * self.num_heads], dim=-1)
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual
'''
output, _ = self.c_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
def vllm__module_executor__models__qwen__QWenBlock__init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super(QWenBlock, self).__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
self.attn = QWenAttention(config.hidden_size,
config.num_attention_heads,
config.max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
cache_config=cache_config,
quant_config=quant_config)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1) use FeedForward instead of MLP
2) prepare to perf per-tensor sq cases if suitable
'''
self.mlp = FeedForward(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size // 2,
hidden_act='silu',
up_proj_name='gate_up_proj',
is_gated=True,
down_proj_name='c_proj',
bias=False,
quant_config=quant_config)
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
self.attn.c_attn.quant_method.skip_quant_input = True
self.mlp.gate_up_proj.quant_method.skip_quant_input = True
self.quant_fusion_attn_layernorm = None
self.quant_fusion_mlp_layernorm = None
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__qwen__QWenBlock__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: perf model by:
1) add residual in matmul;
2) fuse quantization in layernorm in per-tensor sq case;
'''
attn_layernorm = self.ln_1
mlp_layernorm = self.ln_2
if self.is_per_tesnor_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.ln_1, self.attn.c_attn.scale_to_int)
self.quant_fusion_mlp_layernorm = quant_fusion_with_rmsnorm(
self.ln_2, self.mlp.gate_up_proj.scale_to_int)
attn_layernorm = self.quant_fusion_attn_layernorm
mlp_layernorm = self.quant_fusion_mlp_layernorm
elif self.is_per_token_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.ln_1, self.attn.c_attn.smooth, dynamic_quant=True)
self.quant_fusion_mlp_layernorm = quant_fusion_with_rmsnorm(
self.ln_2, self.mlp.gate_up_proj.smooth, dynamic_quant=True)
attn_layernorm = self.quant_fusion_attn_layernorm
mlp_layernorm = self.quant_fusion_mlp_layernorm
return decoder_layer_forward_base(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
input_layernorm=attn_layernorm,
self_attn=self.attn,
post_layernorm=mlp_layernorm,
mlp=self.mlp,
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__qwen__QWenModel__forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
pixel_values: Optional[QwenImageInputs],
) -> Union[torch.Tensor, IntermediateTensors]:
img_pos = None
# If pixel / visual embeddings are provided, this is a visual model
if pixel_values is not None and self.visual is not None:
if pixel_values["type"] != "image_embeds":
image_embeds = self.visual(pixel_values["data"])
else:
image_embeds = pixel_values["data"]
# features should be of shape (# images, 256, hidden_dim)
img_pos = self.visual.get_image_positions(input_ids)
if isinstance(
img_pos,
np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]:
raise ValueError(
f"Number of placeholders: {img_pos.shape[0]} "
f"does not match number of images {image_embeds.shape[0]}."
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: remove residual
'''
if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids)
# Merge the image embeddings into the hidden states if actually have
# visual features and the corresponding image tokens
if img_pos is not None:
for idx, (img_bos, img_eos) in enumerate(img_pos):
hidden_states[img_bos + 1:img_eos] = image_embeds[idx]
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
})
hidden_states = self.ln_f(hidden_states)
'''
==================
End of MLU Hijack
==================
'''
return hidden_states
MluHijackObject.apply_hijack(QWenAttention,
QWenAttention.forward,
vllm__module_executor__models__qwen__QwenAttention__forward)
MluHijackObject.apply_hijack(QWenBlock,
QWenBlock.__init__,
vllm__module_executor__models__qwen__QWenBlock__init__)
MluHijackObject.apply_hijack(QWenBlock,
QWenBlock.forward,
vllm__module_executor__models__qwen__QWenBlock__forward)
MluHijackObject.apply_hijack(QWenModel,
QWenModel.forward,
vllm__module_executor__models__qwen__QWenModel__forward)

View File

@@ -0,0 +1,225 @@
import torch
from typing import List, Optional
from transformers import Qwen2Config
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.models.qwen2 import Qwen2Attention, Qwen2DecoderLayer, Qwen2Model
from vllm.sequence import IntermediateTensors
from vllm_mlu.model_executor.models.layer_utils import (
decoder_layer_forward_base, decoder_model_forward_base_pp,
is_per_tensor_smoothquant, is_per_token_smoothquant,
quant_fusion_with_rmsnorm)
logger = init_logger(__name__)
def vllm__module_executor__models__qwen2__Qwen2Attention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual
'''
output, _ = self.o_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
def vllm__module_executor__models__qwen2__Qwen2DecoderLayer____init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super(Qwen2DecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: use FeedForward instead of MLP
'''
self.mlp = FeedForward(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act='silu',
up_proj_name='gate_up_proj',
is_gated=True,
down_proj_name='down_proj',
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
'''
==================
End of MLU Hijack
==================
'''
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
'''
=============================
Modify by vllm_mlu
=============================
@brief: prepare to perf per-tensor sq cases if suitable
'''
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
self.mlp.gate_up_proj.quant_method.skip_quant_input = True
self.quant_fusion_attn_layernorm = None
self.quant_fusion_mlp_layernorm = None
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__qwen2__Qwen2DecoderLayer__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: perf model by:
1) add residual in matmul;
2) fuse quantization in layernorm in per-tensor sq case;
'''
attn_layernorm = self.input_layernorm
mlp_layernorm = self.post_attention_layernorm
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
if self.is_per_token_sq_perf_cases:
attn_quant_scale = self.self_attn.qkv_proj.smooth
mlp_quant_scale = self.mlp.gate_up_proj.smooth
else:
attn_quant_scale = self.self_attn.qkv_proj.scale_to_int
mlp_quant_scale = self.mlp.gate_up_proj.scale_to_int
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.input_layernorm, attn_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
self.quant_fusion_mlp_layernorm = quant_fusion_with_rmsnorm(
self.post_attention_layernorm, mlp_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
attn_layernorm = self.quant_fusion_attn_layernorm
mlp_layernorm = self.quant_fusion_mlp_layernorm
return decoder_layer_forward_base(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
input_layernorm=attn_layernorm,
self_attn=self.self_attn,
post_layernorm=mlp_layernorm,
mlp=self.mlp,
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__qwen2__Qwen2Model__forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
'''
return decoder_model_forward_base_pp(input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
layers=self.layers,
start_layer=self.start_layer,
end_layer=self.end_layer,
get_input_embeddings=self.embed_tokens,
norm=self.norm,
inputs_embeds=inputs_embeds)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(Qwen2Attention,
Qwen2Attention.forward,
vllm__module_executor__models__qwen2__Qwen2Attention__forward)
MluHijackObject.apply_hijack(Qwen2DecoderLayer,
Qwen2DecoderLayer.__init__,
vllm__module_executor__models__qwen2__Qwen2DecoderLayer____init__)
MluHijackObject.apply_hijack(Qwen2DecoderLayer,
Qwen2DecoderLayer.forward,
vllm__module_executor__models__qwen2__Qwen2DecoderLayer__forward)
MluHijackObject.apply_hijack(Qwen2Model,
Qwen2Model.forward,
vllm__module_executor__models__qwen2__Qwen2Model__forward)

View File

@@ -0,0 +1,449 @@
import torch
import re
import torch.nn.functional as F
from typing import List, Optional, Tuple, Iterable
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.qwen2_moe import Qwen2MoeAttention, Qwen2MoeDecoderLayer, Qwen2MoeForCausalLM, Qwen2MoeModel
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
from vllm.utils import print_warning_once
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.sequence import IntermediateTensors
from vllm_mlu.model_executor.models.layer_utils import (
decoder_layer_forward_base, decoder_model_forward_base_pp,
is_per_tensor_smoothquant, is_per_token_smoothquant,
quant_fusion_with_rmsnorm)
class Qwen2MoeSparseMoeBlock(SparseMoeMlp):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
up_proj_name="gate_up_proj",
is_gated=True,
down_proj_name="down_proj",
has_bias=False,
skip_bias_add=False,
renormalize=config.norm_topk_prob,
hidden_act=config.hidden_act,
params_dtype=None,
quant_config=quant_config,
is_use_fused_moe=True)
self.config = config
self.shared_expert = None
self.shared_expert_gate = None
if config.shared_expert_intermediate_size > 0:
self.shared_expert = FeedForward(hidden_size=config.hidden_size,
intermediate_size=config.shared_expert_intermediate_size,
hidden_act=config.hidden_act,
up_proj_name='gate_up_proj',
is_gated=True,
down_proj_name='down_proj',
bias=False,
quant_config=quant_config,
reduce_results=False)
self.shared_expert_gate = ReplicatedLinear(config.hidden_size,
1,
bias=False,
params_dtype=self.params_dtype,
quant_config=None)
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
if self.shared_expert_gate is not None:
gate_output = self.shared_expert_gate(hidden_states)
shared_output = F.sigmoid(gate_output[0]) * shared_output
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.forward_experts(hidden_states, router_logits, residual)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
def vllm__module_executor__models__qwen2moe__Qwen2MoeAttention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual
'''
output, _ = self.o_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
def vllm__module_executor__models__qwen2moe__Qwen2MoeDecoderLayer____init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super(Qwen2MoeDecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = Qwen2MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
# `mlp_only_layers` in the config.
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
quant_config=quant_config)
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: use FeedForward instead of MLP
'''
self.mlp = FeedForward(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
up_proj_name='gate_up_proj',
is_gated=True,
down_proj_name='down_proj',
bias=False,
quant_config=quant_config)
'''
==================
End of MLU Hijack
==================
'''
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
'''
=============================
Modify by vllm_mlu
=============================
@brief: prepare to perf per-tensor sq cases if suitable. For moe
model, we only do quant fusion in attn block.
'''
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
self.quant_fusion_attn_layernorm = None
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__qwen2moe__Qwen2MoeForCausalLM__load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]]):
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack params and cal start expert id
'''
for name, m in self.model.named_modules():
if isinstance(m, SparseMoeMlp):
m.pack_params()
start_expert_id = 0
'''
==================
End of MLU Hijack
==================
'''
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
'''
=============================
Modify by vllm_mlu
=============================
@brief: delete expert_params_mapping for no useless
'''
'''
==================
End of MLU Hijack
==================
'''
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace expert_id in weight to named_expert_id in params_dict
'''
if start_expert_id > 0 and "mlp.experts." in name:
expert_str = re.search(r'experts\.\d+', name).group(0)
expert_id=int(expert_str.split(".")[1])
named_expert_id = expert_id - start_expert_id
old_expert_name = f"experts.{expert_id}"
new_expert_name = f"experts.{named_expert_id}"
name = name.replace(old_expert_name, new_expert_name)
'''
==================
End of MLU Hijack
==================
'''
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
'''
=============================
Modify by vllm_mlu
=============================
@brief: delete if "mlp.experts" in name: continue condition
'''
'''
==================
End of MLU Hijack
==================
'''
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
'''
=============================
Modify by vllm_mlu
=============================
@brief: add expert skiped condition and delete useless if name not in params_dict: continue condition
'''
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_expert." in name or "mlp.shared_expert_gate." in name)
and name not in params_dict):
continue
'''
==================
End of MLU Hijack
==================
'''
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: delete for mapping in expert_params_mapping condition
'''
'''
==================
End of MLU Hijack
==================
'''
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
'''
=============================
Modify by vllm_mlu
=============================
@brief: add expert skiped condition
'''
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_expert." in name or "mlp.shared_expert_gate." in name)
and name not in params_dict):
continue
'''
==================
End of MLU Hijack
==================
'''
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def vllm__module_executor__models__qwen2moe__Qwen2MoeDecoderLayer__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: perf model by:
1) add residual in matmul;
2) fuse quantization in layernorm in per-tensor sq case;
'''
attn_layernorm = self.input_layernorm
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
if self.is_per_token_sq_perf_cases:
attn_quant_scale = self.self_attn.qkv_proj.smooth
else:
attn_quant_scale = self.self_attn.qkv_proj.scale_to_int
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.input_layernorm, attn_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
attn_layernorm = self.quant_fusion_attn_layernorm
return decoder_layer_forward_base(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
input_layernorm=attn_layernorm,
self_attn=self.self_attn,
post_layernorm=self.post_attention_layernorm,
mlp=self.mlp,
input_norm_fuse_en=self.is_per_token_sq_perf_cases)
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__qwen2moe__Qwen2MoeModel__forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
'''
return decoder_model_forward_base_pp(input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
layers=self.layers,
start_layer=self.start_layer,
end_layer=self.end_layer,
get_input_embeddings=self.embed_tokens,
norm=self.norm)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(Qwen2MoeAttention,
Qwen2MoeAttention.forward,
vllm__module_executor__models__qwen2moe__Qwen2MoeAttention__forward)
MluHijackObject.apply_hijack(Qwen2MoeDecoderLayer,
Qwen2MoeDecoderLayer.__init__,
vllm__module_executor__models__qwen2moe__Qwen2MoeDecoderLayer____init__)
MluHijackObject.apply_hijack(Qwen2MoeForCausalLM,
Qwen2MoeForCausalLM.load_weights,
vllm__module_executor__models__qwen2moe__Qwen2MoeForCausalLM__load_weights)
MluHijackObject.apply_hijack(Qwen2MoeDecoderLayer,
Qwen2MoeDecoderLayer.forward,
vllm__module_executor__models__qwen2moe__Qwen2MoeDecoderLayer__forward)
MluHijackObject.apply_hijack(Qwen2MoeModel,
Qwen2MoeModel.forward,
vllm__module_executor__models__qwen2moe__Qwen2MoeModel__forward)

View File

@@ -0,0 +1,272 @@
from typing import Optional
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.models.qwen2_vl import (
Qwen2VisionMLP, Qwen2VisionTransformer, Qwen2VisionAttention, Qwen2VLForConditionalGeneration)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.config import VllmConfig
from vllm.attention.selector import _Backend
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__module_executor__models__qwen2_vl__Qwen2VisionTransformer__forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor
):
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
# compute cos sin for apply_rope
cos = rotary_pos_emb.cos()
sin = rotary_pos_emb.sin()
cos = repeat(cos, "... d -> ... (2 d)")
sin = repeat(sin, "... d -> ... (2 d)")
rotary_pos_emb.cos = cos
rotary_pos_emb.sin = sin
'''
==================
End of MLU Hijack
==================
'''
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
grid_thw[:, 0]).cumsum(
dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# transformers
x = x.unsqueeze(1)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
# adapter
x = self.merger(x)
return x
def vllm__module_executor__models__qwen2_vl__Qwen2VisionAttention__forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor
):
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
new_x_shape = x.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
x = x.view(*new_x_shape)
'''
=============================
Modify by vllm_mlu
=============================
@brief: apply mlu_ops.apply_rotary
'''
# [s, b, head, 3 * head_dim] --> 3 [b, s, head, head_dim]
batch_size = x.shape[1]
x = rearrange(x, "s b ... -> b s ...")
head_dim = x.shape[-1] // 3
q, k, v = x.split([head_dim] * 3, dim=-1)
if rotary_pos_emb is not None:
sin = rotary_pos_emb.sin
cos = rotary_pos_emb.cos
from vllm import _mlu_ops as mlu_ops
q = q.float()
q = mlu_ops.rotary_embedding(
q, sin, cos, None, None, False, False, False, q.shape[1]
)
k = k.float()
k = mlu_ops.rotary_embedding(
k, sin, cos, None, None, False, False, False, k.shape[1]
)
q = q.type_as(v)
k = k.type_as(v)
'''
==================
End of MLU Hijack
==================
'''
if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
from flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
output = flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0,
causal=False)
context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.TORCH_SDPA:
seq_length = q.size(1)
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
attention_mask = torch.zeros([1, seq_length, seq_length],
device=q.device,
dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
cu_seqlens[i - 1]:cu_seqlens[i]] = True
output = F.scaled_dot_product_attention(q,
k,
v,
attention_mask,
dropout_p=0.0)
context_layer = rearrange(output, "b h s d -> b s h d ")
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
kv_seqlen=None)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None)
elif self.attn_backend == _Backend.MLU_FLASH_ATTN:
'''
=============================
Modify by vllm_mlu
=============================
@brief: apply mlu_ops.flash_attention
'''
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
output = mlu_ops.flash_attention(q,
k,
v,
out=None,
cu_seq_lens_q=cu_seqlens,
cu_seq_lens_kv=cu_seqlens,
max_seq_len_q=max_seqlen,
max_seq_len_kv=max_seqlen,
alibi_slope=None,
attn_bias=None,
softmax_scale=head_dim ** -0.5,
is_causal=False)
context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
'''
==================
End of MLU Hijack
==================
'''
context_layer = rearrange(context_layer,
"b s h d -> s b (h d)").contiguous()
output, _ = self.proj(context_layer)
return output
vllm__module_executor__models__qwen2_vl__Qwen2VisionAttention__init_org = Qwen2VisionAttention.__init__
def vllm__module_executor__models__qwen2_vl__Qwen2VisionAttention____init__(
self,
embed_dim: Optional[int] = None,
num_heads: Optional[int] = None,
projection_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
vllm__module_executor__models__qwen2_vl__Qwen2VisionAttention__init_org(
self, embed_dim, num_heads, projection_size, quant_config, prefix)
'''
=============================
Modify by vllm_mlu
=============================
@brief: use mlu_ops.flash_atten for better performance
'''
self.attn_backend = _Backend.MLU_FLASH_ATTN
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__qwen2_vl___maybe_ignore_quant_config(
self,
quant_config: QuantizationConfig
):
'''
=============================
Modify by vllm_mlu
=============================
@brief: quantization for vit not yet supported
'''
if quant_config is not None:
logger.warning("Quantization for VisionTransformer not yet supported.")
return None
'''
==================
End of MLU Hijack
==================
'''
def vllm_module_executor__models__qwen2_vl__Qwen2VisionMLP__forward(
self,
x: torch.Tensor
) -> torch.Tensor:
x_parallel, _ = self.fc1(x)
'''
=============================
Modify by vllm_mlu
=============================
@brief: better acc than mlu_ops.active for half precision
'''
if x_parallel.dtype == torch.half and isinstance(self.act, QuickGELU):
x_parallel = self.act.forward_native(x_parallel)
else:
x_parallel = self.act(x_parallel)
'''
==================
End of MLU Hijack
==================
'''
x, _ = self.fc2(x_parallel)
return x
MluHijackObject.apply_hijack(Qwen2VisionTransformer,
Qwen2VisionTransformer.forward,
vllm__module_executor__models__qwen2_vl__Qwen2VisionTransformer__forward)
MluHijackObject.apply_hijack(Qwen2VisionAttention,
Qwen2VisionAttention.forward,
vllm__module_executor__models__qwen2_vl__Qwen2VisionAttention__forward)
MluHijackObject.apply_hijack(Qwen2VisionAttention,
Qwen2VisionAttention.__init__,
vllm__module_executor__models__qwen2_vl__Qwen2VisionAttention____init__)
MluHijackObject.apply_hijack(Qwen2VLForConditionalGeneration,
Qwen2VLForConditionalGeneration._maybe_ignore_quant_config,
vllm__module_executor__models__qwen2_vl___maybe_ignore_quant_config)
MluHijackObject.apply_hijack(Qwen2VisionMLP,
Qwen2VisionMLP.forward,
vllm_module_executor__models__qwen2_vl__Qwen2VisionMLP__forward)

View File

@@ -0,0 +1,264 @@
import torch
from typing import List, Optional
from transformers import Qwen2Config as Qwen3Config
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.models.qwen3 import Qwen3Attention, Qwen3DecoderLayer, Qwen3Model
from vllm.sequence import IntermediateTensors
from vllm_mlu.model_executor.models.layer_utils import (
decoder_layer_forward_base, decoder_model_forward_base_pp,
is_per_tensor_smoothquant, is_per_token_smoothquant,
quant_fusion_with_rmsnorm)
logger = init_logger(__name__)
def vllm__module_executor__models__qwen3__Qwen3Attention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu for Qwen3
=============================
@brief: Apply QK normalization (Qwen3 specific)
Reference: Qwen2 MLU implementation style
'''
# Make q and k contiguous before reshape/view operations
q = q.contiguous()
k = k.contiguous()
# Apply QK normalization before rotary embedding
q_shape = q.shape
k_shape = k.shape
q_by_head = q.view(*q_shape[:-1], self.num_heads, self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q_shape)
k_by_head = k.view(*k_shape[:-1], self.num_kv_heads, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k_shape)
'''
==================
End of QK Norm
==================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
Reference: Qwen2 MLU implementation
'''
# Pack q and k for MLU rotary embedding optimization
# Ensure qk is contiguous for view operation
qk = torch.cat([q, k], dim=-1).contiguous()
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
# Split back after rotary
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
'''
==================
End of MLU Hijack
==================
'''
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual
'''
output, _ = self.o_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
def vllm__module_executor__models__qwen3__Qwen3DecoderLayer____init__(
self,
config: Qwen3Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super(Qwen3DecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = Qwen3Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position_embeddings=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
head_dim=getattr(config, "head_dim", None),
prefix=f"{prefix}.self_attn",
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: use FeedForward instead of MLP
'''
self.mlp = FeedForward(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act='silu',
up_proj_name='gate_up_proj',
is_gated=True,
down_proj_name='down_proj',
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
'''
==================
End of MLU Hijack
==================
'''
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
'''
=============================
Modify by vllm_mlu
=============================
@brief: prepare to perf per-tensor sq cases if suitable
'''
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
self.mlp.gate_up_proj.quant_method.skip_quant_input = True
self.quant_fusion_attn_layernorm = None
self.quant_fusion_mlp_layernorm = None
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__qwen3__Qwen3DecoderLayer__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: perf model by:
1) add residual in matmul;
2) fuse quantization in layernorm in per-tensor sq case;
'''
attn_layernorm = self.input_layernorm
mlp_layernorm = self.post_attention_layernorm
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
if self.is_per_token_sq_perf_cases:
attn_quant_scale = self.self_attn.qkv_proj.smooth
mlp_quant_scale = self.mlp.gate_up_proj.smooth
else:
attn_quant_scale = self.self_attn.qkv_proj.scale_to_int
mlp_quant_scale = self.mlp.gate_up_proj.scale_to_int
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.input_layernorm, attn_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
self.quant_fusion_mlp_layernorm = quant_fusion_with_rmsnorm(
self.post_attention_layernorm, mlp_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
attn_layernorm = self.quant_fusion_attn_layernorm
mlp_layernorm = self.quant_fusion_mlp_layernorm
return decoder_layer_forward_base(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
input_layernorm=attn_layernorm,
self_attn=self.self_attn,
post_layernorm=mlp_layernorm,
mlp=self.mlp,
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
'''
==================
End of MLU Hijack
==================
'''
def vllm__module_executor__models__qwen3__Qwen3Model__forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
'''
return decoder_model_forward_base_pp(input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
layers=self.layers,
start_layer=self.start_layer,
end_layer=self.end_layer,
get_input_embeddings=self.embed_tokens,
norm=self.norm,
inputs_embeds=inputs_embeds)
'''
==================
End of MLU Hijack
==================
'''
# Apply hijacks
MluHijackObject.apply_hijack(Qwen3Attention,
Qwen3Attention.forward,
vllm__module_executor__models__qwen3__Qwen3Attention__forward)
MluHijackObject.apply_hijack(Qwen3DecoderLayer,
Qwen3DecoderLayer.__init__,
vllm__module_executor__models__qwen3__Qwen3DecoderLayer____init__)
MluHijackObject.apply_hijack(Qwen3DecoderLayer,
Qwen3DecoderLayer.forward,
vllm__module_executor__models__qwen3__Qwen3DecoderLayer__forward)
MluHijackObject.apply_hijack(Qwen3Model,
Qwen3Model.forward,
vllm__module_executor__models__qwen3__Qwen3Model__forward)