[Model] Support DeepSeek-V4
This commit is contained in:
3
vllm_mlu/model_executor/__init__.py
Executable file
3
vllm_mlu/model_executor/__init__.py
Executable file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
3
vllm_mlu/model_executor/layers/__init__.py
Normal file
3
vllm_mlu/model_executor/layers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
25
vllm_mlu/model_executor/layers/activation.py
Normal file
25
vllm_mlu/model_executor/layers/activation.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
def vllm__model_executor__activation__QuickGELU__forward_oot(self, x: torch.Tensor) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: implement forward_oot
|
||||
'''
|
||||
return mlu_ops.active(x, 'quick_gelu', False)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
MluHijackObject.apply_hijack(QuickGELU,
|
||||
QuickGELU.forward_oot,
|
||||
vllm__model_executor__activation__QuickGELU__forward_oot)
|
||||
277
vllm_mlu/model_executor/layers/compressor.py
Normal file
277
vllm_mlu/model_executor/layers/compressor.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import math
|
||||
from typing import Callable
|
||||
from scipy.linalg import hadamard
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
|
||||
|
||||
|
||||
def hadamard_transform_ref(x, scale=1.0):
|
||||
"""
|
||||
x: (..., dim)
|
||||
out: (..., dim)
|
||||
"""
|
||||
x_shape = x.shape
|
||||
dim = x.shape[-1]
|
||||
x = x.reshape(-1, dim)
|
||||
log_dim = math.ceil(math.log2(dim))
|
||||
dim_padded = 2 ** log_dim
|
||||
if dim != dim_padded:
|
||||
x = F.pad(x, (0, dim_padded - dim))
|
||||
out = F.linear(
|
||||
x,
|
||||
torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device),
|
||||
)
|
||||
out = out * scale
|
||||
return out[..., :dim].reshape(*x_shape)
|
||||
|
||||
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dtype == torch.bfloat16
|
||||
hidden_size = x.size(-1)
|
||||
return hadamard_transform_ref(x, scale=hidden_size ** -0.5)
|
||||
|
||||
|
||||
class Compressor(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
rope,
|
||||
compress_ratio: int = 4,
|
||||
head_dim: int = 512,
|
||||
rotate: bool = False,
|
||||
prefix: str = "",
|
||||
**kwargs,):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.dim = config.dim
|
||||
self.head_dim = head_dim
|
||||
self.rope_head_dim =config.rope_head_dim
|
||||
self.nope_head_dim = head_dim - config.rope_head_dim
|
||||
self.compress_ratio = compress_ratio
|
||||
self.overlap = compress_ratio == 4
|
||||
self.rotate = rotate
|
||||
coff = 1 + self.overlap
|
||||
self.norm_eps = config.norm_eps
|
||||
self.window_size = config.window_size
|
||||
|
||||
self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
|
||||
# wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
|
||||
# The first half of dimensions for overlapping compression and second half for normal compression.
|
||||
|
||||
self.wkv = ReplicatedLinear(
|
||||
self.dim,
|
||||
coff * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
params_dtype = torch.float32,
|
||||
prefix=f"{prefix}.wkv",
|
||||
)
|
||||
|
||||
self.wgate = ReplicatedLinear(
|
||||
self.dim,
|
||||
coff * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
params_dtype = torch.float32,
|
||||
prefix=f"{prefix}.wgate",
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(self.head_dim, self.norm_eps)
|
||||
|
||||
self.rotary_emb = rope
|
||||
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
assert hasattr(hf_config, "cached_state_num"), \
|
||||
f"cached_state_num is not set in hf_config"
|
||||
cached_state_num = hf_config.cached_state_num
|
||||
self.register_buffer(
|
||||
"kv_state",
|
||||
torch.zeros(cached_state_num, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"score_state",
|
||||
torch.full(
|
||||
(cached_state_num, coff * compress_ratio, coff * self.head_dim),
|
||||
float("-inf"),
|
||||
dtype=torch.float32,
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.hadamard_matrix = torch.tensor(
|
||||
hadamard(self.head_dim, dtype=float), dtype=torch.get_default_dtype(), device="mlu")
|
||||
|
||||
def overlap_transform(self, tensor: torch.Tensor, value=0):
|
||||
# tensor: [b,s,r,2d]
|
||||
b, s, _, _ = tensor.size()
|
||||
ratio, d = self.compress_ratio, self.head_dim
|
||||
new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
|
||||
new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
|
||||
new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
|
||||
return new_tensor
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
batch_to_kv_state: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
window_offset: int,
|
||||
compressor_slot_mapping: torch.Tensor,
|
||||
):
|
||||
x = x.float()
|
||||
kv_pack, _ = self.wkv(x)
|
||||
score_pack, _ = self.wgate(x)
|
||||
|
||||
mlu_ops.fused_compress_single_kv(
|
||||
kv=kv_pack.unsqueeze(1), # (token, D) -> (B, S, D)
|
||||
score=score_pack.unsqueeze(1), # (token, D) -> (B, S, D)
|
||||
position=positions,
|
||||
ape=self.ape,
|
||||
kv_state=self.kv_state,
|
||||
score_state=self.score_state,
|
||||
gamma=self.norm.weight,
|
||||
sin=self.rotary_emb.sin_,
|
||||
cos=self.rotary_emb.cos_,
|
||||
hadamard_matrix=self.hadamard_matrix,
|
||||
slot_mapping=compressor_slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
kv_cache_scale=None,
|
||||
eps=self.norm_eps,
|
||||
overlap=self.overlap,
|
||||
rotate=self.rotate,
|
||||
state_idx=batch_to_kv_state,
|
||||
)
|
||||
|
||||
# Here, return fake compressed_kv.
|
||||
return None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
batch_to_kv_state: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
window_offset: int,
|
||||
compressor_slot_mapping: torch.Tensor,
|
||||
):
|
||||
common_metadata = get_common_metadata()
|
||||
forward_func: Callable = (
|
||||
self.forward_prefill if common_metadata.is_prefill_only
|
||||
else self.forward_decode
|
||||
)
|
||||
return forward_func(
|
||||
x,
|
||||
positions,
|
||||
attn_metadata,
|
||||
batch_to_kv_state,
|
||||
kv_cache,
|
||||
window_offset,
|
||||
compressor_slot_mapping,
|
||||
)
|
||||
|
||||
def forward_prefill(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
batch_to_kv_state: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
window_offset: int,
|
||||
compressor_slot_mapping: torch.Tensor,
|
||||
):
|
||||
common_metadata = get_common_metadata()
|
||||
seq_lens = common_metadata.seq_lens
|
||||
query_start_loc = common_metadata.query_start_loc
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
ratio, overlap = self.compress_ratio, self.overlap
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
kv_pack, _ = self.wkv(x)
|
||||
score_pack, _ = self.wgate(x)
|
||||
|
||||
compress_lens = query_lens // self.compress_ratio
|
||||
cu_compress_lens = torch.cat([
|
||||
torch.tensor([0], dtype=compress_lens.dtype, device=compress_lens.device),
|
||||
torch.cumsum(compress_lens, dim=0)],
|
||||
)
|
||||
|
||||
compress_positions = []
|
||||
for i in range(len(seq_lens)):
|
||||
seqlen = (query_start_loc[i+1] - query_start_loc[i]).item()
|
||||
remainder = seqlen % ratio
|
||||
cutoff = seqlen - remainder
|
||||
pos = positions[query_start_loc[i]: query_start_loc[i+1]]
|
||||
positions_ = pos[:cutoff:ratio].contiguous()
|
||||
compress_positions.append(positions_)
|
||||
kv_positions = torch.cat(compress_positions, dim=0)
|
||||
|
||||
|
||||
total_compress_len = cu_compress_lens[-1].item()
|
||||
kv = torch.empty(
|
||||
[total_compress_len, self.head_dim],
|
||||
dtype=kv_pack.dtype,
|
||||
device=kv_pack.device,
|
||||
)
|
||||
|
||||
mlu_ops.fused_compress_multi_kv(
|
||||
kv = kv_pack,
|
||||
score = score_pack,
|
||||
kv_state = self.kv_state,
|
||||
score_state = self.score_state,
|
||||
state_batch_idx = batch_to_kv_state,
|
||||
cu_seqlens = query_start_loc,
|
||||
ape = self.ape,
|
||||
max_seqlen = common_metadata.max_query_len,
|
||||
overlap = overlap,
|
||||
compressed_kv = kv,
|
||||
)
|
||||
|
||||
if kv.size(0) == 0:
|
||||
return kv.unsqueeze(-2).to(dtype) # (compress_token_num, 1, head_size)
|
||||
|
||||
|
||||
kv = self.norm(kv.to(dtype))
|
||||
|
||||
kv_rope = kv[..., -self.rope_head_dim:].unsqueeze(-2)
|
||||
# use compressed cu_seqlens here, so can not call rotary_emb directly
|
||||
kv_rope = mlu_ops.rotary_embedding(
|
||||
kv_rope,
|
||||
self.rotary_emb.sin_,
|
||||
self.rotary_emb.cos_,
|
||||
kv_positions,
|
||||
torch.tensor([0, kv_positions.size(0)], dtype=torch.int32, device=kv_positions.device), # cu_seqlens
|
||||
True, # interleaved
|
||||
True, # discrete
|
||||
False,
|
||||
common_metadata.max_query_len,
|
||||
)
|
||||
|
||||
if self.rotate:
|
||||
kv = rotate_activation(kv)
|
||||
|
||||
mlu_ops.reshape_paged_cache(
|
||||
kv.unsqueeze(1),
|
||||
None,
|
||||
kv_cache,
|
||||
None,
|
||||
compressor_slot_mapping,
|
||||
)
|
||||
|
||||
return kv.unsqueeze(-2) # (compress_token_num, 1, head_size)
|
||||
85
vllm_mlu/model_executor/layers/dp_logits_processor.py
Normal file
85
vllm_mlu/model_executor/layers/dp_logits_processor.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.communication_op import (
|
||||
tensor_model_parallel_all_gather, tensor_model_parallel_gather)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm_mlu.model_executor.models.dp_utils import (
|
||||
tensor_model_parallel_all_gather_dp, DataParallelRuntimeParams)
|
||||
|
||||
|
||||
class DPLogitsProcessor(LogitsProcessor):
|
||||
"""DP LogitsProcessor."""
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
dp_params: Optional[DataParallelRuntimeParams] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Get the logits for the next tokens.
|
||||
batch_sizes = None
|
||||
if (lm_head.tp_group is not None
|
||||
and dp_params is not None
|
||||
and dp_params.logits_batch_split_list is not None):
|
||||
batch_sizes = dp_params.logits_batch_split_list
|
||||
hidden_states = tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens=batch_sizes,
|
||||
rank=lm_head.tp_rank,
|
||||
hidden_states=hidden_states,
|
||||
group=lm_head.tp_group,
|
||||
)
|
||||
|
||||
logits = lm_head.quant_method.apply(
|
||||
lm_head, hidden_states, bias=embedding_bias)
|
||||
|
||||
if self.use_all_gather:
|
||||
# Gather is not supported for some devices such as TPUs.
|
||||
# Use all-gather instead.
|
||||
# NOTE(woosuk): Here, the outputs of every device should not be None
|
||||
# because XLA requires strict SPMD among all devices. Every device
|
||||
# should execute the same operations after gathering the logits.
|
||||
logits = tensor_model_parallel_all_gather(logits, tp_group=lm_head.tp_group)
|
||||
else:
|
||||
# None may be returned for rank > 0
|
||||
logits = tensor_model_parallel_gather(logits, tp_group=lm_head.tp_group)
|
||||
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[..., : self.org_vocab_size]
|
||||
|
||||
if batch_sizes is not None:
|
||||
offset = sum(batch_sizes[:lm_head.tp_rank])
|
||||
logits = logits[offset : offset + batch_sizes[lm_head.tp_rank]]
|
||||
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
hidden_states: torch.Tensor,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
dp_params: Optional[DataParallelRuntimeParams] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if self.logits_as_input:
|
||||
logits = hidden_states
|
||||
else:
|
||||
# Get the logits for the next tokens.
|
||||
logits = self._get_logits(
|
||||
hidden_states, lm_head, embedding_bias, dp_params)
|
||||
if logits is not None:
|
||||
if self.soft_cap is not None:
|
||||
logits = logits / self.soft_cap
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.soft_cap
|
||||
|
||||
if self.scale != 1.0:
|
||||
logits *= self.scale
|
||||
|
||||
return logits
|
||||
219
vllm_mlu/model_executor/layers/dp_vocab_parallel_embedding.py
Normal file
219
vllm_mlu/model_executor/layers/dp_vocab_parallel_embedding.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
method_has_implemented_embedding,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod,
|
||||
VocabParallelEmbedding,
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
get_masked_input_and_mask,
|
||||
pad_vocab_size,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.distributed.communication_op import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_logits_tp_group,
|
||||
get_logits_tp_world_size,
|
||||
get_logits_tp_rank,
|
||||
)
|
||||
from vllm_mlu.model_executor.models.dp_utils import (
|
||||
DataParallelRuntimeParams,
|
||||
tensor_model_parallel_all_gather_dp,
|
||||
)
|
||||
|
||||
|
||||
class DPVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
"""DP Embedding parallelized in the vocabulary dimension."""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
"""
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add self.tp_group, world_size and tp_rank to support other parallel
|
||||
"""
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_world_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = None
|
||||
logits_tp_world_size = get_logits_tp_world_size()
|
||||
if logits_tp_world_size != self.tp_world_size:
|
||||
self.tp_group = get_logits_tp_group()
|
||||
self.tp_world_size = logits_tp_world_size
|
||||
self.tp_rank = get_logits_tp_rank()
|
||||
|
||||
# Keep the input dimensions.
|
||||
tp_rank = self.tp_rank
|
||||
self.tp_size = self.tp_world_size
|
||||
"""
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
"""
|
||||
self.num_embeddings = num_embeddings
|
||||
self.padding_size = padding_size
|
||||
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||
num_added_embeddings = num_embeddings - self.org_vocab_size
|
||||
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
|
||||
self.padding_size)
|
||||
self.num_embeddings_padded = pad_vocab_size(
|
||||
self.org_vocab_size_padded + num_added_embeddings,
|
||||
self.padding_size)
|
||||
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
||||
|
||||
self.shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
quant_method = None
|
||||
if quant_config is not None:
|
||||
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedEmbeddingMethod()
|
||||
|
||||
# If we are making an embedding layer, then our quantization linear
|
||||
# method must implement the embedding operation. If we are another
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self) is VocabParallelEmbedding
|
||||
quant_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(quant_method))
|
||||
if is_embedding_layer and not quant_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(quant_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||
|
||||
self.quant_method: QuantizeMethodBase = quant_method
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
||||
self.tp_size)
|
||||
assert (self.shard_indices.num_elements_padded ==
|
||||
self.num_embeddings_per_partition)
|
||||
self.num_org_embeddings_per_partition = (
|
||||
self.shard_indices.org_vocab_end_index -
|
||||
self.shard_indices.org_vocab_start_index)
|
||||
self.num_added_embeddings_per_partition = (
|
||||
self.shard_indices.added_vocab_end_index -
|
||||
self.shard_indices.added_vocab_start_index)
|
||||
|
||||
self.quant_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
def forward(self, input_,
|
||||
dp_params: Optional[DataParallelRuntimeParams] = None):
|
||||
token_split_list = None
|
||||
if (dp_params is not None
|
||||
and self.tp_group is not None
|
||||
and dp_params.emb_token_split_list is not None):
|
||||
token_split_list = dp_params.emb_token_split_list
|
||||
input_ = tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens=token_split_list,
|
||||
rank=self.tp_rank,
|
||||
hidden_states=input_.reshape(-1, 1),
|
||||
group=self.tp_group,
|
||||
).reshape(-1)
|
||||
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_,
|
||||
self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index,
|
||||
)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel, tp_group=self.tp_group)
|
||||
|
||||
if token_split_list is not None:
|
||||
offset = sum(token_split_list[:self.tp_rank])
|
||||
output = output[offset : offset + token_split_list[self.tp_rank]]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class DPParallelLMHead(DPVocabParallelEmbedding):
|
||||
"""DP Parallelized LM head.
|
||||
|
||||
NOTE: A copy of ParallelLMHead class, and only change its parent
|
||||
from VocabParallelEmbedding to DPVocabParallelEmbedding.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
||||
org_num_embeddings, padding_size, quant_config,
|
||||
prefix)
|
||||
self.quant_config = quant_config
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
|
||||
"""Tie the weights with word embeddings."""
|
||||
# GGUF quantized embed_tokens.
|
||||
if self.quant_config and self.quant_config.get_name() == "gguf":
|
||||
return embed_tokens
|
||||
else:
|
||||
self.weight = embed_tokens.weight
|
||||
return self
|
||||
|
||||
def forward(self, input_):
|
||||
del input_
|
||||
raise RuntimeError("LMHead's weights should be used in the sampler.")
|
||||
224
vllm_mlu/model_executor/layers/feed_forward.py
Executable file
224
vllm_mlu/model_executor/layers/feed_forward.py
Executable file
@@ -0,0 +1,224 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Any
|
||||
|
||||
from vllm.distributed import (
|
||||
get_parallel_world_size_with_group,
|
||||
get_parallel_rank_with_group,
|
||||
)
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear
|
||||
)
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.mlu_hijack_utils import set_is_gated
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class FeedForward(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
up_proj_name: str,
|
||||
is_gated: bool,
|
||||
down_proj_name: str,
|
||||
bias: bool,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
skip_bias_add: bool = False,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
tp_group: Any = None,
|
||||
keep_full_weights: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.is_gated = is_gated
|
||||
self.bias = bias
|
||||
self.up_proj_name = up_proj_name
|
||||
self.down_proj_name = down_proj_name
|
||||
self.quant_config = quant_config
|
||||
self.is_initialized = False
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.reduce_results = reduce_results
|
||||
self.use_bt_ffn = True
|
||||
set_is_gated(self.is_gated)
|
||||
|
||||
# modify tp_size, tp_rank and tp_group when enable data parallel
|
||||
self.tp_size = get_parallel_world_size_with_group(tp_group)
|
||||
self.tp_rank = get_parallel_rank_with_group(tp_group)
|
||||
self.tp_group = tp_group
|
||||
self.keep_full_weights = keep_full_weights
|
||||
if self.keep_full_weights:
|
||||
self.tp_size = 1
|
||||
self.tp_rank = 0
|
||||
self.tp_group = None
|
||||
|
||||
# up_proj with gate or not
|
||||
if self.is_gated:
|
||||
up_proj = MergedColumnParallelLinear(hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.{up_proj_name}",
|
||||
tp_group=self.tp_group,
|
||||
keep_full_weights=keep_full_weights)
|
||||
else:
|
||||
up_proj = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=bias,
|
||||
skip_bias_add=skip_bias_add,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.{up_proj_name}",
|
||||
tp_group=self.tp_group,
|
||||
keep_full_weights=keep_full_weights)
|
||||
self.register_module(up_proj_name, up_proj)
|
||||
|
||||
# down_proj
|
||||
down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
skip_bias_add=skip_bias_add,
|
||||
reduce_results=reduce_results,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.{down_proj_name}",
|
||||
tp_group=self.tp_group,
|
||||
keep_full_weights=keep_full_weights)
|
||||
|
||||
self.register_module(down_proj_name, down_proj)
|
||||
|
||||
def prepare_weight(self):
|
||||
if not self.is_initialized:
|
||||
# alpha and beta are 1.0 and 0.0 respectively due to the fact that we don't need residual for now
|
||||
self.alpha = 1.0
|
||||
self.beta = 0.0
|
||||
# place it here to avoid the overhead of calling it in the forward pass
|
||||
self.is_initialized = True
|
||||
|
||||
def _forward(self, hidden_states):
|
||||
self.prepare_weight()
|
||||
up_proj = getattr(self, self.up_proj_name)
|
||||
down_proj = getattr(self, self.down_proj_name)
|
||||
act_dict = {
|
||||
"relu": F.relu,
|
||||
"gelu": F.gelu,
|
||||
"silu": F.silu,
|
||||
}
|
||||
fc1 = F.linear(hidden_states, up_proj.weight, bias=up_proj.bias)
|
||||
if self.is_gated:
|
||||
d = fc1.shape[-1] // 2
|
||||
fc1 = act_dict[self.hidden_act](fc1[..., :d]) * fc1[..., d:]
|
||||
else:
|
||||
fc1 = act_dict[self.hidden_act](fc1)
|
||||
fc2 = F.linear(fc1, down_proj.weight, bias=None)
|
||||
fc2 = tensor_model_parallel_all_reduce(fc2)
|
||||
if not self.skip_bias_add:
|
||||
fc2 = fc2 + down_proj.bias if down_proj.bias is not None else fc2
|
||||
return fc2
|
||||
|
||||
def forward_naive(
|
||||
self,
|
||||
hidden_states,
|
||||
residual: torch.Tensor | None = None,
|
||||
smooth_quant_scale: torch.Tensor | None = None
|
||||
):
|
||||
'''
|
||||
used by quant_tools
|
||||
'''
|
||||
assert self.quant_config is None, "ffn naive forward dosen't support quantization"
|
||||
assert smooth_quant_scale is None, "ffn naive forward dosen't support smooth_quant_scale"
|
||||
|
||||
up_proj = getattr(self, self.up_proj_name)
|
||||
down_proj = getattr(self, self.down_proj_name)
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
fc1, bias = up_proj(hidden_states)
|
||||
if bias is not None:
|
||||
fc1 += bias
|
||||
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
|
||||
out, bias = down_proj(fc1, residual=residual_)
|
||||
|
||||
if self.skip_bias_add:
|
||||
return out, bias
|
||||
return out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual: torch.Tensor | None = None,
|
||||
smooth_quant_scale: torch.Tensor | None = None,
|
||||
use_tp_weight: bool = False,
|
||||
output: torch.Tensor | None = None,
|
||||
):
|
||||
self.prepare_weight()
|
||||
|
||||
if self.use_bt_ffn is False:
|
||||
return self.forward_naive(hidden_states, residual, None)
|
||||
|
||||
up_proj = getattr(self, self.up_proj_name)
|
||||
down_proj = getattr(self, self.down_proj_name)
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
if (self.quant_config is None and not isinstance(up_proj, BaseLayerWithLoRA)
|
||||
and not isinstance(down_proj, BaseLayerWithLoRA)):
|
||||
# The matmul formula is the following:
|
||||
# mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual
|
||||
# output = active(mul_out)
|
||||
# Notes: We cannot use the activation function in matmul because it does not support gated operation
|
||||
# we might support its in tmo matmul in the future
|
||||
up_proj_weight = up_proj.weight
|
||||
down_proj_weight = down_proj.weight
|
||||
if self.keep_full_weights and use_tp_weight:
|
||||
up_proj_weight = up_proj.tp_weight
|
||||
down_proj_weight = down_proj.tp_weight
|
||||
fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj_weight, up_proj.bias,
|
||||
None, 'none', self.alpha, self.beta)
|
||||
act_out = mlu_ops.active(fc1.float(), self.hidden_act, self.is_gated).to(dtype=fc1.dtype)
|
||||
beta = 0.0
|
||||
if residual_ is not None:
|
||||
beta = 1.0
|
||||
residual_ = residual_.view(-1, residual_.shape[-1])
|
||||
out_ = mlu_ops.matmul(act_out, down_proj_weight, None, residual_, 'none', self.alpha, beta)
|
||||
# bias if existed need to add after second matmul according to the original design of vllm
|
||||
if self.reduce_results:
|
||||
out = tensor_model_parallel_all_reduce(out_, self.tp_group)
|
||||
else:
|
||||
out = out_
|
||||
# do the bias add if needed
|
||||
if not self.skip_bias_add:
|
||||
out = out + down_proj.bias if down_proj.bias is not None else out
|
||||
else:
|
||||
return out, down_proj.bias
|
||||
else:
|
||||
fc1, bias = up_proj(hidden_states, smooth_quant_scale=smooth_quant_scale, use_tp_weight=use_tp_weight)
|
||||
if bias is not None:
|
||||
fc1 += bias
|
||||
input_scale= None
|
||||
if (self.quant_config is not None and self.quant_config.get_name() == "SmoothQuant" and
|
||||
self.quant_config.input_quant_method == "per_token" and not self.quant_config.is_fp8):
|
||||
down_proj.quant_method.skip_quant_input = True
|
||||
down_proj_smooth = down_proj.smooth
|
||||
if self.keep_full_weights and use_tp_weight:
|
||||
assert down_proj.tp_smooth is not None, "tp_smooth is not initialized"
|
||||
down_proj_smooth = down_proj.tp_smooth
|
||||
fc1, input_scale = mlu_ops.per_token_smooth_quantize(
|
||||
fc1, down_proj_smooth, None, None, act_mode=self.hidden_act, is_gated=self.is_gated)
|
||||
else:
|
||||
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
|
||||
out, bias = down_proj(
|
||||
fc1, residual=residual_, smooth_quant_scale=input_scale,
|
||||
use_tp_weight=use_tp_weight, output=output)
|
||||
|
||||
if self.skip_bias_add:
|
||||
return out, bias
|
||||
return out
|
||||
3
vllm_mlu/model_executor/layers/fused_moe/__init__.py
Normal file
3
vllm_mlu/model_executor/layers/fused_moe/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
935
vllm_mlu/model_executor/layers/fused_moe/fused_moe.py
Normal file
935
vllm_mlu/model_executor/layers/fused_moe/fused_moe.py
Normal file
@@ -0,0 +1,935 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Fused MoE kernel."""
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
_get_config_dtype_str,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_moe_kernel_gptq_awq,
|
||||
write_zeros_to_output,
|
||||
get_default_config,
|
||||
try_get_optimal_moe_config,
|
||||
_get_config_quant_dtype,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
activation_without_mul,
|
||||
disable_inplace,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
from vllm_mlu.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||
import vllm_mlu._mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
b_bias_ptr,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_valid_tokens,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
stride_bbe, # bias expert stride
|
||||
stride_bbn, # bias N stride
|
||||
# Block size for block-wise quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
per_channel_quant: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||
token and expert matrices.
|
||||
|
||||
Key Parameters:
|
||||
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
||||
be any shape representing batches and K is the feature dimension of
|
||||
each token.
|
||||
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
||||
the number of experts, K is the input feature dimension, and N is
|
||||
the output feature dimension.
|
||||
- C: The output cache tensor with shape (M, topk, N), where M is the
|
||||
total number of tokens post padding, topk is the number of times
|
||||
each token is repeated, and N is the output feature dimension.
|
||||
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
||||
repeated topk times and arranged by the expert index they are
|
||||
assigned to.
|
||||
- expert_ids: A tensor containing the indices of the expert for each
|
||||
block. It determines which expert matrix from B should be used for
|
||||
each block in A.
|
||||
This kernel performs the multiplication of a token by its corresponding
|
||||
expert matrix as determined by `expert_ids`. The sorting of
|
||||
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
||||
multiplication across different blocks processed by the same expert.
|
||||
"""
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of C it should compute.
|
||||
# This is done in a grouped ordering to promote L2 data reuse.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Split the program ID into two dimensions (pid_0 and pid_1)
|
||||
'''
|
||||
pid_0 = tl.program_id(axis=0)
|
||||
pid_1 = tl.program_id(axis=1)
|
||||
pid = pid_1 * tl.num_programs(axis=0) + pid_0
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Create pointers for the first blocks of A and B.
|
||||
# We will advance this pointer as we move in the K direction
|
||||
# and accumulate
|
||||
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
if off_experts == -1:
|
||||
# -----------------------------------------------------------
|
||||
# Write back zeros to the output when the expert is not
|
||||
# in the current expert parallel rank.
|
||||
write_zeros_to_output(
|
||||
c_ptr,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
pid_n,
|
||||
N,
|
||||
offs_token,
|
||||
token_mask,
|
||||
BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N,
|
||||
compute_type,
|
||||
)
|
||||
return
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
||||
)
|
||||
|
||||
b_ptrs = (
|
||||
b_ptr
|
||||
+ off_experts * stride_be
|
||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
)
|
||||
if use_int8_w8a16:
|
||||
b_scale_ptrs = (
|
||||
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
# block-wise
|
||||
if group_k > 0 and group_n > 0:
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
offs_bsn = offs_bn // group_n
|
||||
b_scale_ptrs = (
|
||||
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
||||
)
|
||||
# channel-wise
|
||||
elif per_channel_quant:
|
||||
b_scale_ptrs = (
|
||||
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
# Load per-token scale for activations
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
|
||||
# tensor-wise
|
||||
else:
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||
if HAS_BIAS:
|
||||
# bias shape: [num_experts, N]
|
||||
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
|
||||
bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix.
|
||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop.
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the
|
||||
# K dimension.
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
if use_int8_w8a16:
|
||||
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_scale = tl.load(
|
||||
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
||||
|
||||
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
||||
else:
|
||||
if use_fp8_w8a8:
|
||||
# acc used to enable fp8_fast_accum
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
if HAS_BIAS:
|
||||
accumulator = accumulator + bias[None, :]
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
if use_int8_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
else:
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
else:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: torch.Tensor | None,
|
||||
B_scale: torch.Tensor | None,
|
||||
B_zp: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: dict[str, Any],
|
||||
compute_type: tl.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: list[int] | None = None,
|
||||
B_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or triton.cdiv(
|
||||
B.size(-2), block_shape[0]
|
||||
) == B_scale.size(-2)
|
||||
assert block_shape is None or triton.cdiv(
|
||||
B.size(-1), block_shape[1]
|
||||
) == B_scale.size(-1)
|
||||
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
M = A.size(0)
|
||||
num_tokens = M * top_k
|
||||
|
||||
EM = sorted_token_ids.size(0)
|
||||
if A.size(0) < config["BLOCK_SIZE_M"]:
|
||||
# optimize for small batch_size.
|
||||
# We assume that top_ids of each token is unique,
|
||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||
# and we can skip some invalid blocks.
|
||||
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Split the program ID into two dimensions (pid_0, pid_1)
|
||||
'''
|
||||
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']), triton.cdiv(
|
||||
B.shape[1], META['BLOCK_SIZE_N']), )
|
||||
|
||||
assert not (use_int8_w8a16 or use_int4_w4a16)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
HAS_BIAS = B_bias is not None
|
||||
if (
|
||||
(use_int8_w8a16 or use_int4_w4a16)
|
||||
and block_shape is not None
|
||||
and block_shape[1] > 0
|
||||
):
|
||||
assert B_scale is not None and B_scale.ndim == 3
|
||||
assert B_zp is None or B_zp.ndim == 3
|
||||
|
||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
||||
num_valid_tokens=num_tokens,
|
||||
group_size=block_shape[1],
|
||||
num_experts=B.size(0),
|
||||
bit=4 if use_int4_w4a16 else 8,
|
||||
)
|
||||
config = config.copy()
|
||||
config.update(
|
||||
get_moe_wna16_block_config(
|
||||
config=config,
|
||||
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||
num_valid_tokens=num_tokens,
|
||||
size_k=A.size(1),
|
||||
size_n=B.size(1),
|
||||
num_experts=B.size(1),
|
||||
group_size=block_shape[1],
|
||||
real_top_k=top_k,
|
||||
block_size_m=config["BLOCK_SIZE_M"],
|
||||
)
|
||||
)
|
||||
|
||||
if use_moe_wna16_cuda:
|
||||
bit = 4 if use_int4_w4a16 else 8
|
||||
ops.moe_wna16_gemm(
|
||||
A,
|
||||
C,
|
||||
B,
|
||||
B_scale,
|
||||
B_zp,
|
||||
topk_weights if mul_routed_weight else None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
top_k,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
bit,
|
||||
)
|
||||
return
|
||||
fused_moe_kernel_gptq_awq[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_scale,
|
||||
B_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.size(1),
|
||||
A.size(1),
|
||||
EM,
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
B_scale.stride(0),
|
||||
B_scale.stride(2),
|
||||
B_scale.stride(1),
|
||||
B_zp.stride(0) if B_zp is not None else 0,
|
||||
B_zp.stride(2) if B_zp is not None else 0,
|
||||
B_zp.stride(1) if B_zp is not None else 0,
|
||||
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
|
||||
group_size=block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
has_zp=B_zp is not None,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
**config,
|
||||
)
|
||||
else:
|
||||
config = config.copy()
|
||||
config["SPLIT_K"] = 1
|
||||
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
|
||||
if block_shape is not None:
|
||||
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
|
||||
fused_moe_kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_bias,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.size(1),
|
||||
B.size(2),
|
||||
EM,
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_bias.stride(0) if B_bias is not None else 0,
|
||||
B_bias.stride(1) if B_bias is not None else 0,
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
HAS_BIAS=HAS_BIAS,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
**config,
|
||||
)
|
||||
|
||||
|
||||
def outplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
True,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
ocp_mx_scheme,
|
||||
per_channel_quant,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
w1_zp,
|
||||
w2_zp,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
block_shape,
|
||||
w1_bias,
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="outplace_fused_experts_mlu",
|
||||
op_func=outplace_fused_experts,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=outplace_fused_experts_fake,
|
||||
dispatch_key="PrivateUse1",
|
||||
tags=(
|
||||
()
|
||||
if is_torch_equal_or_newer("2.7.0")
|
||||
else (torch.Tag.needs_fixed_stride_order,)
|
||||
),
|
||||
)
|
||||
|
||||
def fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
return torch.ops.vllm.outplace_fused_experts_mlu(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
SILU_NO_MUL: str = activation_without_mul("silu")
|
||||
GELU_NO_MUL: str = activation_without_mul("gelu")
|
||||
RELU2_NO_MUL: str = activation_without_mul("relu2")
|
||||
|
||||
|
||||
def fused_experts_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
|
||||
elif ocp_mx_scheme is not None:
|
||||
if ocp_mx_scheme in {
|
||||
"w_mxfp4_a_mxfp4",
|
||||
"w_mxfp4_a_mxfp6_e3m2",
|
||||
"w_mxfp4_a_mxfp6_e2m3",
|
||||
}:
|
||||
# 16bit activation and fp4x2 packed weight
|
||||
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
|
||||
elif ocp_mx_scheme in {
|
||||
"w_mxfp6_e3m2_a_mxfp6_e3m2",
|
||||
"w_mxfp6_e2m3_a_mxfp6_e2m3",
|
||||
}:
|
||||
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
|
||||
"hidden size mismatch"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
|
||||
else:
|
||||
assert hidden_states.size(1) == w1.size(2), (
|
||||
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
|
||||
)
|
||||
|
||||
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
num_tokens = hidden_states.size(0)
|
||||
E, N, _ = w1.size()
|
||||
K = w2.size(1)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
top_k_num = topk_ids.size(1)
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
|
||||
config_dtype = _get_config_dtype_str(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
# Note: for use_int8_w8a16 or use_int4_w4a16, the activations are
|
||||
# quantized prior to calling fused_experts.
|
||||
quant_dtype = _get_config_quant_dtype(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.size(),
|
||||
w2.size(),
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Only use the default config
|
||||
'''
|
||||
config = get_default_config(M, E, N, w1.shape[2], topk_ids.shape[1],
|
||||
hidden_states.dtype, block_shape)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# We can reuse the memory between these because by the time we need
|
||||
# cache3, we're done with cache1
|
||||
cache13 = torch.empty(
|
||||
M * top_k_num * max(N, K),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
|
||||
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
|
||||
|
||||
# This needs separate memory since it's used concurrently with cache1
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.bfloat16:
|
||||
compute_type = tl.bfloat16
|
||||
elif hidden_states.dtype == torch.float16:
|
||||
compute_type = tl.float16
|
||||
elif hidden_states.dtype == torch.float32:
|
||||
compute_type = tl.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
if inplace and not disable_inplace():
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
if ocp_mx_scheme is not None:
|
||||
# TODO: On platforms for which `current_platform.supports_mx()` is True
|
||||
# and for which we have a native OCP mx fused MOE kernel,
|
||||
# this dequantization step should not be done.
|
||||
if ocp_mx_scheme in {
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
|
||||
}:
|
||||
# Weight has to be dequantized for mxfp4 emulation.
|
||||
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
|
||||
w2_scale = None
|
||||
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2:
|
||||
w1 = dequant_mxfp6(
|
||||
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp6(
|
||||
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w2_scale = None
|
||||
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3:
|
||||
w1 = dequant_mxfp6(
|
||||
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp6(
|
||||
w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w2_scale = None
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
|
||||
|
||||
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
||||
begin_chunk_idx, end_chunk_idx = (
|
||||
chunk * CHUNK_SIZE,
|
||||
min((chunk + 1) * CHUNK_SIZE, num_tokens),
|
||||
)
|
||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
tokens_in_chunk, _ = curr_hidden_states.size()
|
||||
|
||||
if tokens_in_chunk == 0:
|
||||
break
|
||||
|
||||
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
||||
# Adjust the intermediate cache size and config for the last
|
||||
# chunk. Note that in most cases we only have one chunk
|
||||
# so the cache size and config are already set correctly and
|
||||
# do not need to be adjusted.
|
||||
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
||||
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
||||
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
a1q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
if use_fp8_w8a8:
|
||||
qcurr_hidden_states, a1q_scale = _fp8_quantize(
|
||||
curr_hidden_states, a1_scale, block_shape)
|
||||
else:
|
||||
qcurr_hidden_states = curr_hidden_states
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
|
||||
)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
qcurr_hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
curr_topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
apply_router_weight_on_input,
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
B_bias=w1_bias,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Activate by mlu_ops
|
||||
'''
|
||||
intermediate_cache2 = mlu_ops.active(intermediate_cache1.view(-1, N),
|
||||
act_mode=activation,
|
||||
is_gated=True)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
if use_fp8_w8a8:
|
||||
qintermediate_cache2, a2q_scale = _fp8_quantize(
|
||||
intermediate_cache2, a2_scale, block_shape)
|
||||
else:
|
||||
qintermediate_cache2 = intermediate_cache2
|
||||
invoke_fused_moe_kernel(
|
||||
qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
curr_topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
B_bias=w2_bias,
|
||||
)
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace moe_sum with torch.sum
|
||||
Reference Links: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py#L1513
|
||||
'''
|
||||
if topk_ids.shape[1] == 2:
|
||||
torch.add(
|
||||
intermediate_cache3[:, 0],
|
||||
intermediate_cache3[:, 1],
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
).squeeze(dim=1)
|
||||
elif topk_ids.shape[1] > 2:
|
||||
torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return out_hidden_states
|
||||
|
||||
106
vllm_mlu/model_executor/layers/fused_moe/layer.py
Normal file
106
vllm_mlu/model_executor/layers/fused_moe/layer.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Optional, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
|
||||
|
||||
def vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
#TODO: support `routed_scaling_factor`
|
||||
assert routed_scaling_factor == 1.0, (
|
||||
f"routed_scaling_factor {routed_scaling_factor} is not supported for MLU."
|
||||
)
|
||||
use_fused_kernel = topk_group is None
|
||||
if use_fused_kernel:
|
||||
assert not enable_eplb, f"MLU not support eplb in fused_moe kernel."
|
||||
assert use_grouped_topk is False and num_expert_group is None and topk_group is None, \
|
||||
f"Following params: use_grouped_topk, num_expert_group, topk_group are not support yet."
|
||||
return mlu_ops.fused_moe(
|
||||
x,
|
||||
router_logits,
|
||||
layer.w13_weight, layer.w2_weight,
|
||||
None, None, # bias1, bias2
|
||||
None, # residual
|
||||
None, # input_smooth
|
||||
None, # act_smooth
|
||||
None, None, # w1_scale, w2_scale
|
||||
top_k,
|
||||
renormalize,
|
||||
True, # gated
|
||||
activation
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
assert expert_map is None
|
||||
return self.rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
else:
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
UnquantizedFusedMoEMethod,
|
||||
UnquantizedFusedMoEMethod.forward_oot,
|
||||
vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot
|
||||
)
|
||||
248
vllm_mlu/model_executor/layers/fused_moe/moe_align_block_size.py
Normal file
248
vllm_mlu/model_executor/layers/fused_moe/moe_align_block_size.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Implementation of moe_align_block_size_triton.
|
||||
Note: the implemtentation has been removed from vllm since the
|
||||
cuda implementation is more efficient.
|
||||
'''
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage1(
|
||||
topk_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
|
||||
off_c = (pid + 1) * num_experts
|
||||
|
||||
for i in range(tokens_per_thread):
|
||||
if start_idx + i < numel:
|
||||
idx = tl.load(topk_ids_ptr + start_idx + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
||||
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage2(
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
last_cnt = 0
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
||||
last_cnt = last_cnt + token_cnt
|
||||
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage3(
|
||||
total_tokens_post_pad_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
last_cumsum = 0
|
||||
off_cnt = num_experts * num_experts
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
||||
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
||||
tl.store(cumsum_ptr + i, last_cumsum)
|
||||
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage4(
|
||||
topk_ids_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_idx = tl.load(cumsum_ptr + pid)
|
||||
end_idx = tl.load(cumsum_ptr + pid + 1)
|
||||
|
||||
for i in range(start_idx, end_idx, block_size):
|
||||
tl.store(expert_ids_ptr + i // block_size, pid)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
off_t = pid * num_experts
|
||||
|
||||
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
|
||||
numel)):
|
||||
expert_id = tl.load(topk_ids_ptr + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
||||
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
||||
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
||||
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
||||
|
||||
|
||||
# Triton implementation based on:
|
||||
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
|
||||
def moe_align_block_size_triton(
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
) -> None:
|
||||
numel = topk_ids.numel()
|
||||
grid = (num_experts, )
|
||||
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
cumsum = torch.zeros((num_experts + 1, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
tokens_per_thread = cdiv(numel, num_experts)
|
||||
sorted_token_ids.fill_(numel)
|
||||
expert_ids.zero_()
|
||||
|
||||
moe_align_block_size_stage1[grid](
|
||||
topk_ids,
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
moe_align_block_size_stage2[grid](
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
)
|
||||
moe_align_block_size_stage3[(1, )](
|
||||
num_tokens_post_pad,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
)
|
||||
moe_align_block_size_stage4[grid](
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
size for matrix multiplication.
|
||||
|
||||
Parameters:
|
||||
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
||||
top-k expert indices for each token.
|
||||
- block_size: The block size used in block matrix multiplication.
|
||||
- num_experts: The total number of experts.
|
||||
- expert_map: A tensor of shape [num_experts] that maps the expert index
|
||||
from the global space to the local index space of the current
|
||||
expert parallel shard. If the expert is not in the current expert
|
||||
parallel shard, the mapping is set to -1.
|
||||
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
|
||||
should be padded to a multiple of block_size,
|
||||
|
||||
Returns:
|
||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||
to their allocated expert.
|
||||
- expert_ids: A tensor indicating the assigned expert index for each block.
|
||||
- num_tokens_post_padded: The total number of tokens after padding,
|
||||
ensuring divisibility by block_size.
|
||||
|
||||
This function pads the number of tokens that each expert needs to process
|
||||
so that it is divisible by block_size.
|
||||
Padding ensures that during block matrix multiplication, the dimensions
|
||||
align correctly.
|
||||
|
||||
Example:
|
||||
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
||||
block_size = 4, and num_experts = 4:
|
||||
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
||||
with each expert needing to process 3 tokens.
|
||||
- As block_size is 4, we pad 1 token for each expert.
|
||||
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
||||
- Then append padding tokens [12, 12, 12, 12] for each block.
|
||||
- After sorting by expert index, we obtain token_ids
|
||||
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
||||
Tokens 12 are non-existent (padding) and are ignored in
|
||||
the subsequent matrix multiplication.
|
||||
- The padding ensures that the total number of tokens is now divisible
|
||||
by block_size for proper block matrix operations.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
# Expert ids must be zeroed out to prevent index out of bounds error while
|
||||
# mapping global expert ids to local expert ids in expert parallelism.
|
||||
expert_ids = torch.zeros((max_num_m_blocks, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
num_tokens_post_pad = torch.empty((1),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Only use triton to implement moe_align_block_size
|
||||
'''
|
||||
moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
31
vllm_mlu/model_executor/layers/fused_moe/utils.py
Normal file
31
vllm_mlu/model_executor/layers/fused_moe/utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from math import prod
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
|
||||
def _fp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
block_shape: List[int],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform fp8 quantization on the inputs. If a block_shape
|
||||
is provided, the output will be blocked.
|
||||
"""
|
||||
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
assert block_shape is not None
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
return A, A_scale
|
||||
|
||||
|
||||
278
vllm_mlu/model_executor/layers/indexer.py
Normal file
278
vllm_mlu/model_executor/layers/indexer.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.model_executor.layers.compressor import (
|
||||
Compressor,
|
||||
rotate_activation,
|
||||
)
|
||||
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Indexer(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
rope,
|
||||
compress_ratio: int = 4,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.dim = config.dim
|
||||
self.n_heads = config.index_n_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.n_local_heads = config.index_n_heads // self.tp_size
|
||||
self.head_dim = config.index_head_dim
|
||||
self.rope_head_dim = config.rope_head_dim
|
||||
self.index_topk = config.index_topk
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.window_size = config.window_size
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
self.wq_b = ReplicatedLinear(
|
||||
self.q_lora_rank,
|
||||
self.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.wq_b",
|
||||
)
|
||||
|
||||
self.weights_proj = ReplicatedLinear(
|
||||
self.dim,
|
||||
self.n_heads,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
params_dtype = torch.bfloat16,
|
||||
prefix=f"{prefix}.weights_proj",
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_dim ** -0.5
|
||||
self.merged_softmax_scale = (self.head_dim ** -0.5) * (self.n_heads ** -0.5)
|
||||
self.compress_ratio = compress_ratio
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
self.rotary_emb = rope
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
self.compressor = Compressor(vllm_config, self.rotary_emb, compress_ratio, self.head_dim, True, f"{prefix}.compressor")
|
||||
|
||||
self.freqs_cis = None
|
||||
|
||||
def forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
k_full: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
):
|
||||
assert attn_metadata.prefill.chunked_context is None, \
|
||||
f"Prefill chunked context is not supported."
|
||||
|
||||
query_start_loc = attn_metadata.prefill.query_start_loc
|
||||
cu_seq_q_lens = query_start_loc
|
||||
cu_seq_k_lens = torch.zeros(
|
||||
context_lens.size(0) + 1, dtype=torch.int32, device=q.device,
|
||||
)
|
||||
torch.cumsum(context_lens, dim=0, out=cu_seq_k_lens[1:])
|
||||
attn_metadata.prefill.query_start_loc
|
||||
seq_lens = torch.diff(cu_seq_k_lens)
|
||||
|
||||
batch_size = seq_lens.shape[0]
|
||||
|
||||
new_block_tables = torch.empty(
|
||||
[attn_metadata.num_prefill_tokens, self.index_topk],
|
||||
dtype=torch.int32,
|
||||
device=q.device,
|
||||
)
|
||||
new_context_lens = torch.empty(
|
||||
[attn_metadata.num_prefill_tokens],
|
||||
dtype=torch.int32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
q_seq_lens = cu_seq_q_lens[1:]-cu_seq_q_lens[:-1]
|
||||
max_seq_len = q_seq_lens.max().item()
|
||||
batch_size = q_seq_lens.size(0)
|
||||
max_compressed_kv_len = max_seq_len // self.compress_ratio
|
||||
kv_cache_block_table = torch.zeros([batch_size, max_compressed_kv_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
# The layout of linear kv is as follows:
|
||||
# | bs0_origin_kv | bs1_origin_kv | bs0_compressed_kv | bs1_compressed_kv |
|
||||
for i in range(batch_size):
|
||||
start = cu_seq_k_lens[i].item()
|
||||
kv_cache_block_table[i] = torch.arange(
|
||||
start, start + max_compressed_kv_len,
|
||||
dtype=torch.int32,
|
||||
device=q.device,
|
||||
)
|
||||
# offset total origin_kv len
|
||||
kv_cache_block_table = kv_cache_block_table + cu_seq_q_lens[-1]
|
||||
|
||||
# query: (tokens, index_head, index_head_dim)
|
||||
# k_full: (tokens, index_head_dim)
|
||||
# weights: (tokens, index_head, 1)
|
||||
mlu_ops.masked_indexer_select_paged_kv_prefill(
|
||||
query=q,
|
||||
key_value=k_full,
|
||||
weights=weights.unsqueeze(-1),
|
||||
kv_cache_block_table=kv_cache_block_table,
|
||||
cu_seq_q_lens=cu_seq_q_lens,
|
||||
cu_seq_k_lens=cu_seq_k_lens,
|
||||
index_topk=self.index_topk,
|
||||
kv_cache_block_size=self.block_size,
|
||||
softmax_scale=self.merged_softmax_scale,
|
||||
q_scale=None,
|
||||
k_scale_cache=None,
|
||||
sparse_block_table=new_block_tables,
|
||||
sparse_context_lens=new_context_lens,
|
||||
compress_ratio=self.compress_ratio,
|
||||
kv_cache_block_table_offset=None,
|
||||
)
|
||||
|
||||
return new_block_tables, new_context_lens
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
block_table = attn_metadata.decode.block_table
|
||||
batch_size = block_table.shape[0]
|
||||
seq_len = x.shape[0] // batch_size
|
||||
q = q.view(batch_size, seq_len, *q.shape[1:])
|
||||
weights = weights.view(batch_size, seq_len, *weights.shape[1:])
|
||||
|
||||
seq_lens = attn_metadata.decode.seq_lens
|
||||
k_block_table = block_table
|
||||
|
||||
seq_len = x.shape[0] // batch_size
|
||||
new_block_tables = torch.empty(
|
||||
[batch_size, seq_len, self.index_topk],
|
||||
dtype=torch.int32,
|
||||
device=block_table.device,
|
||||
)
|
||||
new_context_lens = torch.empty(
|
||||
[attn_metadata.num_decode_tokens],
|
||||
dtype=torch.int32,
|
||||
device=block_table.device,
|
||||
)
|
||||
|
||||
kv_cache_block_table_offset=torch.empty(
|
||||
[attn_metadata.num_decode_tokens],
|
||||
dtype=torch.int32,
|
||||
device=block_table.device,
|
||||
)
|
||||
kv_cache_block_table_offset.fill_(self.window_size)
|
||||
mlu_ops.masked_indexer_select_paged_kv_decode(
|
||||
query=q,
|
||||
k_cache=k_cache,
|
||||
weights=weights.unsqueeze(-1), # (bsz, seq_q, head_num, 1)
|
||||
kv_cache_block_table=block_table,
|
||||
k_context_lens=seq_lens // self.compress_ratio,
|
||||
k_cache_block_table=k_block_table,
|
||||
index_topk=self.index_topk,
|
||||
kv_cache_block_size=self.block_size,
|
||||
softmax_scale=self.merged_softmax_scale,
|
||||
q_scale=None,
|
||||
k_scale_cache=None,
|
||||
sparse_block_table=new_block_tables,
|
||||
sparse_context_lens=new_context_lens,
|
||||
compress_ratio=self.compress_ratio,
|
||||
kv_cache_block_table_offset=kv_cache_block_table_offset,
|
||||
)
|
||||
|
||||
# [batch, seq_q, index_topk] -> [batch, index_topk]
|
||||
new_block_tables = new_block_tables.squeeze(1)
|
||||
return new_block_tables, new_context_lens
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
qr: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
batch_to_kv_state: torch.Tensor,
|
||||
indexer_kv_cache: torch.Tensor,
|
||||
compressor_slot_mapping: torch.Tensor,
|
||||
):
|
||||
common_metadata = get_common_metadata()
|
||||
query_start_loc = common_metadata.query_start_loc
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
rd = self.rope_head_dim
|
||||
q = self.wq_b(qr)[0]
|
||||
q = q.unflatten(-1, (self.n_heads, self.head_dim))
|
||||
|
||||
self.rotary_emb(positions, q[..., -rd:], None, only_prefill=False)
|
||||
q_pack = rotate_activation(q)
|
||||
weights_pack = self.weights_proj(x)[0] # (tokens, index_local_head)
|
||||
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
compressed_kv = self.compressor(
|
||||
x,
|
||||
positions,
|
||||
attn_metadata,
|
||||
batch_to_kv_state,
|
||||
indexer_kv_cache,
|
||||
0,
|
||||
compressor_slot_mapping,
|
||||
)
|
||||
if attn_metadata.prefill:
|
||||
assert compressed_kv is not None and compressed_kv.dim() == 3
|
||||
compressed_kv = compressed_kv.squeeze(-2)
|
||||
compressed_context_lens = query_lens // self.compress_ratio
|
||||
|
||||
prefill_q = q_pack[num_decode_tokens:, ...]
|
||||
prefill_weights = weights_pack[num_decode_tokens:, ...]
|
||||
prefill_block_tables, prefill_context_lens = self.forward_prefill(
|
||||
prefill_q,
|
||||
indexer_kv_cache,
|
||||
prefill_weights,
|
||||
attn_metadata,
|
||||
compressed_kv,
|
||||
compressed_context_lens,
|
||||
)
|
||||
|
||||
if attn_metadata.decode:
|
||||
decode_x = x[:num_decode_tokens, ...]
|
||||
decode_q = q_pack[:num_decode_tokens, ...]
|
||||
decode_weights = weights_pack[attn_metadata.num_prefills:]
|
||||
decode_block_tables, decode_context_lens = self.forward_decode(
|
||||
decode_q,
|
||||
decode_x,
|
||||
indexer_kv_cache,
|
||||
decode_weights,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
if attn_metadata.prefill and attn_metadata.decode:
|
||||
new_block_tables = torch.cat([prefill_block_tables, decode_block_tables], dim=0)
|
||||
new_context_lens = torch.cat([prefill_context_lens, decode_context_lens], dim=0)
|
||||
elif attn_metadata.prefill:
|
||||
new_block_tables = prefill_block_tables
|
||||
new_context_lens = prefill_context_lens
|
||||
else:
|
||||
new_block_tables = decode_block_tables
|
||||
new_context_lens = decode_context_lens
|
||||
return new_block_tables, new_context_lens
|
||||
130
vllm_mlu/model_executor/layers/layernorm.py
Normal file
130
vllm_mlu/model_executor/layers/layernorm.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.model_executor.models.layer_utils import is_per_token_smoothquant
|
||||
|
||||
|
||||
@CustomOp.register("quant_fusion_rms_norm")
|
||||
class QuantFusionRMSNorm(RMSNorm):
|
||||
def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase):
|
||||
super().__init__(hidden_size, variance_epsilon)
|
||||
assert not isinstance(
|
||||
proj.quant_method, UnquantizedLinearMethod
|
||||
), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported"
|
||||
proj.quant_method.skip_quant_input = True
|
||||
if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config):
|
||||
quant_scale = proj.smooth.data
|
||||
else:
|
||||
quant_scale = proj.scale_to_int.data
|
||||
self.dynamic_quant = dynamic_quant
|
||||
self.quant_scale = torch.nn.Parameter(quant_scale)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, residual: torch.Tensor | None = None
|
||||
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
|
||||
return mlu_ops.fused_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
None,
|
||||
None,
|
||||
self.variance_epsilon,
|
||||
False,
|
||||
self.quant_scale.data,
|
||||
self.dynamic_quant,
|
||||
)
|
||||
|
||||
|
||||
@CustomOp.register("quant_fusion_layer_norm")
|
||||
class QuantFusionLayerNorm(torch.nn.LayerNorm, CustomOp):
|
||||
def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase):
|
||||
super().__init__(hidden_size, variance_epsilon)
|
||||
assert not isinstance(
|
||||
proj.quant_method, UnquantizedLinearMethod
|
||||
), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported"
|
||||
proj.quant_method.skip_quant_input = True
|
||||
if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config):
|
||||
quant_scale = proj.smooth.data
|
||||
else:
|
||||
quant_scale = proj.scale_to_int.data
|
||||
self.dynamic_quant = dynamic_quant
|
||||
self.quant_scale = torch.nn.Parameter(quant_scale)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, residual: torch.Tensor | None = None
|
||||
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
|
||||
bias = None if self.bias is None else self.bias.data
|
||||
return mlu_ops.fused_layer_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
bias,
|
||||
None,
|
||||
self.eps,
|
||||
False,
|
||||
self.quant_scale.data,
|
||||
self.dynamic_quant,
|
||||
)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__layernorm__RMSNorm__forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
|
||||
|
||||
org_shape = x.shape
|
||||
x = x.reshape(-1, self.weight.data.shape[0])
|
||||
if out is not None:
|
||||
out = out.view(-1, self.weight.data.shape[0])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, self.weight.data.shape[0])
|
||||
x = mlu_ops.fused_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
None,
|
||||
None,
|
||||
self.variance_epsilon,
|
||||
True,
|
||||
out=out,
|
||||
)
|
||||
else:
|
||||
x = mlu_ops.fused_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
None,
|
||||
None,
|
||||
self.variance_epsilon,
|
||||
False,
|
||||
out=out,
|
||||
)
|
||||
|
||||
if out is not None:
|
||||
return x
|
||||
|
||||
if residual is None:
|
||||
assert isinstance(x, torch.Tensor)
|
||||
return x.view(org_shape)
|
||||
|
||||
assert isinstance(x, tuple)
|
||||
assert len(x) == 2
|
||||
return x[0].view(org_shape), x[1].view(org_shape)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
RMSNorm,
|
||||
RMSNorm.forward_oot,
|
||||
vllm__model_executor__layers__layernorm__RMSNorm__forward_oot,
|
||||
)
|
||||
693
vllm_mlu/model_executor/layers/linear.py
Normal file
693
vllm_mlu/model_executor/layers/linear.py
Normal file
@@ -0,0 +1,693 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Optional, Any
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.distributed import (divide, split_tensor_along_last_dim,
|
||||
get_parallel_rank_with_group, get_parallel_world_size_with_group,
|
||||
get_tp_world_group, get_tp_world_world_size, get_tp_world_rank)
|
||||
from vllm.distributed.communication_op import (
|
||||
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.linear import (
|
||||
WEIGHT_LOADER_V2_SUPPORTED, UnquantizedLinearMethod, LinearBase,
|
||||
ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED.extend([
|
||||
"GPTQMluLinearMethod",
|
||||
"AWQMluLinearMethod"
|
||||
])
|
||||
|
||||
vllm__module_executor__layers__linear__LinearBase____init__org = LinearBase.__init__
|
||||
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader_org = MergedColumnParallelLinear.weight_loader
|
||||
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader_org = RowParallelLinear.weight_loader
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual parameter.
|
||||
@brief: dispatch unquantized_gemm to mlu ops.
|
||||
'''
|
||||
def vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
residual: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
beta = 0.0
|
||||
if residual is not None:
|
||||
beta = 1.0
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
|
||||
return mlu_ops.matmul(x.reshape(x.numel() // x.shape[-1], x.shape[-1]),
|
||||
layer.weight,
|
||||
bias, residual, 'none', 1.0, beta).view(res_shape)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group and keep_full_weights parameters.
|
||||
'''
|
||||
def vllm__module_executor__layers__linear__LinearBase____init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
tp_group: Any = None,
|
||||
keep_full_weights: bool = False,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
vllm__module_executor__layers__linear__LinearBase____init__org(
|
||||
self=self,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add self.tp_group, world_size and tp_rank to support data parallel and moe expert parallel
|
||||
'''
|
||||
self.tp_group = tp_group
|
||||
self.tp_world_size = get_parallel_world_size_with_group(self.tp_group)
|
||||
self.tp_size = self.tp_world_size
|
||||
self.tp_rank = get_parallel_rank_with_group(self.tp_group)
|
||||
|
||||
self.keep_full_weights = keep_full_weights
|
||||
if self.keep_full_weights or disable_tp:
|
||||
self.tp_group = None
|
||||
self.tp_world_size = 1
|
||||
self.tp_size = self.tp_world_size
|
||||
self.tp_rank = 0
|
||||
self.tp_world_size_org = get_tp_world_world_size()
|
||||
self.tp_rank_org = get_tp_world_rank()
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group and keep_full_weights parameters.
|
||||
'''
|
||||
def vllm__module_executor__layers__linear__ColumnParallelLinear____init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
output_sizes: list[int] | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
tp_group: Any = None,
|
||||
keep_full_weights: bool = False,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
super(ColumnParallelLinear, self).__init__(
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
tp_group=tp_group,
|
||||
keep_full_weights=keep_full_weights,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: self.tp_size and self.tp_rank has been initialized in LinearBase.__init__
|
||||
'''
|
||||
# Divide the weight matrix along the last dimension.
|
||||
# self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
|
||||
# self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size) for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group in create_weights
|
||||
'''
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2
|
||||
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
|
||||
else self.weight_loader
|
||||
),
|
||||
tp_group=self.tp_group,
|
||||
)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition, dtype=params_dtype)
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.bias,
|
||||
{
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
self.update_param_tp_status()
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add smooth_quant_scale and use_tp_weight parameters.
|
||||
'''
|
||||
def vllm__module_executor__layers__linear__ColumnParallelLinear__forward(
|
||||
self,
|
||||
input_,
|
||||
smooth_quant_scale: torch.Tensor | None = None,
|
||||
use_tp_weight: bool = False,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add input_scale and use_tp_weight parameter.
|
||||
'''
|
||||
kwargs = {'bias': bias}
|
||||
if use_tp_weight:
|
||||
kwargs['use_tp_weight'] = use_tp_weight
|
||||
if smooth_quant_scale is not None:
|
||||
kwargs['input_scale'] = smooth_quant_scale
|
||||
output_parallel = self.quant_method.apply(self, input_, **kwargs)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.gather_output and self.tp_size > 1:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group param to tensor_model_parallel_all_gather
|
||||
'''
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel, dim=-1, tp_group=self.tp_group)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group and keep_full_weights parameters.
|
||||
'''
|
||||
def vllm__module_executor__layers__linear__MergedColumnParallelLinear____init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
tp_group: Any = None,
|
||||
keep_full_weights: bool = False,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: checkout output_sizes after init to get self.tp_world_size
|
||||
@brief: add keep_full_weights for dp parallelize shared expert
|
||||
'''
|
||||
super(MergedColumnParallelLinear, self).__init__(
|
||||
input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
bias=bias,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
output_sizes=self.output_sizes,
|
||||
prefix=prefix,
|
||||
tp_group=tp_group,
|
||||
keep_full_weights=keep_full_weights,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp,
|
||||
)
|
||||
|
||||
assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
|
||||
|
||||
if self.keep_full_weights:
|
||||
tp_size = self.tp_world_size_org
|
||||
if isinstance(self.quant_method, UnquantizedLinearMethod):
|
||||
out_dim, in_dim = self.weight.shape
|
||||
out_dim_tp = divide(out_dim, tp_size)
|
||||
self.tp_weight = Parameter(
|
||||
self.weight.data.new_empty((out_dim_tp, in_dim)),
|
||||
requires_grad=False,
|
||||
)
|
||||
elif (isinstance(self.quant_method, SmoothQuantLinearMethod)
|
||||
and quant_config.input_quant_method == "per_token"):
|
||||
out_dim, in_dim = self.qweight.shape
|
||||
out_dim_tp = divide(out_dim, tp_size)
|
||||
self.tp_qweight = Parameter(
|
||||
self.qweight.data.new_empty((out_dim_tp, in_dim)),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.tp_per_channel_scale = Parameter(
|
||||
self.per_channel_scale.data.new_empty((out_dim_tp)),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"quant method is expected to be unquantized or smoothquant per-token")
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader(
|
||||
self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: int | None = None,
|
||||
):
|
||||
loaded_weight_orig = loaded_weight
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader_org(
|
||||
self=self,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
loaded_shard_id=loaded_shard_id,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add keep_full_weights for dp parallelize shared expert
|
||||
'''
|
||||
# load into tp weight
|
||||
if self.keep_full_weights:
|
||||
tp_size = self.tp_world_size_org
|
||||
tp_rank = self.tp_rank_org
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
start_idx = tp_rank * shard_size
|
||||
if isinstance(self.quant_method, UnquantizedLinearMethod):
|
||||
tp_weight = loaded_weight_orig.narrow(output_dim, start_idx, shard_size)
|
||||
tp_weight_shard = self.tp_weight.narrow(output_dim, shard_offset, shard_size)
|
||||
tp_weight_shard.copy_(tp_weight)
|
||||
elif isinstance(self.quant_method, SmoothQuantLinearMethod):
|
||||
if output_dim is None:
|
||||
return
|
||||
tp_weight = loaded_weight_orig.narrow(output_dim, start_idx, shard_size)
|
||||
if loaded_weight_orig.ndim == 1:
|
||||
tp_weight_shard = self.tp_per_channel_scale.narrow(output_dim, shard_offset, shard_size)
|
||||
elif loaded_weight_orig.ndim == 2:
|
||||
tp_weight_shard = self.tp_qweight.narrow(output_dim, shard_offset, shard_size)
|
||||
else:
|
||||
raise ValueError("only support rank 1 and 2 when using tp_weight")
|
||||
|
||||
tp_weight_shard.copy_(tp_weight)
|
||||
else:
|
||||
raise TypeError(f"quant method is expected to be either unquantized or smoothquant")
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__layers__linear__RowParallelLinear____init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
tp_group: Any = None,
|
||||
keep_full_weights: bool = False,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
super(RowParallelLinear, self).__init__(
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix=prefix,
|
||||
tp_group=tp_group,
|
||||
keep_full_weights=keep_full_weights,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp,
|
||||
)
|
||||
|
||||
# Divide the weight matrix along the last dimension
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
assert self.quant_method is not None
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group in create_weights
|
||||
'''
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2
|
||||
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
|
||||
else self.weight_loader
|
||||
),
|
||||
tp_group=self.tp_group,
|
||||
)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(
|
||||
self.bias,
|
||||
{
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add keep_full_weights for dp parallelize shared expert
|
||||
'''
|
||||
if self.keep_full_weights:
|
||||
tp_size = self.tp_world_size_org
|
||||
if isinstance(self.quant_method, UnquantizedLinearMethod):
|
||||
out_dim, in_dim = self.weight.data.shape
|
||||
in_dim_tp = divide(in_dim, tp_size)
|
||||
self.tp_weight = Parameter(self.weight.data.new_empty((out_dim, in_dim_tp)),
|
||||
requires_grad=False)
|
||||
elif (isinstance(self.quant_method, SmoothQuantLinearMethod)
|
||||
and quant_config.input_quant_method == "per_token"):
|
||||
out_dim, in_dim = self.qweight.data.shape
|
||||
in_dim_tp = divide(in_dim, tp_size)
|
||||
self.tp_qweight = Parameter(self.qweight.data.new_empty((out_dim, in_dim_tp)),
|
||||
requires_grad=False)
|
||||
if hasattr(self, "smooth"):
|
||||
assert len(self.smooth.shape) == 1, "smooth should be a 1D tensor"
|
||||
dim = self.smooth.shape[0]
|
||||
dim_tp = divide(dim, tp_size)
|
||||
self.tp_smooth = Parameter(self.smooth.data.new_empty((dim_tp)),
|
||||
requires_grad=False)
|
||||
else:
|
||||
raise TypeError("quant method expected to be unquantized or smoothquant per-token")
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
self.update_param_tp_status()
|
||||
|
||||
|
||||
def vllm__module_executor__layers__linear__RowParallelLinear__weight_loader(
|
||||
self, param: Parameter, loaded_weight: torch.Tensor
|
||||
):
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
loaded_weight_orig = loaded_weight
|
||||
|
||||
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader_org(
|
||||
self=self,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add keep_full_weights for dp parallelize shared expert
|
||||
'''
|
||||
if self.keep_full_weights:
|
||||
if input_dim is None:
|
||||
return
|
||||
tp_size = self.tp_world_size_org
|
||||
tp_rank = self.tp_rank_org
|
||||
shard_size = divide(loaded_weight_orig.shape[input_dim], tp_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
if isinstance(self.quant_method, UnquantizedLinearMethod):
|
||||
shard_view = self.weight.narrow(input_dim, start_idx, shard_size)
|
||||
self.tp_weight.copy_(shard_view)
|
||||
elif isinstance(self.quant_method, SmoothQuantLinearMethod):
|
||||
if loaded_weight_orig.ndim == 1:
|
||||
shard_view = self.smooth.narrow(input_dim, start_idx, shard_size)
|
||||
self.tp_smooth.copy_(shard_view)
|
||||
elif loaded_weight_orig.ndim == 2:
|
||||
shard_view = self.qweight.narrow(input_dim, start_idx, shard_size)
|
||||
self.tp_qweight.copy_(shard_view)
|
||||
else:
|
||||
raise ValueError("only rank 1 and 2 is supported for tp_weight")
|
||||
else:
|
||||
raise TypeError("quant method is expected to be UnquantizedLinearMethod and SmoothQuant")
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual, smooth_quant_scale, use_tp_weight and output parameters.
|
||||
'''
|
||||
def vllm__module_executor__layers__linear__RowParallelLinear__forward(
|
||||
self,
|
||||
input_,
|
||||
residual: torch.Tensor | None = None,
|
||||
smooth_quant_scale: torch.Tensor | None = None,
|
||||
use_tp_weight: bool = False,
|
||||
output: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add additional matmul parameters.
|
||||
'''
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
kwargs = {'bias': bias_, 'residual': residual_}
|
||||
if use_tp_weight:
|
||||
kwargs['use_tp_weight'] = use_tp_weight
|
||||
if smooth_quant_scale is not None:
|
||||
kwargs['input_scale'] = smooth_quant_scale
|
||||
if output is not None:
|
||||
kwargs['output'] = output
|
||||
output_parallel = self.quant_method.apply(self, input_parallel, **kwargs)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tensor_model_parallel_all_reduce() with self.tp_group
|
||||
'''
|
||||
output = tensor_model_parallel_all_reduce(output_parallel, tp_group=self.tp_group)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(UnquantizedLinearMethod,
|
||||
UnquantizedLinearMethod.apply,
|
||||
vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply)
|
||||
MluHijackObject.apply_hijack(LinearBase,
|
||||
LinearBase.__init__,
|
||||
vllm__module_executor__layers__linear__LinearBase____init__)
|
||||
MluHijackObject.apply_hijack(ColumnParallelLinear,
|
||||
ColumnParallelLinear.__init__,
|
||||
vllm__module_executor__layers__linear__ColumnParallelLinear____init__)
|
||||
MluHijackObject.apply_hijack(ColumnParallelLinear,
|
||||
ColumnParallelLinear.forward,
|
||||
vllm__module_executor__layers__linear__ColumnParallelLinear__forward)
|
||||
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
|
||||
MergedColumnParallelLinear.__init__,
|
||||
vllm__module_executor__layers__linear__MergedColumnParallelLinear____init__)
|
||||
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
|
||||
MergedColumnParallelLinear.weight_loader,
|
||||
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader)
|
||||
MluHijackObject.apply_hijack(RowParallelLinear,
|
||||
RowParallelLinear.__init__,
|
||||
vllm__module_executor__layers__linear__RowParallelLinear____init__)
|
||||
MluHijackObject.apply_hijack(RowParallelLinear,
|
||||
RowParallelLinear.weight_loader,
|
||||
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader)
|
||||
MluHijackObject.apply_hijack(RowParallelLinear,
|
||||
RowParallelLinear.forward,
|
||||
vllm__module_executor__layers__linear__RowParallelLinear__forward)
|
||||
744
vllm_mlu/model_executor/layers/longcat_sparse_moe_mlp.py
Normal file
744
vllm_mlu/model_executor/layers/longcat_sparse_moe_mlp.py
Normal file
@@ -0,0 +1,744 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
"""Inference-only MOE model."""
|
||||
from typing import Optional, Any, List, Dict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.distributed import (
|
||||
divide,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.distributed.parallel_state import(
|
||||
cnclep_dispatch, cnclep_combine)
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
|
||||
class LongCatSparseMoeMlp(SparseMoeMlp):
|
||||
"""
|
||||
sparse moe mlp layer specific to longcat model
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
up_proj_name: str,
|
||||
is_gated: bool,
|
||||
down_proj_name: str,
|
||||
has_bias: bool,
|
||||
skip_bias_add: bool = False,
|
||||
renormalize:bool = False,
|
||||
hidden_act: str = "silu",
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
is_use_fused_moe: bool = False,
|
||||
expert_group: Optional[int] = 1,
|
||||
topk_group: Optional[int] = 1,
|
||||
scoring_func: str = "softmax",
|
||||
topk_method: str = "",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
tp_group: Any = None,
|
||||
use_all2all: bool = False,
|
||||
num_zero_experts: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
up_proj_name=up_proj_name,
|
||||
is_gated=is_gated,
|
||||
down_proj_name=down_proj_name,
|
||||
has_bias=has_bias,
|
||||
skip_bias_add=skip_bias_add,
|
||||
renormalize=renormalize,
|
||||
hidden_act=hidden_act,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
is_use_fused_moe=is_use_fused_moe,
|
||||
expert_group=expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
topk_method=topk_method,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
tp_group=tp_group,
|
||||
use_all2all=use_all2all,
|
||||
init_avg_moe=False,
|
||||
)
|
||||
self.num_zero_experts = num_zero_experts
|
||||
self.total_experts_including_zero = self.num_total_experts + self.num_zero_experts
|
||||
self.use_quant_all2all = use_all2all and quant_config is not None
|
||||
self.zero_expert_size = divide(self.num_zero_experts, self.moe_ep_size)
|
||||
self.start_zero_expert_id = (
|
||||
self.num_total_experts + self.moe_ep_rank * ((self.num_zero_experts + self.moe_ep_size - 1) // self.moe_ep_size)
|
||||
)
|
||||
|
||||
if VLLM_AVG_MOE_EN and not SparseMoeMlp.is_expert_avg:
|
||||
n_tokens = SparseMoeMlp.max_batched_token * self.dp_size
|
||||
expert_group = self.moe_ep_size
|
||||
val = 1.0 / float(self.total_experts_including_zero)
|
||||
SparseMoeMlp.reduce_weight = torch.full((n_tokens, top_k), val, device="mlu", dtype=torch.float32)
|
||||
if VLLM_RANDOM_MOE_EN:
|
||||
import numpy as np
|
||||
# example deepseekv2: experts 160 topk 6
|
||||
# avg list: 92, 8, 88, 45, 99, 9,... 118, 142, 116, 57, 104, 6,......
|
||||
array = np.stack([np.random.permutation(self.total_experts_including_zero)[:top_k] for _ in range(n_tokens)])
|
||||
table = torch.from_numpy(array.flatten()).to(device="mlu", dtype=torch.int32)
|
||||
else:
|
||||
# example deepseekv2: experts 160
|
||||
# avg list: 0,20,40,60,80...120,140, 1,21,...121,141, 2...142, ...... 19,...159, 0,20,......
|
||||
import math
|
||||
batch_table = math.ceil(n_tokens * top_k / self.total_experts_including_zero) * self.total_experts_including_zero
|
||||
hi_val = batch_table // self.total_experts_including_zero
|
||||
table = (torch.arange(hi_val * num_experts, device="mlu", dtype=torch.int32) % num_experts).view(
|
||||
hi_val, expert_group, num_experts // expert_group).transpose(1, 2)
|
||||
if self.num_zero_experts > 0:
|
||||
# Longcat model, for avg expert, we choose eight non-zero experts and four zero
|
||||
# experts for each token accorrding to the paper.
|
||||
assert num_experts == 512 and num_zero_experts == 256 and top_k == 12
|
||||
assert num_zero_experts % expert_group == 0
|
||||
non_zero_expert_num_per_token = 8
|
||||
zero_expert_num_per_token = 4
|
||||
zero_expert_table = torch.arange(
|
||||
num_experts, num_experts + num_zero_experts, dtype=table.dtype, device=table.device).view(
|
||||
expert_group, num_zero_experts // expert_group).transpose(0, 1).flatten()
|
||||
non_zero_expert_table = table[0].flatten()
|
||||
token_expert_list = []
|
||||
for idx in range(0, num_experts // non_zero_expert_num_per_token):
|
||||
token_expert_list.append(non_zero_expert_table[
|
||||
idx * non_zero_expert_num_per_token:
|
||||
idx * non_zero_expert_num_per_token + non_zero_expert_num_per_token])
|
||||
token_expert_list.append(zero_expert_table[
|
||||
idx * zero_expert_num_per_token:
|
||||
idx * zero_expert_num_per_token + zero_expert_num_per_token])
|
||||
avg_expert_table = torch.cat(token_expert_list)
|
||||
table = avg_expert_table.repeat(hi_val)
|
||||
SparseMoeMlp.expert_id = table.flatten()[:n_tokens * top_k].view(n_tokens, top_k)
|
||||
SparseMoeMlp.is_expert_avg = True
|
||||
|
||||
|
||||
def forward_experts_nofused_longcat(
|
||||
self, hidden_states, total_num_experts, total_num_experts_per_rank,
|
||||
topk_indices=None, topk_weights=None, residual_=None):
|
||||
assert self.moe_ep_size == 1
|
||||
assert not self.use_all2all
|
||||
expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count = mlu_ops.moe_gen_idx(
|
||||
topk_indices.to(torch.int32), total_num_experts)
|
||||
# no expert is routed, then expand_gather_idx, expand_scatter_idx has no item,
|
||||
# expand_token_count and expand_cusum_token_count has item but the value is all zero
|
||||
# so this rank should only return final_hidden_states with zero value
|
||||
if cusum_token_count[-1] == 0:
|
||||
final_hidden_states = torch.zeros_like(hidden_states,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
return final_hidden_states
|
||||
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_gather_idx, cusum_token_count,
|
||||
start_expert_id=self.start_expert_id,
|
||||
expert_size=self.end_expert_id - self.start_expert_id)
|
||||
expand_hidden_states_zero = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_gather_idx, cusum_token_count,
|
||||
start_expert_id=self.start_zero_expert_id,
|
||||
expert_size=self.zero_expert_size)
|
||||
|
||||
expand_output_list = []
|
||||
expand_cusum_token_count = cusum_token_count[self.start_expert_id:self.end_expert_id +
|
||||
1] - cusum_token_count[self.start_expert_id]
|
||||
|
||||
for expert_idx, num_tokens_per_expert in enumerate(expand_token_count[:self.num_total_experts]):
|
||||
if num_tokens_per_expert > 0:
|
||||
expert_hidden_states = expand_hidden_states[
|
||||
expand_cusum_token_count[expert_idx]:expand_cusum_token_count[expert_idx + 1]]
|
||||
if expert_idx < self.num_total_experts:
|
||||
expert_output = self.experts[expert_idx](expert_hidden_states)
|
||||
else:
|
||||
expert_output = expert_hidden_states
|
||||
expert_output = expert_output[0] if isinstance(expert_output, (tuple, list)) else expert_output
|
||||
expand_output_list.append(expert_output)
|
||||
expand_output = torch.cat(expand_output_list, dim=0)
|
||||
num_normal_tokens = cusum_token_count[self.num_total_experts]
|
||||
expand_hidden_states[:num_normal_tokens] = expand_output
|
||||
# reduce normal experts
|
||||
final_hidden_states = mlu_ops.moe_combine_result(
|
||||
expand_hidden_states, topk_weights, scatter_idx,
|
||||
residual_, cusum_token_count, start_expert_id=self.start_expert_id,
|
||||
expert_size=self.end_expert_id - self.start_expert_id, bias=None)
|
||||
# reduce zero experts
|
||||
if self.moe_ep_size > 1 or self.moe_tp_rank == 0:
|
||||
final_hidden_states = mlu_ops.moe_combine_result(
|
||||
expand_hidden_states_zero, topk_weights, scatter_idx,
|
||||
final_hidden_states, cusum_token_count, start_expert_id=self.start_zero_expert_id,
|
||||
expert_size=self.zero_expert_size, bias=None,
|
||||
output=final_hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
# no compute-communication parallel, for prototyping only, not in actual use.
|
||||
# subject to becoming stale
|
||||
def forward_all2all_int8_longcat(
|
||||
self, hidden_states, total_num_experts, total_num_experts_per_rank,
|
||||
topk_indices=None, topk_weights=None, residual_=None):
|
||||
ori_input_shape = hidden_states.shape
|
||||
dtype = hidden_states.dtype
|
||||
self.pack_params()
|
||||
self.pack_params_after_loading()
|
||||
w1=self.w13
|
||||
w2=self.w2
|
||||
bias2=self.b2
|
||||
input_smooth=self.a13_scale_all_experts
|
||||
act_smooth=self.a2_scale
|
||||
w1_scale=self.w13_scale
|
||||
w2_scale=self.w2_scale
|
||||
act_mode=self.hidden_act
|
||||
quant_input=None
|
||||
|
||||
max_m = hidden_states.shape[0]
|
||||
|
||||
reduce_weight = topk_weights
|
||||
expert_id = topk_indices
|
||||
|
||||
expand_idx, combine_idx, token_count, cusum_token_count \
|
||||
= mlu_ops.moe_gen_idx(expert_id, total_num_experts)
|
||||
|
||||
num_token_expand = hidden_states.shape[0] * self.top_k
|
||||
dispatch_bytes = num_token_expand * self.dispatch_token_size
|
||||
|
||||
dispatch_send_token_tensor = (
|
||||
self.dispatch_send_buffer[:dispatch_bytes]
|
||||
.view(num_token_expand, self.dispatch_token_size)
|
||||
)
|
||||
|
||||
quant_size = self.hidden_size
|
||||
quant_input = dispatch_send_token_tensor[:, : quant_size]
|
||||
input_scale = dispatch_send_token_tensor[:, quant_size :].view(torch.float32)
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(
|
||||
hidden_states, input_smooth, None, token_count[:self.num_total_experts],
|
||||
expand_idx, None,
|
||||
output=quant_input,
|
||||
output_scale=input_scale)
|
||||
expand_hidden_states_zero = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count,
|
||||
start_expert_id=self.num_total_experts,
|
||||
expert_size=self.num_zero_experts)
|
||||
|
||||
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
|
||||
token_count[:self.num_total_experts], self.moe_ep_size)
|
||||
|
||||
cnclep_dispatch(self.dispatch_token_size,
|
||||
num_token_expand,
|
||||
dispatch_send_layout,
|
||||
token_count[:self.num_total_experts],
|
||||
self.dispatch_recv_layout,
|
||||
self.dispatch_recv_token_num)
|
||||
|
||||
recv_token_num = self.dispatch_recv_token_num.view(
|
||||
self.moe_ep_size, self.num_experts_per_rank)
|
||||
pad_num = self.max_num_tokens_per_rank
|
||||
|
||||
(
|
||||
gather_by_expert_index,
|
||||
gather_by_rank_index,
|
||||
tokens_per_local_expert,
|
||||
token_sum
|
||||
) = mlu_ops.moe_all2all_gen_gather_index(recv_token_num, pad_num)
|
||||
|
||||
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
|
||||
dispatch_recv_token_tensor = (
|
||||
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
|
||||
.view(self.max_num_tokens_recv, self.dispatch_token_size))
|
||||
|
||||
mlu_ops.gather_split(dispatch_recv_token_tensor,
|
||||
gather_by_expert_index,
|
||||
token_sum,
|
||||
self.quant_input_recv,
|
||||
self.input_scale_recv)
|
||||
|
||||
max_m = self.max_num_tokens_per_expert
|
||||
gemm_out = mlu_ops.smooth_quant_group_gemm(self.quant_input_recv, w1,
|
||||
tokens_per_local_expert,
|
||||
None, None, None, None,
|
||||
self.input_scale_recv.view(torch.float32).flatten(),
|
||||
w1_scale, dtype, max_m)
|
||||
|
||||
# continue reusing self.quant_input_recv and self.input_scale_recv
|
||||
quant_input = self.quant_input_recv[:, :gemm_out.shape[-1] // 2]
|
||||
input_scale_fp32 = self.input_scale_recv.view(torch.float32).flatten()[:gemm_out.shape[0]]
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None,
|
||||
tokens_per_local_expert,
|
||||
output=quant_input,
|
||||
output_scale=input_scale_fp32,
|
||||
act_mode=act_mode,
|
||||
is_gated=self.is_gated)
|
||||
|
||||
gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, w2,
|
||||
tokens_per_local_expert,
|
||||
None, None, None, None, input_scale, w2_scale, dtype, max_m)
|
||||
|
||||
combine_send_token_tensor = self.combine_send_buffer.view(self.max_num_tokens_recv, -1).view(hidden_states.dtype)
|
||||
mlu_ops.gather_split(gemm_out,
|
||||
gather_by_rank_index,
|
||||
token_sum,
|
||||
combine_send_token_tensor,
|
||||
None)
|
||||
|
||||
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(self.dispatch_recv_token_num, self.moe_ep_size)
|
||||
combine_recv_layout = self.dispatch_recv_layout
|
||||
|
||||
# combine
|
||||
combine_args = dict(
|
||||
token_byte=self.hidden_size * 2,
|
||||
token_num=num_token_expand,
|
||||
send_src_layout=combine_send_layout,
|
||||
send_dst_layout=combine_recv_layout,
|
||||
send_token=None,
|
||||
recv_token=None)
|
||||
|
||||
cnclep_combine(**combine_args)
|
||||
|
||||
numel_recv = num_token_expand * self.hidden_size
|
||||
recv_token = (self.combine_recv_buffer.view(hidden_states.dtype)[:numel_recv]
|
||||
.view(num_token_expand, self.hidden_size))
|
||||
|
||||
residual_ = None
|
||||
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
|
||||
residual_, cusum_token_count, start_expert_id=0,
|
||||
expert_size=self.num_total_experts, bias=bias2, output=hidden_states)
|
||||
assert self.moe_ep_size > 1
|
||||
# zero expert reduce
|
||||
output = mlu_ops.moe_combine_result(
|
||||
expand_hidden_states_zero, reduce_weight, combine_idx,
|
||||
output, cusum_token_count, self.num_total_experts,
|
||||
self.num_zero_experts, output=hidden_states)
|
||||
|
||||
return output.view(ori_input_shape)
|
||||
|
||||
# no compute-communication parallel, for prototyping only, not in actual use.
|
||||
# subject to becoming stale
|
||||
def forward_all2all_bf16_longcat(
|
||||
self, hidden_states, total_num_experts, total_num_experts_per_rank,
|
||||
topk_indices=None, topk_weights=None, residual_=None):
|
||||
is_fp8_quant = isinstance(self.quant_config, Fp8Config)
|
||||
ori_input_shape = hidden_states.shape
|
||||
dtype = hidden_states.dtype
|
||||
self.pack_params()
|
||||
self.pack_params_after_loading()
|
||||
w1=self.w13
|
||||
w2=self.w2
|
||||
bias1=self.b13
|
||||
bias2=self.b2
|
||||
gated=self.is_gated
|
||||
act_mode=self.hidden_act
|
||||
|
||||
max_m = hidden_states.shape[0]
|
||||
reduce_weight = topk_weights
|
||||
expert_id = topk_indices
|
||||
|
||||
# gen_idx
|
||||
expand_idx, combine_idx, token_count, cusum_token_count = \
|
||||
mlu_ops.moe_gen_idx(expert_id, total_num_experts)
|
||||
num_token_expand = hidden_states.shape[0] * self.top_k
|
||||
dispatch_bytes = num_token_expand * self.dispatch_token_size
|
||||
|
||||
dispatch_send_token_tensor = (
|
||||
self.dispatch_send_buffer[:dispatch_bytes]
|
||||
.view(num_token_expand, self.dispatch_token_size)
|
||||
.view(hidden_states.dtype)
|
||||
)
|
||||
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count, start_expert_id=0,
|
||||
expert_size=self.num_total_experts)
|
||||
expand_hidden_states_zero = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count,
|
||||
start_expert_id=self.num_total_experts,
|
||||
expert_size=self.num_zero_experts)
|
||||
|
||||
dispatch_send_token_tensor.copy_(expand_hidden_states)
|
||||
|
||||
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
|
||||
token_count[:self.num_total_experts], self.moe_ep_size)
|
||||
|
||||
cnclep_dispatch(self.dispatch_token_size,
|
||||
num_token_expand,
|
||||
dispatch_send_layout,
|
||||
token_count[:self.num_total_experts],
|
||||
self.dispatch_recv_layout,
|
||||
self.dispatch_recv_token_num,
|
||||
use_quant_dispatch=False,
|
||||
)
|
||||
|
||||
recv_token_num = self.dispatch_recv_token_num.view(
|
||||
self.moe_ep_size, self.num_experts_per_rank)
|
||||
pad_num = self.max_num_tokens_per_rank
|
||||
|
||||
(
|
||||
gather_by_expert_index,
|
||||
gather_by_rank_index,
|
||||
tokens_per_local_expert,
|
||||
token_sum
|
||||
) = mlu_ops.moe_all2all_gen_gather_index(recv_token_num, pad_num)
|
||||
|
||||
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
|
||||
dispatch_recv_token_tensor = (
|
||||
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
|
||||
.view(self.max_num_tokens_recv, self.dispatch_token_size)
|
||||
.view(hidden_states.dtype)
|
||||
)
|
||||
|
||||
|
||||
self.quant_input_recv = self.quant_input_recv.view(hidden_states.dtype)
|
||||
mlu_ops.gather_split(dispatch_recv_token_tensor,
|
||||
gather_by_expert_index,
|
||||
token_sum,
|
||||
self.quant_input_recv)
|
||||
|
||||
max_m = self.max_num_tokens_per_expert
|
||||
gemm_out = mlu_ops.group_gemm(
|
||||
self.quant_input_recv, w1, tokens_per_local_expert,
|
||||
None, None, None, None, max_m)
|
||||
act_out = mlu_ops.moe_active(
|
||||
gemm_out, act_mode, gated)
|
||||
gemm_out = mlu_ops.group_gemm(
|
||||
act_out, w2, tokens_per_local_expert,
|
||||
None, None, None, None, max_m)
|
||||
|
||||
combine_send_token_tensor = self.combine_send_buffer.view(
|
||||
self.max_num_tokens_recv, -1).view(hidden_states.dtype)
|
||||
mlu_ops.gather_split(gemm_out,
|
||||
gather_by_rank_index,
|
||||
token_sum,
|
||||
combine_send_token_tensor,
|
||||
None)
|
||||
|
||||
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(
|
||||
self.dispatch_recv_token_num, self.moe_ep_size)
|
||||
combine_recv_layout = self.dispatch_recv_layout
|
||||
|
||||
combine_args = dict(
|
||||
token_byte=self.hidden_size * 2,
|
||||
token_num=num_token_expand,
|
||||
send_src_layout=combine_send_layout,
|
||||
send_dst_layout=combine_recv_layout,
|
||||
send_token=None,
|
||||
recv_token=None,
|
||||
use_quant_dispatch=False,
|
||||
)
|
||||
|
||||
cnclep_combine(**combine_args)
|
||||
|
||||
numel_recv = num_token_expand * self.hidden_size
|
||||
recv_token = (self.combine_recv_buffer.view(hidden_states.dtype)[:numel_recv]
|
||||
.view(num_token_expand, self.hidden_size))
|
||||
|
||||
residual_ = None
|
||||
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
|
||||
residual_, cusum_token_count, start_expert_id=0,
|
||||
expert_size=self.num_total_experts, bias=bias2, output=hidden_states)
|
||||
# zero expert reduce
|
||||
output = mlu_ops.moe_combine_result(
|
||||
expand_hidden_states_zero, reduce_weight, combine_idx,
|
||||
output, cusum_token_count, self.num_total_experts,
|
||||
self.num_zero_experts, output=hidden_states)
|
||||
return output.view(ori_input_shape)
|
||||
|
||||
def forward_before_dispatch(self, hidden_states: torch.Tensor,
|
||||
topk_indices: torch.Tensor):
|
||||
# gate and softmax topk is called in router for longcat
|
||||
# other models can do these operations here
|
||||
expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(
|
||||
topk_indices, self.total_experts_including_zero)
|
||||
|
||||
num_token_expand = hidden_states.shape[0] * self.top_k
|
||||
dispatch_bytes = num_token_expand * self.dispatch_token_size
|
||||
dispatch_send_token_tensor = (
|
||||
self.dispatch_send_buffer[:dispatch_bytes]
|
||||
.view(num_token_expand, self.dispatch_token_size)
|
||||
)
|
||||
if self.use_quant_all2all:
|
||||
hidden_states_stride = self.hidden_size
|
||||
quant_input = dispatch_send_token_tensor[:, : hidden_states_stride]
|
||||
input_scale = dispatch_send_token_tensor[:, hidden_states_stride :].view(torch.float32)
|
||||
# expand input + quantize
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(
|
||||
hidden_states, self.a13_scale_all_experts, None,
|
||||
token_count[:self.num_total_experts],
|
||||
expand_idx, None,
|
||||
output=quant_input,
|
||||
output_scale=input_scale)
|
||||
# expand input of zero-expert
|
||||
expand_hidden_states_zero = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count,
|
||||
start_expert_id=self.num_total_experts,
|
||||
expert_size=self.num_zero_experts)
|
||||
else:
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count, start_expert_id=0,
|
||||
expert_size=self.num_total_experts)
|
||||
dispatch_send_token_tensor = dispatch_send_token_tensor.view(
|
||||
hidden_states.dtype)
|
||||
dispatch_send_token_tensor.copy_(expand_hidden_states)
|
||||
del expand_hidden_states
|
||||
expand_hidden_states_zero = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count,
|
||||
start_expert_id=self.num_total_experts,
|
||||
expert_size=self.num_zero_experts)
|
||||
|
||||
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
|
||||
token_count[:self.num_total_experts], self.moe_ep_size)
|
||||
|
||||
return combine_idx, token_count, cusum_token_count, dispatch_send_layout, expand_hidden_states_zero
|
||||
|
||||
def forward_dispatch(self, token_num: int, dispatch_send_layout: torch.Tensor,
|
||||
token_count: torch.Tensor):
|
||||
num_token_expand = token_num * self.top_k
|
||||
cnclep_dispatch(self.dispatch_token_size,
|
||||
num_token_expand,
|
||||
dispatch_send_layout,
|
||||
token_count[:self.num_total_experts],
|
||||
self.dispatch_recv_layout,
|
||||
self.dispatch_recv_token_num,
|
||||
use_quant_dispatch=self.use_quant_all2all)
|
||||
|
||||
def forward_before_combine(self, hidden_states_dtype: torch.dtype):
|
||||
recv_token_num = self.dispatch_recv_token_num.view(
|
||||
self.moe_ep_size, self.num_experts_per_rank)
|
||||
|
||||
(
|
||||
gather_by_expert_index,
|
||||
gather_by_rank_index,
|
||||
tokens_per_local_expert,
|
||||
token_sum,
|
||||
cusum_token_count
|
||||
) = mlu_ops.moe_all2all_gen_gather_index(
|
||||
recv_token_num, self.max_num_tokens_per_rank,
|
||||
return_cusum_token_count=True)
|
||||
|
||||
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
|
||||
dispatch_recv_token_tensor = (
|
||||
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
|
||||
.view(self.max_num_tokens_recv, self.dispatch_token_size))
|
||||
|
||||
max_m = self.max_num_tokens_per_expert
|
||||
if self.use_quant_all2all:
|
||||
mlu_ops.gather_split(dispatch_recv_token_tensor,
|
||||
gather_by_expert_index,
|
||||
token_sum,
|
||||
self.quant_input_recv,
|
||||
self.input_scale_recv)
|
||||
# OPT: input_scale_recv_flatten can reuse self.input_scale_recv
|
||||
input_scale_recv_flatten = self.input_scale_recv.view(torch.float32).flatten()
|
||||
gemm_out = mlu_ops.smooth_quant_group_gemm(self.quant_input_recv, self.w13,
|
||||
tokens_per_local_expert,
|
||||
None, None, None, None,
|
||||
input_scale_recv_flatten,
|
||||
self.w13_scale, hidden_states_dtype, max_m)
|
||||
|
||||
quant_input = self.quant_input_recv[:, :gemm_out.shape[-1] // 2]
|
||||
input_scale_fp32 = input_scale_recv_flatten[:gemm_out.shape[0]]
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, self.a2_scale, None,
|
||||
tokens_per_local_expert,
|
||||
output=quant_input,
|
||||
output_scale=input_scale_fp32,
|
||||
act_mode=self.hidden_act,
|
||||
is_gated=self.is_gated)
|
||||
|
||||
gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, self.w2, tokens_per_local_expert,
|
||||
None, None, None, None, input_scale, self.w2_scale,
|
||||
hidden_states_dtype, max_m)
|
||||
else:
|
||||
dispatch_recv_token_tensor = dispatch_recv_token_tensor.view(hidden_states_dtype)
|
||||
self.input_recv = self.input_recv.view(hidden_states_dtype)
|
||||
mlu_ops.gather_split(dispatch_recv_token_tensor,
|
||||
gather_by_expert_index,
|
||||
token_sum,
|
||||
self.input_recv)
|
||||
gemm_out = mlu_ops.group_gemm(
|
||||
self.input_recv, self.w13, tokens_per_local_expert,
|
||||
None, None, None, None, max_m)
|
||||
act_out = self.input_recv[:, :gemm_out.shape[-1] // 2]
|
||||
act_out = mlu_ops.moe_active(
|
||||
gemm_out, self.hidden_act, self.is_gated, output=act_out,
|
||||
bias=None, cusum_token_count=cusum_token_count,
|
||||
start_expert_id=0, expert_size=self.num_experts_per_rank)
|
||||
gemm_out = mlu_ops.group_gemm(
|
||||
act_out, self.w2, tokens_per_local_expert,
|
||||
None, None, None, None, max_m)
|
||||
|
||||
combine_send_token_tensor = self.combine_send_buffer.view(
|
||||
self.max_num_tokens_recv, -1).view(hidden_states_dtype)
|
||||
mlu_ops.gather_split(gemm_out,
|
||||
gather_by_rank_index,
|
||||
token_sum,
|
||||
combine_send_token_tensor,
|
||||
None)
|
||||
|
||||
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(
|
||||
self.dispatch_recv_token_num, self.moe_ep_size)
|
||||
|
||||
return combine_send_layout
|
||||
|
||||
def forward_combine(self, token_num: int, combine_send_layout: torch.Tensor):
|
||||
num_token_expand = token_num * self.top_k
|
||||
# combine_recv_layout(self.dispatch_recv_layout) is calculated when cnclep_dispatch
|
||||
# because dispatch and combine are inverse operation
|
||||
cnclep_combine(token_byte=self.hidden_size * 2,
|
||||
token_num=num_token_expand,
|
||||
send_src_layout=combine_send_layout,
|
||||
send_dst_layout=self.dispatch_recv_layout,
|
||||
send_token=None,
|
||||
recv_token=None,
|
||||
use_quant_dispatch=self.use_quant_all2all)
|
||||
|
||||
def forward_after_combine(self, token_num: int,
|
||||
reduce_weight: torch.Tensor,
|
||||
combine_idx: torch.Tensor,
|
||||
cusum_token_count: torch.Tensor,
|
||||
expand_hidden_states_zero: torch.Tensor,
|
||||
output_tensor_dtype: torch.dtype,
|
||||
output_tensor: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None):
|
||||
num_token_expand = token_num * self.top_k
|
||||
numel_recv = num_token_expand * self.hidden_size
|
||||
recv_token = (self.combine_recv_buffer.view(output_tensor_dtype)[:numel_recv]
|
||||
.view(num_token_expand, self.hidden_size))
|
||||
|
||||
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
|
||||
residual, cusum_token_count, start_expert_id=0,
|
||||
expert_size=self.num_total_experts, bias=self.b2, output=output_tensor)
|
||||
output = mlu_ops.moe_combine_result(
|
||||
expand_hidden_states_zero, reduce_weight, combine_idx,
|
||||
output, cusum_token_count, self.num_total_experts,
|
||||
self.num_zero_experts, output=output_tensor)
|
||||
|
||||
return output
|
||||
|
||||
# no compute-communication parallel, for prototyping only, not in actual use.
|
||||
# subject to becoming stale
|
||||
def forward_group_experts_longcat(
|
||||
self, hidden_states, total_num_experts, total_num_experts_per_rank,
|
||||
topk_indices=None, topk_weights=None, residual_=None,
|
||||
expand_idx=None, combine_idx=None, token_count=None, cusum_token_count=None):
|
||||
is_fp8_quant = isinstance(self.quant_config, Fp8Config)
|
||||
ori_input_shape = hidden_states.shape
|
||||
dtype = hidden_states.dtype
|
||||
self.pack_params()
|
||||
self.pack_params_after_loading()
|
||||
w1=self.w13
|
||||
w2=self.w2
|
||||
bias1=self.b13
|
||||
bias2=self.b2
|
||||
input_smooth=self.a13_scale
|
||||
act_smooth=self.a2_scale
|
||||
w1_scale=self.w13_scale
|
||||
w2_scale=self.w2_scale
|
||||
gated=self.is_gated
|
||||
act_mode=self.hidden_act
|
||||
quant_input=None
|
||||
|
||||
start_expert_id=self.start_expert_id
|
||||
expert_size = w1.size(0)
|
||||
max_m = hidden_states.shape[0]
|
||||
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||||
residual_ = residual_.view(-1, residual_.size(-1)) if residual_ is not None else None
|
||||
# Check smooth quant parameters.
|
||||
per_token_sq = False
|
||||
if not is_fp8_quant:
|
||||
check_list = [input_smooth, act_smooth, w1_scale, w2_scale]
|
||||
if all(x is not None for x in check_list):
|
||||
per_token_sq = True
|
||||
|
||||
if not (all(x is None for x in check_list) or all(x is not None for x in check_list)):
|
||||
raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present "
|
||||
"and absent at the same time.")
|
||||
|
||||
expert_id = topk_indices
|
||||
reduce_weight = topk_weights
|
||||
|
||||
# gen_idx
|
||||
if expert_id is not None:
|
||||
expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(expert_id, total_num_experts)
|
||||
|
||||
# check quant
|
||||
if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token':
|
||||
raise NotImplementedError
|
||||
elif per_token_sq:
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count,
|
||||
start_expert_id=start_expert_id,
|
||||
expert_size=expert_size)
|
||||
expand_hidden_states_zero = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count,
|
||||
start_expert_id=self.start_zero_expert_id,
|
||||
expert_size=self.zero_expert_size)
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(
|
||||
expand_hidden_states, input_smooth, None,
|
||||
token_count[start_expert_id:start_expert_id+expert_size])
|
||||
else:
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(hidden_states, expand_idx,
|
||||
cusum_token_count, start_expert_id, expert_size)
|
||||
expand_hidden_states_zero = mlu_ops.moe_expand_input(hidden_states, expand_idx,
|
||||
cusum_token_count, self.start_zero_expert_id, self.zero_expert_size)
|
||||
|
||||
if (is_fp8_quant and self.quant_config.activation_quant_method == 'per_token') or per_token_sq:
|
||||
gemm_out = mlu_ops.smooth_quant_group_gemm(
|
||||
quant_input, w1,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None, input_scale, w1_scale, dtype, max_m)
|
||||
else:
|
||||
gemm_out = mlu_ops.group_gemm(expand_hidden_states, w1,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None, max_m)
|
||||
|
||||
# add_bias_active
|
||||
if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token':
|
||||
raise NotImplementedError
|
||||
elif per_token_sq:
|
||||
quant_input = quant_input[:, :gemm_out.shape[-1] // 2]
|
||||
input_scale = input_scale[:gemm_out.shape[0]]
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
output=quant_input,
|
||||
output_scale=input_scale,
|
||||
act_mode=act_mode,
|
||||
is_gated=self.is_gated)
|
||||
|
||||
if ((is_fp8_quant and self.quant_config.activation_quant_method == 'per_token')
|
||||
or per_token_sq):
|
||||
# Remove the reference to gemm_out tensor.
|
||||
# If that was the only reference, the tensor’s memory becomes eligible for deallocation
|
||||
# So that we can reuse this memory for the new allocation of next gemm operation
|
||||
# del gemm_out
|
||||
gemm_out = mlu_ops.smooth_quant_group_gemm(
|
||||
quant_input, w2,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None, input_scale, w2_scale, dtype, max_m,
|
||||
output=expand_hidden_states)
|
||||
else:
|
||||
act_out = mlu_ops.moe_active(
|
||||
gemm_out, act_mode, gated, gemm_out[:,:gemm_out.shape[-1]//2],
|
||||
bias1, cusum_token_count, start_expert_id, expert_size)
|
||||
gemm_out = mlu_ops.group_gemm(
|
||||
act_out, w2, token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None, max_m,
|
||||
output=expand_hidden_states)
|
||||
|
||||
output = mlu_ops.moe_combine_result(
|
||||
gemm_out, reduce_weight, combine_idx,
|
||||
residual_, cusum_token_count, start_expert_id,
|
||||
expert_size, bias2)
|
||||
if self.moe_ep_size > 1 or self.moe_tp_rank == 0:
|
||||
output = mlu_ops.moe_combine_result(
|
||||
expand_hidden_states_zero, reduce_weight, combine_idx,
|
||||
output, cusum_token_count, self.start_zero_expert_id,
|
||||
self.zero_expert_size, bias2,
|
||||
output=output)
|
||||
return output.view(ori_input_shape)
|
||||
37
vllm_mlu/model_executor/layers/quantization/__init__.py
Normal file
37
vllm_mlu/model_executor/layers/quantization/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QUANTIZATION_METHODS, register_quantization_config
|
||||
)
|
||||
|
||||
MLU_QUANTIZATION_METHODS= [
|
||||
"smoothquant",
|
||||
"weightonly",
|
||||
"awq_mlu",
|
||||
"gptq_mlu",
|
||||
]
|
||||
|
||||
|
||||
def register_fake_mlu_quantization_methods():
|
||||
for quant_method in MLU_QUANTIZATION_METHODS:
|
||||
if quant_method not in QUANTIZATION_METHODS:
|
||||
QUANTIZATION_METHODS.append(quant_method)
|
||||
|
||||
|
||||
def remove_fake_mlu_quantization_methods():
|
||||
for quant_method in MLU_QUANTIZATION_METHODS:
|
||||
if quant_method in QUANTIZATION_METHODS:
|
||||
QUANTIZATION_METHODS.remove(quant_method)
|
||||
|
||||
|
||||
def register_real_mlu_quantization_methods():
|
||||
remove_fake_mlu_quantization_methods()
|
||||
from vllm_mlu.model_executor.layers.quantization.weightonly import WeightOnlyConfig
|
||||
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantConfig
|
||||
from vllm_mlu.model_executor.layers.quantization.awq_mlu import AWQMluConfig
|
||||
from vllm_mlu.model_executor.layers.quantization.gptq_mlu import GPTQMluConfig
|
||||
register_quantization_config("weightonly")(WeightOnlyConfig)
|
||||
register_quantization_config("smoothquant")(SmoothQuantConfig)
|
||||
register_quantization_config("awq_mlu")(AWQMluConfig)
|
||||
register_quantization_config("gptq_mlu")(GPTQMluConfig)
|
||||
412
vllm_mlu/model_executor/layers/quantization/awq_mlu.py
Normal file
412
vllm_mlu/model_executor/layers/quantization/awq_mlu.py
Normal file
@@ -0,0 +1,412 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
|
||||
|
||||
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
|
||||
def query_mlu_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def check_mlu_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
supported_types = query_mlu_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"Mlu does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"device_capability = {device_capability}, zp = {has_zp}).")
|
||||
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
|
||||
return (False, f"Mlu does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# @register_quantization_config("awq_mlu")
|
||||
class AWQMluConfig(QuantizationConfig):
|
||||
"""Config class for AWQMlu.
|
||||
|
||||
Reference: https://arxiv.org/abs/2306.00978
|
||||
"""
|
||||
|
||||
# num_bits -> type
|
||||
TYPE_MAP = {
|
||||
4: {
|
||||
False: scalar_types.uint4b8,
|
||||
True: scalar_types.uint4,
|
||||
},
|
||||
8: {
|
||||
False: scalar_types.uint8b128,
|
||||
True: scalar_types.uint8,
|
||||
}
|
||||
}
|
||||
|
||||
VERSION = ["gemm"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
version: str = "gemm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
self.version = version
|
||||
self.support_scale_zeros = False
|
||||
|
||||
if self.weight_bits not in [4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 4/8-bit weight quantization is supported for "
|
||||
f"AWQMlu, but got {self.weight_bits} bits.")
|
||||
if self.version not in self.VERSION:
|
||||
raise ValueError(
|
||||
"Currently, only gemm, gemv version is supported for "
|
||||
f"AWQMlu, but got verion:{self.version}.")
|
||||
|
||||
if self.version in ["gemm"]:
|
||||
self.order_map = {4: [0, 2, 4, 6, 1, 3, 5, 7], 8: [0, 2, 1, 3]}
|
||||
self.reverse_order_map = {4 : [0, 4, 1, 5, 2, 6, 3, 7], 8: [0, 2, 1, 3]}
|
||||
else:
|
||||
self.order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
|
||||
self.reverse_order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQMluConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point}), "
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "awq_mlu"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["quant_config.json", "quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQMluConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
version = cls.get_from_keys_or(config, ["version"],
|
||||
default="gemm")
|
||||
return cls(weight_bits, group_size, zero_point, lm_head_quantized, version)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["AWQMluLinearMethod"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return AWQMluLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_awq_mlu_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "awq"
|
||||
or user_quant == "awq_mlu")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "awq":
|
||||
logger.info("Detected that the model can run with awq_mlu"
|
||||
", however you specified quantization=awq explicitly,"
|
||||
" so forcing awq. Use quantization=awq_mlu for"
|
||||
" faster inference")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_awq_mlu_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits", None)
|
||||
group_size = quant_config.get("group_size", None)
|
||||
has_zp = quant_config.get("zero_point", None)
|
||||
version = quant_config.get("version", "gemm")
|
||||
|
||||
if quant_method != "awq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or has_zp is None):
|
||||
return False
|
||||
|
||||
if num_bits not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
if version not in cls.VERSION:
|
||||
return False
|
||||
|
||||
return check_mlu_supported(quant_type=cls.TYPE_MAP[num_bits][has_zp],
|
||||
group_size=group_size,
|
||||
has_zp=has_zp)
|
||||
|
||||
class AWQMluLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQMlu.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQMlu quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQMluConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scales = GroupQuantScaleParameter(data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
packed_qweight, scale_zeros = self.extract_autoawq(layer)
|
||||
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
layer.qzeros = None
|
||||
layer.scales = None
|
||||
else:
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
if scale_zeros is not None:
|
||||
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
|
||||
else:
|
||||
layer.qzeros = None
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data.transpose(0, 1).contiguous(), requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.quant_config.zero_point and not self.quant_config.support_scale_zeros:
|
||||
output = mlu_ops.matmul(x, layer.qweight, bias)
|
||||
if residual is not None:
|
||||
output = output + residual
|
||||
else:
|
||||
output = mlu_ops.weight_only_quant_matmul(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
bias,
|
||||
residual,
|
||||
"none",
|
||||
self.quant_config.weight_bits)
|
||||
|
||||
return output
|
||||
|
||||
def extract_autoawq(self, layer: torch.nn.Module):
|
||||
qweight = layer.qweight.data
|
||||
qzeros = layer.qzeros.data
|
||||
scales = layer.scales.data
|
||||
bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
|
||||
# Unpack the qweight and qzeros tensors
|
||||
iweight, izeros = self.unpack_awq_int32_into_int8(qweight, qzeros, bits)
|
||||
# Reverse the order of the iweight and izeros tensors
|
||||
iweight, izeros = self.reverse_awq_order(iweight, izeros, bits)
|
||||
|
||||
# overflow checks
|
||||
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
|
||||
if izeros is not None:
|
||||
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
|
||||
|
||||
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
if izeros is not None:
|
||||
izeros = izeros.repeat_interleave(group_size, dim=0)
|
||||
fweight = (iweight - izeros) * scales
|
||||
else:
|
||||
fweight = iweight * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
fweight = fweight.transpose(0, 1)
|
||||
|
||||
return fweight, None
|
||||
|
||||
if self.quant_config.zero_point and self.quant_config.support_scale_zeros and izeros is not None:
|
||||
scale_zeros = izeros.to(scales.dtype) * -1 * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
scale_zeros = scale_zeros.transpose(0, 1)
|
||||
else:
|
||||
scale_zeros = None
|
||||
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
iweight = iweight.to(torch.int8).transpose(0, 1)
|
||||
|
||||
if bits == 4:
|
||||
higher_bit_tensor = iweight[:, 1::2]
|
||||
lower_bit_tensor = iweight[:, 0::2]
|
||||
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
|
||||
else:
|
||||
packed_qweight = iweight
|
||||
|
||||
return packed_qweight, scale_zeros
|
||||
|
||||
def unpack_awq_int32_into_int8(self, qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qweight.device)
|
||||
dtype = torch.int16 if bits == 8 else torch.int8
|
||||
# unpacking columnwise
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(dtype)
|
||||
iweights = iweights.view(iweights.shape[0], -1)
|
||||
if not self.quant_config.zero_point or self.quant_config.support_scale_zeros:
|
||||
iweights = torch.bitwise_and(iweights - 2**(bits - 1), (2 ** bits) - 1)
|
||||
|
||||
# unpacking columnwise
|
||||
if qzeros is not None:
|
||||
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(dtype)
|
||||
izeros = izeros.view(izeros.shape[0], -1)
|
||||
if not self.quant_config.zero_point:
|
||||
izeros = torch.bitwise_and(izeros - 2**(bits - 1), (2 ** bits) - 1)
|
||||
else:
|
||||
izeros = None
|
||||
|
||||
return iweights, izeros
|
||||
|
||||
def reverse_awq_order(self, iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
|
||||
reverse_order_tensor = torch.arange(iweights.shape[-1], dtype=torch.int32, device=iweights.device)
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
||||
reverse_order_tensor = reverse_order_tensor[:, self.quant_config.reverse_order_map[bits]]
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1)
|
||||
|
||||
rweights = iweights[:, reverse_order_tensor]
|
||||
if izeros is not None:
|
||||
rzeros = izeros[:, reverse_order_tensor]
|
||||
|
||||
return rweights, rzeros
|
||||
|
||||
def combine_low_bits(self, tensor_a, tensor_b):
|
||||
"""
|
||||
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
|
||||
|
||||
Args:
|
||||
tensor_a (torch.Tensor): First tensor of type int8.
|
||||
tensor_b (torch.Tensor): Second tensor of type int8.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
|
||||
"""
|
||||
# 确保输入是 int8 类型
|
||||
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
|
||||
raise ValueError("Both tensors must be of int8 type.")
|
||||
|
||||
# 提取每个 tensor 的低4位
|
||||
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
|
||||
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
|
||||
|
||||
# 将 tensor_a 的低4位左移4位
|
||||
shifted_low_bits_a = low_bits_a << 4
|
||||
|
||||
# 组合两个 tensor 的低4位
|
||||
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
|
||||
|
||||
return combined
|
||||
753
vllm_mlu/model_executor/layers/quantization/fp8.py
Normal file
753
vllm_mlu/model_executor/layers/quantization/fp8.py
Normal file
@@ -0,0 +1,753 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import functools
|
||||
from functools import partial
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
from vllm import envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
get_flashinfer_moe_backend,
|
||||
ACTIVATION_SCHEMES,
|
||||
Fp8Config,
|
||||
Fp8LinearMethod,
|
||||
Fp8MoeBackend,
|
||||
Fp8MoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
get_flashinfer_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
validate_fp8_block_shape
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise, cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale,
|
||||
maybe_create_device_identity, Fp8LinearOp)
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter, ChannelQuantScaleParameter,
|
||||
ModelWeightParameter, PerTensorScaleParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||
import vllm_mlu._mlu_ops as mlu_ops
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
|
||||
"""
|
||||
Select the primary FP8 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and (
|
||||
current_platform.is_device_capability(100)
|
||||
or current_platform.is_device_capability(90)
|
||||
)
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP8
|
||||
and has_flashinfer_moe()
|
||||
):
|
||||
backend = get_flashinfer_moe_backend()
|
||||
if backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
else:
|
||||
if block_quant and current_platform.is_device_capability(100):
|
||||
raise ValueError(
|
||||
"FlashInfer FP8 MoE throughput backend does not "
|
||||
"support block quantization. Please use "
|
||||
"VLLM_FLASHINFER_MOE_BACKEND=latency "
|
||||
"instead."
|
||||
)
|
||||
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
|
||||
# weight-only path for older GPUs without native FP8
|
||||
use_marlin = (
|
||||
not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: disable marlin for MLU backend.
|
||||
'''
|
||||
if current_platform.is_rocm() or current_platform.is_out_of_tree():
|
||||
use_marlin = False
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if use_marlin:
|
||||
logger.info_once("Using Marlin backend for FP8 MoE")
|
||||
return Fp8MoeBackend.MARLIN
|
||||
|
||||
# deepGEMM on supported platforms with block-quantized weights
|
||||
if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
|
||||
if not has_deep_gemm():
|
||||
logger.warning_once("DeepGEMM backend requested but not available.")
|
||||
elif is_deep_gemm_supported():
|
||||
logger.info_once("Using DeepGEMM backend for FP8 MoE")
|
||||
return Fp8MoeBackend.DEEPGEMM
|
||||
|
||||
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and block_quant
|
||||
):
|
||||
logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
|
||||
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
|
||||
|
||||
# default to Triton
|
||||
logger.info_once("Using Triton backend for FP8 MoE")
|
||||
return Fp8MoeBackend.TRITON
|
||||
|
||||
|
||||
Fp8Config____init____org = Fp8Config.__init__
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8Config____init__(
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: list[str] | None = None,
|
||||
weight_block_size: list[int] | None = None,
|
||||
activation_quant_method: Optional[str] = None,
|
||||
weight_quant_method: Optional[str] = None,
|
||||
) -> None:
|
||||
super(Fp8Config, self).__init__()
|
||||
|
||||
Fp8Config____init____org(
|
||||
self,
|
||||
is_checkpoint_fp8_serialized,
|
||||
activation_scheme,
|
||||
ignored_layers,
|
||||
weight_block_size
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add class members activation_quant_method and weight_quant_method to
|
||||
indicate the granularity of quantization.
|
||||
'''
|
||||
self.activation_quant_method = activation_quant_method
|
||||
self.weight_quant_method = weight_quant_method
|
||||
|
||||
assert (self.weight_block_size or \
|
||||
self.activation_quant_method == "per_token" and self.weight_quant_method == "per_channel"
|
||||
and self.activation_scheme == "dynamic"), "Only support block-wise quantization, or "\
|
||||
"input dynamic per-token weight per-channel quantization yet."
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
@classmethod
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config(
|
||||
cls, config: Dict[str, Any]
|
||||
) -> "Fp8Config":
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
|
||||
if not ignored_layers:
|
||||
ignored_layers = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add config members activation_quant_method and weight_quant_method to
|
||||
indicate the granularity of quantization.
|
||||
'''
|
||||
activation_quant_method = cls.get_from_keys_or(config,
|
||||
["activation_quant_method"],
|
||||
'per_token')
|
||||
weight_quant_method = cls.get_from_keys_or(config,
|
||||
["weight_quant_method"],
|
||||
None)
|
||||
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||
activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers,
|
||||
weight_block_size=weight_block_size,
|
||||
activation_quant_method=activation_quant_method,
|
||||
weight_quant_method=weight_quant_method)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group.
|
||||
'''
|
||||
tp_group = extra_weight_attrs.get("tp_group", None)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
validate_fp8_block_shape(
|
||||
layer,
|
||||
input_size,
|
||||
output_size,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
self.weight_block_size,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group.
|
||||
'''
|
||||
# WEIGHT
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
weight = create_fp8_weight_parameter(
|
||||
output_size_per_partition, input_size_per_partition, weight_loader
|
||||
)
|
||||
else:
|
||||
# For non-serialized checkpoints, use original dtype
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# If checkpoint is serialized fp8, load them.
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
if not self.block_quant:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Support weight per channel quantization.
|
||||
@brief: Add tp_group to enable custom split.
|
||||
'''
|
||||
if self.weight_per_channel:
|
||||
scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty(sum(output_partition_sizes), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
else:
|
||||
scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
assert not self.act_q_static
|
||||
assert self.weight_block_size is not None
|
||||
scale = create_fp8_scale_parameter(
|
||||
BlockQuantScaleParameter,
|
||||
output_partition_sizes,
|
||||
input_size_per_partition,
|
||||
self.weight_block_size,
|
||||
weight_loader,
|
||||
)
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.act_q_static:
|
||||
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
|
||||
set_weight_attrs(scale, {"scale_type": "input_scale"})
|
||||
layer.register_parameter("input_scale", scale)
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__(
|
||||
self,
|
||||
quant_config: Fp8Config
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (
|
||||
not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
if vllm_is_batch_invariant():
|
||||
self.use_marlin = False
|
||||
|
||||
# AITER is only supported on ROCm and only for FP8_FNUZ
|
||||
# and at the moment are MI300 series
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
self.use_deep_gemm = is_deep_gemm_supported()
|
||||
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant = self.weight_block_size is not None
|
||||
if self.block_quant:
|
||||
# Marlin doesn't support block-wise fp8
|
||||
self.use_marlin = False
|
||||
|
||||
self.act_q_static = self.quant_config.activation_scheme == "static"
|
||||
if self.weight_block_size:
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
else:
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add config members activation_quant_method and weight_quant_method to
|
||||
indicate the granularity of quantization.
|
||||
'''
|
||||
self.weight_per_channel = (self.quant_config.weight_quant_method == 'per_channel')
|
||||
self.activation_per_token = (self.quant_config.activation_quant_method == 'per_token')
|
||||
if self.weight_per_channel and self.activation_per_token:
|
||||
self.use_marlin = False
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.block_quant:
|
||||
assert not self.act_q_static
|
||||
assert self.weight_block_size is not None
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.act_q_static,
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
)
|
||||
|
||||
|
||||
Fp8LinearMethod__process_weights_after_loading__org = Fp8LinearMethod.process_weights_after_loading
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading(
|
||||
self,
|
||||
layer: Module,
|
||||
) -> None:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: For dynamic activation and channel-wise weight quantization,
|
||||
additional processing is not needed.
|
||||
'''
|
||||
if (self.quant_config.is_checkpoint_fp8_serialized
|
||||
and self.weight_per_channel
|
||||
and self.quant_config.activation_scheme == "dynamic"):
|
||||
return
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
Fp8LinearMethod__process_weights_after_loading__org(self=self, layer=layer)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert residual is None, "Fp8Linear residual is not supported yet."
|
||||
|
||||
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
|
||||
# we will use BF16 dequant when DeepGEMM is not supported.
|
||||
if vllm_is_batch_invariant():
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
# per-tensor/channel: dequant to BF16 and run GEMM
|
||||
weight_fp8 = layer.weight.to(torch.bfloat16)
|
||||
weight_scale = layer.weight_scale.to(torch.bfloat16)
|
||||
if weight_scale.numel() == 1:
|
||||
# Per-tensor: simple scalar multiplication
|
||||
weight_bf16 = weight_fp8 * weight_scale
|
||||
else:
|
||||
# Multiple scales (fused modules like QKV)
|
||||
# Try to infer correct broadcasting
|
||||
# weight is [K, N], scale could be [num_logical_weights]
|
||||
# Need to figure out how to broadcast - for now just try
|
||||
# direct multiplication
|
||||
if (
|
||||
weight_scale.dim() == 1
|
||||
and weight_scale.shape[0] == weight_fp8.shape[0]
|
||||
):
|
||||
# Per-row scaling
|
||||
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
|
||||
else:
|
||||
# Fallback
|
||||
weight_bf16 = weight_fp8 * weight_scale
|
||||
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||
|
||||
if self.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_w8a8_block_fp8_linear)
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
block_size=self.quant_config.weight_block_size,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Use activation per token quantization based on quantization config.
|
||||
'''
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
weight_per_channel=self.weight_per_channel,
|
||||
activation_per_token=self.activation_per_token)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__(
|
||||
self,
|
||||
quant_config: Fp8Config,
|
||||
layer: torch.nn.Module
|
||||
):
|
||||
super(Fp8MoEMethod, self).__init__(layer.moe_config)
|
||||
self.layer = layer
|
||||
self.quant_config = quant_config
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant: bool = self.weight_block_size is not None
|
||||
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
|
||||
|
||||
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size == [128, 128], (
|
||||
f"Only support weight_block_size == [128, 128], "
|
||||
f"got {self.weight_block_size}"
|
||||
)
|
||||
self.flashinfer_moe_fn = partial(
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
moe=self.moe,
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
)
|
||||
|
||||
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
||||
self.allow_cutlass_block_scaled_grouped_gemm = (
|
||||
self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: In mlu, always set self.use_marlin as False.
|
||||
'''
|
||||
self.use_marlin = False
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
assert isinstance(layer, FusedMoE)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Use moe_softmax_topk and moe_sigmoid_topk of mlu_ops to implement FusedMoE.select_experts
|
||||
'''
|
||||
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
if scoring_func == "softmax":
|
||||
topk_weights, topk_ids = mlu_ops.moe_softmax_topk(
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
route_scale=routed_scaling_factor,
|
||||
)
|
||||
elif scoring_func == "sigmoid":
|
||||
topk_weights, topk_ids = mlu_ops.moe_sigmoid_topk(
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
# gen_idx
|
||||
ori_input_shape = x.shape
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
router_logits = router_logits.reshape(-1, router_logits.size(-1))
|
||||
expert_num = router_logits.size(-1)
|
||||
tokens_num = x.size(0)
|
||||
expert_size = layer.w13_weight.size(0)
|
||||
expand_idx, combine_idx, token_count, cumsum_token_count = mlu_ops.moe_gen_idx(
|
||||
topk_ids, expert_num
|
||||
)
|
||||
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(
|
||||
x, expand_idx, cumsum_token_count, 0, expert_size
|
||||
)
|
||||
quant_input, input_scale = _fp8_quantize(
|
||||
expand_hidden_states, A_scale=None, block_shape=self.quant_config.weight_block_size
|
||||
)
|
||||
gemm1_out = mlu_ops.smooth_quant_group_gemm(
|
||||
quant_input,
|
||||
layer.w13_weight,
|
||||
token_count,
|
||||
expand_idx=None,
|
||||
c=None,
|
||||
alpha=None,
|
||||
beta=None,
|
||||
a_scale=input_scale.T.contiguous(),
|
||||
b_scale=layer.w13_weight_scale_inv,
|
||||
dtype=x.dtype,
|
||||
max_m=tokens_num,
|
||||
)
|
||||
|
||||
act_out = mlu_ops.active(gemm1_out, activation, is_gated=True)
|
||||
act_out_quantize, act_out_scale = _fp8_quantize(
|
||||
act_out, A_scale=None, block_shape=self.quant_config.weight_block_size
|
||||
)
|
||||
|
||||
gemm2_out = mlu_ops.smooth_quant_group_gemm(
|
||||
act_out_quantize,
|
||||
layer.w2_weight,
|
||||
token_count,
|
||||
expand_idx=None,
|
||||
c=None,
|
||||
alpha=None,
|
||||
beta=None,
|
||||
a_scale=act_out_scale.T.contiguous(),
|
||||
b_scale=layer.w2_weight_scale_inv,
|
||||
dtype=x.dtype,
|
||||
max_m=tokens_num,
|
||||
)
|
||||
|
||||
output = mlu_ops.moe_combine_result(
|
||||
gemm2_out,
|
||||
topk_weights,
|
||||
combine_idx,
|
||||
residual=None,
|
||||
cusum_token_count=cumsum_token_count,
|
||||
start_expert_id=0,
|
||||
expert_size=expert_size,
|
||||
bias=None,
|
||||
)
|
||||
return output.view(ori_input_shape)
|
||||
|
||||
"""
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
"""
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearMethod,
|
||||
Fp8LinearMethod.apply,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8Config,
|
||||
Fp8Config.__init__,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8Config____init__
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8Config,
|
||||
Fp8Config.from_config,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearMethod,
|
||||
Fp8LinearMethod.create_weights,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearMethod,
|
||||
Fp8LinearMethod.__init__,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearMethod,
|
||||
Fp8LinearMethod.process_weights_after_loading,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8MoEMethod,
|
||||
Fp8MoEMethod.__init__,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8MoEMethod,
|
||||
Fp8MoEMethod.apply,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply
|
||||
)
|
||||
440
vllm_mlu/model_executor/layers/quantization/gptq_mlu.py
Normal file
440
vllm_mlu/model_executor/layers/quantization/gptq_mlu.py
Normal file
@@ -0,0 +1,440 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
|
||||
|
||||
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
|
||||
def query_mlu_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def check_mlu_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
supported_types = query_mlu_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"Mlu does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"device_capability = {device_capability}, zp = {has_zp}).")
|
||||
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
|
||||
return (False, f"Mlu does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# @register_quantization_config("gptq_mlu")
|
||||
class GPTQMluConfig(QuantizationConfig):
|
||||
"""Config class for GPTQMlu.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
(4, False): scalar_types.uint4b8,
|
||||
(8, False): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = Fraction(32, self.weight_bits)
|
||||
self.support_scale_zeros = False
|
||||
self.use_native = self.desc_act or (not self.is_sym and not self.support_scale_zeros)
|
||||
|
||||
if self.weight_bits not in [4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 4/8-bit weight quantization is "
|
||||
f"supported for GPTQMlu, but got {self.weight_bits} bits.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQMluConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}),"
|
||||
f"lm_head_quantized={self.lm_head_quantized}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "gptq_mlu"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quant_config.json", "quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMluConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym, lm_head_quantized)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQMluLinearMethod"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return GPTQMluLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
@classmethod
|
||||
def is_gptq_mlu_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits", None)
|
||||
group_size = quant_config.get("group_size", None)
|
||||
sym = quant_config.get("sym", None)
|
||||
desc_act = quant_config.get("desc_act", None)
|
||||
|
||||
if quant_method != "gptq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or sym is None
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_mlu_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
|
||||
group_size=group_size, has_zp=False)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_gptq_mlu_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
||||
or user_quant == "gptq_mlu")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
class GPTQMluLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQMlu.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQMlu quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQMluConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (output_size_per_partition % self.quant_config.pack_factor.numerator
|
||||
!= 0):
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if (input_size != input_size_per_partition) and (self.quant_config.group_size !=
|
||||
-1) and (not self.quant_config.desc_act):
|
||||
scale_and_zero_size = input_size_per_partition // group_size
|
||||
scale_and_zero_input_dim = 0
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
g_idx = RowvLLMParameter(data=torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if scale_and_zero_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.device = layer.qweight.data.device
|
||||
packed_qweight, scale_zeros = self.extract_autogptq(layer)
|
||||
if self.quant_config.use_native:
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
layer.qzeros = None
|
||||
layer.scales = None
|
||||
else:
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
if scale_zeros is not None:
|
||||
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
|
||||
else:
|
||||
layer.qzeros = None
|
||||
layer.scales = torch.nn.Parameter(layer.scales.transpose(0, 1).contiguous(), requires_grad=False)
|
||||
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.quant_config.use_native:
|
||||
output = mlu_ops.matmul(x, layer.qweight, bias)
|
||||
if residual is not None:
|
||||
output = output + residual
|
||||
else:
|
||||
output = mlu_ops.weight_only_quant_matmul(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
bias,
|
||||
residual,
|
||||
"none",
|
||||
self.quant_config.weight_bits)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def extract_autogptq(self, layer: torch.nn.Module):
|
||||
scales = layer.scales.data
|
||||
bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
# Unpack the qweight and qzeros tensors
|
||||
iweight = self.unpack_gptq_qweight_int32_into_int8(layer.qweight.data, bits)
|
||||
izeros = self.unpack_gptq_qzeros_int32_into_int8(layer.qzeros.data, bits)
|
||||
|
||||
if self.quant_config.use_native:
|
||||
if self.quant_config.desc_act:
|
||||
scales = torch.index_select(scales, 0, layer.g_idx)
|
||||
if izeros is not None:
|
||||
izeros = torch.index_select(izeros, 0, layer.g_idx)
|
||||
else:
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
if izeros is not None:
|
||||
izeros = izeros.repeat_interleave(group_size, dim=0)
|
||||
|
||||
if izeros is not None:
|
||||
fweight = (iweight - izeros) * scales
|
||||
else:
|
||||
fweight = iweight * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
fweight = fweight.transpose(0, 1)
|
||||
|
||||
return fweight, None
|
||||
|
||||
if not self.quant_config.is_sym and self.quant_config.support_scale_zeros and izeros is not None:
|
||||
scale_zeros = izeros.to(scales.dtype) * -1 * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
scale_zeros = scale_zeros.transpose(0, 1)
|
||||
else:
|
||||
# for is_sym is true now, so make iweight to sign value and ignore qzeros
|
||||
iweight = torch.bitwise_and(iweight - 2**(bits - 1), (2 ** bits) - 1)
|
||||
scale_zeros = None
|
||||
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
iweight = iweight.to(torch.int8).transpose(0, 1)
|
||||
|
||||
if bits == 4:
|
||||
higher_bit_tensor = iweight[:, 1::2]
|
||||
lower_bit_tensor = iweight[:, 0::2]
|
||||
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
|
||||
else:
|
||||
packed_qweight = iweight
|
||||
|
||||
return packed_qweight, scale_zeros
|
||||
|
||||
|
||||
def unpack_gptq_qweight_int32_into_int8(self, qweight: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qweight.device).unsqueeze(0)
|
||||
dtype = torch.int16 if bits == 8 else torch.int8
|
||||
weight = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1),
|
||||
shifts.unsqueeze(-1),
|
||||
).to(dtype)
|
||||
weight = torch.bitwise_and(weight, (2**bits) - 1)
|
||||
weight = weight.reshape(-1, weight.shape[-1])
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
def unpack_gptq_qzeros_int32_into_int8(self, qzeros: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qzeros.device).unsqueeze(0)
|
||||
dtype = torch.int16 if bits == 8 else torch.int8
|
||||
zeros = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
|
||||
shifts.unsqueeze(0),
|
||||
).to(dtype)
|
||||
|
||||
zeros = zeros + 1
|
||||
|
||||
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
|
||||
zeros = zeros.reshape(qzeros.shape[0], -1)
|
||||
|
||||
return zeros
|
||||
|
||||
|
||||
def combine_low_bits(self, tensor_a, tensor_b):
|
||||
"""
|
||||
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
|
||||
|
||||
Args:
|
||||
tensor_a (torch.Tensor): First tensor of type int8.
|
||||
tensor_b (torch.Tensor): Second tensor of type int8.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
|
||||
"""
|
||||
# 确保输入是 int8 类型
|
||||
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
|
||||
raise ValueError("Both tensors must be of int8 type.")
|
||||
|
||||
# 提取每个 tensor 的低4位
|
||||
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
|
||||
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
|
||||
|
||||
# 将 tensor_a 的低4位左移4位
|
||||
shifted_low_bits_a = low_bits_a << 4
|
||||
|
||||
# 组合两个 tensor 的低4位
|
||||
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
|
||||
|
||||
return combined
|
||||
337
vllm_mlu/model_executor/layers/quantization/smoothquant.py
Executable file
337
vllm_mlu/model_executor/layers/quantization/smoothquant.py
Executable file
@@ -0,0 +1,337 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.model_executor.layers.quantization.utils.common_utils import (str_dtype_to_torch,
|
||||
str_dtype_to_bits,
|
||||
is_fp8_str_dtype)
|
||||
|
||||
|
||||
# @register_quantization_config("smoothquant")
|
||||
class SmoothQuantConfig(QuantizationConfig):
|
||||
"""Config class for SmoothQuant.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_mode: str, # smoothquant
|
||||
input_quant_method: str, # per token/per tensor
|
||||
group_size: int,
|
||||
weight_precision: str,
|
||||
activation_precision: str,
|
||||
only_expert_per_group: bool,
|
||||
expert_weight_precision: str,
|
||||
expert_activation_precision: str,
|
||||
force_use_weightonly_except_expert: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_mode = quant_mode
|
||||
self.input_quant_method = input_quant_method
|
||||
self.group_size = group_size
|
||||
self.weight_precision = weight_precision
|
||||
self.activation_precision = activation_precision
|
||||
self.only_expert_per_group = only_expert_per_group
|
||||
self.expert_weight_precision = expert_weight_precision
|
||||
self.expert_activation_precision = expert_activation_precision
|
||||
self.force_use_weightonly_except_expert = force_use_weightonly_except_expert
|
||||
|
||||
if quant_mode == "SmoothQuant" and (self.input_quant_method != "per_token" and self.input_quant_method != "per_tensor"):
|
||||
raise ValueError(
|
||||
"Currently, only per_token or per_tensor input quantization is supported for "
|
||||
f"SmoothQuant, but got {self.input_quant_method}.")
|
||||
|
||||
self.weight_bits = str_dtype_to_bits(self.weight_precision)
|
||||
self.expert_weight_bits = str_dtype_to_bits(self.expert_weight_precision)
|
||||
if self.weight_precision == 'int4':
|
||||
self.weight_dtype = torch.int8
|
||||
else:
|
||||
self.weight_dtype = str_dtype_to_torch(self.weight_precision)
|
||||
if self.expert_weight_precision == 'int4':
|
||||
self.expert_weight_dtype = torch.int8
|
||||
else:
|
||||
self.expert_weight_dtype = str_dtype_to_torch(self.expert_weight_precision)
|
||||
self.is_fp8 = is_fp8_str_dtype(self.weight_precision)
|
||||
self.expert_is_fp8 = is_fp8_str_dtype(self.expert_weight_precision)
|
||||
|
||||
self.pack_factor = 8 // self.weight_bits
|
||||
self.expert_pack_factor = 8 // self.expert_weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SmoothQuantConfig(input_quant_method={self.input_quant_method}, "
|
||||
f"quant_mode={self.quant_mode}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"weight_precision={self.weight_precision}, "
|
||||
f"activation_precision={self.activation_precision}, "
|
||||
f"only_expert_per_group={self.only_expert_per_group}, "
|
||||
f"expert_weight_precision={self.expert_weight_precision}, "
|
||||
f"expert_activation_precision={self.expert_activation_precision}, "
|
||||
f"force_use_weightonly_except_expert={self.force_use_weightonly_except_expert})")
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
return "SmoothQuant"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig":
|
||||
quant_mode = cls.get_from_keys(config, ["quant_mode"])
|
||||
input_quant_method = cls.get_from_keys(config, ["input_quant_method"])
|
||||
group_size = cls.get_from_keys_or(config, ["group_size"], 1)
|
||||
weight_precision = cls.get_from_keys_or(config, ["weight_precision"], "int8")
|
||||
activation_precision = cls.get_from_keys_or(config, ["activation_precision"], "int8")
|
||||
only_expert_per_group = cls.get_from_keys_or(config, ["only_expert_per_group"], False)
|
||||
expert_weight_precision = cls.get_from_keys_or(config, ["expert_weight_precision"], None)
|
||||
expert_activation_precision = cls.get_from_keys_or(config, ["expert_activation_precision"], None)
|
||||
force_use_weightonly_except_expert = cls.get_from_keys_or(config, ["force_use_weightonly_except_expert"], False)
|
||||
|
||||
if expert_weight_precision is None:
|
||||
expert_weight_precision = weight_precision
|
||||
if group_size > 1 and only_expert_per_group and weight_precision == 'int4':
|
||||
weight_precision = 'int8'
|
||||
|
||||
if expert_activation_precision is None:
|
||||
expert_activation_precision = activation_precision
|
||||
|
||||
return cls(quant_mode=quant_mode,
|
||||
input_quant_method=input_quant_method,
|
||||
group_size=group_size,
|
||||
weight_precision=weight_precision,
|
||||
activation_precision=activation_precision,
|
||||
only_expert_per_group=only_expert_per_group,
|
||||
expert_weight_precision=expert_weight_precision,
|
||||
expert_activation_precision=expert_activation_precision,
|
||||
force_use_weightonly_except_expert=force_use_weightonly_except_expert)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["SmoothQuantLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return SmoothQuantLinearMethod(self, prefix)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
|
||||
class SmoothQuantLinearMethod(LinearMethodBase):
|
||||
"""Linear method for SmoothQuant.
|
||||
|
||||
Args:
|
||||
quant_config: The SmoothQuant quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: SmoothQuantConfig, prefix: str):
|
||||
self.quant_config = quant_config
|
||||
# for per-tensor case, we can skip quant input for the first attn|ffn linear
|
||||
# and fusion this step in layernorm to get better performance
|
||||
self.skip_quant_input = False
|
||||
self.compute_dtype = torch.get_default_dtype()
|
||||
self.is_expert = 'expert' in prefix and "shared_expert" not in prefix
|
||||
self.weight_dtype = quant_config.expert_weight_dtype if self.is_expert else quant_config.weight_dtype
|
||||
self.pack_factor = quant_config.expert_pack_factor if self.is_expert else quant_config.pack_factor
|
||||
self.is_fp8 = quant_config.expert_is_fp8 if self.is_expert else quant_config.is_fp8
|
||||
|
||||
if quant_config.only_expert_per_group and self.is_expert and quant_config.group_size > 1:
|
||||
self.is_group_quant = True
|
||||
elif quant_config.only_expert_per_group is False and quant_config.group_size > 1:
|
||||
self.is_group_quant = True
|
||||
else:
|
||||
self.is_group_quant = False
|
||||
self.has_smooth = self.quant_config.input_quant_method == "per_token" and (
|
||||
self.quant_config.force_use_weightonly_except_expert is False or self.is_expert)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (output_size_per_partition % self.quant_config.pack_factor != 0):
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
group_num = 1
|
||||
if self.is_group_quant:
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
f"The input size {input_size_per_partition} is not aligned with the quantized "
|
||||
f"weight shape. This can be caused by too large "
|
||||
f"tensor parallel size. group_size: {self.quant_config.group_size}.")
|
||||
|
||||
group_num = (input_size + self.quant_config.group_size - 1) // self.quant_config.group_size
|
||||
if input_size_per_partition != input_size:
|
||||
group_num = (input_size_per_partition + self.quant_config.group_size - 1) // self.quant_config.group_size
|
||||
|
||||
qweight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.pack_factor,
|
||||
device="mlu",
|
||||
dtype=self.weight_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
if self.is_group_quant:
|
||||
per_channel_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
group_num,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
per_channel_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("per_channel_scale", per_channel_scale)
|
||||
|
||||
if self.has_smooth:
|
||||
smooth = RowvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
set_weight_attrs(smooth, {
|
||||
"ignore_warning": True,
|
||||
})
|
||||
layer.register_parameter("smooth", smooth)
|
||||
if self.quant_config.input_quant_method == "per_tensor":
|
||||
scale_to_int = RowvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
set_weight_attrs(scale_to_int, {
|
||||
"ignore_warning": True,
|
||||
})
|
||||
layer.register_parameter("scale_to_int", scale_to_int)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if self.has_smooth and layer.smooth.dtype != torch.float:
|
||||
layer.smooth = layer.smooth.to(torch.float)
|
||||
if self.quant_config.input_quant_method == "per_tensor" and layer.scale_to_int.dtype != torch.float:
|
||||
layer.scale_to_int = layer.scale_to_int.to(torch.float)
|
||||
if layer.per_channel_scale.dtype != torch.float:
|
||||
layer.per_channel_scale = layer.per_channel_scale.to(torch.float)
|
||||
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.per_channel_scale = Parameter(layer.per_channel_scale.data, requires_grad=False)
|
||||
if self.has_smooth:
|
||||
layer.smooth = Parameter(layer.smooth.data, requires_grad=False)
|
||||
if self.quant_config.input_quant_method == "per_tensor":
|
||||
layer.scale_to_int = Parameter(layer.scale_to_int.data, requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
use_tp_weight : bool = False,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
layer_smooth = layer.smooth if self.has_smooth else None
|
||||
layer_qweight = layer.qweight
|
||||
layer_per_channel_scale = layer.per_channel_scale
|
||||
if use_tp_weight:
|
||||
if hasattr(layer, 'tp_smooth'):
|
||||
layer_smooth = layer.tp_smooth
|
||||
if hasattr(layer, 'tp_qweight'):
|
||||
layer_qweight = layer.tp_qweight
|
||||
if hasattr(layer, 'tp_per_channel_scale'):
|
||||
layer_per_channel_scale = layer.tp_per_channel_scale
|
||||
|
||||
quant_input = None
|
||||
if self.skip_quant_input:
|
||||
quant_input = x
|
||||
elif self.quant_config.input_quant_method == "per_token":
|
||||
if self.is_fp8:
|
||||
quant_input, input_scale = mlu_ops.scaled_quantize(x,
|
||||
layer_smooth,
|
||||
quant_type=self.weight_dtype,
|
||||
quant_mode='dynamic_per_token')
|
||||
else:
|
||||
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer_smooth, None)
|
||||
elif self.quant_config.input_quant_method == "per_tensor":
|
||||
quant_input = mlu_ops.quantize(x, layer.scale_to_int, None)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Currently, only per_token or per_tensor input quantization is supported for "
|
||||
f"SmoothQuant, but got {self.input_quant_method}.")
|
||||
quant_input_shape = quant_input.shape
|
||||
if len(quant_input_shape) > 2:
|
||||
quant_input = quant_input.view(-1, quant_input_shape[-1])
|
||||
input_scale = input_scale.view(-1)
|
||||
if residual is not None and len(residual.shape) > 2:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
if self.is_fp8:
|
||||
out = mlu_ops.scaled_matmul(quant_input, layer_qweight, input_scale,
|
||||
layer_per_channel_scale,
|
||||
self.compute_dtype if hasattr(self, 'compute_dtype') else x.dtype,
|
||||
bias,
|
||||
c=residual, act_mode="none",quant_bit_size=8,
|
||||
alpha=1.0, beta=1.0, use_hp_active=False,
|
||||
a_quant_bit_size=8, a_calib=None, b_calib=None)
|
||||
if output is not None:
|
||||
out = out.view(output.shape)
|
||||
output.copy_(out)
|
||||
out = output
|
||||
else:
|
||||
if output is not None:
|
||||
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
|
||||
layer_per_channel_scale, self.compute_dtype, bias, residual, output=output)
|
||||
else:
|
||||
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
|
||||
layer_per_channel_scale, self.compute_dtype, bias, residual)
|
||||
if len(quant_input_shape) > 2:
|
||||
out = out.view(*quant_input_shape[:-1], out.shape[-1])
|
||||
return out
|
||||
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
QUANTIZATION_CHOICES = ['int8', 'int4', 'e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz']
|
||||
INTERGER_DTYPES = [torch.uint8, torch.uint16, torch.uint32, torch.uint64, torch.int8, torch.int16, torch.short,
|
||||
torch.int32, torch.int, torch.int64, torch.long]
|
||||
FLOAT_DTYPES = [torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.bfloat16,
|
||||
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz, torch.half]
|
||||
FP8_DTYPE = [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
|
||||
FP8_STR_DTYPE = ['e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz']
|
||||
GEMM_GROUP_SIZE = [64, 128, 256, 512]
|
||||
|
||||
_STR_TO_TORCH_DTYPE_DICT = dict(
|
||||
bfloat16=torch.bfloat16,
|
||||
float16=torch.float16,
|
||||
float32=torch.float32,
|
||||
int64=torch.int64,
|
||||
int32=torch.int32,
|
||||
int8=torch.int8,
|
||||
bool=torch.bool,
|
||||
e4m3fn=torch.float8_e4m3fn,
|
||||
e4m3fnuz=torch.float8_e4m3fnuz,
|
||||
e5m2=torch.float8_e5m2,
|
||||
e5m2fnuz=torch.float8_e5m2fnuz,
|
||||
)
|
||||
|
||||
TORCH_DTYPE_TO_STR_DICT = {
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.float16: "float16",
|
||||
torch.float32: "float32",
|
||||
torch.int64: "int64",
|
||||
torch.int32: "int32",
|
||||
torch.int8: "int8",
|
||||
torch.bool: "bool",
|
||||
torch.float8_e4m3fn: "e4m3fn",
|
||||
torch.float8_e4m3fnuz: "e4m3fnuz",
|
||||
torch.float8_e5m2: "e5m2",
|
||||
torch.float8_e5m2fnuz: "e5m2fnuz",
|
||||
}
|
||||
|
||||
STR_DTYPE_TO_BITS_DICT = {
|
||||
"bfloat16": 16,
|
||||
"float16": 16,
|
||||
"float32": 32,
|
||||
"int64": 64,
|
||||
"int32": 32,
|
||||
"int8": 8,
|
||||
'int4': 4,
|
||||
"bool": 1,
|
||||
"e4m3fn": 8,
|
||||
"e4m3fnuz": 8,
|
||||
"e5m2": 8,
|
||||
"e5m2fnuz": 8,
|
||||
}
|
||||
|
||||
|
||||
def str_dtype_to_torch(str_dtype: str):
|
||||
'''
|
||||
convert torch dytpe to str dtype
|
||||
'''
|
||||
ret = _STR_TO_TORCH_DTYPE_DICT.get(str_dtype)
|
||||
dtype = ret if ret is not None else torch.float16
|
||||
return dtype
|
||||
|
||||
|
||||
def torch_dtype_to_str(dtype: torch.dtype):
|
||||
'''
|
||||
convert torch dytpe to str dtype
|
||||
'''
|
||||
ret = TORCH_DTYPE_TO_STR_DICT.get(dtype)
|
||||
str_dtype = ret if ret is not None else "float16"
|
||||
return str_dtype
|
||||
|
||||
|
||||
def str_dtype_to_bits(str_dtype):
|
||||
'''
|
||||
convert torch dtype to bits size
|
||||
'''
|
||||
ret = STR_DTYPE_TO_BITS_DICT.get(str_dtype)
|
||||
bits = ret if ret is not None else 8
|
||||
return bits
|
||||
|
||||
|
||||
def is_integer_dtype(dtype: torch.dtype):
|
||||
'''
|
||||
check whether is integer or not
|
||||
'''
|
||||
return dtype in INTERGER_DTYPES
|
||||
|
||||
|
||||
def is_float_dtype(dtype: torch.dtype):
|
||||
'''
|
||||
check whether is float or not
|
||||
'''
|
||||
return dtype in FLOAT_DTYPES
|
||||
|
||||
|
||||
def is_fp8_dtype(dtype: torch.dtype):
|
||||
'''
|
||||
judge fp8 torch dtype
|
||||
'''
|
||||
return dtype in FP8_DTYPE
|
||||
|
||||
|
||||
def is_fp8_str_dtype(str_dtype: str):
|
||||
'''
|
||||
judge fp8 str dtype
|
||||
'''
|
||||
return str_dtype in FP8_STR_DTYPE
|
||||
424
vllm_mlu/model_executor/layers/quantization/utils/fp8_utils.py
Normal file
424
vllm_mlu/model_executor/layers/quantization/utils/fp8_utils.py
Normal file
@@ -0,0 +1,424 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/pull/2575
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_per_token_group_quant_fp8_colmajor)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: get total core for split triton kernel
|
||||
'''
|
||||
|
||||
import triton.backends.mlu.driver as driver
|
||||
|
||||
_devprob = driver.BangUtils().get_device_properties(torch.mlu.current_device())
|
||||
TOTAL_CLUSTER_NUM = _devprob.get("cluster_num")
|
||||
TOTAL_CORE_NUM = TOTAL_CLUSTER_NUM * _devprob.get("core_num_per_cluster")
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
|
||||
and weight.shape[1] % 128 == 0)
|
||||
if current_platform.is_rocm():
|
||||
# TODO this is never used, as cutlass_block_fp8_supported is False
|
||||
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
|
||||
input_2d.shape[:-1])[::-1]
|
||||
scale_b_shape = (weight_scale.view(-1, 1)
|
||||
if weight_scale.dim() <= 1 else weight_scale.T).shape
|
||||
ar, ac = scale_a_shape
|
||||
br, bc = scale_b_shape
|
||||
if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0])
|
||||
or br not in (1, weight.shape[0])):
|
||||
shape_supported_by_cutlass = False
|
||||
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True)
|
||||
output = ops.cutlass_scaled_mm(q_input,
|
||||
weight.T,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale.T)
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=False)
|
||||
output = w8a8_block_fp8_matmul(q_input,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
block_size,
|
||||
output_dtype=input.dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
def per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
column_major_scales: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
It converts the tensor values into signed float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
Args:
|
||||
x: The input tensor with ndim >= 2.
|
||||
group_size: The group size used for quantization.
|
||||
eps: The minimum to avoid dividing zero.
|
||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||
is supported for now.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
"""
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
assert (x.shape[-1] % group_size == 0), (
|
||||
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
||||
f"by `group_size` {group_size}")
|
||||
assert x.stride(-1) == 1, "`x` groups must be contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: split for limit the memory usage(65536)
|
||||
'''
|
||||
group_per_block = 1
|
||||
while M >= 65536:
|
||||
group_per_block *= 2
|
||||
M = x.numel() // (group_size * group_per_block)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if column_major_scales:
|
||||
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
|
||||
x_s = torch.empty(shape, device=x.device,
|
||||
dtype=torch.float32).permute(-1, -2)
|
||||
else:
|
||||
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
|
||||
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: set num_warps to 1 for triton-mlu
|
||||
'''
|
||||
num_warps = 1
|
||||
num_stages = 1
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if column_major_scales:
|
||||
_per_token_group_quant_fp8_colmajor[(M, )](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
x_s.stride(1),
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replaced the 'scaled_quantize' kernel from the 'tmo' library with
|
||||
'_per_token_group_quant_fp8' kernel
|
||||
'''
|
||||
# Check if x is contiguous, if not, create a new tensor for contiguous x
|
||||
if not x.is_contiguous():
|
||||
x = x.contiguous()
|
||||
x_origin_shape = x.shape
|
||||
x = x.reshape(*x.shape[:-1], -1, group_size)
|
||||
x_q, x_s = mlu_ops.scaled_quantize(x,
|
||||
None,
|
||||
quant_type=dtype,
|
||||
quant_mode='dynamic_per_token')
|
||||
x_q = x_q.reshape(x_origin_shape)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
@triton.jit
|
||||
def _w8a8_block_fp8_matmul(
|
||||
# Pointers to inputs and output
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
# Shape for matmul
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
# Block size for block-wise quantization
|
||||
group_n,
|
||||
group_k,
|
||||
# Stride for inputs and output
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_As_m,
|
||||
stride_As_k,
|
||||
stride_Bs_k,
|
||||
stride_Bs_n,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""Triton-accelerated function used to perform linear operations (dot
|
||||
product) on input tensors `A` and `B` with block-wise quantization, and
|
||||
store the result in output tensor `C`.
|
||||
"""
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: split for limit the memory usage(65536)
|
||||
'''
|
||||
num_block_size_all = num_pid_m * num_pid_n
|
||||
num_block_size_per = num_block_size_all // tl.num_programs(axis=0)
|
||||
num_block_size_rem = num_block_size_all % tl.num_programs(axis=0)
|
||||
|
||||
core_deal_num_block_size = num_block_size_per + (pid < num_block_size_rem)
|
||||
core_deal_num_block_start = num_block_size_per * pid + min(num_block_size_rem, pid)
|
||||
|
||||
for pid_i in range(0, core_deal_num_block_size):
|
||||
pid_in_core_deal_block = core_deal_num_block_start + pid_i
|
||||
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid_in_core_deal_block // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid_in_core_deal_block % group_size_m)
|
||||
pid_n = (pid_in_core_deal_block % num_pid_in_group) // group_size_m
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
As_ptrs = As + offs_am * stride_As_m
|
||||
offs_bsn = offs_bn // group_n
|
||||
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs,
|
||||
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
||||
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
||||
|
||||
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if C.dtype.element_ty == tl.bfloat16:
|
||||
c = accumulator.to(tl.bfloat16)
|
||||
elif C.dtype.element_ty == tl.float16:
|
||||
c = accumulator.to(tl.float16)
|
||||
else:
|
||||
c = accumulator.to(tl.float32)
|
||||
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
def w8a8_block_fp8_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
quantization.
|
||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||
The output is returned in the specified `output_dtype`.
|
||||
Args:
|
||||
A: The input tensor, e.g., activation.
|
||||
B: The input tensor, e.g., weight.
|
||||
As: The per-token-group quantization scale for `A`.
|
||||
Bs: The per-block quantization scale for `B`.
|
||||
block_size: The block size for per-block quantization. It should
|
||||
be 2-dim, e.g., [128, 128].
|
||||
output_dytpe: The dtype of the returned tensor.
|
||||
Returns:
|
||||
torch.Tensor: The result of matmul.
|
||||
"""
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replaced the 'scaled_matmul' kernel from the 'tmo' library with
|
||||
'_w8a8_block_fp8_matmul' kernel
|
||||
'''
|
||||
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||
assert B.ndim == 2 and Bs.ndim == 2
|
||||
|
||||
if (B.shape[0] % 128 == 0) and (B.shape[1] % 128 == 0):
|
||||
C = mlu_ops.scaled_matmul(A, B, As, Bs, output_dtype, bias=None, c=None, act_mode="none",
|
||||
quant_bit_size=8, alpha=1, beta=1, use_hp_active=False,
|
||||
a_quant_bit_size=8, a_calib=None, b_calib=None)
|
||||
else:
|
||||
# NOTE(wulingchao): scaled_matmul 底层算子只支持n和k是128的倍数
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
M = A.numel() // A.shape[-1]
|
||||
|
||||
assert B.ndim == 2 and Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N, )
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
# Default config
|
||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
|
||||
# BLOCK_SIZE_K must be divisible by block_size[1]
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_size[0],
|
||||
"BLOCK_SIZE_K": block_size[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 1,
|
||||
"num_stages": 1,
|
||||
}
|
||||
|
||||
def grid(META):
|
||||
return (TOTAL_CORE_NUM, )
|
||||
|
||||
_w8a8_block_fp8_matmul[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
)
|
||||
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
return C
|
||||
178
vllm_mlu/model_executor/layers/quantization/utils/w8a8_utils.py
Normal file
178
vllm_mlu/model_executor/layers/quantization/utils/w8a8_utils.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Optional, Callable
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, USE_ROWWISE_TORCH_SCALED_MM, cutlass_w8a8_scaled_mm,
|
||||
flashinfer_w8a8_scaled_mm, rocm_per_tensor_w8a8_scaled_mm,
|
||||
torch_per_tensor_w8a8_scaled_mm, torch_per_token_w8a8_scaled_mm,
|
||||
torch_channelwise_w8a8_scaled_mm)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def mlu_w8a8_scaled_mm(
|
||||
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
output_shape: list, **kwargs
|
||||
) -> torch.Tensor:
|
||||
output = mlu_ops.scaled_matmul(
|
||||
qinput, # a
|
||||
weight, # b
|
||||
scale_a, # a_scale
|
||||
scale_b, # b_scale
|
||||
out_dtype, # output_dtype
|
||||
bias, # bias
|
||||
c=None, act_mode="none",quant_bit_size=8, alpha=1, beta=1, use_hp_active=False,
|
||||
a_quant_bit_size=8, a_calib=None, b_calib=None
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
def dispatch_w8a8_scaled_mm(
|
||||
preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool,
|
||||
weight_per_channel: bool, activation_per_token: bool
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
if preferred_backend == "rocm":
|
||||
return rocm_per_tensor_w8a8_scaled_mm
|
||||
if preferred_backend == "flashinfer":
|
||||
return flashinfer_w8a8_scaled_mm
|
||||
if preferred_backend == "cutlass":
|
||||
return cutlass_w8a8_scaled_mm
|
||||
return torch_per_tensor_w8a8_scaled_mm
|
||||
|
||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
|
||||
return cutlass_w8a8_scaled_mm
|
||||
|
||||
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
|
||||
if (
|
||||
not per_tensor_weights
|
||||
and not per_tensor_activations
|
||||
and USE_ROWWISE_TORCH_SCALED_MM
|
||||
):
|
||||
return torch_per_token_w8a8_scaled_mm
|
||||
# Normally, torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: dispatch to mlu_w8a8_scaled_mm
|
||||
'''
|
||||
if weight_per_channel and activation_per_token:
|
||||
return mlu_w8a8_scaled_mm
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return torch_channelwise_w8a8_scaled_mm
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
input_scale_ub: torch.Tensor | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
weight_per_channel: bool = True,
|
||||
activation_per_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add mlu_fp8_supported
|
||||
'''
|
||||
self.mlu_fp8_supported = False
|
||||
if weight_per_channel and activation_per_token:
|
||||
self.mlu_fp8_supported = True
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[1]]
|
||||
|
||||
if out_dtype is None:
|
||||
out_dtype = input.dtype
|
||||
|
||||
if self.mlu_fp8_supported:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add support for activation-per-token weight-per-channel quantization.
|
||||
'''
|
||||
qinput, x_scale = mlu_ops.scaled_quantize(
|
||||
input_2d,# x
|
||||
None, # scale
|
||||
None, # zero
|
||||
None, # scale_ub
|
||||
quant_type=torch.float8_e4m3fn,
|
||||
quant_mode='dynamic_per_token'
|
||||
)
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
if input.dtype != current_platform.fp8_dtype():
|
||||
qinput, x_scale = self.quant_fp8(
|
||||
input_2d,
|
||||
input_scale,
|
||||
input_scale_ub,
|
||||
)
|
||||
else:
|
||||
qinput, x_scale = input_2d, input_scale
|
||||
|
||||
# Must have dim() conditions
|
||||
# In per-token quant scenario, when the number of token is 1,
|
||||
# the scale will only have 1 elements.
|
||||
# Without checking the dim(),
|
||||
# we cannot distingushes between per-tensor and per-token quant.
|
||||
# Example:
|
||||
# When the number of token is 1, per-token scale is [[1]]
|
||||
# When per-tensor scale is [1] or ().
|
||||
per_tensor_weights = weight_scale.numel() == 1
|
||||
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
|
||||
|
||||
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
|
||||
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
|
||||
self.preferred_backend, per_tensor_weights, per_tensor_activations,
|
||||
weight_per_channel, activation_per_token)
|
||||
return w8a8_scaled_mm_func(
|
||||
qinput=qinput,
|
||||
weight=weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearOp,
|
||||
Fp8LinearOp.apply,
|
||||
vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply
|
||||
)
|
||||
150
vllm_mlu/model_executor/layers/quantization/weightonly.py
Executable file
150
vllm_mlu/model_executor/layers/quantization/weightonly.py
Executable file
@@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# @register_quantization_config("weightonly")
|
||||
class WeightOnlyConfig(QuantizationConfig):
|
||||
"""Config class for WeightOnly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
quant_mode: str, # weight_only
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.quant_mode = quant_mode
|
||||
|
||||
if quant_mode == "WeightOnly" and (self.weight_bits != 8 and self.weight_bits != 4):
|
||||
raise ValueError(
|
||||
"Currently, only 8/4-bit weight quantization is supported for "
|
||||
f"weight_only, but got {self.weight_bits} bits.")
|
||||
self.pack_factor = 8 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"WeightOnlyConfig(weight_bits={self.weight_bits}, "
|
||||
f"quant_mode={self.quant_mode})")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "WeightOnly"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "WeightOnlyConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
try:
|
||||
quant_mode = cls.get_from_keys(config, ["quant_mode"])
|
||||
except Exception:
|
||||
quant_mode = "WeightOnly"
|
||||
return cls(weight_bits, quant_mode)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["WeightOnlyLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return WeightOnlyLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
|
||||
class WeightOnlyLinearMethod(LinearMethodBase):
|
||||
"""Linear method for WeightOnly.
|
||||
|
||||
Args:
|
||||
quant_config: The WeightOnly quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: WeightOnlyConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> Dict[str, Any]:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if self.quant_config.quant_mode == "WeightOnly":
|
||||
scale_and_zero_input_dim = None
|
||||
if output_size != output_size_per_partition:
|
||||
scale_and_zero_input_dim = 0
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
device="mlu",
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(qweight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(scales, {
|
||||
"input_dim": scale_and_zero_input_dim,
|
||||
"output_dim": 0,
|
||||
})
|
||||
layer.register_parameter("qweight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("scales", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if layer.scales.dtype != torch.float:
|
||||
layer.scales = Parameter(layer.scales.to(torch.float), requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
x_shape = x.shape
|
||||
if len(x_shape) > 2:
|
||||
x = x.view(-1, x_shape[-1])
|
||||
out = mlu_ops.weight_only_quant_matmul(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
None,
|
||||
bias,
|
||||
residual,
|
||||
"none",
|
||||
self.quant_config.weight_bits)
|
||||
if len(x_shape) > 2:
|
||||
out = out.view(*x_shape[:-1], out.shape[-1])
|
||||
return out
|
||||
342
vllm_mlu/model_executor/layers/rotary_embedding/__init__.py
Normal file
342
vllm_mlu/model_executor/layers/rotary_embedding/__init__.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
import vllm.model_executor.layers.rotary_embedding as rotary_embedding
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
_ROPE_DICT,
|
||||
RotaryEmbedding,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
_ROPE_DICT,
|
||||
DualChunkRotaryEmbedding,
|
||||
DynamicNTKAlphaRotaryEmbedding,
|
||||
DynamicNTKScalingRotaryEmbedding,
|
||||
Llama4VisionRotaryEmbedding,
|
||||
MRotaryEmbedding,
|
||||
NTKScalingRotaryEmbedding,
|
||||
Phi3LongRoPEScaledRotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding,
|
||||
)
|
||||
|
||||
from .base import MLURotaryEmbedding
|
||||
from .deepseek_scaling_rope import MLUDeepseekScalingRotaryEmbedding
|
||||
from .dynamic_ntk_alpha_rope import MLUDynamicNTKAlphaRotaryEmbedding
|
||||
from .dynamic_ntk_scaling_rope import MLUDynamicNTKScalingRotaryEmbedding
|
||||
from .linear_scaling_rope import MLULinearScalingRotaryEmbedding
|
||||
from .llama3_rope import MLULlama3RotaryEmbedding
|
||||
from .mrope import MLUMRotaryEmbedding
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_long_max_model_max_position_emb(max_position_embeddings, scaling_factor):
|
||||
if MLURotaryEmbedding.max_seq_len != None and \
|
||||
MLURotaryEmbedding.max_seq_len > max_position_embeddings * scaling_factor:
|
||||
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
|
||||
f"max_position_embedding ({max_position_embeddings}) * scaling_factor ({scaling_factor}) " +
|
||||
"from model's config.json, This may lead to incorrect model outputs or MLU errors. " +
|
||||
f"Make sure the value is correct and within the model context size. " +
|
||||
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
|
||||
return math.ceil(MLURotaryEmbedding.max_seq_len / scaling_factor)
|
||||
return max_position_embeddings
|
||||
|
||||
def vllm__model_executor__layers__rotary_embedding__get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: float,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: dict[str, Any] | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
dual_chunk_attention_config: dict[str, Any] | None = None,
|
||||
inverse: bool = False
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if rope_scaling is not None:
|
||||
# Transforms every value that is a list into a tuple for caching calls
|
||||
rope_scaling_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
|
||||
}
|
||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
|
||||
if dual_chunk_attention_config is not None:
|
||||
dual_chunk_attention_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v
|
||||
for k, v in dual_chunk_attention_config.items()
|
||||
if k != "sparse_attention_config"
|
||||
}
|
||||
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
|
||||
else:
|
||||
dual_chunk_attention_args = None
|
||||
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
key = (
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
rope_scaling_args,
|
||||
dual_chunk_attention_args,
|
||||
dtype,
|
||||
inverse,
|
||||
)
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
|
||||
if dual_chunk_attention_config is not None:
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in dual_chunk_attention_config.items()
|
||||
if k in ("chunk_size", "local_size")
|
||||
}
|
||||
rotary_emb = DualChunkRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif not rope_scaling:
|
||||
rotary_emb = MLURotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype,
|
||||
inverse=inverse,
|
||||
)
|
||||
else:
|
||||
scaling_type = rope_scaling["rope_type"]
|
||||
|
||||
if scaling_type == "llama3":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
low_freq_factor = rope_scaling["low_freq_factor"]
|
||||
high_freq_factor = rope_scaling["high_freq_factor"]
|
||||
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||
rotary_emb = MLULlama3RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
scaling_factor,
|
||||
low_freq_factor,
|
||||
high_freq_factor,
|
||||
original_max_position,
|
||||
)
|
||||
elif scaling_type == "mllama4":
|
||||
rotary_emb = Llama4VisionRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||
)
|
||||
elif scaling_type == "default":
|
||||
if "mrope_section" in rope_scaling:
|
||||
rotary_emb = MLUMRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
|
||||
)
|
||||
else:
|
||||
rotary_emb = MLURotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
inverse=inverse,
|
||||
)
|
||||
elif scaling_type == "linear":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
rotary_emb = MLULinearScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "ntk":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mixed_b = rope_scaling.get('mixed_b', None)
|
||||
rotary_emb = NTKScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
mixed_b,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
if "alpha" in rope_scaling:
|
||||
scaling_alpha = rope_scaling["alpha"]
|
||||
rotary_emb = MLUDynamicNTKAlphaRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_alpha,
|
||||
dtype,
|
||||
)
|
||||
elif "factor" in rope_scaling:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
rotary_emb = MLUDynamicNTKScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
|
||||
)
|
||||
elif scaling_type == "yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k
|
||||
in (
|
||||
"extrapolation_factor",
|
||||
"attn_factor",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"apply_yarn_scaling",
|
||||
)
|
||||
}
|
||||
if "mrope_section" in rope_scaling:
|
||||
extra_kwargs.pop("apply_yarn_scaling", None)
|
||||
rotary_emb = MRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
|
||||
scaling_factor=scaling_factor,
|
||||
**extra_kwargs,
|
||||
)
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: update original_max_position
|
||||
'''
|
||||
original_max_position = get_long_max_model_max_position_emb(
|
||||
original_max_position, scaling_factor,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
rotary_emb = YaRNScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "deepseek_yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||
# assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k
|
||||
in (
|
||||
"extrapolation_factor",
|
||||
"attn_factor",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
)
|
||||
}
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: update original_max_position
|
||||
'''
|
||||
original_max_position = get_long_max_model_max_position_emb(
|
||||
original_max_position, scaling_factor,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
rotary_emb = MLUDeepseekScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
inverse,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "longrope":
|
||||
short_factor = rope_scaling["short_factor"]
|
||||
long_factor = rope_scaling["long_factor"]
|
||||
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("short_mscale", "long_mscale")
|
||||
}
|
||||
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
short_factor,
|
||||
long_factor,
|
||||
**extra_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
_ROPE_DICT[key] = rotary_emb
|
||||
return rotary_emb
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
rotary_embedding,
|
||||
rotary_embedding.get_rope,
|
||||
vllm__model_executor__layers__rotary_embedding__get_rope,
|
||||
)
|
||||
302
vllm_mlu/model_executor/layers/rotary_embedding/base.py
Normal file
302
vllm_mlu/model_executor/layers/rotary_embedding/base.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Tuple
|
||||
import torch
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.rotary_embedding.base import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.v1.attention.backends.utils import (
|
||||
get_common_metadata,
|
||||
MLUCommonAttentionMetadata,
|
||||
)
|
||||
from vllm_mlu.v1.attention.backends.mla.flashmla import MLACommonMetadata
|
||||
from vllm_mlu.model_executor.models.sp_utils import get_sp_forward_context
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding_mlu")
|
||||
class MLURotaryEmbedding(RotaryEmbedding, CustomOp):
|
||||
|
||||
cu_seq_lens : torch.Tensor = None
|
||||
max_seq_len : int = None
|
||||
max_model_len : int = None
|
||||
is_prompt : bool = False
|
||||
is_chunked : bool = False
|
||||
positions_: torch.Tensor = None
|
||||
chunked_prefill_enabled: bool = False
|
||||
prefill_cu_seq_lens: torch.Tensor = None
|
||||
prefill_max_seq_len: int = None
|
||||
decode_cu_seq_lens: torch.Tensor = None
|
||||
decode_max_seq_len: int = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
inverse: bool = False,
|
||||
) -> None:
|
||||
CustomOp.__init__(self)
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
# TODO(mgoin): disabled for now due to failures
|
||||
# Flashinfer only supports head_size=64, 128, 256, 512.
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202
|
||||
# self.use_flashinfer = (self.enabled()
|
||||
# and dtype in (torch.float16, torch.bfloat16)
|
||||
# and current_platform.is_cuda()
|
||||
# and has_flashinfer()
|
||||
# and self.head_size in [64, 128, 256, 512])
|
||||
self.use_flashinfer = False
|
||||
self.inverse = inverse
|
||||
|
||||
# For vlm v1
|
||||
# 1. mlu rope run in eager mode
|
||||
# 2. all layer use layer0's rope to inference
|
||||
prefix = "global_rope"
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.use_direct_call = False
|
||||
if not self.use_direct_call:
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
pass
|
||||
else:
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import DeepseekScalingRotaryEmbedding
|
||||
from vllm.model_executor.layers.rotary_embedding.yarn_scaling_rope import YaRNScalingRotaryEmbedding
|
||||
|
||||
if MLURotaryEmbedding.max_seq_len != None \
|
||||
and self.max_position_embeddings < MLURotaryEmbedding.max_seq_len and \
|
||||
not isinstance(self, (YaRNScalingRotaryEmbedding, DeepseekScalingRotaryEmbedding)):
|
||||
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
|
||||
f"max_position_embedding ({max_position_embeddings}) from model's config.json, " +
|
||||
f"This may lead to incorrect model outputs or MLU errors. " +
|
||||
f"Make sure the value is correct and within the model context size. " +
|
||||
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
|
||||
self.max_position_embeddings = MLURotaryEmbedding.max_seq_len
|
||||
cache = self._compute_cos_sin_cache()
|
||||
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.linear_scaling_rope import MLULinearScalingRotaryEmbedding
|
||||
if isinstance(self, MLULinearScalingRotaryEmbedding):
|
||||
logger.debug(f"Using mlu defining _compute_cos_sin_cache due to the special tensor composition")
|
||||
elif is_neox_style:
|
||||
cache_pos = cache.shape[0]
|
||||
cache = cache.reshape(cache_pos, 2, -1)
|
||||
cache = torch.tile(cache, (1, 1, 2)).reshape(cache_pos, -1)
|
||||
else:
|
||||
cache = cache.repeat_interleave(2, dim=-1)
|
||||
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.cos_, self.sin_ = self._get_cos_sin()
|
||||
@classmethod
|
||||
def set_mlu_var_v1(
|
||||
cls,
|
||||
common_metadata: MLUCommonAttentionMetadata
|
||||
) -> None:
|
||||
cls.unset_mlu_var()
|
||||
cls.cu_seq_lens = common_metadata.query_start_loc
|
||||
cls.max_seq_len = common_metadata.max_query_len
|
||||
cls.is_prompt = common_metadata.is_prefill_only
|
||||
cls.is_chunked = common_metadata.is_chunked
|
||||
|
||||
# for MLA
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
_, attn_metadata = next(iter(attn_metadata.items()))
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
decode_metadata = attn_metadata.decode
|
||||
if prefill_metadata:
|
||||
cls.prefill_max_seq_len = prefill_metadata.max_query_len
|
||||
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
|
||||
else:
|
||||
cls.prefill_max_seq_len = cls.max_seq_len
|
||||
cls.prefill_cu_seq_lens = cls.cu_seq_lens
|
||||
|
||||
if decode_metadata:
|
||||
cls.decode_max_seq_len = decode_metadata.max_query_len
|
||||
cls.decode_cu_seq_lens = decode_metadata.query_start_loc
|
||||
else:
|
||||
cls.decode_max_seq_len = cls.max_seq_len
|
||||
cls.decode_cu_seq_lens = cls.cu_seq_lens
|
||||
|
||||
# for sp
|
||||
sp_context = get_sp_forward_context()
|
||||
if sp_context is not None and sp_context.is_v32:
|
||||
prefill_metadata = sp_context.sp_attn_metadata.prefill
|
||||
cls.is_chunked = True
|
||||
cls.prefill_max_seq_len = prefill_metadata.max_query_len
|
||||
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
|
||||
|
||||
@classmethod
|
||||
def unset_mlu_var(cls):
|
||||
cls.cu_seq_lens = None
|
||||
cls.max_seq_len = None
|
||||
cls.is_prompt = False
|
||||
cls.is_chunked = False
|
||||
cls.positions_ = None
|
||||
cls.chunked_prefill_enabled = False
|
||||
cls.prefill_cu_seq_lens = None
|
||||
cls.prefill_max_seq_len = None
|
||||
cls.decode_cu_seq_lens = None
|
||||
cls.decode_max_seq_len = None
|
||||
|
||||
def _get_cos_sin(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
|
||||
sin = sin.view(-1, self.rotary_dim)
|
||||
cos = cos.view(-1, self.rotary_dim)
|
||||
return cos, sin
|
||||
|
||||
def _get_positions_with_offsets_mlu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
offsets: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if offsets.numel() != positions.numel():
|
||||
raise Exception("rope offsets numel mismatch with positions, "
|
||||
f"positions: {positions.numel()}, offsets: {offsets.numel()}")
|
||||
return (positions + offsets).to(torch.int32)
|
||||
|
||||
def forward_impl(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
common_metadata: MLUCommonAttentionMetadata = get_common_metadata()
|
||||
if common_metadata is None:
|
||||
num_tokens, head_num, head_size = x.shape
|
||||
x = mlu_ops.rotary_embedding(
|
||||
x.view(1, num_tokens, head_num, head_size),
|
||||
self.sin_,
|
||||
self.cos_,
|
||||
positions,
|
||||
None,
|
||||
not self.is_neox_style,
|
||||
True,
|
||||
False,
|
||||
num_tokens
|
||||
)
|
||||
return x
|
||||
else:
|
||||
cu_seq_lens_ = common_metadata.query_start_loc
|
||||
|
||||
if offsets is not None:
|
||||
if MLURotaryEmbedding.positions_ is None:
|
||||
MLURotaryEmbedding.positions_ = (
|
||||
self._get_positions_with_offsets_mlu(positions, offsets))
|
||||
position_ids = MLURotaryEmbedding.positions_
|
||||
discrete = True
|
||||
elif MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
|
||||
position_ids = positions
|
||||
discrete = True
|
||||
else:
|
||||
position_ids = None
|
||||
discrete = False
|
||||
|
||||
x = mlu_ops.rotary_embedding(
|
||||
x,
|
||||
self.sin_,
|
||||
self.cos_,
|
||||
position_ids,
|
||||
cu_seq_lens_,
|
||||
not self.is_neox_style,
|
||||
discrete,
|
||||
False,
|
||||
MLURotaryEmbedding.max_seq_len
|
||||
)
|
||||
return x
|
||||
|
||||
def get_param(self, positions, discrete=False):
|
||||
interleaved = True
|
||||
if self.is_neox_style:
|
||||
interleaved = False
|
||||
|
||||
if discrete:
|
||||
position_ids = positions
|
||||
discrete = discrete
|
||||
else:
|
||||
if MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
|
||||
position_ids = positions
|
||||
discrete = True
|
||||
else:
|
||||
position_ids = None
|
||||
discrete = False
|
||||
|
||||
return position_ids, interleaved, discrete
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
cos = freqs_cis.real
|
||||
sin = freqs_cis.imag * (-1 if self.inverse else 1)
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor | None = None,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
only_prefill: bool | None = False,
|
||||
only_decode: bool | None = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self.forward_impl(positions, query, offsets)
|
||||
if key is not None:
|
||||
self.forward_impl(positions, key, offsets)
|
||||
return query, key
|
||||
|
||||
|
||||
def rope_forward(
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
layer_name: str,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
self.forward_impl(positions, x, offsets)
|
||||
|
||||
|
||||
def rope_forward_fake(
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
layer_name: str,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rope_forward",
|
||||
op_func=rope_forward,
|
||||
mutates_args=["x"],
|
||||
fake_impl=rope_forward_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
@@ -0,0 +1,166 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Tuple
|
||||
import torch
|
||||
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import (
|
||||
DeepseekScalingRotaryEmbedding,
|
||||
yarn_get_mscale,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
rotate_gptj,
|
||||
rotate_neox,
|
||||
yarn_find_correction_range,
|
||||
yarn_linear_ramp_mask,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
|
||||
|
||||
|
||||
class MLUDeepseekScalingRotaryEmbedding(MLURotaryEmbedding, DeepseekScalingRotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
inverse: bool = False,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
||||
attn_factor)
|
||||
self.inverse = inverse
|
||||
MLURotaryEmbedding.__init__(
|
||||
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
def forward_mlu_rot(self, input, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len):
|
||||
"""only one input rotary implementation"""
|
||||
if input is None:
|
||||
return None
|
||||
if self.rotary_dim < self.head_size:
|
||||
input_pass = input[..., self.rotary_dim:]
|
||||
input_rot = input[..., :self.rotary_dim]
|
||||
input_rot = mlu_ops.rotary_embedding(
|
||||
input_rot,
|
||||
self.sin_,
|
||||
self.cos_,
|
||||
position_ids,
|
||||
cu_seq_lens,
|
||||
interleaved,
|
||||
discrete,
|
||||
False,
|
||||
max_seq_len
|
||||
)
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
input = torch.cat((input_rot, input_pass), dim=-1)
|
||||
else:
|
||||
input = input_rot
|
||||
|
||||
return input
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor | None = None,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
only_prefill: bool | None = False,
|
||||
only_decode: bool | None = False,
|
||||
discrete: bool | None = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
position_ids, interleaved, discrete = self.get_param(positions, discrete)
|
||||
|
||||
cu_seq_lens = MLURotaryEmbedding.cu_seq_lens
|
||||
max_seq_len = MLURotaryEmbedding.max_seq_len
|
||||
|
||||
# for MLA
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
_, attn_metadata = next(iter(attn_metadata.items()))
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
if only_prefill:
|
||||
cu_seq_lens = MLURotaryEmbedding.prefill_cu_seq_lens
|
||||
max_seq_len = MLURotaryEmbedding.prefill_max_seq_len
|
||||
elif only_decode:
|
||||
cu_seq_lens = MLURotaryEmbedding.decode_cu_seq_lens
|
||||
max_seq_len = MLURotaryEmbedding.decode_max_seq_len
|
||||
|
||||
query = self.forward_mlu_rot(query, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len)
|
||||
key = self.forward_mlu_rot(key, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len)
|
||||
|
||||
return query, key
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base ** (
|
||||
torch.arange(
|
||||
0,
|
||||
self.rotary_dim,
|
||||
2,
|
||||
dtype=torch.float,
|
||||
device=current_platform.device_type,
|
||||
)
|
||||
/ self.rotary_dim
|
||||
)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
self.rotary_dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
device = current_platform.device_type
|
||||
inv_freq_mask = ((
|
||||
1
|
||||
- yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
|
||||
) * self.extrapolation_factor).to(device)
|
||||
inv_freq = (
|
||||
inv_freq_interpolation * (1 - inv_freq_mask)
|
||||
+ inv_freq_extrapolation * inv_freq_mask
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(
|
||||
self.max_position_embeddings * self.scaling_factor,
|
||||
device=current_platform.device_type,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos() * self.mscale
|
||||
sin = freqs.sin() * self.mscale * (-1 if self.inverse else 1)
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
forward = MLURotaryEmbedding.forward
|
||||
forward_native = forward_oot
|
||||
@@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding.dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
|
||||
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
|
||||
|
||||
|
||||
class MLUDynamicNTKAlphaRotaryEmbedding(MLURotaryEmbedding, DynamicNTKAlphaRotaryEmbedding):
|
||||
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
||||
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
scaling_alpha: float,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
self.scaling_alpha = scaling_alpha
|
||||
MLURotaryEmbedding.__init__(
|
||||
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding.dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
|
||||
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
|
||||
|
||||
|
||||
class MLUDynamicNTKScalingRotaryEmbedding(MLURotaryEmbedding, DynamicNTKScalingRotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
MLURotaryEmbedding.__init__(
|
||||
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.model_executor.layers.rotary_embedding.linear_scaling_rope import LinearScalingRotaryEmbedding
|
||||
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
|
||||
|
||||
class MLULinearScalingRotaryEmbedding(MLURotaryEmbedding, LinearScalingRotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
scaling_factors: list[float] | float,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
if isinstance(scaling_factors, float):
|
||||
scaling_factors = [scaling_factors]
|
||||
self.scaling_factors: list[float] = scaling_factors # noqa
|
||||
MLURotaryEmbedding.__init__(
|
||||
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
# Lazy initialized.
|
||||
self._scaling_factor_to_offset: dict[float, int]
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
device = current_platform.device_type
|
||||
if self.is_neox_style:
|
||||
half_dim = self.rotary_dim // 2
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, self.rotary_dim, 1, dtype=torch.float32, device=device)
|
||||
% half_dim * 2 / self.rotary_dim)
|
||||
)
|
||||
else:
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, self.rotary_dim, 1, dtype=torch.float32, device=device)
|
||||
// 2 * 2 / self.rotary_dim
|
||||
)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
cache_list: list[torch.Tensor] = []
|
||||
# offsets to the next cache in a tensor.
|
||||
# Each offset corresponds to the same index in scaling_factors.
|
||||
offsets: list[int] = []
|
||||
device = current_platform.device_type
|
||||
for scaling_factor in self.scaling_factors:
|
||||
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||
# maximum length before applying the rope scaling.
|
||||
# Thus, the maximum length after applying the rope scaling is
|
||||
# self.max_position_embeddings * self.scaling_factor.
|
||||
max_len = self.max_position_embeddings * scaling_factor
|
||||
t = torch.arange(max_len, dtype=torch.float, device=device)
|
||||
t = t / scaling_factor
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
if not cache_list:
|
||||
offset = 0
|
||||
else:
|
||||
last_offset = offsets[-1]
|
||||
next_max_len = cache_list[-1].shape[0]
|
||||
offset = last_offset + next_max_len
|
||||
offsets.append(offset)
|
||||
cache_list.append(cache)
|
||||
self._scaling_factor_to_offset = {
|
||||
float(scaling_factor): offsets[i]
|
||||
for i, scaling_factor in enumerate(self.scaling_factors)
|
||||
}
|
||||
assert len(self.scaling_factors) == len(offsets)
|
||||
return torch.cat(cache_list, dim=0)
|
||||
@@ -0,0 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
|
||||
|
||||
|
||||
class MLULlama3RotaryEmbedding(MLURotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
scaling_factor: float,
|
||||
low_freq_factor: float,
|
||||
high_freq_factor: float,
|
||||
orig_max_position: int,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.low_freq_factor = low_freq_factor
|
||||
self.high_freq_factor = high_freq_factor
|
||||
self.orig_max_position = orig_max_position
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
140
vllm_mlu/model_executor/layers/rotary_embedding/mrope.py
Normal file
140
vllm_mlu/model_executor/layers/rotary_embedding/mrope.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding.common import yarn_get_mscale
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
|
||||
|
||||
|
||||
class MLUMRotaryEmbedding(MLURotaryEmbedding, MRotaryEmbedding):
|
||||
"""Rotary Embedding with Multimodal Sections."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
mrope_section: list[int] | None = None,
|
||||
mrope_interleaved: bool = False,
|
||||
# YaRN parameters.
|
||||
*,
|
||||
scaling_factor: float | None = None,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
if self.scaling_factor is not None:
|
||||
# Get n-d magnitude scaling corrected for interpolation
|
||||
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||
else:
|
||||
self.mscale = 1.0
|
||||
|
||||
# In Qwen2.5-VL, the maximum index value is related to the duration of
|
||||
# the input video. We enlarge max_position_embeddings to 4 times to get
|
||||
# a larger the cos and sin cache.
|
||||
self.cache_max_position_num = max_position_embeddings * 4
|
||||
MLURotaryEmbedding.__init__(
|
||||
self,
|
||||
head_size,
|
||||
rotary_dim,
|
||||
self.cache_max_position_num,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
)
|
||||
|
||||
self.mrope_section = mrope_section
|
||||
self.mrope_interleaved = mrope_interleaved
|
||||
if self.mrope_section:
|
||||
assert sum(self.mrope_section) == rotary_dim // 2
|
||||
|
||||
def _apply_mrope(self, positions):
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
num_section = len(self.mrope_section)
|
||||
mrope_section = self.mrope_section * 2
|
||||
def _apply(x):
|
||||
x = torch.cat([
|
||||
m[i % num_section]
|
||||
for i, m in enumerate(x.split(mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
return x
|
||||
return _apply(cos), _apply(sin)
|
||||
|
||||
def _apply_interleaved_mrope(self, positions):
|
||||
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
||||
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
|
||||
interleaved [THTHWHTHW...TT], preserving frequency continuity.
|
||||
"""
|
||||
mrope_section = self.mrope_section
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
def _apply(x):
|
||||
x_t = x[0].clone()
|
||||
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
|
||||
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
|
||||
offset = self.rotary_dim // 2
|
||||
x_t[..., 1 + offset:mrope_section[1] * 3 + offset:3] = x[1, ..., 1 + offset:mrope_section[1] * 3 + offset:3]
|
||||
x_t[..., 2 + offset:mrope_section[2] * 3 + offset:3] = x[2, ..., 2 + offset:mrope_section[2] * 3 + offset:3]
|
||||
return x_t
|
||||
return _apply(cos), _apply(sin)
|
||||
|
||||
def precompute_sin_cos_cache(
|
||||
self,
|
||||
positions: torch.Tensor
|
||||
):
|
||||
'''
|
||||
call this function before forward decoder layers
|
||||
precompute sin/cos cache for mrope
|
||||
'''
|
||||
if positions.ndim == 1:
|
||||
return
|
||||
assert positions.ndim == 2
|
||||
assert self.mrope_section
|
||||
if self.mrope_interleaved:
|
||||
cos, sin = self._apply_interleaved_mrope(positions)
|
||||
else:
|
||||
cos, sin = self._apply_mrope(positions)
|
||||
self.mrope_cos_cache = cos
|
||||
self.mrope_sin_cache = sin
|
||||
self.mrope_cu_seq_lens = torch.zeros(2, dtype=torch.int32, device=positions.device)
|
||||
num_tokens = positions.shape[-1]
|
||||
self.mrope_cu_seq_lens[1] = num_tokens
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
if positions.ndim == 1:
|
||||
return MLURotaryEmbedding.forward_oot(self, positions, x)
|
||||
assert self.mrope_cos_cache is not None and self.mrope_sin_cache is not None,\
|
||||
"call precompute_sin_cos_cache first!"
|
||||
num_tokens = positions.shape[-1]
|
||||
x = mlu_ops.rotary_embedding(x,
|
||||
self.mrope_sin_cache,
|
||||
self.mrope_cos_cache,
|
||||
None,
|
||||
self.mrope_cu_seq_lens,
|
||||
not self.is_neox_style,
|
||||
False,
|
||||
False,
|
||||
num_tokens)
|
||||
return x
|
||||
|
||||
forward = MLURotaryEmbedding.forward
|
||||
1271
vllm_mlu/model_executor/layers/sparse_moe_mlp.py
Normal file
1271
vllm_mlu/model_executor/layers/sparse_moe_mlp.py
Normal file
File diff suppressed because it is too large
Load Diff
3
vllm_mlu/model_executor/model_loader/__init__.py
Normal file
3
vllm_mlu/model_executor/model_loader/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
173
vllm_mlu/model_executor/model_loader/dummy_loader.py
Normal file
173
vllm_mlu/model_executor/model_loader/dummy_loader.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from typing import List, Tuple
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def initialize_dummy_weights_normal_dist(
|
||||
model: torch.nn.Module,
|
||||
low: float = -1e-3,
|
||||
high: float = 1e-3,
|
||||
std: float = 0.5,
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the weights of a PyTorch model with values drawn from a normal distribution.
|
||||
Floating point parameters are initialized with a normal distribution whose mean is randomly
|
||||
sampled from [low, high] and standard deviation is fixed at 0.5. Integer parameters are
|
||||
initialized with random integers in [floor(low), ceil(high)). The initialization is performed
|
||||
in a batched and efficient way for both floating point and integer parameters.
|
||||
|
||||
Optimized version: Uses shared pinned memory based on the largest parameter block size
|
||||
to minimize H2D transfers, sacrificing global uniqueness for performance.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model whose weights will be initialized.
|
||||
low (float): Lower bound for sampling the mean of the normal distribution (for float params).
|
||||
high (float): Upper bound for sampling the mean of the normal distribution (for float params).
|
||||
std (float): Standard deviation for the normal distribution (for float params).
|
||||
seed (int): Random seed for reproducibility.
|
||||
"""
|
||||
# Randomly sample the mean for the normal distribution from [low, high]
|
||||
rng = np.random.RandomState(seed)
|
||||
mean = float(rng.uniform(low, high, 1).item())
|
||||
|
||||
# Create a CPU generator for reproducibility
|
||||
cpu_gen = torch.Generator(device="cpu")
|
||||
cpu_gen.manual_seed(seed)
|
||||
|
||||
# Collect parameters: separate into floating point and integer types
|
||||
float_params: List[Tuple[str, torch.Tensor]] = []
|
||||
int_params: List[Tuple[str, torch.Tensor]] = []
|
||||
|
||||
for name, t in tqdm(model.state_dict().items(), desc="Gen dummy weights: Collect params"):
|
||||
if not isinstance(t, torch.Tensor):
|
||||
continue
|
||||
if torch.is_floating_point(t):
|
||||
float_params.append((name, t))
|
||||
elif t.dtype in (torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64):
|
||||
int_params.append((name, t))
|
||||
|
||||
# -------- Floating point parameters: optimized shared memory initialization --------
|
||||
if float_params:
|
||||
# Find the largest parameter block size
|
||||
max_float_elems = max(p.numel() for _, p in float_params)
|
||||
|
||||
# Create shared pinned memory buffer based on largest parameter
|
||||
shared_float_buffer = torch.empty(max_float_elems, dtype=torch.float32, device="cpu", pin_memory=True)
|
||||
shared_float_buffer.normal_(mean=mean, std=std, generator=cpu_gen)
|
||||
|
||||
# Copy shared buffer to device once
|
||||
device_buffer = shared_float_buffer.to(next(iter(float_params))[1].device, non_blocking=True)
|
||||
|
||||
for _, p in tqdm(float_params, desc="Gen dummy weights: Init float params"):
|
||||
n = p.numel()
|
||||
# Extract from device buffer (may reuse same values for different parameters)
|
||||
view = device_buffer[:n].view(p.shape)
|
||||
|
||||
# torch.normal_ does not support dtypes < fp16, so cast via fp16 if needed
|
||||
if torch.finfo(p.dtype).bits < 16:
|
||||
tmp = view.to(torch.float16)
|
||||
tmp = tmp.to(p.dtype)
|
||||
else:
|
||||
tmp = view.to(p.dtype)
|
||||
|
||||
# Copy from device buffer to parameter (D2D copy, much faster)
|
||||
p.data.copy_(tmp)
|
||||
|
||||
# -------- Integer parameters: optimized shared memory initialization --------
|
||||
if int_params:
|
||||
# Find the largest parameter block size
|
||||
max_int_elems = max(p.numel() for _, p in int_params)
|
||||
|
||||
int_low = int(np.floor(low))
|
||||
int_high = int(np.ceil(high))
|
||||
if int_high == int_low:
|
||||
int_high = int_low + 1 # Ensure at least one possible value
|
||||
|
||||
# Create shared pinned memory buffer based on largest parameter
|
||||
shared_int_buffer = torch.randint(
|
||||
low=int_low,
|
||||
high=int_high,
|
||||
size=(max_int_elems,),
|
||||
dtype=torch.int64,
|
||||
generator=cpu_gen,
|
||||
device="cpu",
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
# Copy shared buffer to device once
|
||||
device_int_buffer = shared_int_buffer.to(next(iter(int_params))[1].device, non_blocking=True)
|
||||
|
||||
for _, p in tqdm(int_params, desc="Gen dummy weights: Init int params"):
|
||||
n = p.numel()
|
||||
# Extract from device buffer (may reuse same values for different parameters)
|
||||
view = device_int_buffer[:n].view(p.shape)
|
||||
tmp = view.to(p.dtype)
|
||||
# Copy from device buffer to parameter (D2D copy, much faster)
|
||||
p.data.copy_(tmp)
|
||||
|
||||
|
||||
SMOOTHQUANT_METHOD = "smoothquant"
|
||||
MULTIMODAL_ARCH_KEYWORDS = {"VL", "Vision", "Multimodal"}
|
||||
def vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use torch.normal_ instead of torch.uniform_ for distinguishable logits
|
||||
std=0.5 is used for better distinguishable logits
|
||||
'''
|
||||
|
||||
# === Default parameter setup (Original values as fallback) ===
|
||||
low_val = -1e-3
|
||||
high_val = 1e-3
|
||||
std_val = 0.5
|
||||
|
||||
# === Model and Quantization Check Logic ===
|
||||
quant_method = getattr(model_config, "quantization", None)
|
||||
|
||||
# Attempt to get the architectures list from model_config
|
||||
archs = getattr(model_config, "architectures", []) or []
|
||||
|
||||
# Determine if the model is multimodal (based on architecture names)
|
||||
is_multimodal = any(
|
||||
keyword in arch
|
||||
for arch in archs
|
||||
for keyword in MULTIMODAL_ARCH_KEYWORDS
|
||||
)
|
||||
|
||||
# === Apply SmoothQuant + Multimodal Parameters ===
|
||||
if is_multimodal and quant_method == SMOOTHQUANT_METHOD:
|
||||
# (smoothquant) + Multimodal specific values to mitigate NaN overflow
|
||||
std_val = 1e-4
|
||||
|
||||
initialize_dummy_weights_normal_dist(
|
||||
model,
|
||||
low=low_val,
|
||||
high=high_val,
|
||||
std=std_val
|
||||
)
|
||||
# add a sync to make sure the weights are initialized
|
||||
torch.mlu.synchronize()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
DummyModelLoader,
|
||||
DummyModelLoader.load_weights,
|
||||
vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights
|
||||
)
|
||||
137
vllm_mlu/model_executor/model_loader/tensorizer.py
Normal file
137
vllm_mlu/model_executor/model_loader/tensorizer.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import time
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
||||
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig, TensorDeserializer, TensorizerArgs,
|
||||
_check_tensors_on_meta_device, _resize_lora_embeddings,
|
||||
is_valid_deserialization_uri)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.logger import init_logger
|
||||
|
||||
try:
|
||||
from tensorizer.stream_io import open_stream
|
||||
from tensorizer.utils import (convert_bytes, get_mem_usage,
|
||||
no_init_or_tensor)
|
||||
|
||||
except ImportError:
|
||||
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
|
||||
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
|
||||
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
|
||||
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def deserialize_tensorizer_model(model: nn.Module,
|
||||
tensorizer_config: TensorizerConfig) -> None:
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri):
|
||||
raise ValueError(
|
||||
f"{tensorizer_config.tensorizer_uri} is not a valid "
|
||||
f"tensorizer URI. Please check that the URI is correct. "
|
||||
f"It must either point to a local existing file, or have a "
|
||||
f"S3, HTTP or HTTPS scheme.")
|
||||
before_mem = get_mem_usage()
|
||||
start = time.perf_counter()
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use mlu device
|
||||
'''
|
||||
device = ''
|
||||
if current_platform.is_out_of_tree():
|
||||
device = f'mlu:{torch.mlu.current_device()}'
|
||||
elif current_platform.is_xpu():
|
||||
device = f'xpu:{torch.xpu.current_device()}'
|
||||
else:
|
||||
device = f'cuda:{torch.cuda.current_device()}'
|
||||
with open_stream(
|
||||
tensorizer_config.tensorizer_uri,
|
||||
mode="rb",
|
||||
**tensorizer_args.stream_kwargs) as stream, TensorDeserializer(
|
||||
stream,
|
||||
dtype=tensorizer_config.dtype,
|
||||
device=device,
|
||||
**tensorizer_args.deserialization_kwargs) as deserializer:
|
||||
deserializer.load_into_module(model)
|
||||
end = time.perf_counter()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||
duration = end - start
|
||||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||
after_mem = get_mem_usage()
|
||||
deserializer.close()
|
||||
logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
|
||||
end - start, per_second)
|
||||
logger.info("Memory usage before: %s", before_mem)
|
||||
logger.info("Memory usage after: %s", after_mem)
|
||||
|
||||
_check_tensors_on_meta_device(model)
|
||||
_resize_lora_embeddings(model)
|
||||
del model.vllm_tensorized_marker
|
||||
|
||||
def serialize_extra_artifacts(
|
||||
tensorizer_args: TensorizerArgs,
|
||||
served_model_name: Union[str, list[str], None]) -> None:
|
||||
if not isinstance(served_model_name, str):
|
||||
raise ValueError(
|
||||
f"served_model_name must be a str for serialize_extra_artifacts, "
|
||||
f"not {type(served_model_name)}.")
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use local file
|
||||
'''
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
local_model_path = Path(served_model_name)
|
||||
if not local_model_path.exists() or not local_model_path.is_dir():
|
||||
raise ValueError(
|
||||
f"served_model_name must be a valid local directory in offline mode, "
|
||||
f"but got: {served_model_name}"
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: copy local file
|
||||
'''
|
||||
logger.info("Copying local model from %s to temporary directory %s",
|
||||
local_model_path, tmpdir)
|
||||
shutil.copytree(local_model_path, tmpdir, dirs_exist_ok=True)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
for artifact in os.scandir(tmpdir):
|
||||
if not artifact.is_file():
|
||||
continue
|
||||
with open(artifact.path, "rb") as f, open_stream(
|
||||
f"{tensorizer_args.tensorizer_dir}/{artifact.name}",
|
||||
mode="wb+",
|
||||
**tensorizer_args.stream_kwargs) as stream:
|
||||
logger.info("Writing artifact %s", artifact.name)
|
||||
stream.write(f.read())
|
||||
|
||||
35
vllm_mlu/model_executor/model_loader/tensorizer_loader.py
Normal file
35
vllm_mlu/model_executor/model_loader/tensorizer_loader.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.model_loader.tensorizer import is_vllm_tensorized
|
||||
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
|
||||
|
||||
from vllm_mlu.model_executor.model_loader.tensorizer import deserialize_tensorizer_model
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def vllm__model_executor__model_loader__tensorizer_loader__TensorizerLoader__load_weights(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_config: ModelConfig
|
||||
) -> None:
|
||||
"""Load serialized model weights with tensorizer.
|
||||
|
||||
Expects a vLLM-tensorized model. See the
|
||||
examples/others/tensorize_vllm_model.py example script
|
||||
for serializing vLLM models."""
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
tensorizer_config = self._patch_tensorizer_config(model_config)
|
||||
deserialize_tensorizer_model(model, tensorizer_config)
|
||||
else:
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
TensorizerLoader,
|
||||
TensorizerLoader.load_weights,
|
||||
vllm__model_executor__model_loader__tensorizer_loader__TensorizerLoader__load_weights
|
||||
)
|
||||
12
vllm_mlu/model_executor/models/__init__.py
Executable file
12
vllm_mlu/model_executor/models/__init__.py
Executable file
@@ -0,0 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm import ModelRegistry
|
||||
|
||||
|
||||
def register_model():
|
||||
from .deepseek_v4 import MLUDeepseekV4ForCausalLM # noqa: F401
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV4ForCausalLM",
|
||||
"vllm_mlu.model_executor.models.deepseek_v4:MLUDeepseekV4ForCausalLM")
|
||||
192
vllm_mlu/model_executor/models/config.py
Normal file
192
vllm_mlu/model_executor/models/config.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from math import lcm
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.model_executor.models.config import (HybridAttentionMambaModelConfig,
|
||||
MambaModelConfig)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@classmethod
|
||||
def vllm__module_executor__models__config__HybridAttentionMambaModelConfig__verify_and_update_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig"
|
||||
) -> None:
|
||||
"""
|
||||
Ensure that page size of attention layers is greater than or
|
||||
equal to the mamba layers. If not, automatically set the attention
|
||||
block size to ensure that it is. If the attention page size is
|
||||
strictly greater than the mamba page size, we pad the mamba page size
|
||||
to make them equal.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM Config
|
||||
"""
|
||||
# Save the user input before it gets modified by MambaModelConfig
|
||||
mamba_block_size = vllm_config.cache_config.mamba_block_size
|
||||
# Enable FULL_AND_PIECEWISE by default
|
||||
MambaModelConfig.verify_and_update_config(vllm_config)
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
kv_cache_dtype = model_config.dtype
|
||||
else:
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# get attention page size (for 1 token)
|
||||
# Attention backend constraints:
|
||||
# - FlashAttention (FA) requires block size to be multiple of 16
|
||||
# - MLA (Multi-head Latent Attention) requires larger alignment:
|
||||
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
|
||||
# * Other MLA backends: kernel_block_size 64 alignment
|
||||
if model_config.use_mla:
|
||||
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
|
||||
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
|
||||
attn_page_size_1_token = MLAAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
else:
|
||||
kernel_block_alignment_size = 16
|
||||
if (
|
||||
current_platform.is_device_capability(100)
|
||||
and model_config.get_head_size() == 256
|
||||
and (
|
||||
envs.VLLM_ATTENTION_BACKEND is None
|
||||
or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER"
|
||||
)
|
||||
):
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
|
||||
# head size 256 and block size 16 is not supported on blackwell.
|
||||
kernel_block_alignment_size = 32
|
||||
attn_page_size_1_token = FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||
model_config.architecture,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# get mamba page size
|
||||
mamba_page_size = MambaSpec(
|
||||
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
||||
block_size=model_config.max_model_len,
|
||||
).page_size_bytes
|
||||
|
||||
# Model may be marked as is_hybrid
|
||||
# but mamba is skipped via config,
|
||||
# return directly
|
||||
if mamba_page_size == 0:
|
||||
return
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
# With prefix caching, select attention block size to
|
||||
# optimize for mamba kernel performance
|
||||
|
||||
# Mamba2 SSD kernel uses a chunk_size, e.g. 256
|
||||
# Align the block to the kernel: use lowest multiple of chunk_size
|
||||
# of attention tokens that would fit mamba_page_size:
|
||||
# e.g. for mamba page size = 788kB
|
||||
# attn_1_token = 2kB -> fits ~394 tokens
|
||||
# then round up to a mulitple of 256 -> 512 tokens
|
||||
# End result:
|
||||
# attn_block_size = 512
|
||||
# mamba_block_size = 512 (aligned to a multiple of chunk_size)
|
||||
# TODO(tdoublep): this constraint can be relaxed fairly
|
||||
# easily by changing the way we layout chunks in the
|
||||
# mamba2 kernels.
|
||||
|
||||
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
|
||||
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
||||
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
|
||||
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
|
||||
cache_config.mamba_block_size = attn_block_size
|
||||
else:
|
||||
# Without prefix caching, select minimum valid attention block size
|
||||
# to minimize mamba state padding
|
||||
|
||||
# Calculate minimum attention block size that satisfies both:
|
||||
# 1. Backend alignment requirements (kernel_block_alignment_size)
|
||||
# 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
|
||||
attn_block_size = kernel_block_alignment_size * cdiv(
|
||||
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support qwen3-next
|
||||
'''
|
||||
if (vllm_config.mlu_config.enable_mamba_split_page_size):
|
||||
vllm_config.mlu_config.mamba_to_attn_block_ratio = cdiv(attn_block_size, cache_config.block_size)
|
||||
cache_config.mamba_page_size_padded = cache_config.block_size * attn_page_size_1_token
|
||||
return
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# override attention block size if either (a) the
|
||||
# user has not set it or (b) the user has set it
|
||||
# too small.
|
||||
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
|
||||
cache_config.block_size = attn_block_size
|
||||
logger.info(
|
||||
"Setting attention block size to %d tokens "
|
||||
"to ensure that attention page size is >= mamba page size.",
|
||||
attn_block_size,
|
||||
)
|
||||
|
||||
# compute new attention page size
|
||||
attn_page_size = cache_config.block_size * attn_page_size_1_token
|
||||
|
||||
assert attn_page_size >= mamba_page_size
|
||||
|
||||
if attn_page_size == mamba_page_size:
|
||||
# don't need to pad mamba page size
|
||||
return
|
||||
|
||||
# pad mamba page size to exactly match attention
|
||||
if (
|
||||
cache_config.mamba_page_size_padded is None
|
||||
or cache_config.mamba_page_size_padded != attn_page_size
|
||||
):
|
||||
cache_config.mamba_page_size_padded = attn_page_size
|
||||
mamba_padding_pct = (
|
||||
100 * (attn_page_size - mamba_page_size) / mamba_page_size
|
||||
)
|
||||
logger.info(
|
||||
"Padding mamba page size by %.2f%% to ensure "
|
||||
"that mamba page size and attention page size are "
|
||||
"exactly equal.",
|
||||
mamba_padding_pct,
|
||||
)
|
||||
|
||||
MluHijackObject.apply_hijack(HybridAttentionMambaModelConfig,
|
||||
HybridAttentionMambaModelConfig.verify_and_update_config,
|
||||
vllm__module_executor__models__config__HybridAttentionMambaModelConfig__verify_and_update_config)
|
||||
1096
vllm_mlu/model_executor/models/deepseek_v4.py
Normal file
1096
vllm_mlu/model_executor/models/deepseek_v4.py
Normal file
File diff suppressed because it is too large
Load Diff
607
vllm_mlu/model_executor/models/dp_utils.py
Normal file
607
vllm_mlu/model_executor/models/dp_utils.py
Normal file
@@ -0,0 +1,607 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import (
|
||||
Any, List, Tuple, Optional, Dict, Union, ClassVar, Literal,
|
||||
Protocol, overload, runtime_checkable)
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.communication_op import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_gather_into_list,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
get_tp_group,
|
||||
get_pp_group,
|
||||
get_dp_group,
|
||||
get_data_parallel_group_rank,
|
||||
get_data_parallel_group_world_size,
|
||||
get_dense_mlp_tp_world_size,
|
||||
get_tp_world_world_size,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_logits_tp_world_size,
|
||||
get_parallel_rank_with_group,
|
||||
get_tp_world_group,
|
||||
get_tp_world_rank,
|
||||
GroupCoordinator,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_mlu.mlu_forward_context import MLUDPMetadata
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# alias after refactor
|
||||
DataParallelRuntimeParams = MLUDPMetadata
|
||||
|
||||
|
||||
def enable_data_parallel():
|
||||
return get_dp_group().world_size > 1
|
||||
|
||||
|
||||
def enable_emb_logits_custom_parallel():
|
||||
return get_logits_tp_world_size() != get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
def enable_dense_mlp_custom_parallel():
|
||||
return get_dense_mlp_tp_world_size() != get_tp_world_world_size()
|
||||
|
||||
|
||||
def get_runtime_infos_per_dp_group(
|
||||
num_tokens: int, num_requests: int, all_prefill: bool, seq_lens: List[int],
|
||||
device: torch.device, vllm_config: VllmConfig) -> Tuple[List[int], List[bool]]:
|
||||
dp_tensor = torch.tensor([num_tokens, num_requests, int(all_prefill)]).to(device, non_blocking=True)
|
||||
outputs = tensor_model_parallel_all_gather_into_list(dp_tensor, get_dp_group())
|
||||
outputs = torch.cat(outputs).tolist() # d2h
|
||||
dp_world_size = get_data_parallel_group_world_size()
|
||||
dp_is_prefill, dp_query_lens, dp_group_bs, seq_len_per_batch = [], [], [], []
|
||||
for i in range(0, 3 * dp_world_size, 3):
|
||||
dp_query_lens.append(outputs[i])
|
||||
dp_group_bs.append(outputs[i + 1])
|
||||
dp_is_prefill.append(bool(outputs[i + 2]))
|
||||
|
||||
# Only run communication if mcc is enabled and is prefill.
|
||||
if vllm_config.mlu_config.is_dpsk_mcc_enabled and all(dp_is_prefill):
|
||||
assert len(seq_lens) == num_requests
|
||||
seq_len_per_batch = [torch.empty([bs], dtype=dp_tensor.dtype, device=device) for bs in dp_group_bs]
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=dp_tensor.dtype, device=device)
|
||||
torch.distributed.all_gather(seq_len_per_batch, seq_lens_tensor, group=get_dp_group().device_group)
|
||||
seq_len_per_batch=torch.cat(seq_len_per_batch).tolist()
|
||||
else:
|
||||
seq_len_per_batch = [0] * sum(dp_group_bs)
|
||||
|
||||
return dp_query_lens, dp_group_bs, dp_is_prefill, seq_len_per_batch
|
||||
|
||||
|
||||
def get_deepseek_layer_split_list(
|
||||
dp_query_lens: List[int], dp_group_bs: List[int]
|
||||
) -> Tuple[Optional[List[int]], Optional[List[int]], Optional[List[int]]]:
|
||||
if len(dp_query_lens) != len(dp_group_bs) or len(dp_query_lens) != get_data_parallel_group_world_size():
|
||||
logger.warning(f"dp_query_lens length: {len(dp_query_lens)} != dp_group_bs length: {len(dp_group_bs)}, "
|
||||
f"disable deepseek layer split")
|
||||
return None, None, None
|
||||
emb_query_lens, logits_batch_sizes, dense_attn_token_split_list = None, None, None
|
||||
all_dp_query_lens, all_dp_group_bs = [], []
|
||||
for i in range(len(dp_query_lens)):
|
||||
all_dp_query_lens.extend([dp_query_lens[i]] * get_tensor_model_parallel_world_size())
|
||||
all_dp_group_bs.extend([dp_group_bs[i]] * get_tensor_model_parallel_world_size())
|
||||
if get_logits_tp_world_size() != get_tensor_model_parallel_world_size():
|
||||
slice_start = get_tp_world_rank() // get_logits_tp_world_size() * get_logits_tp_world_size()
|
||||
slice_end = slice_start + get_logits_tp_world_size()
|
||||
emb_query_lens = all_dp_query_lens[slice_start:slice_end]
|
||||
logits_batch_sizes = all_dp_group_bs[slice_start:slice_end]
|
||||
if get_dense_mlp_tp_world_size() != get_tp_world_world_size():
|
||||
slice_start = get_tp_world_rank() // get_dense_mlp_tp_world_size() * get_dense_mlp_tp_world_size()
|
||||
slice_end = slice_start + get_dense_mlp_tp_world_size()
|
||||
dense_attn_token_split_list = all_dp_query_lens[slice_start:slice_end]
|
||||
return emb_query_lens, logits_batch_sizes, dense_attn_token_split_list
|
||||
|
||||
|
||||
def get_dp_metadata(
|
||||
num_tokens: int,
|
||||
data_parallel_size: int,
|
||||
data_parallel_rank: int,
|
||||
tensor_parallel_size: int,
|
||||
prefill_dispatch_use_RS_AG: bool,
|
||||
) -> DataParallelRuntimeParams:
|
||||
"""
|
||||
Get dp params when dummy run or capture model graph. These two cases do not have
|
||||
dp_params when forward call, because we do not want to hijack to much.
|
||||
"""
|
||||
dp_query_lens = [num_tokens] * data_parallel_size
|
||||
in_prefill = get_forward_context().attn_metadata is None # dummy run
|
||||
dp_is_prefill = [in_prefill] * data_parallel_size
|
||||
emb_query_lens, logits_batch_sizes, dense_attn_token_split_list = None, None, None
|
||||
if get_logits_tp_world_size() != get_tensor_model_parallel_world_size():
|
||||
emb_query_lens = [num_tokens] * get_logits_tp_world_size()
|
||||
logits_batch_sizes = None # dummy run and capture model does not contain logits
|
||||
if get_dense_mlp_tp_world_size() != get_tp_world_world_size():
|
||||
dense_attn_token_split_list = [num_tokens] * get_dense_mlp_tp_world_size()
|
||||
|
||||
return MLUDPMetadata.make_oot(data_parallel_rank,
|
||||
data_parallel_size,
|
||||
tensor_parallel_size,
|
||||
dp_query_lens,
|
||||
dp_is_prefill,
|
||||
prefill_dispatch_use_RS_AG,
|
||||
emb_query_lens=emb_query_lens,
|
||||
logits_batch_sizes=logits_batch_sizes,
|
||||
dense_attn_token_split_list=dense_attn_token_split_list)
|
||||
|
||||
|
||||
def remove_paddings_after_all_gather(
|
||||
hidden_states: torch.Tensor,
|
||||
padding_to_token_num: int,
|
||||
token_num_list: List[int],
|
||||
) -> torch.Tensor:
|
||||
dp_group_tensors = []
|
||||
offset = 0
|
||||
for token_num in token_num_list:
|
||||
if token_num != 0:
|
||||
dp_group_tensors.append(hidden_states[offset:offset+token_num])
|
||||
offset += padding_to_token_num
|
||||
if len(dp_group_tensors) == 1:
|
||||
hidden_states = dp_group_tensors[0]
|
||||
else:
|
||||
hidden_states = torch.cat(dp_group_tensors)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens: List[int],
|
||||
rank: int,
|
||||
hidden_states: Optional[torch.Tensor],
|
||||
group: GroupCoordinator,
|
||||
hidden_size: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None) -> torch.Tensor:
|
||||
"""
|
||||
All gather in the group.
|
||||
Input is a 2-D tensor, and can have different shape in the first dim,
|
||||
for example, [4, 7, 5, 8], [2, 5, 4, 0].
|
||||
"""
|
||||
num_tokens_equal = all(x == group_num_tokens[0] for x in group_num_tokens)
|
||||
if num_tokens_equal:
|
||||
hidden_states = tensor_model_parallel_all_gather(
|
||||
input_=hidden_states, dim=0, tp_group=group)
|
||||
else:
|
||||
max_num_tokens = max(group_num_tokens)
|
||||
num_padding = max_num_tokens - group_num_tokens[rank]
|
||||
if num_padding > 0:
|
||||
if hidden_states is None:
|
||||
hidden_states = torch.empty((max_num_tokens, hidden_size),
|
||||
dtype=dtype, device=device)
|
||||
else:
|
||||
hidden_states = F.pad(hidden_states, (0, 0, 0, num_padding))
|
||||
hidden_states = tensor_model_parallel_all_gather(
|
||||
input_=hidden_states, dim=0, tp_group=group)
|
||||
hidden_states = remove_paddings_after_all_gather(
|
||||
hidden_states, max_num_tokens, group_num_tokens)
|
||||
return hidden_states
|
||||
|
||||
def tensor_model_parallel_all_gather_op_v2(
|
||||
input_: torch.Tensor,
|
||||
dim_size_list: List[int],
|
||||
group_coordinator: GroupCoordinator,
|
||||
non_leading_dim_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
All gather the input tensor across model parallel group with only communication ops.
|
||||
|
||||
Note: compared to `tensor_model_parallel_all_gather_dp`, this method supports different
|
||||
sizes in the first dim, and does not involve padding operation.
|
||||
"""
|
||||
all_size_equal = all([dim_size == dim_size_list[0] for dim_size in dim_size_list])
|
||||
|
||||
output_shape = (sum(dim_size_list), non_leading_dim_size)
|
||||
output = torch.empty(output_shape, device=device, dtype=dtype)
|
||||
|
||||
if input_ is None:
|
||||
input_ = torch.empty((0, non_leading_dim_size), device=device, dtype=dtype)
|
||||
|
||||
if all_size_equal:
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
output, input_, group=group_coordinator.device_group)
|
||||
else:
|
||||
# Note: torch.split splits the tensor into chunks. And each chunk
|
||||
# is a view of the original tensor.
|
||||
tensor_list = torch.split(output, dim_size_list, dim=0)
|
||||
torch.distributed.all_gather(
|
||||
list(tensor_list), input_, group=group_coordinator.device_group)
|
||||
return output
|
||||
|
||||
def process_post_attention_communication(
|
||||
hidden_states: Optional[torch.Tensor],
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
tp_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Processes distributed communication operations after attention computation.
|
||||
|
||||
This function performs necessary communication operations after attention computation
|
||||
to ensure data synchronization across different parallel groups.
|
||||
Supports two modes:
|
||||
1. Tensor parallel mode: Uses tp_group for all-reduce and all-gather operations
|
||||
2. Data parallel mode: Uses reduce-scatter and all-gather for global synchronization
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states tensor after attention computation, can be None
|
||||
dp_params: Data parallel runtime parameters containing token distribution and padding info
|
||||
hidden_size: Dimension size of hidden states
|
||||
dtype: Data type of the tensor
|
||||
device: Device where the tensor is located
|
||||
tp_group: Tensor parallel group, if None uses data parallel mode
|
||||
|
||||
Returns:
|
||||
Hidden states tensor after communication synchronization processing
|
||||
|
||||
Note:
|
||||
- When prefill_pad_to_token_num != -1, padding and unpadding operations will be performed
|
||||
- Function selects optimal communication path based on token count and parallel strategy
|
||||
"""
|
||||
if tp_group is not None:
|
||||
if dp_params.token_num != 0:
|
||||
hidden_states = tensor_model_parallel_all_reduce(
|
||||
hidden_states)
|
||||
hidden_states = tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens=dp_params.dense_attn_token_split_list,
|
||||
rank=get_parallel_rank_with_group(tp_group),
|
||||
hidden_states=hidden_states,
|
||||
group=tp_group,
|
||||
)
|
||||
else:
|
||||
if dp_params.prefill_pad_to_token_num != -1:
|
||||
# pad hidden_states to use reduce_scatter and global all gather
|
||||
pad_num = dp_params.prefill_pad_to_token_num - dp_params.token_num
|
||||
if pad_num != 0:
|
||||
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_num))
|
||||
|
||||
hidden_states = tensor_model_parallel_reduce_scatter(
|
||||
hidden_states, dim=0)
|
||||
|
||||
hidden_states = tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens=dp_params.attn_token_split_list_reduce_scatter,
|
||||
rank=get_tp_world_rank(),
|
||||
hidden_states=hidden_states,
|
||||
group=get_tp_world_group(),
|
||||
)
|
||||
|
||||
# get origin hidden_states for moe compute
|
||||
hidden_states = remove_paddings_after_all_gather(
|
||||
hidden_states, dp_params.prefill_pad_to_token_num,
|
||||
dp_params.token_split_list)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(
|
||||
hidden_states)
|
||||
|
||||
all_gather_group = get_dp_group()
|
||||
all_gather_rank = get_data_parallel_group_rank()
|
||||
hidden_states = tensor_model_parallel_all_gather_dp(
|
||||
dp_params.token_split_list, all_gather_rank, hidden_states,
|
||||
all_gather_group, hidden_size, dtype, device)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def dp_model_forward(
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor],
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
embedding_layer: nn.Module,
|
||||
model_norm_layer: nn.Module,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
layers: List[nn.Module],
|
||||
layer_input_norm_name: str,
|
||||
prefill_dispatch_use_RS_AG: bool,
|
||||
streams: Optional[Dict[str, torch.mlu.Stream]] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""run model with dp."""
|
||||
if dp_params is None:
|
||||
dp_params = get_dp_metadata(positions.numel(),
|
||||
get_data_parallel_group_world_size(),
|
||||
get_data_parallel_group_rank(),
|
||||
get_tensor_model_parallel_world_size(),
|
||||
prefill_dispatch_use_RS_AG)
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
if embedding_layer.__class__.__name__ == "DPVocabParallelEmbedding":
|
||||
hidden_states = embedding_layer(input_ids, dp_params=dp_params)
|
||||
else:
|
||||
hidden_states = embedding_layer(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(start_layer, end_layer):
|
||||
is_first_layer = (i == start_layer)
|
||||
is_last_layer = (i == end_layer - 1)
|
||||
next_input_layernorm = None
|
||||
if not is_last_layer:
|
||||
next_input_layernorm = getattr(layers[i+1], layer_input_norm_name)
|
||||
hidden_states, residual = layers[i](
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
dp_params=dp_params,
|
||||
is_first_layer=is_first_layer,
|
||||
is_last_layer=is_last_layer,
|
||||
streams=streams,
|
||||
next_input_layernorm=next_input_layernorm,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states = model_norm_layer(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def dp_layer_forward(
|
||||
input_norm: nn.Module,
|
||||
self_attn: nn.Module,
|
||||
post_norm: nn.Module,
|
||||
mlp: nn.Module,
|
||||
mlp_kwargs: List[Dict[str, Any]],
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
hidden_size: int,
|
||||
hidden_states_dtype: torch.dtype,
|
||||
is_first_layer: bool = False,
|
||||
is_last_layer: bool = False,
|
||||
next_input_layernorm: Optional[nn.Module] = None,
|
||||
enable_all2all: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
run layer with dp. dispatch all2all or rs+ag or common.
|
||||
|
||||
For mlp_kwargs, because all2all forward args is often different with common mlp args.
|
||||
So here we decide that the mlp_kwargs[-1] is always all2all kwargs. For example:
|
||||
|
||||
Deepseek enable all2all, mlp_kwargs will be: [{mlp common forward kwargs}, {mlp all2all kwargs}].
|
||||
Deepseek does not enable all2all, mlp_kwargs will be: [{mlp common forward kwargs}].
|
||||
"""
|
||||
if dp_params.layer_use_reduce_scatter:
|
||||
common_metadata = get_common_metadata()
|
||||
is_decode_only = common_metadata is not None and common_metadata.is_decode_only
|
||||
use_all2all = enable_all2all and is_decode_only and isinstance(mlp, SparseMoeMlp)
|
||||
forward_func = _dp_forward_layer_all2all if use_all2all else _dp_forward_layer_rs_ag
|
||||
hidden_states, residual = forward_func(input_norm,
|
||||
self_attn,
|
||||
post_norm,
|
||||
mlp,
|
||||
mlp_kwargs,
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
dp_params,
|
||||
is_first_layer,
|
||||
is_last_layer,
|
||||
next_input_layernorm)
|
||||
else:
|
||||
hidden_states, residual = _dp_forward_layer_common(input_norm,
|
||||
self_attn,
|
||||
post_norm,
|
||||
mlp,
|
||||
mlp_kwargs,
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
dp_params,
|
||||
hidden_size,
|
||||
hidden_states_dtype)
|
||||
return hidden_states, residual
|
||||
|
||||
def _dp_forward_layer_rs_ag(
|
||||
input_norm: nn.Module,
|
||||
self_attn: nn.Module,
|
||||
post_norm: nn.Module,
|
||||
mlp: nn.Module,
|
||||
mlp_kwargs: List[Dict[str, Any]],
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
is_first_layer: bool,
|
||||
is_last_layer: bool,
|
||||
next_input_layernorm: List[Optional[nn.Module]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""run layer with rs+ag."""
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
# We move the input_layernorm of i+1 layer to the end of i layer.
|
||||
# But for the first layer, we need to do input_layernorm first.
|
||||
if is_first_layer:
|
||||
hidden_states = input_norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# add residual here for the first layer
|
||||
if is_first_layer and get_tensor_model_parallel_rank() == 0:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = tensor_model_parallel_reduce_scatter(
|
||||
hidden_states, dim=0)
|
||||
|
||||
# move norm between rs and ag
|
||||
if is_first_layer:
|
||||
residual = hidden_states
|
||||
hidden_states = post_norm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = post_norm(hidden_states, residual)
|
||||
|
||||
hidden_states = tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens=dp_params.attn_token_split_list_reduce_scatter,
|
||||
rank=get_tp_world_rank(),
|
||||
hidden_states=hidden_states,
|
||||
group=get_tp_world_group(),
|
||||
)
|
||||
|
||||
# mlp, use all cards
|
||||
hidden_states = mlp(hidden_states, **mlp_kwargs[0])
|
||||
|
||||
hidden_states = tensor_model_parallel_reduce_scatter(
|
||||
hidden_states, dim=0, tp_group=get_tp_world_group())
|
||||
|
||||
if is_last_layer:
|
||||
hidden_states = hidden_states + residual
|
||||
residual = None
|
||||
else:
|
||||
# To reduce layernorm computation, we move the layernorm of i+1 layer to
|
||||
# the end of i layer. Besides, we fuse residual addition into layernorm.
|
||||
assert next_input_layernorm is not None
|
||||
hidden_states, residual = next_input_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens=dp_params.moe_token_split_list_reduce_scatter,
|
||||
rank=get_tensor_model_parallel_rank(),
|
||||
hidden_states=hidden_states,
|
||||
group=get_tp_group(),
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def _dp_forward_layer_all2all(
|
||||
input_norm: nn.Module,
|
||||
self_attn: nn.Module,
|
||||
post_norm: nn.Module,
|
||||
mlp: nn.Module,
|
||||
mlp_kwargs: List[Dict[str, Any]],
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
is_first_layer: bool,
|
||||
is_last_layer: bool,
|
||||
next_input_layernorm: List[Optional[nn.Module]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""run layer with all2all."""
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
# We move the input_layernorm of i+1 layer to the end of i layer.
|
||||
# But for the first layer, we need to do input_layernorm first.
|
||||
if is_first_layer:
|
||||
hidden_states = input_norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# add residual here for the first layer
|
||||
if is_first_layer and get_tensor_model_parallel_rank() == 0:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = tensor_model_parallel_reduce_scatter(
|
||||
hidden_states, dim=0)
|
||||
|
||||
# move norm between rs and ag
|
||||
if is_first_layer:
|
||||
residual = hidden_states
|
||||
hidden_states = post_norm(hidden_states)
|
||||
else:
|
||||
# add residual in norm for other layers
|
||||
hidden_states, residual = post_norm(hidden_states, residual)
|
||||
|
||||
hidden_states = mlp.forward_all2all(hidden_states, **mlp_kwargs[-1])
|
||||
|
||||
if is_last_layer:
|
||||
hidden_states = hidden_states + residual
|
||||
residual = None
|
||||
else:
|
||||
# To reduce layernorm computation, we move the layernorm of i+1 layer to
|
||||
# the end of i layer. Besides, we fuse residual addition into layernorm.
|
||||
assert next_input_layernorm is not None
|
||||
hidden_states, residual = next_input_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens=dp_params.moe_token_split_list_reduce_scatter,
|
||||
rank=get_tensor_model_parallel_rank(),
|
||||
hidden_states=hidden_states,
|
||||
group=get_tp_group(),
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def _dp_forward_layer_common(
|
||||
input_norm: nn.Module,
|
||||
self_attn: nn.Module,
|
||||
post_norm: nn.Module,
|
||||
mlp: nn.Module,
|
||||
mlp_kwargs: List[Dict[str, Any]],
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""run layer with common."""
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = input_norm(hidden_states)
|
||||
hidden_states = self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# add residual here
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = process_post_attention_communication(
|
||||
hidden_states, dp_params, hidden_size, dtype, positions.device, None
|
||||
)
|
||||
|
||||
residual = hidden_states[dp_params.token_num_offset:
|
||||
dp_params.token_num_offset + dp_params.token_num]
|
||||
|
||||
hidden_states = post_norm(hidden_states)
|
||||
|
||||
hidden_states = mlp(hidden_states, **mlp_kwargs[0])
|
||||
|
||||
hidden_states = tensor_model_parallel_all_reduce(
|
||||
hidden_states, tp_group=get_tp_world_group())
|
||||
|
||||
# add residual here
|
||||
hidden_states = hidden_states[dp_params.token_num_offset:
|
||||
dp_params.token_num_offset+dp_params.token_num]
|
||||
hidden_states = hidden_states + residual
|
||||
residual = hidden_states
|
||||
|
||||
return hidden_states, residual
|
||||
245
vllm_mlu/model_executor/models/layer_utils.py
Executable file
245
vllm_mlu/model_executor/models/layer_utils.py
Executable file
@@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
from typing import Callable, Optional, List, Union, Tuple
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
def hunyuan_decoder_layer_forward_base(
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_layernorm: Callable,
|
||||
self_attn: Callable,
|
||||
post_layernorm: Callable,
|
||||
mlp: Callable,
|
||||
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
||||
apply_residual_connection_post_layernorm: bool = False,
|
||||
position_name: str = 'positions',
|
||||
input_norm_fuse_en: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
smooth_quant_scale = None
|
||||
if input_norm_fuse_en:
|
||||
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
|
||||
else:
|
||||
layernorm_output = input_layernorm(hidden_states)
|
||||
smooth_quant_scale = None
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# Self Attention
|
||||
attention_output, ori_kv_states = self_attn(
|
||||
**{position_name: positions},
|
||||
hidden_states=layernorm_output,
|
||||
residual=residual,
|
||||
kv_states=kv_states,
|
||||
smooth_quant_scale=smooth_quant_scale,
|
||||
)
|
||||
|
||||
layernorm_output = post_layernorm(attention_output)
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = attention_output
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = mlp(layernorm_output, residual)
|
||||
return hidden_states, ori_kv_states
|
||||
|
||||
|
||||
def decoder_layer_forward_base(
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_layernorm: Callable,
|
||||
self_attn: Callable,
|
||||
post_layernorm: Callable,
|
||||
mlp: Callable,
|
||||
apply_residual_connection_post_layernorm: bool = False,
|
||||
position_name: str = 'positions',
|
||||
input_norm_fuse_en: bool = False,
|
||||
post_norm_fuse_en: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if input_norm_fuse_en:
|
||||
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
|
||||
else:
|
||||
layernorm_output = input_layernorm(hidden_states)
|
||||
smooth_quant_scale = None
|
||||
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# Self Attention
|
||||
attention_output = self_attn(
|
||||
**{position_name: positions},
|
||||
hidden_states=layernorm_output,
|
||||
residual=residual,
|
||||
smooth_quant_scale=smooth_quant_scale,
|
||||
)
|
||||
|
||||
if post_norm_fuse_en:
|
||||
layernorm_output, smooth_quant_scale = post_layernorm(attention_output)
|
||||
else:
|
||||
layernorm_output = post_layernorm(attention_output)
|
||||
smooth_quant_scale = None
|
||||
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = attention_output
|
||||
|
||||
# Fully Connected
|
||||
kwargs = dict()
|
||||
if post_norm_fuse_en:
|
||||
kwargs['smooth_quant_scale'] = smooth_quant_scale
|
||||
hidden_states = mlp(layernorm_output, residual, **kwargs)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def decoder_model_forward_base(
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
layers: torch.nn.ModuleList,
|
||||
embed_input_ids: Callable,
|
||||
norm: Callable
|
||||
) -> torch.Tensor:
|
||||
hidden_states = embed_input_ids(input_ids)
|
||||
for i in range(len(layers)):
|
||||
layer = layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
)
|
||||
hidden_states = norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def hunyuan_decoder_model_forward_base_pp(
|
||||
config: PretrainedConfig,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
layers: torch.nn.ModuleList,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
embed_input_ids: Callable,
|
||||
norm: Callable,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = embed_input_ids(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
cla_factor = getattr(config, "cla_share_factor", 1)
|
||||
prev_kv_states = None
|
||||
for i in range(start_layer, end_layer):
|
||||
layer = layers[i]
|
||||
hidden_states, kv_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
prev_kv_states,
|
||||
)
|
||||
if (i - start_layer) % cla_factor == 0:
|
||||
prev_kv_states = kv_states
|
||||
else:
|
||||
prev_kv_states = None
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
})
|
||||
|
||||
hidden_states = norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def decoder_model_forward_base_pp(
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
layers: torch.nn.ModuleList,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
embed_input_ids: Callable,
|
||||
norm: Callable,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = embed_input_ids(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(start_layer, end_layer):
|
||||
layer = layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
})
|
||||
|
||||
hidden_states = norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def is_smoothquant(quant_config: QuantizationConfig) -> bool:
|
||||
return (quant_config is not None
|
||||
and quant_config.get_name() == "SmoothQuant")
|
||||
|
||||
|
||||
def is_per_token_smoothquant(quant_config: QuantizationConfig) -> bool:
|
||||
return (is_smoothquant(quant_config)
|
||||
and quant_config.input_quant_method == "per_token")
|
||||
|
||||
def compute_in_loop(func: Callable,
|
||||
input: torch.Tensor,
|
||||
chunk_size: int,
|
||||
feature_size: Optional[int] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
divides input into chunks in the leading dimension (dimension 0), and
|
||||
compute the chunks in a loop, instead of in a batch at once.
|
||||
|
||||
arg:
|
||||
feature_size: size of output feature dimension. Provide it when the
|
||||
the output's feature dimension would differ from the input's
|
||||
feature dimension.
|
||||
"""
|
||||
|
||||
total = input.shape[0]
|
||||
# directly compute if there is only one chunk
|
||||
if chunk_size >= total:
|
||||
return func(input, **kwargs)
|
||||
|
||||
feature_size = feature_size or input.shape[1]
|
||||
output = input.new_empty(total, feature_size)
|
||||
num_chunks = (total + chunk_size - 1) // chunk_size
|
||||
|
||||
for i in range(num_chunks):
|
||||
start = i * chunk_size
|
||||
end = min((i + 1) * chunk_size, total)
|
||||
output[start : end] = func(input[start : end], **kwargs)
|
||||
|
||||
return output
|
||||
507
vllm_mlu/model_executor/models/partition_utils.py
Normal file
507
vllm_mlu/model_executor/models/partition_utils.py
Normal file
@@ -0,0 +1,507 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import itertools
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm_mlu.mlu_forward_context import MLUDPMetadata
|
||||
from vllm_mlu.model_executor.models.dp_utils import DataParallelRuntimeParams
|
||||
from vllm_mlu.v1.attention.backends.mla.flashmla import (
|
||||
FlashMLAPrefillMetadata, FlashMLAMetadata, MLACommonMetadata
|
||||
)
|
||||
from vllm_mlu.v1.attention.backends.utils import (
|
||||
COMMON_METADATA_STR,
|
||||
MLUCommonAttentionMetadata,
|
||||
)
|
||||
|
||||
SEQUENCE_DIM_PARITION_THRESHOLD = 1024
|
||||
|
||||
def get_common_and_layer_metadata(
|
||||
attn_metadata: Optional[dict],
|
||||
) -> Tuple[Optional[MLUCommonAttentionMetadata], Optional[AttentionMetadata]]:
|
||||
"""
|
||||
Returns the common metadata and layer metadata from the given attention metadata.
|
||||
"""
|
||||
if attn_metadata is None:
|
||||
return None, None
|
||||
|
||||
if isinstance(attn_metadata, dict):
|
||||
assert COMMON_METADATA_STR in attn_metadata, (
|
||||
f"attn_metadata must contain {COMMON_METADATA_STR} key"
|
||||
)
|
||||
assert len({id(v) for v in attn_metadata.values()}) == 2, (
|
||||
f"attn_metadata should be a dict with two values, one for {COMMON_METADATA_STR} and "
|
||||
f"the other for layers."
|
||||
)
|
||||
common_metadata = attn_metadata[COMMON_METADATA_STR]
|
||||
layer_metadata = next((v for k, v in attn_metadata.items() if k != COMMON_METADATA_STR), None)
|
||||
|
||||
return common_metadata, layer_metadata
|
||||
return None, attn_metadata
|
||||
|
||||
def should_skip_partition(layer_metadata, common_metadata) -> bool:
|
||||
"""Helper function to simplify partition condition check"""
|
||||
is_layer_metadata_invalid = (layer_metadata is None
|
||||
or layer_metadata.prefill is None
|
||||
or layer_metadata.query_start_loc is None
|
||||
or layer_metadata.query_start_loc.numel() == 0)
|
||||
is_common_metadata_invalid = common_metadata is None or not common_metadata.is_prefill_only
|
||||
return is_layer_metadata_invalid or is_common_metadata_invalid
|
||||
|
||||
|
||||
def attn_mcc_plan(
|
||||
attn_metadata: Any,
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
parts_to_split: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns the number of parts for batch size dimension and the number of parts for sequence length dimension.
|
||||
"""
|
||||
# In the precedure of dummy run, attn_metadata is an instance of MLACommonMetadata
|
||||
if not isinstance(attn_metadata, (dict, MLACommonMetadata, type(None))):
|
||||
raise TypeError(f"attn_metadata must be dict or MLACommonMetadata, got {type(attn_metadata)}")
|
||||
|
||||
if isinstance(attn_metadata, dict):
|
||||
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
|
||||
else:
|
||||
common_metadata, layer_metadata = None, attn_metadata
|
||||
|
||||
if dp_params is None:
|
||||
# We don't support mcc with decode yet.
|
||||
if should_skip_partition(layer_metadata, common_metadata):
|
||||
return 1, 1
|
||||
|
||||
# The priority of batch size dimension to split is higher than sequence length dimension.
|
||||
# And we ensure each subtask is not empty without dp.
|
||||
num_prefills = layer_metadata.query_start_loc.numel() - 1
|
||||
if num_prefills > 1:
|
||||
return min(parts_to_split, num_prefills), 1
|
||||
|
||||
try:
|
||||
max_query_len = torch.diff(layer_metadata.query_start_loc).max().item()
|
||||
except RuntimeError:
|
||||
return 1, 1
|
||||
|
||||
if max_query_len < SEQUENCE_DIM_PARITION_THRESHOLD:
|
||||
return 1, 1
|
||||
return 1, min(parts_to_split, max_query_len)
|
||||
else:
|
||||
if not all(is_prefill for is_prefill in dp_params.dp_is_prefill):
|
||||
return 1, 1
|
||||
|
||||
max_bs = max(dp_params.batch_sizes)
|
||||
if max_bs > 1:
|
||||
# Ensure parts_to_split does not exceed max_bs to avoid unnecessary splits
|
||||
if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD:
|
||||
return 1, 1
|
||||
return min(parts_to_split, max_bs), 1
|
||||
else:
|
||||
if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD:
|
||||
return 1, 1
|
||||
return 1, parts_to_split
|
||||
|
||||
def get_data_num_and_offset(total_size, parts_to_split):
|
||||
"""
|
||||
Get data size and offset for each.
|
||||
For example, total batch 11, parallel_num 4, result is [3, 3, 3, 2], offsets is [0, 3, 6, 9]
|
||||
total batch 8, parallel_num 4, result is [2, 2, 2, 2], offsets is [0, 2, 4, 6]
|
||||
"""
|
||||
# Calculate the quotient and remainder of total_size divided by parts_to_split
|
||||
quotient = total_size // parts_to_split
|
||||
remainder = total_size % parts_to_split
|
||||
data_num_list = [quotient + 1] * remainder + [quotient] * (parts_to_split - remainder)
|
||||
offset_list = [0] + list(itertools.accumulate(data_num_list))
|
||||
return data_num_list, offset_list[:-1]
|
||||
|
||||
def split_dp_params(
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
bs_parts_to_split: int,
|
||||
seq_parts_to_split: int,
|
||||
attn_data_parallel_size: int,
|
||||
attn_tensor_parallel_size: int,
|
||||
prefill_dispatch_use_RS_AG: bool,
|
||||
dp_rank_: int,
|
||||
) -> List[DataParallelRuntimeParams]:
|
||||
assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \
|
||||
"We don't support split batch and sequence dimensions concurrently."
|
||||
|
||||
if dp_params is None:
|
||||
return [None] * bs_parts_to_split * seq_parts_to_split
|
||||
|
||||
if bs_parts_to_split * seq_parts_to_split == 1:
|
||||
return list([dp_params])
|
||||
|
||||
if bs_parts_to_split == 1:
|
||||
results : List[DataParallelRuntimeParams] = []
|
||||
dp_seq_lens = []
|
||||
for seq_len in dp_params.seq_lens:
|
||||
tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split)
|
||||
dp_seq_lens.append(tokens)
|
||||
|
||||
query_lens_per_dp_rank = []
|
||||
|
||||
# For each dp rank, the batch size is 0 or 1.
|
||||
bs_offset = 0
|
||||
for i in range(attn_data_parallel_size):
|
||||
if dp_params.batch_sizes[i] > 0:
|
||||
seq_len = dp_params.seq_lens[bs_offset]
|
||||
tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split)
|
||||
query_lens_per_dp_rank.append(tokens)
|
||||
bs_offset += dp_params.batch_sizes[i]
|
||||
else:
|
||||
query_lens_per_dp_rank.append([0] * seq_parts_to_split)
|
||||
|
||||
for i in range(seq_parts_to_split):
|
||||
dp_is_prefill = []
|
||||
for dp_rank in range(attn_data_parallel_size):
|
||||
dp_is_prefill.append(True)
|
||||
|
||||
results.append(MLUDPMetadata.make_oot(
|
||||
data_parallel_rank=dp_rank_,
|
||||
data_parallel_size=attn_data_parallel_size,
|
||||
tensor_parallel_size=attn_tensor_parallel_size,
|
||||
dp_token_nums=[query_lens_per_dp_rank[j][i] for j in range(attn_data_parallel_size)],
|
||||
dp_is_prefill=dp_is_prefill,
|
||||
prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG,
|
||||
seq_lens=[seq_lens[i] for seq_lens in dp_seq_lens],
|
||||
batch_sizes=dp_params.batch_sizes,
|
||||
))
|
||||
return results
|
||||
|
||||
bs_per_dp = dp_params.batch_sizes # [bs_rank_0, bs_rank_1, ...]
|
||||
seq_lens_per_dp = dp_params.seq_lens # [seq_len_bs_0, seq_len_bs_1,...]
|
||||
|
||||
# [[bs_rank_0_part_0, bs_rank_0_part_1,...], [bs_rank_1_part_0, bs_rank_1_part_1,...], ...]
|
||||
split_bs_per_dp = []
|
||||
# [[
|
||||
# [bs0_part_0_rank_0, bs1_part_0_rank_0, ...],
|
||||
# [bs0_part_1_rank_0, bs1_part_1_rank_0, ...],
|
||||
# ...
|
||||
# ],
|
||||
# [
|
||||
# [bs0_part_0_rank_1, bs1_part_0_rank_1, ...],
|
||||
# [bs0_part_1_rank_1, bs1_part_1_rank_1, ...],
|
||||
# ...
|
||||
# ],
|
||||
# ]
|
||||
split_query_lens_per_dp = []
|
||||
for dp_rank in range(attn_data_parallel_size):
|
||||
_bs, _offset = get_data_num_and_offset(bs_per_dp[dp_rank], bs_parts_to_split)
|
||||
split_bs_per_dp.append(_bs)
|
||||
split_query_lens_per_dp.append([])
|
||||
for i in range(bs_parts_to_split):
|
||||
start = sum(bs_per_dp[:dp_rank]) + _offset[i]
|
||||
end = start + _bs[i]
|
||||
split_query_lens_per_dp[-1].append(dp_params.seq_lens[start:end])
|
||||
|
||||
results : List[DataParallelRuntimeParams] = []
|
||||
for i in range(bs_parts_to_split):
|
||||
dp_query_lens = [sum(split_query_lens_per_dp[dp_rank][i]) for dp_rank in range(attn_data_parallel_size)]
|
||||
seq_lens = []
|
||||
for dp_rank in range(attn_data_parallel_size):
|
||||
seq_lens += split_query_lens_per_dp[dp_rank][i]
|
||||
batch_sizes = []
|
||||
for dp_rank in range(attn_data_parallel_size):
|
||||
batch_sizes.append(split_bs_per_dp[dp_rank][i])
|
||||
|
||||
dp_is_prefill = []
|
||||
for dp_rank in range(attn_data_parallel_size):
|
||||
dp_is_prefill.append(True)
|
||||
results.append(MLUDPMetadata.make_oot(
|
||||
data_parallel_rank=dp_rank_,
|
||||
data_parallel_size=attn_data_parallel_size,
|
||||
tensor_parallel_size=attn_tensor_parallel_size,
|
||||
dp_token_nums=dp_query_lens,
|
||||
dp_is_prefill=dp_is_prefill,
|
||||
prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG,
|
||||
seq_lens=seq_lens,
|
||||
batch_sizes=batch_sizes,
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def split_input(
|
||||
input: torch.Tensor,
|
||||
bs_parts_to_split: int,
|
||||
seq_parts_to_split: int,
|
||||
attn_metadata_list: List[AttentionMetadata],
|
||||
) -> List[torch.Tensor]:
|
||||
assert seq_parts_to_split == 1 or bs_parts_to_split == 1, \
|
||||
"We don't support split batch and sequence dimensions concurrently."
|
||||
|
||||
if input is None:
|
||||
return [None] * bs_parts_to_split * seq_parts_to_split
|
||||
|
||||
if bs_parts_to_split * seq_parts_to_split == 1:
|
||||
return list([input])
|
||||
|
||||
token_num_list = [0] * len(attn_metadata_list)
|
||||
for i, metadata in enumerate(attn_metadata_list):
|
||||
common_metadata, layer_metadata = get_common_and_layer_metadata(metadata)
|
||||
if layer_metadata is not None:
|
||||
token_num_list[i] = layer_metadata.num_actual_tokens
|
||||
|
||||
# A special case for dummy run
|
||||
if layer_metadata is None and i == 0:
|
||||
token_num_list[i] = input.shape[0]
|
||||
|
||||
results = list()
|
||||
for i in range(bs_parts_to_split * seq_parts_to_split):
|
||||
start = sum(token_num_list[:i])
|
||||
end = start + token_num_list[i]
|
||||
results.append(input[start:end])
|
||||
return results
|
||||
|
||||
|
||||
def split_positions(
|
||||
positions: torch.Tensor,
|
||||
bs_parts_to_split: int,
|
||||
seq_parts_to_split: int,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> List[torch.Tensor]:
|
||||
if seq_parts_to_split == 1:
|
||||
return [positions] * bs_parts_to_split
|
||||
|
||||
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
|
||||
total_tokens = layer_metadata.num_actual_tokens if layer_metadata is not None else 0
|
||||
|
||||
tokens, offsets = get_data_num_and_offset(total_tokens, seq_parts_to_split)
|
||||
positions_list = []
|
||||
for i in range(seq_parts_to_split):
|
||||
positions_list.append(positions[offsets[i]: offsets[i] + tokens[i]])
|
||||
|
||||
return positions_list
|
||||
|
||||
def split_attn_metadata(
|
||||
attn_metadata: dict,
|
||||
bs_parts_to_split: int,
|
||||
seq_parts_to_split: int,
|
||||
) -> List[Any]:
|
||||
""" attn_metdata is a dict, which contains common and layer metadata."""
|
||||
assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \
|
||||
"We don't support split batch and sequence dimensions concurrently."
|
||||
if bs_parts_to_split == 1 and seq_parts_to_split == 1:
|
||||
return list([attn_metadata])
|
||||
|
||||
if attn_metadata is None:
|
||||
return [None] * bs_parts_to_split * seq_parts_to_split
|
||||
|
||||
if seq_parts_to_split > 1:
|
||||
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
|
||||
if common_metadata is None or not hasattr(common_metadata, 'num_actual_tokens'):
|
||||
raise ValueError("common_metadata is invalid or missing num_actual_tokens")
|
||||
num_prefill_tokens = common_metadata.num_actual_tokens
|
||||
tokens, offsets = get_data_num_and_offset(num_prefill_tokens, seq_parts_to_split)
|
||||
device = common_metadata.seq_lens.device
|
||||
sub_common_metadata, sub_layer_metadata = [], []
|
||||
for i in range(seq_parts_to_split):
|
||||
# query_start_loc tensor, which indices positions in input.
|
||||
query_start_loc_tensor = torch.empty_like(common_metadata.query_start_loc)
|
||||
query_start_loc_tensor[0] = 0
|
||||
query_start_loc_tensor[1] = tokens[i]
|
||||
# seq_lens tensor
|
||||
seq_lens_tensor = torch.tensor(
|
||||
[offsets[i] + tokens[i]],
|
||||
dtype=common_metadata.seq_lens.dtype,
|
||||
device=device
|
||||
)
|
||||
# seq_start_loc tensor, which indicates positions in the sequence(kv cache).
|
||||
seq_start_loc_tensor = torch.empty_like(common_metadata.seq_start_loc)
|
||||
seq_start_loc_tensor[0] = offsets[i]
|
||||
seq_start_loc_tensor[1] = offsets[i] + tokens[i]
|
||||
# max_query_len scalar
|
||||
max_query_len = tokens[i]
|
||||
# num_actual_tokens scalar
|
||||
num_actual_tokens = tokens[i]
|
||||
# num_input_tokens scalar
|
||||
num_input_tokens = num_actual_tokens
|
||||
# infer_mode
|
||||
infer_mode = common_metadata.infer_mode
|
||||
# update common metadata
|
||||
sub_common_metadata.append(MLUCommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used
|
||||
seq_lens=seq_lens_tensor,
|
||||
seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used
|
||||
num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used
|
||||
num_reqs=common_metadata.num_reqs, # FIXME: split when used
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_query_len,
|
||||
block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used
|
||||
slot_mapping=common_metadata.slot_mapping, # FIXME: split when used
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
num_input_tokens=num_input_tokens,
|
||||
infer_mode=infer_mode,
|
||||
num_prefill_query_tokens=tokens[i],
|
||||
num_prefill_kv_tokens=offsets[i] + tokens[i],
|
||||
))
|
||||
# slot_mapping tensor
|
||||
slot_mapping = layer_metadata.slot_mapping[offsets[i]:offsets[i] + tokens[i]]
|
||||
# update layer metadata
|
||||
REQUIRED_NUM_DECODES = 0
|
||||
REQUIRED_NUM_DECODE_TOKENS = 0
|
||||
REQUIRED_NUM_PREFILLS = 1
|
||||
if not hasattr(layer_metadata, 'num_prefills') or \
|
||||
layer_metadata.num_prefills is None:
|
||||
raise ValueError("layer_metadata.num_prefills is required")
|
||||
|
||||
assert layer_metadata.num_decodes == REQUIRED_NUM_DECODES and \
|
||||
layer_metadata.num_decode_tokens == REQUIRED_NUM_DECODE_TOKENS and \
|
||||
layer_metadata.num_prefills == REQUIRED_NUM_PREFILLS, (
|
||||
f"num_decodes, num_decode_tokens, num_prefills must be {REQUIRED_NUM_DECODES}, {REQUIRED_NUM_DECODE_TOKENS}, "
|
||||
f"{REQUIRED_NUM_PREFILLS}, but got {layer_metadata.num_decodes}, {layer_metadata.num_decode_tokens}, "
|
||||
f"{layer_metadata.num_prefills}."
|
||||
)
|
||||
assert layer_metadata.prefill.chunked_context is None, (
|
||||
f"chunked_context is only available for prefill with chunked context, "
|
||||
f"and it is not supported when enabling mcc."
|
||||
)
|
||||
prefill_metadata = FlashMLAPrefillMetadata(
|
||||
block_table=layer_metadata.prefill.block_table,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
max_query_len=max_query_len,
|
||||
chunked_context=None,
|
||||
num_prefills=layer_metadata.prefill.num_prefills,
|
||||
max_seq_len=layer_metadata.prefill.max_seq_len,
|
||||
)
|
||||
# Note: for sequence dimension partition, we provide cu_seqlens_kv filed to
|
||||
# indicates key/value size for flash attention operator.
|
||||
prefill_metadata.cu_seqlens_kv = torch.empty_like(prefill_metadata.query_start_loc)
|
||||
prefill_metadata.cu_seqlens_kv[0] = 0
|
||||
prefill_metadata.cu_seqlens_kv[1] = offsets[i] + tokens[i]
|
||||
|
||||
sub_layer_metadata.append(FlashMLAMetadata(
|
||||
num_reqs=layer_metadata.num_reqs,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_query_len,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
num_decodes=layer_metadata.num_decodes,
|
||||
num_decode_tokens=layer_metadata.num_decode_tokens,
|
||||
num_prefills=layer_metadata.num_prefills,
|
||||
num_prefill_tokens=tokens[i],
|
||||
head_dim=layer_metadata.head_dim,
|
||||
decode=layer_metadata.decode,
|
||||
prefill=prefill_metadata,
|
||||
))
|
||||
|
||||
sub_attn_metadata_list = []
|
||||
for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata):
|
||||
sub_attn_metadata_dict = {}
|
||||
for key, value in attn_metadata.items():
|
||||
if key == COMMON_METADATA_STR:
|
||||
sub_attn_metadata_dict[key] = common_meta
|
||||
else:
|
||||
sub_attn_metadata_dict[key] = layer_meta
|
||||
sub_attn_metadata_list.append(sub_attn_metadata_dict)
|
||||
return sub_attn_metadata_list
|
||||
elif bs_parts_to_split > 1:
|
||||
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
|
||||
if not hasattr(layer_metadata, 'num_prefills') or layer_metadata.num_prefills is None:
|
||||
raise ValueError("layer_metadata.num_prefills is required")
|
||||
total_batch = layer_metadata.num_prefills
|
||||
batch_sizes, offsets = get_data_num_and_offset(total_batch, bs_parts_to_split)
|
||||
sub_common_metadata, sub_layer_metadata = [], []
|
||||
for i in range(bs_parts_to_split):
|
||||
# query_start_loc tensor
|
||||
start, end = offsets[i], offsets[i] + batch_sizes[i]
|
||||
query_start_loc_tensor = common_metadata.query_start_loc[start:end+1].clone()
|
||||
if i > 0:
|
||||
query_start_loc_tensor -= common_metadata.query_start_loc[start]
|
||||
# block_table
|
||||
block_tables = torch.empty(
|
||||
(batch_sizes[i], 0),
|
||||
dtype=layer_metadata.prefill.block_table.dtype,
|
||||
device=layer_metadata.prefill.block_table.device,
|
||||
)
|
||||
# seq_lens tensor
|
||||
seq_lens_tensor = common_metadata.seq_lens[start:end].clone()
|
||||
# seq_start_loc tensor
|
||||
seq_start_loc_tensor = query_start_loc_tensor
|
||||
# max_query_len scalar
|
||||
max_query_len = seq_lens_tensor.max().item() if seq_lens_tensor.numel() > 0 else 0
|
||||
# num_actual_tokens scalar
|
||||
num_actual_tokens = seq_start_loc_tensor[-1].item()
|
||||
# num_input_tokens scalar
|
||||
num_input_tokens = num_actual_tokens
|
||||
# infer_mode
|
||||
infer_mode = common_metadata.infer_mode
|
||||
# slot_mapping tensor
|
||||
slot_mapping_start = 0
|
||||
for data in sub_common_metadata:
|
||||
slot_mapping_start += data.num_actual_tokens
|
||||
slot_mapping_tensor = layer_metadata.slot_mapping[
|
||||
slot_mapping_start:slot_mapping_start + num_actual_tokens
|
||||
]
|
||||
# update common metadata
|
||||
sub_common_metadata.append(MLUCommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used
|
||||
seq_lens=seq_lens_tensor,
|
||||
seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used
|
||||
num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used
|
||||
num_reqs=common_metadata.num_reqs, # FIXME: split when used
|
||||
block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used
|
||||
slot_mapping=common_metadata.slot_mapping, # FIXME: split when used
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_query_len,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_input_tokens=num_input_tokens,
|
||||
infer_mode=infer_mode,
|
||||
num_prefill_query_tokens=num_actual_tokens,
|
||||
num_prefill_kv_tokens=num_actual_tokens,
|
||||
))
|
||||
# update layer_metadata
|
||||
prefill_metadata = FlashMLAPrefillMetadata(
|
||||
block_table=block_tables,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
max_query_len=max_query_len,
|
||||
chunked_context=None,
|
||||
num_prefills=batch_sizes[i],
|
||||
max_seq_len=max_query_len,
|
||||
)
|
||||
sub_layer_metadata.append(FlashMLAMetadata(
|
||||
num_reqs=batch_sizes[i],
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_query_len,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_decodes=layer_metadata.num_decodes, # useless field
|
||||
num_decode_tokens=0, # useless field
|
||||
num_prefills=batch_sizes[i],
|
||||
num_prefill_tokens=num_actual_tokens,
|
||||
head_dim=layer_metadata.head_dim,
|
||||
decode=layer_metadata.decode,
|
||||
prefill=prefill_metadata,
|
||||
))
|
||||
|
||||
sub_attn_metadata_list = []
|
||||
for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata):
|
||||
sub_attn_metadata_dict = {}
|
||||
for key, value in attn_metadata.items():
|
||||
if key == COMMON_METADATA_STR:
|
||||
sub_attn_metadata_dict[key] = common_meta
|
||||
else:
|
||||
sub_attn_metadata_dict[key] = layer_meta
|
||||
sub_attn_metadata_list.append(sub_attn_metadata_dict)
|
||||
return sub_attn_metadata_list
|
||||
|
||||
|
||||
def execute_with_updated_forward_context(
|
||||
vllm_config: VllmConfig,
|
||||
attn_metadata: AttentionMetadata,
|
||||
func: Callable,
|
||||
kwargs: Dict[str, Any],
|
||||
):
|
||||
with set_forward_context(attn_metadata, vllm_config):
|
||||
return func(**kwargs)
|
||||
81
vllm_mlu/model_executor/models/registry.py
Normal file
81
vllm_mlu/model_executor/models/registry.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Type, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.registry import (
|
||||
_LazyRegisteredModel, _RegisteredModel, _ModelRegistry)
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__model_executor__models__registry___ModelRegistry__register_model(
|
||||
self,
|
||||
model_arch: str,
|
||||
model_cls: Union[type[nn.Module], str],
|
||||
) -> None:
|
||||
"""
|
||||
Register an external model to be used in vLLM.
|
||||
|
||||
`model_cls` can be either:
|
||||
|
||||
- A [`torch.nn.Module`][] class directly referencing the model.
|
||||
- A string in the format `<module>:<class>` which can be used to
|
||||
lazily import the model. This is useful to avoid initializing CUDA
|
||||
when importing the model and thus the related error
|
||||
`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
|
||||
"""
|
||||
if not isinstance(model_arch, str):
|
||||
msg = f"`model_arch` should be a string, not a {type(model_arch)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: change mlu models register log level
|
||||
'''
|
||||
if model_arch in self.models:
|
||||
if isinstance(model_cls, str) and "MLU" in model_cls:
|
||||
logger.debug(
|
||||
"Model architecture %s is already registered, and will be "
|
||||
"overwritten by the new model class %s.", model_arch,
|
||||
model_cls)
|
||||
else:
|
||||
logger.warning(
|
||||
"Model architecture %s is already registered, and will be "
|
||||
"overwritten by the new model class %s.", model_arch,
|
||||
model_cls)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if isinstance(model_cls, str):
|
||||
split_str = model_cls.split(":")
|
||||
if len(split_str) != 2:
|
||||
msg = "Expected a string in the format `<module>:<class>`"
|
||||
raise ValueError(msg)
|
||||
|
||||
model = _LazyRegisteredModel(*split_str)
|
||||
elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
|
||||
model = _RegisteredModel.from_model_cls(model_cls)
|
||||
else:
|
||||
msg = ("`model_cls` should be a string or PyTorch model class, "
|
||||
f"not a {type(model_arch)}")
|
||||
raise TypeError(msg)
|
||||
|
||||
self.models[model_arch] = model
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
_ModelRegistry,
|
||||
_ModelRegistry.register_model,
|
||||
vllm__model_executor__models__registry___ModelRegistry__register_model
|
||||
)
|
||||
67
vllm_mlu/model_executor/models/utils.py
Normal file
67
vllm_mlu/model_executor/models/utils.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
def set_attn_compute_dtype_v1(attn_metadata, dtype: torch.dtype):
|
||||
'''
|
||||
set attn compute_dtype for v1
|
||||
'''
|
||||
if isinstance(attn_metadata, dict):
|
||||
for _, metadata in attn_metadata.items():
|
||||
metadata.compute_dtype = dtype
|
||||
else:
|
||||
metadata.compute_dtype = dtype
|
||||
|
||||
def set_attn_compute_dtype(dtype: torch.dtype):
|
||||
'''
|
||||
set attn compute_dtype.
|
||||
|
||||
TODO: FA may standardize on half precision computation in the future
|
||||
set_attn_compute_dtype might be deprecated and removed
|
||||
'''
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
set_attn_compute_dtype_v1(attn_metadata, dtype)
|
||||
|
||||
def is_tie_word_embeddings(
|
||||
model_config: ModelConfig,
|
||||
org_tie_word_embeddings: bool
|
||||
) -> bool:
|
||||
'''
|
||||
Vllm language model config for multimodal model may have wrong tie_word_embeddings,
|
||||
for example, InternVL3.5-38B, InternVL3.5-30B-A3B, etc.
|
||||
|
||||
This function is a WorkAround.
|
||||
'''
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
|
||||
if not model_config.is_multimodal_model:
|
||||
return org_tie_word_embeddings
|
||||
|
||||
model_path = get_adapter_absolute_path(model_config.model)
|
||||
config_path = os.path.join(model_path, "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
return org_tie_word_embeddings
|
||||
|
||||
tie_word_embeddings = org_tie_word_embeddings
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
# first, we find if tie_word_embeddings config is in overall config
|
||||
if config.get("tie_word_embeddings") is not None:
|
||||
tie_word_embeddings = config["tie_word_embeddings"]
|
||||
# then, we find if tie_word_embeddings config is in language model config
|
||||
if (config.get("llm_config") is not None
|
||||
and config["llm_config"].get("tie_word_embeddings") is not None):
|
||||
tie_word_embeddings = config["llm_config"]["tie_word_embeddings"]
|
||||
|
||||
return tie_word_embeddings
|
||||
48
vllm_mlu/model_executor/parameter.py
Normal file
48
vllm_mlu/model_executor/parameter.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Callable, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parameter import BasevLLMParameter
|
||||
from vllm.distributed import (
|
||||
get_parallel_rank_with_group,
|
||||
get_parallel_world_size_with_group,
|
||||
)
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
vllm__model_executor__parameter__BasevLLMParameter____init__org = BasevLLMParameter.__init__
|
||||
|
||||
|
||||
def vllm__model_executor__parameter__BasevLLMParameter____init__(
|
||||
self,
|
||||
data: torch.Tensor,
|
||||
weight_loader: Callable,
|
||||
tp_group: Any = None
|
||||
):
|
||||
vllm__model_executor__parameter__BasevLLMParameter____init__org(
|
||||
self, data, weight_loader
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add self.tp_group, world_size and tp_rank
|
||||
'''
|
||||
if tp_group is not None:
|
||||
self.tp_group = tp_group
|
||||
self.tp_world_size = get_parallel_world_size_with_group(self.tp_group)
|
||||
self.tp_rank = get_parallel_rank_with_group(self.tp_group)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(BasevLLMParameter,
|
||||
BasevLLMParameter.__init__,
|
||||
vllm__model_executor__parameter__BasevLLMParameter____init__)
|
||||
33
vllm_mlu/model_executor/warmup/kernel_warmup.py
Normal file
33
vllm_mlu/model_executor/warmup/kernel_warmup.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
"""
|
||||
Warmup kernels used during model execution.
|
||||
This is useful specifically for JIT'ed kernels as we don't want JIT'ing to
|
||||
happen during model execution.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def kernel_warmup(worker: "Worker"):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: skip deep GEMM warmup, flashinfer autotune, and
|
||||
flash infer attention warmup
|
||||
'''
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
pass
|
||||
Reference in New Issue
Block a user