add qwen3

This commit is contained in:
Chranos
2026-02-04 17:22:39 +08:00
parent d1c0f68ab4
commit 8511fe8530
1932 changed files with 300426 additions and 0 deletions

View File

@@ -0,0 +1 @@
from . import model_executor

View File

@@ -0,0 +1,2 @@
from . import layers
from . import model_loader

View File

@@ -0,0 +1,2 @@
from . import feed_forward
from . import linear

View File

@@ -0,0 +1,98 @@
import torch
from typing import Optional
from vllm_mlu.mlu_hijack_utils import MluHijackObject, set_is_gated
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm.distributed.parallel_state import get_tp_group, get_tensor_model_parallel_group
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed import get_tensor_model_parallel_rank
from vllm import _mlu_ops as mlu_ops
from vllm.lora.layers import BaseLayerWithLoRA
from vllm_mlu._mlu_utils import *
def vllm_mlu__model_executor__layers__feed_forward__FeedForward__forward(
self,
hidden_states,
residual: Optional[torch.Tensor] = None
):
self.prepare_weight()
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
residual_ = None if self.tp_rank > 0 else residual
if (self.use_bt_ffn and not isinstance(up_proj, BaseLayerWithLoRA)
and not isinstance(down_proj, BaseLayerWithLoRA)):
# The matmul formula is the following:
# mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual
# output = active(mul_out)
# Notes: We cannot use the activation function in matmul because it does not support gated operation
# we might support its in tmo matmul in the future
fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj.weight, up_proj.bias,
None, 'none', self.alpha, self.beta)
act_out = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
beta = 1.0 if residual_ is not None else 0.0
'''
=======================================
Modify by custom vllm_mlu
=======================================
@brief: call parallel op and abandon original reduce if parallel_num is set
'''
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
if is_parallel_enable:
rank = get_tensor_model_parallel_rank()
pg = get_tensor_model_parallel_group().device_group
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
out_ = mlu_ops.matmul_allreduce(cncl_comm, act_out, down_proj.weight, None, residual_,
self.alpha, beta, self.parallel_num)
else:
out_ = mlu_ops.matmul(act_out, down_proj.weight, None, residual_, 'none', self.alpha, beta)
'''
=======================================
End of custom MLU Hijack
=======================================
'''
# bias if existed need to add after second matmul according to the original design of vllm
'''
=============================
Modify by custom vllm_mlu
=============================
@brief: when preload_size is set, call GroupCoordinator.all_reduce() directly and
use async_op to set all_reduce paralleled with preload
'''
if self.reduce_results and self.tp_size > 1 and not is_parallel_enable:
if hasattr(self, 'preload_size') and self.preload_size > 0 and not self.is_prompt:
handle = get_tp_group().all_reduce(out_, async_op=True)
_MB = 1 << 20
mlu_ops.preload(self.preloaded_weights[0].data, self.preload_size * _MB)
preloaded_weights_size = self.preloaded_weights[0].numel() * self.preloaded_weights[0].element_size()
if preloaded_weights_size < (self.preload_size * _MB) and len(self.preloaded_weights) > 1:
mlu_ops.preload(self.preloaded_weights[1].data, (self.preload_size * _MB) - preloaded_weights_size)
handle.wait()
out = out_
else:
out = tensor_model_parallel_all_reduce(out_)
else:
out = out_
'''
=========================
End of custom MLU Hijack
=========================
'''
# do the bias add if needed
if not self.skip_bias_add:
out = out + down_proj.bias if down_proj.bias is not None else out
else:
return out, down_proj.bias
else:
fc1, bias = up_proj(hidden_states)
if bias is not None:
fc1 += bias
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
out, bias = down_proj(fc1, residual=residual_)
if self.skip_bias_add:
return out, bias
return out
MluHijackObject.apply_hijack(FeedForward,
FeedForward.forward,
vllm_mlu__model_executor__layers__feed_forward__FeedForward__forward)

View File

@@ -0,0 +1,116 @@
from typing import Optional
import torch
from vllm.distributed.parallel_state import get_tp_group, get_tensor_model_parallel_group
from vllm.distributed import get_tensor_model_parallel_rank, split_tensor_along_last_dim
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor.layers.linear import UnquantizedLinearMethod, RowParallelLinear
from vllm import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
def vllm__model_executor__layers__linear__UnquantizedLinearMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
beta = 1.0 if residual is not None else 0.0
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
'''
=====================================================
Modify by custom vllm_mlu
=====================================================
@brief: call parallel op if parallel_num is set
'''
if hasattr(self, 'parallel_num') and get_is_prompt():
rank = get_tensor_model_parallel_rank()
pg = get_tensor_model_parallel_group().device_group
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
return mlu_ops.matmul_allreduce(cncl_comm, x.view(-1, x.shape[-1]), layer.weight,
bias, residual, 1.0, beta, self.parallel_num).view(res_shape)
return mlu_ops.matmul(x.view(-1, x.shape[-1]), layer.weight, bias, residual, 'none', 1.0, beta).view(res_shape)
'''
=====================================================
End of custom MLU Hijack
=====================================================
'''
def vllm__model_executor__layers__linear__RowParallelLinear__forward(
self,
input_,
residual: Optional[torch.Tensor] = None
):
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
residual_ = None if self.tp_rank > 0 else residual
'''
=====================================================
Modify by custom vllm_mlu
=====================================================
@brief: abandon original reduce if parallel_num is set
'''
is_parallel_enable = hasattr(self.quant_method, 'parallel_num') and get_is_prompt()
'''
=====================================================
End of custom MLU Hijack
=====================================================
'''
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_,
residual=residual_)
'''
=============================
Modify by custom vllm_mlu
=============================
@brief: when preload_size is set, call GroupCoordinator.all_reduce() directly and
use async_op to set all_reduce paralleled with preload
'''
if self.reduce_results and self.tp_size > 1 and not is_parallel_enable:
if hasattr(self, 'preload_size') and self.preload_size > 0 and not self.is_prompt:
handle = get_tp_group().all_reduce(output_parallel, async_op=True)
_MB = 1 << 20
mlu_ops.preload(self.preloaded_weights[0].data, self.preload_size * _MB)
preloaded_weights_size = self.preloaded_weights[0].numel() * self.preloaded_weights[0].element_size()
if preloaded_weights_size < (self.preload_size * _MB) and len(self.preloaded_weights) > 1:
mlu_ops.preload(self.preloaded_weights[1].data, (self.preload_size * _MB) - preloaded_weights_size)
handle.wait()
output = output_parallel
else:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
'''
=========================
End of custom MLU Hijack
=========================
'''
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
MluHijackObject.undo_hijack(UnquantizedLinearMethod,
UnquantizedLinearMethod.apply)
MluHijackObject.apply_hijack(UnquantizedLinearMethod,
UnquantizedLinearMethod.apply,
vllm__model_executor__layers__linear__UnquantizedLinearMethod__apply)
MluHijackObject.undo_hijack(RowParallelLinear,
RowParallelLinear.forward)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.forward,
vllm__model_executor__layers__linear__RowParallelLinear__forward)

View File

@@ -0,0 +1 @@
from . import loader

View File

@@ -0,0 +1,143 @@
import os
import torch
from torch import nn
from typing import Optional
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.config import VllmConfig, ModelConfig, ParallelConfig
from vllm_mlu._mlu_utils import *
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
logger = init_logger(__name__)
def get_parallel_num(
model_config: ModelConfig,
parallel_config: ParallelConfig
):
attention_parallel_num = os.environ.get(ATTN_PARALLEL_NUM)
ffn_parallel_num = os.environ.get(FFN_PARALLEL_NUM)
if attention_parallel_num and attention_parallel_num.isdecimal():
attention_parallel_num = int(attention_parallel_num)
else:
attention_parallel_num = 0
if ffn_parallel_num and ffn_parallel_num.isdecimal():
ffn_parallel_num = int(ffn_parallel_num)
else:
ffn_parallel_num = 0
if parallel_config.tensor_parallel_size == 1:
raise ValueError("Can not use context_comm_cmpt_parallel when tp num is 1.")
if (attention_parallel_num <= 0 and ffn_parallel_num <= 0):
raise ValueError("attention_parallel_num and ffn_parallel_num must be positive integers.")
hidden_size = model_config.get_hidden_size()
ffn_parallel_num = max(ffn_parallel_num, 1)
if hidden_size % ffn_parallel_num != 0:
raise ValueError(f"Hidden_size: {hidden_size} must be divisible by ffn_parallel_num: {ffn_parallel_num}")
return attention_parallel_num, ffn_parallel_num
def get_attr_by_path(obj, path):
# Split the path by dots to get individual attributes
attributes = path.split('.')
# Iterate through the attributes to access nested members
for attr in attributes:
if not hasattr(obj, attr):
return None
obj = getattr(obj, attr)
return obj
def set_custom_attributes(model, model_config, parallel_config):
attn_row_parallel_layers = []
attn_weights = []
ffn_row_parallel_layers = []
ffn_weights = []
sparse_moe_mlp_layers = []
for module in model.modules():
if module.__class__.__name__ == "FeedForward":
ffn_weight = []
if hasattr(module, "up_proj_name"):
up_proj_name = getattr(module, "up_proj_name")
up_proj = getattr(module, up_proj_name)
if hasattr(up_proj, "weight"):
ffn_weight.append(up_proj.weight)
if hasattr(module, "down_proj_name"):
down_proj_name = getattr(module, "down_proj_name")
down_proj = getattr(module, down_proj_name)
if hasattr(down_proj, "weight"):
ffn_weight.append(down_proj.weight)
if ffn_weight is not None:
ffn_weights.append(ffn_weight)
ffn_row_parallel_layers.append(module)
for child_module in module.children():
if child_module.__class__.__name__ == "Attention":
for sibling_module in module.children():
if sibling_module.__class__.__name__ == "QKVParallelLinear":
if hasattr(sibling_module, "weight"):
weight = getattr(sibling_module, "weight")
attn_weights.append([weight])
if sibling_module.__class__.__name__ == "RowParallelLinear":
attn_row_parallel_layers.append(sibling_module)
if module.__class__.__name__ == "SparseMoeMlp" or issubclass(module.__class__, SparseMoeMlp):
sparse_moe_mlp_layers.append(module)
if VLLM_PRELOAD_SIZE > 0:
if (len(attn_row_parallel_layers) \
== len(attn_weights) \
== len(ffn_row_parallel_layers) \
== len(ffn_weights)) and \
len(attn_row_parallel_layers) != 0:
for i in range(len(attn_row_parallel_layers)):
attn_row_parallel_layers[i].preloaded_weights = ffn_weights[i]
attn_row_parallel_layers[i].preload_size = VLLM_PRELOAD_SIZE
if i < len(attn_row_parallel_layers) - 1:
ffn_row_parallel_layers[i].preloaded_weights = attn_weights[i+1]
ffn_row_parallel_layers[i].preload_size = VLLM_PRELOAD_SIZE
else:
logger.warning("%s does not support preload weight!", model.__class__.__name__)
# context compute communication parallel
if check_context_comm_cmpt_parallel():
attention_parallel_num, ffn_parallel_num = get_parallel_num(model_config, parallel_config)
for o_proj in attn_row_parallel_layers:
setattr(o_proj.quant_method, 'parallel_num', attention_parallel_num)
if len(sparse_moe_mlp_layers) != 0:
for sparse_moe_mlp in sparse_moe_mlp_layers:
setattr(sparse_moe_mlp, 'parallel_num', ffn_parallel_num)
else:
for ffn in ffn_row_parallel_layers:
setattr(ffn, 'parallel_num', ffn_parallel_num)
vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model__org = DefaultModelLoader.load_model
def vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model(
self, vllm_config: VllmConfig) -> nn.Module:
model = vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model__org(
self, vllm_config=vllm_config)
'''
=============================
Modify by custom vllm_mlu
=============================
@brief: According to the layer name in models, set custom optimize attributes.
'''
set_custom_attributes(model, vllm_config.model_config, vllm_config.parallel_config)
'''
=========================
End of custom MLU Hijack
=========================
'''
return model
MluHijackObject.apply_hijack(DefaultModelLoader,
DefaultModelLoader.load_model,
vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model)

View File

@@ -0,0 +1,17 @@
### 简介
该劫持代码实现了vllm Context通算并行功能。开启后可在部分数据规模和切分数量上对Context Latency指标有优化效果。目前是可选功能默认不开启。
### 开启方法
- 设置环境变量ATTN_PARALLEL_NUM和FFN_PARALLEL_NUM为正整数分别控制attention和ffn部分的通算并行切分数量。两个环境变量相互独立可以同时开启。例如输入export ATTN_PARALLEL_NUM=2 FFN_PARALLEL_NUM=4则表示两部分均开启并行attention数据拆分为2份ffn数据拆分为4份。
- 需要保证tensor_parallel_size大于1。
- 开启ffn部分的通算并行时需要保证hidden_size能被FFN_PARALLEL_NUM整除。
### 注意事项
- 开启通算并行功能时由于算子限制Mixtral系列模型、Qwen2包含Qwen1.5和Qwen2.5系列模型在smoothquant量化下只支持batch_size = 1且算子默认切分数为4ATTN_PARALLEL_NUM不生效。
- smoothquant量化下vllm_mlu ffn部分不调用tmo matmul算子该部分通算融合不生效。

View File

@@ -0,0 +1 @@
from . import model_executor

View File

@@ -0,0 +1,3 @@
from . import custom_model
from . import layers
from . import models

View File

@@ -0,0 +1,62 @@
import torch
from typing import Optional
from vllm import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.distributed import tensor_model_parallel_all_reduce, get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
from vllm_mlu.model_executor.custom_model.custom import CustomMoeBlock
def vllm__module_executor__custom_model__CustomMoeBlock__forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
if self.shared_expert_gate is not None:
shared_output = F.sigmoid(
self.shared_expert_gate(hidden_states)) * shared_output
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
residual_ = None if self.rank > 0 else residual
'''
=====================================================
Modify by Context Communication Computation Parallel
=====================================================
@brief: call fused_moe
'''
params = [hidden_states, router_logits, self.w1, self.w2, None, None,
residual_, self.input_smooth, self.act_smooth, self.w1_scale, self.w2_scale,
self.top_k, self.config.norm_topk_prob, self.config.is_gated, self.config.hidden_act, 0]
if hasattr(self, 'parallel_num') and get_is_prompt():
rank = get_tensor_model_parallel_rank()
pg = get_tensor_model_parallel_group().device_group
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
params.extend([self.parallel_num, cncl_comm])
final_hidden_states = mlu_ops.fused_moe(*params)
'''
=====================================================
End of Context Communication Computation Parallel
=====================================================
'''
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
reduce_results = (self.config.use_parallel_residual == False)
if reduce_results:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
MluHijackObject.apply_hijack(CustomMoeBlock,
CustomMoeBlock.forward,
vllm__module_executor__custom_model__CustomMoeBlock__forward)

View File

@@ -0,0 +1,2 @@
from . import quantization
from . import sparse_moe_mlp

View File

@@ -0,0 +1,51 @@
import torch
from typing import Optional
from vllm import _mlu_ops as mlu_ops
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
from vllm_mlu._mlu_utils import get_is_prompt
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm_mlu__model_executor__layers__quantization__smoothquant__SmoothQuantLinearMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
quant_input = None
input_scale = None
if self.quant_config.input_quant_method == "per_token":
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer.smooth, None)
if self.quant_config.input_quant_method == "per_tensor":
quant_input = x if self.skip_quant_input else mlu_ops.quantize(x, layer.scale_to_int, None)
'''
=====================================================
Modify by Context Communication Computation Parallel
=====================================================
@brief: call parallel op
'''
if hasattr(self, 'parallel_num') and get_is_prompt():
rank = get_tensor_model_parallel_rank()
pg = get_tensor_model_parallel_group().device_group
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
params = [cncl_comm, quant_input, input_scale, layer.qweight, layer.per_channel_scale,
self.compute_dtype, bias, residual, 1.0, 1.0, self.parallel_num]
out = mlu_ops.smooth_quant_matmul_allreduce(*params)
else:
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer.qweight,
layer.per_channel_scale, self.compute_dtype, bias, residual)
'''
=====================================================
End of Context Communication Computation Parallel
=====================================================
'''
return out
MluHijackObject.apply_hijack(SmoothQuantLinearMethod,
SmoothQuantLinearMethod.apply,
vllm_mlu__model_executor__layers__quantization__smoothquant__SmoothQuantLinearMethod__apply)

View File

@@ -0,0 +1,89 @@
"""Inference-only MOE model."""
import torch
from torch import nn
from typing import Optional
from vllm.distributed import tensor_model_parallel_all_reduce, get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
from vllm import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
def vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
orig_hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# expert_logits: [num_tokens, self.num_experts_per_rank]
expert_logits, _ = self.gate(hidden_states)
final_hidden_states = self.forward_experts(hidden_states, expert_logits, residual)
'''
=====================================================
Modify by Context Communication Computation Parallel
=====================================================
@brief: disbale reduce if parallel op used
'''
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
if self.tp_size > 1 and not is_parallel_enable:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
'''
=====================================================
End of Context Communication Computation Parallel
=====================================================
'''
output = final_hidden_states.view(orig_hidden_states_shape)
return output
def vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward_experts(
self,
hidden_states,
expert_logits,
residual: Optional[torch.Tensor] = None
):
residual_ = None if self.tp_rank > 0 else residual
if self.is_use_fused_moe:
self.pack_params()
'''
=====================================================
Modify by Context Communication Computation Parallel
=====================================================
@brief: call fused_moe all_reduce
'''
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
if is_parallel_enable:
residual_ = residual
params = [hidden_states, expert_logits, self.w13, self.w2, self.b13, self.b2,
residual_, self.a13_scale, self.a2_scale, self.w13_scale, self.w2_scale,
self.top_k, self.renormalize, self.is_gated, self.hidden_act, self.start_expert_id]
if is_parallel_enable:
rank = get_tensor_model_parallel_rank()
pg = get_tensor_model_parallel_group().device_group
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
params.extend([self.parallel_num, cncl_comm])
final_hidden_states = mlu_ops.fused_moe(*params)
'''
=====================================================
End of Context Communication Computation Parallel
=====================================================
'''
else:
final_hidden_states = self.forward_experts_nofused(hidden_states, expert_logits)
if residual_ is not None:
final_hidden_states = final_hidden_states + residual_
return final_hidden_states
MluHijackObject.apply_hijack(SparseMoeMlp,
SparseMoeMlp.forward,
vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward)
MluHijackObject.apply_hijack(SparseMoeMlp,
SparseMoeMlp.forward_experts,
vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward_experts)

View File

@@ -0,0 +1,3 @@
from . import mixtral_quant
from . import qwen2
from . import qwen2_moe

View File

@@ -0,0 +1,299 @@
import torch
from typing import List, Optional
from vllm import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
from vllm.model_executor.models.mixtral_quant import MixtralAttention
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
from vllm.attention.backends.abstract import (AttentionMetadata,
AttentionType)
from vllm.attention.backends.utils import get_num_prefill_decode_query_kv_tokens
from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.attention.backends.mlu_attn import (MLUFlashAttentionMetadata,
_get_query_key_seq_metadata,
_get_causal_option)
def vllm__model_executor__models__mixtral__MixtralAttention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
'''
=====================================================
Modify by Context Communication Computation Parallel
=====================================================
@brief: call flash_attn_sq_mm_allreduce to finish forward
'''
if (attn_metadata.prefill_metadata) and \
(kv_cache[0].numel() > 0) and \
(hasattr(self.o_proj, 'quant_method')) and \
(isinstance(self.o_proj.quant_method, SmoothQuantLinearMethod)) and \
(self.o_proj.quant_method.quant_config.input_quant_method == "per_token"):
rank = get_tensor_model_parallel_rank()
pg = get_tensor_model_parallel_group().device_group
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
return torch.ops.vllm.context_attn_comm_cmpt_parallel_flash_attention_v2(
q, k, v,
self.num_heads, self.head_dim, self.num_kv_heads,
kv_cache, self.attn.impl.kv_cache_dtype,
1.0, 1.0, self.scaling,
cncl_comm,
self.o_proj.smooth, self.o_proj.qweight,
self.o_proj.per_channel_scale.to(torch.float),
self.o_proj.quant_method.parallel_num,
residual, self.attn.impl.sliding_window, self.attn.impl.alibi_slopes
)
'''
=====================================================
End of Context Communication Computation Parallel
=====================================================
'''
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual
'''
output, _ = self.o_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
MluHijackObject.apply_hijack(MixtralAttention,
MixtralAttention.forward,
vllm__model_executor__models__mixtral__MixtralAttention__forward)
def context_attn_comm_cmpt_parallel_flash_attention_v2(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: List[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
cncl_comm: int,
smooth: torch.Tensor,
qweight: torch.Tensor,
per_channel_scale: torch.Tensor,
parallel_num: int,
residual: Optional[torch.Tensor] = None,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, MLUFlashAttentionMetadata)
attn_metadata: MLUFlashAttentionMetadata = current_metadata
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
if (key is not None) and (key is not None):
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
kv_cache_, kv_cache_scale_ = kv_cache
key_cache = kv_cache_[0]
value_cache = kv_cache_[1]
key_cache_scale, value_cache_scale = None, None
if kv_cache_scale_.numel() > 0:
key_cache_scale = kv_cache_scale_[0]
value_cache_scale = kv_cache_scale_[1]
# if not specified in self.attn.forward params, use default DECODER
attn_type = AttentionType.DECODER
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the KV
# cache is already populated with the cross-attention tensor.
# Thus, we skip cache updates during this time.
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
if USE_PAGED:
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_paged_cache(key,
value,
key_cache,
value_cache,
key_cache_scale,
value_cache_scale,
attn_metadata.slot_mapping.flatten())
else:
mlu_ops.reshape_paged_cache(key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping.flatten())
else:
# FIXME: After TMO-1496 is completed, remove this code.
if key.stride() != value.stride():
key = key.contiguous()
value = value.contiguous()
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_linear_cache(key,
value,
key_cache,
value_cache,
key_cache_scale,
value_cache_scale,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, # packed
None, # context_seq_offset
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else:
mlu_ops.reshape_linear_cache(key,
value,
key_cache,
value_cache,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, # packed
None, # context_seq_offset
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
decode_query = query[num_prefill_query_tokens:]
# QKV for prefill.
query = query[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens
alibi_slopes = None if alibi_slopes is None else \
alibi_slopes.repeat(attn_metadata.num_prefills, 1)
prefill_meta = attn_metadata.prefill_metadata
# Prompt run.
if (prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]
output = mlu_ops.flash_attn_sq_mm_allreduce(cncl_comm,
query, key, value,
q_seq_start_loc, k_seq_start_loc,
alibi_slopes, None,
smooth, qweight,
per_channel_scale, None,
q_seq_len, k_seq_len,
softmax_scale, _get_causal_option(attn_type),
-1 if window_size is None \
else window_size[0],
-1 if window_size is None \
else window_size[1],
torch.float, parallel_num)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
output = mlu_ops.flash_attn_sq_mm_allreduce(cncl_comm,
query, key_cache, value_cache,
prefill_meta.query_start_loc, prefill_meta.seq_start_loc,
alibi_slopes, None,
smooth, qweight,
per_channel_scale, None,
prefill_meta.max_query_len, max_seq_len,
softmax_scale, True,
-1 if window_size is None \
else window_size[0],
-1 if window_size is None \
else window_size[1],
torch.float, parallel_num)
# Add residual.
if residual is not None:
output = output + residual
return output
def context_attn_comm_cmpt_parallel_flash_attention_v2_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: List[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
cncl_comm: int,
smooth: torch.Tensor,
qweight: torch.Tensor,
per_channel_scale: torch.Tensor,
parallel_num: int,
residual: Optional[torch.Tensor] = None,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
direct_register_custom_op(
op_name="context_attn_comm_cmpt_parallel_flash_attention_v2",
op_func=context_attn_comm_cmpt_parallel_flash_attention_v2,
mutates_args=["kv_cache"],
fake_impl=context_attn_comm_cmpt_parallel_flash_attention_v2_fake,
)

View File

@@ -0,0 +1,90 @@
import torch
from typing import Optional
from vllm.attention import AttentionMetadata
from vllm import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
from vllm.model_executor.models.qwen2 import Qwen2Attention
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
def vllm__model_executor__models__qwen2__Qwen2Attention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert smooth_quant_scale is None
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
'''
=====================================================
Modify by Context Communication Computation Parallel
=====================================================
@brief: call flash_attn_sq_mm_allreduce to finish forward
'''
if (attn_metadata.prefill_metadata) and \
(kv_cache[0].numel() > 0) and \
(hasattr(self.o_proj, 'quant_method')) and \
(isinstance(self.o_proj.quant_method, SmoothQuantLinearMethod)) and \
(self.o_proj.quant_method.quant_config.input_quant_method == "per_token"):
rank = get_tensor_model_parallel_rank()
pg = get_tensor_model_parallel_group().device_group
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
return torch.ops.vllm.context_attn_comm_cmpt_parallel_flash_attention_v2(
q, k, v,
self.num_heads, self.head_dim, self.num_kv_heads,
kv_cache, self.attn.impl.kv_cache_dtype,
1.0, 1.0, self.scaling,
cncl_comm,
self.o_proj.smooth, self.o_proj.qweight,
self.o_proj.per_channel_scale.to(torch.float),
self.o_proj.quant_method.parallel_num,
residual, self.attn.impl.sliding_window, self.attn.impl.alibi_slopes
)
'''
=====================================================
End of Context Communication Computation Parallel
=====================================================
'''
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual
'''
output, _ = self.o_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
MluHijackObject.undo_hijack(Qwen2Attention,
Qwen2Attention.forward)
MluHijackObject.apply_hijack(Qwen2Attention,
Qwen2Attention.forward,
vllm__model_executor__models__qwen2__Qwen2Attention__forward)

View File

@@ -0,0 +1,58 @@
import torch
import torch.nn.functional as F
from typing import Optional
from vllm import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.distributed import get_tensor_model_parallel_rank, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
from vllm.attention import AttentionMetadata
from vllm_mlu.model_executor.models.qwen2_moe import Qwen2MoeSparseMoeBlock
def vllm_mlu__model_executor__models__qwen2_moe__Qwen2MoeSparseMoeBlock__forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
if self.shared_expert_gate is not None:
gate_output = self.shared_expert_gate(hidden_states)
shared_output = F.sigmoid(gate_output[0]) * shared_output
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.forward_experts(hidden_states, router_logits, residual)
'''
=====================================================
Modify by Context Communication Computation Parallel
=====================================================
@brief: disbale reduce if parallel op used
'''
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
if self.tp_size > 1:
if is_parallel_enable:
shared_output = tensor_model_parallel_all_reduce(shared_output)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
else:
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
'''
=====================================================
End of Context Communication Computation Parallel
=====================================================
'''
return final_hidden_states.view(num_tokens, hidden_dim)
MluHijackObject.apply_hijack(Qwen2MoeSparseMoeBlock,
Qwen2MoeSparseMoeBlock.forward,
vllm_mlu__model_executor__models__qwen2_moe__Qwen2MoeSparseMoeBlock__forward)

View File

@@ -0,0 +1,32 @@
### 简介
该劫持代码实现在vLLM的解码通信过程中预加载下一层的权重从而减少解码的延迟。
### 支持模型
仅支持以下模型不支持量化后的模型以及MOE模型。
- Baichuan
- Bloom
- ChatGLM
- Falcon
- GPTNeoX
- Llama
- Qwen
- Qwen2
### 支持板卡
300系列不支持其他系列支持。
### 使用方法
- 设置环境变量export VLLM_PRELOAD_SIZE=<PRELOAD_SIZE><PRELOAD_SIZE>表示预加载权重的大小单位MB。
- 参数设置参考在低带宽资源环境下对于模型Llama-65B不同batch_sized和preload_size对应的性能优化收益如下。
| batch\preload | 8 | 16 | 24 | 32 | 48 | 64 |
|:--------------:|:----:|:----:|:----:|:----:|:----:|:----:|
| 1 | 4.9% | 10.0%| 9.5% | 6.7% |-2.4% | -7.1%|
| 8 | 3.2% | 6.3% | 8.9% | 11.2%| 6.0% | 1.8% |
| 16 | 2.3% | 5.1% | 7.5% | 9.2% | 8.3% | 4.3% |
| 24 | 2.3% | 4.8% | 7.4% | 9.1% | 9.5% | 6.0% |
| 32 | 2.1% | 4.3% | 7.0% | 8.7% | 10.1%| 8.1% |

View File

@@ -0,0 +1 @@
from . import distributed

View File

@@ -0,0 +1 @@
from . import parallel_state

View File

@@ -0,0 +1,75 @@
import torch
from typing import Union
from vllm.distributed.parallel_state import GroupCoordinator, supports_custom_op
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__distributed__parallel_state__GroupCoordinator__all_reduce(
self, input_: torch.Tensor,
async_op: bool = False) -> Union[torch.distributed.Work, torch.Tensor]:
"""
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
if input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
return input_
if not supports_custom_op():
self._all_reduce_in_place(input_)
return input_
if self.tpu_communicator is not None and \
not self.tpu_communicator.disabled:
# TPU handles Dynamo with its own logic.
return self.tpu_communicator.all_reduce(input_)
if self.hpu_communicator is not None and \
not self.hpu_communicator.disabled:
return self.hpu_communicator.all_reduce(input_)
if self.xpu_communicator is not None and \
not self.xpu_communicator.disabled:
return self.xpu_communicator.all_reduce(input_)
if self.ca_comm is not None and \
not self.ca_comm.disabled and \
self.ca_comm.should_custom_ar(input_):
return torch.ops.vllm.outplace_all_reduce(
input_, group_name=self.unique_name)
else:
'''
=============================
Modify by custom vllm_mlu
=============================
@brief: use async all reduce when preload weights.
'''
handle = torch.distributed.all_reduce(input_, group=self.device_group,
async_op=async_op)
if async_op:
return handle
'''
==================
End of custom MLU Hijack
==================
'''
return input_
MluHijackObject.apply_hijack(GroupCoordinator,
GroupCoordinator.all_reduce,
vllm__distributed__parallel_state__GroupCoordinator__all_reduce)