add qwen3
This commit is contained in:
8
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/__init__.py
Executable file
8
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/__init__.py
Executable file
@@ -0,0 +1,8 @@
|
||||
# hijack vllm layers
|
||||
import vllm_mlu.model_executor.layers
|
||||
|
||||
# hijack vllm models
|
||||
import vllm_mlu.model_executor.models
|
||||
|
||||
# hijack vllm model loader
|
||||
import vllm_mlu.model_executor.model_loader
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
import vllm_mlu.model_executor.custom_model.custom
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,644 @@
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, Iterable, Union, List, Optional, Tuple
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm_mlu.transformers_utils.configs import CustomConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
ReplicatedLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.interfaces import SupportsPP
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm.model_executor.models.utils import PPMissingLayer, make_layers
|
||||
from vllm_mlu.model_executor.models.layer_utils import (
|
||||
decoder_layer_forward_base, is_per_tensor_smoothquant,
|
||||
is_per_token_smoothquant, quant_fusion_with_rmsnorm,
|
||||
quant_fusion_with_layernorm)
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(
|
||||
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != total_num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2,
|
||||
total_num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(start=1,
|
||||
end=1 + 2 * num_remaining_heads,
|
||||
step=2,
|
||||
dtype=torch.int32)
|
||||
slopes = torch.cat(
|
||||
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
x = x.view(-1, self.weight.data.shape[0])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, self.weight.data.shape[0])
|
||||
return mlu_ops.fused_layer_norm(x, residual, self.weight.data, self.bias.data, None, self.variance_epsilon, True)
|
||||
else:
|
||||
return mlu_ops.fused_layer_norm(x, residual, self.weight.data, self.bias.data, None, self.variance_epsilon, False)
|
||||
|
||||
|
||||
_NORM_DICT: Dict[str, nn.Module] = {"rmsnorm": RMSNorm, "layernorm": LayerNorm}
|
||||
|
||||
|
||||
class CustomMoeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CustomConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.n_routed_experts = config.num_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
if self.tp_size > self.n_routed_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {self.n_routed_experts}.")
|
||||
|
||||
self.moe_intermediate_size = self.config.moe_intermediate_size // self.tp_size
|
||||
|
||||
if quant_config is None:
|
||||
self.w1 = nn.Parameter(
|
||||
torch.empty(self.config.num_experts,
|
||||
2 * self.moe_intermediate_size if self.config.is_gated else self.moe_intermediate_size,
|
||||
self.config.hidden_size,
|
||||
dtype=torch.get_default_dtype()), requires_grad=False)
|
||||
self.w2 = nn.Parameter(
|
||||
torch.empty(self.config.num_experts,
|
||||
self.config.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
dtype=torch.get_default_dtype()), requires_grad=False)
|
||||
self.w1_scale = None
|
||||
self.w2_scale = None
|
||||
self.input_smooth = None
|
||||
self.act_smooth = None
|
||||
else:
|
||||
assert quant_config.weight_bits == 8
|
||||
self.w1 = nn.Parameter(
|
||||
torch.empty(self.config.num_experts,
|
||||
2 * self.moe_intermediate_size if self.config.is_gated else self.moe_intermediate_size,
|
||||
self.config.hidden_size,
|
||||
device="mlu",
|
||||
dtype=torch.int8), requires_grad=False)
|
||||
self.w2 = nn.Parameter(
|
||||
torch.empty(self.config.num_experts,
|
||||
self.config.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
device="mlu",
|
||||
dtype=torch.int8), requires_grad=False)
|
||||
self.w1_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.config.num_experts,
|
||||
2 * self.moe_intermediate_size if self.config.is_gated else self.moe_intermediate_size,
|
||||
device="mlu",
|
||||
dtype=torch.float32), requires_grad=False)
|
||||
self.w2_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.config.num_experts,
|
||||
self.config.hidden_size,
|
||||
device="mlu",
|
||||
dtype=torch.float32), requires_grad=False)
|
||||
self.input_smooth = None
|
||||
self.act_smooth = None
|
||||
if quant_config.quant_mode == "SmoothQuant":
|
||||
self.input_smooth =nn.Parameter(
|
||||
torch.empty(
|
||||
self.config.num_experts,
|
||||
self.config.hidden_size,
|
||||
device="mlu",
|
||||
dtype=torch.float32), requires_grad=False)
|
||||
self.act_smooth =nn.Parameter(
|
||||
torch.empty(
|
||||
self.config.num_experts,
|
||||
self.moe_intermediate_size,
|
||||
device="mlu",
|
||||
dtype=torch.float32), requires_grad=False)
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
self.n_routed_experts,
|
||||
bias=False,
|
||||
quant_config=None)
|
||||
if config.shared_expert_intermediate_size > 0:
|
||||
self.shared_expert = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.shared_expert_intermediate_size,
|
||||
hidden_act = self.config.hidden_act,
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=self.config.is_gated,
|
||||
down_proj_name='down_proj',
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False)
|
||||
else:
|
||||
self.shared_expert = None
|
||||
self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
|
||||
1,
|
||||
bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
shared_output = None
|
||||
if self.shared_expert is not None:
|
||||
shared_output = self.shared_expert(hidden_states)
|
||||
if self.shared_expert_gate is not None:
|
||||
shared_output = F.sigmoid(
|
||||
self.shared_expert_gate(hidden_states)) * shared_output
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
residual_ = None if self.rank > 0 else residual
|
||||
final_hidden_states = mlu_ops.fused_moe(hidden_states,
|
||||
router_logits,
|
||||
self.w1,
|
||||
self.w2,
|
||||
None,
|
||||
None,
|
||||
residual_,
|
||||
self.input_smooth,
|
||||
self.act_smooth,
|
||||
self.w1_scale,
|
||||
self.w2_scale,
|
||||
self.top_k,
|
||||
self.config.norm_topk_prob,
|
||||
self.config.is_gated,
|
||||
self.config.hidden_act)
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
reduce_results = (self.config.use_parallel_residual == False)
|
||||
if reduce_results:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
class CustomAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CustomConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(config, "bias", False)
|
||||
self.hidden_size = config.hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads)
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.kv_scale = 1.0
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=attention_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=attention_bias,
|
||||
quant_config=quant_config,
|
||||
skip_bias_add=(self.config.use_parallel_residual and attention_bias),
|
||||
reduce_results = (self.config.use_parallel_residual == False),
|
||||
)
|
||||
|
||||
self.alibi_slopes = None
|
||||
self.rotary_emb = None
|
||||
if self.config.position_embedding_type == "ALIBI":
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
head_end = (tp_rank + 1) * self.num_heads
|
||||
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
|
||||
self.alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
else:
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
max_position_embeddings = getattr(config, "max_sequence_length", 8192)
|
||||
is_neox_style = getattr(config, "is_neox_style", False)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
cache_config=cache_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.rotary_emb:
|
||||
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, bias = self.o_proj(attn_output, residual)
|
||||
if self.o_proj.skip_bias_add and get_tensor_model_parallel_rank() == 0:
|
||||
output += bias
|
||||
return output
|
||||
|
||||
|
||||
class CustomDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CustomConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.self_attn = CustomAttention(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
mlp_bias = getattr(config, "mlp_bias", False) or getattr(config, "bias", False)
|
||||
is_gated = getattr(config, "is_gated", False)
|
||||
|
||||
if config.num_experts is not None:
|
||||
self.mlp = CustomMoeBlock(config=config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.mlp = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=self.config.hidden_act,
|
||||
up_proj_name='up_proj',
|
||||
is_gated=is_gated,
|
||||
down_proj_name='down_proj',
|
||||
bias=mlp_bias,
|
||||
quant_config=quant_config,
|
||||
skip_bias_add=(self.config.use_parallel_residual and mlp_bias),
|
||||
reduce_results = (self.config.use_parallel_residual == False))
|
||||
|
||||
self.input_layernorm = _NORM_DICT[self.config.norm_type](config.hidden_size, eps=config.norm_eps)
|
||||
self.post_attention_layernorm = _NORM_DICT[self.config.norm_type](config.hidden_size, eps=config.norm_eps)
|
||||
|
||||
# perf per-tensor sq cases by fusing quantization in layernorm
|
||||
self.is_per_tesnor_sq_perf_cases = (is_per_tensor_smoothquant(quant_config) and
|
||||
not self.config.apply_residual_connection_post_layernorm)
|
||||
self.is_per_token_sq_perf_cases = (is_per_token_smoothquant(quant_config) and
|
||||
not self.config.apply_residual_connection_post_layernorm)
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_attn_layernorm = None
|
||||
self.is_moe = config.num_experts is not None
|
||||
self.use_rmsnorm = self.config.norm_type == "rmsnorm"
|
||||
if not self.is_moe:
|
||||
self.mlp.up_proj.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_mlp_layernorm = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.config.use_parallel_residual:
|
||||
# x = x + attn(ln1(x)) + mlp(ln2(x))
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
attention_output = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=layernorm_output,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
layernorm_output = self.post_attention_layernorm(hidden_states)
|
||||
if self.mlp.skip_bias_add:
|
||||
mlp_output, mlp_bias = self.mlp(layernorm_output)
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
mlp_output += mlp_bias
|
||||
else:
|
||||
mlp_output = self.mlp(layernorm_output)
|
||||
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
hidden_states = mlp_output + attention_output + hidden_states
|
||||
else:
|
||||
hidden_states = mlp_output + attention_output
|
||||
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
return hidden_states, None
|
||||
else:
|
||||
# rmsnorm use fused_rms_norm to get better performance
|
||||
# if apply_residual_connection_post_layernorm:
|
||||
# x = ln1(x) + attn(ln1(x))
|
||||
# x = ln2(x) + mlp(ln2(x))
|
||||
# else:
|
||||
# x = x + attn(ln1(x))
|
||||
# x = x + mlp(ln2(x))
|
||||
attn_layernorm = self.input_layernorm
|
||||
mlp_layernorm = self.post_attention_layernorm
|
||||
if self.is_per_tesnor_sq_perf_cases:
|
||||
quant_fusion_func = (quant_fusion_with_rmsnorm if
|
||||
self.use_rmsnorm else quant_fusion_with_layernorm)
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_func(
|
||||
self.input_layernorm, self.self_attn.qkv_proj.scale_to_int)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
if not self.is_moe:
|
||||
if self.quant_fusion_mlp_layernorm is None:
|
||||
self.quant_fusion_mlp_layernorm = quant_fusion_func(
|
||||
self.post_attention_layernorm, self.mlp.up_proj.scale_to_int)
|
||||
mlp_layernorm = self.quant_fusion_mlp_layernorm
|
||||
elif self.is_per_token_sq_perf_cases:
|
||||
quant_fusion_func = (quant_fusion_with_rmsnorm if
|
||||
self.use_rmsnorm else quant_fusion_with_layernorm)
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_func(
|
||||
self.input_layernorm, self.self_attn.qkv_proj.smooth, dynamic_quant=True)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
if not self.is_moe:
|
||||
if self.quant_fusion_mlp_layernorm is None:
|
||||
self.quant_fusion_mlp_layernorm = quant_fusion_func(
|
||||
self.post_attention_layernorm, self.mlp.up_proj.smooth, dynamic_quant=True)
|
||||
mlp_layernorm = self.quant_fusion_mlp_layernorm
|
||||
|
||||
post_norm_fuse_en=(self.is_per_token_sq_perf_cases and not self.is_moe)
|
||||
return decoder_layer_forward_base(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
input_layernorm=attn_layernorm,
|
||||
self_attn=self.self_attn,
|
||||
post_layernorm=mlp_layernorm,
|
||||
mlp=self.mlp,
|
||||
apply_residual_connection_post_layernorm=self.config.apply_residual_connection_post_layernorm,
|
||||
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
|
||||
post_norm_fuse_en=post_norm_fuse_en), None
|
||||
|
||||
|
||||
class CustomModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CustomConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||
and get_pp_group().is_last_rank):
|
||||
embed_layer = VocabParallelEmbedding if self.config.use_parallel_embedding else nn.Embedding
|
||||
self.embed_tokens = embed_layer(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: CustomDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config),
|
||||
prefix="custom_model")
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = _NORM_DICT[self.config.norm_type](config.hidden_size, eps=config.norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
if residual is not None:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_text_config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self._verify_params()
|
||||
self.model = CustomModel(self.config, self.cache_config, self.quant_config)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size, self.config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
pass
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def _verify_params(self) -> None:
|
||||
if (self.config.max_sequence_length) is None or \
|
||||
(self.config.num_hidden_layers) is None or \
|
||||
(self.config.hidden_size) is None or \
|
||||
(self.config.vocab_size) is None or \
|
||||
(self.config.num_attention_heads) is None:
|
||||
raise ValueError(
|
||||
"max_sequence_length, num_hidden_layers, hidden_size, vocab_size, "
|
||||
"num_attention_heads, must be vaild int values")
|
||||
|
||||
if self.config.hidden_act not in ["silu", "gelu"]:
|
||||
raise ValueError(
|
||||
"CustomConfig hidden_act must be one of [silu, gelu]. Got "
|
||||
f"{self.config.hidden_act}.")
|
||||
|
||||
if self.config.position_embedding_type not in ["ALIBI", "ROPE"]:
|
||||
raise ValueError(
|
||||
"position_embedding_type must be one of [ALIBI, ROPE]. Got "
|
||||
f"{self.config.position_embedding_type}.")
|
||||
|
||||
if self.config.num_experts is not None:
|
||||
if self.config.num_experts_per_tok is None:
|
||||
raise ValueError(
|
||||
"num_experts_per_tok must be a valid int value when num_experts is not None")
|
||||
if self.config.moe_intermediate_size is None:
|
||||
raise ValueError(
|
||||
"moe_intermediate_size must be a valid int value when num_experts is not None")
|
||||
if self.config.shared_expert_intermediate_size is None:
|
||||
raise ValueError(
|
||||
"shared_expert_intermediate_size must be a valid int value when num_experts is not None")
|
||||
if self.config.norm_topk_prob is None:
|
||||
raise ValueError(
|
||||
"norm_topk_prob must be a valid bool value when num_experts is not None")
|
||||
if self.config.mlp_bias is True:
|
||||
raise ValueError(
|
||||
"mlp_bias must be False when num_experts is not None")
|
||||
if self.quant_config is not None and self.quant_config.get_name() != "SmoothQuant":
|
||||
raise ValueError(
|
||||
"moe only support smoothquant now")
|
||||
else:
|
||||
if self.config.intermediate_size is None:
|
||||
raise ValueError(
|
||||
"intermediate_size must be a valid int value when num_experts is None")
|
||||
|
||||
if self.config.norm_type not in ["rmsnorm", "layernorm"]:
|
||||
raise ValueError(
|
||||
"norm_type must be one of [rmsnorm, layernorm]. Got "
|
||||
f"{self.config.norm_type}.")
|
||||
@@ -0,0 +1,8 @@
|
||||
import vllm_mlu.model_executor.layers.feed_forward
|
||||
import vllm_mlu.model_executor.layers.sparse_moe_mlp
|
||||
import vllm_mlu.model_executor.layers.linear
|
||||
import vllm_mlu.model_executor.layers.spec_decode_base_sampler
|
||||
import vllm_mlu.model_executor.layers.rotary_embedding
|
||||
import vllm_mlu.model_executor.layers.quantization
|
||||
import vllm_mlu.model_executor.layers.activation
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
def vllm__model_executor__activation__QuickGELU__forward_mlu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: implement forward_mlu
|
||||
'''
|
||||
return mlu_ops.active(x, 'quick_gelu', False)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
MluHijackObject.apply_hijack(QuickGELU,
|
||||
"forward_mlu",
|
||||
vllm__model_executor__activation__QuickGELU__forward_mlu)
|
||||
150
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/feed_forward.py
Executable file
150
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/feed_forward.py
Executable file
@@ -0,0 +1,150 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear
|
||||
)
|
||||
from vllm_mlu.mlu_hijack_utils import set_is_gated
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class FeedForward(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
up_proj_name: str,
|
||||
is_gated: bool,
|
||||
down_proj_name: str,
|
||||
bias: bool,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
skip_bias_add: bool = False,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.is_gated = is_gated
|
||||
self.bias = bias
|
||||
self.up_proj_name = up_proj_name
|
||||
self.down_proj_name = down_proj_name
|
||||
self.quant_config = quant_config
|
||||
self.is_initialized = False
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.reduce_results = reduce_results
|
||||
self.use_bt_ffn = True if quant_config is None else False
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
set_is_gated(self.is_gated)
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# up_proj with gate or not
|
||||
if self.is_gated:
|
||||
up_proj = MergedColumnParallelLinear(hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.{up_proj_name}")
|
||||
else:
|
||||
up_proj = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=bias,
|
||||
skip_bias_add=skip_bias_add,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.{up_proj_name}")
|
||||
self.register_module(up_proj_name, up_proj)
|
||||
|
||||
# down_proj
|
||||
down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
skip_bias_add=skip_bias_add,
|
||||
reduce_results=reduce_results,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.{down_proj_name}")
|
||||
self.register_module(down_proj_name, down_proj)
|
||||
|
||||
def prepare_weight(self):
|
||||
if not self.is_initialized:
|
||||
# alpha and beta are 1.0 and 0.0 respectively due to the fact that we don't need residual for now
|
||||
self.alpha = 1.0
|
||||
self.beta = 0.0
|
||||
# place it here to avoid the overhead of calling it in the forward pass
|
||||
self.is_initialized = True
|
||||
|
||||
def _forward(self, hidden_states):
|
||||
self.prepare_weight()
|
||||
up_proj = getattr(self, self.up_proj_name)
|
||||
down_proj = getattr(self, self.down_proj_name)
|
||||
act_dict = {
|
||||
"relu": F.relu,
|
||||
"gelu": F.gelu,
|
||||
"silu": F.silu,
|
||||
}
|
||||
fc1 = F.linear(hidden_states, up_proj.weight, bias=up_proj.bias)
|
||||
if self.is_gated:
|
||||
d = fc1.shape[-1] // 2
|
||||
fc1 = act_dict[self.hidden_act](fc1[..., :d]) * fc1[..., d:]
|
||||
else:
|
||||
fc1 = act_dict[self.hidden_act](fc1)
|
||||
fc2 = F.linear(fc1, down_proj.weight, bias=None)
|
||||
fc2 = tensor_model_parallel_all_reduce(fc2)
|
||||
if not self.skip_bias_add:
|
||||
fc2 = fc2 + down_proj.bias if down_proj.bias is not None else fc2
|
||||
return fc2
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None
|
||||
):
|
||||
self.prepare_weight()
|
||||
up_proj = getattr(self, self.up_proj_name)
|
||||
down_proj = getattr(self, self.down_proj_name)
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
if (self.use_bt_ffn and not isinstance(up_proj, BaseLayerWithLoRA)
|
||||
and not isinstance(down_proj, BaseLayerWithLoRA)):
|
||||
# The matmul formula is the following:
|
||||
# mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual
|
||||
# output = active(mul_out)
|
||||
# Notes: We cannot use the activation function in matmul because it does not support gated operation
|
||||
# we might support its in tmo matmul in the future
|
||||
fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj.weight, up_proj.bias,
|
||||
None, 'none', self.alpha, self.beta)
|
||||
act_out = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
|
||||
beta = 0.0
|
||||
if residual_ is not None:
|
||||
beta = 1.0
|
||||
residual_ = residual_.view(-1, residual_.shape[-1])
|
||||
out_ = mlu_ops.matmul(act_out, down_proj.weight, None, residual_, 'none', self.alpha, beta)
|
||||
# bias if existed need to add after second matmul according to the original design of vllm
|
||||
if self.reduce_results:
|
||||
out = tensor_model_parallel_all_reduce(out_)
|
||||
else:
|
||||
out = out_
|
||||
# do the bias add if needed
|
||||
if not self.skip_bias_add:
|
||||
out = out + down_proj.bias if down_proj.bias is not None else out
|
||||
else:
|
||||
return out, down_proj.bias
|
||||
else:
|
||||
fc1, bias = up_proj(hidden_states, smooth_quant_scale=smooth_quant_scale)
|
||||
if bias is not None:
|
||||
fc1 += bias
|
||||
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
|
||||
out, bias = down_proj(fc1, residual=residual_)
|
||||
if self.skip_bias_add:
|
||||
return out, bias
|
||||
return out
|
||||
101
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/linear.py
Normal file
101
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/linear.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
UnquantizedLinearMethod,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear
|
||||
)
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
def vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
beta = 0.0
|
||||
if residual is not None:
|
||||
beta = 1.0
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
|
||||
return mlu_ops.matmul(x.view(-1, x.shape[-1]), layer.weight, bias, residual, 'none', 1.0, beta).view(res_shape)
|
||||
|
||||
|
||||
def vllm__module_executor__layers__linear__RowParallelLinear__forward(self, input_, residual: Optional[torch.Tensor] = None):
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_,
|
||||
residual=residual_)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
return output, output_bias
|
||||
|
||||
def vllm__module_executor__layers__linear__ColumnParallelLinear__forward(
|
||||
self, input_, smooth_quant_scale: Optional[torch.Tensor] = None):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add input_scale parameter.
|
||||
'''
|
||||
if smooth_quant_scale is not None:
|
||||
output_parallel = self.quant_method.apply(self, input_, bias,
|
||||
input_scale=smooth_quant_scale)
|
||||
else:
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(UnquantizedLinearMethod,
|
||||
UnquantizedLinearMethod.apply,
|
||||
vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply)
|
||||
|
||||
MluHijackObject.apply_hijack(RowParallelLinear,
|
||||
RowParallelLinear.forward,
|
||||
vllm__module_executor__layers__linear__RowParallelLinear__forward)
|
||||
|
||||
MluHijackObject.apply_hijack(ColumnParallelLinear,
|
||||
ColumnParallelLinear.forward,
|
||||
vllm__module_executor__layers__linear__ColumnParallelLinear__forward)
|
||||
@@ -0,0 +1,13 @@
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm_mlu.model_executor.layers.quantization.weightonly import WeightOnlyConfig
|
||||
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantConfig
|
||||
from vllm_mlu.model_executor.layers.quantization.gptq_mlu import GPTQMluConfig
|
||||
from vllm_mlu.model_executor.layers.quantization.awq_mlu import AWQMluConfig
|
||||
|
||||
|
||||
QUANTIZATION_METHODS.update({
|
||||
"gptq_mlu": GPTQMluConfig,
|
||||
"awq_mlu": AWQMluConfig,
|
||||
"weightonly": WeightOnlyConfig,
|
||||
"smoothquant": SmoothQuantConfig,
|
||||
})
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,414 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear, MergedColumnParallelLinear
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
|
||||
|
||||
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
|
||||
def query_mlu_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
if device_capability < 50:
|
||||
return []
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def check_mlu_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
supported_types = query_mlu_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"Mlu does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"device_capability = {device_capability}, zp = {has_zp}).")
|
||||
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
|
||||
return (False, f"Mlu does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
return True
|
||||
|
||||
class AWQMluConfig(QuantizationConfig):
|
||||
"""Config class for AWQMlu.
|
||||
|
||||
Reference: https://arxiv.org/abs/2306.00978
|
||||
"""
|
||||
|
||||
# num_bits -> type
|
||||
TYPE_MAP = {
|
||||
4: {
|
||||
False: scalar_types.uint4b8,
|
||||
True: scalar_types.uint4,
|
||||
},
|
||||
8: {
|
||||
False: scalar_types.uint8b128,
|
||||
True: scalar_types.uint8,
|
||||
}
|
||||
}
|
||||
|
||||
VERSION = ["gemm"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
version: str = "gemm",
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
self.version = version
|
||||
self.support_scale_zeros = False
|
||||
|
||||
if self.weight_bits not in [4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 4/8-bit weight quantization is supported for "
|
||||
f"AWQMlu, but got {self.weight_bits} bits.")
|
||||
if self.version not in self.VERSION:
|
||||
raise ValueError(
|
||||
"Currently, only gemm, gemv version is supported for "
|
||||
f"AWQMlu, but got verion:{self.version}.")
|
||||
|
||||
if self.version in ["gemm"]:
|
||||
self.order_map = {4: [0, 2, 4, 6, 1, 3, 5, 7], 8: [0, 2, 1, 3]}
|
||||
self.reverse_order_map = {4 : [0, 4, 1, 5, 2, 6, 3, 7], 8: [0, 2, 1, 3]}
|
||||
else:
|
||||
self.order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
|
||||
self.reverse_order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQMluConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point}), "
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "awq_mlu"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 50
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["quant_config.json", "quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQMluConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
version = cls.get_from_keys_or(config, ["version"],
|
||||
default="gemm")
|
||||
return cls(weight_bits, group_size, zero_point, lm_head_quantized, version)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["AWQMluLinearMethod"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return AWQMluLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_awq_mlu_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "awq"
|
||||
or user_quant == "awq_mlu")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "awq":
|
||||
logger.info("Detected that the model can run with awq_mlu"
|
||||
", however you specified quantization=awq explicitly,"
|
||||
" so forcing awq. Use quantization=awq_mlu for"
|
||||
" faster inference")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_awq_mlu_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits", None)
|
||||
group_size = quant_config.get("group_size", None)
|
||||
has_zp = quant_config.get("zero_point", None)
|
||||
version = quant_config.get("version", "gemm")
|
||||
|
||||
if quant_method != "awq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or has_zp is None):
|
||||
return False
|
||||
|
||||
if num_bits not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
if version not in cls.VERSION:
|
||||
return False
|
||||
|
||||
return check_mlu_supported(quant_type=cls.TYPE_MAP[num_bits][has_zp],
|
||||
group_size=group_size,
|
||||
has_zp=has_zp)
|
||||
|
||||
class AWQMluLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQMlu.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQMlu quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQMluConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scales = GroupQuantScaleParameter(data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
packed_qweight, scale_zeros = self.extract_autoawq(layer)
|
||||
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
layer.qzeros = None
|
||||
layer.scales = None
|
||||
else:
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
if scale_zeros is not None:
|
||||
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
|
||||
else:
|
||||
layer.qzeros = None
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data.transpose(0, 1).contiguous(), requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.quant_config.zero_point and not self.quant_config.support_scale_zeros:
|
||||
output = mlu_ops.matmul(x, layer.qweight, bias)
|
||||
if residual is not None:
|
||||
output = output + residual
|
||||
else:
|
||||
output = mlu_ops.weight_only_quant_matmul(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
bias,
|
||||
residual,
|
||||
"none",
|
||||
self.quant_config.weight_bits)
|
||||
|
||||
return output
|
||||
|
||||
def extract_autoawq(self, layer: torch.nn.Module):
|
||||
qweight = layer.qweight.data
|
||||
qzeros = layer.qzeros.data
|
||||
scales = layer.scales.data
|
||||
bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
|
||||
# Unpack the qweight and qzeros tensors
|
||||
iweight, izeros = self.unpack_awq_int32_into_int8(qweight, qzeros, bits)
|
||||
# Reverse the order of the iweight and izeros tensors
|
||||
iweight, izeros = self.reverse_awq_order(iweight, izeros, bits)
|
||||
|
||||
# overflow checks
|
||||
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
|
||||
if izeros is not None:
|
||||
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
|
||||
|
||||
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
if izeros is not None:
|
||||
izeros = izeros.repeat_interleave(group_size, dim=0)
|
||||
fweight = (iweight - izeros) * scales
|
||||
else:
|
||||
fweight = iweight * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
fweight = fweight.transpose(0, 1)
|
||||
|
||||
return fweight, None
|
||||
|
||||
if self.quant_config.zero_point and self.quant_config.support_scale_zeros and izeros is not None:
|
||||
scale_zeros = izeros.to(scales.dtype) * -1 * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
scale_zeros = scale_zeros.transpose(0, 1)
|
||||
else:
|
||||
scale_zeros = None
|
||||
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
iweight = iweight.to(torch.int8).transpose(0, 1)
|
||||
|
||||
if bits == 4:
|
||||
higher_bit_tensor = iweight[:, 1::2]
|
||||
lower_bit_tensor = iweight[:, 0::2]
|
||||
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
|
||||
else:
|
||||
packed_qweight = iweight
|
||||
|
||||
return packed_qweight, scale_zeros
|
||||
|
||||
def unpack_awq_int32_into_int8(self, qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qweight.device)
|
||||
dtype = torch.int16 if bits == 8 else torch.int8
|
||||
# unpacking columnwise
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(dtype)
|
||||
iweights = iweights.view(iweights.shape[0], -1)
|
||||
if not self.quant_config.zero_point or self.quant_config.support_scale_zeros:
|
||||
iweights = torch.bitwise_and(iweights - 2**(bits - 1), (2 ** bits) - 1)
|
||||
|
||||
# unpacking columnwise
|
||||
if qzeros is not None:
|
||||
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(dtype)
|
||||
izeros = izeros.view(izeros.shape[0], -1)
|
||||
if not self.quant_config.zero_point:
|
||||
izeros = torch.bitwise_and(izeros - 2**(bits - 1), (2 ** bits) - 1)
|
||||
else:
|
||||
izeros = None
|
||||
|
||||
return iweights, izeros
|
||||
|
||||
def reverse_awq_order(self, iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
|
||||
reverse_order_tensor = torch.arange(iweights.shape[-1], dtype=torch.int32, device=iweights.device)
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
||||
reverse_order_tensor = reverse_order_tensor[:, self.quant_config.reverse_order_map[bits]]
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1)
|
||||
|
||||
rweights = iweights[:, reverse_order_tensor]
|
||||
if izeros is not None:
|
||||
rzeros = izeros[:, reverse_order_tensor]
|
||||
|
||||
return rweights, rzeros
|
||||
|
||||
def combine_low_bits(self, tensor_a, tensor_b):
|
||||
"""
|
||||
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
|
||||
|
||||
Args:
|
||||
tensor_a (torch.Tensor): First tensor of type int8.
|
||||
tensor_b (torch.Tensor): Second tensor of type int8.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
|
||||
"""
|
||||
# 确保输入是 int8 类型
|
||||
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
|
||||
raise ValueError("Both tensors must be of int8 type.")
|
||||
|
||||
# 提取每个 tensor 的低4位
|
||||
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
|
||||
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
|
||||
|
||||
# 将 tensor_a 的低4位左移4位
|
||||
shifted_low_bits_a = low_bits_a << 4
|
||||
|
||||
# 组合两个 tensor 的低4位
|
||||
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
|
||||
|
||||
return combined
|
||||
@@ -0,0 +1,441 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
|
||||
|
||||
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
|
||||
def query_mlu_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
if device_capability < 50:
|
||||
return []
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def check_mlu_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
supported_types = query_mlu_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"Mlu does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"device_capability = {device_capability}, zp = {has_zp}).")
|
||||
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
|
||||
return (False, f"Mlu does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class GPTQMluConfig(QuantizationConfig):
|
||||
"""Config class for GPTQMlu.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = Fraction(32, self.weight_bits)
|
||||
self.support_scale_zeros = False
|
||||
|
||||
if self.weight_bits not in [4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 4/8-bit weight quantization is "
|
||||
f"supported for GPTQMlu, but got {self.weight_bits} bits.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQMluConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}),"
|
||||
f"lm_head_quantized={self.lm_head_quantized}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "gptq_mlu"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 50
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quant_config.json", "quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMluConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym, lm_head_quantized)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQMluLinearMethod"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return GPTQMluLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
@classmethod
|
||||
def is_gptq_mlu_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits", None)
|
||||
group_size = quant_config.get("group_size", None)
|
||||
sym = quant_config.get("sym", None)
|
||||
desc_act = quant_config.get("desc_act", None)
|
||||
|
||||
if quant_method != "gptq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or sym is None
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_mlu_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
|
||||
group_size=group_size, has_zp=False)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_gptq_mlu_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
||||
or user_quant == "gptq_mlu")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
class GPTQMluLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQMlu.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQMlu quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQMluConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (output_size_per_partition % self.quant_config.pack_factor.numerator
|
||||
!= 0):
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if (input_size != input_size_per_partition) and (self.quant_config.group_size !=
|
||||
-1) and (not self.quant_config.desc_act):
|
||||
scale_and_zero_size = input_size_per_partition // group_size
|
||||
scale_and_zero_input_dim = 0
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
g_idx = RowvLLMParameter(data=torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if scale_and_zero_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.device = layer.qweight.data.device
|
||||
if self.quant_config.desc_act:
|
||||
g_idx_list = layer.g_idx.data.tolist()
|
||||
g_idx_unique = list(dict.fromkeys(g_idx_list))
|
||||
g_idx = torch.tensor(g_idx_unique, dtype=layer.g_idx.data.dtype, device=self.device)
|
||||
scales = layer.scales.data[g_idx]
|
||||
else:
|
||||
scales = layer.scales.data
|
||||
|
||||
packed_qweight, scale_zeros = self.extract_autogptq(layer, scales)
|
||||
if (not self.quant_config.is_sym) and (not self.quant_config.support_scale_zeros):
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
layer.qzeros = None
|
||||
layer.scales = None
|
||||
else:
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
if scale_zeros is not None:
|
||||
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
|
||||
else:
|
||||
layer.qzeros = None
|
||||
layer.scales = torch.nn.Parameter(scales.transpose(0, 1).contiguous(), requires_grad=False)
|
||||
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if (not self.quant_config.is_sym) and (not self.quant_config.support_scale_zeros):
|
||||
output = mlu_ops.matmul(x, layer.qweight, bias)
|
||||
if residual is not None:
|
||||
output = output + residual
|
||||
else:
|
||||
output = mlu_ops.weight_only_quant_matmul(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
bias,
|
||||
residual,
|
||||
"none",
|
||||
self.quant_config.weight_bits)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def extract_autogptq(self, layer: torch.nn.Module, scales: torch.Tensor):
|
||||
bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
# Unpack the qweight and qzeros tensors
|
||||
iweight = self.unpack_gptq_qweight_int32_into_int8(layer.qweight.data, bits)
|
||||
izeros = self.unpack_gptq_qzeros_int32_into_int8(layer.qzeros.data, bits)
|
||||
|
||||
# overflow checks
|
||||
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
|
||||
if izeros is not None:
|
||||
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
|
||||
|
||||
if not self.quant_config.is_sym and (not self.quant_config.support_scale_zeros):
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
if izeros is not None:
|
||||
izeros = izeros.repeat_interleave(group_size, dim=0)
|
||||
fweight = (iweight - izeros) * scales
|
||||
else:
|
||||
fweight = iweight * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
fweight = fweight.transpose(0, 1)
|
||||
|
||||
return fweight, None
|
||||
|
||||
if not self.quant_config.is_sym and self.quant_config.support_scale_zeros and izeros is not None:
|
||||
scale_zeros = izeros.to(scales.dtype) * -1 * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
scale_zeros = scale_zeros.transpose(0, 1)
|
||||
else:
|
||||
scale_zeros = None
|
||||
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
iweight = iweight.to(torch.int8).transpose(0, 1)
|
||||
|
||||
if bits == 4:
|
||||
higher_bit_tensor = iweight[:, 1::2]
|
||||
lower_bit_tensor = iweight[:, 0::2]
|
||||
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
|
||||
else:
|
||||
packed_qweight = iweight
|
||||
|
||||
return packed_qweight, scale_zeros
|
||||
|
||||
def unpack_gptq_qweight_int32_into_int8(self, qweight: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qweight.device)
|
||||
dtype = torch.int16 if bits == 8 else torch.int8
|
||||
# unpacking columnwise
|
||||
iweight = torch.bitwise_right_shift(qweight[:, None, :], shifts[None, :, None]).to(dtype)
|
||||
iweight = iweight.view(-1, iweight.shape[-1])
|
||||
# minus 2**(bit-1)
|
||||
if self.quant_config.is_sym or self.quant_config.support_scale_zeros:
|
||||
iweight = torch.bitwise_and(iweight - 2**(bits - 1), (2 ** bits) - 1)
|
||||
|
||||
return iweight
|
||||
|
||||
def unpack_gptq_qzeros_int32_into_int8(self, qzeros: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
||||
dtype = torch.int16 if bits == 8 else torch.int8
|
||||
# unpacking columnwise
|
||||
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(dtype)
|
||||
izeros = izeros.view(izeros.shape[0], -1)
|
||||
izeros = izeros + 1
|
||||
# minus 2**(bit-1)
|
||||
if self.quant_config.is_sym:
|
||||
izeros = torch.bitwise_and(izeros - 2**(bits - 1), (2 ** bits) - 1)
|
||||
|
||||
return izeros
|
||||
|
||||
def combine_low_bits(self, tensor_a, tensor_b):
|
||||
"""
|
||||
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
|
||||
|
||||
Args:
|
||||
tensor_a (torch.Tensor): First tensor of type int8.
|
||||
tensor_b (torch.Tensor): Second tensor of type int8.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
|
||||
"""
|
||||
# 确保输入是 int8 类型
|
||||
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
|
||||
raise ValueError("Both tensors must be of int8 type.")
|
||||
|
||||
# 提取每个 tensor 的低4位
|
||||
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
|
||||
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
|
||||
|
||||
# 将 tensor_a 的低4位左移4位
|
||||
shifted_low_bits_a = low_bits_a << 4
|
||||
|
||||
# 组合两个 tensor 的低4位
|
||||
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
|
||||
|
||||
return combined
|
||||
192
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/quantization/smoothquant.py
Executable file
192
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/quantization/smoothquant.py
Executable file
@@ -0,0 +1,192 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SmoothQuantConfig(QuantizationConfig):
|
||||
"""Config class for SmoothQuant.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
quant_mode: str, # smoothquant
|
||||
input_quant_method: str, # per token/per tensor
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.quant_mode = quant_mode
|
||||
self.input_quant_method = input_quant_method
|
||||
|
||||
if quant_mode == "SmoothQuant" and (self.weight_bits != 8):
|
||||
raise ValueError(
|
||||
"Currently, only 8-bit weight quantization is supported for "
|
||||
f"SmoothQuant, but got {self.weight_bits} bits.")
|
||||
if quant_mode == "SmoothQuant" and (self.input_quant_method != "per_token" and self.input_quant_method != "per_tensor"):
|
||||
raise ValueError(
|
||||
"Currently, only per_token or per_tensor input quantization is supported for "
|
||||
f"SmoothQuant, but got {self.input_quant_method}.")
|
||||
self.pack_factor = 8 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SmoothQuantConfig(weight_bits={self.weight_bits}, "
|
||||
f"input_quant_method={self.input_quant_method}, "
|
||||
f"quant_mode={self.quant_mode})")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "SmoothQuant"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
return 30
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
input_quant_method = cls.get_from_keys(config, ["input_quant_method"])
|
||||
quant_mode = cls.get_from_keys(config, ["quant_mode"])
|
||||
return cls(weight_bits, quant_mode, input_quant_method)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["SmoothQuantLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return SmoothQuantLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
|
||||
class SmoothQuantLinearMethod(LinearMethodBase):
|
||||
"""Linear method for SmoothQuant.
|
||||
|
||||
Args:
|
||||
quant_config: The SmoothQuant quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: SmoothQuantConfig):
|
||||
self.quant_config = quant_config
|
||||
# for per-tensor case, we can skip quant input for the first attn|ffn linear
|
||||
# and fusion this step in layernorm to get better performance
|
||||
self.skip_quant_input = False
|
||||
self.compute_dtype = torch.get_default_dtype()
|
||||
|
||||
def create_weights(
|
||||
self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> Dict[str, Any]:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if self.quant_config.quant_mode == "SmoothQuant":
|
||||
input_dim = None
|
||||
if input_size != input_size_per_partition:
|
||||
input_dim = 0
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
device="mlu",
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
per_channel_scale = Parameter(
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(per_channel_scale, {
|
||||
"input_dim": None,
|
||||
"output_dim": 0,
|
||||
})
|
||||
layer.register_parameter("qweight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("per_channel_scale", per_channel_scale)
|
||||
set_weight_attrs(per_channel_scale, extra_weight_attrs)
|
||||
if self.quant_config.input_quant_method == "per_token":
|
||||
smooth = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(smooth, {
|
||||
"input_dim": input_dim,
|
||||
"output_dim": None,
|
||||
"ignore_warning": True,
|
||||
})
|
||||
layer.register_parameter("smooth", smooth)
|
||||
set_weight_attrs(smooth, extra_weight_attrs)
|
||||
if self.quant_config.input_quant_method == "per_tensor":
|
||||
scale_to_int = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(scale_to_int, {
|
||||
"input_dim": input_dim,
|
||||
"output_dim": None,
|
||||
"ignore_warning": True,
|
||||
})
|
||||
layer.register_parameter("scale_to_int", scale_to_int)
|
||||
set_weight_attrs(scale_to_int, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if self.quant_config.input_quant_method == "per_token" and layer.smooth.dtype != torch.float:
|
||||
layer.smooth = Parameter(layer.smooth.to(torch.float), requires_grad=False)
|
||||
if self.quant_config.input_quant_method == "per_tensor" and layer.scale_to_int.dtype != torch.float:
|
||||
layer.scale_to_int = Parameter(layer.scale_to_int.to(torch.float), requires_grad=False)
|
||||
if layer.per_channel_scale.dtype != torch.float:
|
||||
layer.per_channel_scale = Parameter(layer.per_channel_scale.to(torch.float), requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
quant_input = None
|
||||
if self.skip_quant_input:
|
||||
quant_input = x
|
||||
elif self.quant_config.input_quant_method == "per_token":
|
||||
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer.smooth, None)
|
||||
elif self.quant_config.input_quant_method == "per_tensor":
|
||||
quant_input = mlu_ops.quantize(x, layer.scale_to_int, None)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Currently, only per_token or per_tensor input quantization is supported for "
|
||||
f"SmoothQuant, but got {self.input_quant_method}.")
|
||||
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer.qweight,
|
||||
layer.per_channel_scale, self.compute_dtype, bias, residual)
|
||||
return out
|
||||
143
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/quantization/weightonly.py
Executable file
143
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/quantization/weightonly.py
Executable file
@@ -0,0 +1,143 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class WeightOnlyConfig(QuantizationConfig):
|
||||
"""Config class for WeightOnly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
quant_mode: str, # weight_only
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.quant_mode = quant_mode
|
||||
|
||||
if quant_mode == "WeightOnly" and (self.weight_bits != 8 and self.weight_bits != 4):
|
||||
raise ValueError(
|
||||
"Currently, only 8/4-bit weight quantization is supported for "
|
||||
f"weight_only, but got {self.weight_bits} bits.")
|
||||
self.pack_factor = 8 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"WeightOnlyConfig(weight_bits={self.weight_bits}, "
|
||||
f"quant_mode={self.quant_mode})")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "WeightOnly"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
return 30
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "WeightOnlyConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
try:
|
||||
quant_mode = cls.get_from_keys(config, ["quant_mode"])
|
||||
except Exception:
|
||||
quant_mode = "WeightOnly"
|
||||
return cls(weight_bits, quant_mode)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["WeightOnlyLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return WeightOnlyLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
|
||||
class WeightOnlyLinearMethod(LinearMethodBase):
|
||||
"""Linear method for WeightOnly.
|
||||
|
||||
Args:
|
||||
quant_config: The WeightOnly quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: WeightOnlyConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> Dict[str, Any]:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if self.quant_config.quant_mode == "WeightOnly":
|
||||
scale_and_zero_input_dim = None
|
||||
if output_size != output_size_per_partition:
|
||||
scale_and_zero_input_dim = 0
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
device="mlu",
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(qweight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(scales, {
|
||||
"input_dim": scale_and_zero_input_dim,
|
||||
"output_dim": 0,
|
||||
})
|
||||
layer.register_parameter("qweight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("scales", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if layer.scales.dtype != torch.float:
|
||||
layer.scales = Parameter(layer.scales.to(torch.float), requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
out = mlu_ops.weight_only_quant_matmul(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
None,
|
||||
bias,
|
||||
residual,
|
||||
"none",
|
||||
self.quant_config.weight_bits)
|
||||
return out
|
||||
|
||||
@@ -0,0 +1,647 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
RotaryEmbedding, MRotaryEmbedding,
|
||||
LinearScalingRotaryEmbedding, DeepseekScalingRotaryEmbedding,
|
||||
DynamicNTKScalingRotaryEmbedding, DynamicNTKAlphaRotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding, Phi3LongRoPEScaledRotaryEmbedding,
|
||||
_yarn_find_correction_range, _ROPE_DICT, yarn_get_mscale, _yarn_linear_ramp_mask)
|
||||
from vllm.model_executor.layers import rotary_embedding
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.envs import VLLM_ALLOW_LONG_MAX_MODEL_LEN
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
def get_long_max_model_max_position_emb(max_position_embeddings, scaling_factor):
|
||||
if MLURotaryEmbedding.max_seq_len != None and \
|
||||
MLURotaryEmbedding.max_seq_len > max_position_embeddings * scaling_factor:
|
||||
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
|
||||
f"max_position_embedding ({max_position_embeddings}) * scaling_factor ({scaling_factor}) " +
|
||||
"from model's config.json, This may lead to incorrect model outputs or MLU errors. " +
|
||||
f"Make sure the value is correct and within the model context size. " +
|
||||
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
|
||||
return math.ceil(MLURotaryEmbedding.max_seq_len / scaling_factor)
|
||||
return max_position_embeddings
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding_mlu")
|
||||
class MLURotaryEmbedding(RotaryEmbedding, CustomOp):
|
||||
|
||||
cu_seq_lens : torch.Tensor = None
|
||||
max_seq_len : int = None
|
||||
is_prompt : bool = False
|
||||
is_chunked : bool = False
|
||||
set_cos_sin : bool = False
|
||||
cos_ : torch.Tensor = None
|
||||
sin_ : torch.Tensor = None
|
||||
positions_: torch.Tensor = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
CustomOp.__init__(self)
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
if MLURotaryEmbedding.max_seq_len != None \
|
||||
and self.max_position_embeddings < MLURotaryEmbedding.max_seq_len and \
|
||||
not isinstance(self, (YaRNScalingRotaryEmbedding, DeepseekScalingRotaryEmbedding)):
|
||||
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
|
||||
f"max_position_embedding ({max_position_embeddings}) from model's config.json, " +
|
||||
f"This may lead to incorrect model outputs or MLU errors. " +
|
||||
f"Make sure the value is correct and within the model context size. " +
|
||||
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
|
||||
self.max_position_embeddings = MLURotaryEmbedding.max_seq_len
|
||||
cache = self._compute_cos_sin_cache()
|
||||
if isinstance(self, MLULinearScalingRotaryEmbedding):
|
||||
logger.debug(f"Using mlu defining _compute_cos_sin_cache due to the special tensor composition")
|
||||
elif is_neox_style:
|
||||
cache_pos = cache.shape[0]
|
||||
cache = cache.reshape(cache_pos, 2, -1)
|
||||
cache = torch.tile(cache, (1, 1, 2)).reshape(cache_pos, -1)
|
||||
else:
|
||||
cache = cache.repeat_interleave(2, dim=-1)
|
||||
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
|
||||
@classmethod
|
||||
def set_mlu_var(
|
||||
cls,
|
||||
input_ids: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata
|
||||
) -> None:
|
||||
cls.unset_mlu_var()
|
||||
is_chunked = False
|
||||
is_prompt = False
|
||||
prefill_metadata = attn_metadata.prefill_metadata
|
||||
decode_metadata = attn_metadata.decode_metadata
|
||||
if prefill_metadata:
|
||||
cu_seq_lens = prefill_metadata.query_start_loc
|
||||
rope_max_seq_len = prefill_metadata.max_query_len
|
||||
is_prompt = True
|
||||
# Workaround: mlugraph does not support torch.ne|eq|equal .etc for now,
|
||||
# because context mlugraph always uses in benchmark latency, and in this
|
||||
# case, query_start_loc always equals to seq_start_loc, so we can set
|
||||
# is_chunked to False directly.
|
||||
if prefill_metadata.use_cuda_graph:
|
||||
is_chunked = False
|
||||
elif decode_metadata or \
|
||||
max(prefill_metadata.seq_lens) != prefill_metadata.max_query_len:
|
||||
is_chunked = True
|
||||
if decode_metadata:
|
||||
if prefill_metadata:
|
||||
cu_seq_lens = attn_metadata.query_start_loc
|
||||
rope_max_seq_len = max(rope_max_seq_len,
|
||||
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens)
|
||||
else:
|
||||
# input_ids is pack mode, and the decode_seq_len = 1
|
||||
cu_seq_lens = torch.arange(0, input_ids.shape[0] + 1, 1, dtype=torch.int32, device="mlu")
|
||||
rope_max_seq_len = 1
|
||||
cls.cu_seq_lens = cu_seq_lens
|
||||
cls.max_seq_len = rope_max_seq_len
|
||||
cls.is_prompt = is_prompt
|
||||
cls.is_chunked = is_chunked
|
||||
|
||||
@classmethod
|
||||
def unset_mlu_var(cls):
|
||||
cls.cu_seq_lens = None
|
||||
cls.max_seq_len = None
|
||||
cls.is_prompt = False
|
||||
cls.is_chunked = False
|
||||
cls.set_cos_sin = False
|
||||
cls.cos_ = None
|
||||
cls.sin_ = None
|
||||
cls.positions_ = None
|
||||
|
||||
def _get_cos_sin(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
|
||||
sin = sin.view(-1, self.rotary_dim)
|
||||
cos = cos.view(-1, self.rotary_dim)
|
||||
return cos, sin
|
||||
|
||||
def _get_positions_with_offsets_mlu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
offsets: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if offsets.numel() != positions.numel():
|
||||
raise Exception("rope offsets numel mismatch with positions, "
|
||||
f"positions: {positions.numel()}, offsets: {offsets.numel()}")
|
||||
return (positions + offsets).to(torch.int32)
|
||||
|
||||
def forward_mlu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if MLURotaryEmbedding.set_cos_sin == False:
|
||||
MLURotaryEmbedding.cos_, MLURotaryEmbedding.sin_ = self._get_cos_sin()
|
||||
MLURotaryEmbedding.set_cos_sin = True
|
||||
interleaved = True
|
||||
if self.is_neox_style:
|
||||
interleaved = False
|
||||
|
||||
if offsets is not None:
|
||||
if MLURotaryEmbedding.positions_ is None:
|
||||
MLURotaryEmbedding.positions_ = (
|
||||
self._get_positions_with_offsets_mlu(positions, offsets))
|
||||
position_ids = MLURotaryEmbedding.positions_
|
||||
discrete = True
|
||||
elif MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
|
||||
position_ids = positions
|
||||
discrete = True
|
||||
else:
|
||||
position_ids = None
|
||||
discrete = False
|
||||
|
||||
x = mlu_ops.rotary_embedding(x,
|
||||
MLURotaryEmbedding.sin_,
|
||||
MLURotaryEmbedding.cos_,
|
||||
position_ids,
|
||||
MLURotaryEmbedding.cu_seq_lens,
|
||||
interleaved,
|
||||
discrete,
|
||||
False,
|
||||
MLURotaryEmbedding.max_seq_len)
|
||||
return x
|
||||
|
||||
|
||||
class MLULinearScalingRotaryEmbedding(MLURotaryEmbedding, LinearScalingRotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factors: Union[List[float], float],
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
if isinstance(scaling_factors, float):
|
||||
scaling_factors = [scaling_factors]
|
||||
self.scaling_factors: List[float] = scaling_factors # noqa
|
||||
MLURotaryEmbedding.__init__(self, head_size, rotary_dim,
|
||||
max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
# Lazy initialized.
|
||||
self._scaling_factor_to_offset: Dict[float, int]
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
half_dim = self.rotary_dim // 2
|
||||
if self.is_neox_style:
|
||||
inv_freq = 1.0 / (base ** ((torch.arange(
|
||||
0, self.rotary_dim, 1, dtype=torch.float32, device="mlu") % half_dim) * 2 / self.rotary_dim)
|
||||
)
|
||||
else:
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** ( torch.arange(0, self.rotary_dim, 1, device="mlu", dtype=torch.float32) // 2 * 2
|
||||
/ self.rotary_dim
|
||||
)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
cache_list: List[torch.Tensor] = []
|
||||
# offsets to the next cache in a tensor.
|
||||
# Each offset corresponds to the same index in scaling_factors.
|
||||
offsets: List[int] = []
|
||||
for scaling_factor in self.scaling_factors:
|
||||
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||
# maximum length before applying the rope scaling.
|
||||
# Thus, the maximum length after applying the rope scaling is
|
||||
# self.max_position_embeddings * self.scaling_factor.
|
||||
max_len = self.max_position_embeddings * scaling_factor
|
||||
t = torch.arange(max_len, dtype=torch.float, device="mlu")
|
||||
t = t / scaling_factor
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
if not cache_list:
|
||||
offset = 0
|
||||
else:
|
||||
last_offset = offsets[-1]
|
||||
next_max_len = cache_list[-1].shape[0]
|
||||
offset = last_offset + next_max_len
|
||||
offsets.append(offset)
|
||||
cache_list.append(cache)
|
||||
self._scaling_factor_to_offset = {
|
||||
float(scaling_factor): offsets[i]
|
||||
for i, scaling_factor in enumerate(self.scaling_factors)
|
||||
}
|
||||
assert len(self.scaling_factors) == len(offsets)
|
||||
return torch.cat(cache_list, dim=0)
|
||||
|
||||
|
||||
class MLUDeepseekScalingRotaryEmbedding(MLURotaryEmbedding, DeepseekScalingRotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
||||
attn_factor)
|
||||
MLURotaryEmbedding.__init__(self, head_size, rotary_dim,
|
||||
max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
|
||||
def forward_mlu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
if MLURotaryEmbedding.set_cos_sin == False:
|
||||
MLURotaryEmbedding.cos_, MLURotaryEmbedding.sin_ = self._get_cos_sin()
|
||||
MLURotaryEmbedding.set_cos_sin = True
|
||||
interleaved = True
|
||||
if self.is_neox_style:
|
||||
interleaved = False
|
||||
if MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
|
||||
position_ids = positions
|
||||
discrete = True
|
||||
else :
|
||||
position_ids = None
|
||||
discrete = False
|
||||
query_rot = mlu_ops.rotary_embedding(query_rot,
|
||||
MLURotaryEmbedding.sin_,
|
||||
MLURotaryEmbedding.cos_,
|
||||
position_ids,
|
||||
MLURotaryEmbedding.cu_seq_lens,
|
||||
interleaved,
|
||||
discrete,
|
||||
False,
|
||||
MLURotaryEmbedding.max_seq_len)
|
||||
key_rot = mlu_ops.rotary_embedding(key_rot,
|
||||
MLURotaryEmbedding.sin_,
|
||||
MLURotaryEmbedding.cos_,
|
||||
position_ids,
|
||||
MLURotaryEmbedding.cu_seq_lens,
|
||||
interleaved,
|
||||
discrete,
|
||||
False,
|
||||
MLURotaryEmbedding.max_seq_len)
|
||||
if self.rotary_dim < self.head_size:
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
else:
|
||||
query = query_rot
|
||||
key = key_rot
|
||||
return query, key
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: change device cuda to mlu
|
||||
'''
|
||||
pos_freqs = self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device="mlu") /
|
||||
self.rotary_dim)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||
self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2,
|
||||
dtype=torch.float)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: change device cuda to mlu
|
||||
'''
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
device="mlu",
|
||||
dtype=torch.float32)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = (freqs.cos() * self.mscale)
|
||||
sin = (freqs.sin() * self.mscale)
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
print("Cache shape", cache.shape)
|
||||
return cache
|
||||
|
||||
|
||||
class MLULlama3RotaryEmbedding(MLURotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
scaling_factor: float,
|
||||
low_freq_factor: float,
|
||||
high_freq_factor: float,
|
||||
orig_max_position: int,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.low_freq_factor = low_freq_factor
|
||||
self.high_freq_factor = high_freq_factor
|
||||
self.orig_max_position = orig_max_position
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
|
||||
|
||||
class MLUDynamicNTKAlphaRotaryEmbedding(MLURotaryEmbedding, DynamicNTKAlphaRotaryEmbedding):
|
||||
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
||||
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_alpha: float,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
self.scaling_alpha = scaling_alpha
|
||||
MLURotaryEmbedding.__init__(
|
||||
self, head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
|
||||
|
||||
class MLUMRotaryEmbedding(MLURotaryEmbedding, MRotaryEmbedding):
|
||||
"""Rotary Embedding with Multimodal Sections."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
mrope_section: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
MLURotaryEmbedding.__init__(
|
||||
self, head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
|
||||
self.mrope_section = mrope_section
|
||||
if self.mrope_section:
|
||||
assert sum(self.mrope_section) == rotary_dim // 2
|
||||
|
||||
def forward_mlu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
num_tokens = positions.shape[-1]
|
||||
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if positions.ndim == 2:
|
||||
assert self.mrope_section
|
||||
num_section = len(self.mrope_section)
|
||||
mrope_section = self.mrope_section * 2
|
||||
cos = torch.cat([
|
||||
m[i % num_section]
|
||||
for i, m in enumerate(cos.split(mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
sin = torch.cat([
|
||||
m[i % num_section]
|
||||
for i, m in enumerate(sin.split(mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
interleaved = True
|
||||
if self.is_neox_style:
|
||||
interleaved = False
|
||||
position_ids = None
|
||||
discrete = False
|
||||
# mlu_ops.rotary_embedding() is a in-place operation that update the query and key tensors.
|
||||
x = mlu_ops.rotary_embedding(x,
|
||||
sin,
|
||||
cos,
|
||||
position_ids,
|
||||
MLURotaryEmbedding.cu_seq_lens,
|
||||
interleaved,
|
||||
discrete,
|
||||
False,
|
||||
MLURotaryEmbedding.max_seq_len)
|
||||
return x
|
||||
|
||||
|
||||
def vllm__model_executor__layers__rotary_embedding__get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if rope_scaling is not None:
|
||||
# Transforms every value that is a list into a tuple for caching calls
|
||||
rope_scaling_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v
|
||||
for k, v in rope_scaling.items()
|
||||
}
|
||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
rope_scaling_args, dtype)
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
|
||||
if rope_scaling is None:
|
||||
rotary_emb = MLURotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, dtype)
|
||||
else:
|
||||
scaling_type = rope_scaling["rope_type"]
|
||||
|
||||
if scaling_type == "llama3":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
low_freq_factor = rope_scaling["low_freq_factor"]
|
||||
high_freq_factor = rope_scaling["high_freq_factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
rotary_emb = MLULlama3RotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style, dtype,
|
||||
scaling_factor, low_freq_factor,
|
||||
high_freq_factor,
|
||||
original_max_position)
|
||||
elif scaling_type == "default":
|
||||
if "mrope_section" in rope_scaling:
|
||||
rotary_emb = MLUMRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
)
|
||||
else:
|
||||
rotary_emb = MLURotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "linear":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
rotary_emb = MLULinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style,
|
||||
scaling_factor, dtype)
|
||||
elif scaling_type == "dynamic":
|
||||
if "alpha" in rope_scaling:
|
||||
rotary_emb = MLUDynamicNTKAlphaRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
rope_scaling["alpha"], dtype)
|
||||
else:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
scaling_factor, dtype)
|
||||
elif scaling_type == "yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||
"beta_slow")
|
||||
}
|
||||
original_max_position = get_long_max_model_max_position_emb(original_max_position, scaling_factor)
|
||||
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
original_max_position,
|
||||
base, is_neox_style,
|
||||
scaling_factor, dtype,
|
||||
**extra_kwargs)
|
||||
elif scaling_type == "deepseek_yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
# assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||
"beta_slow", "mscale", "mscale_all_dim")
|
||||
}
|
||||
original_max_position = get_long_max_model_max_position_emb(original_max_position, scaling_factor)
|
||||
rotary_emb = MLUDeepseekScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, original_max_position, base,
|
||||
is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
||||
elif scaling_type == "longrope":
|
||||
short_factor = rope_scaling["short_factor"]
|
||||
long_factor = rope_scaling["long_factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("short_mscale", "long_mscale")
|
||||
}
|
||||
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, original_max_position,
|
||||
base, is_neox_style, dtype, short_factor, long_factor,
|
||||
**extra_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
_ROPE_DICT[key] = rotary_emb
|
||||
return rotary_emb
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(rotary_embedding,
|
||||
rotary_embedding.get_rope,
|
||||
vllm__model_executor__layers__rotary_embedding__get_rope)
|
||||
@@ -0,0 +1,468 @@
|
||||
"""Inference-only MOE model."""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_group,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.model_executor.layers.quantization.weightonly import WeightOnlyConfig
|
||||
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantConfig
|
||||
from vllm_mlu._mlu_utils import get_device_major_capability
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
class SparseMoeMlp(nn.Module):
|
||||
"""
|
||||
Tensor Parallel evenly splits each expert's weight and distributes them to different ranks,
|
||||
which means each rank holds partial weight of all experts.
|
||||
While Expert Parallel evenly distributes some of the experts' full weight to different ranks,
|
||||
which means each rank holds part of the experts' full weight.
|
||||
|
||||
As a result, each rank in the Tensor Parallel group receives all tokens' hidden states for all experts,
|
||||
then computes using the partial weights, while for Expert Parallel, each rank only receives
|
||||
part of tokens' hidden states for experts on this rank, then computes using the full weights.
|
||||
|
||||
When both Tensor Parallel and Expert Parallel are enabled, each rank handles
|
||||
a portion of the expert weights matrices (as in EP mode) and these weights are further sliced
|
||||
across ranks (as in TP mode). This hybrid approach aims to balance the workload more evenly across ranks,
|
||||
enhancing efficiency and reducing the likelihood of bottlenecks associated with EP mode alone.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
up_proj_name: str,
|
||||
is_gated: bool,
|
||||
down_proj_name: str,
|
||||
has_bias: bool,
|
||||
skip_bias_add: bool = False,
|
||||
renormalize:bool = False,
|
||||
hidden_act: str = "silu",
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
is_use_fused_moe: bool = False,
|
||||
expert_group: Optional[int] = 1,
|
||||
topk_group: Optional[int] = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = get_tensor_model_parallel_group()
|
||||
self.num_total_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.up_proj_name = up_proj_name
|
||||
self.is_gated = is_gated
|
||||
self.down_proj_name = down_proj_name
|
||||
self.has_bias = has_bias
|
||||
self.renormalize = renormalize
|
||||
self.hidden_act = hidden_act
|
||||
self.quant_config = quant_config
|
||||
self.is_use_fused_moe = is_use_fused_moe
|
||||
self.expert_group = expert_group
|
||||
self.topk_group = topk_group
|
||||
if get_device_major_capability() == 3:
|
||||
self.is_use_fused_moe = False
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
# NOTE: The bias for fc2 is only applied on tp_rank 0. If we added it on all nodes the allreduce() would
|
||||
# contain multiple copies of the bias. The bias on other node will be ignored, and may be set to nullptr
|
||||
self.skip_bias_add = True if self.tp_rank > 0 else False
|
||||
|
||||
assert self.intermediate_size % self.tp_size == 0, (
|
||||
f"need intermediate_size:{self.intermediate_size} % tp_size:{self.tp_size} == 0")
|
||||
|
||||
self.num_experts_per_rank = self.num_total_experts
|
||||
|
||||
self.start_expert_id = 0
|
||||
self.end_expert_id = self.start_expert_id + self.num_experts_per_rank
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.gate = ReplicatedLinear(self.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype,
|
||||
quant_config=None)
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
up_proj_name=self.up_proj_name,
|
||||
is_gated=self.is_gated,
|
||||
down_proj_name=self.down_proj_name,
|
||||
bias=self.has_bias,
|
||||
quant_config=self.quant_config,
|
||||
skip_bias_add=self.skip_bias_add,
|
||||
reduce_results=False) for idx in range(self.num_experts_per_rank)
|
||||
])
|
||||
|
||||
self.init_pack_param()
|
||||
|
||||
|
||||
def init_pack_param(self):
|
||||
self.w13 = None
|
||||
self.w2 = None
|
||||
self.b13 = None
|
||||
self.b2 = None
|
||||
self.w13_scale = None
|
||||
self.w2_scale = None
|
||||
self.a13_scale = None
|
||||
self.a2_scale = None
|
||||
self.pack_params_done = False
|
||||
|
||||
|
||||
def map_param_data(self, param_list, is_use_first_data=False):
|
||||
if len(param_list) == 0:
|
||||
return None
|
||||
|
||||
if is_use_first_data or len(param_list) == 1:
|
||||
first_data = param_list[0].data
|
||||
for param in param_list[1: -1]:
|
||||
param.data = first_data
|
||||
out_param = first_data.view_as(param_list[0])
|
||||
else:
|
||||
packed_param = torch._utils._flatten_dense_tensors(param_list)
|
||||
data_list = torch._utils._unflatten_dense_tensors(packed_param, param_list)
|
||||
for data, param in zip(data_list, param_list):
|
||||
param.data = data
|
||||
out_param = packed_param.view(len(param_list), *data_list[0].shape)
|
||||
|
||||
torch.mlu.empty_cache()
|
||||
|
||||
return out_param
|
||||
|
||||
|
||||
def pack_unquantized_params(self, w13, w2, b13, b2):
|
||||
for expert in self.experts:
|
||||
up_proj = getattr(expert, self.up_proj_name)
|
||||
down_proj = getattr(expert, self.down_proj_name)
|
||||
w13.append(up_proj.weight)
|
||||
w2.append(down_proj.weight)
|
||||
if self.has_bias:
|
||||
b13.append(up_proj.bias)
|
||||
b2.append(down_proj.bias)
|
||||
|
||||
|
||||
def pack_smoothquant_params(self, w13, w2, b13, b2, w13_scale, w2_scale, a13_scale, a2_scale):
|
||||
for expert in self.experts:
|
||||
up_proj = getattr(expert, self.up_proj_name)
|
||||
down_proj = getattr(expert, self.down_proj_name)
|
||||
w13.append(up_proj.qweight)
|
||||
w2.append(down_proj.qweight)
|
||||
if self.has_bias:
|
||||
b13.append(up_proj.bias)
|
||||
b2.append(down_proj.bias)
|
||||
w13_scale.append(up_proj.per_channel_scale)
|
||||
w2_scale.append(down_proj.per_channel_scale)
|
||||
if self.quant_config.input_quant_method == "per_token":
|
||||
a13_scale.append(up_proj.smooth)
|
||||
a2_scale.append(down_proj.smooth)
|
||||
else:
|
||||
a13_scale.append(up_proj.scale_to_int)
|
||||
a2_scale.append(down_proj.scale_to_int)
|
||||
|
||||
|
||||
def pack_weightonly_params(self, w13, w2, b13, b2, w13_scale, w2_scale):
|
||||
for expert in self.experts:
|
||||
up_proj = getattr(expert, self.up_proj_name)
|
||||
down_proj = getattr(expert, self.down_proj_name)
|
||||
w13.append(up_proj.qweight)
|
||||
w2.append(down_proj.qweight)
|
||||
if self.has_bias:
|
||||
b13.append(up_proj.bias)
|
||||
b2.append(down_proj.bias)
|
||||
w13_scale.append(up_proj.per_channel_scale)
|
||||
w2_scale.append(up_proj.per_channel_scale)
|
||||
|
||||
|
||||
def pack_params(self):
|
||||
if self.pack_params_done:
|
||||
return
|
||||
|
||||
w13 = []
|
||||
w2 = []
|
||||
b13 = []
|
||||
b2 = []
|
||||
w13_scale = []
|
||||
w2_scale = []
|
||||
a13_scale = []
|
||||
a2_scale = []
|
||||
|
||||
if self.quant_config is None:
|
||||
self.pack_unquantized_params(w13, w2, b13, b2)
|
||||
elif isinstance(self.quant_config, SmoothQuantConfig):
|
||||
self.pack_smoothquant_params(w13, w2, b13, b2, w13_scale, w2_scale, a13_scale, a2_scale)
|
||||
elif isinstance(self.quant_config, WeightOnlyConfig):
|
||||
self.pack_weightonly_params(w13, w2, b13, b2, w13_scale, w2_scale)
|
||||
else:
|
||||
raise ValueError(f'Unsupported quantization:{self.quant_config}')
|
||||
|
||||
# pack weigth
|
||||
self.w13 = self.map_param_data(w13)
|
||||
self.w2 = self.map_param_data(w2)
|
||||
|
||||
# pack bias
|
||||
if self.has_bias:
|
||||
self.b13 = self.map_param_data(b13)
|
||||
# NOTE: The bias for fc2 is only applied on tp_rank 0. If we added it on all nodes the allreduce() would
|
||||
# contain multiple copies of the bias. The bias on other node will be ignored, and may be set to nullptr
|
||||
if self.skip_bias_add is False:
|
||||
self.b2 = self.map_param_data(b2)
|
||||
|
||||
|
||||
# pack weight scale
|
||||
if len(w13_scale) > 0:
|
||||
self.w13_scale = self.map_param_data(w13_scale)
|
||||
if len(w2_scale) > 0:
|
||||
self.w2_scale = self.map_param_data(w2_scale)
|
||||
|
||||
# pack activate scale
|
||||
if len(a13_scale) > 0:
|
||||
self.a13_scale = self.map_param_data(a13_scale)
|
||||
if len(a2_scale) > 0:
|
||||
self.a2_scale = self.map_param_data(a2_scale)
|
||||
|
||||
self.pack_params_done = True
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
orig_hidden_states_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# expert_logits: [num_tokens, self.num_experts_per_rank]
|
||||
expert_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.forward_experts(hidden_states, expert_logits, residual)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
output = final_hidden_states.view(orig_hidden_states_shape)
|
||||
return output
|
||||
|
||||
|
||||
def forward_experts(self, hidden_states, expert_logits, residual: Optional[torch.Tensor] = None):
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
if self.is_use_fused_moe and self.expert_group != 1:
|
||||
final_hidden_states = self.forward_group_experts(hidden_states, expert_logits, residual_)
|
||||
elif self.is_use_fused_moe:
|
||||
self.pack_params()
|
||||
final_hidden_states = mlu_ops.fused_moe(hidden_states=hidden_states,
|
||||
gating_output=expert_logits,
|
||||
w1=self.w13,
|
||||
w2=self.w2,
|
||||
bias1=self.b13,
|
||||
bias2=self.b2,
|
||||
residual=residual_,
|
||||
input_smooth=self.a13_scale,
|
||||
act_smooth=self.a2_scale,
|
||||
w1_scale=self.w13_scale,
|
||||
w2_scale=self.w2_scale,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
gated=self.is_gated,
|
||||
act_mode=self.hidden_act,
|
||||
start_expert_id=self.start_expert_id)
|
||||
else:
|
||||
final_hidden_states = self.forward_experts_nofused(hidden_states, expert_logits)
|
||||
if residual_ is not None:
|
||||
final_hidden_states = final_hidden_states + residual_
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def forward_experts_nofused(self, hidden_states, expert_logits):
|
||||
hidden_states_shape = hidden_states.shape
|
||||
topk_values, topk_indices = self.topk_softmax(expert_logits)
|
||||
expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count = self.generate_gather_idx(
|
||||
topk_indices)
|
||||
# no expert is routed, then expand_gather_idx, expand_scatter_idx has no item,
|
||||
# expand_token_count and expand_cusum_token_count has item but the value is all zero
|
||||
# so this rank should only return final_hidden_states with zero value
|
||||
if expand_gather_idx.numel() == 0:
|
||||
final_hidden_states = torch.zeros_like(hidden_states,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
return final_hidden_states
|
||||
|
||||
expand_hidden_states = self.expand_input(hidden_states, expand_gather_idx)
|
||||
|
||||
expand_output_list = []
|
||||
expand_cusum_token_count = cusum_token_count[self.start_expert_id:self.end_expert_id +
|
||||
1] - cusum_token_count[self.start_expert_id]
|
||||
for expert_idx, num_tokens_per_expert in enumerate(expand_token_count):
|
||||
if num_tokens_per_expert > 0:
|
||||
expert_hidden_states = expand_hidden_states[
|
||||
expand_cusum_token_count[expert_idx]:expand_cusum_token_count[expert_idx + 1]]
|
||||
expert_output = self.experts[expert_idx](expert_hidden_states)
|
||||
expert_output = expert_output[0] if isinstance(expert_output, (tuple, list)) else expert_output
|
||||
expand_output_list.append(expert_output)
|
||||
expand_output = torch.cat(expand_output_list, dim=0)
|
||||
final_hidden_states = self.combine_moe(expand_output, scatter_idx, cusum_token_count, hidden_states_shape,
|
||||
topk_values)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
||||
def forward_group_experts(self, hidden_states, expert_logits, residual_):
|
||||
ori_input_shape = hidden_states.shape
|
||||
dtype = hidden_states.dtype
|
||||
self.pack_params()
|
||||
gating_output=expert_logits.to(torch.float32)
|
||||
w1=self.w13
|
||||
w2=self.w2
|
||||
bias1=self.b13
|
||||
bias2=self.b2
|
||||
input_smooth=self.a13_scale
|
||||
act_smooth=self.a2_scale
|
||||
w1_scale=self.w13_scale
|
||||
w2_scale=self.w2_scale
|
||||
topk=self.top_k
|
||||
renormalized=self.renormalize
|
||||
gated=self.is_gated
|
||||
act_mode=self.hidden_act
|
||||
|
||||
start_expert_id=self.start_expert_id
|
||||
expert_num = gating_output.size(-1)
|
||||
expert_size = w1.size(0)
|
||||
max_m = hidden_states.shape[0]
|
||||
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||||
gating_output = gating_output.view(-1, gating_output.size(-1))
|
||||
residual_ = residual_.view(-1, residual_.size(-1)) if residual_ is not None else None
|
||||
per_token_sq = False
|
||||
# check quant
|
||||
check_list = [input_smooth, act_smooth, w1_scale, w2_scale]
|
||||
if all(x is not None for x in check_list):
|
||||
per_token_sq = True
|
||||
|
||||
if not (all(x is None for x in check_list) or all(x is not None for x in check_list)):
|
||||
raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present "
|
||||
"and absent at the same time.")
|
||||
# softmax_topk
|
||||
reduce_weight, expert_id = mlu_ops.moe_softmax_topk(gating_output,
|
||||
topk, renormalized)
|
||||
# gen_idx
|
||||
expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(expert_id, expert_num)
|
||||
# check quant
|
||||
if per_token_sq:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
if major == 3:
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(hidden_states, expand_idx,
|
||||
cusum_token_count, start_expert_id, expert_size)
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(expand_hidden_states,
|
||||
input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size])
|
||||
else:
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(hidden_states,
|
||||
input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size], expand_idx,
|
||||
cusum_token_count[start_expert_id].unsqueeze(0))
|
||||
else:
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(hidden_states, expand_idx,
|
||||
cusum_token_count, start_expert_id, expert_size)
|
||||
|
||||
if per_token_sq:
|
||||
gemm1_out = mlu_ops.smooth_quant_group_gemm(quant_input, w1,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None,
|
||||
input_scale, w1_scale, dtype, max_m)
|
||||
else:
|
||||
gemm1_out = mlu_ops.group_gemm(expand_hidden_states, w1,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None, max_m)
|
||||
# add_bias_active
|
||||
act_out = mlu_ops.moe_active(gemm1_out, act_mode, gated, None, bias1, cusum_token_count, start_expert_id, expert_size)
|
||||
if per_token_sq:
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(act_out, act_smooth, None,
|
||||
token_count[start_expert_id:start_expert_id+expert_size])
|
||||
if per_token_sq:
|
||||
gemm2_out = mlu_ops.smooth_quant_group_gemm(quant_input, w2,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None, input_scale, w2_scale, dtype, max_m)
|
||||
else:
|
||||
gemm2_out = mlu_ops.group_gemm(act_out, w2,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None, max_m)
|
||||
|
||||
output = mlu_ops.moe_combine_result(gemm2_out, reduce_weight, combine_idx,
|
||||
residual_, cusum_token_count, start_expert_id,
|
||||
expert_size, bias2)
|
||||
return output.view(ori_input_shape)
|
||||
|
||||
|
||||
def topk_softmax(self, expert_logits):
|
||||
# expert_logits: [num_tokens, self.num_experts_per_rank]
|
||||
# topk_values: [num_tokens, self.top_k]
|
||||
# topk_indices: [num_tokens, self.top_k]
|
||||
if self.renormalize:
|
||||
topk_values, topk_indices = torch.topk(expert_logits, self.top_k, dim=-1)
|
||||
topk_values = torch.softmax(topk_values, -1)
|
||||
else:
|
||||
router_probs = torch.softmax(expert_logits, -1)
|
||||
topk_values, topk_indices = torch.topk(router_probs, self.top_k, dim=-1)
|
||||
|
||||
return topk_values, topk_indices
|
||||
|
||||
|
||||
def generate_gather_idx(self, topk_indices):
|
||||
device = topk_indices.device
|
||||
# gather_expand_idx: [num_tokens * self.top_k]
|
||||
sorted_expert_id, indices = topk_indices.flatten().sort()
|
||||
gather_idx = indices // self.top_k
|
||||
|
||||
seqs = torch.arange(indices.numel(), dtype=indices.dtype, device=indices.device)
|
||||
scatter_idx=torch.zeros((indices.numel(),), dtype=seqs.dtype, device=seqs.device).scatter(0, indices, seqs)
|
||||
|
||||
# token_count: [self.num_experts_per_rank]
|
||||
partial_token_index, partial_token_count = sorted_expert_id.unique(sorted=True, return_counts=True)
|
||||
zero_token_count = torch.zeros(self.num_total_experts, dtype=partial_token_count.dtype, device=device)
|
||||
token_count = zero_token_count.scatter(dim=0, index=partial_token_index, src=partial_token_count)
|
||||
# cusum_token_count: [self.num_experts_per_rank + 1]
|
||||
cusum_token_count = torch.cat(
|
||||
[torch.tensor([0], dtype=token_count.dtype, device=device),
|
||||
token_count.cumsum(dim=0)])
|
||||
|
||||
num_tokens_before_expert = cusum_token_count[self.start_expert_id]
|
||||
num_tokens_including_expert = cusum_token_count[self.end_expert_id]
|
||||
|
||||
expand_gather_idx = gather_idx[num_tokens_before_expert:num_tokens_including_expert]
|
||||
expand_token_count = token_count[self.start_expert_id:self.end_expert_id]
|
||||
|
||||
return expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count
|
||||
|
||||
|
||||
def expand_input(self, hidden_states, expand_gather_idx):
|
||||
expand_hidden_states = hidden_states[expand_gather_idx]
|
||||
return expand_hidden_states
|
||||
|
||||
|
||||
def combine_moe(self, expand_output, scatter_idx, cusum_token_count, hidden_states_shape, topk_values):
|
||||
num_tokens, hidden_size = hidden_states_shape
|
||||
num_tokens_before_expert = cusum_token_count[self.start_expert_id]
|
||||
num_tokens_after_expert = cusum_token_count[-1] - cusum_token_count[self.end_expert_id]
|
||||
|
||||
expand_output_before_expert = torch.zeros((num_tokens_before_expert, hidden_size),
|
||||
dtype=expand_output.dtype,
|
||||
device=expand_output.device)
|
||||
expand_output_after_expert = torch.zeros((num_tokens_after_expert, hidden_size),
|
||||
dtype=expand_output.dtype,
|
||||
device=expand_output.device)
|
||||
unscatted_output = torch.cat([expand_output_before_expert, expand_output, expand_output_after_expert], dim=0)
|
||||
scatter_output = unscatted_output[scatter_idx]
|
||||
hidden_states_weight = topk_values.flatten().unsqueeze(-1)
|
||||
weighted_hidden_states = scatter_output * hidden_states_weight
|
||||
unreduced_hidden_states = weighted_hidden_states.view(num_tokens, self.top_k, hidden_size)
|
||||
final_hidden_states = unreduced_hidden_states.sum(dim=1)
|
||||
|
||||
return final_hidden_states
|
||||
@@ -0,0 +1,42 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def vllm__model_executor__layers__spec_decode_base_sampler__SpecDecodeBaseSampler__init_gpu_tensors(
|
||||
self, device: Union[int, str]
|
||||
) -> None:
|
||||
assert self.num_accepted_tokens is None
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add mlu device support.
|
||||
'''
|
||||
if isinstance(device, int) and current_platform.is_mlu():
|
||||
device = f"mlu:{device}"
|
||||
elif isinstance(device, int) and current_platform.is_cuda():
|
||||
device = f"cuda:{device}"
|
||||
elif not isinstance(device, str):
|
||||
raise ValueError(f"Device must be int or str, get {type(device)}")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
SpecDecodeBaseSampler,
|
||||
SpecDecodeBaseSampler.init_gpu_tensors,
|
||||
vllm__model_executor__layers__spec_decode_base_sampler__SpecDecodeBaseSampler__init_gpu_tensors
|
||||
)
|
||||
@@ -0,0 +1,2 @@
|
||||
import vllm_mlu.model_executor.model_loader.loader
|
||||
import vllm_mlu.model_executor.model_loader.tensorizer
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,138 @@
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from safetensors.torch import safe_open
|
||||
from typing import List, Tuple, Generator
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (np_cache_weights_iterator,
|
||||
_BAR_FORMAT)
|
||||
from vllm.config import LoadFormat
|
||||
from vllm.platforms import current_platform
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu._mlu_utils import get_device_major_capability
|
||||
|
||||
|
||||
CAST_BFLOAT16_TO_FLOAT16_ENABLE = (get_device_major_capability() == 3)
|
||||
|
||||
|
||||
def vllm__model_executor__model_loader__weight_utils__safetensors_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading safetensors checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: cast bfloat16 to float16 for MLU3xx
|
||||
'''
|
||||
if CAST_BFLOAT16_TO_FLOAT16_ENABLE and param.dtype == torch.bfloat16:
|
||||
param = param.to(torch.float16)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
yield name, param
|
||||
|
||||
|
||||
def vllm__model_executor__model_loader__weight_utils__pt_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model bin/pt files."""
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading pt checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: cast bfloat16 to float16 for MLU3xx
|
||||
'''
|
||||
if CAST_BFLOAT16_TO_FLOAT16_ENABLE and param.dtype == torch.bfloat16:
|
||||
param = param.to(torch.float16)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
yield name, param
|
||||
del state
|
||||
torch.mlu.empty_cache()
|
||||
|
||||
|
||||
def vllm__model_executor__model_loader__loader__DefaultModelLoader___get_weights_iterator(
|
||||
self, source: "DefaultModelLoader.Source"
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path, source.revision, source.fall_back_to_pt)
|
||||
if self.load_config.load_format == LoadFormat.NPCACHE:
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
weights_iterator = np_cache_weights_iterator(
|
||||
source.model_or_path, self.load_config.download_dir, hf_folder,
|
||||
hf_weights_files)
|
||||
elif use_safetensors:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: cast bfloat16 to float16 for MLU3xx
|
||||
'''
|
||||
weights_iterator = vllm__model_executor__model_loader__weight_utils__safetensors_weights_iterator(hf_weights_files)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: cast bfloat16 to float16 for MLU3xx
|
||||
'''
|
||||
weights_iterator = vllm__model_executor__model_loader__weight_utils__pt_weights_iterator(hf_weights_files)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if current_platform.is_tpu():
|
||||
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
|
||||
# not too many ops are accumulated in the XLA program.
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
xm.mark_step()
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
|
||||
# Apply the prefix.
|
||||
return ((source.prefix + name, tensor)
|
||||
for (name, tensor) in weights_iterator)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(DefaultModelLoader,
|
||||
DefaultModelLoader._get_weights_iterator,
|
||||
vllm__model_executor__model_loader__loader__DefaultModelLoader___get_weights_iterator)
|
||||
@@ -0,0 +1,68 @@
|
||||
import time
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.model_loader.tensorizer import (TensorizerAgent,
|
||||
TensorDeserializer,
|
||||
get_mem_usage,
|
||||
_read_stream,
|
||||
convert_bytes)
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__model_executor__model_loader__tensorizer__TensorizerAgent__deserialize(self):
|
||||
"""
|
||||
Deserialize the model using the TensorDeserializer. This method is
|
||||
specifically for vLLM models using tensorizer's plaid_mode.
|
||||
|
||||
The deserializer makes use of tensorizer_args.stream_params
|
||||
to configure the behavior of the stream when loading tensors from a
|
||||
serialized model. The deserializer_params are used to configure the
|
||||
behavior of the TensorDeserializer when loading tensors themselves.
|
||||
Documentation on these params can be found in TensorizerArgs
|
||||
|
||||
Returns:
|
||||
nn.Module: The deserialized model.
|
||||
"""
|
||||
before_mem = get_mem_usage()
|
||||
start = time.perf_counter()
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use mlu device
|
||||
'''
|
||||
with _read_stream(
|
||||
self.tensorizer_config.tensorizer_uri,
|
||||
**self.tensorizer_args.stream_params
|
||||
) as stream, TensorDeserializer(
|
||||
stream,
|
||||
dtype=self.tensorizer_config.dtype,
|
||||
device=f'mlu:{torch.mlu.current_device()}',
|
||||
**self.tensorizer_args.deserializer_params) as deserializer:
|
||||
deserializer.load_into_module(self.model)
|
||||
end = time.perf_counter()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||
duration = end - start
|
||||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||
after_mem = get_mem_usage()
|
||||
deserializer.close()
|
||||
logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
|
||||
end - start, per_second)
|
||||
logger.info("Memory usage before: %s", before_mem)
|
||||
logger.info("Memory usage after: %s", after_mem)
|
||||
|
||||
self._check_tensors_on_meta_device()
|
||||
self._resize_lora_embeddings()
|
||||
del self.model.vllm_tensorized_marker
|
||||
return self.model.eval()
|
||||
|
||||
MluHijackObject.apply_hijack(TensorizerAgent,
|
||||
TensorizerAgent.deserialize,
|
||||
vllm__model_executor__model_loader__tensorizer__TensorizerAgent__deserialize)
|
||||
41
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/__init__.py
Executable file
41
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/__init__.py
Executable file
@@ -0,0 +1,41 @@
|
||||
import vllm_mlu.model_executor.models.deepseek_v2
|
||||
import vllm_mlu.model_executor.models.baichuan
|
||||
import vllm_mlu.model_executor.models.bloom
|
||||
import vllm_mlu.model_executor.models.chatglm
|
||||
|
||||
# Multimodal models - may fail with older transformers versions
|
||||
try:
|
||||
import vllm_mlu.model_executor.models.clip
|
||||
except ImportError as e:
|
||||
import logging
|
||||
logging.warning(f"Failed to import clip hijack: {e}")
|
||||
|
||||
import vllm_mlu.model_executor.models.gpt_neox
|
||||
import vllm_mlu.model_executor.models.llama
|
||||
import vllm_mlu.model_executor.models.mixtral
|
||||
import vllm_mlu.model_executor.models.qwen
|
||||
import vllm_mlu.model_executor.models.qwen2
|
||||
import vllm_mlu.model_executor.models.qwen2_moe
|
||||
|
||||
try:
|
||||
import vllm_mlu.model_executor.models.qwen2_vl
|
||||
except ImportError as e:
|
||||
import logging
|
||||
logging.warning(f"Failed to import qwen2_vl hijack: {e}")
|
||||
|
||||
try:
|
||||
import vllm_mlu.model_executor.models.qwen3
|
||||
except ImportError as e:
|
||||
import logging
|
||||
logging.warning(f"Failed to import qwen3 hijack: {e}")
|
||||
|
||||
import vllm_mlu.model_executor.models.falcon
|
||||
import vllm_mlu.model_executor.models.internlm2
|
||||
import vllm_mlu.model_executor.models.hunyuan
|
||||
import vllm_mlu.model_executor.models.layer_utils
|
||||
|
||||
try:
|
||||
import vllm_mlu.model_executor.models.mllama
|
||||
except ImportError as e:
|
||||
import logging
|
||||
logging.warning(f"Failed to import mllama hijack: {e}")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
309
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/baichuan.py
Normal file
309
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/baichuan.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import torch
|
||||
from typing import List, Optional, Union
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.models.baichuan import (
|
||||
_get_alibi_slopes, BaiChuanAttention,
|
||||
BaiChuanDecoderLayer, BaiChuanModel)
|
||||
|
||||
from vllm_mlu.model_executor.models.layer_utils import (
|
||||
decoder_layer_forward_base, decoder_model_forward_base_pp,
|
||||
is_per_tensor_smoothquant, is_per_token_smoothquant,
|
||||
quant_fusion_with_rmsnorm)
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__module_executor__models__baichuan__BaiChuanAttention__init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
position_embedding: str,
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super(BaiChuanAttention, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.postion_embedding = position_embedding
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.W_pack = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
# Create the alibi slopes and slice them.
|
||||
if self.postion_embedding == "ALIBI":
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
head_end = (tp_rank + 1) * self.num_heads
|
||||
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
|
||||
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
|
||||
scaling = self.head_dim**-0.5
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add cache_config to support kv8
|
||||
'''
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
scaling,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
|
||||
def vllm__module_executor__models__baichuan__BaiChuanAttention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.W_pack(hidden_states, smooth_quant_scale)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
qk, _ = qkv.split([self.num_heads * self.head_dim * 2, self.num_heads * self.head_dim], dim=-1)
|
||||
if self.postion_embedding != "ALIBI":
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual
|
||||
'''
|
||||
output, _ = self.o_proj(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
|
||||
def vllm__module_executor__models__baichuan__BaiChuanDecoderLayer__init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None
|
||||
):
|
||||
super(BaiChuanDecoderLayer, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = BaiChuanAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
position_embedding=position_embedding,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use FeedForward instead of MLP
|
||||
'''
|
||||
self.mlp = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act='silu',
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=True,
|
||||
down_proj_name='down_proj',
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: prepare to perf per-tensor sq cases if suitable
|
||||
'''
|
||||
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
|
||||
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
self.self_attn.W_pack.quant_method.skip_quant_input = True
|
||||
self.mlp.gate_up_proj.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_attn_layernorm = None
|
||||
self.quant_fusion_mlp_layernorm = None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__baichuan__BaiChuanDecoderLayer__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: perf model by:
|
||||
1) add residual in matmul;
|
||||
2) fuse quantization in layernorm in per-tensor sq case;
|
||||
'''
|
||||
attn_layernorm = self.input_layernorm
|
||||
mlp_layernorm = self.post_attention_layernorm
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
if self.is_per_token_sq_perf_cases:
|
||||
attn_quant_scale = self.self_attn.W_pack.smooth
|
||||
mlp_quant_scale = self.mlp.gate_up_proj.smooth
|
||||
else:
|
||||
attn_quant_scale = self.self_attn.W_pack.scale_to_int
|
||||
mlp_quant_scale = self.mlp.gate_up_proj.scale_to_int
|
||||
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.input_layernorm, attn_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
self.quant_fusion_mlp_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.post_attention_layernorm, mlp_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
mlp_layernorm = self.quant_fusion_mlp_layernorm
|
||||
|
||||
return decoder_layer_forward_base(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
input_layernorm=attn_layernorm,
|
||||
self_attn=self.self_attn,
|
||||
post_layernorm=mlp_layernorm,
|
||||
mlp=self.mlp,
|
||||
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
|
||||
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__baichuan__BaiChuanModel__forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
return decoder_model_forward_base_pp(input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
layers=self.layers,
|
||||
start_layer=self.start_layer,
|
||||
end_layer=self.end_layer,
|
||||
get_input_embeddings=self.embed_tokens,
|
||||
norm=self.norm)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(BaiChuanAttention,
|
||||
BaiChuanAttention.__init__,
|
||||
vllm__module_executor__models__baichuan__BaiChuanAttention__init__)
|
||||
MluHijackObject.apply_hijack(BaiChuanAttention,
|
||||
BaiChuanAttention.forward,
|
||||
vllm__module_executor__models__baichuan__BaiChuanAttention__forward)
|
||||
MluHijackObject.apply_hijack(BaiChuanDecoderLayer,
|
||||
BaiChuanDecoderLayer.__init__,
|
||||
vllm__module_executor__models__baichuan__BaiChuanDecoderLayer__init__)
|
||||
MluHijackObject.apply_hijack(BaiChuanDecoderLayer,
|
||||
BaiChuanDecoderLayer.forward,
|
||||
vllm__module_executor__models__baichuan__BaiChuanDecoderLayer__forward)
|
||||
MluHijackObject.apply_hijack(BaiChuanModel,
|
||||
BaiChuanModel.forward,
|
||||
vllm__module_executor__models__baichuan__BaiChuanModel__forward)
|
||||
170
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/bloom.py
Normal file
170
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/bloom.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BloomConfig
|
||||
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
from vllm.model_executor.models.bloom import BloomAttention, BloomBlock
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm_mlu.model_executor.models.layer_utils import (
|
||||
decoder_layer_forward_base,
|
||||
is_per_tensor_smoothquant,
|
||||
is_per_token_smoothquant,
|
||||
quant_fusion_with_layernorm
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
def vllm__module_executor__models__bloom__BloomAttention__forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
del position_ids # Unused.
|
||||
qkv, _ = self.query_key_value(hidden_states, smooth_quant_scale)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
output, _ = self.dense(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
def vllm__module_executor__models__bloom__BloomBlock__init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super(BloomBlock, self).__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.self_attention = BloomAttention(config, cache_config,
|
||||
quant_config)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
hidden_size, eps=config.layer_norm_epsilon)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use FeedForward instead of MLP
|
||||
'''
|
||||
self.mlp = FeedForward(hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 4,
|
||||
hidden_act='gelu',
|
||||
up_proj_name="dense_h_to_4h",
|
||||
is_gated=False,
|
||||
down_proj_name="dense_4h_to_h",
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
config.apply_residual_connection_post_layernorm)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: prepare to perf sq cases if suitable
|
||||
'''
|
||||
self.is_per_tesnor_sq_perf_cases = (is_per_tensor_smoothquant(quant_config) and
|
||||
not self.apply_residual_connection_post_layernorm)
|
||||
self.is_per_token_sq_perf_cases = (is_per_token_smoothquant(quant_config) and
|
||||
not self.apply_residual_connection_post_layernorm)
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
self.self_attention.query_key_value.quant_method.skip_quant_input = True
|
||||
self.mlp.dense_h_to_4h.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_attn_layernorm = None
|
||||
self.quant_fusion_mlp_layernorm = None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__bloom__BloomBlock__forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: perf model by:
|
||||
1) add residual in matmul;
|
||||
2) fuse quantization in layernorm in per-tensor sq case;
|
||||
'''
|
||||
attn_layernorm = self.input_layernorm
|
||||
mlp_layernorm = self.post_attention_layernorm
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
if self.is_per_token_sq_perf_cases:
|
||||
attn_quant_scale = self.self_attention.query_key_value.smooth
|
||||
mlp_quant_scale = self.mlp.dense_h_to_4h.smooth
|
||||
else:
|
||||
attn_quant_scale = self.self_attention.query_key_value.scale_to_int
|
||||
mlp_quant_scale = self.mlp.dense_h_to_4h.scale_to_int
|
||||
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_layernorm(
|
||||
self.input_layernorm, attn_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
self.quant_fusion_mlp_layernorm = quant_fusion_with_layernorm(
|
||||
self.post_attention_layernorm, mlp_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
mlp_layernorm = self.quant_fusion_mlp_layernorm
|
||||
|
||||
return decoder_layer_forward_base(positions=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
input_layernorm=attn_layernorm,
|
||||
self_attn=self.self_attention,
|
||||
post_layernorm=mlp_layernorm,
|
||||
mlp=self.mlp,
|
||||
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
|
||||
position_name='position_ids',
|
||||
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
|
||||
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
MluHijackObject.apply_hijack(BloomAttention,
|
||||
BloomAttention.forward,
|
||||
vllm__module_executor__models__bloom__BloomAttention__forward)
|
||||
MluHijackObject.apply_hijack(BloomBlock,
|
||||
BloomBlock.__init__,
|
||||
vllm__module_executor__models__bloom__BloomBlock__init__)
|
||||
MluHijackObject.apply_hijack(BloomBlock,
|
||||
BloomBlock.forward,
|
||||
vllm__module_executor__models__bloom__BloomBlock__forward)
|
||||
195
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/chatglm.py
Normal file
195
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/chatglm.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import torch
|
||||
|
||||
from torch.nn import LayerNorm
|
||||
from typing import Optional
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.model_executor.models.chatglm import GLMAttention, GLMBlock
|
||||
from vllm_mlu.model_executor.models.layer_utils import (
|
||||
decoder_layer_forward_base,
|
||||
is_per_tensor_smoothquant,
|
||||
is_per_token_smoothquant,
|
||||
quant_fusion_with_layernorm,
|
||||
quant_fusion_with_rmsnorm
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__module_executor__models__chatglm__GLMAttention__forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states, smooth_quant_scale)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
self.rotary_emb(position_ids, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
context_layer = self.attn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
attn_output, _ = self.dense(context_layer, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return attn_output
|
||||
|
||||
|
||||
def vllm__module_executor__models__chatglm__GLMBlock__init__(
|
||||
self,
|
||||
config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super(GLMBlock, self).__init__()
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
config.apply_residual_connection_post_layernorm)
|
||||
|
||||
self.fp32_residual_connection = config.fp32_residual_connection
|
||||
|
||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = layer_norm_func(config.hidden_size,
|
||||
eps=config.layernorm_epsilon)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = GLMAttention(config, cache_config, quant_config)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
# Layernorm on the attention output
|
||||
self.post_attention_layernorm = layer_norm_func(
|
||||
config.hidden_size, eps=config.layernorm_epsilon)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: 1) use FeedForward instead of MLP
|
||||
2) prepare to perf per-tensor sq cases if suitable
|
||||
'''
|
||||
# MLP
|
||||
self.mlp = FeedForward(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.ffn_hidden_size,
|
||||
hidden_act='silu',
|
||||
up_proj_name='dense_h_to_4h',
|
||||
is_gated=True,
|
||||
down_proj_name='dense_4h_to_h',
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config
|
||||
)
|
||||
|
||||
self.is_per_tesnor_sq_perf_cases = (is_per_tensor_smoothquant(quant_config) and
|
||||
not self.apply_residual_connection_post_layernorm)
|
||||
self.is_per_token_sq_perf_cases = (is_per_token_smoothquant(quant_config) and
|
||||
not self.apply_residual_connection_post_layernorm)
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
self.self_attention.query_key_value.quant_method.skip_quant_input = True
|
||||
self.mlp.dense_h_to_4h.quant_method.skip_quant_input = True
|
||||
self.use_rmsnorm = config.rmsnorm
|
||||
self.quant_fusion_attn_layernorm = None
|
||||
self.quant_fusion_mlp_layernorm = None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__chatglm__GLMBlock__forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: perf model by:
|
||||
1) add residual in matmul;
|
||||
2) fuse quantization in layernorm in per-tensor sq case;
|
||||
'''
|
||||
attn_layernorm = self.input_layernorm
|
||||
mlp_layernorm = self.post_attention_layernorm
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
quant_fusion_func = (quant_fusion_with_rmsnorm if
|
||||
self.use_rmsnorm else quant_fusion_with_layernorm)
|
||||
if self.is_per_token_sq_perf_cases:
|
||||
attn_quant_scale = self.self_attention.query_key_value.smooth
|
||||
mlp_quant_scale = self.mlp.dense_h_to_4h.smooth
|
||||
else:
|
||||
attn_quant_scale = self.self_attention.query_key_value.scale_to_int
|
||||
mlp_quant_scale = self.mlp.dense_h_to_4h.scale_to_int
|
||||
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_func(
|
||||
self.input_layernorm, attn_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
self.quant_fusion_mlp_layernorm = quant_fusion_func(
|
||||
self.post_attention_layernorm, mlp_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
mlp_layernorm = self.quant_fusion_mlp_layernorm
|
||||
|
||||
return decoder_layer_forward_base(positions=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
input_layernorm=attn_layernorm,
|
||||
self_attn=self.self_attention,
|
||||
post_layernorm=mlp_layernorm,
|
||||
mlp=self.mlp,
|
||||
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
|
||||
position_name='position_ids',
|
||||
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
|
||||
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(GLMAttention,
|
||||
GLMAttention.forward,
|
||||
vllm__module_executor__models__chatglm__GLMAttention__forward)
|
||||
MluHijackObject.apply_hijack(GLMBlock,
|
||||
GLMBlock.__init__,
|
||||
vllm__module_executor__models__chatglm__GLMBlock__init__)
|
||||
MluHijackObject.apply_hijack(GLMBlock,
|
||||
GLMBlock.forward,
|
||||
vllm__module_executor__models__chatglm__GLMBlock__forward)
|
||||
370
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/clip.py
Normal file
370
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/clip.py
Normal file
@@ -0,0 +1,370 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPVisionConfig
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.clip import (CLIPVisionModel,
|
||||
CLIPVisionTransformer,
|
||||
CLIPEncoderLayer,
|
||||
CLIPParallelAttention)
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MLUCLIPAttention(nn.Module):
|
||||
"""
|
||||
MLU-compatible CLIP attention implementation.
|
||||
Used as fallback when num_heads % tp_size != 0.
|
||||
This implementation does not use tensor parallelism.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.dropout = getattr(config, 'attention_dropout', 0.0)
|
||||
|
||||
# Use non-parallel linear layers since this is fallback for non-divisible cases
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=self.embed_dim,
|
||||
output_size=self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Input shape: Batch x Time x Channel
|
||||
Compatible with CLIPSdpaAttention interface
|
||||
"""
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# Project to Q, K, V
|
||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||
|
||||
# Reshape for attention computation
|
||||
query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
|
||||
# Use MLU flash attention for inference
|
||||
if self.dropout == 0.0:
|
||||
try:
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
out = mlu_ops.flash_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
out=None,
|
||||
cu_seq_lens_q=None,
|
||||
cu_seq_lens_kv=None,
|
||||
alibi_slope=None,
|
||||
attn_bias=None,
|
||||
max_seq_len_q=tgt_len,
|
||||
max_seq_len_kv=tgt_len,
|
||||
softmax_scale=self.scale,
|
||||
is_causal=False
|
||||
)
|
||||
except (ImportError, AttributeError):
|
||||
# Fallback to standard PyTorch attention if MLU ops not available
|
||||
logger.warning("MLU ops not available, using standard PyTorch attention")
|
||||
out = self._pytorch_attention(query_states, key_states, value_states)
|
||||
else:
|
||||
# Use xformers if dropout is needed (training mode)
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("xformers not available, using standard PyTorch attention")
|
||||
out = self._pytorch_attention(query_states, key_states, value_states)
|
||||
|
||||
# Reshape output
|
||||
out = out.view(bsz, tgt_len, -1)
|
||||
|
||||
# Output projection
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
def _pytorch_attention(
|
||||
self,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Standard PyTorch scaled dot-product attention as fallback.
|
||||
Input shape: [batch, seq_len, num_heads, head_dim]
|
||||
"""
|
||||
# Transpose to [batch, num_heads, seq_len, head_dim]
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
# Compute attention scores
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scale
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
|
||||
# Apply attention to values
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
# Transpose back to [batch, seq_len, num_heads, head_dim]
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
def vllm__module_executor__models__clip__CLIPParallelAttention__forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor
|
||||
):
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||
|
||||
query_states = query_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
key_states = key_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
value_states = value_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: perf attn using tmo flash attn
|
||||
'''
|
||||
if self.dropout is None or self.dropout == 0.0:
|
||||
# Always true for inference
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
out = mlu_ops.flash_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
out=None,
|
||||
cu_seq_lens_q=None,
|
||||
cu_seq_lens_kv=None,
|
||||
alibi_slope=None,
|
||||
attn_bias=None,
|
||||
max_seq_len_q=tgt_len,
|
||||
max_seq_len_kv=tgt_len,
|
||||
softmax_scale=self.scale,
|
||||
is_causal=False)
|
||||
else:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
out = out.view(bsz, tgt_len, -1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual
|
||||
'''
|
||||
attn_output, _ = self.out_proj(out, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def vllm__module_executor__models__clip__CLIPEncoderLayer____init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super(CLIPEncoderLayer, self).__init__()
|
||||
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: perf attn using tmo flash attn, do not check xformers
|
||||
'''
|
||||
if num_heads % tp_size == 0:
|
||||
self.self_attn = CLIPParallelAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.use_parallel_attn = True
|
||||
else:
|
||||
logger.warning("Use MLUCLIPAttention for clip model (fallback for non-divisible heads).")
|
||||
self.self_attn = MLUCLIPAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.use_parallel_attn = False
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use FeedForward instead of MLP
|
||||
'''
|
||||
self.mlp = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
is_gated=False,
|
||||
bias=True,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
up_proj_name='fc1',
|
||||
down_proj_name='fc2',
|
||||
prefix=f"{prefix}.mlp")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
|
||||
def vllm__module_executor__models__clip__CLIPEncoderLayer__forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: apply residual fusion
|
||||
'''
|
||||
residual = hidden_states
|
||||
if self.use_parallel_attn:
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, _ = self.self_attn(hidden_states, residual)
|
||||
else:
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, _ = self.self_attn(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
hidden_states = self.mlp(hidden_states, residual)
|
||||
hidden_states = hidden_states.view(bsz, tgt_len, -1)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return hidden_states
|
||||
|
||||
|
||||
def vllm__module_executor__models__clip__CLIPVisionModel____init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super(CLIPVisionModel, self).__init__()
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads = config.num_attention_heads
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: perf attn using tmo flash attn, do not check xformers
|
||||
'''
|
||||
self.shard_weight = num_heads % tp_size == 0
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.vision_model = CLIPVisionTransformer(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(CLIPParallelAttention,
|
||||
CLIPParallelAttention.forward,
|
||||
vllm__module_executor__models__clip__CLIPParallelAttention__forward)
|
||||
MluHijackObject.apply_hijack(CLIPEncoderLayer,
|
||||
CLIPEncoderLayer.__init__,
|
||||
vllm__module_executor__models__clip__CLIPEncoderLayer____init__)
|
||||
MluHijackObject.apply_hijack(CLIPEncoderLayer,
|
||||
CLIPEncoderLayer.forward,
|
||||
vllm__module_executor__models__clip__CLIPEncoderLayer__forward)
|
||||
MluHijackObject.apply_hijack(CLIPVisionModel,
|
||||
CLIPVisionModel.__init__,
|
||||
vllm__module_executor__models__clip__CLIPVisionModel____init__)
|
||||
@@ -0,0 +1,625 @@
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
from vllm.utils import print_warning_once
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
from vllm_mlu.model_executor.models.layer_utils import quant_fusion_with_rmsnorm
|
||||
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, yarn_get_mscale
|
||||
|
||||
class DeepseekV2MoE(SparseMoeMlp):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
up_proj_name="gate_up_proj",
|
||||
is_gated=True,
|
||||
down_proj_name="down_proj",
|
||||
has_bias=False,
|
||||
skip_bias_add=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
hidden_act=config.hidden_act,
|
||||
params_dtype=None,
|
||||
quant_config=quant_config,
|
||||
is_use_fused_moe=True,
|
||||
expert_group=config.n_group,
|
||||
topk_group=config.topk_group)
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
if self.tp_size > config.n_routed_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.n_routed_experts}.")
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate")
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace MLP with FeedForward.
|
||||
'''
|
||||
self.shared_experts = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=True,
|
||||
down_proj_name='down_proj',
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace experts() with forward_experts, which defined by SparseMoeMlp.
|
||||
'''
|
||||
final_hidden_states = self.forward_experts(
|
||||
hidden_states, router_logits) * self.routed_scaling_factor
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
def forward_prefill(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
q_scale = None
|
||||
if hasattr(self.q_b_proj.quant_method, "quant_config"):
|
||||
self.q_b_proj.quant_method.skip_quant_input = True
|
||||
quant_scale = self.q_b_proj.smooth
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.q_a_layernorm, quant_scale, dynamic_quant=True)
|
||||
q, q_scale = self.quant_fusion_attn_layernorm(q)
|
||||
q = self.q_b_proj(q, q_scale)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
|
||||
self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
|
||||
self.qk_head_dim)
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
kv_a, _ = latent_cache.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = latent_cache.unsqueeze(1)
|
||||
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
||||
kv = self.kv_b_proj(kv_a)[0]
|
||||
kv = kv.view(-1, self.num_local_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = latent_cache[:, :, self.kv_lora_rank:]
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: MLA save cache before flashattn
|
||||
'''
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
if len(kv_cache) != 0 and kv_cache[0].numel() > 0:
|
||||
key_cache = kv_cache[0][0]
|
||||
key_value = torch.concat((kv_a.unsqueeze(1), k_pe), dim=-1)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
if self.attn.kv_cache_dtype == 'int8':
|
||||
key_cache_scale = kv_cache[1][0]
|
||||
mlu_ops.quant_to_paged_cache(key_value,
|
||||
None,
|
||||
key_cache,
|
||||
None,
|
||||
key_cache_scale,
|
||||
None,
|
||||
updated_slot_mapping.flatten())
|
||||
else:
|
||||
mlu_ops.reshape_paged_cache(key_value,
|
||||
None,
|
||||
key_cache,
|
||||
None,
|
||||
updated_slot_mapping.flatten())
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
q[..., self.qk_nope_head_dim:] = q_pe
|
||||
k = torch.empty_like(q)
|
||||
k[..., :self.qk_nope_head_dim] = k_nope
|
||||
k[..., self.qk_nope_head_dim:] = k_pe
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: mlu attention not pad but qk_head_dim 192 v_head_dim 128.
|
||||
'''
|
||||
q = q.reshape(-1, self.num_local_heads * self.qk_head_dim)
|
||||
k = k.reshape(-1, self.num_local_heads * self.qk_head_dim)
|
||||
v = v.contiguous().reshape(-1, self.num_local_heads * self.v_head_dim)
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def forward_decoder(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
q_len = hidden_states.shape[0]
|
||||
q_input = hidden_states.new_empty(
|
||||
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
|
||||
)
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
q_scale = None
|
||||
if hasattr(self.q_b_proj.quant_method, "quant_config"):
|
||||
self.q_b_proj.quant_method.skip_quant_input = True
|
||||
quant_scale = self.q_b_proj.smooth
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.q_a_layernorm, quant_scale, dynamic_quant=True)
|
||||
q, q_scale = self.quant_fusion_attn_layernorm(q)
|
||||
q = self.q_b_proj(q, q_scale)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
||||
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
||||
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
v_input = latent_cache[..., : self.kv_lora_rank]
|
||||
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
|
||||
k_input = latent_cache.unsqueeze(1)
|
||||
k_input[..., : self.kv_lora_rank] = v_input
|
||||
k_pe = k_input[..., self.kv_lora_rank :]
|
||||
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
q_input[..., self.kv_lora_rank :] = q_pe
|
||||
k_input[..., self.kv_lora_rank :] = k_pe
|
||||
v_input = torch.nn.functional.pad(v_input, [0, self.qk_rope_head_dim, 0, 0, 0, 0],
|
||||
value=0).view(-1, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
q_input = q_input.reshape(q_input.shape[0], -1)
|
||||
k_input = k_input.reshape(k_input.shape[0], -1)
|
||||
v_input = v_input.reshape(v_input.shape[0], -1)
|
||||
attn_output = self.attn_decoder(q_input, k_input, v_input, kv_cache, attn_metadata)
|
||||
attn_output = attn_output.reshape(-1, self.num_local_heads,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
attn_output = attn_output[:, :, :self.kv_lora_rank]
|
||||
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc.transpose(1, 2).contiguous())
|
||||
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def vllm__module_executor__models__deepseek_v2__DeepseekV2Attention__init(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: int,
|
||||
kv_lora_rank: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super(DeepseekV2Attention, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.num_heads = num_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % tp_size == 0
|
||||
self.num_local_heads = num_heads // tp_size
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
# only RowParallelLinear/ColumnParallelLinear will be quantize
|
||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.q_a_proj")
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj")
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.kv_b_proj")
|
||||
kv_b_proj_weight = self.kv_b_proj.weight
|
||||
w_kc, w_vc = kv_b_proj_weight.unflatten(
|
||||
0, (-1, self.qk_nope_head_dim + self.v_head_dim)
|
||||
).split([self.qk_nope_head_dim, self.v_head_dim], dim=1)
|
||||
self.w_kc = w_kc
|
||||
self.w_vc = w_vc
|
||||
# O projection.
|
||||
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
rope_scaling['rope_type'] = 'deepseek_yarn'
|
||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False)
|
||||
|
||||
if rope_scaling:
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
# self.attn = Attention(self.num_heads,
|
||||
# self.qk_head_dim,
|
||||
# self.scaling,
|
||||
# num_kv_heads=self.num_heads)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: mlu attention support head_size 192.
|
||||
'''
|
||||
self.attn = Attention(self.num_local_heads,
|
||||
self.qk_head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
use_mla=True)
|
||||
self.attn_decoder = Attention(self.num_local_heads,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
use_mla=True)
|
||||
import types
|
||||
self.forward_prefill = types.MethodType(forward_prefill, self)
|
||||
self.forward_decoder = types.MethodType(forward_decoder, self)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def vllm__module_executor__models__deepseek_v2__DeepseekV2Attention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Use normal computation for prefill and use weight absorption for extend/decode
|
||||
if attn_metadata.prefill_metadata:
|
||||
return self.forward_prefill(positions, hidden_states, kv_cache,
|
||||
attn_metadata)
|
||||
else:
|
||||
return self.forward_decoder(positions, hidden_states, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
|
||||
def vllm__module_executor__models__deepseek_v2__DeepseekV2DecoderLayer__init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super(DeepseekV2DecoderLayer, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
# DecoderLayers are created with `make_layers` which passes the prefix
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep='.')[-1])
|
||||
self.self_attn = DeepseekV2Attention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||
v_head_dim=config.v_head_dim,
|
||||
q_lora_rank=config.q_lora_rank
|
||||
if hasattr(config, "q_lora_rank") else None,
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
if (config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0):
|
||||
self.mlp = DeepseekV2MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace MLP with FeedForward.
|
||||
'''
|
||||
self.mlp = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=True,
|
||||
down_proj_name='down_proj',
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack params and cal start expert id
|
||||
'''
|
||||
for name, m in self.model.named_modules():
|
||||
if isinstance(m, SparseMoeMlp):
|
||||
m.pack_params()
|
||||
|
||||
start_expert_id = 0
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: delete expert_params_mapping for no useless
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace expert_id in weight to named_expert_id in params_dict
|
||||
'''
|
||||
if start_expert_id > 0 and "mlp.experts." in name:
|
||||
expert_str = re.search(r'experts\.\d+', name).group(0)
|
||||
expert_id=int(expert_str.split(".")[1])
|
||||
named_expert_id = expert_id - start_expert_id
|
||||
old_expert_name = f"experts.{expert_id}"
|
||||
new_expert_name = f"experts.{named_expert_id}"
|
||||
name = name.replace(old_expert_name, new_expert_name)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add expert skiped condition and delete useless if name not in params_dict: continue condition
|
||||
'''
|
||||
name = name.replace(weight_name, param_name)
|
||||
if (("mlp.experts." in name or "mlp.shared_experts." in name or "mlp.shared_expert_gate." in name)
|
||||
and name not in params_dict):
|
||||
continue
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add expert skiped condition
|
||||
'''
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if (("mlp.experts." in name or "mlp.shared_experts." in name or "mlp.shared_expert_gate." in name)
|
||||
and name not in params_dict):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
MluHijackObject.apply_hijack(DeepseekV2Attention,
|
||||
DeepseekV2Attention.forward,
|
||||
vllm__module_executor__models__deepseek_v2__DeepseekV2Attention__forward)
|
||||
MluHijackObject.apply_hijack(DeepseekV2Attention,
|
||||
DeepseekV2Attention.__init__,
|
||||
vllm__module_executor__models__deepseek_v2__DeepseekV2Attention__init)
|
||||
MluHijackObject.apply_hijack(DeepseekV2DecoderLayer,
|
||||
DeepseekV2DecoderLayer.__init__,
|
||||
vllm__module_executor__models__deepseek_v2__DeepseekV2DecoderLayer__init__)
|
||||
MluHijackObject.apply_hijack(DeepseekV2ForCausalLM,
|
||||
DeepseekV2ForCausalLM.load_weights,
|
||||
vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weights)
|
||||
242
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/falcon.py
Executable file
242
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/falcon.py
Executable file
@@ -0,0 +1,242 @@
|
||||
import math
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from typing import List, Optional, Union
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from torch.nn import LayerNorm
|
||||
from transformers import FalconConfig as HF_FalconConfig
|
||||
from vllm.transformers_utils.configs import RWConfig
|
||||
FalconConfig = Union[HF_FalconConfig, RWConfig]
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
|
||||
from vllm.model_executor.models.falcon import (FalconAttention,
|
||||
FalconDecoderLayer,
|
||||
_get_alibi_slopes)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__module_executor__models__falcon__FalconAttention____init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super(FalconAttention, self).__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
assert self.head_dim * self.total_num_heads == self.hidden_size
|
||||
|
||||
self.new_decoder_architecture = config.new_decoder_architecture
|
||||
self.multi_query = config.multi_query
|
||||
|
||||
if self.new_decoder_architecture:
|
||||
self.total_num_kv_heads = config.num_kv_heads
|
||||
elif self.multi_query:
|
||||
self.total_num_kv_heads = 1
|
||||
else:
|
||||
self.total_num_kv_heads = self.total_num_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=config.bias,
|
||||
skip_bias_add=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
|
||||
# Layer-wise attention scaling
|
||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||
or config.parallel_attn)
|
||||
self.dense = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=config.bias,
|
||||
skip_bias_add=True,
|
||||
quant_config=quant_config,
|
||||
reduce_results=self.reduce_row_parallel_results)
|
||||
|
||||
self.use_rotary = config.rotary
|
||||
self.use_alibi = config.alibi
|
||||
assert not (self.use_rotary and self.use_alibi), (
|
||||
"Rotary and alibi are mutually exclusive.")
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: set cache_config for rotary & alibi
|
||||
'''
|
||||
if self.use_rotary:
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config,
|
||||
"max_position_embeddings", 8192)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
elif self.use_alibi:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
head_end = (tp_rank + 1) * self.num_heads
|
||||
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
|
||||
self.inv_norm_factor)
|
||||
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.inv_norm_factor,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__falcon__FalconAttention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, bias = self.query_key_value(hidden_states)
|
||||
if bias is not None:
|
||||
qkv += bias
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
if self.use_rotary:
|
||||
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output, bias = self.dense(attn_output)
|
||||
return attn_output, bias
|
||||
|
||||
|
||||
def vllm__module_executor__models__falcon__FalconDecoderLayer____init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super(FalconDecoderLayer, self).__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.self_attention = FalconAttention(config, cache_config,
|
||||
quant_config)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use FeedForward instead of MLP
|
||||
'''
|
||||
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||
or config.parallel_attn)
|
||||
self.mlp = FeedForward(hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 4,
|
||||
hidden_act='gelu',
|
||||
up_proj_name='dense_h_to_4h',
|
||||
is_gated=False,
|
||||
down_proj_name='dense_4h_to_h',
|
||||
bias=config.bias,
|
||||
quant_config=quant_config,
|
||||
skip_bias_add=True,
|
||||
reduce_results=self.reduce_row_parallel_results)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.config = config
|
||||
|
||||
if (config.num_ln_in_parallel_attn is None
|
||||
and config.new_decoder_architecture):
|
||||
config.num_ln_in_parallel_attn = 2
|
||||
|
||||
if not config.parallel_attn:
|
||||
self.post_attention_layernorm = LayerNorm(
|
||||
hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.input_layernorm = LayerNorm(hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
else:
|
||||
if config.num_ln_in_parallel_attn == 2:
|
||||
# The layer norm before self-attention
|
||||
self.ln_attn = LayerNorm(hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
# The layer norm before the MLP
|
||||
self.ln_mlp = LayerNorm(hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
else:
|
||||
self.input_layernorm = LayerNorm(hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
|
||||
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||
or config.parallel_attn)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(FalconAttention,
|
||||
FalconAttention.__init__,
|
||||
vllm__module_executor__models__falcon__FalconAttention____init__)
|
||||
MluHijackObject.apply_hijack(FalconAttention,
|
||||
FalconAttention.forward,
|
||||
vllm__module_executor__models__falcon__FalconAttention__forward)
|
||||
MluHijackObject.apply_hijack(FalconDecoderLayer,
|
||||
FalconDecoderLayer.__init__,
|
||||
vllm__module_executor__models__falcon__FalconDecoderLayer____init__)
|
||||
238
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/gpt_neox.py
Normal file
238
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/gpt_neox.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Optional
|
||||
from transformers import GPTNeoXConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.models.gpt_neox import GPTNeoXAttention, GPTNeoXLayer
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__module_executor__models__gpt_neox__GPTNeoXAttention__init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super(GPTNeoXAttention, self).__init__()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.total_num_heads
|
||||
self.bias = getattr(config, "attention_bias", True)
|
||||
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
config.hidden_size,
|
||||
self.head_size,
|
||||
self.total_num_heads,
|
||||
bias=self.bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: do not do allreduce in linear and skip bias add when use_parallel_residual
|
||||
'''
|
||||
if config.use_parallel_residual:
|
||||
reduce_results = False
|
||||
skip_bias_add = True
|
||||
else:
|
||||
reduce_results = True
|
||||
skip_bias_add = False
|
||||
|
||||
self.dense = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=self.bias,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
skip_bias_add=skip_bias_add,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
scaling = self.head_size**-0.5
|
||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||
assert rotary_dim % 2 == 0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_size,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_size,
|
||||
scaling,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
|
||||
def vllm__module_executor__models__gpt_neox__GPTNeoXAttention__forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
qk, _ = qkv.split([self.num_heads * self.head_size * 2, self.num_heads * self.head_size], dim=-1)
|
||||
self.rotary_emb(position_ids, qk.view(-1, self.num_heads + self.num_heads, self.head_size))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add bias for rank 0 when use_parallel_residual
|
||||
'''
|
||||
output, bias = self.dense(attn_output)
|
||||
if self.dense.skip_bias_add and get_tensor_model_parallel_rank() == 0:
|
||||
output += bias
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
|
||||
def vllm__module_executor__models__gpt_neox__GPTNeoXLayer__init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super(GPTNeoXLayer, self).__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.attention = GPTNeoXAttention(config, cache_config, quant_config)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: 1) use FeedForward instead of MLP
|
||||
2) do not do allreduce in row linear and skip bias add in it
|
||||
'''
|
||||
if self.use_parallel_residual:
|
||||
reduce_results = False
|
||||
skip_bias_add = True
|
||||
else:
|
||||
reduce_results = True
|
||||
skip_bias_add = False
|
||||
|
||||
self.mlp = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act='gelu',
|
||||
up_proj_name='dense_h_to_4h',
|
||||
is_gated=False,
|
||||
down_proj_name='dense_4h_to_h',
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
skip_bias_add=skip_bias_add,
|
||||
reduce_results=reduce_results)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__gpt_neox__GPTNeoXLayer__forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.input_layernorm(hidden_states)
|
||||
attn_output = self.attention(
|
||||
position_ids=position_ids,
|
||||
hidden_states=attn_input,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: only do one allreduce when use_parallel_residual
|
||||
'''
|
||||
if self.use_parallel_residual:
|
||||
# pseudocode:
|
||||
# x = x + attn(ln1(x)) + mlp(ln2(x))
|
||||
mlp_input = self.post_attention_layernorm(hidden_states)
|
||||
mlp_output, mlp_bias = self.mlp(mlp_input)
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
mlp_output += mlp_bias
|
||||
hidden_states = mlp_output + attn_output + hidden_states
|
||||
else:
|
||||
hidden_states = mlp_output + attn_output
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
else:
|
||||
# pseudocode:
|
||||
# x = x + attn(ln1(x))
|
||||
# x = x + mlp(ln2(x))
|
||||
attn_output = attn_output + hidden_states
|
||||
mlp_input = self.post_attention_layernorm(attn_output)
|
||||
mlp_output = self.mlp(mlp_input)
|
||||
hidden_states = mlp_output + attn_output
|
||||
return hidden_states
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(GPTNeoXAttention,
|
||||
GPTNeoXAttention.__init__,
|
||||
vllm__module_executor__models__gpt_neox__GPTNeoXAttention__init__)
|
||||
MluHijackObject.apply_hijack(GPTNeoXAttention,
|
||||
GPTNeoXAttention.forward,
|
||||
vllm__module_executor__models__gpt_neox__GPTNeoXAttention__forward)
|
||||
MluHijackObject.apply_hijack(GPTNeoXLayer,
|
||||
GPTNeoXLayer.__init__,
|
||||
vllm__module_executor__models__gpt_neox__GPTNeoXLayer__init__)
|
||||
MluHijackObject.apply_hijack(GPTNeoXLayer,
|
||||
GPTNeoXLayer.forward,
|
||||
vllm__module_executor__models__gpt_neox__GPTNeoXLayer__forward)
|
||||
502
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/hunyuan.py
Executable file
502
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/hunyuan.py
Executable file
@@ -0,0 +1,502 @@
|
||||
import torch
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Iterable, Union
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.model_executor.models.hunyuan import HunYuanAttention, HunYuanDecoderLayer, HunYuanForCausalLM, HunYuanModel
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
get_compressed_tensors_cache_scale)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.model_executor.models.layer_utils import (
|
||||
hunyuan_decoder_layer_forward_base, hunyuan_decoder_model_forward_base_pp,
|
||||
is_per_tensor_smoothquant, is_per_token_smoothquant, quant_fusion_with_rmsnorm)
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
class HunYuanSparseMoeBlock(SparseMoeMlp):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__(num_experts=config.num_experts,
|
||||
top_k=config.moe_topk,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
up_proj_name="gate_up_proj",
|
||||
is_gated=True,
|
||||
down_proj_name="down_proj",
|
||||
has_bias=False,
|
||||
skip_bias_add=False,
|
||||
renormalize=True if config.moe_topk>1 else False,
|
||||
hidden_act=config.hidden_act,
|
||||
params_dtype=None,
|
||||
quant_config=quant_config,
|
||||
is_use_fused_moe=True)
|
||||
self.config = config
|
||||
self.shared_mlp = None
|
||||
if config.use_mixed_mlp_moe > 0:
|
||||
self.shared_mlp = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size * config.num_shared_expert,
|
||||
hidden_act=config.hidden_act,
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=True,
|
||||
down_proj_name='down_proj',
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False)
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
shared_output = None
|
||||
if self.shared_mlp is not None:
|
||||
shared_output = self.shared_mlp(hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.forward_experts(hidden_states, router_logits, residual)
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
def vllm__module_executor__models__hunyuan__HunYuanAttention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.attention_type == "self":
|
||||
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
ori_k = k
|
||||
if self.use_qk_norm:
|
||||
q = self.query_layernorm(q.reshape(-1, self.num_heads, self.head_dim).contiguous()).reshape(-1, self.num_heads*self.head_dim)
|
||||
k = self.key_layernorm(k.reshape(-1, self.num_kv_heads, self.head_dim).contiguous()).reshape(-1, self.num_kv_heads*self.head_dim)
|
||||
elif self.attention_type == "cross":
|
||||
assert kv_states is not None
|
||||
ori_k, v = kv_states # use last layer kv,
|
||||
k = ori_k
|
||||
q, _ = self.q_proj(hidden_states, smooth_quant_scale)
|
||||
k_tmp = torch.empty_like(k) # Todo: reduant rotary embedding
|
||||
qk_temp = torch.cat((q, k_tmp), dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
self.rotary_emb(positions, qk_temp.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.use_qk_norm:
|
||||
q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous()).reshape(-1, self.num_heads*self.head_dim)
|
||||
k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous()).reshape(-1, self.num_kv_heads*self.head_dim)
|
||||
else:
|
||||
raise RuntimeError("Not support attnention type")
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual
|
||||
'''
|
||||
output, _ = self.o_proj(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output, (ori_k, v)
|
||||
|
||||
|
||||
def vllm__module_executor__models__hunyuan__HunYuanDecoderLayer____init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
layer_id: int = -1,
|
||||
) -> None:
|
||||
super(HunYuanDecoderLayer, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
||||
# Support internlm/internlm-7b with bias
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False)
|
||||
cla_factor = getattr(config, "cla_share_factor", 1)
|
||||
attention_type = "cross" \
|
||||
if layer_id >= 0 and layer_id % cla_factor != 0 else "self"
|
||||
self.self_attn = HunYuanAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
config.num_attention_heads),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attention_type=attention_type,
|
||||
)
|
||||
|
||||
if getattr(config, "num_experts", None):
|
||||
self.mlp = HunYuanSparseMoeBlock(config=config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use FeedForward instead of MLP
|
||||
'''
|
||||
self.mlp = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=True,
|
||||
down_proj_name='down_proj',
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
prefix=f"{prefix}.mlp",
|
||||
quant_config=quant_config)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: prepare to perf per-tensor sq cases if suitable. For moe
|
||||
model, we only do quant fusion in attn block.
|
||||
'''
|
||||
self.is_per_tensor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
|
||||
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
|
||||
if self.is_per_tensor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
if self.self_attn.attention_type == "self":
|
||||
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
|
||||
if self.self_attn.attention_type == "cross":
|
||||
self.self_attn.q_proj.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_attn_layernorm = None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__hunyuan__HunYuanForCausalLM__load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
cla_factor = getattr(self.config, "cla_share_factor", 1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack params and cal start expert id
|
||||
'''
|
||||
for name, m in self.model.named_modules():
|
||||
if isinstance(m, SparseMoeMlp):
|
||||
m.pack_params()
|
||||
|
||||
start_expert_id = 0
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: delete expert_params_mapping for no useless
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
# With tie_word_embeddings, we can skip lm_head.weight
|
||||
# The weight might appear unnecessarily in the files if the model is
|
||||
# processed with quantization, LoRA, fine-tuning, etc.
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace expert_id in weight to named_expert_id in params_dict
|
||||
'''
|
||||
if start_expert_id > 0 and "mlp.experts." in name:
|
||||
expert_str = re.search(r'experts\.\d+', name).group(0)
|
||||
expert_id=int(expert_str.split(".")[1])
|
||||
named_expert_id = expert_id - start_expert_id
|
||||
old_expert_name = f"experts.{expert_id}"
|
||||
new_expert_name = f"experts.{named_expert_id}"
|
||||
name = name.replace(old_expert_name, new_expert_name)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: delete if "mlp.experts" in name: continue condition
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# cross layer only have q_proj, skip qkv pack
|
||||
if weight_name == ".q_proj":
|
||||
match = re.search(r'layers\.\d+', name)
|
||||
if match:
|
||||
layer_id = int(match.group(0).split('.')[-1])
|
||||
if cla_factor > 1 and layer_id % cla_factor != 0:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if (name.endswith(".bias") and name not in params_dict):
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add expert skiped condition and delete useless if name not in params_dict: continue condition
|
||||
'''
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if (("mlp.experts." in name or "mlp.shared_mlp." in name)
|
||||
and name not in params_dict):
|
||||
continue
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if (name.endswith(".bias") and name not in params_dict):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
if "mlp.gate.wg." in name:
|
||||
name = name.replace("wg.", "")
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add expert skiped condition
|
||||
'''
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if (("mlp.experts." in name or "mlp.shared_mlp." in name)
|
||||
and name not in params_dict):
|
||||
continue
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
def vllm__module_executor__models__hunyuan__HunYuanDecoderLayer__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: perf model by:
|
||||
1) add residual in matmul;
|
||||
2) fuse quantization in layernorm in per-tensor sq case;
|
||||
'''
|
||||
attn_layernorm = self.input_layernorm
|
||||
if self.is_per_tensor_sq_perf_cases:
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
if self.self_attn.attention_type == "self":
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.input_layernorm, self.self_attn.qkv_proj.scale_to_int)
|
||||
if self.self_attn.attention_type == "cross":
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.input_layernorm, self.self_attn.q_proj.scale_to_int)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
elif self.is_per_token_sq_perf_cases:
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
if self.self_attn.attention_type == "self":
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.input_layernorm, self.self_attn.qkv_proj.smooth, dynamic_quant=True)
|
||||
if self.self_attn.attention_type == "cross":
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.input_layernorm, self.self_attn.q_proj.smooth, dynamic_quant=True)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
return hunyuan_decoder_layer_forward_base(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
input_layernorm=attn_layernorm,
|
||||
self_attn=self.self_attn,
|
||||
post_layernorm=self.post_attention_layernorm,
|
||||
mlp=self.mlp,
|
||||
kv_states=kv_states,
|
||||
input_norm_fuse_en=self.is_per_token_sq_perf_cases)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__hunyuan__HunYuanModel__forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
return hunyuan_decoder_model_forward_base_pp(config=self.config,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
layers=self.layers,
|
||||
start_layer=self.start_layer,
|
||||
end_layer=self.end_layer,
|
||||
get_input_embeddings=self.get_input_embeddings,
|
||||
norm=self.norm,
|
||||
inputs_embeds=inputs_embeds)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(HunYuanAttention,
|
||||
HunYuanAttention.forward,
|
||||
vllm__module_executor__models__hunyuan__HunYuanAttention__forward)
|
||||
MluHijackObject.apply_hijack(HunYuanDecoderLayer,
|
||||
HunYuanDecoderLayer.__init__,
|
||||
vllm__module_executor__models__hunyuan__HunYuanDecoderLayer____init__)
|
||||
MluHijackObject.apply_hijack(HunYuanForCausalLM,
|
||||
HunYuanForCausalLM.load_weights,
|
||||
vllm__module_executor__models__hunyuan__HunYuanForCausalLM__load_weights)
|
||||
MluHijackObject.apply_hijack(HunYuanDecoderLayer,
|
||||
HunYuanDecoderLayer.forward,
|
||||
vllm__module_executor__models__hunyuan__HunYuanDecoderLayer__forward)
|
||||
MluHijackObject.apply_hijack(HunYuanModel,
|
||||
HunYuanModel.forward,
|
||||
vllm__module_executor__models__hunyuan__HunYuanModel__forward)
|
||||
294
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/internlm2.py
Normal file
294
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/internlm2.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple, Iterable, Union, List
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.model_executor.models.internlm2 import InternLM2Attention, InternLMDecoderLayer, InternLM2ForCausalLM, InternLM2Model
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import (is_pp_missing_parameter)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm_mlu.model_executor.models.layer_utils import (
|
||||
decoder_layer_forward_base, decoder_model_forward_base_pp,
|
||||
is_per_tensor_smoothquant, is_per_token_smoothquant,
|
||||
quant_fusion_with_rmsnorm)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__module_executor__models__internlm2__InternLM2Attention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.wqkv(hidden_states, smooth_quant_scale)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual
|
||||
'''
|
||||
output, _ = self.wo(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
|
||||
def vllm__module_executor__models__internlm2__InternLMDecoderLayer____init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super(InternLMDecoderLayer, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.attention = InternLM2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attention",
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use FeedForward instead of MLP
|
||||
'''
|
||||
self.feed_forward = FeedForward(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=True,
|
||||
down_proj_name='w2',
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f'{prefix}.feed_forward'
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.attention_norm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: prepare to perf per-tensor sq cases if suitable
|
||||
'''
|
||||
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
|
||||
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
self.attention.wqkv.quant_method.skip_quant_input = True
|
||||
self.feed_forward.gate_up_proj.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_attn_layernorm = None
|
||||
self.quant_fusion_mlp_layernorm = None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__internlm2__InternLM2ForCausalLM__load_weights(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "w1", 0),
|
||||
("gate_up_proj", "w3", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support load quant weights and params
|
||||
'''
|
||||
if "wqkv" in name and 'smooth' not in name and 'scale_to_int' not in name:
|
||||
config = self.config
|
||||
kv_groups = (config.num_attention_heads //
|
||||
config.num_key_value_heads)
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
if 'weight' in name:
|
||||
loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
|
||||
head_dim,
|
||||
loaded_weight.shape[-1])
|
||||
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
|
||||
dim=1)
|
||||
wq = wq.reshape(-1, wq.shape[-1])
|
||||
wk = wk.reshape(-1, wk.shape[-1])
|
||||
wv = wv.reshape(-1, wv.shape[-1])
|
||||
elif 'scale' in name:
|
||||
loaded_weight = loaded_weight.view(-1, 2 + kv_groups, head_dim)
|
||||
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
|
||||
dim=1)
|
||||
wq = wq.reshape(-1)
|
||||
wk = wk.reshape(-1)
|
||||
wv = wv.reshape(-1)
|
||||
else:
|
||||
logger.error(f"unsupport internlm2 quant param: {name}, shape: {loaded_weight.shape}")
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, wq, 'q')
|
||||
weight_loader(param, wk, 'k')
|
||||
weight_loader(param, wv, 'v')
|
||||
else:
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__internlm2__InternLMDecoderLayer__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: perf model by:
|
||||
1) add residual in matmul;
|
||||
2) fuse quantization in layernorm in per-tensor sq case;
|
||||
'''
|
||||
attn_layernorm = self.attention_norm
|
||||
mlp_layernorm = self.ffn_norm
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
if self.is_per_token_sq_perf_cases:
|
||||
attn_quant_scale = self.attention.wqkv.smooth
|
||||
mlp_quant_scale = self.feed_forward.gate_up_proj.smooth
|
||||
else:
|
||||
attn_quant_scale = self.attention.wqkv.scale_to_int
|
||||
mlp_quant_scale = self.feed_forward.gate_up_proj.scale_to_int
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.attention_norm, attn_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
self.quant_fusion_mlp_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.ffn_norm, mlp_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
mlp_layernorm = self.quant_fusion_mlp_layernorm
|
||||
|
||||
return decoder_layer_forward_base(positions, hidden_states, kv_cache, attn_metadata,
|
||||
attn_layernorm,
|
||||
self.attention,
|
||||
mlp_layernorm,
|
||||
self.feed_forward,
|
||||
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
|
||||
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def vllm__module_executor__models__internlm2__InternLM2Model__forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
return decoder_model_forward_base_pp(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors,
|
||||
self.layers, self.start_layer, self.end_layer,
|
||||
self.tok_embeddings,
|
||||
self.norm,
|
||||
inputs_embeds)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(InternLM2Attention,
|
||||
InternLM2Attention.forward,
|
||||
vllm__module_executor__models__internlm2__InternLM2Attention__forward)
|
||||
MluHijackObject.apply_hijack(InternLMDecoderLayer,
|
||||
InternLMDecoderLayer.__init__,
|
||||
vllm__module_executor__models__internlm2__InternLMDecoderLayer____init__)
|
||||
MluHijackObject.apply_hijack(InternLM2ForCausalLM,
|
||||
InternLM2ForCausalLM.load_weights,
|
||||
vllm__module_executor__models__internlm2__InternLM2ForCausalLM__load_weights)
|
||||
MluHijackObject.apply_hijack(InternLMDecoderLayer,
|
||||
InternLMDecoderLayer.forward,
|
||||
vllm__module_executor__models__internlm2__InternLMDecoderLayer__forward)
|
||||
MluHijackObject.apply_hijack(InternLM2Model,
|
||||
InternLM2Model.forward,
|
||||
vllm__module_executor__models__internlm2__InternLM2Model__forward)
|
||||
273
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/layer_utils.py
Executable file
273
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/layer_utils.py
Executable file
@@ -0,0 +1,273 @@
|
||||
import torch
|
||||
from typing import Callable, Optional, List, Union, Tuple
|
||||
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from transformers import PretrainedConfig
|
||||
from vllm_mlu._mlu_utils import check_context_comm_cmpt_parallel
|
||||
|
||||
def hunyuan_decoder_layer_forward_base(
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_layernorm: Callable,
|
||||
self_attn: Callable,
|
||||
post_layernorm: Callable,
|
||||
mlp: Callable,
|
||||
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
||||
apply_residual_connection_post_layernorm: bool = False,
|
||||
position_name: str = 'positions',
|
||||
input_norm_fuse_en: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
smooth_quant_scale = None
|
||||
if input_norm_fuse_en:
|
||||
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
|
||||
else:
|
||||
layernorm_output = input_layernorm(hidden_states)
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# Self Attention
|
||||
attention_output, ori_kv_states = self_attn(
|
||||
**{position_name: positions},
|
||||
hidden_states=layernorm_output,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
kv_states=kv_states,
|
||||
smooth_quant_scale=smooth_quant_scale,
|
||||
)
|
||||
|
||||
layernorm_output = post_layernorm(attention_output)
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = attention_output
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = mlp(layernorm_output, residual)
|
||||
return hidden_states, ori_kv_states
|
||||
|
||||
def decoder_layer_forward_base(
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_layernorm: Callable,
|
||||
self_attn: Callable,
|
||||
post_layernorm: Callable,
|
||||
mlp: Callable,
|
||||
apply_residual_connection_post_layernorm: bool = False,
|
||||
position_name: str = 'positions',
|
||||
input_norm_fuse_en: bool = False,
|
||||
post_norm_fuse_en: bool = False,
|
||||
) -> torch.Tensor:
|
||||
smooth_quant_scale = None
|
||||
if input_norm_fuse_en:
|
||||
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
|
||||
else:
|
||||
layernorm_output = input_layernorm(hidden_states)
|
||||
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# Self Attention
|
||||
attention_output = self_attn(
|
||||
**{position_name: positions},
|
||||
hidden_states=layernorm_output,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
smooth_quant_scale=smooth_quant_scale,
|
||||
)
|
||||
smooth_quant_scale = None
|
||||
if post_norm_fuse_en:
|
||||
layernorm_output, smooth_quant_scale = post_layernorm(attention_output)
|
||||
else:
|
||||
layernorm_output = post_layernorm(attention_output)
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = attention_output
|
||||
|
||||
# Fully Connected
|
||||
kwargs = dict()
|
||||
if post_norm_fuse_en:
|
||||
kwargs['smooth_quant_scale'] = smooth_quant_scale
|
||||
hidden_states = mlp(layernorm_output, residual, **kwargs)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def decoder_model_forward_base(
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
layers: torch.nn.ModuleList,
|
||||
get_input_embeddings: Callable,
|
||||
norm: Callable
|
||||
) -> torch.Tensor:
|
||||
hidden_states = get_input_embeddings(input_ids)
|
||||
for i in range(len(layers)):
|
||||
layer = layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
attn_metadata,
|
||||
)
|
||||
hidden_states = norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def hunyuan_decoder_model_forward_base_pp(
|
||||
config: PretrainedConfig,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
layers: torch.nn.ModuleList,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
get_input_embeddings: Callable,
|
||||
norm: Callable,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = get_input_embeddings(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
cla_factor = getattr(config, "cla_share_factor", 1)
|
||||
prev_kv_states = None
|
||||
for i in range(start_layer, end_layer):
|
||||
layer = layers[i]
|
||||
hidden_states, kv_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - start_layer],
|
||||
attn_metadata,
|
||||
prev_kv_states,
|
||||
)
|
||||
if (i - start_layer) % cla_factor == 0:
|
||||
prev_kv_states = kv_states
|
||||
else:
|
||||
prev_kv_states = None
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
})
|
||||
|
||||
hidden_states = norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def decoder_model_forward_base_pp(
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
layers: torch.nn.ModuleList,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
get_input_embeddings: Callable,
|
||||
norm: Callable,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = get_input_embeddings(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(start_layer, end_layer):
|
||||
layer = layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
})
|
||||
|
||||
hidden_states = norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def is_smoothquant(quant_config: QuantizationConfig) -> bool:
|
||||
return (quant_config is not None and quant_config.get_name() == "SmoothQuant")
|
||||
|
||||
def is_per_tensor_smoothquant(quant_config: QuantizationConfig) -> bool:
|
||||
return is_smoothquant(quant_config) and quant_config.input_quant_method == "per_tensor"
|
||||
|
||||
def is_per_token_smoothquant(quant_config: QuantizationConfig) -> bool:
|
||||
if check_context_comm_cmpt_parallel():
|
||||
return False
|
||||
|
||||
return is_smoothquant(quant_config) and quant_config.input_quant_method == "per_token"
|
||||
|
||||
def quant_fusion_with_layernorm(
|
||||
op: torch.nn.LayerNorm,
|
||||
quant_scale: torch.Tensor,
|
||||
dynamic_quant: bool = False,
|
||||
) -> Callable:
|
||||
bias = None
|
||||
if op.bias is not None:
|
||||
bias = op.bias.data
|
||||
|
||||
def func(x: torch.Tensor) -> torch.Tensor:
|
||||
return mlu_ops.fused_layer_norm(
|
||||
x,
|
||||
None,
|
||||
op.weight.data,
|
||||
bias,
|
||||
None,
|
||||
op.eps,
|
||||
False,
|
||||
quant_scale,
|
||||
dynamic_quant)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def quant_fusion_with_rmsnorm(
|
||||
op: RMSNorm,
|
||||
quant_scale: torch.Tensor,
|
||||
dynamic_quant: bool = False,
|
||||
) -> Callable:
|
||||
|
||||
def func(x: torch.Tensor) -> torch.Tensor:
|
||||
return mlu_ops.fused_rms_norm(
|
||||
x,
|
||||
None,
|
||||
op.weight.data,
|
||||
None,
|
||||
None,
|
||||
op.variance_epsilon,
|
||||
False,
|
||||
quant_scale,
|
||||
dynamic_quant)
|
||||
|
||||
return func
|
||||
275
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/llama.py
Normal file
275
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/llama.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import torch
|
||||
|
||||
from typing import Dict, List, Optional, Union, Any
|
||||
from transformers import LlamaConfig
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.model_executor.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_mlu.model_executor.models.layer_utils import (
|
||||
decoder_layer_forward_base, decoder_model_forward_base_pp,
|
||||
is_per_tensor_smoothquant, is_per_token_smoothquant,
|
||||
quant_fusion_with_rmsnorm)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
vllm__module_executor__models__llama__LlamaAttention__init__org = LlamaAttention.__init__
|
||||
|
||||
|
||||
def vllm__module_executor__models__llama__LlamaAttention____init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
vllm__module_executor__models__llama__LlamaAttention__init__org(
|
||||
self,
|
||||
config,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
rope_theta,
|
||||
rope_scaling,
|
||||
max_position_embeddings,
|
||||
quant_config,
|
||||
bias,
|
||||
cache_config,
|
||||
prefix)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add rope_scaling params
|
||||
'''
|
||||
self.rope_scaling = rope_scaling
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__llama__LlamaAttention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
if self.rope_scaling is not None and self.rope_scaling["rope_type"] == "longrope":
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
else:
|
||||
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual
|
||||
'''
|
||||
output, _ = self.o_proj(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
|
||||
def vllm__module_executor__models__llama__LlamaDecoderLayer____init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super(LlamaDecoderLayer, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
||||
# Support internlm/internlm-7b with bias
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False)
|
||||
self.self_attn = LlamaAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
config.num_attention_heads),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use FeedForward instead of MLP
|
||||
'''
|
||||
self.mlp = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act='silu',
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=True,
|
||||
down_proj_name='down_proj',
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: prepare to perf sq cases if suitable
|
||||
'''
|
||||
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
|
||||
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
|
||||
self.mlp.gate_up_proj.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_attn_layernorm = None
|
||||
self.quant_fusion_mlp_layernorm = None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def vllm__module_executor__models__llama__LlamaDecoderLayer__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: perf model by:
|
||||
1) add residual in matmul;
|
||||
2) fuse quantization in layernorm in per-tensor sq case;
|
||||
'''
|
||||
attn_layernorm = self.input_layernorm
|
||||
mlp_layernorm = self.post_attention_layernorm
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
if self.is_per_token_sq_perf_cases:
|
||||
attn_quant_scale = self.self_attn.qkv_proj.smooth
|
||||
mlp_quant_scale = self.mlp.gate_up_proj.smooth
|
||||
else:
|
||||
attn_quant_scale = self.self_attn.qkv_proj.scale_to_int
|
||||
mlp_quant_scale = self.mlp.gate_up_proj.scale_to_int
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.input_layernorm, attn_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
self.quant_fusion_mlp_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.post_attention_layernorm, mlp_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
mlp_layernorm = self.quant_fusion_mlp_layernorm
|
||||
|
||||
return decoder_layer_forward_base(positions, hidden_states, kv_cache, attn_metadata,
|
||||
attn_layernorm,
|
||||
self.self_attn,
|
||||
mlp_layernorm,
|
||||
self.mlp,
|
||||
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
|
||||
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def vllm__module_executor__models__llama__LlamaModel__forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
return decoder_model_forward_base_pp(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors,
|
||||
self.layers, self.start_layer, self.end_layer,
|
||||
self.get_input_embeddings,
|
||||
self.norm,
|
||||
inputs_embeds)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
MluHijackObject.apply_hijack(LlamaAttention,
|
||||
LlamaAttention.__init__,
|
||||
vllm__module_executor__models__llama__LlamaAttention____init__)
|
||||
MluHijackObject.apply_hijack(LlamaAttention,
|
||||
LlamaAttention.forward,
|
||||
vllm__module_executor__models__llama__LlamaAttention__forward)
|
||||
MluHijackObject.apply_hijack(LlamaDecoderLayer,
|
||||
LlamaDecoderLayer.__init__,
|
||||
vllm__module_executor__models__llama__LlamaDecoderLayer____init__)
|
||||
MluHijackObject.apply_hijack(LlamaDecoderLayer,
|
||||
LlamaDecoderLayer.forward,
|
||||
vllm__module_executor__models__llama__LlamaDecoderLayer__forward)
|
||||
MluHijackObject.apply_hijack(LlamaModel,
|
||||
LlamaModel.forward,
|
||||
vllm__module_executor__models__llama__LlamaModel__forward)
|
||||
336
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/mixtral.py
Normal file
336
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/mixtral.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import torch
|
||||
import re
|
||||
|
||||
from typing import List, Optional, Tuple, Union, Iterable
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.model_executor.models.mixtral import (MixtralAttention, MixtralDecoderLayer,
|
||||
MixtralForCausalLM, MixtralModel)
|
||||
from vllm_mlu.mlu_hijack_utils import set_is_gated
|
||||
from transformers import MixtralConfig
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.model_loader.weight_utils import (default_weight_loader,
|
||||
maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_mlu.model_executor.models.layer_utils import (
|
||||
decoder_layer_forward_base, decoder_model_forward_base_pp,
|
||||
is_per_tensor_smoothquant, is_per_token_smoothquant,
|
||||
quant_fusion_with_rmsnorm)
|
||||
|
||||
|
||||
def vllm__module_executor__models__mixtral__MixtralAttention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual
|
||||
'''
|
||||
output, _ = self.o_proj(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
|
||||
def vllm__module_executor__models__mixtral__MixtralDecoderLayer____init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super(MixtralDecoderLayer, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.self_attn = MixtralAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace MixtralMoE to SparseMoeMlp
|
||||
'''
|
||||
self.block_sparse_moe = SparseMoeMlp(num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
up_proj_name="w13",
|
||||
is_gated=True,
|
||||
down_proj_name="w2",
|
||||
has_bias=False,
|
||||
skip_bias_add=False,
|
||||
renormalize=True,
|
||||
hidden_act=config.hidden_act,
|
||||
params_dtype=None,
|
||||
quant_config=quant_config,
|
||||
is_use_fused_moe=True)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: prepare to perf per-tensor sq cases if suitable. MoE gate linear always runs
|
||||
in half/full precision for now, so we only do quant fusion in attn block.
|
||||
'''
|
||||
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
|
||||
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_attn_layernorm = None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__mixtral__MixtralForCausalLM__load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack params and cal start expert id
|
||||
'''
|
||||
for name, m in self.model.named_modules():
|
||||
if isinstance(m, SparseMoeMlp):
|
||||
m.pack_params()
|
||||
|
||||
start_expert_id = 0
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("w13", "w1", 0),
|
||||
("w13", "w3", 1),
|
||||
]
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: delete expert_params_mapping for no useless
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace expert_id in weight to named_expert_id in params_dict
|
||||
'''
|
||||
if start_expert_id > 0 and "block_sparse_moe.experts." in name:
|
||||
expert_str = re.search(r'experts\.\d+', name).group(0)
|
||||
expert_id=int(expert_str.split(".")[1])
|
||||
named_expert_id = expert_id - start_expert_id
|
||||
old_expert_name = f"experts.{expert_id}"
|
||||
new_expert_name = f"experts.{named_expert_id}"
|
||||
name = name.replace(old_expert_name, new_expert_name)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add expert skiped condition
|
||||
'''
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if (("block_sparse_moe.experts." in name) and (name not in params_dict)):
|
||||
continue
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add expert skiped condition
|
||||
'''
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if (("block_sparse_moe.experts." in name) and (name not in params_dict)):
|
||||
continue
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
def vllm__module_executor__models__mixtral__MixtralDecoderLayer__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: perf model by:
|
||||
1) add residual in matmul;
|
||||
2) fuse quantization in layernorm in per-tensor sq case;
|
||||
'''
|
||||
attn_layernorm = self.input_layernorm
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
if self.is_per_token_sq_perf_cases:
|
||||
attn_quant_scale = self.self_attn.qkv_proj.smooth
|
||||
else:
|
||||
attn_quant_scale = self.self_attn.qkv_proj.scale_to_int
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.input_layernorm, attn_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
|
||||
return decoder_layer_forward_base(positions, hidden_states, kv_cache, attn_metadata,
|
||||
attn_layernorm,
|
||||
self.self_attn,
|
||||
self.post_attention_layernorm,
|
||||
self.block_sparse_moe,
|
||||
input_norm_fuse_en=self.is_per_token_sq_perf_cases)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__mixtral__MixtralModel__forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
return decoder_model_forward_base_pp(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors,
|
||||
self.layers, self.start_layer, self.end_layer,
|
||||
self.embed_tokens,
|
||||
self.norm)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
set_is_gated(True)
|
||||
MluHijackObject.apply_hijack(MixtralAttention,
|
||||
MixtralAttention.forward,
|
||||
vllm__module_executor__models__mixtral__MixtralAttention__forward)
|
||||
MluHijackObject.apply_hijack(MixtralDecoderLayer,
|
||||
MixtralDecoderLayer.__init__,
|
||||
vllm__module_executor__models__mixtral__MixtralDecoderLayer____init__)
|
||||
MluHijackObject.apply_hijack(MixtralForCausalLM,
|
||||
MixtralForCausalLM.load_weights,
|
||||
vllm__module_executor__models__mixtral__MixtralForCausalLM__load_weights)
|
||||
MluHijackObject.apply_hijack(MixtralDecoderLayer,
|
||||
MixtralDecoderLayer.forward,
|
||||
vllm__module_executor__models__mixtral__MixtralDecoderLayer__forward)
|
||||
MluHijackObject.apply_hijack(MixtralModel,
|
||||
MixtralModel.forward,
|
||||
vllm__module_executor__models__mixtral__MixtralModel__forward)
|
||||
|
||||
230
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/mllama.py
Normal file
230
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/mllama.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Mllama model."""
|
||||
import math
|
||||
from typing import (List, Optional, Tuple)
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm._mlu_ops import flash_attention
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.attention.backends.mlu_attn import MLUFlashAttentionMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.model_executor.models.llama import LlamaDecoderLayer
|
||||
from vllm.model_executor.models.mllama import (MllamaCrossAttentionDecoderLayer,
|
||||
MllamaTextCrossAttention,
|
||||
MllamaTextModel,
|
||||
MllamaVisionSdpaAttention)
|
||||
from vllm_mlu._mlu_utils import USE_PAGED, BlockSizeInfo
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__model_executor__models__mllama__MllamaTextModel__forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: Optional[torch.LongTensor],
|
||||
cross_attention_states: Optional[torch.LongTensor],
|
||||
cross_attention_mask: Optional[torch.LongTensor],
|
||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
|
||||
torch.Tensor]],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
skip_cross_attention: bool,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
|
||||
if not skip_cross_attention:
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states=hidden_states,
|
||||
cross_attention_states=cross_attention_states,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
kv_range_for_decode=kv_range_for_decode,
|
||||
full_text_row_masked_out_mask=
|
||||
full_text_row_masked_out_mask,
|
||||
kv_cache=kv_caches[idx],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
elif isinstance(decoder_layer, LlamaDecoderLayer):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: fuse residual into decoder layer.
|
||||
'''
|
||||
hidden_states = decoder_layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_caches[idx],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown decoder layer type {type(decoder_layer)}")
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def vllm__model_executor__models__mllama__MllamaVisionSdpaAttention__forward(
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_state)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = q.view(q.shape[0], q.shape[1], self.num_local_heads,
|
||||
self.head_dim)
|
||||
k = k.view(k.shape[0], k.shape[1], self.num_local_heads,
|
||||
self.head_dim)
|
||||
v = v.view(v.shape[0], v.shape[1], self.num_local_heads,
|
||||
self.head_dim)
|
||||
batch, seq_len_q, q_head_num, head_size = q.shape
|
||||
seq_len_k = k.shape[1]
|
||||
softmax_scale = head_size ** -0.5
|
||||
attention_mask = attention_mask.repeat(1, q_head_num, 1, 1)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace SDPA with flash attn.
|
||||
'''
|
||||
attn_output = flash_attention(q, k, v,
|
||||
None, # out
|
||||
None, # cu_seq_lens_q
|
||||
None, # cu_seq_lens_kv
|
||||
None, # alibi_slop
|
||||
attention_mask, # attn_bias
|
||||
seq_len_q, # max_seq_len_q
|
||||
seq_len_k, # max_seq_len_kv
|
||||
softmax_scale, # softmax_scale
|
||||
False, # is_casual
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
attn_output = attn_output.reshape(attn_output.shape[0],
|
||||
attn_output.shape[1], -1).contiguous()
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def vllm__model_executor__models__mllama__MllamaTextCrossAttention___attention_with_mask(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
kv_range_for_decode: List[Tuple[int, int]],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Skip writing kv-cache for the initial profiling run.
|
||||
if len(kv_cache[0].shape) > 1:
|
||||
if isinstance(attn_metadata, MLUFlashAttentionMetadata):
|
||||
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
||||
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
|
||||
if USE_PAGED:
|
||||
mlu_ops.reshape_paged_cache(cached_k,
|
||||
cached_v,
|
||||
kv_cache[0][0],
|
||||
kv_cache[0][1],
|
||||
attn_metadata.cross_slot_mapping,
|
||||
)
|
||||
else:
|
||||
cross_slot_mapping_flat = attn_metadata.cross_slot_mapping.flatten()
|
||||
seq_start_loc = attn_metadata._cached_prefill_metadata.encoder_seq_start_loc
|
||||
batch_ids = cross_slot_mapping_flat[seq_start_loc[:-1]] // BlockSizeInfo.BLOCK_SIZE
|
||||
max_context_len = attn_metadata._cached_prefill_metadata.max_encoder_seq_len
|
||||
mlu_ops.reshape_linear_cache(cached_k,
|
||||
cached_v,
|
||||
kv_cache[0][0],
|
||||
kv_cache[0][1],
|
||||
seq_start_loc, # context_lengths
|
||||
max_context_len,
|
||||
True, # packed
|
||||
None, # context_seq_offset
|
||||
batch_ids, # cache_bs_id
|
||||
None, # cache_seqlen_offset
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported AttentionMetadata {type(attn_metadata)} "
|
||||
f"class found. Expected the AttentionMetadata to "
|
||||
f"be either FlashAttentionMetadata.")
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace SDPA with flash attn.
|
||||
'''
|
||||
# We have to call torch.sdpa for prefill when using a
|
||||
# custom cross-attention mask. Because the mask is not a
|
||||
# standard causal mask, neither a block diagonal mask which
|
||||
# can be optimized by xformers.BlockDiagonalMask.
|
||||
# The mask is specially calculated for supporting multi
|
||||
# images and interleaved images.
|
||||
seq_len_q, q_head_num, head_size = q.shape
|
||||
softmax_scale = head_size ** -0.5
|
||||
cu_seq_lens_q = attn_metadata.seq_start_loc
|
||||
cu_seq_lens_kv = attn_metadata.encoder_seq_start_loc
|
||||
|
||||
max_seq_len_q = attn_metadata.max_prefill_seq_len
|
||||
max_seq_len_kv = attn_metadata.max_encoder_seq_len
|
||||
attn_output = flash_attention(q, k, v,
|
||||
None, # out
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_kv,
|
||||
None, # alibi_slope
|
||||
None, # attn_bias
|
||||
max_seq_len_q,
|
||||
max_seq_len_kv,
|
||||
softmax_scale,
|
||||
False, # is_causal
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
output = attn_output.reshape(seq_len_q, self.num_local_heads * self.head_dim)
|
||||
return output
|
||||
|
||||
MluHijackObject.apply_hijack(MllamaTextCrossAttention,
|
||||
MllamaTextCrossAttention._attention_with_mask,
|
||||
vllm__model_executor__models__mllama__MllamaTextCrossAttention___attention_with_mask)
|
||||
MluHijackObject.apply_hijack(MllamaTextModel,
|
||||
MllamaTextModel.forward,
|
||||
vllm__model_executor__models__mllama__MllamaTextModel__forward)
|
||||
MluHijackObject.apply_hijack(MllamaVisionSdpaAttention,
|
||||
MllamaVisionSdpaAttention.forward,
|
||||
vllm__model_executor__models__mllama__MllamaVisionSdpaAttention__forward)
|
||||
241
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/qwen.py
Normal file
241
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/qwen.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.model_executor.models.qwen import QWenAttention, QWenBlock, QWenModel, QwenImageInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.distributed import get_pp_group
|
||||
from .layer_utils import decoder_layer_forward_base
|
||||
from vllm_mlu.model_executor.models.layer_utils import (
|
||||
decoder_layer_forward_base, decoder_model_forward_base,
|
||||
is_per_tensor_smoothquant, is_per_token_smoothquant,
|
||||
quant_fusion_with_rmsnorm)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen__QwenAttention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states, smooth_quant_scale)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
qk, _ = qkv.split([self.head_dim * self.num_heads * 2, self.head_dim * self.num_heads], dim=-1)
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual
|
||||
'''
|
||||
output, _ = self.c_proj(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen__QWenBlock__init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super(QWenBlock, self).__init__()
|
||||
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
self.attn = QWenAttention(config.hidden_size,
|
||||
config.num_attention_heads,
|
||||
config.max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: 1) use FeedForward instead of MLP
|
||||
2) prepare to perf per-tensor sq cases if suitable
|
||||
'''
|
||||
self.mlp = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size // 2,
|
||||
hidden_act='silu',
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=True,
|
||||
down_proj_name='c_proj',
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
|
||||
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
self.attn.c_attn.quant_method.skip_quant_input = True
|
||||
self.mlp.gate_up_proj.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_attn_layernorm = None
|
||||
self.quant_fusion_mlp_layernorm = None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen__QWenBlock__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: perf model by:
|
||||
1) add residual in matmul;
|
||||
2) fuse quantization in layernorm in per-tensor sq case;
|
||||
'''
|
||||
attn_layernorm = self.ln_1
|
||||
mlp_layernorm = self.ln_2
|
||||
if self.is_per_tesnor_sq_perf_cases:
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.ln_1, self.attn.c_attn.scale_to_int)
|
||||
self.quant_fusion_mlp_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.ln_2, self.mlp.gate_up_proj.scale_to_int)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
mlp_layernorm = self.quant_fusion_mlp_layernorm
|
||||
elif self.is_per_token_sq_perf_cases:
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.ln_1, self.attn.c_attn.smooth, dynamic_quant=True)
|
||||
self.quant_fusion_mlp_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.ln_2, self.mlp.gate_up_proj.smooth, dynamic_quant=True)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
mlp_layernorm = self.quant_fusion_mlp_layernorm
|
||||
|
||||
return decoder_layer_forward_base(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
input_layernorm=attn_layernorm,
|
||||
self_attn=self.attn,
|
||||
post_layernorm=mlp_layernorm,
|
||||
mlp=self.mlp,
|
||||
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
|
||||
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen__QWenModel__forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
pixel_values: Optional[QwenImageInputs],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
img_pos = None
|
||||
# If pixel / visual embeddings are provided, this is a visual model
|
||||
if pixel_values is not None and self.visual is not None:
|
||||
if pixel_values["type"] != "image_embeds":
|
||||
image_embeds = self.visual(pixel_values["data"])
|
||||
else:
|
||||
image_embeds = pixel_values["data"]
|
||||
|
||||
# features should be of shape (# images, 256, hidden_dim)
|
||||
img_pos = self.visual.get_image_positions(input_ids)
|
||||
if isinstance(
|
||||
img_pos,
|
||||
np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]:
|
||||
raise ValueError(
|
||||
f"Number of placeholders: {img_pos.shape[0]} "
|
||||
f"does not match number of images {image_embeds.shape[0]}."
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: remove residual
|
||||
'''
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.wte(input_ids)
|
||||
# Merge the image embeddings into the hidden states if actually have
|
||||
# visual features and the corresponding image tokens
|
||||
if img_pos is not None:
|
||||
for idx, (img_bos, img_eos) in enumerate(img_pos):
|
||||
hidden_states[img_bos + 1:img_eos] = image_embeds[idx]
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
})
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return hidden_states
|
||||
|
||||
MluHijackObject.apply_hijack(QWenAttention,
|
||||
QWenAttention.forward,
|
||||
vllm__module_executor__models__qwen__QwenAttention__forward)
|
||||
MluHijackObject.apply_hijack(QWenBlock,
|
||||
QWenBlock.__init__,
|
||||
vllm__module_executor__models__qwen__QWenBlock__init__)
|
||||
MluHijackObject.apply_hijack(QWenBlock,
|
||||
QWenBlock.forward,
|
||||
vllm__module_executor__models__qwen__QWenBlock__forward)
|
||||
MluHijackObject.apply_hijack(QWenModel,
|
||||
QWenModel.forward,
|
||||
vllm__module_executor__models__qwen__QWenModel__forward)
|
||||
225
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/qwen2.py
Normal file
225
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/qwen2.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import torch
|
||||
|
||||
from typing import List, Optional
|
||||
from transformers import Qwen2Config
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Attention, Qwen2DecoderLayer, Qwen2Model
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm_mlu.model_executor.models.layer_utils import (
|
||||
decoder_layer_forward_base, decoder_model_forward_base_pp,
|
||||
is_per_tensor_smoothquant, is_per_token_smoothquant,
|
||||
quant_fusion_with_rmsnorm)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen2__Qwen2Attention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual
|
||||
'''
|
||||
output, _ = self.o_proj(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen2__Qwen2DecoderLayer____init__(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super(Qwen2DecoderLayer, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
self.self_attn = Qwen2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use FeedForward instead of MLP
|
||||
'''
|
||||
self.mlp = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act='silu',
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=True,
|
||||
down_proj_name='down_proj',
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: prepare to perf per-tensor sq cases if suitable
|
||||
'''
|
||||
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
|
||||
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
|
||||
self.mlp.gate_up_proj.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_attn_layernorm = None
|
||||
self.quant_fusion_mlp_layernorm = None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen2__Qwen2DecoderLayer__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: perf model by:
|
||||
1) add residual in matmul;
|
||||
2) fuse quantization in layernorm in per-tensor sq case;
|
||||
'''
|
||||
attn_layernorm = self.input_layernorm
|
||||
mlp_layernorm = self.post_attention_layernorm
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
if self.is_per_token_sq_perf_cases:
|
||||
attn_quant_scale = self.self_attn.qkv_proj.smooth
|
||||
mlp_quant_scale = self.mlp.gate_up_proj.smooth
|
||||
else:
|
||||
attn_quant_scale = self.self_attn.qkv_proj.scale_to_int
|
||||
mlp_quant_scale = self.mlp.gate_up_proj.scale_to_int
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.input_layernorm, attn_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
self.quant_fusion_mlp_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.post_attention_layernorm, mlp_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
mlp_layernorm = self.quant_fusion_mlp_layernorm
|
||||
|
||||
return decoder_layer_forward_base(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
input_layernorm=attn_layernorm,
|
||||
self_attn=self.self_attn,
|
||||
post_layernorm=mlp_layernorm,
|
||||
mlp=self.mlp,
|
||||
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
|
||||
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen2__Qwen2Model__forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
return decoder_model_forward_base_pp(input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
layers=self.layers,
|
||||
start_layer=self.start_layer,
|
||||
end_layer=self.end_layer,
|
||||
get_input_embeddings=self.embed_tokens,
|
||||
norm=self.norm,
|
||||
inputs_embeds=inputs_embeds)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(Qwen2Attention,
|
||||
Qwen2Attention.forward,
|
||||
vllm__module_executor__models__qwen2__Qwen2Attention__forward)
|
||||
MluHijackObject.apply_hijack(Qwen2DecoderLayer,
|
||||
Qwen2DecoderLayer.__init__,
|
||||
vllm__module_executor__models__qwen2__Qwen2DecoderLayer____init__)
|
||||
MluHijackObject.apply_hijack(Qwen2DecoderLayer,
|
||||
Qwen2DecoderLayer.forward,
|
||||
vllm__module_executor__models__qwen2__Qwen2DecoderLayer__forward)
|
||||
MluHijackObject.apply_hijack(Qwen2Model,
|
||||
Qwen2Model.forward,
|
||||
vllm__module_executor__models__qwen2__Qwen2Model__forward)
|
||||
449
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/qwen2_moe.py
Normal file
449
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/qwen2_moe.py
Normal file
@@ -0,0 +1,449 @@
|
||||
import torch
|
||||
import re
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Optional, Tuple, Iterable
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.models.qwen2_moe import Qwen2MoeAttention, Qwen2MoeDecoderLayer, Qwen2MoeForCausalLM, Qwen2MoeModel
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
from vllm.utils import print_warning_once
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm_mlu.model_executor.models.layer_utils import (
|
||||
decoder_layer_forward_base, decoder_model_forward_base_pp,
|
||||
is_per_tensor_smoothquant, is_per_token_smoothquant,
|
||||
quant_fusion_with_rmsnorm)
|
||||
|
||||
|
||||
class Qwen2MoeSparseMoeBlock(SparseMoeMlp):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__(num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
up_proj_name="gate_up_proj",
|
||||
is_gated=True,
|
||||
down_proj_name="down_proj",
|
||||
has_bias=False,
|
||||
skip_bias_add=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
hidden_act=config.hidden_act,
|
||||
params_dtype=None,
|
||||
quant_config=quant_config,
|
||||
is_use_fused_moe=True)
|
||||
self.config = config
|
||||
self.shared_expert = None
|
||||
self.shared_expert_gate = None
|
||||
if config.shared_expert_intermediate_size > 0:
|
||||
self.shared_expert = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.shared_expert_intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=True,
|
||||
down_proj_name='down_proj',
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False)
|
||||
self.shared_expert_gate = ReplicatedLinear(config.hidden_size,
|
||||
1,
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype,
|
||||
quant_config=None)
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
shared_output = None
|
||||
if self.shared_expert is not None:
|
||||
shared_output = self.shared_expert(hidden_states)
|
||||
if self.shared_expert_gate is not None:
|
||||
gate_output = self.shared_expert_gate(hidden_states)
|
||||
shared_output = F.sigmoid(gate_output[0]) * shared_output
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.forward_experts(hidden_states, router_logits, residual)
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen2moe__Qwen2MoeAttention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual
|
||||
'''
|
||||
output, _ = self.o_proj(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen2moe__Qwen2MoeDecoderLayer____init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super(Qwen2MoeDecoderLayer, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = Qwen2MoeAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
||||
# `mlp_only_layers` in the config.
|
||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
||||
config.mlp_only_layers)
|
||||
if (layer_idx not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and
|
||||
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use FeedForward instead of MLP
|
||||
'''
|
||||
self.mlp = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=True,
|
||||
down_proj_name='down_proj',
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: prepare to perf per-tensor sq cases if suitable. For moe
|
||||
model, we only do quant fusion in attn block.
|
||||
'''
|
||||
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
|
||||
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_attn_layernorm = None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen2moe__Qwen2MoeForCausalLM__load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack params and cal start expert id
|
||||
'''
|
||||
for name, m in self.model.named_modules():
|
||||
if isinstance(m, SparseMoeMlp):
|
||||
m.pack_params()
|
||||
|
||||
start_expert_id = 0
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: delete expert_params_mapping for no useless
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace expert_id in weight to named_expert_id in params_dict
|
||||
'''
|
||||
if start_expert_id > 0 and "mlp.experts." in name:
|
||||
expert_str = re.search(r'experts\.\d+', name).group(0)
|
||||
expert_id=int(expert_str.split(".")[1])
|
||||
named_expert_id = expert_id - start_expert_id
|
||||
old_expert_name = f"experts.{expert_id}"
|
||||
new_expert_name = f"experts.{named_expert_id}"
|
||||
name = name.replace(old_expert_name, new_expert_name)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: delete if "mlp.experts" in name: continue condition
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add expert skiped condition and delete useless if name not in params_dict: continue condition
|
||||
'''
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if (("mlp.experts." in name or "mlp.shared_expert." in name or "mlp.shared_expert_gate." in name)
|
||||
and name not in params_dict):
|
||||
continue
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: delete for mapping in expert_params_mapping condition
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if name.endswith("kv_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
".kv_scale", ".attn.kv_scale")
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
print_warning_once(
|
||||
"Found kv scale in the checkpoint "
|
||||
f"(e.g. {name}), but not found the expected "
|
||||
f"name in the model "
|
||||
f"(e.g. {remapped_kv_scale_name}). "
|
||||
"kv-scale is not loaded.")
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add expert skiped condition
|
||||
'''
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if (("mlp.experts." in name or "mlp.shared_expert." in name or "mlp.shared_expert_gate." in name)
|
||||
and name not in params_dict):
|
||||
continue
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen2moe__Qwen2MoeDecoderLayer__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: perf model by:
|
||||
1) add residual in matmul;
|
||||
2) fuse quantization in layernorm in per-tensor sq case;
|
||||
'''
|
||||
attn_layernorm = self.input_layernorm
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
if self.is_per_token_sq_perf_cases:
|
||||
attn_quant_scale = self.self_attn.qkv_proj.smooth
|
||||
else:
|
||||
attn_quant_scale = self.self_attn.qkv_proj.scale_to_int
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.input_layernorm, attn_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
|
||||
return decoder_layer_forward_base(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
input_layernorm=attn_layernorm,
|
||||
self_attn=self.self_attn,
|
||||
post_layernorm=self.post_attention_layernorm,
|
||||
mlp=self.mlp,
|
||||
input_norm_fuse_en=self.is_per_token_sq_perf_cases)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen2moe__Qwen2MoeModel__forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
return decoder_model_forward_base_pp(input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
layers=self.layers,
|
||||
start_layer=self.start_layer,
|
||||
end_layer=self.end_layer,
|
||||
get_input_embeddings=self.embed_tokens,
|
||||
norm=self.norm)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(Qwen2MoeAttention,
|
||||
Qwen2MoeAttention.forward,
|
||||
vllm__module_executor__models__qwen2moe__Qwen2MoeAttention__forward)
|
||||
MluHijackObject.apply_hijack(Qwen2MoeDecoderLayer,
|
||||
Qwen2MoeDecoderLayer.__init__,
|
||||
vllm__module_executor__models__qwen2moe__Qwen2MoeDecoderLayer____init__)
|
||||
MluHijackObject.apply_hijack(Qwen2MoeForCausalLM,
|
||||
Qwen2MoeForCausalLM.load_weights,
|
||||
vllm__module_executor__models__qwen2moe__Qwen2MoeForCausalLM__load_weights)
|
||||
MluHijackObject.apply_hijack(Qwen2MoeDecoderLayer,
|
||||
Qwen2MoeDecoderLayer.forward,
|
||||
vllm__module_executor__models__qwen2moe__Qwen2MoeDecoderLayer__forward)
|
||||
MluHijackObject.apply_hijack(Qwen2MoeModel,
|
||||
Qwen2MoeModel.forward,
|
||||
vllm__module_executor__models__qwen2moe__Qwen2MoeModel__forward)
|
||||
272
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/qwen2_vl.py
Normal file
272
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/qwen2_vl.py
Normal file
@@ -0,0 +1,272 @@
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm.model_executor.models.qwen2_vl import (
|
||||
Qwen2VisionMLP, Qwen2VisionTransformer, Qwen2VisionAttention, Qwen2VLForConditionalGeneration)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
def vllm__module_executor__models__qwen2_vl__Qwen2VisionTransformer__forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor
|
||||
):
|
||||
# patchify
|
||||
x = x.to(device=self.device, dtype=self.dtype)
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# compute position embedding
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
# compute cos sin for apply_rope
|
||||
cos = rotary_pos_emb.cos()
|
||||
sin = rotary_pos_emb.sin()
|
||||
cos = repeat(cos, "... d -> ... (2 d)")
|
||||
sin = repeat(sin, "... d -> ... (2 d)")
|
||||
rotary_pos_emb.cos = cos
|
||||
rotary_pos_emb.sin = sin
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
|
||||
grid_thw[:, 0]).cumsum(
|
||||
dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||
|
||||
# transformers
|
||||
x = x.unsqueeze(1)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
return x
|
||||
|
||||
def vllm__module_executor__models__qwen2_vl__Qwen2VisionAttention__forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor
|
||||
):
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
||||
new_x_shape = x.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head,
|
||||
)
|
||||
x = x.view(*new_x_shape)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: apply mlu_ops.apply_rotary
|
||||
'''
|
||||
# [s, b, head, 3 * head_dim] --> 3 [b, s, head, head_dim]
|
||||
batch_size = x.shape[1]
|
||||
x = rearrange(x, "s b ... -> b s ...")
|
||||
head_dim = x.shape[-1] // 3
|
||||
q, k, v = x.split([head_dim] * 3, dim=-1)
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
sin = rotary_pos_emb.sin
|
||||
cos = rotary_pos_emb.cos
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
q = q.float()
|
||||
q = mlu_ops.rotary_embedding(
|
||||
q, sin, cos, None, None, False, False, False, q.shape[1]
|
||||
)
|
||||
k = k.float()
|
||||
k = mlu_ops.rotary_embedding(
|
||||
k, sin, cos, None, None, False, False, False, k.shape[1]
|
||||
)
|
||||
q = q.type_as(v)
|
||||
k = k.type_as(v)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
# from vllm_flash_attn.flash_attn_interface import (
|
||||
# flash_attn_varlen_func)
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
output = flash_attn_varlen_func(q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0,
|
||||
causal=False)
|
||||
|
||||
context_layer = rearrange(output,
|
||||
"(b s) ... -> b s ...",
|
||||
b=batch_size)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
seq_length = q.size(1)
|
||||
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
|
||||
attention_mask = torch.zeros([1, seq_length, seq_length],
|
||||
device=q.device,
|
||||
dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
|
||||
cu_seqlens[i - 1]:cu_seqlens[i]] = True
|
||||
output = F.scaled_dot_product_attention(q,
|
||||
k,
|
||||
v,
|
||||
attention_mask,
|
||||
dropout_p=0.0)
|
||||
context_layer = rearrange(output, "b h s d -> b s h d ")
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
|
||||
kv_seqlen=None)
|
||||
|
||||
context_layer = xops.memory_efficient_attention_forward(
|
||||
q, k, v, attn_bias=attn_bias, p=0, scale=None)
|
||||
elif self.attn_backend == _Backend.MLU_FLASH_ATTN:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: apply mlu_ops.flash_attention
|
||||
'''
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
output = mlu_ops.flash_attention(q,
|
||||
k,
|
||||
v,
|
||||
out=None,
|
||||
cu_seq_lens_q=cu_seqlens,
|
||||
cu_seq_lens_kv=cu_seqlens,
|
||||
max_seq_len_q=max_seqlen,
|
||||
max_seq_len_kv=max_seqlen,
|
||||
alibi_slope=None,
|
||||
attn_bias=None,
|
||||
softmax_scale=head_dim ** -0.5,
|
||||
is_causal=False)
|
||||
context_layer = rearrange(output,
|
||||
"(b s) ... -> b s ...",
|
||||
b=batch_size)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
context_layer = rearrange(context_layer,
|
||||
"b s h d -> s b (h d)").contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
vllm__module_executor__models__qwen2_vl__Qwen2VisionAttention__init_org = Qwen2VisionAttention.__init__
|
||||
|
||||
def vllm__module_executor__models__qwen2_vl__Qwen2VisionAttention____init__(
|
||||
self,
|
||||
embed_dim: Optional[int] = None,
|
||||
num_heads: Optional[int] = None,
|
||||
projection_size: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
vllm__module_executor__models__qwen2_vl__Qwen2VisionAttention__init_org(
|
||||
self, embed_dim, num_heads, projection_size, quant_config, prefix)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use mlu_ops.flash_atten for better performance
|
||||
'''
|
||||
self.attn_backend = _Backend.MLU_FLASH_ATTN
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def vllm__module_executor__models__qwen2_vl___maybe_ignore_quant_config(
|
||||
self,
|
||||
quant_config: QuantizationConfig
|
||||
):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: quantization for vit not yet supported
|
||||
'''
|
||||
if quant_config is not None:
|
||||
logger.warning("Quantization for VisionTransformer not yet supported.")
|
||||
return None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def vllm_module_executor__models__qwen2_vl__Qwen2VisionMLP__forward(
|
||||
self,
|
||||
x: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
x_parallel, _ = self.fc1(x)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: better acc than mlu_ops.active for half precision
|
||||
'''
|
||||
if x_parallel.dtype == torch.half and isinstance(self.act, QuickGELU):
|
||||
x_parallel = self.act.forward_native(x_parallel)
|
||||
else:
|
||||
x_parallel = self.act(x_parallel)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
x, _ = self.fc2(x_parallel)
|
||||
return x
|
||||
|
||||
MluHijackObject.apply_hijack(Qwen2VisionTransformer,
|
||||
Qwen2VisionTransformer.forward,
|
||||
vllm__module_executor__models__qwen2_vl__Qwen2VisionTransformer__forward)
|
||||
MluHijackObject.apply_hijack(Qwen2VisionAttention,
|
||||
Qwen2VisionAttention.forward,
|
||||
vllm__module_executor__models__qwen2_vl__Qwen2VisionAttention__forward)
|
||||
MluHijackObject.apply_hijack(Qwen2VisionAttention,
|
||||
Qwen2VisionAttention.__init__,
|
||||
vllm__module_executor__models__qwen2_vl__Qwen2VisionAttention____init__)
|
||||
MluHijackObject.apply_hijack(Qwen2VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration._maybe_ignore_quant_config,
|
||||
vllm__module_executor__models__qwen2_vl___maybe_ignore_quant_config)
|
||||
MluHijackObject.apply_hijack(Qwen2VisionMLP,
|
||||
Qwen2VisionMLP.forward,
|
||||
vllm_module_executor__models__qwen2_vl__Qwen2VisionMLP__forward)
|
||||
264
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/qwen3.py
Normal file
264
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/qwen3.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import torch
|
||||
|
||||
from typing import List, Optional
|
||||
from transformers import Qwen2Config as Qwen3Config
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.models.qwen3 import Qwen3Attention, Qwen3DecoderLayer, Qwen3Model
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm_mlu.model_executor.models.layer_utils import (
|
||||
decoder_layer_forward_base, decoder_model_forward_base_pp,
|
||||
is_per_tensor_smoothquant, is_per_token_smoothquant,
|
||||
quant_fusion_with_rmsnorm)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen3__Qwen3Attention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu for Qwen3
|
||||
=============================
|
||||
@brief: Apply QK normalization (Qwen3 specific)
|
||||
Reference: Qwen2 MLU implementation style
|
||||
'''
|
||||
# Make q and k contiguous before reshape/view operations
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
|
||||
# Apply QK normalization before rotary embedding
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
q_by_head = q.view(*q_shape[:-1], self.num_heads, self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q_shape)
|
||||
|
||||
k_by_head = k.view(*k_shape[:-1], self.num_kv_heads, self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k_shape)
|
||||
'''
|
||||
==================
|
||||
End of QK Norm
|
||||
==================
|
||||
'''
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
Reference: Qwen2 MLU implementation
|
||||
'''
|
||||
# Pack q and k for MLU rotary embedding optimization
|
||||
# Ensure qk is contiguous for view operation
|
||||
qk = torch.cat([q, k], dim=-1).contiguous()
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
# Split back after rotary
|
||||
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual
|
||||
'''
|
||||
output, _ = self.o_proj(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen3__Qwen3DecoderLayer____init__(
|
||||
self,
|
||||
config: Qwen3Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super(Qwen3DecoderLayer, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
self.self_attn = Qwen3Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_scaling=rope_scaling,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, "attention_bias", False),
|
||||
head_dim=getattr(config, "head_dim", None),
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use FeedForward instead of MLP
|
||||
'''
|
||||
self.mlp = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act='silu',
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=True,
|
||||
down_proj_name='down_proj',
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: prepare to perf per-tensor sq cases if suitable
|
||||
'''
|
||||
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(quant_config)
|
||||
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(quant_config)
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
|
||||
self.mlp.gate_up_proj.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_attn_layernorm = None
|
||||
self.quant_fusion_mlp_layernorm = None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen3__Qwen3DecoderLayer__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: perf model by:
|
||||
1) add residual in matmul;
|
||||
2) fuse quantization in layernorm in per-tensor sq case;
|
||||
'''
|
||||
attn_layernorm = self.input_layernorm
|
||||
mlp_layernorm = self.post_attention_layernorm
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
if self.is_per_token_sq_perf_cases:
|
||||
attn_quant_scale = self.self_attn.qkv_proj.smooth
|
||||
mlp_quant_scale = self.mlp.gate_up_proj.smooth
|
||||
else:
|
||||
attn_quant_scale = self.self_attn.qkv_proj.scale_to_int
|
||||
mlp_quant_scale = self.mlp.gate_up_proj.scale_to_int
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.input_layernorm, attn_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
self.quant_fusion_mlp_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.post_attention_layernorm, mlp_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
mlp_layernorm = self.quant_fusion_mlp_layernorm
|
||||
|
||||
return decoder_layer_forward_base(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
input_layernorm=attn_layernorm,
|
||||
self_attn=self.self_attn,
|
||||
post_layernorm=mlp_layernorm,
|
||||
mlp=self.mlp,
|
||||
input_norm_fuse_en=self.is_per_token_sq_perf_cases,
|
||||
post_norm_fuse_en=self.is_per_token_sq_perf_cases)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen3__Qwen3Model__forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
return decoder_model_forward_base_pp(input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
layers=self.layers,
|
||||
start_layer=self.start_layer,
|
||||
end_layer=self.end_layer,
|
||||
get_input_embeddings=self.embed_tokens,
|
||||
norm=self.norm,
|
||||
inputs_embeds=inputs_embeds)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
# Apply hijacks
|
||||
MluHijackObject.apply_hijack(Qwen3Attention,
|
||||
Qwen3Attention.forward,
|
||||
vllm__module_executor__models__qwen3__Qwen3Attention__forward)
|
||||
MluHijackObject.apply_hijack(Qwen3DecoderLayer,
|
||||
Qwen3DecoderLayer.__init__,
|
||||
vllm__module_executor__models__qwen3__Qwen3DecoderLayer____init__)
|
||||
MluHijackObject.apply_hijack(Qwen3DecoderLayer,
|
||||
Qwen3DecoderLayer.forward,
|
||||
vllm__module_executor__models__qwen3__Qwen3DecoderLayer__forward)
|
||||
MluHijackObject.apply_hijack(Qwen3Model,
|
||||
Qwen3Model.forward,
|
||||
vllm__module_executor__models__qwen3__Qwen3Model__forward)
|
||||
Reference in New Issue
Block a user