forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from . import model_executor
|
||||
@@ -0,0 +1,2 @@
|
||||
from . import layers
|
||||
from . import model_loader
|
||||
@@ -0,0 +1,2 @@
|
||||
from . import feed_forward
|
||||
from . import linear
|
||||
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject, set_is_gated
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm.distributed.parallel_state import get_tp_group, get_tensor_model_parallel_group
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm_mlu._mlu_utils import *
|
||||
|
||||
|
||||
def vllm_mlu__model_executor__layers__feed_forward__FeedForward__forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
):
|
||||
self.prepare_weight()
|
||||
up_proj = getattr(self, self.up_proj_name)
|
||||
down_proj = getattr(self, self.down_proj_name)
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
if (self.use_bt_ffn and not isinstance(up_proj, BaseLayerWithLoRA)
|
||||
and not isinstance(down_proj, BaseLayerWithLoRA)):
|
||||
# The matmul formula is the following:
|
||||
# mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual
|
||||
# output = active(mul_out)
|
||||
# Notes: We cannot use the activation function in matmul because it does not support gated operation
|
||||
# we might support its in tmo matmul in the future
|
||||
fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj.weight, up_proj.bias,
|
||||
None, 'none', self.alpha, self.beta)
|
||||
act_out = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
|
||||
beta = 1.0 if residual_ is not None else 0.0
|
||||
'''
|
||||
=======================================
|
||||
Modify by custom vllm_mlu
|
||||
=======================================
|
||||
@brief: call parallel op and abandon original reduce if parallel_num is set
|
||||
'''
|
||||
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
|
||||
if is_parallel_enable:
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pg = get_tensor_model_parallel_group().device_group
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
out_ = mlu_ops.matmul_allreduce(cncl_comm, act_out, down_proj.weight, None, residual_,
|
||||
self.alpha, beta, self.parallel_num)
|
||||
else:
|
||||
out_ = mlu_ops.matmul(act_out, down_proj.weight, None, residual_, 'none', self.alpha, beta)
|
||||
'''
|
||||
=======================================
|
||||
End of custom MLU Hijack
|
||||
=======================================
|
||||
'''
|
||||
# bias if existed need to add after second matmul according to the original design of vllm
|
||||
'''
|
||||
=============================
|
||||
Modify by custom vllm_mlu
|
||||
=============================
|
||||
@brief: when preload_size is set, call GroupCoordinator.all_reduce() directly and
|
||||
use async_op to set all_reduce paralleled with preload
|
||||
'''
|
||||
if self.reduce_results and self.tp_size > 1 and not is_parallel_enable:
|
||||
if hasattr(self, 'preload_size') and self.preload_size > 0 and not self.is_prompt:
|
||||
handle = get_tp_group().all_reduce(out_, async_op=True)
|
||||
_MB = 1 << 20
|
||||
mlu_ops.preload(self.preloaded_weights[0].data, self.preload_size * _MB)
|
||||
preloaded_weights_size = self.preloaded_weights[0].numel() * self.preloaded_weights[0].element_size()
|
||||
if preloaded_weights_size < (self.preload_size * _MB) and len(self.preloaded_weights) > 1:
|
||||
mlu_ops.preload(self.preloaded_weights[1].data, (self.preload_size * _MB) - preloaded_weights_size)
|
||||
handle.wait()
|
||||
out = out_
|
||||
else:
|
||||
out = tensor_model_parallel_all_reduce(out_)
|
||||
else:
|
||||
out = out_
|
||||
'''
|
||||
=========================
|
||||
End of custom MLU Hijack
|
||||
=========================
|
||||
'''
|
||||
# do the bias add if needed
|
||||
if not self.skip_bias_add:
|
||||
out = out + down_proj.bias if down_proj.bias is not None else out
|
||||
else:
|
||||
return out, down_proj.bias
|
||||
else:
|
||||
fc1, bias = up_proj(hidden_states)
|
||||
if bias is not None:
|
||||
fc1 += bias
|
||||
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
|
||||
out, bias = down_proj(fc1, residual=residual_)
|
||||
if self.skip_bias_add:
|
||||
return out, bias
|
||||
return out
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(FeedForward,
|
||||
FeedForward.forward,
|
||||
vllm_mlu__model_executor__layers__feed_forward__FeedForward__forward)
|
||||
@@ -0,0 +1,116 @@
|
||||
from typing import Optional
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import get_tp_group, get_tensor_model_parallel_group
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, split_tensor_along_last_dim
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod, RowParallelLinear
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm_mlu._mlu_utils import *
|
||||
|
||||
|
||||
def vllm__model_executor__layers__linear__UnquantizedLinearMethod__apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
beta = 1.0 if residual is not None else 0.0
|
||||
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
|
||||
'''
|
||||
=====================================================
|
||||
Modify by custom vllm_mlu
|
||||
=====================================================
|
||||
@brief: call parallel op if parallel_num is set
|
||||
'''
|
||||
if hasattr(self, 'parallel_num') and get_is_prompt():
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pg = get_tensor_model_parallel_group().device_group
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
return mlu_ops.matmul_allreduce(cncl_comm, x.view(-1, x.shape[-1]), layer.weight,
|
||||
bias, residual, 1.0, beta, self.parallel_num).view(res_shape)
|
||||
return mlu_ops.matmul(x.view(-1, x.shape[-1]), layer.weight, bias, residual, 'none', 1.0, beta).view(res_shape)
|
||||
'''
|
||||
=====================================================
|
||||
End of custom MLU Hijack
|
||||
=====================================================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__model_executor__layers__linear__RowParallelLinear__forward(
|
||||
self,
|
||||
input_,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
):
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
'''
|
||||
=====================================================
|
||||
Modify by custom vllm_mlu
|
||||
=====================================================
|
||||
@brief: abandon original reduce if parallel_num is set
|
||||
'''
|
||||
is_parallel_enable = hasattr(self.quant_method, 'parallel_num') and get_is_prompt()
|
||||
'''
|
||||
=====================================================
|
||||
End of custom MLU Hijack
|
||||
=====================================================
|
||||
'''
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_,
|
||||
residual=residual_)
|
||||
'''
|
||||
=============================
|
||||
Modify by custom vllm_mlu
|
||||
=============================
|
||||
@brief: when preload_size is set, call GroupCoordinator.all_reduce() directly and
|
||||
use async_op to set all_reduce paralleled with preload
|
||||
'''
|
||||
if self.reduce_results and self.tp_size > 1 and not is_parallel_enable:
|
||||
if hasattr(self, 'preload_size') and self.preload_size > 0 and not self.is_prompt:
|
||||
handle = get_tp_group().all_reduce(output_parallel, async_op=True)
|
||||
_MB = 1 << 20
|
||||
mlu_ops.preload(self.preloaded_weights[0].data, self.preload_size * _MB)
|
||||
preloaded_weights_size = self.preloaded_weights[0].numel() * self.preloaded_weights[0].element_size()
|
||||
if preloaded_weights_size < (self.preload_size * _MB) and len(self.preloaded_weights) > 1:
|
||||
mlu_ops.preload(self.preloaded_weights[1].data, (self.preload_size * _MB) - preloaded_weights_size)
|
||||
handle.wait()
|
||||
output = output_parallel
|
||||
else:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
'''
|
||||
=========================
|
||||
End of custom MLU Hijack
|
||||
=========================
|
||||
'''
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
return output, output_bias
|
||||
|
||||
|
||||
MluHijackObject.undo_hijack(UnquantizedLinearMethod,
|
||||
UnquantizedLinearMethod.apply)
|
||||
MluHijackObject.apply_hijack(UnquantizedLinearMethod,
|
||||
UnquantizedLinearMethod.apply,
|
||||
vllm__model_executor__layers__linear__UnquantizedLinearMethod__apply)
|
||||
MluHijackObject.undo_hijack(RowParallelLinear,
|
||||
RowParallelLinear.forward)
|
||||
MluHijackObject.apply_hijack(RowParallelLinear,
|
||||
RowParallelLinear.forward,
|
||||
vllm__model_executor__layers__linear__RowParallelLinear__forward)
|
||||
@@ -0,0 +1 @@
|
||||
from . import loader
|
||||
@@ -0,0 +1,143 @@
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Optional
|
||||
from vllm.logger import init_logger
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
from vllm.config import VllmConfig, ModelConfig, ParallelConfig
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_parallel_num(
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig
|
||||
):
|
||||
attention_parallel_num = os.environ.get(ATTN_PARALLEL_NUM)
|
||||
ffn_parallel_num = os.environ.get(FFN_PARALLEL_NUM)
|
||||
if attention_parallel_num and attention_parallel_num.isdecimal():
|
||||
attention_parallel_num = int(attention_parallel_num)
|
||||
else:
|
||||
attention_parallel_num = 0
|
||||
if ffn_parallel_num and ffn_parallel_num.isdecimal():
|
||||
ffn_parallel_num = int(ffn_parallel_num)
|
||||
else:
|
||||
ffn_parallel_num = 0
|
||||
|
||||
if parallel_config.tensor_parallel_size == 1:
|
||||
raise ValueError("Can not use context_comm_cmpt_parallel when tp num is 1.")
|
||||
if (attention_parallel_num <= 0 and ffn_parallel_num <= 0):
|
||||
raise ValueError("attention_parallel_num and ffn_parallel_num must be positive integers.")
|
||||
|
||||
hidden_size = model_config.get_hidden_size()
|
||||
ffn_parallel_num = max(ffn_parallel_num, 1)
|
||||
if hidden_size % ffn_parallel_num != 0:
|
||||
raise ValueError(f"Hidden_size: {hidden_size} must be divisible by ffn_parallel_num: {ffn_parallel_num}")
|
||||
|
||||
return attention_parallel_num, ffn_parallel_num
|
||||
|
||||
|
||||
def get_attr_by_path(obj, path):
|
||||
# Split the path by dots to get individual attributes
|
||||
attributes = path.split('.')
|
||||
# Iterate through the attributes to access nested members
|
||||
for attr in attributes:
|
||||
if not hasattr(obj, attr):
|
||||
return None
|
||||
obj = getattr(obj, attr)
|
||||
return obj
|
||||
|
||||
|
||||
def set_custom_attributes(model, model_config, parallel_config):
|
||||
attn_row_parallel_layers = []
|
||||
attn_weights = []
|
||||
ffn_row_parallel_layers = []
|
||||
ffn_weights = []
|
||||
sparse_moe_mlp_layers = []
|
||||
for module in model.modules():
|
||||
if module.__class__.__name__ == "FeedForward":
|
||||
ffn_weight = []
|
||||
if hasattr(module, "up_proj_name"):
|
||||
up_proj_name = getattr(module, "up_proj_name")
|
||||
up_proj = getattr(module, up_proj_name)
|
||||
if hasattr(up_proj, "weight"):
|
||||
ffn_weight.append(up_proj.weight)
|
||||
if hasattr(module, "down_proj_name"):
|
||||
down_proj_name = getattr(module, "down_proj_name")
|
||||
down_proj = getattr(module, down_proj_name)
|
||||
if hasattr(down_proj, "weight"):
|
||||
ffn_weight.append(down_proj.weight)
|
||||
if ffn_weight is not None:
|
||||
ffn_weights.append(ffn_weight)
|
||||
ffn_row_parallel_layers.append(module)
|
||||
for child_module in module.children():
|
||||
if child_module.__class__.__name__ == "Attention":
|
||||
for sibling_module in module.children():
|
||||
if sibling_module.__class__.__name__ == "QKVParallelLinear":
|
||||
if hasattr(sibling_module, "weight"):
|
||||
weight = getattr(sibling_module, "weight")
|
||||
attn_weights.append([weight])
|
||||
if sibling_module.__class__.__name__ == "RowParallelLinear":
|
||||
attn_row_parallel_layers.append(sibling_module)
|
||||
if module.__class__.__name__ == "SparseMoeMlp" or issubclass(module.__class__, SparseMoeMlp):
|
||||
sparse_moe_mlp_layers.append(module)
|
||||
|
||||
if VLLM_PRELOAD_SIZE > 0:
|
||||
if (len(attn_row_parallel_layers) \
|
||||
== len(attn_weights) \
|
||||
== len(ffn_row_parallel_layers) \
|
||||
== len(ffn_weights)) and \
|
||||
len(attn_row_parallel_layers) != 0:
|
||||
|
||||
for i in range(len(attn_row_parallel_layers)):
|
||||
attn_row_parallel_layers[i].preloaded_weights = ffn_weights[i]
|
||||
attn_row_parallel_layers[i].preload_size = VLLM_PRELOAD_SIZE
|
||||
if i < len(attn_row_parallel_layers) - 1:
|
||||
ffn_row_parallel_layers[i].preloaded_weights = attn_weights[i+1]
|
||||
ffn_row_parallel_layers[i].preload_size = VLLM_PRELOAD_SIZE
|
||||
else:
|
||||
logger.warning("%s does not support preload weight!", model.__class__.__name__)
|
||||
|
||||
# context compute communication parallel
|
||||
if check_context_comm_cmpt_parallel():
|
||||
attention_parallel_num, ffn_parallel_num = get_parallel_num(model_config, parallel_config)
|
||||
for o_proj in attn_row_parallel_layers:
|
||||
setattr(o_proj.quant_method, 'parallel_num', attention_parallel_num)
|
||||
|
||||
if len(sparse_moe_mlp_layers) != 0:
|
||||
for sparse_moe_mlp in sparse_moe_mlp_layers:
|
||||
setattr(sparse_moe_mlp, 'parallel_num', ffn_parallel_num)
|
||||
else:
|
||||
for ffn in ffn_row_parallel_layers:
|
||||
setattr(ffn, 'parallel_num', ffn_parallel_num)
|
||||
|
||||
|
||||
vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model__org = DefaultModelLoader.load_model
|
||||
|
||||
|
||||
def vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model(
|
||||
self, vllm_config: VllmConfig) -> nn.Module:
|
||||
model = vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model__org(
|
||||
self, vllm_config=vllm_config)
|
||||
'''
|
||||
=============================
|
||||
Modify by custom vllm_mlu
|
||||
=============================
|
||||
@brief: According to the layer name in models, set custom optimize attributes.
|
||||
'''
|
||||
set_custom_attributes(model, vllm_config.model_config, vllm_config.parallel_config)
|
||||
'''
|
||||
=========================
|
||||
End of custom MLU Hijack
|
||||
=========================
|
||||
'''
|
||||
return model
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(DefaultModelLoader,
|
||||
DefaultModelLoader.load_model,
|
||||
vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model)
|
||||
@@ -0,0 +1,17 @@
|
||||
### 简介
|
||||
|
||||
该劫持代码实现了vllm Context通算并行功能。开启后可在部分数据规模和切分数量上对Context Latency指标有优化效果。目前是可选功能,默认不开启。
|
||||
|
||||
### 开启方法
|
||||
|
||||
- 设置环境变量ATTN_PARALLEL_NUM和FFN_PARALLEL_NUM为正整数,分别控制attention和ffn部分的通算并行切分数量。两个环境变量相互独立,可以同时开启。例如输入export ATTN_PARALLEL_NUM=2 FFN_PARALLEL_NUM=4,则表示两部分均开启并行,attention数据拆分为2份,ffn数据拆分为4份。
|
||||
|
||||
- 需要保证tensor_parallel_size大于1。
|
||||
|
||||
- 开启ffn部分的通算并行时,需要保证hidden_size能被FFN_PARALLEL_NUM整除。
|
||||
|
||||
### 注意事项
|
||||
|
||||
- 开启通算并行功能时,由于算子限制,Mixtral系列模型、Qwen2(包含Qwen1.5和Qwen2.5)系列模型在smoothquant量化下只支持batch_size = 1,且算子默认切分数为4,ATTN_PARALLEL_NUM不生效。
|
||||
|
||||
- smoothquant量化下,vllm_mlu ffn部分不调用tmo matmul算子,该部分通算融合不生效。
|
||||
@@ -0,0 +1 @@
|
||||
from . import model_executor
|
||||
@@ -0,0 +1,3 @@
|
||||
from . import custom_model
|
||||
from . import layers
|
||||
from . import models
|
||||
@@ -0,0 +1 @@
|
||||
from . import custom
|
||||
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce, get_tensor_model_parallel_rank
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
|
||||
from vllm_mlu.model_executor.custom_model.custom import CustomMoeBlock
|
||||
|
||||
|
||||
def vllm__module_executor__custom_model__CustomMoeBlock__forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
shared_output = None
|
||||
if self.shared_expert is not None:
|
||||
shared_output = self.shared_expert(hidden_states)
|
||||
if self.shared_expert_gate is not None:
|
||||
shared_output = F.sigmoid(
|
||||
self.shared_expert_gate(hidden_states)) * shared_output
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
residual_ = None if self.rank > 0 else residual
|
||||
'''
|
||||
=====================================================
|
||||
Modify by Context Communication Computation Parallel
|
||||
=====================================================
|
||||
@brief: call fused_moe
|
||||
'''
|
||||
params = [hidden_states, router_logits, self.w1, self.w2, None, None,
|
||||
residual_, self.input_smooth, self.act_smooth, self.w1_scale, self.w2_scale,
|
||||
self.top_k, self.config.norm_topk_prob, self.config.is_gated, self.config.hidden_act, 0]
|
||||
if hasattr(self, 'parallel_num') and get_is_prompt():
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pg = get_tensor_model_parallel_group().device_group
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
params.extend([self.parallel_num, cncl_comm])
|
||||
final_hidden_states = mlu_ops.fused_moe(*params)
|
||||
'''
|
||||
=====================================================
|
||||
End of Context Communication Computation Parallel
|
||||
=====================================================
|
||||
'''
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
reduce_results = (self.config.use_parallel_residual == False)
|
||||
if reduce_results:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(CustomMoeBlock,
|
||||
CustomMoeBlock.forward,
|
||||
vllm__module_executor__custom_model__CustomMoeBlock__forward)
|
||||
@@ -0,0 +1,2 @@
|
||||
from . import quantization
|
||||
from . import sparse_moe_mlp
|
||||
@@ -0,0 +1 @@
|
||||
from . import smoothquant
|
||||
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
|
||||
from vllm_mlu._mlu_utils import get_is_prompt
|
||||
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def vllm_mlu__model_executor__layers__quantization__smoothquant__SmoothQuantLinearMethod__apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
quant_input = None
|
||||
input_scale = None
|
||||
if self.quant_config.input_quant_method == "per_token":
|
||||
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer.smooth, None)
|
||||
if self.quant_config.input_quant_method == "per_tensor":
|
||||
quant_input = x if self.skip_quant_input else mlu_ops.quantize(x, layer.scale_to_int, None)
|
||||
|
||||
'''
|
||||
=====================================================
|
||||
Modify by Context Communication Computation Parallel
|
||||
=====================================================
|
||||
@brief: call parallel op
|
||||
'''
|
||||
if hasattr(self, 'parallel_num') and get_is_prompt():
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pg = get_tensor_model_parallel_group().device_group
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
params = [cncl_comm, quant_input, input_scale, layer.qweight, layer.per_channel_scale,
|
||||
self.compute_dtype, bias, residual, 1.0, 1.0, self.parallel_num]
|
||||
out = mlu_ops.smooth_quant_matmul_allreduce(*params)
|
||||
else:
|
||||
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer.qweight,
|
||||
layer.per_channel_scale, self.compute_dtype, bias, residual)
|
||||
'''
|
||||
=====================================================
|
||||
End of Context Communication Computation Parallel
|
||||
=====================================================
|
||||
'''
|
||||
return out
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(SmoothQuantLinearMethod,
|
||||
SmoothQuantLinearMethod.apply,
|
||||
vllm_mlu__model_executor__layers__quantization__smoothquant__SmoothQuantLinearMethod__apply)
|
||||
@@ -0,0 +1,89 @@
|
||||
"""Inference-only MOE model."""
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Optional
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce, get_tensor_model_parallel_rank
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
|
||||
|
||||
def vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
orig_hidden_states_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# expert_logits: [num_tokens, self.num_experts_per_rank]
|
||||
expert_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.forward_experts(hidden_states, expert_logits, residual)
|
||||
|
||||
'''
|
||||
=====================================================
|
||||
Modify by Context Communication Computation Parallel
|
||||
=====================================================
|
||||
@brief: disbale reduce if parallel op used
|
||||
'''
|
||||
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
|
||||
if self.tp_size > 1 and not is_parallel_enable:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
'''
|
||||
=====================================================
|
||||
End of Context Communication Computation Parallel
|
||||
=====================================================
|
||||
'''
|
||||
|
||||
output = final_hidden_states.view(orig_hidden_states_shape)
|
||||
return output
|
||||
|
||||
|
||||
def vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward_experts(
|
||||
self,
|
||||
hidden_states,
|
||||
expert_logits,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
):
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
if self.is_use_fused_moe:
|
||||
self.pack_params()
|
||||
'''
|
||||
=====================================================
|
||||
Modify by Context Communication Computation Parallel
|
||||
=====================================================
|
||||
@brief: call fused_moe all_reduce
|
||||
'''
|
||||
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
|
||||
if is_parallel_enable:
|
||||
residual_ = residual
|
||||
params = [hidden_states, expert_logits, self.w13, self.w2, self.b13, self.b2,
|
||||
residual_, self.a13_scale, self.a2_scale, self.w13_scale, self.w2_scale,
|
||||
self.top_k, self.renormalize, self.is_gated, self.hidden_act, self.start_expert_id]
|
||||
if is_parallel_enable:
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pg = get_tensor_model_parallel_group().device_group
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
params.extend([self.parallel_num, cncl_comm])
|
||||
final_hidden_states = mlu_ops.fused_moe(*params)
|
||||
'''
|
||||
=====================================================
|
||||
End of Context Communication Computation Parallel
|
||||
=====================================================
|
||||
'''
|
||||
else:
|
||||
final_hidden_states = self.forward_experts_nofused(hidden_states, expert_logits)
|
||||
if residual_ is not None:
|
||||
final_hidden_states = final_hidden_states + residual_
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(SparseMoeMlp,
|
||||
SparseMoeMlp.forward,
|
||||
vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward)
|
||||
MluHijackObject.apply_hijack(SparseMoeMlp,
|
||||
SparseMoeMlp.forward_experts,
|
||||
vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward_experts)
|
||||
@@ -0,0 +1,3 @@
|
||||
from . import mixtral_quant
|
||||
from . import qwen2
|
||||
from . import qwen2_moe
|
||||
@@ -0,0 +1,299 @@
|
||||
import torch
|
||||
from typing import List, Optional
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
|
||||
from vllm.model_executor.models.mixtral_quant import MixtralAttention
|
||||
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
|
||||
from vllm.attention.backends.abstract import (AttentionMetadata,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import get_num_prefill_decode_query_kv_tokens
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.attention.backends.mlu_attn import (MLUFlashAttentionMetadata,
|
||||
_get_query_key_seq_metadata,
|
||||
_get_causal_option)
|
||||
|
||||
|
||||
def vllm__model_executor__models__mixtral__MixtralAttention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
'''
|
||||
=====================================================
|
||||
Modify by Context Communication Computation Parallel
|
||||
=====================================================
|
||||
@brief: call flash_attn_sq_mm_allreduce to finish forward
|
||||
'''
|
||||
if (attn_metadata.prefill_metadata) and \
|
||||
(kv_cache[0].numel() > 0) and \
|
||||
(hasattr(self.o_proj, 'quant_method')) and \
|
||||
(isinstance(self.o_proj.quant_method, SmoothQuantLinearMethod)) and \
|
||||
(self.o_proj.quant_method.quant_config.input_quant_method == "per_token"):
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pg = get_tensor_model_parallel_group().device_group
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
|
||||
return torch.ops.vllm.context_attn_comm_cmpt_parallel_flash_attention_v2(
|
||||
q, k, v,
|
||||
self.num_heads, self.head_dim, self.num_kv_heads,
|
||||
kv_cache, self.attn.impl.kv_cache_dtype,
|
||||
1.0, 1.0, self.scaling,
|
||||
cncl_comm,
|
||||
self.o_proj.smooth, self.o_proj.qweight,
|
||||
self.o_proj.per_channel_scale.to(torch.float),
|
||||
self.o_proj.quant_method.parallel_num,
|
||||
residual, self.attn.impl.sliding_window, self.attn.impl.alibi_slopes
|
||||
)
|
||||
'''
|
||||
=====================================================
|
||||
End of Context Communication Computation Parallel
|
||||
=====================================================
|
||||
'''
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual
|
||||
'''
|
||||
output, _ = self.o_proj(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MixtralAttention,
|
||||
MixtralAttention.forward,
|
||||
vllm__model_executor__models__mixtral__MixtralAttention__forward)
|
||||
|
||||
|
||||
def context_attn_comm_cmpt_parallel_flash_attention_v2(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: List[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
cncl_comm: int,
|
||||
smooth: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
per_channel_scale: torch.Tensor,
|
||||
parallel_num: int,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
current_metadata = get_forward_context()
|
||||
assert current_metadata is not None
|
||||
assert isinstance(current_metadata, MLUFlashAttentionMetadata)
|
||||
attn_metadata: MLUFlashAttentionMetadata = current_metadata
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
if (key is not None) and (key is not None):
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
kv_cache_, kv_cache_scale_ = kv_cache
|
||||
key_cache = kv_cache_[0]
|
||||
value_cache = kv_cache_[1]
|
||||
key_cache_scale, value_cache_scale = None, None
|
||||
if kv_cache_scale_.numel() > 0:
|
||||
key_cache_scale = kv_cache_scale_[0]
|
||||
value_cache_scale = kv_cache_scale_[1]
|
||||
|
||||
# if not specified in self.attn.forward params, use default DECODER
|
||||
attn_type = AttentionType.DECODER
|
||||
|
||||
# We skip updating the KV cache under two conditions:
|
||||
# a. When the Attention Type is ENCODER. In this phase, we compute
|
||||
# only the encoder attention without updating the cache.
|
||||
# b. When both Key and Value are None. This occurs during
|
||||
# cross-attention computation in the decoding phase, where the KV
|
||||
# cache is already populated with the cross-attention tensor.
|
||||
# Thus, we skip cache updates during this time.
|
||||
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
|
||||
value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
if USE_PAGED:
|
||||
if kv_cache_dtype == 'int8':
|
||||
mlu_ops.quant_to_paged_cache(key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
attn_metadata.slot_mapping.flatten())
|
||||
else:
|
||||
mlu_ops.reshape_paged_cache(key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping.flatten())
|
||||
else:
|
||||
# FIXME: After TMO-1496 is completed, remove this code.
|
||||
if key.stride() != value.stride():
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
if kv_cache_dtype == 'int8':
|
||||
mlu_ops.quant_to_linear_cache(key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
attn_metadata.cu_seq_lens,
|
||||
attn_metadata.max_seq_len,
|
||||
True, # packed
|
||||
None, # context_seq_offset
|
||||
attn_metadata.batch_ids,
|
||||
attn_metadata.slot_mapping_unpaged)
|
||||
else:
|
||||
mlu_ops.reshape_linear_cache(key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.cu_seq_lens,
|
||||
attn_metadata.max_seq_len,
|
||||
True, # packed
|
||||
None, # context_seq_offset
|
||||
attn_metadata.batch_ids,
|
||||
attn_metadata.slot_mapping_unpaged)
|
||||
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
decode_query = query[num_prefill_query_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_query_tokens]
|
||||
assert query.shape[0] == num_prefill_query_tokens
|
||||
assert decode_query.shape[0] == num_decode_query_tokens
|
||||
|
||||
alibi_slopes = None if alibi_slopes is None else \
|
||||
alibi_slopes.repeat(attn_metadata.num_prefills, 1)
|
||||
prefill_meta = attn_metadata.prefill_metadata
|
||||
# Prompt run.
|
||||
if (prefill_meta.block_tables is None
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
|
||||
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
|
||||
|
||||
key = key[:num_prefill_kv_tokens]
|
||||
value = value[:num_prefill_kv_tokens]
|
||||
|
||||
output = mlu_ops.flash_attn_sq_mm_allreduce(cncl_comm,
|
||||
query, key, value,
|
||||
q_seq_start_loc, k_seq_start_loc,
|
||||
alibi_slopes, None,
|
||||
smooth, qweight,
|
||||
per_channel_scale, None,
|
||||
q_seq_len, k_seq_len,
|
||||
softmax_scale, _get_causal_option(attn_type),
|
||||
-1 if window_size is None \
|
||||
else window_size[0],
|
||||
-1 if window_size is None \
|
||||
else window_size[1],
|
||||
torch.float, parallel_num)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support prefix caching")
|
||||
assert prefill_meta.seq_lens is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
output = mlu_ops.flash_attn_sq_mm_allreduce(cncl_comm,
|
||||
query, key_cache, value_cache,
|
||||
prefill_meta.query_start_loc, prefill_meta.seq_start_loc,
|
||||
alibi_slopes, None,
|
||||
smooth, qweight,
|
||||
per_channel_scale, None,
|
||||
prefill_meta.max_query_len, max_seq_len,
|
||||
softmax_scale, True,
|
||||
-1 if window_size is None \
|
||||
else window_size[0],
|
||||
-1 if window_size is None \
|
||||
else window_size[1],
|
||||
torch.float, parallel_num)
|
||||
|
||||
# Add residual.
|
||||
if residual is not None:
|
||||
output = output + residual
|
||||
return output
|
||||
|
||||
|
||||
def context_attn_comm_cmpt_parallel_flash_attention_v2_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: List[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
cncl_comm: int,
|
||||
smooth: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
per_channel_scale: torch.Tensor,
|
||||
parallel_num: int,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="context_attn_comm_cmpt_parallel_flash_attention_v2",
|
||||
op_func=context_attn_comm_cmpt_parallel_flash_attention_v2,
|
||||
mutates_args=["kv_cache"],
|
||||
fake_impl=context_attn_comm_cmpt_parallel_flash_attention_v2_fake,
|
||||
)
|
||||
@@ -0,0 +1,90 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Attention
|
||||
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
|
||||
|
||||
|
||||
def vllm__model_executor__models__qwen2__Qwen2Attention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert smooth_quant_scale is None
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
'''
|
||||
=====================================================
|
||||
Modify by Context Communication Computation Parallel
|
||||
=====================================================
|
||||
@brief: call flash_attn_sq_mm_allreduce to finish forward
|
||||
'''
|
||||
if (attn_metadata.prefill_metadata) and \
|
||||
(kv_cache[0].numel() > 0) and \
|
||||
(hasattr(self.o_proj, 'quant_method')) and \
|
||||
(isinstance(self.o_proj.quant_method, SmoothQuantLinearMethod)) and \
|
||||
(self.o_proj.quant_method.quant_config.input_quant_method == "per_token"):
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pg = get_tensor_model_parallel_group().device_group
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
|
||||
return torch.ops.vllm.context_attn_comm_cmpt_parallel_flash_attention_v2(
|
||||
q, k, v,
|
||||
self.num_heads, self.head_dim, self.num_kv_heads,
|
||||
kv_cache, self.attn.impl.kv_cache_dtype,
|
||||
1.0, 1.0, self.scaling,
|
||||
cncl_comm,
|
||||
self.o_proj.smooth, self.o_proj.qweight,
|
||||
self.o_proj.per_channel_scale.to(torch.float),
|
||||
self.o_proj.quant_method.parallel_num,
|
||||
residual, self.attn.impl.sliding_window, self.attn.impl.alibi_slopes
|
||||
)
|
||||
'''
|
||||
=====================================================
|
||||
End of Context Communication Computation Parallel
|
||||
=====================================================
|
||||
'''
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual
|
||||
'''
|
||||
output, _ = self.o_proj(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
|
||||
MluHijackObject.undo_hijack(Qwen2Attention,
|
||||
Qwen2Attention.forward)
|
||||
MluHijackObject.apply_hijack(Qwen2Attention,
|
||||
Qwen2Attention.forward,
|
||||
vllm__model_executor__models__qwen2__Qwen2Attention__forward)
|
||||
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm_mlu.model_executor.models.qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
|
||||
|
||||
def vllm_mlu__model_executor__models__qwen2_moe__Qwen2MoeSparseMoeBlock__forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
shared_output = None
|
||||
if self.shared_expert is not None:
|
||||
shared_output = self.shared_expert(hidden_states)
|
||||
if self.shared_expert_gate is not None:
|
||||
gate_output = self.shared_expert_gate(hidden_states)
|
||||
shared_output = F.sigmoid(gate_output[0]) * shared_output
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.forward_experts(hidden_states, router_logits, residual)
|
||||
|
||||
'''
|
||||
=====================================================
|
||||
Modify by Context Communication Computation Parallel
|
||||
=====================================================
|
||||
@brief: disbale reduce if parallel op used
|
||||
'''
|
||||
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
|
||||
if self.tp_size > 1:
|
||||
if is_parallel_enable:
|
||||
shared_output = tensor_model_parallel_all_reduce(shared_output)
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
else:
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
'''
|
||||
=====================================================
|
||||
End of Context Communication Computation Parallel
|
||||
=====================================================
|
||||
'''
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(Qwen2MoeSparseMoeBlock,
|
||||
Qwen2MoeSparseMoeBlock.forward,
|
||||
vllm_mlu__model_executor__models__qwen2_moe__Qwen2MoeSparseMoeBlock__forward)
|
||||
32
vllm-v0.6.2/vllm_mlu/vllm_mlu/mlu_custom/preload/README.md
Normal file
32
vllm-v0.6.2/vllm_mlu/vllm_mlu/mlu_custom/preload/README.md
Normal file
@@ -0,0 +1,32 @@
|
||||
### 简介
|
||||
|
||||
该劫持代码实现在vLLM的解码通信过程中预加载下一层的权重,从而减少解码的延迟。
|
||||
|
||||
### 支持模型
|
||||
|
||||
仅支持以下模型,不支持量化后的模型以及MOE模型。
|
||||
- Baichuan
|
||||
- Bloom
|
||||
- ChatGLM
|
||||
- Falcon
|
||||
- GPTNeoX
|
||||
- Llama
|
||||
- Qwen
|
||||
- Qwen2
|
||||
|
||||
### 支持板卡
|
||||
|
||||
300系列不支持,其他系列支持。
|
||||
|
||||
### 使用方法
|
||||
|
||||
- 设置环境变量export VLLM_PRELOAD_SIZE=<PRELOAD_SIZE>,<PRELOAD_SIZE>表示预加载权重的大小,单位:MB。
|
||||
- 参数设置参考:在低带宽资源环境下,对于模型Llama-65B,不同batch_sized和preload_size对应的性能优化收益如下。
|
||||
|
||||
| batch\preload | 8 | 16 | 24 | 32 | 48 | 64 |
|
||||
|:--------------:|:----:|:----:|:----:|:----:|:----:|:----:|
|
||||
| 1 | 4.9% | 10.0%| 9.5% | 6.7% |-2.4% | -7.1%|
|
||||
| 8 | 3.2% | 6.3% | 8.9% | 11.2%| 6.0% | 1.8% |
|
||||
| 16 | 2.3% | 5.1% | 7.5% | 9.2% | 8.3% | 4.3% |
|
||||
| 24 | 2.3% | 4.8% | 7.4% | 9.1% | 9.5% | 6.0% |
|
||||
| 32 | 2.1% | 4.3% | 7.0% | 8.7% | 10.1%| 8.1% |
|
||||
@@ -0,0 +1 @@
|
||||
from . import distributed
|
||||
@@ -0,0 +1 @@
|
||||
from . import parallel_state
|
||||
@@ -0,0 +1,75 @@
|
||||
import torch
|
||||
from typing import Union
|
||||
from vllm.distributed.parallel_state import GroupCoordinator, supports_custom_op
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
def vllm__distributed__parallel_state__GroupCoordinator__all_reduce(
|
||||
self, input_: torch.Tensor,
|
||||
async_op: bool = False) -> Union[torch.distributed.Work, torch.Tensor]:
|
||||
"""
|
||||
User-facing all-reduce function before we actually call the
|
||||
all-reduce operation.
|
||||
|
||||
We need this because Dynamo does not support passing an arbitrary
|
||||
object (`self` in this case) to a custom op. We need to pass the
|
||||
group name as a string, and then look up the group coordinator from
|
||||
the group name, dispatch the all-reduce operation to the group
|
||||
coordinator.
|
||||
|
||||
In addition, PyTorch custom ops do not support mutation or returning
|
||||
a new tensor in the same op. So we need to figure out if the op is
|
||||
in-place or out-of-place ahead of time.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
if input_.is_cpu:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
ipex.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
if not supports_custom_op():
|
||||
self._all_reduce_in_place(input_)
|
||||
return input_
|
||||
|
||||
if self.tpu_communicator is not None and \
|
||||
not self.tpu_communicator.disabled:
|
||||
# TPU handles Dynamo with its own logic.
|
||||
return self.tpu_communicator.all_reduce(input_)
|
||||
|
||||
if self.hpu_communicator is not None and \
|
||||
not self.hpu_communicator.disabled:
|
||||
return self.hpu_communicator.all_reduce(input_)
|
||||
|
||||
if self.xpu_communicator is not None and \
|
||||
not self.xpu_communicator.disabled:
|
||||
return self.xpu_communicator.all_reduce(input_)
|
||||
|
||||
if self.ca_comm is not None and \
|
||||
not self.ca_comm.disabled and \
|
||||
self.ca_comm.should_custom_ar(input_):
|
||||
return torch.ops.vllm.outplace_all_reduce(
|
||||
input_, group_name=self.unique_name)
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by custom vllm_mlu
|
||||
=============================
|
||||
@brief: use async all reduce when preload weights.
|
||||
'''
|
||||
handle = torch.distributed.all_reduce(input_, group=self.device_group,
|
||||
async_op=async_op)
|
||||
if async_op:
|
||||
return handle
|
||||
'''
|
||||
==================
|
||||
End of custom MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return input_
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(GroupCoordinator,
|
||||
GroupCoordinator.all_reduce,
|
||||
vllm__distributed__parallel_state__GroupCoordinator__all_reduce)
|
||||
Reference in New Issue
Block a user