[Model] Support DeepSeek-V4

This commit is contained in:
chenxb002
2026-04-24 09:50:34 +08:00
commit b9925203b8
172 changed files with 44780 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.model_executor.layers.activation import QuickGELU
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu import _mlu_ops as mlu_ops
def vllm__model_executor__activation__QuickGELU__forward_oot(self, x: torch.Tensor) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: implement forward_oot
'''
return mlu_ops.active(x, 'quick_gelu', False)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(QuickGELU,
QuickGELU.forward_oot,
vllm__model_executor__activation__QuickGELU__forward_oot)

View File

@@ -0,0 +1,277 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import math
from typing import Callable
from scipy.linalg import hadamard
import torch
from torch import nn
import torch.nn.functional as F
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
def hadamard_transform_ref(x, scale=1.0):
"""
x: (..., dim)
out: (..., dim)
"""
x_shape = x.shape
dim = x.shape[-1]
x = x.reshape(-1, dim)
log_dim = math.ceil(math.log2(dim))
dim_padded = 2 ** log_dim
if dim != dim_padded:
x = F.pad(x, (0, dim_padded - dim))
out = F.linear(
x,
torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device),
)
out = out * scale
return out[..., :dim].reshape(*x_shape)
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.bfloat16
hidden_size = x.size(-1)
return hadamard_transform_ref(x, scale=hidden_size ** -0.5)
class Compressor(nn.Module):
def __init__(self,
vllm_config: VllmConfig,
rope,
compress_ratio: int = 4,
head_dim: int = 512,
rotate: bool = False,
prefix: str = "",
**kwargs,):
super().__init__()
config = vllm_config.model_config.hf_config
self.dim = config.dim
self.head_dim = head_dim
self.rope_head_dim =config.rope_head_dim
self.nope_head_dim = head_dim - config.rope_head_dim
self.compress_ratio = compress_ratio
self.overlap = compress_ratio == 4
self.rotate = rotate
coff = 1 + self.overlap
self.norm_eps = config.norm_eps
self.window_size = config.window_size
self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
# wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
# The first half of dimensions for overlapping compression and second half for normal compression.
self.wkv = ReplicatedLinear(
self.dim,
coff * self.head_dim,
bias=False,
quant_config=None,
params_dtype = torch.float32,
prefix=f"{prefix}.wkv",
)
self.wgate = ReplicatedLinear(
self.dim,
coff * self.head_dim,
bias=False,
quant_config=None,
params_dtype = torch.float32,
prefix=f"{prefix}.wgate",
)
self.norm = RMSNorm(self.head_dim, self.norm_eps)
self.rotary_emb = rope
hf_config = vllm_config.model_config.hf_config
assert hasattr(hf_config, "cached_state_num"), \
f"cached_state_num is not set in hf_config"
cached_state_num = hf_config.cached_state_num
self.register_buffer(
"kv_state",
torch.zeros(cached_state_num, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"score_state",
torch.full(
(cached_state_num, coff * compress_ratio, coff * self.head_dim),
float("-inf"),
dtype=torch.float32,
),
persistent=False,
)
self.hadamard_matrix = torch.tensor(
hadamard(self.head_dim, dtype=float), dtype=torch.get_default_dtype(), device="mlu")
def overlap_transform(self, tensor: torch.Tensor, value=0):
# tensor: [b,s,r,2d]
b, s, _, _ = tensor.size()
ratio, d = self.compress_ratio, self.head_dim
new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
return new_tensor
def forward_decode(
self,
x: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
kv_cache: torch.Tensor,
window_offset: int,
compressor_slot_mapping: torch.Tensor,
):
x = x.float()
kv_pack, _ = self.wkv(x)
score_pack, _ = self.wgate(x)
mlu_ops.fused_compress_single_kv(
kv=kv_pack.unsqueeze(1), # (token, D) -> (B, S, D)
score=score_pack.unsqueeze(1), # (token, D) -> (B, S, D)
position=positions,
ape=self.ape,
kv_state=self.kv_state,
score_state=self.score_state,
gamma=self.norm.weight,
sin=self.rotary_emb.sin_,
cos=self.rotary_emb.cos_,
hadamard_matrix=self.hadamard_matrix,
slot_mapping=compressor_slot_mapping,
kv_cache=kv_cache,
kv_cache_scale=None,
eps=self.norm_eps,
overlap=self.overlap,
rotate=self.rotate,
state_idx=batch_to_kv_state,
)
# Here, return fake compressed_kv.
return None
def forward(
self,
x: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
kv_cache: torch.Tensor,
window_offset: int,
compressor_slot_mapping: torch.Tensor,
):
common_metadata = get_common_metadata()
forward_func: Callable = (
self.forward_prefill if common_metadata.is_prefill_only
else self.forward_decode
)
return forward_func(
x,
positions,
attn_metadata,
batch_to_kv_state,
kv_cache,
window_offset,
compressor_slot_mapping,
)
def forward_prefill(
self,
x: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
kv_cache: torch.Tensor,
window_offset: int,
compressor_slot_mapping: torch.Tensor,
):
common_metadata = get_common_metadata()
seq_lens = common_metadata.seq_lens
query_start_loc = common_metadata.query_start_loc
query_lens = query_start_loc[1:] - query_start_loc[:-1]
ratio, overlap = self.compress_ratio, self.overlap
dtype = x.dtype
x = x.float()
kv_pack, _ = self.wkv(x)
score_pack, _ = self.wgate(x)
compress_lens = query_lens // self.compress_ratio
cu_compress_lens = torch.cat([
torch.tensor([0], dtype=compress_lens.dtype, device=compress_lens.device),
torch.cumsum(compress_lens, dim=0)],
)
compress_positions = []
for i in range(len(seq_lens)):
seqlen = (query_start_loc[i+1] - query_start_loc[i]).item()
remainder = seqlen % ratio
cutoff = seqlen - remainder
pos = positions[query_start_loc[i]: query_start_loc[i+1]]
positions_ = pos[:cutoff:ratio].contiguous()
compress_positions.append(positions_)
kv_positions = torch.cat(compress_positions, dim=0)
total_compress_len = cu_compress_lens[-1].item()
kv = torch.empty(
[total_compress_len, self.head_dim],
dtype=kv_pack.dtype,
device=kv_pack.device,
)
mlu_ops.fused_compress_multi_kv(
kv = kv_pack,
score = score_pack,
kv_state = self.kv_state,
score_state = self.score_state,
state_batch_idx = batch_to_kv_state,
cu_seqlens = query_start_loc,
ape = self.ape,
max_seqlen = common_metadata.max_query_len,
overlap = overlap,
compressed_kv = kv,
)
if kv.size(0) == 0:
return kv.unsqueeze(-2).to(dtype) # (compress_token_num, 1, head_size)
kv = self.norm(kv.to(dtype))
kv_rope = kv[..., -self.rope_head_dim:].unsqueeze(-2)
# use compressed cu_seqlens here, so can not call rotary_emb directly
kv_rope = mlu_ops.rotary_embedding(
kv_rope,
self.rotary_emb.sin_,
self.rotary_emb.cos_,
kv_positions,
torch.tensor([0, kv_positions.size(0)], dtype=torch.int32, device=kv_positions.device), # cu_seqlens
True, # interleaved
True, # discrete
False,
common_metadata.max_query_len,
)
if self.rotate:
kv = rotate_activation(kv)
mlu_ops.reshape_paged_cache(
kv.unsqueeze(1),
None,
kv_cache,
None,
compressor_slot_mapping,
)
return kv.unsqueeze(-2) # (compress_token_num, 1, head_size)

View File

@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional
import torch
from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_gather)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm_mlu.model_executor.models.dp_utils import (
tensor_model_parallel_all_gather_dp, DataParallelRuntimeParams)
class DPLogitsProcessor(LogitsProcessor):
"""DP LogitsProcessor."""
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
dp_params: Optional[DataParallelRuntimeParams] = None,
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
batch_sizes = None
if (lm_head.tp_group is not None
and dp_params is not None
and dp_params.logits_batch_split_list is not None):
batch_sizes = dp_params.logits_batch_split_list
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=batch_sizes,
rank=lm_head.tp_rank,
hidden_states=hidden_states,
group=lm_head.tp_group,
)
logits = lm_head.quant_method.apply(
lm_head, hidden_states, bias=embedding_bias)
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits, tp_group=lm_head.tp_group)
else:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits, tp_group=lm_head.tp_group)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., : self.org_vocab_size]
if batch_sizes is not None:
offset = sum(batch_sizes[:lm_head.tp_rank])
logits = logits[offset : offset + batch_sizes[lm_head.tp_rank]]
return logits
def forward(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
dp_params: Optional[DataParallelRuntimeParams] = None,
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
# Get the logits for the next tokens.
logits = self._get_logits(
hidden_states, lm_head, embedding_bias, dp_params)
if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap
logits = torch.tanh(logits)
logits = logits * self.soft_cap
if self.scale != 1.0:
logits *= self.scale
return logits

View File

@@ -0,0 +1,219 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
method_has_implemented_embedding,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod,
VocabParallelEmbedding,
DEFAULT_VOCAB_PADDING_SIZE,
get_masked_input_and_mask,
pad_vocab_size,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed.communication_op import (
tensor_model_parallel_all_reduce,
)
from vllm.distributed import (
divide,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_logits_tp_group,
get_logits_tp_world_size,
get_logits_tp_rank,
)
from vllm_mlu.model_executor.models.dp_utils import (
DataParallelRuntimeParams,
tensor_model_parallel_all_gather_dp,
)
class DPVocabParallelEmbedding(VocabParallelEmbedding):
"""DP Embedding parallelized in the vocabulary dimension."""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
torch.nn.Module.__init__(self)
"""
=============================
Modify by vllm_mlu
=============================
@brief: add self.tp_group, world_size and tp_rank to support other parallel
"""
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_world_size = get_tensor_model_parallel_world_size()
self.tp_group = None
logits_tp_world_size = get_logits_tp_world_size()
if logits_tp_world_size != self.tp_world_size:
self.tp_group = get_logits_tp_group()
self.tp_world_size = logits_tp_world_size
self.tp_rank = get_logits_tp_rank()
# Keep the input dimensions.
tp_rank = self.tp_rank
self.tp_size = self.tp_world_size
"""
=================
End of MLU Hijack
=================
"""
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size, tp_rank,
self.tp_size)
self.embedding_dim = embedding_dim
quant_method = None
if quant_config is not None:
quant_method = quant_config.get_quant_method(self, prefix=prefix)
if quant_method is None:
quant_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self) is VocabParallelEmbedding
quant_method_implements_embedding = method_has_implemented_embedding(
type(quant_method))
if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError(
f"The class {type(quant_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
self.quant_method: QuantizeMethodBase = quant_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
self.tp_size)
assert (self.shard_indices.num_elements_padded ==
self.num_embeddings_per_partition)
self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index -
self.shard_indices.org_vocab_start_index)
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)
self.quant_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
def forward(self, input_,
dp_params: Optional[DataParallelRuntimeParams] = None):
token_split_list = None
if (dp_params is not None
and self.tp_group is not None
and dp_params.emb_token_split_list is not None):
token_split_list = dp_params.emb_token_split_list
input_ = tensor_model_parallel_all_gather_dp(
group_num_tokens=token_split_list,
rank=self.tp_rank,
hidden_states=input_.reshape(-1, 1),
group=self.tp_group,
).reshape(-1)
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
input_,
self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index,
)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self, masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel, tp_group=self.tp_group)
if token_split_list is not None:
offset = sum(token_split_list[:self.tp_rank])
output = output[offset : offset + token_split_list[self.tp_rank]]
return output
class DPParallelLMHead(DPVocabParallelEmbedding):
"""DP Parallelized LM head.
NOTE: A copy of ParallelLMHead class, and only change its parent
from VocabParallelEmbedding to DPVocabParallelEmbedding.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config,
prefix)
self.quant_config = quant_config
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
"""Tie the weights with word embeddings."""
# GGUF quantized embed_tokens.
if self.quant_config and self.quant_config.get_name() == "gguf":
return embed_tokens
else:
self.weight = embed_tokens.weight
return self
def forward(self, input_):
del input_
raise RuntimeError("LMHead's weights should be used in the sampler.")

View File

@@ -0,0 +1,224 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import torch.nn.functional as F
from typing import Any
from vllm.distributed import (
get_parallel_world_size_with_group,
get_parallel_rank_with_group,
)
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
ColumnParallelLinear,
RowParallelLinear
)
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.mlu_hijack_utils import set_is_gated
logger = init_logger(__name__)
class FeedForward(torch.nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
up_proj_name: str,
is_gated: bool,
down_proj_name: str,
bias: bool,
quant_config: QuantizationConfig | None = None,
skip_bias_add: bool = False,
reduce_results: bool = True,
prefix: str = "",
tp_group: Any = None,
keep_full_weights: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.is_gated = is_gated
self.bias = bias
self.up_proj_name = up_proj_name
self.down_proj_name = down_proj_name
self.quant_config = quant_config
self.is_initialized = False
self.skip_bias_add = skip_bias_add
self.reduce_results = reduce_results
self.use_bt_ffn = True
set_is_gated(self.is_gated)
# modify tp_size, tp_rank and tp_group when enable data parallel
self.tp_size = get_parallel_world_size_with_group(tp_group)
self.tp_rank = get_parallel_rank_with_group(tp_group)
self.tp_group = tp_group
self.keep_full_weights = keep_full_weights
if self.keep_full_weights:
self.tp_size = 1
self.tp_rank = 0
self.tp_group = None
# up_proj with gate or not
if self.is_gated:
up_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.{up_proj_name}",
tp_group=self.tp_group,
keep_full_weights=keep_full_weights)
else:
up_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=bias,
skip_bias_add=skip_bias_add,
quant_config=quant_config,
prefix=f"{prefix}.{up_proj_name}",
tp_group=self.tp_group,
keep_full_weights=keep_full_weights)
self.register_module(up_proj_name, up_proj)
# down_proj
down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=bias,
skip_bias_add=skip_bias_add,
reduce_results=reduce_results,
quant_config=quant_config,
prefix=f"{prefix}.{down_proj_name}",
tp_group=self.tp_group,
keep_full_weights=keep_full_weights)
self.register_module(down_proj_name, down_proj)
def prepare_weight(self):
if not self.is_initialized:
# alpha and beta are 1.0 and 0.0 respectively due to the fact that we don't need residual for now
self.alpha = 1.0
self.beta = 0.0
# place it here to avoid the overhead of calling it in the forward pass
self.is_initialized = True
def _forward(self, hidden_states):
self.prepare_weight()
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
act_dict = {
"relu": F.relu,
"gelu": F.gelu,
"silu": F.silu,
}
fc1 = F.linear(hidden_states, up_proj.weight, bias=up_proj.bias)
if self.is_gated:
d = fc1.shape[-1] // 2
fc1 = act_dict[self.hidden_act](fc1[..., :d]) * fc1[..., d:]
else:
fc1 = act_dict[self.hidden_act](fc1)
fc2 = F.linear(fc1, down_proj.weight, bias=None)
fc2 = tensor_model_parallel_all_reduce(fc2)
if not self.skip_bias_add:
fc2 = fc2 + down_proj.bias if down_proj.bias is not None else fc2
return fc2
def forward_naive(
self,
hidden_states,
residual: torch.Tensor | None = None,
smooth_quant_scale: torch.Tensor | None = None
):
'''
used by quant_tools
'''
assert self.quant_config is None, "ffn naive forward dosen't support quantization"
assert smooth_quant_scale is None, "ffn naive forward dosen't support smooth_quant_scale"
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
residual_ = None if self.tp_rank > 0 else residual
fc1, bias = up_proj(hidden_states)
if bias is not None:
fc1 += bias
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
out, bias = down_proj(fc1, residual=residual_)
if self.skip_bias_add:
return out, bias
return out
def forward(
self,
hidden_states,
residual: torch.Tensor | None = None,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
output: torch.Tensor | None = None,
):
self.prepare_weight()
if self.use_bt_ffn is False:
return self.forward_naive(hidden_states, residual, None)
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
residual_ = None if self.tp_rank > 0 else residual
if (self.quant_config is None and not isinstance(up_proj, BaseLayerWithLoRA)
and not isinstance(down_proj, BaseLayerWithLoRA)):
# The matmul formula is the following:
# mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual
# output = active(mul_out)
# Notes: We cannot use the activation function in matmul because it does not support gated operation
# we might support its in tmo matmul in the future
up_proj_weight = up_proj.weight
down_proj_weight = down_proj.weight
if self.keep_full_weights and use_tp_weight:
up_proj_weight = up_proj.tp_weight
down_proj_weight = down_proj.tp_weight
fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj_weight, up_proj.bias,
None, 'none', self.alpha, self.beta)
act_out = mlu_ops.active(fc1.float(), self.hidden_act, self.is_gated).to(dtype=fc1.dtype)
beta = 0.0
if residual_ is not None:
beta = 1.0
residual_ = residual_.view(-1, residual_.shape[-1])
out_ = mlu_ops.matmul(act_out, down_proj_weight, None, residual_, 'none', self.alpha, beta)
# bias if existed need to add after second matmul according to the original design of vllm
if self.reduce_results:
out = tensor_model_parallel_all_reduce(out_, self.tp_group)
else:
out = out_
# do the bias add if needed
if not self.skip_bias_add:
out = out + down_proj.bias if down_proj.bias is not None else out
else:
return out, down_proj.bias
else:
fc1, bias = up_proj(hidden_states, smooth_quant_scale=smooth_quant_scale, use_tp_weight=use_tp_weight)
if bias is not None:
fc1 += bias
input_scale= None
if (self.quant_config is not None and self.quant_config.get_name() == "SmoothQuant" and
self.quant_config.input_quant_method == "per_token" and not self.quant_config.is_fp8):
down_proj.quant_method.skip_quant_input = True
down_proj_smooth = down_proj.smooth
if self.keep_full_weights and use_tp_weight:
assert down_proj.tp_smooth is not None, "tp_smooth is not initialized"
down_proj_smooth = down_proj.tp_smooth
fc1, input_scale = mlu_ops.per_token_smooth_quantize(
fc1, down_proj_smooth, None, None, act_mode=self.hidden_act, is_gated=self.is_gated)
else:
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
out, bias = down_proj(
fc1, residual=residual_, smooth_quant_scale=input_scale,
use_tp_weight=use_tp_weight, output=output)
if self.skip_bias_add:
return out, bias
return out

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,935 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
import functools
import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_moe_kernel_gptq_awq,
write_zeros_to_output,
get_default_config,
try_get_optimal_moe_config,
_get_config_quant_dtype,
)
from vllm.model_executor.layers.fused_moe.utils import (
activation_without_mul,
disable_inplace,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm_mlu.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize
import vllm_mlu._mlu_ops as mlu_ops
logger = init_logger(__name__)
@triton.jit
def fused_moe_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
b_bias_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
stride_bbe, # bias expert stride
stride_bbn, # bias N stride
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
'''
=============================
Modify by vllm_mlu
=============================
@brief: Split the program ID into two dimensions (pid_0 and pid_1)
'''
pid_0 = tl.program_id(axis=0)
pid_1 = tl.program_id(axis=1)
pid = pid_1 * tl.num_programs(axis=0) + pid_0
'''
==================
End of MLU Hijack
==================
'''
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(
c_ptr,
stride_cm,
stride_cn,
pid_n,
N,
offs_token,
token_mask,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
compute_type,
)
return
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
if use_int8_w8a16:
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
)
b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8 or use_int8_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
)
# channel-wise
elif per_channel_quant:
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
)
b_scale = tl.load(b_scale_ptrs)
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
# tensor-wise
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
if HAS_BIAS:
# bias shape: [num_experts, N]
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
if use_fp8_w8a8:
# acc used to enable fp8_fast_accum
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if HAS_BIAS:
accumulator = accumulator + bias[None, :]
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: torch.Tensor | None,
B_scale: torch.Tensor | None,
B_zp: torch.Tensor | None,
topk_weights: torch.Tensor | None,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: dict[str, Any],
compute_type: tl.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: list[int] | None = None,
B_bias: torch.Tensor | None = None,
) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8 or use_int8_w8a8:
assert B_scale is not None
assert block_shape is None or triton.cdiv(
B.size(-2), block_shape[0]
) == B_scale.size(-2)
assert block_shape is None or triton.cdiv(
B.size(-1), block_shape[1]
) == B_scale.size(-1)
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
M = A.size(0)
num_tokens = M * top_k
EM = sorted_token_ids.size(0)
if A.size(0) < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique,
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
'''
=============================
Modify by vllm_mlu
=============================
@brief: Split the program ID into two dimensions (pid_0, pid_1)
'''
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']), triton.cdiv(
B.shape[1], META['BLOCK_SIZE_N']), )
assert not (use_int8_w8a16 or use_int4_w4a16)
'''
==================
End of MLU Hijack
==================
'''
HAS_BIAS = B_bias is not None
if (
(use_int8_w8a16 or use_int4_w4a16)
and block_shape is not None
and block_shape[1] > 0
):
assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
num_valid_tokens=num_tokens,
group_size=block_shape[1],
num_experts=B.size(0),
bit=4 if use_int4_w4a16 else 8,
)
config = config.copy()
config.update(
get_moe_wna16_block_config(
config=config,
use_moe_wna16_cuda=use_moe_wna16_cuda,
num_valid_tokens=num_tokens,
size_k=A.size(1),
size_n=B.size(1),
num_experts=B.size(1),
group_size=block_shape[1],
real_top_k=top_k,
block_size_m=config["BLOCK_SIZE_M"],
)
)
if use_moe_wna16_cuda:
bit = 4 if use_int4_w4a16 else 8
ops.moe_wna16_gemm(
A,
C,
B,
B_scale,
B_zp,
topk_weights if mul_routed_weight else None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
top_k,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
bit,
)
return
fused_moe_kernel_gptq_awq[grid](
A,
B,
C,
B_scale,
B_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.size(1),
A.size(1),
EM,
num_tokens,
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
B_scale.stride(0),
B_scale.stride(2),
B_scale.stride(1),
B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0,
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
**config,
)
else:
config = config.copy()
config["SPLIT_K"] = 1
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
if block_shape is not None:
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
fused_moe_kernel[grid](
A,
B,
C,
B_bias,
A_scale,
B_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.size(1),
B.size(2),
EM,
num_tokens,
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
B_bias.stride(0) if B_bias is not None else 0,
B_bias.stride(1) if B_bias is not None else 0,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
HAS_BIAS=HAS_BIAS,
BLOCK_SIZE_K=BLOCK_SIZE_K,
**config,
)
def outplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
True,
activation,
apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
ocp_mx_scheme,
per_channel_quant,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
w1_bias,
w2_bias,
)
def outplace_fused_experts_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> None:
pass
direct_register_custom_op(
op_name="outplace_fused_experts_mlu",
op_func=outplace_fused_experts,
mutates_args=["hidden_states"],
fake_impl=outplace_fused_experts_fake,
dispatch_key="PrivateUse1",
tags=(
()
if is_torch_equal_or_newer("2.7.0")
else (torch.Tag.needs_fixed_stride_order,)
),
)
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
return torch.ops.vllm.outplace_fused_experts_mlu(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu")
RELU2_NO_MUL: str = activation_without_mul("relu2")
def fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> torch.Tensor:
# Check constraints.
if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
elif ocp_mx_scheme is not None:
if ocp_mx_scheme in {
"w_mxfp4_a_mxfp4",
"w_mxfp4_a_mxfp6_e3m2",
"w_mxfp4_a_mxfp6_e2m3",
}:
# 16bit activation and fp4x2 packed weight
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
elif ocp_mx_scheme in {
"w_mxfp6_e3m2_a_mxfp6_e3m2",
"w_mxfp6_e2m3_a_mxfp6_e2m3",
}:
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
"hidden size mismatch"
)
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
else:
assert hidden_states.size(1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
)
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
num_tokens = hidden_states.size(0)
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
config_dtype = _get_config_dtype_str(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
ocp_mx_scheme=ocp_mx_scheme,
dtype=hidden_states.dtype,
)
# Note: for use_int8_w8a16 or use_int4_w4a16, the activations are
# quantized prior to calling fused_experts.
quant_dtype = _get_config_quant_dtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
ocp_mx_scheme=ocp_mx_scheme,
)
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.size(),
w2.size(),
top_k_num,
config_dtype,
block_shape=block_shape,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Only use the default config
'''
config = get_default_config(M, E, N, w1.shape[2], topk_ids.shape[1],
hidden_states.dtype, block_shape)
'''
==================
End of MLU Hijack
==================
'''
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
cache13 = torch.empty(
M * top_k_num * max(N, K),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
# This needs separate memory since it's used concurrently with cache1
intermediate_cache2 = torch.empty(
(M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
if inplace and not disable_inplace():
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
if ocp_mx_scheme is not None:
# TODO: On platforms for which `current_platform.supports_mx()` is True
# and for which we have a native OCP mx fused MOE kernel,
# this dequantization step should not be done.
if ocp_mx_scheme in {
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
}:
# Weight has to be dequantized for mxfp4 emulation.
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
w1_scale = None
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2:
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w1_scale = None
w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3:
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
w1_scale = None
w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
w2_scale = None
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (
chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE, num_tokens),
)
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.size()
if tokens_in_chunk == 0:
break
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
a1q_scale: Optional[torch.Tensor] = None
if use_fp8_w8a8:
qcurr_hidden_states, a1q_scale = _fp8_quantize(
curr_hidden_states, a1_scale, block_shape)
else:
qcurr_hidden_states = curr_hidden_states
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
)
invoke_fused_moe_kernel(
qcurr_hidden_states,
w1,
intermediate_cache1,
a1q_scale,
w1_scale,
w1_zp,
curr_topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
apply_router_weight_on_input,
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w1_bias,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Activate by mlu_ops
'''
intermediate_cache2 = mlu_ops.active(intermediate_cache1.view(-1, N),
act_mode=activation,
is_gated=True)
'''
==================
End of MLU Hijack
==================
'''
a2q_scale: Optional[torch.Tensor] = None
if use_fp8_w8a8:
qintermediate_cache2, a2q_scale = _fp8_quantize(
intermediate_cache2, a2_scale, block_shape)
else:
qintermediate_cache2 = intermediate_cache2
invoke_fused_moe_kernel(
qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
curr_topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w2_bias,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace moe_sum with torch.sum
Reference Links: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py#L1513
'''
if topk_ids.shape[1] == 2:
torch.add(
intermediate_cache3[:, 0],
intermediate_cache3[:, 1],
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
).squeeze(dim=1)
elif topk_ids.shape[1] > 2:
torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
'''
==================
End of MLU Hijack
==================
'''
return out_hidden_states

View File

@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional, Callable
import torch
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
def vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
#TODO: support `routed_scaling_factor`
assert routed_scaling_factor == 1.0, (
f"routed_scaling_factor {routed_scaling_factor} is not supported for MLU."
)
use_fused_kernel = topk_group is None
if use_fused_kernel:
assert not enable_eplb, f"MLU not support eplb in fused_moe kernel."
assert use_grouped_topk is False and num_expert_group is None and topk_group is None, \
f"Following params: use_grouped_topk, num_expert_group, topk_group are not support yet."
return mlu_ops.fused_moe(
x,
router_logits,
layer.w13_weight, layer.w2_weight,
None, None, # bias1, bias2
None, # residual
None, # input_smooth
None, # act_smooth
None, None, # w1_scale, w2_scale
top_k,
renormalize,
True, # gated
activation
)
else:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
if self.rocm_aiter_moe_enabled:
assert expert_map is None
return self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
MluHijackObject.apply_hijack(
UnquantizedFusedMoEMethod,
UnquantizedFusedMoEMethod.forward_oot,
vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot
)

View File

@@ -0,0 +1,248 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
import torch
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv, round_up
'''
=============================
Modify by vllm_mlu
=============================
@brief: Implementation of moe_align_block_size_triton.
Note: the implemtentation has been removed from vllm since the
cuda implementation is more efficient.
'''
@triton.jit
def moe_align_block_size_stage1(
topk_ids_ptr,
tokens_cnts_ptr,
num_experts: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = pid * tokens_per_thread
off_c = (pid + 1) * num_experts
for i in range(tokens_per_thread):
if start_idx + i < numel:
idx = tl.load(topk_ids_ptr + start_idx + i)
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
@triton.jit
def moe_align_block_size_stage2(
tokens_cnts_ptr,
num_experts: tl.constexpr,
):
pid = tl.program_id(0)
last_cnt = 0
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
last_cnt = last_cnt + token_cnt
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
@triton.jit
def moe_align_block_size_stage3(
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
):
last_cumsum = 0
off_cnt = num_experts * num_experts
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
tl.store(cumsum_ptr + i, last_cumsum)
tl.store(total_tokens_post_pad_ptr, last_cumsum)
@triton.jit
def moe_align_block_size_stage4(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = tl.load(cumsum_ptr + pid)
end_idx = tl.load(cumsum_ptr + pid + 1)
for i in range(start_idx, end_idx, block_size):
tl.store(expert_ids_ptr + i // block_size, pid)
start_idx = pid * tokens_per_thread
off_t = pid * num_experts
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
numel)):
expert_id = tl.load(topk_ids_ptr + i)
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
# Triton implementation based on:
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
def moe_align_block_size_triton(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts, )
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
dtype=torch.int32,
device=topk_ids.device)
cumsum = torch.zeros((num_experts + 1, ),
dtype=torch.int32,
device=topk_ids.device)
tokens_per_thread = cdiv(numel, num_experts)
sorted_token_ids.fill_(numel)
expert_ids.zero_()
moe_align_block_size_stage1[grid](
topk_ids,
tokens_cnts,
num_experts,
numel,
tokens_per_thread,
)
moe_align_block_size_stage2[grid](
tokens_cnts,
num_experts,
)
moe_align_block_size_stage3[(1, )](
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
)
moe_align_block_size_stage4[grid](
topk_ids,
sorted_token_ids,
expert_ids,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
)
'''
==================
End of MLU Hijack
==================
'''
def moe_align_block_size(
topk_ids: torch.Tensor,
block_size: int,
num_experts: int,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
- expert_map: A tensor of shape [num_experts] that maps the expert index
from the global space to the local index space of the current
expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1.
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
should be padded to a multiple of block_size,
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids = torch.zeros((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Only use triton to implement moe_align_block_size
'''
moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
'''
==================
End of MLU Hijack
==================
'''
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
from math import prod
from typing import List, Optional, Tuple
import torch
from vllm.utils.math_utils import cdiv
def _fp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
block_shape: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
assert block_shape is not None
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
return A, A_scale

View File

@@ -0,0 +1,278 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import (
get_tensor_model_parallel_world_size,
get_tp_group
)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.compressor import (
Compressor,
rotate_activation,
)
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
logger = init_logger(__name__)
class Indexer(torch.nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
rope,
compress_ratio: int = 4,
prefix: str = "",
**kwargs,
):
super().__init__()
config = vllm_config.model_config.hf_config
self.dim = config.dim
self.n_heads = config.index_n_heads
self.tp_size = get_tensor_model_parallel_world_size()
self.n_local_heads = config.index_n_heads // self.tp_size
self.head_dim = config.index_head_dim
self.rope_head_dim = config.rope_head_dim
self.index_topk = config.index_topk
self.q_lora_rank = config.q_lora_rank
self.window_size = config.window_size
self.block_size = vllm_config.cache_config.block_size
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=None,
prefix=f"{prefix}.wq_b",
)
self.weights_proj = ReplicatedLinear(
self.dim,
self.n_heads,
bias=False,
quant_config=None,
params_dtype = torch.bfloat16,
prefix=f"{prefix}.weights_proj",
)
self.softmax_scale = self.head_dim ** -0.5
self.merged_softmax_scale = (self.head_dim ** -0.5) * (self.n_heads ** -0.5)
self.compress_ratio = compress_ratio
self.max_model_len = vllm_config.model_config.max_model_len
self.rotary_emb = rope
self.tp_group = get_tp_group()
self.compressor = Compressor(vllm_config, self.rotary_emb, compress_ratio, self.head_dim, True, f"{prefix}.compressor")
self.freqs_cis = None
def forward_prefill(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
weights: torch.Tensor,
attn_metadata: AttentionMetadata,
k_full: torch.Tensor,
context_lens: torch.Tensor,
):
assert attn_metadata.prefill.chunked_context is None, \
f"Prefill chunked context is not supported."
query_start_loc = attn_metadata.prefill.query_start_loc
cu_seq_q_lens = query_start_loc
cu_seq_k_lens = torch.zeros(
context_lens.size(0) + 1, dtype=torch.int32, device=q.device,
)
torch.cumsum(context_lens, dim=0, out=cu_seq_k_lens[1:])
attn_metadata.prefill.query_start_loc
seq_lens = torch.diff(cu_seq_k_lens)
batch_size = seq_lens.shape[0]
new_block_tables = torch.empty(
[attn_metadata.num_prefill_tokens, self.index_topk],
dtype=torch.int32,
device=q.device,
)
new_context_lens = torch.empty(
[attn_metadata.num_prefill_tokens],
dtype=torch.int32,
device=q.device,
)
q_seq_lens = cu_seq_q_lens[1:]-cu_seq_q_lens[:-1]
max_seq_len = q_seq_lens.max().item()
batch_size = q_seq_lens.size(0)
max_compressed_kv_len = max_seq_len // self.compress_ratio
kv_cache_block_table = torch.zeros([batch_size, max_compressed_kv_len], dtype=torch.int32, device=q.device)
# The layout of linear kv is as follows:
# | bs0_origin_kv | bs1_origin_kv | bs0_compressed_kv | bs1_compressed_kv |
for i in range(batch_size):
start = cu_seq_k_lens[i].item()
kv_cache_block_table[i] = torch.arange(
start, start + max_compressed_kv_len,
dtype=torch.int32,
device=q.device,
)
# offset total origin_kv len
kv_cache_block_table = kv_cache_block_table + cu_seq_q_lens[-1]
# query: (tokens, index_head, index_head_dim)
# k_full: (tokens, index_head_dim)
# weights: (tokens, index_head, 1)
mlu_ops.masked_indexer_select_paged_kv_prefill(
query=q,
key_value=k_full,
weights=weights.unsqueeze(-1),
kv_cache_block_table=kv_cache_block_table,
cu_seq_q_lens=cu_seq_q_lens,
cu_seq_k_lens=cu_seq_k_lens,
index_topk=self.index_topk,
kv_cache_block_size=self.block_size,
softmax_scale=self.merged_softmax_scale,
q_scale=None,
k_scale_cache=None,
sparse_block_table=new_block_tables,
sparse_context_lens=new_context_lens,
compress_ratio=self.compress_ratio,
kv_cache_block_table_offset=None,
)
return new_block_tables, new_context_lens
def forward_decode(
self,
q: torch.Tensor,
x: torch.Tensor,
k_cache: torch.Tensor,
weights: torch.Tensor,
attn_metadata: AttentionMetadata,
):
block_table = attn_metadata.decode.block_table
batch_size = block_table.shape[0]
seq_len = x.shape[0] // batch_size
q = q.view(batch_size, seq_len, *q.shape[1:])
weights = weights.view(batch_size, seq_len, *weights.shape[1:])
seq_lens = attn_metadata.decode.seq_lens
k_block_table = block_table
seq_len = x.shape[0] // batch_size
new_block_tables = torch.empty(
[batch_size, seq_len, self.index_topk],
dtype=torch.int32,
device=block_table.device,
)
new_context_lens = torch.empty(
[attn_metadata.num_decode_tokens],
dtype=torch.int32,
device=block_table.device,
)
kv_cache_block_table_offset=torch.empty(
[attn_metadata.num_decode_tokens],
dtype=torch.int32,
device=block_table.device,
)
kv_cache_block_table_offset.fill_(self.window_size)
mlu_ops.masked_indexer_select_paged_kv_decode(
query=q,
k_cache=k_cache,
weights=weights.unsqueeze(-1), # (bsz, seq_q, head_num, 1)
kv_cache_block_table=block_table,
k_context_lens=seq_lens // self.compress_ratio,
k_cache_block_table=k_block_table,
index_topk=self.index_topk,
kv_cache_block_size=self.block_size,
softmax_scale=self.merged_softmax_scale,
q_scale=None,
k_scale_cache=None,
sparse_block_table=new_block_tables,
sparse_context_lens=new_context_lens,
compress_ratio=self.compress_ratio,
kv_cache_block_table_offset=kv_cache_block_table_offset,
)
# [batch, seq_q, index_topk] -> [batch, index_topk]
new_block_tables = new_block_tables.squeeze(1)
return new_block_tables, new_context_lens
def forward(self,
x: torch.Tensor,
qr: torch.Tensor,
positions: torch.Tensor,
offsets: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
indexer_kv_cache: torch.Tensor,
compressor_slot_mapping: torch.Tensor,
):
common_metadata = get_common_metadata()
query_start_loc = common_metadata.query_start_loc
query_lens = query_start_loc[1:] - query_start_loc[:-1]
rd = self.rope_head_dim
q = self.wq_b(qr)[0]
q = q.unflatten(-1, (self.n_heads, self.head_dim))
self.rotary_emb(positions, q[..., -rd:], None, only_prefill=False)
q_pack = rotate_activation(q)
weights_pack = self.weights_proj(x)[0] # (tokens, index_local_head)
num_decode_tokens = attn_metadata.num_decode_tokens
compressed_kv = self.compressor(
x,
positions,
attn_metadata,
batch_to_kv_state,
indexer_kv_cache,
0,
compressor_slot_mapping,
)
if attn_metadata.prefill:
assert compressed_kv is not None and compressed_kv.dim() == 3
compressed_kv = compressed_kv.squeeze(-2)
compressed_context_lens = query_lens // self.compress_ratio
prefill_q = q_pack[num_decode_tokens:, ...]
prefill_weights = weights_pack[num_decode_tokens:, ...]
prefill_block_tables, prefill_context_lens = self.forward_prefill(
prefill_q,
indexer_kv_cache,
prefill_weights,
attn_metadata,
compressed_kv,
compressed_context_lens,
)
if attn_metadata.decode:
decode_x = x[:num_decode_tokens, ...]
decode_q = q_pack[:num_decode_tokens, ...]
decode_weights = weights_pack[attn_metadata.num_prefills:]
decode_block_tables, decode_context_lens = self.forward_decode(
decode_q,
decode_x,
indexer_kv_cache,
decode_weights,
attn_metadata,
)
if attn_metadata.prefill and attn_metadata.decode:
new_block_tables = torch.cat([prefill_block_tables, decode_block_tables], dim=0)
new_context_lens = torch.cat([prefill_context_lens, decode_context_lens], dim=0)
elif attn_metadata.prefill:
new_block_tables = prefill_block_tables
new_context_lens = prefill_context_lens
else:
new_block_tables = decode_block_tables
new_context_lens = decode_context_lens
return new_block_tables, new_context_lens

View File

@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Tuple
import torch
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.models.layer_utils import is_per_token_smoothquant
@CustomOp.register("quant_fusion_rms_norm")
class QuantFusionRMSNorm(RMSNorm):
def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase):
super().__init__(hidden_size, variance_epsilon)
assert not isinstance(
proj.quant_method, UnquantizedLinearMethod
), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported"
proj.quant_method.skip_quant_input = True
if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config):
quant_scale = proj.smooth.data
else:
quant_scale = proj.scale_to_int.data
self.dynamic_quant = dynamic_quant
self.quant_scale = torch.nn.Parameter(quant_scale)
def forward(
self, x: torch.Tensor, residual: torch.Tensor | None = None
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
return mlu_ops.fused_rms_norm(
x,
residual,
self.weight.data,
None,
None,
self.variance_epsilon,
False,
self.quant_scale.data,
self.dynamic_quant,
)
@CustomOp.register("quant_fusion_layer_norm")
class QuantFusionLayerNorm(torch.nn.LayerNorm, CustomOp):
def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase):
super().__init__(hidden_size, variance_epsilon)
assert not isinstance(
proj.quant_method, UnquantizedLinearMethod
), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported"
proj.quant_method.skip_quant_input = True
if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config):
quant_scale = proj.smooth.data
else:
quant_scale = proj.scale_to_int.data
self.dynamic_quant = dynamic_quant
self.quant_scale = torch.nn.Parameter(quant_scale)
def forward(
self, x: torch.Tensor, residual: torch.Tensor | None = None
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
bias = None if self.bias is None else self.bias.data
return mlu_ops.fused_layer_norm(
x,
residual,
self.weight.data,
bias,
None,
self.eps,
False,
self.quant_scale.data,
self.dynamic_quant,
)
def vllm__model_executor__layers__layernorm__RMSNorm__forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
out: torch.Tensor | None = None,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
org_shape = x.shape
x = x.reshape(-1, self.weight.data.shape[0])
if out is not None:
out = out.view(-1, self.weight.data.shape[0])
if residual is not None:
residual = residual.view(-1, self.weight.data.shape[0])
x = mlu_ops.fused_rms_norm(
x,
residual,
self.weight.data,
None,
None,
self.variance_epsilon,
True,
out=out,
)
else:
x = mlu_ops.fused_rms_norm(
x,
residual,
self.weight.data,
None,
None,
self.variance_epsilon,
False,
out=out,
)
if out is not None:
return x
if residual is None:
assert isinstance(x, torch.Tensor)
return x.view(org_shape)
assert isinstance(x, tuple)
assert len(x) == 2
return x[0].view(org_shape), x[1].view(org_shape)
MluHijackObject.apply_hijack(
RMSNorm,
RMSNorm.forward_oot,
vllm__model_executor__layers__layernorm__RMSNorm__forward_oot,
)

View File

@@ -0,0 +1,693 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional, Any
import torch
from torch.nn.parameter import Parameter
from vllm.distributed import (divide, split_tensor_along_last_dim,
get_parallel_rank_with_group, get_parallel_world_size_with_group,
get_tp_world_group, get_tp_world_world_size, get_tp_world_rank)
from vllm.distributed.communication_op import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.linear import (
WEIGHT_LOADER_V2_SUPPORTED, UnquantizedLinearMethod, LinearBase,
ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED.extend([
"GPTQMluLinearMethod",
"AWQMluLinearMethod"
])
vllm__module_executor__layers__linear__LinearBase____init__org = LinearBase.__init__
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader_org = MergedColumnParallelLinear.weight_loader
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader_org = RowParallelLinear.weight_loader
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual parameter.
@brief: dispatch unquantized_gemm to mlu ops.
'''
def vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
residual: torch.Tensor | None = None
) -> torch.Tensor:
beta = 0.0
if residual is not None:
beta = 1.0
residual = residual.view(-1, residual.shape[-1])
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
return mlu_ops.matmul(x.reshape(x.numel() // x.shape[-1], x.shape[-1]),
layer.weight,
bias, residual, 'none', 1.0, beta).view(res_shape)
'''
==================
End of MLU Hijack
==================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group and keep_full_weights parameters.
'''
def vllm__module_executor__layers__linear__LinearBase____init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
vllm__module_executor__layers__linear__LinearBase____init__org(
self=self,
input_size=input_size,
output_size=output_size,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add self.tp_group, world_size and tp_rank to support data parallel and moe expert parallel
'''
self.tp_group = tp_group
self.tp_world_size = get_parallel_world_size_with_group(self.tp_group)
self.tp_size = self.tp_world_size
self.tp_rank = get_parallel_rank_with_group(self.tp_group)
self.keep_full_weights = keep_full_weights
if self.keep_full_weights or disable_tp:
self.tp_group = None
self.tp_world_size = 1
self.tp_size = self.tp_world_size
self.tp_rank = 0
self.tp_world_size_org = get_tp_world_world_size()
self.tp_rank_org = get_tp_world_rank()
'''
=================
End of MLU Hijack
=================
'''
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group and keep_full_weights parameters.
'''
def vllm__module_executor__layers__linear__ColumnParallelLinear____init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
output_sizes: list[int] | None = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
super(ColumnParallelLinear, self).__init__(
input_size,
output_size,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
tp_group=tp_group,
keep_full_weights=keep_full_weights,
return_bias=return_bias,
disable_tp=disable_tp,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: self.tp_size and self.tp_rank has been initialized in LinearBase.__init__
'''
# Divide the weight matrix along the last dimension.
# self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
# self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
'''
=================
End of MLU Hijack
=================
'''
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size) for output_size in self.output_sizes
]
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group in create_weights
'''
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
else self.weight_loader
),
tp_group=self.tp_group,
)
'''
=================
End of MLU Hijack
=================
'''
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition, dtype=params_dtype)
)
set_weight_attrs(
self.bias,
{
"output_dim": 0,
"weight_loader": self.weight_loader,
},
)
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add smooth_quant_scale and use_tp_weight parameters.
'''
def vllm__module_executor__layers__linear__ColumnParallelLinear__forward(
self,
input_,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add input_scale and use_tp_weight parameter.
'''
kwargs = {'bias': bias}
if use_tp_weight:
kwargs['use_tp_weight'] = use_tp_weight
if smooth_quant_scale is not None:
kwargs['input_scale'] = smooth_quant_scale
output_parallel = self.quant_method.apply(self, input_, **kwargs)
'''
==================
End of MLU Hijack
==================
'''
if self.gather_output and self.tp_size > 1:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group param to tensor_model_parallel_all_gather
'''
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel, dim=-1, tp_group=self.tp_group)
'''
=================
End of MLU Hijack
=================
'''
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group and keep_full_weights parameters.
'''
def vllm__module_executor__layers__linear__MergedColumnParallelLinear____init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
self.output_sizes = output_sizes
'''
=============================
Modify by vllm_mlu
=============================
@brief: checkout output_sizes after init to get self.tp_world_size
@brief: add keep_full_weights for dp parallelize shared expert
'''
super(MergedColumnParallelLinear, self).__init__(
input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
output_sizes=self.output_sizes,
prefix=prefix,
tp_group=tp_group,
keep_full_weights=keep_full_weights,
return_bias=return_bias,
disable_tp=disable_tp,
)
assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
if self.keep_full_weights:
tp_size = self.tp_world_size_org
if isinstance(self.quant_method, UnquantizedLinearMethod):
out_dim, in_dim = self.weight.shape
out_dim_tp = divide(out_dim, tp_size)
self.tp_weight = Parameter(
self.weight.data.new_empty((out_dim_tp, in_dim)),
requires_grad=False,
)
elif (isinstance(self.quant_method, SmoothQuantLinearMethod)
and quant_config.input_quant_method == "per_token"):
out_dim, in_dim = self.qweight.shape
out_dim_tp = divide(out_dim, tp_size)
self.tp_qweight = Parameter(
self.qweight.data.new_empty((out_dim_tp, in_dim)),
requires_grad=False,
)
self.tp_per_channel_scale = Parameter(
self.per_channel_scale.data.new_empty((out_dim_tp)),
requires_grad=False,
)
else:
raise TypeError(f"quant method is expected to be unquantized or smoothquant per-token")
'''
=================
End of MLU Hijack
=================
'''
'''
=================
End of MLU Hijack
=================
'''
def vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: int | None = None,
):
loaded_weight_orig = loaded_weight
output_dim = getattr(param, "output_dim", None)
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader_org(
self=self,
param=param,
loaded_weight=loaded_weight,
loaded_shard_id=loaded_shard_id,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add keep_full_weights for dp parallelize shared expert
'''
# load into tp weight
if self.keep_full_weights:
tp_size = self.tp_world_size_org
tp_rank = self.tp_rank_org
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
start_idx = tp_rank * shard_size
if isinstance(self.quant_method, UnquantizedLinearMethod):
tp_weight = loaded_weight_orig.narrow(output_dim, start_idx, shard_size)
tp_weight_shard = self.tp_weight.narrow(output_dim, shard_offset, shard_size)
tp_weight_shard.copy_(tp_weight)
elif isinstance(self.quant_method, SmoothQuantLinearMethod):
if output_dim is None:
return
tp_weight = loaded_weight_orig.narrow(output_dim, start_idx, shard_size)
if loaded_weight_orig.ndim == 1:
tp_weight_shard = self.tp_per_channel_scale.narrow(output_dim, shard_offset, shard_size)
elif loaded_weight_orig.ndim == 2:
tp_weight_shard = self.tp_qweight.narrow(output_dim, shard_offset, shard_size)
else:
raise ValueError("only support rank 1 and 2 when using tp_weight")
tp_weight_shard.copy_(tp_weight)
else:
raise TypeError(f"quant method is expected to be either unquantized or smoothquant")
'''
=================
End of MLU Hijack
=================
'''
def vllm__module_executor__layers__linear__RowParallelLinear____init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
super(RowParallelLinear, self).__init__(
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
tp_group=tp_group,
keep_full_weights=keep_full_weights,
return_bias=return_bias,
disable_tp=disable_tp,
)
# Divide the weight matrix along the last dimension
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group in create_weights
'''
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
else self.weight_loader
),
tp_group=self.tp_group,
)
'''
=================
End of MLU Hijack
=================
'''
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(
self.bias,
{
"output_dim": 0,
"weight_loader": self.weight_loader,
},
)
else:
self.register_parameter("bias", None)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add keep_full_weights for dp parallelize shared expert
'''
if self.keep_full_weights:
tp_size = self.tp_world_size_org
if isinstance(self.quant_method, UnquantizedLinearMethod):
out_dim, in_dim = self.weight.data.shape
in_dim_tp = divide(in_dim, tp_size)
self.tp_weight = Parameter(self.weight.data.new_empty((out_dim, in_dim_tp)),
requires_grad=False)
elif (isinstance(self.quant_method, SmoothQuantLinearMethod)
and quant_config.input_quant_method == "per_token"):
out_dim, in_dim = self.qweight.data.shape
in_dim_tp = divide(in_dim, tp_size)
self.tp_qweight = Parameter(self.qweight.data.new_empty((out_dim, in_dim_tp)),
requires_grad=False)
if hasattr(self, "smooth"):
assert len(self.smooth.shape) == 1, "smooth should be a 1D tensor"
dim = self.smooth.shape[0]
dim_tp = divide(dim, tp_size)
self.tp_smooth = Parameter(self.smooth.data.new_empty((dim_tp)),
requires_grad=False)
else:
raise TypeError("quant method expected to be unquantized or smoothquant per-token")
'''
=================
End of MLU Hijack
=================
'''
self.update_param_tp_status()
def vllm__module_executor__layers__linear__RowParallelLinear__weight_loader(
self, param: Parameter, loaded_weight: torch.Tensor
):
input_dim = getattr(param, "input_dim", None)
loaded_weight_orig = loaded_weight
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader_org(
self=self,
param=param,
loaded_weight=loaded_weight,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add keep_full_weights for dp parallelize shared expert
'''
if self.keep_full_weights:
if input_dim is None:
return
tp_size = self.tp_world_size_org
tp_rank = self.tp_rank_org
shard_size = divide(loaded_weight_orig.shape[input_dim], tp_size)
start_idx = tp_rank * shard_size
if isinstance(self.quant_method, UnquantizedLinearMethod):
shard_view = self.weight.narrow(input_dim, start_idx, shard_size)
self.tp_weight.copy_(shard_view)
elif isinstance(self.quant_method, SmoothQuantLinearMethod):
if loaded_weight_orig.ndim == 1:
shard_view = self.smooth.narrow(input_dim, start_idx, shard_size)
self.tp_smooth.copy_(shard_view)
elif loaded_weight_orig.ndim == 2:
shard_view = self.qweight.narrow(input_dim, start_idx, shard_size)
self.tp_qweight.copy_(shard_view)
else:
raise ValueError("only rank 1 and 2 is supported for tp_weight")
else:
raise TypeError("quant method is expected to be UnquantizedLinearMethod and SmoothQuant")
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual, smooth_quant_scale, use_tp_weight and output parameters.
'''
def vllm__module_executor__layers__linear__RowParallelLinear__forward(
self,
input_,
residual: torch.Tensor | None = None,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
output: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add additional matmul parameters.
'''
residual_ = None if self.tp_rank > 0 else residual
kwargs = {'bias': bias_, 'residual': residual_}
if use_tp_weight:
kwargs['use_tp_weight'] = use_tp_weight
if smooth_quant_scale is not None:
kwargs['input_scale'] = smooth_quant_scale
if output is not None:
kwargs['output'] = output
output_parallel = self.quant_method.apply(self, input_parallel, **kwargs)
'''
=================
End of MLU Hijack
=================
'''
if self.reduce_results and self.tp_size > 1:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tensor_model_parallel_all_reduce() with self.tp_group
'''
output = tensor_model_parallel_all_reduce(output_parallel, tp_group=self.tp_group)
'''
=================
End of MLU Hijack
=================
'''
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
'''
=================
End of MLU Hijack
=================
'''
MluHijackObject.apply_hijack(UnquantizedLinearMethod,
UnquantizedLinearMethod.apply,
vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply)
MluHijackObject.apply_hijack(LinearBase,
LinearBase.__init__,
vllm__module_executor__layers__linear__LinearBase____init__)
MluHijackObject.apply_hijack(ColumnParallelLinear,
ColumnParallelLinear.__init__,
vllm__module_executor__layers__linear__ColumnParallelLinear____init__)
MluHijackObject.apply_hijack(ColumnParallelLinear,
ColumnParallelLinear.forward,
vllm__module_executor__layers__linear__ColumnParallelLinear__forward)
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
MergedColumnParallelLinear.__init__,
vllm__module_executor__layers__linear__MergedColumnParallelLinear____init__)
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
MergedColumnParallelLinear.weight_loader,
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.__init__,
vllm__module_executor__layers__linear__RowParallelLinear____init__)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.weight_loader,
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.forward,
vllm__module_executor__layers__linear__RowParallelLinear__forward)

View File

@@ -0,0 +1,744 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""Inference-only MOE model."""
from typing import Optional, Any, List, Dict
import torch
from torch import nn
from vllm.distributed import (
divide,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
from vllm_mlu.distributed.parallel_state import(
cnclep_dispatch, cnclep_combine)
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
class LongCatSparseMoeMlp(SparseMoeMlp):
"""
sparse moe mlp layer specific to longcat model
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
up_proj_name: str,
is_gated: bool,
down_proj_name: str,
has_bias: bool,
skip_bias_add: bool = False,
renormalize:bool = False,
hidden_act: str = "silu",
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
is_use_fused_moe: bool = False,
expert_group: Optional[int] = 1,
topk_group: Optional[int] = 1,
scoring_func: str = "softmax",
topk_method: str = "",
routed_scaling_factor: float = 1.0,
tp_group: Any = None,
use_all2all: bool = False,
num_zero_experts: int = 0,
):
super().__init__(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
up_proj_name=up_proj_name,
is_gated=is_gated,
down_proj_name=down_proj_name,
has_bias=has_bias,
skip_bias_add=skip_bias_add,
renormalize=renormalize,
hidden_act=hidden_act,
params_dtype=params_dtype,
quant_config=quant_config,
is_use_fused_moe=is_use_fused_moe,
expert_group=expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
topk_method=topk_method,
routed_scaling_factor=routed_scaling_factor,
tp_group=tp_group,
use_all2all=use_all2all,
init_avg_moe=False,
)
self.num_zero_experts = num_zero_experts
self.total_experts_including_zero = self.num_total_experts + self.num_zero_experts
self.use_quant_all2all = use_all2all and quant_config is not None
self.zero_expert_size = divide(self.num_zero_experts, self.moe_ep_size)
self.start_zero_expert_id = (
self.num_total_experts + self.moe_ep_rank * ((self.num_zero_experts + self.moe_ep_size - 1) // self.moe_ep_size)
)
if VLLM_AVG_MOE_EN and not SparseMoeMlp.is_expert_avg:
n_tokens = SparseMoeMlp.max_batched_token * self.dp_size
expert_group = self.moe_ep_size
val = 1.0 / float(self.total_experts_including_zero)
SparseMoeMlp.reduce_weight = torch.full((n_tokens, top_k), val, device="mlu", dtype=torch.float32)
if VLLM_RANDOM_MOE_EN:
import numpy as np
# example deepseekv2: experts 160 topk 6
# avg list: 92, 8, 88, 45, 99, 9,... 118, 142, 116, 57, 104, 6,......
array = np.stack([np.random.permutation(self.total_experts_including_zero)[:top_k] for _ in range(n_tokens)])
table = torch.from_numpy(array.flatten()).to(device="mlu", dtype=torch.int32)
else:
# example deepseekv2: experts 160
# avg list: 0,20,40,60,80...120,140, 1,21,...121,141, 2...142, ...... 19,...159, 0,20,......
import math
batch_table = math.ceil(n_tokens * top_k / self.total_experts_including_zero) * self.total_experts_including_zero
hi_val = batch_table // self.total_experts_including_zero
table = (torch.arange(hi_val * num_experts, device="mlu", dtype=torch.int32) % num_experts).view(
hi_val, expert_group, num_experts // expert_group).transpose(1, 2)
if self.num_zero_experts > 0:
# Longcat model, for avg expert, we choose eight non-zero experts and four zero
# experts for each token accorrding to the paper.
assert num_experts == 512 and num_zero_experts == 256 and top_k == 12
assert num_zero_experts % expert_group == 0
non_zero_expert_num_per_token = 8
zero_expert_num_per_token = 4
zero_expert_table = torch.arange(
num_experts, num_experts + num_zero_experts, dtype=table.dtype, device=table.device).view(
expert_group, num_zero_experts // expert_group).transpose(0, 1).flatten()
non_zero_expert_table = table[0].flatten()
token_expert_list = []
for idx in range(0, num_experts // non_zero_expert_num_per_token):
token_expert_list.append(non_zero_expert_table[
idx * non_zero_expert_num_per_token:
idx * non_zero_expert_num_per_token + non_zero_expert_num_per_token])
token_expert_list.append(zero_expert_table[
idx * zero_expert_num_per_token:
idx * zero_expert_num_per_token + zero_expert_num_per_token])
avg_expert_table = torch.cat(token_expert_list)
table = avg_expert_table.repeat(hi_val)
SparseMoeMlp.expert_id = table.flatten()[:n_tokens * top_k].view(n_tokens, top_k)
SparseMoeMlp.is_expert_avg = True
def forward_experts_nofused_longcat(
self, hidden_states, total_num_experts, total_num_experts_per_rank,
topk_indices=None, topk_weights=None, residual_=None):
assert self.moe_ep_size == 1
assert not self.use_all2all
expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count = mlu_ops.moe_gen_idx(
topk_indices.to(torch.int32), total_num_experts)
# no expert is routed, then expand_gather_idx, expand_scatter_idx has no item,
# expand_token_count and expand_cusum_token_count has item but the value is all zero
# so this rank should only return final_hidden_states with zero value
if cusum_token_count[-1] == 0:
final_hidden_states = torch.zeros_like(hidden_states,
dtype=hidden_states.dtype,
device=hidden_states.device)
return final_hidden_states
expand_hidden_states = mlu_ops.moe_expand_input(
hidden_states, expand_gather_idx, cusum_token_count,
start_expert_id=self.start_expert_id,
expert_size=self.end_expert_id - self.start_expert_id)
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_gather_idx, cusum_token_count,
start_expert_id=self.start_zero_expert_id,
expert_size=self.zero_expert_size)
expand_output_list = []
expand_cusum_token_count = cusum_token_count[self.start_expert_id:self.end_expert_id +
1] - cusum_token_count[self.start_expert_id]
for expert_idx, num_tokens_per_expert in enumerate(expand_token_count[:self.num_total_experts]):
if num_tokens_per_expert > 0:
expert_hidden_states = expand_hidden_states[
expand_cusum_token_count[expert_idx]:expand_cusum_token_count[expert_idx + 1]]
if expert_idx < self.num_total_experts:
expert_output = self.experts[expert_idx](expert_hidden_states)
else:
expert_output = expert_hidden_states
expert_output = expert_output[0] if isinstance(expert_output, (tuple, list)) else expert_output
expand_output_list.append(expert_output)
expand_output = torch.cat(expand_output_list, dim=0)
num_normal_tokens = cusum_token_count[self.num_total_experts]
expand_hidden_states[:num_normal_tokens] = expand_output
# reduce normal experts
final_hidden_states = mlu_ops.moe_combine_result(
expand_hidden_states, topk_weights, scatter_idx,
residual_, cusum_token_count, start_expert_id=self.start_expert_id,
expert_size=self.end_expert_id - self.start_expert_id, bias=None)
# reduce zero experts
if self.moe_ep_size > 1 or self.moe_tp_rank == 0:
final_hidden_states = mlu_ops.moe_combine_result(
expand_hidden_states_zero, topk_weights, scatter_idx,
final_hidden_states, cusum_token_count, start_expert_id=self.start_zero_expert_id,
expert_size=self.zero_expert_size, bias=None,
output=final_hidden_states)
return final_hidden_states
# no compute-communication parallel, for prototyping only, not in actual use.
# subject to becoming stale
def forward_all2all_int8_longcat(
self, hidden_states, total_num_experts, total_num_experts_per_rank,
topk_indices=None, topk_weights=None, residual_=None):
ori_input_shape = hidden_states.shape
dtype = hidden_states.dtype
self.pack_params()
self.pack_params_after_loading()
w1=self.w13
w2=self.w2
bias2=self.b2
input_smooth=self.a13_scale_all_experts
act_smooth=self.a2_scale
w1_scale=self.w13_scale
w2_scale=self.w2_scale
act_mode=self.hidden_act
quant_input=None
max_m = hidden_states.shape[0]
reduce_weight = topk_weights
expert_id = topk_indices
expand_idx, combine_idx, token_count, cusum_token_count \
= mlu_ops.moe_gen_idx(expert_id, total_num_experts)
num_token_expand = hidden_states.shape[0] * self.top_k
dispatch_bytes = num_token_expand * self.dispatch_token_size
dispatch_send_token_tensor = (
self.dispatch_send_buffer[:dispatch_bytes]
.view(num_token_expand, self.dispatch_token_size)
)
quant_size = self.hidden_size
quant_input = dispatch_send_token_tensor[:, : quant_size]
input_scale = dispatch_send_token_tensor[:, quant_size :].view(torch.float32)
quant_input, input_scale = mlu_ops.moe_quantize(
hidden_states, input_smooth, None, token_count[:self.num_total_experts],
expand_idx, None,
output=quant_input,
output_scale=input_scale)
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.num_total_experts,
expert_size=self.num_zero_experts)
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
token_count[:self.num_total_experts], self.moe_ep_size)
cnclep_dispatch(self.dispatch_token_size,
num_token_expand,
dispatch_send_layout,
token_count[:self.num_total_experts],
self.dispatch_recv_layout,
self.dispatch_recv_token_num)
recv_token_num = self.dispatch_recv_token_num.view(
self.moe_ep_size, self.num_experts_per_rank)
pad_num = self.max_num_tokens_per_rank
(
gather_by_expert_index,
gather_by_rank_index,
tokens_per_local_expert,
token_sum
) = mlu_ops.moe_all2all_gen_gather_index(recv_token_num, pad_num)
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
dispatch_recv_token_tensor = (
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
.view(self.max_num_tokens_recv, self.dispatch_token_size))
mlu_ops.gather_split(dispatch_recv_token_tensor,
gather_by_expert_index,
token_sum,
self.quant_input_recv,
self.input_scale_recv)
max_m = self.max_num_tokens_per_expert
gemm_out = mlu_ops.smooth_quant_group_gemm(self.quant_input_recv, w1,
tokens_per_local_expert,
None, None, None, None,
self.input_scale_recv.view(torch.float32).flatten(),
w1_scale, dtype, max_m)
# continue reusing self.quant_input_recv and self.input_scale_recv
quant_input = self.quant_input_recv[:, :gemm_out.shape[-1] // 2]
input_scale_fp32 = self.input_scale_recv.view(torch.float32).flatten()[:gemm_out.shape[0]]
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None,
tokens_per_local_expert,
output=quant_input,
output_scale=input_scale_fp32,
act_mode=act_mode,
is_gated=self.is_gated)
gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, w2,
tokens_per_local_expert,
None, None, None, None, input_scale, w2_scale, dtype, max_m)
combine_send_token_tensor = self.combine_send_buffer.view(self.max_num_tokens_recv, -1).view(hidden_states.dtype)
mlu_ops.gather_split(gemm_out,
gather_by_rank_index,
token_sum,
combine_send_token_tensor,
None)
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(self.dispatch_recv_token_num, self.moe_ep_size)
combine_recv_layout = self.dispatch_recv_layout
# combine
combine_args = dict(
token_byte=self.hidden_size * 2,
token_num=num_token_expand,
send_src_layout=combine_send_layout,
send_dst_layout=combine_recv_layout,
send_token=None,
recv_token=None)
cnclep_combine(**combine_args)
numel_recv = num_token_expand * self.hidden_size
recv_token = (self.combine_recv_buffer.view(hidden_states.dtype)[:numel_recv]
.view(num_token_expand, self.hidden_size))
residual_ = None
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
residual_, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts, bias=bias2, output=hidden_states)
assert self.moe_ep_size > 1
# zero expert reduce
output = mlu_ops.moe_combine_result(
expand_hidden_states_zero, reduce_weight, combine_idx,
output, cusum_token_count, self.num_total_experts,
self.num_zero_experts, output=hidden_states)
return output.view(ori_input_shape)
# no compute-communication parallel, for prototyping only, not in actual use.
# subject to becoming stale
def forward_all2all_bf16_longcat(
self, hidden_states, total_num_experts, total_num_experts_per_rank,
topk_indices=None, topk_weights=None, residual_=None):
is_fp8_quant = isinstance(self.quant_config, Fp8Config)
ori_input_shape = hidden_states.shape
dtype = hidden_states.dtype
self.pack_params()
self.pack_params_after_loading()
w1=self.w13
w2=self.w2
bias1=self.b13
bias2=self.b2
gated=self.is_gated
act_mode=self.hidden_act
max_m = hidden_states.shape[0]
reduce_weight = topk_weights
expert_id = topk_indices
# gen_idx
expand_idx, combine_idx, token_count, cusum_token_count = \
mlu_ops.moe_gen_idx(expert_id, total_num_experts)
num_token_expand = hidden_states.shape[0] * self.top_k
dispatch_bytes = num_token_expand * self.dispatch_token_size
dispatch_send_token_tensor = (
self.dispatch_send_buffer[:dispatch_bytes]
.view(num_token_expand, self.dispatch_token_size)
.view(hidden_states.dtype)
)
expand_hidden_states = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts)
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.num_total_experts,
expert_size=self.num_zero_experts)
dispatch_send_token_tensor.copy_(expand_hidden_states)
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
token_count[:self.num_total_experts], self.moe_ep_size)
cnclep_dispatch(self.dispatch_token_size,
num_token_expand,
dispatch_send_layout,
token_count[:self.num_total_experts],
self.dispatch_recv_layout,
self.dispatch_recv_token_num,
use_quant_dispatch=False,
)
recv_token_num = self.dispatch_recv_token_num.view(
self.moe_ep_size, self.num_experts_per_rank)
pad_num = self.max_num_tokens_per_rank
(
gather_by_expert_index,
gather_by_rank_index,
tokens_per_local_expert,
token_sum
) = mlu_ops.moe_all2all_gen_gather_index(recv_token_num, pad_num)
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
dispatch_recv_token_tensor = (
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
.view(self.max_num_tokens_recv, self.dispatch_token_size)
.view(hidden_states.dtype)
)
self.quant_input_recv = self.quant_input_recv.view(hidden_states.dtype)
mlu_ops.gather_split(dispatch_recv_token_tensor,
gather_by_expert_index,
token_sum,
self.quant_input_recv)
max_m = self.max_num_tokens_per_expert
gemm_out = mlu_ops.group_gemm(
self.quant_input_recv, w1, tokens_per_local_expert,
None, None, None, None, max_m)
act_out = mlu_ops.moe_active(
gemm_out, act_mode, gated)
gemm_out = mlu_ops.group_gemm(
act_out, w2, tokens_per_local_expert,
None, None, None, None, max_m)
combine_send_token_tensor = self.combine_send_buffer.view(
self.max_num_tokens_recv, -1).view(hidden_states.dtype)
mlu_ops.gather_split(gemm_out,
gather_by_rank_index,
token_sum,
combine_send_token_tensor,
None)
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(
self.dispatch_recv_token_num, self.moe_ep_size)
combine_recv_layout = self.dispatch_recv_layout
combine_args = dict(
token_byte=self.hidden_size * 2,
token_num=num_token_expand,
send_src_layout=combine_send_layout,
send_dst_layout=combine_recv_layout,
send_token=None,
recv_token=None,
use_quant_dispatch=False,
)
cnclep_combine(**combine_args)
numel_recv = num_token_expand * self.hidden_size
recv_token = (self.combine_recv_buffer.view(hidden_states.dtype)[:numel_recv]
.view(num_token_expand, self.hidden_size))
residual_ = None
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
residual_, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts, bias=bias2, output=hidden_states)
# zero expert reduce
output = mlu_ops.moe_combine_result(
expand_hidden_states_zero, reduce_weight, combine_idx,
output, cusum_token_count, self.num_total_experts,
self.num_zero_experts, output=hidden_states)
return output.view(ori_input_shape)
def forward_before_dispatch(self, hidden_states: torch.Tensor,
topk_indices: torch.Tensor):
# gate and softmax topk is called in router for longcat
# other models can do these operations here
expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(
topk_indices, self.total_experts_including_zero)
num_token_expand = hidden_states.shape[0] * self.top_k
dispatch_bytes = num_token_expand * self.dispatch_token_size
dispatch_send_token_tensor = (
self.dispatch_send_buffer[:dispatch_bytes]
.view(num_token_expand, self.dispatch_token_size)
)
if self.use_quant_all2all:
hidden_states_stride = self.hidden_size
quant_input = dispatch_send_token_tensor[:, : hidden_states_stride]
input_scale = dispatch_send_token_tensor[:, hidden_states_stride :].view(torch.float32)
# expand input + quantize
quant_input, input_scale = mlu_ops.moe_quantize(
hidden_states, self.a13_scale_all_experts, None,
token_count[:self.num_total_experts],
expand_idx, None,
output=quant_input,
output_scale=input_scale)
# expand input of zero-expert
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.num_total_experts,
expert_size=self.num_zero_experts)
else:
expand_hidden_states = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts)
dispatch_send_token_tensor = dispatch_send_token_tensor.view(
hidden_states.dtype)
dispatch_send_token_tensor.copy_(expand_hidden_states)
del expand_hidden_states
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.num_total_experts,
expert_size=self.num_zero_experts)
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
token_count[:self.num_total_experts], self.moe_ep_size)
return combine_idx, token_count, cusum_token_count, dispatch_send_layout, expand_hidden_states_zero
def forward_dispatch(self, token_num: int, dispatch_send_layout: torch.Tensor,
token_count: torch.Tensor):
num_token_expand = token_num * self.top_k
cnclep_dispatch(self.dispatch_token_size,
num_token_expand,
dispatch_send_layout,
token_count[:self.num_total_experts],
self.dispatch_recv_layout,
self.dispatch_recv_token_num,
use_quant_dispatch=self.use_quant_all2all)
def forward_before_combine(self, hidden_states_dtype: torch.dtype):
recv_token_num = self.dispatch_recv_token_num.view(
self.moe_ep_size, self.num_experts_per_rank)
(
gather_by_expert_index,
gather_by_rank_index,
tokens_per_local_expert,
token_sum,
cusum_token_count
) = mlu_ops.moe_all2all_gen_gather_index(
recv_token_num, self.max_num_tokens_per_rank,
return_cusum_token_count=True)
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
dispatch_recv_token_tensor = (
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
.view(self.max_num_tokens_recv, self.dispatch_token_size))
max_m = self.max_num_tokens_per_expert
if self.use_quant_all2all:
mlu_ops.gather_split(dispatch_recv_token_tensor,
gather_by_expert_index,
token_sum,
self.quant_input_recv,
self.input_scale_recv)
# OPT: input_scale_recv_flatten can reuse self.input_scale_recv
input_scale_recv_flatten = self.input_scale_recv.view(torch.float32).flatten()
gemm_out = mlu_ops.smooth_quant_group_gemm(self.quant_input_recv, self.w13,
tokens_per_local_expert,
None, None, None, None,
input_scale_recv_flatten,
self.w13_scale, hidden_states_dtype, max_m)
quant_input = self.quant_input_recv[:, :gemm_out.shape[-1] // 2]
input_scale_fp32 = input_scale_recv_flatten[:gemm_out.shape[0]]
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, self.a2_scale, None,
tokens_per_local_expert,
output=quant_input,
output_scale=input_scale_fp32,
act_mode=self.hidden_act,
is_gated=self.is_gated)
gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, self.w2, tokens_per_local_expert,
None, None, None, None, input_scale, self.w2_scale,
hidden_states_dtype, max_m)
else:
dispatch_recv_token_tensor = dispatch_recv_token_tensor.view(hidden_states_dtype)
self.input_recv = self.input_recv.view(hidden_states_dtype)
mlu_ops.gather_split(dispatch_recv_token_tensor,
gather_by_expert_index,
token_sum,
self.input_recv)
gemm_out = mlu_ops.group_gemm(
self.input_recv, self.w13, tokens_per_local_expert,
None, None, None, None, max_m)
act_out = self.input_recv[:, :gemm_out.shape[-1] // 2]
act_out = mlu_ops.moe_active(
gemm_out, self.hidden_act, self.is_gated, output=act_out,
bias=None, cusum_token_count=cusum_token_count,
start_expert_id=0, expert_size=self.num_experts_per_rank)
gemm_out = mlu_ops.group_gemm(
act_out, self.w2, tokens_per_local_expert,
None, None, None, None, max_m)
combine_send_token_tensor = self.combine_send_buffer.view(
self.max_num_tokens_recv, -1).view(hidden_states_dtype)
mlu_ops.gather_split(gemm_out,
gather_by_rank_index,
token_sum,
combine_send_token_tensor,
None)
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(
self.dispatch_recv_token_num, self.moe_ep_size)
return combine_send_layout
def forward_combine(self, token_num: int, combine_send_layout: torch.Tensor):
num_token_expand = token_num * self.top_k
# combine_recv_layout(self.dispatch_recv_layout) is calculated when cnclep_dispatch
# because dispatch and combine are inverse operation
cnclep_combine(token_byte=self.hidden_size * 2,
token_num=num_token_expand,
send_src_layout=combine_send_layout,
send_dst_layout=self.dispatch_recv_layout,
send_token=None,
recv_token=None,
use_quant_dispatch=self.use_quant_all2all)
def forward_after_combine(self, token_num: int,
reduce_weight: torch.Tensor,
combine_idx: torch.Tensor,
cusum_token_count: torch.Tensor,
expand_hidden_states_zero: torch.Tensor,
output_tensor_dtype: torch.dtype,
output_tensor: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None):
num_token_expand = token_num * self.top_k
numel_recv = num_token_expand * self.hidden_size
recv_token = (self.combine_recv_buffer.view(output_tensor_dtype)[:numel_recv]
.view(num_token_expand, self.hidden_size))
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
residual, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts, bias=self.b2, output=output_tensor)
output = mlu_ops.moe_combine_result(
expand_hidden_states_zero, reduce_weight, combine_idx,
output, cusum_token_count, self.num_total_experts,
self.num_zero_experts, output=output_tensor)
return output
# no compute-communication parallel, for prototyping only, not in actual use.
# subject to becoming stale
def forward_group_experts_longcat(
self, hidden_states, total_num_experts, total_num_experts_per_rank,
topk_indices=None, topk_weights=None, residual_=None,
expand_idx=None, combine_idx=None, token_count=None, cusum_token_count=None):
is_fp8_quant = isinstance(self.quant_config, Fp8Config)
ori_input_shape = hidden_states.shape
dtype = hidden_states.dtype
self.pack_params()
self.pack_params_after_loading()
w1=self.w13
w2=self.w2
bias1=self.b13
bias2=self.b2
input_smooth=self.a13_scale
act_smooth=self.a2_scale
w1_scale=self.w13_scale
w2_scale=self.w2_scale
gated=self.is_gated
act_mode=self.hidden_act
quant_input=None
start_expert_id=self.start_expert_id
expert_size = w1.size(0)
max_m = hidden_states.shape[0]
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
residual_ = residual_.view(-1, residual_.size(-1)) if residual_ is not None else None
# Check smooth quant parameters.
per_token_sq = False
if not is_fp8_quant:
check_list = [input_smooth, act_smooth, w1_scale, w2_scale]
if all(x is not None for x in check_list):
per_token_sq = True
if not (all(x is None for x in check_list) or all(x is not None for x in check_list)):
raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present "
"and absent at the same time.")
expert_id = topk_indices
reduce_weight = topk_weights
# gen_idx
if expert_id is not None:
expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(expert_id, total_num_experts)
# check quant
if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token':
raise NotImplementedError
elif per_token_sq:
expand_hidden_states = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=start_expert_id,
expert_size=expert_size)
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.start_zero_expert_id,
expert_size=self.zero_expert_size)
quant_input, input_scale = mlu_ops.moe_quantize(
expand_hidden_states, input_smooth, None,
token_count[start_expert_id:start_expert_id+expert_size])
else:
expand_hidden_states = mlu_ops.moe_expand_input(hidden_states, expand_idx,
cusum_token_count, start_expert_id, expert_size)
expand_hidden_states_zero = mlu_ops.moe_expand_input(hidden_states, expand_idx,
cusum_token_count, self.start_zero_expert_id, self.zero_expert_size)
if (is_fp8_quant and self.quant_config.activation_quant_method == 'per_token') or per_token_sq:
gemm_out = mlu_ops.smooth_quant_group_gemm(
quant_input, w1,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, input_scale, w1_scale, dtype, max_m)
else:
gemm_out = mlu_ops.group_gemm(expand_hidden_states, w1,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, max_m)
# add_bias_active
if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token':
raise NotImplementedError
elif per_token_sq:
quant_input = quant_input[:, :gemm_out.shape[-1] // 2]
input_scale = input_scale[:gemm_out.shape[0]]
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None,
token_count[start_expert_id:start_expert_id+expert_size],
output=quant_input,
output_scale=input_scale,
act_mode=act_mode,
is_gated=self.is_gated)
if ((is_fp8_quant and self.quant_config.activation_quant_method == 'per_token')
or per_token_sq):
# Remove the reference to gemm_out tensor.
# If that was the only reference, the tensors memory becomes eligible for deallocation
# So that we can reuse this memory for the new allocation of next gemm operation
# del gemm_out
gemm_out = mlu_ops.smooth_quant_group_gemm(
quant_input, w2,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, input_scale, w2_scale, dtype, max_m,
output=expand_hidden_states)
else:
act_out = mlu_ops.moe_active(
gemm_out, act_mode, gated, gemm_out[:,:gemm_out.shape[-1]//2],
bias1, cusum_token_count, start_expert_id, expert_size)
gemm_out = mlu_ops.group_gemm(
act_out, w2, token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, max_m,
output=expand_hidden_states)
output = mlu_ops.moe_combine_result(
gemm_out, reduce_weight, combine_idx,
residual_, cusum_token_count, start_expert_id,
expert_size, bias2)
if self.moe_ep_size > 1 or self.moe_tp_rank == 0:
output = mlu_ops.moe_combine_result(
expand_hidden_states_zero, reduce_weight, combine_idx,
output, cusum_token_count, self.start_zero_expert_id,
self.zero_expert_size, bias2,
output=output)
return output.view(ori_input_shape)

View File

@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.model_executor.layers.quantization import (
QUANTIZATION_METHODS, register_quantization_config
)
MLU_QUANTIZATION_METHODS= [
"smoothquant",
"weightonly",
"awq_mlu",
"gptq_mlu",
]
def register_fake_mlu_quantization_methods():
for quant_method in MLU_QUANTIZATION_METHODS:
if quant_method not in QUANTIZATION_METHODS:
QUANTIZATION_METHODS.append(quant_method)
def remove_fake_mlu_quantization_methods():
for quant_method in MLU_QUANTIZATION_METHODS:
if quant_method in QUANTIZATION_METHODS:
QUANTIZATION_METHODS.remove(quant_method)
def register_real_mlu_quantization_methods():
remove_fake_mlu_quantization_methods()
from vllm_mlu.model_executor.layers.quantization.weightonly import WeightOnlyConfig
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantConfig
from vllm_mlu.model_executor.layers.quantization.awq_mlu import AWQMluConfig
from vllm_mlu.model_executor.layers.quantization.gptq_mlu import GPTQMluConfig
register_quantization_config("weightonly")(WeightOnlyConfig)
register_quantization_config("smoothquant")(SmoothQuantConfig)
register_quantization_config("awq_mlu")(AWQMluConfig)
register_quantization_config("gptq_mlu")(GPTQMluConfig)

View File

@@ -0,0 +1,412 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Any, Dict, List, Optional, Tuple
import torch
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.scalar_type import ScalarType, scalar_types
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
def query_mlu_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if has_zp:
# AWQ style, unsigned + zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
return [scalar_types.uint4b8, scalar_types.uint8b128]
def check_mlu_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
supported_types = query_mlu_supported_quant_types(
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"Mlu does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
return (False, f"Mlu does not support group_size = {group_size}. "
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True
# @register_quantization_config("awq_mlu")
class AWQMluConfig(QuantizationConfig):
"""Config class for AWQMlu.
Reference: https://arxiv.org/abs/2306.00978
"""
# num_bits -> type
TYPE_MAP = {
4: {
False: scalar_types.uint4b8,
True: scalar_types.uint4,
},
8: {
False: scalar_types.uint8b128,
True: scalar_types.uint8,
}
}
VERSION = ["gemm"]
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
lm_head_quantized: bool,
version: str = "gemm",
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.pack_factor = 32 // self.weight_bits
self.version = version
self.support_scale_zeros = False
if self.weight_bits not in [4, 8]:
raise ValueError(
"Currently, only 4/8-bit weight quantization is supported for "
f"AWQMlu, but got {self.weight_bits} bits.")
if self.version not in self.VERSION:
raise ValueError(
"Currently, only gemm, gemv version is supported for "
f"AWQMlu, but got verion:{self.version}.")
if self.version in ["gemm"]:
self.order_map = {4: [0, 2, 4, 6, 1, 3, 5, 7], 8: [0, 2, 1, 3]}
self.reverse_order_map = {4 : [0, 4, 1, 5, 2, 6, 3, 7], 8: [0, 2, 1, 3]}
else:
self.order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
self.reverse_order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
def __repr__(self) -> str:
return (f"AWQMluConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point}), "
f"lm_head_quantized={self.lm_head_quantized})")
@classmethod
def get_name(cls) -> str:
return "awq_mlu"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@staticmethod
def get_config_filenames() -> List[str]:
return ["quant_config.json", "quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQMluConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
version = cls.get_from_keys_or(config, ["version"],
default="gemm")
return cls(weight_bits, group_size, zero_point, lm_head_quantized, version)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQMluLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return AWQMluLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_awq_mlu_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "awq"
or user_quant == "awq_mlu")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "awq":
logger.info("Detected that the model can run with awq_mlu"
", however you specified quantization=awq explicitly,"
" so forcing awq. Use quantization=awq_mlu for"
" faster inference")
return None
@classmethod
def is_awq_mlu_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
has_zp = quant_config.get("zero_point", None)
version = quant_config.get("version", "gemm")
if quant_method != "awq":
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or has_zp is None):
return False
if num_bits not in cls.TYPE_MAP:
return False
if version not in cls.VERSION:
return False
return check_mlu_supported(quant_type=cls.TYPE_MAP[num_bits][has_zp],
group_size=group_size,
has_zp=has_zp)
class AWQMluLinearMethod(LinearMethodBase):
"""Linear method for AWQMlu.
Args:
quant_config: The AWQMlu quantization config.
"""
def __init__(self, quant_config: AWQMluConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
weight_loader = extra_weight_attrs.get("weight_loader")
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
qzeros = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
scales = GroupQuantScaleParameter(data=torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
dtype=params_dtype,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader)
layer.register_parameter("qweight", qweight)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
packed_qweight, scale_zeros = self.extract_autoawq(layer)
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
layer.qzeros = None
layer.scales = None
else:
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
if scale_zeros is not None:
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
else:
layer.qzeros = None
layer.scales = torch.nn.Parameter(layer.scales.data.transpose(0, 1).contiguous(), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.quant_config.zero_point and not self.quant_config.support_scale_zeros:
output = mlu_ops.matmul(x, layer.qweight, bias)
if residual is not None:
output = output + residual
else:
output = mlu_ops.weight_only_quant_matmul(x,
layer.qweight,
layer.scales,
layer.qzeros,
bias,
residual,
"none",
self.quant_config.weight_bits)
return output
def extract_autoawq(self, layer: torch.nn.Module):
qweight = layer.qweight.data
qzeros = layer.qzeros.data
scales = layer.scales.data
bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
# Unpack the qweight and qzeros tensors
iweight, izeros = self.unpack_awq_int32_into_int8(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = self.reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
if izeros is not None:
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
scales = scales.repeat_interleave(group_size, dim=0)
if izeros is not None:
izeros = izeros.repeat_interleave(group_size, dim=0)
fweight = (iweight - izeros) * scales
else:
fweight = iweight * scales
# transpose [ci, co] -> [co, ci]
fweight = fweight.transpose(0, 1)
return fweight, None
if self.quant_config.zero_point and self.quant_config.support_scale_zeros and izeros is not None:
scale_zeros = izeros.to(scales.dtype) * -1 * scales
# transpose [ci, co] -> [co, ci]
scale_zeros = scale_zeros.transpose(0, 1)
else:
scale_zeros = None
# transpose [ci, co] -> [co, ci]
iweight = iweight.to(torch.int8).transpose(0, 1)
if bits == 4:
higher_bit_tensor = iweight[:, 1::2]
lower_bit_tensor = iweight[:, 0::2]
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
else:
packed_qweight = iweight
return packed_qweight, scale_zeros
def unpack_awq_int32_into_int8(self, qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qweight.device)
dtype = torch.int16 if bits == 8 else torch.int8
# unpacking columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(dtype)
iweights = iweights.view(iweights.shape[0], -1)
if not self.quant_config.zero_point or self.quant_config.support_scale_zeros:
iweights = torch.bitwise_and(iweights - 2**(bits - 1), (2 ** bits) - 1)
# unpacking columnwise
if qzeros is not None:
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(dtype)
izeros = izeros.view(izeros.shape[0], -1)
if not self.quant_config.zero_point:
izeros = torch.bitwise_and(izeros - 2**(bits - 1), (2 ** bits) - 1)
else:
izeros = None
return iweights, izeros
def reverse_awq_order(self, iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
reverse_order_tensor = torch.arange(iweights.shape[-1], dtype=torch.int32, device=iweights.device)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, self.quant_config.reverse_order_map[bits]]
reverse_order_tensor = reverse_order_tensor.view(-1)
rweights = iweights[:, reverse_order_tensor]
if izeros is not None:
rzeros = izeros[:, reverse_order_tensor]
return rweights, rzeros
def combine_low_bits(self, tensor_a, tensor_b):
"""
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
Args:
tensor_a (torch.Tensor): First tensor of type int8.
tensor_b (torch.Tensor): Second tensor of type int8.
Returns:
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
"""
# 确保输入是 int8 类型
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
raise ValueError("Both tensors must be of int8 type.")
# 提取每个 tensor 的低4位
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
# 将 tensor_a 的低4位左移4位
shifted_low_bits_a = low_bits_a << 4
# 组合两个 tensor 的低4位
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
return combined

View File

@@ -0,0 +1,753 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import functools
from functools import partial
import importlib.util
from typing import Any, Callable, Optional, Union
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from typing import Any, Dict, List, Optional, Callable
from vllm import envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.quantization.fp8 import (
get_flashinfer_moe_backend,
ACTIVATION_SCHEMES,
Fp8Config,
Fp8LinearMethod,
Fp8MoeBackend,
Fp8MoEMethod,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
validate_fp8_block_shape
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, cutlass_block_fp8_supported, cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale,
maybe_create_device_identity, Fp8LinearOp)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter, ChannelQuantScaleParameter,
ModelWeightParameter, PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
from vllm.utils.flashinfer import has_flashinfer_moe
from vllm.utils.import_utils import has_deep_gemm
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize
import vllm_mlu._mlu_ops as mlu_ops
logger = init_logger(__name__)
def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if (
current_platform.is_cuda()
and (
current_platform.is_device_capability(100)
or current_platform.is_device_capability(90)
)
and envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
):
backend = get_flashinfer_moe_backend()
if backend == FlashinferMoeBackend.TENSORRT_LLM:
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
if block_quant and current_platform.is_device_capability(100):
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency "
"instead."
)
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
return Fp8MoeBackend.FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: disable marlin for MLU backend.
'''
if current_platform.is_rocm() or current_platform.is_out_of_tree():
use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
if use_marlin:
logger.info_once("Using Marlin backend for FP8 MoE")
return Fp8MoeBackend.MARLIN
# deepGEMM on supported platforms with block-quantized weights
if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
if not has_deep_gemm():
logger.warning_once("DeepGEMM backend requested but not available.")
elif is_deep_gemm_supported():
logger.info_once("Using DeepGEMM backend for FP8 MoE")
return Fp8MoeBackend.DEEPGEMM
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
if (
current_platform.is_cuda()
and current_platform.is_device_capability(100)
and block_quant
):
logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
# default to Triton
logger.info_once("Using Triton backend for FP8 MoE")
return Fp8MoeBackend.TRITON
Fp8Config____init____org = Fp8Config.__init__
def vllm__model_executor__layers__quantization__fp8__Fp8Config____init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: list[str] | None = None,
weight_block_size: list[int] | None = None,
activation_quant_method: Optional[str] = None,
weight_quant_method: Optional[str] = None,
) -> None:
super(Fp8Config, self).__init__()
Fp8Config____init____org(
self,
is_checkpoint_fp8_serialized,
activation_scheme,
ignored_layers,
weight_block_size
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add class members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
self.activation_quant_method = activation_quant_method
self.weight_quant_method = weight_quant_method
assert (self.weight_block_size or \
self.activation_quant_method == "per_token" and self.weight_quant_method == "per_channel"
and self.activation_scheme == "dynamic"), "Only support block-wise quantization, or "\
"input dynamic per-token weight per-channel quantization yet."
'''
==================
End of MLU Hijack
==================
'''
@classmethod
def vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config(
cls, config: Dict[str, Any]
) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "fp8" in quant_method
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
if not ignored_layers:
ignored_layers = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add config members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
activation_quant_method = cls.get_from_keys_or(config,
["activation_quant_method"],
'per_token')
weight_quant_method = cls.get_from_keys_or(config,
["weight_quant_method"],
None)
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=weight_block_size,
activation_quant_method=activation_quant_method,
weight_quant_method=weight_quant_method)
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group.
'''
tp_group = extra_weight_attrs.get("tp_group", None)
'''
==================
End of MLU Hijack
==================
'''
if self.block_quant:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
validate_fp8_block_shape(
layer,
input_size,
output_size,
input_size_per_partition,
output_partition_sizes,
self.weight_block_size,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group.
'''
# WEIGHT
if self.quant_config.is_checkpoint_fp8_serialized:
weight = create_fp8_weight_parameter(
output_size_per_partition, input_size_per_partition, weight_loader
)
else:
# For non-serialized checkpoints, use original dtype
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
tp_group=tp_group,
)
'''
==================
End of MLU Hijack
==================
'''
layer.register_parameter("weight", weight)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
if not self.block_quant:
'''
=============================
Modify by vllm_mlu
=============================
@brief: Support weight per channel quantization.
@brief: Add tp_group to enable custom split.
'''
if self.weight_per_channel:
scale = ChannelQuantScaleParameter(
data=torch.empty(sum(output_partition_sizes), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
tp_group=tp_group,
)
else:
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
layer.register_parameter("weight_scale", scale)
'''
==================
End of MLU Hijack
==================
'''
else:
assert not self.act_q_static
assert self.weight_block_size is not None
scale = create_fp8_scale_parameter(
BlockQuantScaleParameter,
output_partition_sizes,
input_size_per_partition,
self.weight_block_size,
weight_loader,
)
set_weight_attrs(scale, {"scale_type": "weight_scale"})
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
# INPUT ACTIVATION SCALE
if self.act_q_static:
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__(
self,
quant_config: Fp8Config
):
self.quant_config = quant_config
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
if vllm_is_batch_invariant():
self.use_marlin = False
# AITER is only supported on ROCm and only for FP8_FNUZ
# and at the moment are MI300 series
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
self.use_deep_gemm = is_deep_gemm_supported()
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.act_q_static = self.quant_config.activation_scheme == "static"
if self.weight_block_size:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add config members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
self.weight_per_channel = (self.quant_config.weight_quant_method == 'per_channel')
self.activation_per_token = (self.quant_config.activation_quant_method == 'per_token')
if self.weight_per_channel and self.activation_per_token:
self.use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
if self.block_quant:
assert not self.act_q_static
assert self.weight_block_size is not None
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape,
)
Fp8LinearMethod__process_weights_after_loading__org = Fp8LinearMethod.process_weights_after_loading
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading(
self,
layer: Module,
) -> None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: For dynamic activation and channel-wise weight quantization,
additional processing is not needed.
'''
if (self.quant_config.is_checkpoint_fp8_serialized
and self.weight_per_channel
and self.quant_config.activation_scheme == "dynamic"):
return
'''
==================
End of MLU Hijack
==================
'''
Fp8LinearMethod__process_weights_after_loading__org(self=self, layer=layer)
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert residual is None, "Fp8Linear residual is not supported yet."
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
if vllm_is_batch_invariant():
if self.block_quant:
assert self.weight_block_size is not None
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
else:
# per-tensor/channel: dequant to BF16 and run GEMM
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
if weight_scale.numel() == 1:
# Per-tensor: simple scalar multiplication
weight_bf16 = weight_fp8 * weight_scale
else:
# Multiple scales (fused modules like QKV)
# Try to infer correct broadcasting
# weight is [K, N], scale could be [num_logical_weights]
# Need to figure out how to broadcast - for now just try
# direct multiplication
if (
weight_scale.dim() == 1
and weight_scale.shape[0] == weight_fp8.shape[0]
):
# Per-row scaling
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
else:
# Fallback
weight_bf16 = weight_fp8 * weight_scale
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
if self.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
if self.block_quant:
assert self.weight_block_size is not None
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear)
return apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Use activation per token quantization based on quantization config.
'''
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
weight_per_channel=self.weight_per_channel,
activation_per_token=self.activation_per_token)
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__(
self,
quant_config: Fp8Config,
layer: torch.nn.Module
):
super(Fp8MoEMethod, self).__init__(layer.moe_config)
self.layer = layer
self.quant_config = quant_config
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
if self.block_quant:
assert self.weight_block_size == [128, 128], (
f"Only support weight_block_size == [128, 128], "
f"got {self.weight_block_size}"
)
self.flashinfer_moe_fn = partial(
flashinfer_cutlass_moe_fp8,
moe=self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
self.allow_cutlass_block_scaled_grouped_gemm = (
self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: In mlu, always set self.use_marlin as False.
'''
self.use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Use moe_softmax_topk and moe_sigmoid_topk of mlu_ops to implement FusedMoE.select_experts
'''
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
if scoring_func == "softmax":
topk_weights, topk_ids = mlu_ops.moe_softmax_topk(
router_logits,
top_k,
renormalize,
num_expert_group,
topk_group,
route_scale=routed_scaling_factor,
)
elif scoring_func == "sigmoid":
topk_weights, topk_ids = mlu_ops.moe_sigmoid_topk(
router_logits,
top_k,
renormalize,
num_expert_group,
topk_group,
routed_scaling_factor,
e_score_correction_bias,
)
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
# gen_idx
ori_input_shape = x.shape
x = x.reshape(-1, x.size(-1))
router_logits = router_logits.reshape(-1, router_logits.size(-1))
expert_num = router_logits.size(-1)
tokens_num = x.size(0)
expert_size = layer.w13_weight.size(0)
expand_idx, combine_idx, token_count, cumsum_token_count = mlu_ops.moe_gen_idx(
topk_ids, expert_num
)
expand_hidden_states = mlu_ops.moe_expand_input(
x, expand_idx, cumsum_token_count, 0, expert_size
)
quant_input, input_scale = _fp8_quantize(
expand_hidden_states, A_scale=None, block_shape=self.quant_config.weight_block_size
)
gemm1_out = mlu_ops.smooth_quant_group_gemm(
quant_input,
layer.w13_weight,
token_count,
expand_idx=None,
c=None,
alpha=None,
beta=None,
a_scale=input_scale.T.contiguous(),
b_scale=layer.w13_weight_scale_inv,
dtype=x.dtype,
max_m=tokens_num,
)
act_out = mlu_ops.active(gemm1_out, activation, is_gated=True)
act_out_quantize, act_out_scale = _fp8_quantize(
act_out, A_scale=None, block_shape=self.quant_config.weight_block_size
)
gemm2_out = mlu_ops.smooth_quant_group_gemm(
act_out_quantize,
layer.w2_weight,
token_count,
expand_idx=None,
c=None,
alpha=None,
beta=None,
a_scale=act_out_scale.T.contiguous(),
b_scale=layer.w2_weight_scale_inv,
dtype=x.dtype,
max_m=tokens_num,
)
output = mlu_ops.moe_combine_result(
gemm2_out,
topk_weights,
combine_idx,
residual=None,
cusum_token_count=cumsum_token_count,
start_expert_id=0,
expert_size=expert_size,
bias=None,
)
return output.view(ori_input_shape)
"""
==================
End of MLU Hijack
==================
"""
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.apply,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply
)
MluHijackObject.apply_hijack(
Fp8Config,
Fp8Config.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8Config____init__
)
MluHijackObject.apply_hijack(
Fp8Config,
Fp8Config.from_config,
vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.create_weights,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.process_weights_after_loading,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading
)
MluHijackObject.apply_hijack(
Fp8MoEMethod,
Fp8MoEMethod.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__
)
MluHijackObject.apply_hijack(
Fp8MoEMethod,
Fp8MoEMethod.apply,
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply
)

View File

@@ -0,0 +1,440 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from fractions import Fraction
from typing import Any, Dict, List, Optional, Tuple
import torch
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
def query_mlu_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if has_zp:
# AWQ style, unsigned + zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
return [scalar_types.uint4b8, scalar_types.uint8b128]
def check_mlu_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
supported_types = query_mlu_supported_quant_types(
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"Mlu does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
return (False, f"Mlu does not support group_size = {group_size}. "
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True
# @register_quantization_config("gptq_mlu")
class GPTQMluConfig(QuantizationConfig):
"""Config class for GPTQMlu.
Reference: https://arxiv.org/abs/2210.17323
"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
(4, False): scalar_types.uint4b8,
(8, False): scalar_types.uint8b128,
}
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.lm_head_quantized = lm_head_quantized
self.pack_factor = Fraction(32, self.weight_bits)
self.support_scale_zeros = False
self.use_native = self.desc_act or (not self.is_sym and not self.support_scale_zeros)
if self.weight_bits not in [4, 8]:
raise ValueError(
"Currently, only 4/8-bit weight quantization is "
f"supported for GPTQMlu, but got {self.weight_bits} bits.")
def __repr__(self) -> str:
return (f"GPTQMluConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}),"
f"lm_head_quantized={self.lm_head_quantized}")
@classmethod
def get_name(cls) -> str:
return "gptq_mlu"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quant_config.json", "quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMluConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym, lm_head_quantized)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQMluLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQMluLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
@classmethod
def is_gptq_mlu_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
if quant_method != "gptq":
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
if (num_bits, sym) not in cls.TYPE_MAP:
return False
return check_mlu_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size, has_zp=False)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_gptq_mlu_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "gptq_mlu")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
return None
class GPTQMluLinearMethod(LinearMethodBase):
"""Linear method for GPTQMlu.
Args:
quant_config: The GPTQMlu quantization config.
"""
def __init__(self, quant_config: GPTQMluConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs.get("weight_loader")
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if (input_size != input_size_per_partition) and (self.quant_config.group_size !=
-1) and (not self.quant_config.desc_act):
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
g_idx = RowvLLMParameter(data=torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
qzeros_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scale_and_zero_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.device = layer.qweight.data.device
packed_qweight, scale_zeros = self.extract_autogptq(layer)
if self.quant_config.use_native:
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
layer.qzeros = None
layer.scales = None
else:
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
if scale_zeros is not None:
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
else:
layer.qzeros = None
layer.scales = torch.nn.Parameter(layer.scales.transpose(0, 1).contiguous(), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.quant_config.use_native:
output = mlu_ops.matmul(x, layer.qweight, bias)
if residual is not None:
output = output + residual
else:
output = mlu_ops.weight_only_quant_matmul(x,
layer.qweight,
layer.scales,
layer.qzeros,
bias,
residual,
"none",
self.quant_config.weight_bits)
return output
def extract_autogptq(self, layer: torch.nn.Module):
scales = layer.scales.data
bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
# Unpack the qweight and qzeros tensors
iweight = self.unpack_gptq_qweight_int32_into_int8(layer.qweight.data, bits)
izeros = self.unpack_gptq_qzeros_int32_into_int8(layer.qzeros.data, bits)
if self.quant_config.use_native:
if self.quant_config.desc_act:
scales = torch.index_select(scales, 0, layer.g_idx)
if izeros is not None:
izeros = torch.index_select(izeros, 0, layer.g_idx)
else:
scales = scales.repeat_interleave(group_size, dim=0)
if izeros is not None:
izeros = izeros.repeat_interleave(group_size, dim=0)
if izeros is not None:
fweight = (iweight - izeros) * scales
else:
fweight = iweight * scales
# transpose [ci, co] -> [co, ci]
fweight = fweight.transpose(0, 1)
return fweight, None
if not self.quant_config.is_sym and self.quant_config.support_scale_zeros and izeros is not None:
scale_zeros = izeros.to(scales.dtype) * -1 * scales
# transpose [ci, co] -> [co, ci]
scale_zeros = scale_zeros.transpose(0, 1)
else:
# for is_sym is true now, so make iweight to sign value and ignore qzeros
iweight = torch.bitwise_and(iweight - 2**(bits - 1), (2 ** bits) - 1)
scale_zeros = None
# transpose [ci, co] -> [co, ci]
iweight = iweight.to(torch.int8).transpose(0, 1)
if bits == 4:
higher_bit_tensor = iweight[:, 1::2]
lower_bit_tensor = iweight[:, 0::2]
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
else:
packed_qweight = iweight
return packed_qweight, scale_zeros
def unpack_gptq_qweight_int32_into_int8(self, qweight: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qweight.device).unsqueeze(0)
dtype = torch.int16 if bits == 8 else torch.int8
weight = torch.bitwise_right_shift(
torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1),
shifts.unsqueeze(-1),
).to(dtype)
weight = torch.bitwise_and(weight, (2**bits) - 1)
weight = weight.reshape(-1, weight.shape[-1])
return weight
def unpack_gptq_qzeros_int32_into_int8(self, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qzeros.device).unsqueeze(0)
dtype = torch.int16 if bits == 8 else torch.int8
zeros = torch.bitwise_right_shift(
torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
shifts.unsqueeze(0),
).to(dtype)
zeros = zeros + 1
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
zeros = zeros.reshape(qzeros.shape[0], -1)
return zeros
def combine_low_bits(self, tensor_a, tensor_b):
"""
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
Args:
tensor_a (torch.Tensor): First tensor of type int8.
tensor_b (torch.Tensor): Second tensor of type int8.
Returns:
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
"""
# 确保输入是 int8 类型
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
raise ValueError("Both tensors must be of int8 type.")
# 提取每个 tensor 的低4位
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
# 将 tensor_a 的低4位左移4位
shifted_low_bits_a = low_bits_a << 4
# 组合两个 tensor 的低4位
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
return combined

View File

@@ -0,0 +1,337 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
ModelWeightParameter,
RowvLLMParameter)
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.quantization.utils.common_utils import (str_dtype_to_torch,
str_dtype_to_bits,
is_fp8_str_dtype)
# @register_quantization_config("smoothquant")
class SmoothQuantConfig(QuantizationConfig):
"""Config class for SmoothQuant.
"""
def __init__(
self,
quant_mode: str, # smoothquant
input_quant_method: str, # per token/per tensor
group_size: int,
weight_precision: str,
activation_precision: str,
only_expert_per_group: bool,
expert_weight_precision: str,
expert_activation_precision: str,
force_use_weightonly_except_expert: bool,
) -> None:
super().__init__()
self.quant_mode = quant_mode
self.input_quant_method = input_quant_method
self.group_size = group_size
self.weight_precision = weight_precision
self.activation_precision = activation_precision
self.only_expert_per_group = only_expert_per_group
self.expert_weight_precision = expert_weight_precision
self.expert_activation_precision = expert_activation_precision
self.force_use_weightonly_except_expert = force_use_weightonly_except_expert
if quant_mode == "SmoothQuant" and (self.input_quant_method != "per_token" and self.input_quant_method != "per_tensor"):
raise ValueError(
"Currently, only per_token or per_tensor input quantization is supported for "
f"SmoothQuant, but got {self.input_quant_method}.")
self.weight_bits = str_dtype_to_bits(self.weight_precision)
self.expert_weight_bits = str_dtype_to_bits(self.expert_weight_precision)
if self.weight_precision == 'int4':
self.weight_dtype = torch.int8
else:
self.weight_dtype = str_dtype_to_torch(self.weight_precision)
if self.expert_weight_precision == 'int4':
self.expert_weight_dtype = torch.int8
else:
self.expert_weight_dtype = str_dtype_to_torch(self.expert_weight_precision)
self.is_fp8 = is_fp8_str_dtype(self.weight_precision)
self.expert_is_fp8 = is_fp8_str_dtype(self.expert_weight_precision)
self.pack_factor = 8 // self.weight_bits
self.expert_pack_factor = 8 // self.expert_weight_bits
def __repr__(self) -> str:
return (f"SmoothQuantConfig(input_quant_method={self.input_quant_method}, "
f"quant_mode={self.quant_mode}, "
f"group_size={self.group_size}, "
f"weight_precision={self.weight_precision}, "
f"activation_precision={self.activation_precision}, "
f"only_expert_per_group={self.only_expert_per_group}, "
f"expert_weight_precision={self.expert_weight_precision}, "
f"expert_activation_precision={self.expert_activation_precision}, "
f"force_use_weightonly_except_expert={self.force_use_weightonly_except_expert})")
@classmethod
def get_name(self) -> str:
return "SmoothQuant"
@classmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@staticmethod
def get_config_filenames() -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig":
quant_mode = cls.get_from_keys(config, ["quant_mode"])
input_quant_method = cls.get_from_keys(config, ["input_quant_method"])
group_size = cls.get_from_keys_or(config, ["group_size"], 1)
weight_precision = cls.get_from_keys_or(config, ["weight_precision"], "int8")
activation_precision = cls.get_from_keys_or(config, ["activation_precision"], "int8")
only_expert_per_group = cls.get_from_keys_or(config, ["only_expert_per_group"], False)
expert_weight_precision = cls.get_from_keys_or(config, ["expert_weight_precision"], None)
expert_activation_precision = cls.get_from_keys_or(config, ["expert_activation_precision"], None)
force_use_weightonly_except_expert = cls.get_from_keys_or(config, ["force_use_weightonly_except_expert"], False)
if expert_weight_precision is None:
expert_weight_precision = weight_precision
if group_size > 1 and only_expert_per_group and weight_precision == 'int4':
weight_precision = 'int8'
if expert_activation_precision is None:
expert_activation_precision = activation_precision
return cls(quant_mode=quant_mode,
input_quant_method=input_quant_method,
group_size=group_size,
weight_precision=weight_precision,
activation_precision=activation_precision,
only_expert_per_group=only_expert_per_group,
expert_weight_precision=expert_weight_precision,
expert_activation_precision=expert_activation_precision,
force_use_weightonly_except_expert=force_use_weightonly_except_expert)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["SmoothQuantLinearMethod"]:
if isinstance(layer, LinearBase):
return SmoothQuantLinearMethod(self, prefix)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class SmoothQuantLinearMethod(LinearMethodBase):
"""Linear method for SmoothQuant.
Args:
quant_config: The SmoothQuant quantization config.
"""
def __init__(self, quant_config: SmoothQuantConfig, prefix: str):
self.quant_config = quant_config
# for per-tensor case, we can skip quant input for the first attn|ffn linear
# and fusion this step in layernorm to get better performance
self.skip_quant_input = False
self.compute_dtype = torch.get_default_dtype()
self.is_expert = 'expert' in prefix and "shared_expert" not in prefix
self.weight_dtype = quant_config.expert_weight_dtype if self.is_expert else quant_config.weight_dtype
self.pack_factor = quant_config.expert_pack_factor if self.is_expert else quant_config.pack_factor
self.is_fp8 = quant_config.expert_is_fp8 if self.is_expert else quant_config.is_fp8
if quant_config.only_expert_per_group and self.is_expert and quant_config.group_size > 1:
self.is_group_quant = True
elif quant_config.only_expert_per_group is False and quant_config.group_size > 1:
self.is_group_quant = True
else:
self.is_group_quant = False
self.has_smooth = self.quant_config.input_quant_method == "per_token" and (
self.quant_config.force_use_weightonly_except_expert is False or self.is_expert)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor != 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
weight_loader = extra_weight_attrs.get("weight_loader")
group_num = 1
if self.is_group_quant:
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
f"The input size {input_size_per_partition} is not aligned with the quantized "
f"weight shape. This can be caused by too large "
f"tensor parallel size. group_size: {self.quant_config.group_size}.")
group_num = (input_size + self.quant_config.group_size - 1) // self.quant_config.group_size
if input_size_per_partition != input_size:
group_num = (input_size_per_partition + self.quant_config.group_size - 1) // self.quant_config.group_size
qweight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.pack_factor,
device="mlu",
dtype=self.weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
if self.is_group_quant:
per_channel_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
group_num,
device="mlu",
dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
else:
per_channel_scale = ChannelQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
device="mlu",
dtype=torch.float32,
),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("qweight", qweight)
layer.register_parameter("per_channel_scale", per_channel_scale)
if self.has_smooth:
smooth = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
device="mlu",
dtype=torch.float32,
),
input_dim=0,
weight_loader=weight_loader,
)
set_weight_attrs(smooth, {
"ignore_warning": True,
})
layer.register_parameter("smooth", smooth)
if self.quant_config.input_quant_method == "per_tensor":
scale_to_int = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
device="mlu",
dtype=torch.float32,
),
input_dim=0,
weight_loader=weight_loader,
)
set_weight_attrs(scale_to_int, {
"ignore_warning": True,
})
layer.register_parameter("scale_to_int", scale_to_int)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.has_smooth and layer.smooth.dtype != torch.float:
layer.smooth = layer.smooth.to(torch.float)
if self.quant_config.input_quant_method == "per_tensor" and layer.scale_to_int.dtype != torch.float:
layer.scale_to_int = layer.scale_to_int.to(torch.float)
if layer.per_channel_scale.dtype != torch.float:
layer.per_channel_scale = layer.per_channel_scale.to(torch.float)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.per_channel_scale = Parameter(layer.per_channel_scale.data, requires_grad=False)
if self.has_smooth:
layer.smooth = Parameter(layer.smooth.data, requires_grad=False)
if self.quant_config.input_quant_method == "per_tensor":
layer.scale_to_int = Parameter(layer.scale_to_int.data, requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
use_tp_weight : bool = False,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
layer_smooth = layer.smooth if self.has_smooth else None
layer_qweight = layer.qweight
layer_per_channel_scale = layer.per_channel_scale
if use_tp_weight:
if hasattr(layer, 'tp_smooth'):
layer_smooth = layer.tp_smooth
if hasattr(layer, 'tp_qweight'):
layer_qweight = layer.tp_qweight
if hasattr(layer, 'tp_per_channel_scale'):
layer_per_channel_scale = layer.tp_per_channel_scale
quant_input = None
if self.skip_quant_input:
quant_input = x
elif self.quant_config.input_quant_method == "per_token":
if self.is_fp8:
quant_input, input_scale = mlu_ops.scaled_quantize(x,
layer_smooth,
quant_type=self.weight_dtype,
quant_mode='dynamic_per_token')
else:
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer_smooth, None)
elif self.quant_config.input_quant_method == "per_tensor":
quant_input = mlu_ops.quantize(x, layer.scale_to_int, None)
else:
raise ValueError(
"Currently, only per_token or per_tensor input quantization is supported for "
f"SmoothQuant, but got {self.input_quant_method}.")
quant_input_shape = quant_input.shape
if len(quant_input_shape) > 2:
quant_input = quant_input.view(-1, quant_input_shape[-1])
input_scale = input_scale.view(-1)
if residual is not None and len(residual.shape) > 2:
residual = residual.view(-1, residual.shape[-1])
if self.is_fp8:
out = mlu_ops.scaled_matmul(quant_input, layer_qweight, input_scale,
layer_per_channel_scale,
self.compute_dtype if hasattr(self, 'compute_dtype') else x.dtype,
bias,
c=residual, act_mode="none",quant_bit_size=8,
alpha=1.0, beta=1.0, use_hp_active=False,
a_quant_bit_size=8, a_calib=None, b_calib=None)
if output is not None:
out = out.view(output.shape)
output.copy_(out)
out = output
else:
if output is not None:
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
layer_per_channel_scale, self.compute_dtype, bias, residual, output=output)
else:
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
layer_per_channel_scale, self.compute_dtype, bias, residual)
if len(quant_input_shape) > 2:
out = out.view(*quant_input_shape[:-1], out.shape[-1])
return out

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
QUANTIZATION_CHOICES = ['int8', 'int4', 'e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz']
INTERGER_DTYPES = [torch.uint8, torch.uint16, torch.uint32, torch.uint64, torch.int8, torch.int16, torch.short,
torch.int32, torch.int, torch.int64, torch.long]
FLOAT_DTYPES = [torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.bfloat16,
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz, torch.half]
FP8_DTYPE = [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
FP8_STR_DTYPE = ['e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz']
GEMM_GROUP_SIZE = [64, 128, 256, 512]
_STR_TO_TORCH_DTYPE_DICT = dict(
bfloat16=torch.bfloat16,
float16=torch.float16,
float32=torch.float32,
int64=torch.int64,
int32=torch.int32,
int8=torch.int8,
bool=torch.bool,
e4m3fn=torch.float8_e4m3fn,
e4m3fnuz=torch.float8_e4m3fnuz,
e5m2=torch.float8_e5m2,
e5m2fnuz=torch.float8_e5m2fnuz,
)
TORCH_DTYPE_TO_STR_DICT = {
torch.bfloat16: "bfloat16",
torch.float16: "float16",
torch.float32: "float32",
torch.int64: "int64",
torch.int32: "int32",
torch.int8: "int8",
torch.bool: "bool",
torch.float8_e4m3fn: "e4m3fn",
torch.float8_e4m3fnuz: "e4m3fnuz",
torch.float8_e5m2: "e5m2",
torch.float8_e5m2fnuz: "e5m2fnuz",
}
STR_DTYPE_TO_BITS_DICT = {
"bfloat16": 16,
"float16": 16,
"float32": 32,
"int64": 64,
"int32": 32,
"int8": 8,
'int4': 4,
"bool": 1,
"e4m3fn": 8,
"e4m3fnuz": 8,
"e5m2": 8,
"e5m2fnuz": 8,
}
def str_dtype_to_torch(str_dtype: str):
'''
convert torch dytpe to str dtype
'''
ret = _STR_TO_TORCH_DTYPE_DICT.get(str_dtype)
dtype = ret if ret is not None else torch.float16
return dtype
def torch_dtype_to_str(dtype: torch.dtype):
'''
convert torch dytpe to str dtype
'''
ret = TORCH_DTYPE_TO_STR_DICT.get(dtype)
str_dtype = ret if ret is not None else "float16"
return str_dtype
def str_dtype_to_bits(str_dtype):
'''
convert torch dtype to bits size
'''
ret = STR_DTYPE_TO_BITS_DICT.get(str_dtype)
bits = ret if ret is not None else 8
return bits
def is_integer_dtype(dtype: torch.dtype):
'''
check whether is integer or not
'''
return dtype in INTERGER_DTYPES
def is_float_dtype(dtype: torch.dtype):
'''
check whether is float or not
'''
return dtype in FLOAT_DTYPES
def is_fp8_dtype(dtype: torch.dtype):
'''
judge fp8 torch dtype
'''
return dtype in FP8_DTYPE
def is_fp8_str_dtype(str_dtype: str):
'''
judge fp8 str dtype
'''
return str_dtype in FP8_STR_DTYPE

View File

@@ -0,0 +1,424 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/pull/2575
import functools
import json
import os
from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_per_token_group_quant_fp8_colmajor)
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
'''
=============================
Modify by vllm_mlu
=============================
@brief: get total core for split triton kernel
'''
import triton.backends.mlu.driver as driver
_devprob = driver.BangUtils().get_device_properties(torch.mlu.current_device())
TOTAL_CLUSTER_NUM = _devprob.get("cluster_num")
TOTAL_CORE_NUM = TOTAL_CLUSTER_NUM * _devprob.get("core_num_per_cluster")
'''
==================
End of MLU Hijack
==================
'''
def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
and weight.shape[1] % 128 == 0)
if current_platform.is_rocm():
# TODO this is never used, as cutlass_block_fp8_supported is False
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
input_2d.shape[:-1])[::-1]
scale_b_shape = (weight_scale.view(-1, 1)
if weight_scale.dim() <= 1 else weight_scale.T).shape
ar, ac = scale_a_shape
br, bc = scale_b_shape
if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0])
or br not in (1, weight.shape[0])):
shape_supported_by_cutlass = False
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=True)
output = ops.cutlass_scaled_mm(q_input,
weight.T,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale.T)
else:
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=False)
output = w8a8_block_fp8_matmul(q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=input.dtype)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
'''
=============================
Modify by vllm_mlu
=============================
@brief: split for limit the memory usage(65536)
'''
group_per_block = 1
while M >= 65536:
group_per_block *= 2
M = x.numel() // (group_size * group_per_block)
'''
==================
End of MLU Hijack
==================
'''
if column_major_scales:
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device,
dtype=torch.float32).permute(-1, -2)
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
'''
=============================
Modify by vllm_mlu
=============================
@brief: set num_warps to 1 for triton-mlu
'''
num_warps = 1
num_stages = 1
'''
==================
End of MLU Hijack
==================
'''
if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M, )](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: replaced the 'scaled_quantize' kernel from the 'tmo' library with
'_per_token_group_quant_fp8' kernel
'''
# Check if x is contiguous, if not, create a new tensor for contiguous x
if not x.is_contiguous():
x = x.contiguous()
x_origin_shape = x.shape
x = x.reshape(*x.shape[:-1], -1, group_size)
x_q, x_s = mlu_ops.scaled_quantize(x,
None,
quant_type=dtype,
quant_mode='dynamic_per_token')
x_q = x_q.reshape(x_origin_shape)
'''
==================
End of MLU Hijack
==================
'''
return x_q, x_s
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
'''
=============================
Modify by vllm_mlu
=============================
@brief: split for limit the memory usage(65536)
'''
num_block_size_all = num_pid_m * num_pid_n
num_block_size_per = num_block_size_all // tl.num_programs(axis=0)
num_block_size_rem = num_block_size_all % tl.num_programs(axis=0)
core_deal_num_block_size = num_block_size_per + (pid < num_block_size_rem)
core_deal_num_block_start = num_block_size_per * pid + min(num_block_size_rem, pid)
for pid_i in range(0, core_deal_num_block_size):
pid_in_core_deal_block = core_deal_num_block_start + pid_i
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid_in_core_deal_block // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid_in_core_deal_block % group_size_m)
pid_n = (pid_in_core_deal_block % num_pid_in_group) // group_size_m
'''
==================
End of MLU Hijack
==================
'''
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
'''
=============================
Modify by vllm_mlu
=============================
@brief: replaced the 'scaled_matmul' kernel from the 'tmo' library with
'_w8a8_block_fp8_matmul' kernel
'''
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert B.ndim == 2 and Bs.ndim == 2
if (B.shape[0] % 128 == 0) and (B.shape[1] % 128 == 0):
C = mlu_ops.scaled_matmul(A, B, As, Bs, output_dtype, bias=None, c=None, act_mode="none",
quant_bit_size=8, alpha=1, beta=1, use_hp_active=False,
a_quant_bit_size=8, a_calib=None, b_calib=None)
else:
# NOTE(wulingchao): scaled_matmul 底层算子只支持n和k是128的倍数
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
# Default config
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
# BLOCK_SIZE_K must be divisible by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 1,
"num_stages": 1,
}
def grid(META):
return (TOTAL_CORE_NUM, )
_w8a8_block_fp8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
'''
==================
End of MLU Hijack
==================
'''
return C

View File

@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional, Callable
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, USE_ROWWISE_TORCH_SCALED_MM, cutlass_w8a8_scaled_mm,
flashinfer_w8a8_scaled_mm, rocm_per_tensor_w8a8_scaled_mm,
torch_per_tensor_w8a8_scaled_mm, torch_per_token_w8a8_scaled_mm,
torch_channelwise_w8a8_scaled_mm)
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def mlu_w8a8_scaled_mm(
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
output_shape: list, **kwargs
) -> torch.Tensor:
output = mlu_ops.scaled_matmul(
qinput, # a
weight, # b
scale_a, # a_scale
scale_b, # b_scale
out_dtype, # output_dtype
bias, # bias
c=None, act_mode="none",quant_bit_size=8, alpha=1, beta=1, use_hp_active=False,
a_quant_bit_size=8, a_calib=None, b_calib=None
)
return output.view(*output_shape)
def dispatch_w8a8_scaled_mm(
preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool,
weight_per_channel: bool, activation_per_token: bool
) -> Callable[..., torch.Tensor]:
if per_tensor_weights and per_tensor_activations:
if preferred_backend == "rocm":
return rocm_per_tensor_w8a8_scaled_mm
if preferred_backend == "flashinfer":
return flashinfer_w8a8_scaled_mm
if preferred_backend == "cutlass":
return cutlass_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
return cutlass_w8a8_scaled_mm
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
if (
not per_tensor_weights
and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM
):
return torch_per_token_w8a8_scaled_mm
# Normally, torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
'''
=============================
Modify by vllm_mlu
=============================
@brief: dispatch to mlu_w8a8_scaled_mm
'''
if weight_per_channel and activation_per_token:
return mlu_w8a8_scaled_mm
'''
==================
End of MLU Hijack
==================
'''
return torch_channelwise_w8a8_scaled_mm
def vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype | None = None,
input_scale: torch.Tensor | None = None,
input_scale_ub: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
weight_per_channel: bool = True,
activation_per_token: bool = True,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
'''
=============================
Modify by vllm_mlu
=============================
@brief: add mlu_fp8_supported
'''
self.mlu_fp8_supported = False
if weight_per_channel and activation_per_token:
self.mlu_fp8_supported = True
'''
==================
End of MLU Hijack
==================
'''
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[1]]
if out_dtype is None:
out_dtype = input.dtype
if self.mlu_fp8_supported:
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add support for activation-per-token weight-per-channel quantization.
'''
qinput, x_scale = mlu_ops.scaled_quantize(
input_2d,# x
None, # scale
None, # zero
None, # scale_ub
quant_type=torch.float8_e4m3fn,
quant_mode='dynamic_per_token'
)
output_shape = [*input.shape[:-1], weight.shape[0]]
'''
==================
End of MLU Hijack
==================
'''
else:
# If input not quantized
# TODO(luka) remove this path if not used anymore
if input.dtype != current_platform.fp8_dtype():
qinput, x_scale = self.quant_fp8(
input_2d,
input_scale,
input_scale_ub,
)
else:
qinput, x_scale = input_2d, input_scale
# Must have dim() conditions
# In per-token quant scenario, when the number of token is 1,
# the scale will only have 1 elements.
# Without checking the dim(),
# we cannot distingushes between per-tensor and per-token quant.
# Example:
# When the number of token is 1, per-token scale is [[1]]
# When per-tensor scale is [1] or ().
per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
self.preferred_backend, per_tensor_weights, per_tensor_activations,
weight_per_channel, activation_per_token)
return w8a8_scaled_mm_func(
qinput=qinput,
weight=weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
output_shape=output_shape,
)
MluHijackObject.apply_hijack(
Fp8LinearOp,
Fp8LinearOp.apply,
vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply
)

View File

@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm_mlu import _mlu_ops as mlu_ops
from vllm.logger import init_logger
logger = init_logger(__name__)
# @register_quantization_config("weightonly")
class WeightOnlyConfig(QuantizationConfig):
"""Config class for WeightOnly.
"""
def __init__(
self,
weight_bits: int,
quant_mode: str, # weight_only
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.quant_mode = quant_mode
if quant_mode == "WeightOnly" and (self.weight_bits != 8 and self.weight_bits != 4):
raise ValueError(
"Currently, only 8/4-bit weight quantization is supported for "
f"weight_only, but got {self.weight_bits} bits.")
self.pack_factor = 8 // self.weight_bits
def __repr__(self) -> str:
return (f"WeightOnlyConfig(weight_bits={self.weight_bits}, "
f"quant_mode={self.quant_mode})")
def get_name(self) -> str:
return "WeightOnly"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@staticmethod
def get_config_filenames() -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "WeightOnlyConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
try:
quant_mode = cls.get_from_keys(config, ["quant_mode"])
except Exception:
quant_mode = "WeightOnly"
return cls(weight_bits, quant_mode)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["WeightOnlyLinearMethod"]:
if isinstance(layer, LinearBase):
return WeightOnlyLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class WeightOnlyLinearMethod(LinearMethodBase):
"""Linear method for WeightOnly.
Args:
quant_config: The WeightOnly quantization config.
"""
def __init__(self, quant_config: WeightOnlyConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> Dict[str, Any]:
output_size_per_partition = sum(output_partition_sizes)
if self.quant_config.quant_mode == "WeightOnly":
scale_and_zero_input_dim = None
if output_size != output_size_per_partition:
scale_and_zero_input_dim = 0
qweight = Parameter(
torch.empty(
output_size_per_partition,
input_size_per_partition // self.quant_config.pack_factor,
device="mlu",
dtype=torch.int8,
),
requires_grad=False,
)
set_weight_attrs(qweight, {
"input_dim": 1,
"output_dim": 0,
})
scales = Parameter(
torch.empty(
output_size_per_partition,
device="mlu",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": scale_and_zero_input_dim,
"output_dim": 0,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if layer.scales.dtype != torch.float:
layer.scales = Parameter(layer.scales.to(torch.float), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
x_shape = x.shape
if len(x_shape) > 2:
x = x.view(-1, x_shape[-1])
out = mlu_ops.weight_only_quant_matmul(x,
layer.qweight,
layer.scales,
None,
bias,
residual,
"none",
self.quant_config.weight_bits)
if len(x_shape) > 2:
out = out.view(*x_shape[:-1], out.shape[-1])
return out

View File

@@ -0,0 +1,342 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import math
from typing import Any
import torch
from vllm.logger import init_logger
import vllm.model_executor.layers.rotary_embedding as rotary_embedding
from vllm.model_executor.layers.rotary_embedding import (
_ROPE_DICT,
RotaryEmbedding,
)
from vllm.model_executor.layers.rotary_embedding import (
_ROPE_DICT,
DualChunkRotaryEmbedding,
DynamicNTKAlphaRotaryEmbedding,
DynamicNTKScalingRotaryEmbedding,
Llama4VisionRotaryEmbedding,
MRotaryEmbedding,
NTKScalingRotaryEmbedding,
Phi3LongRoPEScaledRotaryEmbedding,
YaRNScalingRotaryEmbedding,
)
from .base import MLURotaryEmbedding
from .deepseek_scaling_rope import MLUDeepseekScalingRotaryEmbedding
from .dynamic_ntk_alpha_rope import MLUDynamicNTKAlphaRotaryEmbedding
from .dynamic_ntk_scaling_rope import MLUDynamicNTKScalingRotaryEmbedding
from .linear_scaling_rope import MLULinearScalingRotaryEmbedding
from .llama3_rope import MLULlama3RotaryEmbedding
from .mrope import MLUMRotaryEmbedding
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def get_long_max_model_max_position_emb(max_position_embeddings, scaling_factor):
if MLURotaryEmbedding.max_seq_len != None and \
MLURotaryEmbedding.max_seq_len > max_position_embeddings * scaling_factor:
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
f"max_position_embedding ({max_position_embeddings}) * scaling_factor ({scaling_factor}) " +
"from model's config.json, This may lead to incorrect model outputs or MLU errors. " +
f"Make sure the value is correct and within the model context size. " +
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
return math.ceil(MLURotaryEmbedding.max_seq_len / scaling_factor)
return max_position_embeddings
def vllm__model_executor__layers__rotary_embedding__get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: float,
is_neox_style: bool = True,
rope_scaling: dict[str, Any] | None = None,
dtype: torch.dtype | None = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: dict[str, Any] | None = None,
inverse: bool = False
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if dual_chunk_attention_config is not None:
dual_chunk_attention_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in dual_chunk_attention_config.items()
if k != "sparse_attention_config"
}
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
else:
dual_chunk_attention_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling_args,
dual_chunk_attention_args,
dtype,
inverse,
)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
for k, v in dual_chunk_attention_config.items()
if k in ("chunk_size", "local_size")
}
rotary_emb = DualChunkRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
**extra_kwargs,
)
elif not rope_scaling:
rotary_emb = MLURotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype,
inverse=inverse,
)
else:
scaling_type = rope_scaling["rope_type"]
if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
rotary_emb = MLULlama3RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
scaling_factor,
low_freq_factor,
high_freq_factor,
original_max_position,
)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MLUMRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
)
else:
rotary_emb = MLURotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
inverse=inverse,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
rotary_emb = MLULinearScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
elif scaling_type == "ntk":
scaling_factor = rope_scaling["factor"]
mixed_b = rope_scaling.get('mixed_b', None)
rotary_emb = NTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
mixed_b,
)
elif scaling_type == "dynamic":
if "alpha" in rope_scaling:
scaling_alpha = rope_scaling["alpha"]
rotary_emb = MLUDynamicNTKAlphaRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
dtype,
)
elif "factor" in rope_scaling:
scaling_factor = rope_scaling["factor"]
rotary_emb = MLUDynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
else:
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"apply_yarn_scaling",
)
}
if "mrope_section" in rope_scaling:
extra_kwargs.pop("apply_yarn_scaling", None)
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
scaling_factor=scaling_factor,
**extra_kwargs,
)
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: update original_max_position
'''
original_max_position = get_long_max_model_max_position_emb(
original_max_position, scaling_factor,
)
'''
==================
End of MLU Hijack
==================
'''
rotary_emb = YaRNScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
}
'''
=============================
Modify by vllm_mlu
=============================
@brief: update original_max_position
'''
original_max_position = get_long_max_model_max_position_emb(
original_max_position, scaling_factor,
)
'''
==================
End of MLU Hijack
==================
'''
rotary_emb = MLUDeepseekScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
inverse,
**extra_kwargs,
)
elif scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size,
rotary_dim,
max_position,
original_max_position,
base,
is_neox_style,
dtype,
short_factor,
long_factor,
**extra_kwargs,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb
MluHijackObject.apply_hijack(
rotary_embedding,
rotary_embedding.get_rope,
vllm__model_executor__layers__rotary_embedding__get_rope,
)

View File

@@ -0,0 +1,302 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Tuple
import torch
from vllm.config import get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.rotary_embedding.base import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.v1.attention.backends.utils import (
get_common_metadata,
MLUCommonAttentionMetadata,
)
from vllm_mlu.v1.attention.backends.mla.flashmla import MLACommonMetadata
from vllm_mlu.model_executor.models.sp_utils import get_sp_forward_context
logger = init_logger(__name__)
@CustomOp.register("rotary_embedding_mlu")
class MLURotaryEmbedding(RotaryEmbedding, CustomOp):
cu_seq_lens : torch.Tensor = None
max_seq_len : int = None
max_model_len : int = None
is_prompt : bool = False
is_chunked : bool = False
positions_: torch.Tensor = None
chunked_prefill_enabled: bool = False
prefill_cu_seq_lens: torch.Tensor = None
prefill_max_seq_len: int = None
decode_cu_seq_lens: torch.Tensor = None
decode_max_seq_len: int = None
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
inverse: bool = False,
) -> None:
CustomOp.__init__(self)
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
# TODO(mgoin): disabled for now due to failures
# Flashinfer only supports head_size=64, 128, 256, 512.
# https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202
# self.use_flashinfer = (self.enabled()
# and dtype in (torch.float16, torch.bfloat16)
# and current_platform.is_cuda()
# and has_flashinfer()
# and self.head_size in [64, 128, 256, 512])
self.use_flashinfer = False
self.inverse = inverse
# For vlm v1
# 1. mlu rope run in eager mode
# 2. all layer use layer0's rope to inference
prefix = "global_rope"
vllm_config = get_current_vllm_config()
self.use_direct_call = False
if not self.use_direct_call:
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
pass
else:
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import DeepseekScalingRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding.yarn_scaling_rope import YaRNScalingRotaryEmbedding
if MLURotaryEmbedding.max_seq_len != None \
and self.max_position_embeddings < MLURotaryEmbedding.max_seq_len and \
not isinstance(self, (YaRNScalingRotaryEmbedding, DeepseekScalingRotaryEmbedding)):
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
f"max_position_embedding ({max_position_embeddings}) from model's config.json, " +
f"This may lead to incorrect model outputs or MLU errors. " +
f"Make sure the value is correct and within the model context size. " +
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
self.max_position_embeddings = MLURotaryEmbedding.max_seq_len
cache = self._compute_cos_sin_cache()
from vllm_mlu.model_executor.layers.rotary_embedding.linear_scaling_rope import MLULinearScalingRotaryEmbedding
if isinstance(self, MLULinearScalingRotaryEmbedding):
logger.debug(f"Using mlu defining _compute_cos_sin_cache due to the special tensor composition")
elif is_neox_style:
cache_pos = cache.shape[0]
cache = cache.reshape(cache_pos, 2, -1)
cache = torch.tile(cache, (1, 1, 2)).reshape(cache_pos, -1)
else:
cache = cache.repeat_interleave(2, dim=-1)
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.cos_, self.sin_ = self._get_cos_sin()
@classmethod
def set_mlu_var_v1(
cls,
common_metadata: MLUCommonAttentionMetadata
) -> None:
cls.unset_mlu_var()
cls.cu_seq_lens = common_metadata.query_start_loc
cls.max_seq_len = common_metadata.max_query_len
cls.is_prompt = common_metadata.is_prefill_only
cls.is_chunked = common_metadata.is_chunked
# for MLA
attn_metadata = get_forward_context().attn_metadata
if isinstance(attn_metadata, dict):
_, attn_metadata = next(iter(attn_metadata.items()))
if isinstance(attn_metadata, MLACommonMetadata):
prefill_metadata = attn_metadata.prefill
decode_metadata = attn_metadata.decode
if prefill_metadata:
cls.prefill_max_seq_len = prefill_metadata.max_query_len
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
else:
cls.prefill_max_seq_len = cls.max_seq_len
cls.prefill_cu_seq_lens = cls.cu_seq_lens
if decode_metadata:
cls.decode_max_seq_len = decode_metadata.max_query_len
cls.decode_cu_seq_lens = decode_metadata.query_start_loc
else:
cls.decode_max_seq_len = cls.max_seq_len
cls.decode_cu_seq_lens = cls.cu_seq_lens
# for sp
sp_context = get_sp_forward_context()
if sp_context is not None and sp_context.is_v32:
prefill_metadata = sp_context.sp_attn_metadata.prefill
cls.is_chunked = True
cls.prefill_max_seq_len = prefill_metadata.max_query_len
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
@classmethod
def unset_mlu_var(cls):
cls.cu_seq_lens = None
cls.max_seq_len = None
cls.is_prompt = False
cls.is_chunked = False
cls.positions_ = None
cls.chunked_prefill_enabled = False
cls.prefill_cu_seq_lens = None
cls.prefill_max_seq_len = None
cls.decode_cu_seq_lens = None
cls.decode_max_seq_len = None
def _get_cos_sin(self) -> Tuple[torch.Tensor, torch.Tensor]:
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
sin = sin.view(-1, self.rotary_dim)
cos = cos.view(-1, self.rotary_dim)
return cos, sin
def _get_positions_with_offsets_mlu(
self,
positions: torch.Tensor,
offsets: torch.Tensor
) -> torch.Tensor:
if offsets.numel() != positions.numel():
raise Exception("rope offsets numel mismatch with positions, "
f"positions: {positions.numel()}, offsets: {offsets.numel()}")
return (positions + offsets).to(torch.int32)
def forward_impl(
self,
positions: torch.Tensor,
x: torch.Tensor,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
common_metadata: MLUCommonAttentionMetadata = get_common_metadata()
if common_metadata is None:
num_tokens, head_num, head_size = x.shape
x = mlu_ops.rotary_embedding(
x.view(1, num_tokens, head_num, head_size),
self.sin_,
self.cos_,
positions,
None,
not self.is_neox_style,
True,
False,
num_tokens
)
return x
else:
cu_seq_lens_ = common_metadata.query_start_loc
if offsets is not None:
if MLURotaryEmbedding.positions_ is None:
MLURotaryEmbedding.positions_ = (
self._get_positions_with_offsets_mlu(positions, offsets))
position_ids = MLURotaryEmbedding.positions_
discrete = True
elif MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
position_ids = positions
discrete = True
else:
position_ids = None
discrete = False
x = mlu_ops.rotary_embedding(
x,
self.sin_,
self.cos_,
position_ids,
cu_seq_lens_,
not self.is_neox_style,
discrete,
False,
MLURotaryEmbedding.max_seq_len
)
return x
def get_param(self, positions, discrete=False):
interleaved = True
if self.is_neox_style:
interleaved = False
if discrete:
position_ids = positions
discrete = discrete
else:
if MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
position_ids = positions
discrete = True
else:
position_ids = None
discrete = False
return position_ids, interleaved, discrete
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.outer(t, inv_freq)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cos = freqs_cis.real
sin = freqs_cis.imag * (-1 if self.inverse else 1)
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor | None = None,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
only_prefill: bool | None = False,
only_decode: bool | None = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.forward_impl(positions, query, offsets)
if key is not None:
self.forward_impl(positions, key, offsets)
return query, key
def rope_forward(
positions: torch.Tensor,
x: torch.Tensor,
layer_name: str,
offsets: torch.Tensor | None = None,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_impl(positions, x, offsets)
def rope_forward_fake(
positions: torch.Tensor,
x: torch.Tensor,
layer_name: str,
offsets: torch.Tensor | None = None,
) -> None:
return
direct_register_custom_op(
op_name="rope_forward",
op_func=rope_forward,
mutates_args=["x"],
fake_impl=rope_forward_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@@ -0,0 +1,166 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Tuple
import torch
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import (
DeepseekScalingRotaryEmbedding,
yarn_get_mscale,
)
from vllm.model_executor.layers.rotary_embedding.common import (
rotate_gptj,
rotate_neox,
yarn_find_correction_range,
yarn_linear_ramp_mask,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLUDeepseekScalingRotaryEmbedding(MLURotaryEmbedding, DeepseekScalingRotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
inverse: bool = False,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
self.inverse = inverse
MLURotaryEmbedding.__init__(
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def forward_mlu_rot(self, input, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len):
"""only one input rotary implementation"""
if input is None:
return None
if self.rotary_dim < self.head_size:
input_pass = input[..., self.rotary_dim:]
input_rot = input[..., :self.rotary_dim]
input_rot = mlu_ops.rotary_embedding(
input_rot,
self.sin_,
self.cos_,
position_ids,
cu_seq_lens,
interleaved,
discrete,
False,
max_seq_len
)
if self.rotary_dim < self.head_size:
input = torch.cat((input_rot, input_pass), dim=-1)
else:
input = input_rot
return input
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor | None = None,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
only_prefill: bool | None = False,
only_decode: bool | None = False,
discrete: bool | None = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
position_ids, interleaved, discrete = self.get_param(positions, discrete)
cu_seq_lens = MLURotaryEmbedding.cu_seq_lens
max_seq_len = MLURotaryEmbedding.max_seq_len
# for MLA
attn_metadata = get_forward_context().attn_metadata
if isinstance(attn_metadata, dict):
_, attn_metadata = next(iter(attn_metadata.items()))
if isinstance(attn_metadata, MLACommonMetadata):
if only_prefill:
cu_seq_lens = MLURotaryEmbedding.prefill_cu_seq_lens
max_seq_len = MLURotaryEmbedding.prefill_max_seq_len
elif only_decode:
cu_seq_lens = MLURotaryEmbedding.decode_cu_seq_lens
max_seq_len = MLURotaryEmbedding.decode_max_seq_len
query = self.forward_mlu_rot(query, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len)
key = self.forward_mlu_rot(key, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len)
return query, key
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** (
torch.arange(
0,
self.rotary_dim,
2,
dtype=torch.float,
device=current_platform.device_type,
)
/ self.rotary_dim
)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.rotary_dim,
self.base,
self.max_position_embeddings,
)
# Get n-d rotational scaling corrected for extrapolation
device = current_platform.device_type
inv_freq_mask = ((
1
- yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
) * self.extrapolation_factor).to(device)
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(
self.max_position_embeddings * self.scaling_factor,
device=current_platform.device_type,
dtype=torch.float32,
)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale * (-1 if self.inverse else 1)
cache = torch.cat((cos, sin), dim=-1)
return cache
forward = MLURotaryEmbedding.forward
forward_native = forward_oot

View File

@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.model_executor.layers.rotary_embedding.dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLUDynamicNTKAlphaRotaryEmbedding(MLURotaryEmbedding, DynamicNTKAlphaRotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_alpha: float,
dtype: torch.dtype,
) -> None:
self.scaling_alpha = scaling_alpha
MLURotaryEmbedding.__init__(
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)

View File

@@ -0,0 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.model_executor.layers.rotary_embedding.dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLUDynamicNTKScalingRotaryEmbedding(MLURotaryEmbedding, DynamicNTKScalingRotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
) -> None:
self.scaling_factor = scaling_factor
MLURotaryEmbedding.__init__(
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)

View File

@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Union
import torch
from vllm.platforms import current_platform
from vllm.model_executor.layers.rotary_embedding.linear_scaling_rope import LinearScalingRotaryEmbedding
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLULinearScalingRotaryEmbedding(MLURotaryEmbedding, LinearScalingRotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factors: list[float] | float,
dtype: torch.dtype,
) -> None:
if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors: list[float] = scaling_factors # noqa
MLURotaryEmbedding.__init__(
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
# Lazy initialized.
self._scaling_factor_to_offset: dict[float, int]
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
device = current_platform.device_type
if self.is_neox_style:
half_dim = self.rotary_dim // 2
inv_freq = 1.0 / (
base
** (torch.arange(0, self.rotary_dim, 1, dtype=torch.float32, device=device)
% half_dim * 2 / self.rotary_dim)
)
else:
inv_freq = 1.0 / (
base
** (torch.arange(0, self.rotary_dim, 1, dtype=torch.float32, device=device)
// 2 * 2 / self.rotary_dim
)
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
cache_list: list[torch.Tensor] = []
# offsets to the next cache in a tensor.
# Each offset corresponds to the same index in scaling_factors.
offsets: list[int] = []
device = current_platform.device_type
for scaling_factor in self.scaling_factors:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * scaling_factor
t = torch.arange(max_len, dtype=torch.float, device=device)
t = t / scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
if not cache_list:
offset = 0
else:
last_offset = offsets[-1]
next_max_len = cache_list[-1].shape[0]
offset = last_offset + next_max_len
offsets.append(offset)
cache_list.append(cache)
self._scaling_factor_to_offset = {
float(scaling_factor): offsets[i]
for i, scaling_factor in enumerate(self.scaling_factors)
}
assert len(self.scaling_factors) == len(offsets)
return torch.cat(cache_list, dim=0)

View File

@@ -0,0 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLULlama3RotaryEmbedding(MLURotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
scaling_factor: float,
low_freq_factor: float,
high_freq_factor: float,
orig_max_position: int,
) -> None:
self.scaling_factor = scaling_factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)

View File

@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.model_executor.layers.rotary_embedding.common import yarn_get_mscale
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLUMRotaryEmbedding(MLURotaryEmbedding, MRotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: list[int] | None = None,
mrope_interleaved: bool = False,
# YaRN parameters.
*,
scaling_factor: float | None = None,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
if self.scaling_factor is not None:
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
else:
self.mscale = 1.0
# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
# a larger the cos and sin cache.
self.cache_max_position_num = max_position_embeddings * 4
MLURotaryEmbedding.__init__(
self,
head_size,
rotary_dim,
self.cache_max_position_num,
base,
is_neox_style,
dtype,
)
self.mrope_section = mrope_section
self.mrope_interleaved = mrope_interleaved
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
def _apply_mrope(self, positions):
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
num_section = len(self.mrope_section)
mrope_section = self.mrope_section * 2
def _apply(x):
x = torch.cat([
m[i % num_section]
for i, m in enumerate(x.split(mrope_section, dim=-1))
],
dim=-1)
return x
return _apply(cos), _apply(sin)
def _apply_interleaved_mrope(self, positions):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
mrope_section = self.mrope_section
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
def _apply(x):
x_t = x[0].clone()
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
offset = self.rotary_dim // 2
x_t[..., 1 + offset:mrope_section[1] * 3 + offset:3] = x[1, ..., 1 + offset:mrope_section[1] * 3 + offset:3]
x_t[..., 2 + offset:mrope_section[2] * 3 + offset:3] = x[2, ..., 2 + offset:mrope_section[2] * 3 + offset:3]
return x_t
return _apply(cos), _apply(sin)
def precompute_sin_cos_cache(
self,
positions: torch.Tensor
):
'''
call this function before forward decoder layers
precompute sin/cos cache for mrope
'''
if positions.ndim == 1:
return
assert positions.ndim == 2
assert self.mrope_section
if self.mrope_interleaved:
cos, sin = self._apply_interleaved_mrope(positions)
else:
cos, sin = self._apply_mrope(positions)
self.mrope_cos_cache = cos
self.mrope_sin_cache = sin
self.mrope_cu_seq_lens = torch.zeros(2, dtype=torch.int32, device=positions.device)
num_tokens = positions.shape[-1]
self.mrope_cu_seq_lens[1] = num_tokens
def forward_oot(
self,
positions: torch.Tensor,
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
assert positions.ndim == 1 or positions.ndim == 2
if positions.ndim == 1:
return MLURotaryEmbedding.forward_oot(self, positions, x)
assert self.mrope_cos_cache is not None and self.mrope_sin_cache is not None,\
"call precompute_sin_cos_cache first!"
num_tokens = positions.shape[-1]
x = mlu_ops.rotary_embedding(x,
self.mrope_sin_cache,
self.mrope_cos_cache,
None,
self.mrope_cu_seq_lens,
not self.is_neox_style,
False,
False,
num_tokens)
return x
forward = MLURotaryEmbedding.forward

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,173 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import torch.nn as nn
import numpy as np
from typing import List, Tuple
from tqdm import tqdm
from vllm.config import ModelConfig
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def initialize_dummy_weights_normal_dist(
model: torch.nn.Module,
low: float = -1e-3,
high: float = 1e-3,
std: float = 0.5,
seed: int = 1234,
) -> None:
"""
Initialize the weights of a PyTorch model with values drawn from a normal distribution.
Floating point parameters are initialized with a normal distribution whose mean is randomly
sampled from [low, high] and standard deviation is fixed at 0.5. Integer parameters are
initialized with random integers in [floor(low), ceil(high)). The initialization is performed
in a batched and efficient way for both floating point and integer parameters.
Optimized version: Uses shared pinned memory based on the largest parameter block size
to minimize H2D transfers, sacrificing global uniqueness for performance.
Args:
model (torch.nn.Module): The model whose weights will be initialized.
low (float): Lower bound for sampling the mean of the normal distribution (for float params).
high (float): Upper bound for sampling the mean of the normal distribution (for float params).
std (float): Standard deviation for the normal distribution (for float params).
seed (int): Random seed for reproducibility.
"""
# Randomly sample the mean for the normal distribution from [low, high]
rng = np.random.RandomState(seed)
mean = float(rng.uniform(low, high, 1).item())
# Create a CPU generator for reproducibility
cpu_gen = torch.Generator(device="cpu")
cpu_gen.manual_seed(seed)
# Collect parameters: separate into floating point and integer types
float_params: List[Tuple[str, torch.Tensor]] = []
int_params: List[Tuple[str, torch.Tensor]] = []
for name, t in tqdm(model.state_dict().items(), desc="Gen dummy weights: Collect params"):
if not isinstance(t, torch.Tensor):
continue
if torch.is_floating_point(t):
float_params.append((name, t))
elif t.dtype in (torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64):
int_params.append((name, t))
# -------- Floating point parameters: optimized shared memory initialization --------
if float_params:
# Find the largest parameter block size
max_float_elems = max(p.numel() for _, p in float_params)
# Create shared pinned memory buffer based on largest parameter
shared_float_buffer = torch.empty(max_float_elems, dtype=torch.float32, device="cpu", pin_memory=True)
shared_float_buffer.normal_(mean=mean, std=std, generator=cpu_gen)
# Copy shared buffer to device once
device_buffer = shared_float_buffer.to(next(iter(float_params))[1].device, non_blocking=True)
for _, p in tqdm(float_params, desc="Gen dummy weights: Init float params"):
n = p.numel()
# Extract from device buffer (may reuse same values for different parameters)
view = device_buffer[:n].view(p.shape)
# torch.normal_ does not support dtypes < fp16, so cast via fp16 if needed
if torch.finfo(p.dtype).bits < 16:
tmp = view.to(torch.float16)
tmp = tmp.to(p.dtype)
else:
tmp = view.to(p.dtype)
# Copy from device buffer to parameter (D2D copy, much faster)
p.data.copy_(tmp)
# -------- Integer parameters: optimized shared memory initialization --------
if int_params:
# Find the largest parameter block size
max_int_elems = max(p.numel() for _, p in int_params)
int_low = int(np.floor(low))
int_high = int(np.ceil(high))
if int_high == int_low:
int_high = int_low + 1 # Ensure at least one possible value
# Create shared pinned memory buffer based on largest parameter
shared_int_buffer = torch.randint(
low=int_low,
high=int_high,
size=(max_int_elems,),
dtype=torch.int64,
generator=cpu_gen,
device="cpu",
pin_memory=True
)
# Copy shared buffer to device once
device_int_buffer = shared_int_buffer.to(next(iter(int_params))[1].device, non_blocking=True)
for _, p in tqdm(int_params, desc="Gen dummy weights: Init int params"):
n = p.numel()
# Extract from device buffer (may reuse same values for different parameters)
view = device_int_buffer[:n].view(p.shape)
tmp = view.to(p.dtype)
# Copy from device buffer to parameter (D2D copy, much faster)
p.data.copy_(tmp)
SMOOTHQUANT_METHOD = "smoothquant"
MULTIMODAL_ARCH_KEYWORDS = {"VL", "Vision", "Multimodal"}
def vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
'''
=============================
Modify by vllm_mlu
=============================
@brief: use torch.normal_ instead of torch.uniform_ for distinguishable logits
std=0.5 is used for better distinguishable logits
'''
# === Default parameter setup (Original values as fallback) ===
low_val = -1e-3
high_val = 1e-3
std_val = 0.5
# === Model and Quantization Check Logic ===
quant_method = getattr(model_config, "quantization", None)
# Attempt to get the architectures list from model_config
archs = getattr(model_config, "architectures", []) or []
# Determine if the model is multimodal (based on architecture names)
is_multimodal = any(
keyword in arch
for arch in archs
for keyword in MULTIMODAL_ARCH_KEYWORDS
)
# === Apply SmoothQuant + Multimodal Parameters ===
if is_multimodal and quant_method == SMOOTHQUANT_METHOD:
# (smoothquant) + Multimodal specific values to mitigate NaN overflow
std_val = 1e-4
initialize_dummy_weights_normal_dist(
model,
low=low_val,
high=high_val,
std=std_val
)
# add a sync to make sure the weights are initialized
torch.mlu.synchronize()
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(
DummyModelLoader,
DummyModelLoader.load_weights,
vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights
)

View File

@@ -0,0 +1,137 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import time
import torch
from torch import nn
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, TensorDeserializer, TensorizerArgs,
_check_tensors_on_meta_device, _resize_lora_embeddings,
is_valid_deserialization_uri)
from vllm.platforms import current_platform
from vllm.logger import init_logger
try:
from tensorizer.stream_io import open_stream
from tensorizer.utils import (convert_bytes, get_mem_usage,
no_init_or_tensor)
except ImportError:
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
logger = init_logger(__name__)
def deserialize_tensorizer_model(model: nn.Module,
tensorizer_config: TensorizerConfig) -> None:
tensorizer_args = tensorizer_config._construct_tensorizer_args()
if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri):
raise ValueError(
f"{tensorizer_config.tensorizer_uri} is not a valid "
f"tensorizer URI. Please check that the URI is correct. "
f"It must either point to a local existing file, or have a "
f"S3, HTTP or HTTPS scheme.")
before_mem = get_mem_usage()
start = time.perf_counter()
'''
=============================
Modify by vllm_mlu
=============================
@brief: use mlu device
'''
device = ''
if current_platform.is_out_of_tree():
device = f'mlu:{torch.mlu.current_device()}'
elif current_platform.is_xpu():
device = f'xpu:{torch.xpu.current_device()}'
else:
device = f'cuda:{torch.cuda.current_device()}'
with open_stream(
tensorizer_config.tensorizer_uri,
mode="rb",
**tensorizer_args.stream_kwargs) as stream, TensorDeserializer(
stream,
dtype=tensorizer_config.dtype,
device=device,
**tensorizer_args.deserialization_kwargs) as deserializer:
deserializer.load_into_module(model)
end = time.perf_counter()
'''
==================
End of MLU Hijack
==================
'''
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
deserializer.close()
logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
end - start, per_second)
logger.info("Memory usage before: %s", before_mem)
logger.info("Memory usage after: %s", after_mem)
_check_tensors_on_meta_device(model)
_resize_lora_embeddings(model)
del model.vllm_tensorized_marker
def serialize_extra_artifacts(
tensorizer_args: TensorizerArgs,
served_model_name: Union[str, list[str], None]) -> None:
if not isinstance(served_model_name, str):
raise ValueError(
f"served_model_name must be a str for serialize_extra_artifacts, "
f"not {type(served_model_name)}.")
'''
=============================
Modify by vllm_mlu
=============================
@brief: use local file
'''
import shutil
from pathlib import Path
local_model_path = Path(served_model_name)
if not local_model_path.exists() or not local_model_path.is_dir():
raise ValueError(
f"served_model_name must be a valid local directory in offline mode, "
f"but got: {served_model_name}"
)
'''
==================
End of MLU Hijack
==================
'''
with tempfile.TemporaryDirectory() as tmpdir:
'''
=============================
Modify by vllm_mlu
=============================
@brief: copy local file
'''
logger.info("Copying local model from %s to temporary directory %s",
local_model_path, tmpdir)
shutil.copytree(local_model_path, tmpdir, dirs_exist_ok=True)
'''
==================
End of MLU Hijack
==================
'''
for artifact in os.scandir(tmpdir):
if not artifact.is_file():
continue
with open(artifact.path, "rb") as f, open_stream(
f"{tensorizer_args.tensorizer_dir}/{artifact.name}",
mode="wb+",
**tensorizer_args.stream_kwargs) as stream:
logger.info("Writing artifact %s", artifact.name)
stream.write(f.read())

View File

@@ -0,0 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from torch import nn
from vllm.config import ModelConfig
from vllm.model_executor.model_loader.tensorizer import is_vllm_tensorized
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
from vllm_mlu.model_executor.model_loader.tensorizer import deserialize_tensorizer_model
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__model_executor__model_loader__tensorizer_loader__TensorizerLoader__load_weights(
self,
model: nn.Module,
model_config: ModelConfig
) -> None:
"""Load serialized model weights with tensorizer.
Expects a vLLM-tensorized model. See the
examples/others/tensorize_vllm_model.py example script
for serializing vLLM models."""
if is_vllm_tensorized(self.tensorizer_config):
tensorizer_config = self._patch_tensorizer_config(model_config)
deserialize_tensorizer_model(model, tensorizer_config)
else:
model.load_weights(self._get_weights_iterator())
MluHijackObject.apply_hijack(
TensorizerLoader,
TensorizerLoader.load_weights,
vllm__model_executor__model_loader__tensorizer_loader__TensorizerLoader__load_weights
)

View File

@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm import ModelRegistry
def register_model():
from .deepseek_v4 import MLUDeepseekV4ForCausalLM # noqa: F401
ModelRegistry.register_model(
"DeepseekV4ForCausalLM",
"vllm_mlu.model_executor.models.deepseek_v4:MLUDeepseekV4ForCausalLM")

View File

@@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from math import lcm
from typing import TYPE_CHECKING
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor.models.config import (HybridAttentionMambaModelConfig,
MambaModelConfig)
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
@classmethod
def vllm__module_executor__models__config__HybridAttentionMambaModelConfig__verify_and_update_config(
cls,
vllm_config: "VllmConfig"
) -> None:
"""
Ensure that page size of attention layers is greater than or
equal to the mamba layers. If not, automatically set the attention
block size to ensure that it is. If the attention page size is
strictly greater than the mamba page size, we pad the mamba page size
to make them equal.
Args:
vllm_config: vLLM Config
"""
# Save the user input before it gets modified by MambaModelConfig
mamba_block_size = vllm_config.cache_config.mamba_block_size
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
if cache_config.cache_dtype == "auto":
kv_cache_dtype = model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# get attention page size (for 1 token)
# Attention backend constraints:
# - FlashAttention (FA) requires block size to be multiple of 16
# - MLA (Multi-head Latent Attention) requires larger alignment:
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
# * Other MLA backends: kernel_block_size 64 alignment
if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
attn_page_size_1_token = MLAAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
else:
kernel_block_alignment_size = 16
if (
current_platform.is_device_capability(100)
and model_config.get_head_size() == 256
and (
envs.VLLM_ATTENTION_BACKEND is None
or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER"
)
):
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
# head size 256 and block size 16 is not supported on blackwell.
kernel_block_alignment_size = 32
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
model_config=model_config,
)
# get mamba page size
mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
block_size=model_config.max_model_len,
).page_size_bytes
# Model may be marked as is_hybrid
# but mamba is skipped via config,
# return directly
if mamba_page_size == 0:
return
if cache_config.enable_prefix_caching:
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
# Mamba2 SSD kernel uses a chunk_size, e.g. 256
# Align the block to the kernel: use lowest multiple of chunk_size
# of attention tokens that would fit mamba_page_size:
# e.g. for mamba page size = 788kB
# attn_1_token = 2kB -> fits ~394 tokens
# then round up to a mulitple of 256 -> 512 tokens
# End result:
# attn_block_size = 512
# mamba_block_size = 512 (aligned to a multiple of chunk_size)
# TODO(tdoublep): this constraint can be relaxed fairly
# easily by changing the way we layout chunks in the
# mamba2 kernels.
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size
else:
# Without prefix caching, select minimum valid attention block size
# to minimize mamba state padding
# Calculate minimum attention block size that satisfies both:
# 1. Backend alignment requirements (kernel_block_alignment_size)
# 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
attn_block_size = kernel_block_alignment_size * cdiv(
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: support qwen3-next
'''
if (vllm_config.mlu_config.enable_mamba_split_page_size):
vllm_config.mlu_config.mamba_to_attn_block_ratio = cdiv(attn_block_size, cache_config.block_size)
cache_config.mamba_page_size_padded = cache_config.block_size * attn_page_size_1_token
return
'''
==================
End of MLU Hijack
==================
'''
# override attention block size if either (a) the
# user has not set it or (b) the user has set it
# too small.
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
cache_config.block_size = attn_block_size
logger.info(
"Setting attention block size to %d tokens "
"to ensure that attention page size is >= mamba page size.",
attn_block_size,
)
# compute new attention page size
attn_page_size = cache_config.block_size * attn_page_size_1_token
assert attn_page_size >= mamba_page_size
if attn_page_size == mamba_page_size:
# don't need to pad mamba page size
return
# pad mamba page size to exactly match attention
if (
cache_config.mamba_page_size_padded is None
or cache_config.mamba_page_size_padded != attn_page_size
):
cache_config.mamba_page_size_padded = attn_page_size
mamba_padding_pct = (
100 * (attn_page_size - mamba_page_size) / mamba_page_size
)
logger.info(
"Padding mamba page size by %.2f%% to ensure "
"that mamba page size and attention page size are "
"exactly equal.",
mamba_padding_pct,
)
MluHijackObject.apply_hijack(HybridAttentionMambaModelConfig,
HybridAttentionMambaModelConfig.verify_and_update_config,
vllm__module_executor__models__config__HybridAttentionMambaModelConfig__verify_and_update_config)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,607 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import (
Any, List, Tuple, Optional, Dict, Union, ClassVar, Literal,
Protocol, overload, runtime_checkable)
from typing_extensions import TypeIs
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.config import VllmConfig
from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_gather_into_list,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter,
)
from vllm.distributed import (
get_tp_group,
get_pp_group,
get_dp_group,
get_data_parallel_group_rank,
get_data_parallel_group_world_size,
get_dense_mlp_tp_world_size,
get_tp_world_world_size,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_logits_tp_world_size,
get_parallel_rank_with_group,
get_tp_world_group,
get_tp_world_rank,
GroupCoordinator,
)
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm_mlu.mlu_forward_context import MLUDPMetadata
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
logger = init_logger(__name__)
# alias after refactor
DataParallelRuntimeParams = MLUDPMetadata
def enable_data_parallel():
return get_dp_group().world_size > 1
def enable_emb_logits_custom_parallel():
return get_logits_tp_world_size() != get_tensor_model_parallel_world_size()
def enable_dense_mlp_custom_parallel():
return get_dense_mlp_tp_world_size() != get_tp_world_world_size()
def get_runtime_infos_per_dp_group(
num_tokens: int, num_requests: int, all_prefill: bool, seq_lens: List[int],
device: torch.device, vllm_config: VllmConfig) -> Tuple[List[int], List[bool]]:
dp_tensor = torch.tensor([num_tokens, num_requests, int(all_prefill)]).to(device, non_blocking=True)
outputs = tensor_model_parallel_all_gather_into_list(dp_tensor, get_dp_group())
outputs = torch.cat(outputs).tolist() # d2h
dp_world_size = get_data_parallel_group_world_size()
dp_is_prefill, dp_query_lens, dp_group_bs, seq_len_per_batch = [], [], [], []
for i in range(0, 3 * dp_world_size, 3):
dp_query_lens.append(outputs[i])
dp_group_bs.append(outputs[i + 1])
dp_is_prefill.append(bool(outputs[i + 2]))
# Only run communication if mcc is enabled and is prefill.
if vllm_config.mlu_config.is_dpsk_mcc_enabled and all(dp_is_prefill):
assert len(seq_lens) == num_requests
seq_len_per_batch = [torch.empty([bs], dtype=dp_tensor.dtype, device=device) for bs in dp_group_bs]
seq_lens_tensor = torch.tensor(seq_lens, dtype=dp_tensor.dtype, device=device)
torch.distributed.all_gather(seq_len_per_batch, seq_lens_tensor, group=get_dp_group().device_group)
seq_len_per_batch=torch.cat(seq_len_per_batch).tolist()
else:
seq_len_per_batch = [0] * sum(dp_group_bs)
return dp_query_lens, dp_group_bs, dp_is_prefill, seq_len_per_batch
def get_deepseek_layer_split_list(
dp_query_lens: List[int], dp_group_bs: List[int]
) -> Tuple[Optional[List[int]], Optional[List[int]], Optional[List[int]]]:
if len(dp_query_lens) != len(dp_group_bs) or len(dp_query_lens) != get_data_parallel_group_world_size():
logger.warning(f"dp_query_lens length: {len(dp_query_lens)} != dp_group_bs length: {len(dp_group_bs)}, "
f"disable deepseek layer split")
return None, None, None
emb_query_lens, logits_batch_sizes, dense_attn_token_split_list = None, None, None
all_dp_query_lens, all_dp_group_bs = [], []
for i in range(len(dp_query_lens)):
all_dp_query_lens.extend([dp_query_lens[i]] * get_tensor_model_parallel_world_size())
all_dp_group_bs.extend([dp_group_bs[i]] * get_tensor_model_parallel_world_size())
if get_logits_tp_world_size() != get_tensor_model_parallel_world_size():
slice_start = get_tp_world_rank() // get_logits_tp_world_size() * get_logits_tp_world_size()
slice_end = slice_start + get_logits_tp_world_size()
emb_query_lens = all_dp_query_lens[slice_start:slice_end]
logits_batch_sizes = all_dp_group_bs[slice_start:slice_end]
if get_dense_mlp_tp_world_size() != get_tp_world_world_size():
slice_start = get_tp_world_rank() // get_dense_mlp_tp_world_size() * get_dense_mlp_tp_world_size()
slice_end = slice_start + get_dense_mlp_tp_world_size()
dense_attn_token_split_list = all_dp_query_lens[slice_start:slice_end]
return emb_query_lens, logits_batch_sizes, dense_attn_token_split_list
def get_dp_metadata(
num_tokens: int,
data_parallel_size: int,
data_parallel_rank: int,
tensor_parallel_size: int,
prefill_dispatch_use_RS_AG: bool,
) -> DataParallelRuntimeParams:
"""
Get dp params when dummy run or capture model graph. These two cases do not have
dp_params when forward call, because we do not want to hijack to much.
"""
dp_query_lens = [num_tokens] * data_parallel_size
in_prefill = get_forward_context().attn_metadata is None # dummy run
dp_is_prefill = [in_prefill] * data_parallel_size
emb_query_lens, logits_batch_sizes, dense_attn_token_split_list = None, None, None
if get_logits_tp_world_size() != get_tensor_model_parallel_world_size():
emb_query_lens = [num_tokens] * get_logits_tp_world_size()
logits_batch_sizes = None # dummy run and capture model does not contain logits
if get_dense_mlp_tp_world_size() != get_tp_world_world_size():
dense_attn_token_split_list = [num_tokens] * get_dense_mlp_tp_world_size()
return MLUDPMetadata.make_oot(data_parallel_rank,
data_parallel_size,
tensor_parallel_size,
dp_query_lens,
dp_is_prefill,
prefill_dispatch_use_RS_AG,
emb_query_lens=emb_query_lens,
logits_batch_sizes=logits_batch_sizes,
dense_attn_token_split_list=dense_attn_token_split_list)
def remove_paddings_after_all_gather(
hidden_states: torch.Tensor,
padding_to_token_num: int,
token_num_list: List[int],
) -> torch.Tensor:
dp_group_tensors = []
offset = 0
for token_num in token_num_list:
if token_num != 0:
dp_group_tensors.append(hidden_states[offset:offset+token_num])
offset += padding_to_token_num
if len(dp_group_tensors) == 1:
hidden_states = dp_group_tensors[0]
else:
hidden_states = torch.cat(dp_group_tensors)
return hidden_states
def tensor_model_parallel_all_gather_dp(
group_num_tokens: List[int],
rank: int,
hidden_states: Optional[torch.Tensor],
group: GroupCoordinator,
hidden_size: int = None,
dtype: torch.dtype = None,
device: torch.device = None) -> torch.Tensor:
"""
All gather in the group.
Input is a 2-D tensor, and can have different shape in the first dim,
for example, [4, 7, 5, 8], [2, 5, 4, 0].
"""
num_tokens_equal = all(x == group_num_tokens[0] for x in group_num_tokens)
if num_tokens_equal:
hidden_states = tensor_model_parallel_all_gather(
input_=hidden_states, dim=0, tp_group=group)
else:
max_num_tokens = max(group_num_tokens)
num_padding = max_num_tokens - group_num_tokens[rank]
if num_padding > 0:
if hidden_states is None:
hidden_states = torch.empty((max_num_tokens, hidden_size),
dtype=dtype, device=device)
else:
hidden_states = F.pad(hidden_states, (0, 0, 0, num_padding))
hidden_states = tensor_model_parallel_all_gather(
input_=hidden_states, dim=0, tp_group=group)
hidden_states = remove_paddings_after_all_gather(
hidden_states, max_num_tokens, group_num_tokens)
return hidden_states
def tensor_model_parallel_all_gather_op_v2(
input_: torch.Tensor,
dim_size_list: List[int],
group_coordinator: GroupCoordinator,
non_leading_dim_size: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""
All gather the input tensor across model parallel group with only communication ops.
Note: compared to `tensor_model_parallel_all_gather_dp`, this method supports different
sizes in the first dim, and does not involve padding operation.
"""
all_size_equal = all([dim_size == dim_size_list[0] for dim_size in dim_size_list])
output_shape = (sum(dim_size_list), non_leading_dim_size)
output = torch.empty(output_shape, device=device, dtype=dtype)
if input_ is None:
input_ = torch.empty((0, non_leading_dim_size), device=device, dtype=dtype)
if all_size_equal:
torch.distributed.all_gather_into_tensor(
output, input_, group=group_coordinator.device_group)
else:
# Note: torch.split splits the tensor into chunks. And each chunk
# is a view of the original tensor.
tensor_list = torch.split(output, dim_size_list, dim=0)
torch.distributed.all_gather(
list(tensor_list), input_, group=group_coordinator.device_group)
return output
def process_post_attention_communication(
hidden_states: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
tp_group: Any = None,
):
"""
Processes distributed communication operations after attention computation.
This function performs necessary communication operations after attention computation
to ensure data synchronization across different parallel groups.
Supports two modes:
1. Tensor parallel mode: Uses tp_group for all-reduce and all-gather operations
2. Data parallel mode: Uses reduce-scatter and all-gather for global synchronization
Args:
hidden_states: Hidden states tensor after attention computation, can be None
dp_params: Data parallel runtime parameters containing token distribution and padding info
hidden_size: Dimension size of hidden states
dtype: Data type of the tensor
device: Device where the tensor is located
tp_group: Tensor parallel group, if None uses data parallel mode
Returns:
Hidden states tensor after communication synchronization processing
Note:
- When prefill_pad_to_token_num != -1, padding and unpadding operations will be performed
- Function selects optimal communication path based on token count and parallel strategy
"""
if tp_group is not None:
if dp_params.token_num != 0:
hidden_states = tensor_model_parallel_all_reduce(
hidden_states)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.dense_attn_token_split_list,
rank=get_parallel_rank_with_group(tp_group),
hidden_states=hidden_states,
group=tp_group,
)
else:
if dp_params.prefill_pad_to_token_num != -1:
# pad hidden_states to use reduce_scatter and global all gather
pad_num = dp_params.prefill_pad_to_token_num - dp_params.token_num
if pad_num != 0:
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_num))
hidden_states = tensor_model_parallel_reduce_scatter(
hidden_states, dim=0)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.attn_token_split_list_reduce_scatter,
rank=get_tp_world_rank(),
hidden_states=hidden_states,
group=get_tp_world_group(),
)
# get origin hidden_states for moe compute
hidden_states = remove_paddings_after_all_gather(
hidden_states, dp_params.prefill_pad_to_token_num,
dp_params.token_split_list)
else:
hidden_states = tensor_model_parallel_all_reduce(
hidden_states)
all_gather_group = get_dp_group()
all_gather_rank = get_data_parallel_group_rank()
hidden_states = tensor_model_parallel_all_gather_dp(
dp_params.token_split_list, all_gather_rank, hidden_states,
all_gather_group, hidden_size, dtype, device)
return hidden_states
def dp_model_forward(
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
embedding_layer: nn.Module,
model_norm_layer: nn.Module,
start_layer: int,
end_layer: int,
layers: List[nn.Module],
layer_input_norm_name: str,
prefill_dispatch_use_RS_AG: bool,
streams: Optional[Dict[str, torch.mlu.Stream]] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
"""run model with dp."""
if dp_params is None:
dp_params = get_dp_metadata(positions.numel(),
get_data_parallel_group_world_size(),
get_data_parallel_group_rank(),
get_tensor_model_parallel_world_size(),
prefill_dispatch_use_RS_AG)
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
if embedding_layer.__class__.__name__ == "DPVocabParallelEmbedding":
hidden_states = embedding_layer(input_ids, dp_params=dp_params)
else:
hidden_states = embedding_layer(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(start_layer, end_layer):
is_first_layer = (i == start_layer)
is_last_layer = (i == end_layer - 1)
next_input_layernorm = None
if not is_last_layer:
next_input_layernorm = getattr(layers[i+1], layer_input_norm_name)
hidden_states, residual = layers[i](
positions=positions,
hidden_states=hidden_states,
residual=residual,
dp_params=dp_params,
is_first_layer=is_first_layer,
is_last_layer=is_last_layer,
streams=streams,
next_input_layernorm=next_input_layernorm,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = model_norm_layer(hidden_states)
return hidden_states
def dp_layer_forward(
input_norm: nn.Module,
self_attn: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
mlp_kwargs: List[Dict[str, Any]],
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
hidden_size: int,
hidden_states_dtype: torch.dtype,
is_first_layer: bool = False,
is_last_layer: bool = False,
next_input_layernorm: Optional[nn.Module] = None,
enable_all2all: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
run layer with dp. dispatch all2all or rs+ag or common.
For mlp_kwargs, because all2all forward args is often different with common mlp args.
So here we decide that the mlp_kwargs[-1] is always all2all kwargs. For example:
Deepseek enable all2all, mlp_kwargs will be: [{mlp common forward kwargs}, {mlp all2all kwargs}].
Deepseek does not enable all2all, mlp_kwargs will be: [{mlp common forward kwargs}].
"""
if dp_params.layer_use_reduce_scatter:
common_metadata = get_common_metadata()
is_decode_only = common_metadata is not None and common_metadata.is_decode_only
use_all2all = enable_all2all and is_decode_only and isinstance(mlp, SparseMoeMlp)
forward_func = _dp_forward_layer_all2all if use_all2all else _dp_forward_layer_rs_ag
hidden_states, residual = forward_func(input_norm,
self_attn,
post_norm,
mlp,
mlp_kwargs,
positions,
hidden_states,
residual,
dp_params,
is_first_layer,
is_last_layer,
next_input_layernorm)
else:
hidden_states, residual = _dp_forward_layer_common(input_norm,
self_attn,
post_norm,
mlp,
mlp_kwargs,
positions,
hidden_states,
residual,
dp_params,
hidden_size,
hidden_states_dtype)
return hidden_states, residual
def _dp_forward_layer_rs_ag(
input_norm: nn.Module,
self_attn: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
mlp_kwargs: List[Dict[str, Any]],
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
is_first_layer: bool,
is_last_layer: bool,
next_input_layernorm: List[Optional[nn.Module]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""run layer with rs+ag."""
if residual is None:
residual = hidden_states
# We move the input_layernorm of i+1 layer to the end of i layer.
# But for the first layer, we need to do input_layernorm first.
if is_first_layer:
hidden_states = input_norm(hidden_states)
# Self Attention
hidden_states = self_attn(
positions=positions,
hidden_states=hidden_states,
)
# add residual here for the first layer
if is_first_layer and get_tensor_model_parallel_rank() == 0:
hidden_states = hidden_states + residual
hidden_states = tensor_model_parallel_reduce_scatter(
hidden_states, dim=0)
# move norm between rs and ag
if is_first_layer:
residual = hidden_states
hidden_states = post_norm(hidden_states)
else:
hidden_states, residual = post_norm(hidden_states, residual)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.attn_token_split_list_reduce_scatter,
rank=get_tp_world_rank(),
hidden_states=hidden_states,
group=get_tp_world_group(),
)
# mlp, use all cards
hidden_states = mlp(hidden_states, **mlp_kwargs[0])
hidden_states = tensor_model_parallel_reduce_scatter(
hidden_states, dim=0, tp_group=get_tp_world_group())
if is_last_layer:
hidden_states = hidden_states + residual
residual = None
else:
# To reduce layernorm computation, we move the layernorm of i+1 layer to
# the end of i layer. Besides, we fuse residual addition into layernorm.
assert next_input_layernorm is not None
hidden_states, residual = next_input_layernorm(hidden_states, residual)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.moe_token_split_list_reduce_scatter,
rank=get_tensor_model_parallel_rank(),
hidden_states=hidden_states,
group=get_tp_group(),
)
return hidden_states, residual
def _dp_forward_layer_all2all(
input_norm: nn.Module,
self_attn: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
mlp_kwargs: List[Dict[str, Any]],
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
is_first_layer: bool,
is_last_layer: bool,
next_input_layernorm: List[Optional[nn.Module]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""run layer with all2all."""
if residual is None:
residual = hidden_states
# We move the input_layernorm of i+1 layer to the end of i layer.
# But for the first layer, we need to do input_layernorm first.
if is_first_layer:
hidden_states = input_norm(hidden_states)
# Self Attention
hidden_states = self_attn(
positions=positions,
hidden_states=hidden_states,
)
# add residual here for the first layer
if is_first_layer and get_tensor_model_parallel_rank() == 0:
hidden_states = hidden_states + residual
hidden_states = tensor_model_parallel_reduce_scatter(
hidden_states, dim=0)
# move norm between rs and ag
if is_first_layer:
residual = hidden_states
hidden_states = post_norm(hidden_states)
else:
# add residual in norm for other layers
hidden_states, residual = post_norm(hidden_states, residual)
hidden_states = mlp.forward_all2all(hidden_states, **mlp_kwargs[-1])
if is_last_layer:
hidden_states = hidden_states + residual
residual = None
else:
# To reduce layernorm computation, we move the layernorm of i+1 layer to
# the end of i layer. Besides, we fuse residual addition into layernorm.
assert next_input_layernorm is not None
hidden_states, residual = next_input_layernorm(hidden_states, residual)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.moe_token_split_list_reduce_scatter,
rank=get_tensor_model_parallel_rank(),
hidden_states=hidden_states,
group=get_tp_group(),
)
return hidden_states, residual
def _dp_forward_layer_common(
input_norm: nn.Module,
self_attn: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
mlp_kwargs: List[Dict[str, Any]],
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
hidden_size: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""run layer with common."""
if residual is None:
residual = hidden_states
hidden_states = input_norm(hidden_states)
hidden_states = self_attn(
positions=positions,
hidden_states=hidden_states,
)
# add residual here
if get_tensor_model_parallel_rank() == 0:
hidden_states = hidden_states + residual
hidden_states = process_post_attention_communication(
hidden_states, dp_params, hidden_size, dtype, positions.device, None
)
residual = hidden_states[dp_params.token_num_offset:
dp_params.token_num_offset + dp_params.token_num]
hidden_states = post_norm(hidden_states)
hidden_states = mlp(hidden_states, **mlp_kwargs[0])
hidden_states = tensor_model_parallel_all_reduce(
hidden_states, tp_group=get_tp_world_group())
# add residual here
hidden_states = hidden_states[dp_params.token_num_offset:
dp_params.token_num_offset+dp_params.token_num]
hidden_states = hidden_states + residual
residual = hidden_states
return hidden_states, residual

View File

@@ -0,0 +1,245 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from typing import Callable, Optional, List, Union, Tuple
from vllm_mlu import _mlu_ops as mlu_ops
from vllm.attention import AttentionMetadata
from vllm.sequence import IntermediateTensors
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from transformers import PretrainedConfig
def hunyuan_decoder_layer_forward_base(
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_layernorm: Callable,
self_attn: Callable,
post_layernorm: Callable,
mlp: Callable,
kv_states: Optional[Tuple[torch.Tensor]] = None,
apply_residual_connection_post_layernorm: bool = False,
position_name: str = 'positions',
input_norm_fuse_en: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
smooth_quant_scale = None
if input_norm_fuse_en:
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
else:
layernorm_output = input_layernorm(hidden_states)
smooth_quant_scale = None
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self Attention
attention_output, ori_kv_states = self_attn(
**{position_name: positions},
hidden_states=layernorm_output,
residual=residual,
kv_states=kv_states,
smooth_quant_scale=smooth_quant_scale,
)
layernorm_output = post_layernorm(attention_output)
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# Fully Connected
hidden_states = mlp(layernorm_output, residual)
return hidden_states, ori_kv_states
def decoder_layer_forward_base(
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_layernorm: Callable,
self_attn: Callable,
post_layernorm: Callable,
mlp: Callable,
apply_residual_connection_post_layernorm: bool = False,
position_name: str = 'positions',
input_norm_fuse_en: bool = False,
post_norm_fuse_en: bool = False,
) -> torch.Tensor:
if input_norm_fuse_en:
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
else:
layernorm_output = input_layernorm(hidden_states)
smooth_quant_scale = None
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self Attention
attention_output = self_attn(
**{position_name: positions},
hidden_states=layernorm_output,
residual=residual,
smooth_quant_scale=smooth_quant_scale,
)
if post_norm_fuse_en:
layernorm_output, smooth_quant_scale = post_layernorm(attention_output)
else:
layernorm_output = post_layernorm(attention_output)
smooth_quant_scale = None
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# Fully Connected
kwargs = dict()
if post_norm_fuse_en:
kwargs['smooth_quant_scale'] = smooth_quant_scale
hidden_states = mlp(layernorm_output, residual, **kwargs)
return hidden_states
def decoder_model_forward_base(
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
layers: torch.nn.ModuleList,
embed_input_ids: Callable,
norm: Callable
) -> torch.Tensor:
hidden_states = embed_input_ids(input_ids)
for i in range(len(layers)):
layer = layers[i]
hidden_states = layer(
positions,
hidden_states,
)
hidden_states = norm(hidden_states)
return hidden_states
def hunyuan_decoder_model_forward_base_pp(
config: PretrainedConfig,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
layers: torch.nn.ModuleList,
start_layer: int,
end_layer: int,
embed_input_ids: Callable,
norm: Callable,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
cla_factor = getattr(config, "cla_share_factor", 1)
prev_kv_states = None
for i in range(start_layer, end_layer):
layer = layers[i]
hidden_states, kv_states = layer(
positions,
hidden_states,
prev_kv_states,
)
if (i - start_layer) % cla_factor == 0:
prev_kv_states = kv_states
else:
prev_kv_states = None
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
})
hidden_states = norm(hidden_states)
return hidden_states
def decoder_model_forward_base_pp(
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
layers: torch.nn.ModuleList,
start_layer: int,
end_layer: int,
embed_input_ids: Callable,
norm: Callable,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(start_layer, end_layer):
layer = layers[i]
hidden_states = layer(
positions,
hidden_states,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
})
hidden_states = norm(hidden_states)
return hidden_states
def is_smoothquant(quant_config: QuantizationConfig) -> bool:
return (quant_config is not None
and quant_config.get_name() == "SmoothQuant")
def is_per_token_smoothquant(quant_config: QuantizationConfig) -> bool:
return (is_smoothquant(quant_config)
and quant_config.input_quant_method == "per_token")
def compute_in_loop(func: Callable,
input: torch.Tensor,
chunk_size: int,
feature_size: Optional[int] = None,
**kwargs):
"""
divides input into chunks in the leading dimension (dimension 0), and
compute the chunks in a loop, instead of in a batch at once.
arg:
feature_size: size of output feature dimension. Provide it when the
the output's feature dimension would differ from the input's
feature dimension.
"""
total = input.shape[0]
# directly compute if there is only one chunk
if chunk_size >= total:
return func(input, **kwargs)
feature_size = feature_size or input.shape[1]
output = input.new_empty(total, feature_size)
num_chunks = (total + chunk_size - 1) // chunk_size
for i in range(num_chunks):
start = i * chunk_size
end = min((i + 1) * chunk_size, total)
output[start : end] = func(input[start : end], **kwargs)
return output

View File

@@ -0,0 +1,507 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import itertools
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm_mlu.mlu_forward_context import MLUDPMetadata
from vllm_mlu.model_executor.models.dp_utils import DataParallelRuntimeParams
from vllm_mlu.v1.attention.backends.mla.flashmla import (
FlashMLAPrefillMetadata, FlashMLAMetadata, MLACommonMetadata
)
from vllm_mlu.v1.attention.backends.utils import (
COMMON_METADATA_STR,
MLUCommonAttentionMetadata,
)
SEQUENCE_DIM_PARITION_THRESHOLD = 1024
def get_common_and_layer_metadata(
attn_metadata: Optional[dict],
) -> Tuple[Optional[MLUCommonAttentionMetadata], Optional[AttentionMetadata]]:
"""
Returns the common metadata and layer metadata from the given attention metadata.
"""
if attn_metadata is None:
return None, None
if isinstance(attn_metadata, dict):
assert COMMON_METADATA_STR in attn_metadata, (
f"attn_metadata must contain {COMMON_METADATA_STR} key"
)
assert len({id(v) for v in attn_metadata.values()}) == 2, (
f"attn_metadata should be a dict with two values, one for {COMMON_METADATA_STR} and "
f"the other for layers."
)
common_metadata = attn_metadata[COMMON_METADATA_STR]
layer_metadata = next((v for k, v in attn_metadata.items() if k != COMMON_METADATA_STR), None)
return common_metadata, layer_metadata
return None, attn_metadata
def should_skip_partition(layer_metadata, common_metadata) -> bool:
"""Helper function to simplify partition condition check"""
is_layer_metadata_invalid = (layer_metadata is None
or layer_metadata.prefill is None
or layer_metadata.query_start_loc is None
or layer_metadata.query_start_loc.numel() == 0)
is_common_metadata_invalid = common_metadata is None or not common_metadata.is_prefill_only
return is_layer_metadata_invalid or is_common_metadata_invalid
def attn_mcc_plan(
attn_metadata: Any,
dp_params: DataParallelRuntimeParams,
parts_to_split: int,
) -> Tuple[int, int]:
"""
Returns the number of parts for batch size dimension and the number of parts for sequence length dimension.
"""
# In the precedure of dummy run, attn_metadata is an instance of MLACommonMetadata
if not isinstance(attn_metadata, (dict, MLACommonMetadata, type(None))):
raise TypeError(f"attn_metadata must be dict or MLACommonMetadata, got {type(attn_metadata)}")
if isinstance(attn_metadata, dict):
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
else:
common_metadata, layer_metadata = None, attn_metadata
if dp_params is None:
# We don't support mcc with decode yet.
if should_skip_partition(layer_metadata, common_metadata):
return 1, 1
# The priority of batch size dimension to split is higher than sequence length dimension.
# And we ensure each subtask is not empty without dp.
num_prefills = layer_metadata.query_start_loc.numel() - 1
if num_prefills > 1:
return min(parts_to_split, num_prefills), 1
try:
max_query_len = torch.diff(layer_metadata.query_start_loc).max().item()
except RuntimeError:
return 1, 1
if max_query_len < SEQUENCE_DIM_PARITION_THRESHOLD:
return 1, 1
return 1, min(parts_to_split, max_query_len)
else:
if not all(is_prefill for is_prefill in dp_params.dp_is_prefill):
return 1, 1
max_bs = max(dp_params.batch_sizes)
if max_bs > 1:
# Ensure parts_to_split does not exceed max_bs to avoid unnecessary splits
if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD:
return 1, 1
return min(parts_to_split, max_bs), 1
else:
if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD:
return 1, 1
return 1, parts_to_split
def get_data_num_and_offset(total_size, parts_to_split):
"""
Get data size and offset for each.
For example, total batch 11, parallel_num 4, result is [3, 3, 3, 2], offsets is [0, 3, 6, 9]
total batch 8, parallel_num 4, result is [2, 2, 2, 2], offsets is [0, 2, 4, 6]
"""
# Calculate the quotient and remainder of total_size divided by parts_to_split
quotient = total_size // parts_to_split
remainder = total_size % parts_to_split
data_num_list = [quotient + 1] * remainder + [quotient] * (parts_to_split - remainder)
offset_list = [0] + list(itertools.accumulate(data_num_list))
return data_num_list, offset_list[:-1]
def split_dp_params(
dp_params: DataParallelRuntimeParams,
bs_parts_to_split: int,
seq_parts_to_split: int,
attn_data_parallel_size: int,
attn_tensor_parallel_size: int,
prefill_dispatch_use_RS_AG: bool,
dp_rank_: int,
) -> List[DataParallelRuntimeParams]:
assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \
"We don't support split batch and sequence dimensions concurrently."
if dp_params is None:
return [None] * bs_parts_to_split * seq_parts_to_split
if bs_parts_to_split * seq_parts_to_split == 1:
return list([dp_params])
if bs_parts_to_split == 1:
results : List[DataParallelRuntimeParams] = []
dp_seq_lens = []
for seq_len in dp_params.seq_lens:
tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split)
dp_seq_lens.append(tokens)
query_lens_per_dp_rank = []
# For each dp rank, the batch size is 0 or 1.
bs_offset = 0
for i in range(attn_data_parallel_size):
if dp_params.batch_sizes[i] > 0:
seq_len = dp_params.seq_lens[bs_offset]
tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split)
query_lens_per_dp_rank.append(tokens)
bs_offset += dp_params.batch_sizes[i]
else:
query_lens_per_dp_rank.append([0] * seq_parts_to_split)
for i in range(seq_parts_to_split):
dp_is_prefill = []
for dp_rank in range(attn_data_parallel_size):
dp_is_prefill.append(True)
results.append(MLUDPMetadata.make_oot(
data_parallel_rank=dp_rank_,
data_parallel_size=attn_data_parallel_size,
tensor_parallel_size=attn_tensor_parallel_size,
dp_token_nums=[query_lens_per_dp_rank[j][i] for j in range(attn_data_parallel_size)],
dp_is_prefill=dp_is_prefill,
prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG,
seq_lens=[seq_lens[i] for seq_lens in dp_seq_lens],
batch_sizes=dp_params.batch_sizes,
))
return results
bs_per_dp = dp_params.batch_sizes # [bs_rank_0, bs_rank_1, ...]
seq_lens_per_dp = dp_params.seq_lens # [seq_len_bs_0, seq_len_bs_1,...]
# [[bs_rank_0_part_0, bs_rank_0_part_1,...], [bs_rank_1_part_0, bs_rank_1_part_1,...], ...]
split_bs_per_dp = []
# [[
# [bs0_part_0_rank_0, bs1_part_0_rank_0, ...],
# [bs0_part_1_rank_0, bs1_part_1_rank_0, ...],
# ...
# ],
# [
# [bs0_part_0_rank_1, bs1_part_0_rank_1, ...],
# [bs0_part_1_rank_1, bs1_part_1_rank_1, ...],
# ...
# ],
# ]
split_query_lens_per_dp = []
for dp_rank in range(attn_data_parallel_size):
_bs, _offset = get_data_num_and_offset(bs_per_dp[dp_rank], bs_parts_to_split)
split_bs_per_dp.append(_bs)
split_query_lens_per_dp.append([])
for i in range(bs_parts_to_split):
start = sum(bs_per_dp[:dp_rank]) + _offset[i]
end = start + _bs[i]
split_query_lens_per_dp[-1].append(dp_params.seq_lens[start:end])
results : List[DataParallelRuntimeParams] = []
for i in range(bs_parts_to_split):
dp_query_lens = [sum(split_query_lens_per_dp[dp_rank][i]) for dp_rank in range(attn_data_parallel_size)]
seq_lens = []
for dp_rank in range(attn_data_parallel_size):
seq_lens += split_query_lens_per_dp[dp_rank][i]
batch_sizes = []
for dp_rank in range(attn_data_parallel_size):
batch_sizes.append(split_bs_per_dp[dp_rank][i])
dp_is_prefill = []
for dp_rank in range(attn_data_parallel_size):
dp_is_prefill.append(True)
results.append(MLUDPMetadata.make_oot(
data_parallel_rank=dp_rank_,
data_parallel_size=attn_data_parallel_size,
tensor_parallel_size=attn_tensor_parallel_size,
dp_token_nums=dp_query_lens,
dp_is_prefill=dp_is_prefill,
prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG,
seq_lens=seq_lens,
batch_sizes=batch_sizes,
))
return results
def split_input(
input: torch.Tensor,
bs_parts_to_split: int,
seq_parts_to_split: int,
attn_metadata_list: List[AttentionMetadata],
) -> List[torch.Tensor]:
assert seq_parts_to_split == 1 or bs_parts_to_split == 1, \
"We don't support split batch and sequence dimensions concurrently."
if input is None:
return [None] * bs_parts_to_split * seq_parts_to_split
if bs_parts_to_split * seq_parts_to_split == 1:
return list([input])
token_num_list = [0] * len(attn_metadata_list)
for i, metadata in enumerate(attn_metadata_list):
common_metadata, layer_metadata = get_common_and_layer_metadata(metadata)
if layer_metadata is not None:
token_num_list[i] = layer_metadata.num_actual_tokens
# A special case for dummy run
if layer_metadata is None and i == 0:
token_num_list[i] = input.shape[0]
results = list()
for i in range(bs_parts_to_split * seq_parts_to_split):
start = sum(token_num_list[:i])
end = start + token_num_list[i]
results.append(input[start:end])
return results
def split_positions(
positions: torch.Tensor,
bs_parts_to_split: int,
seq_parts_to_split: int,
attn_metadata: AttentionMetadata,
) -> List[torch.Tensor]:
if seq_parts_to_split == 1:
return [positions] * bs_parts_to_split
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
total_tokens = layer_metadata.num_actual_tokens if layer_metadata is not None else 0
tokens, offsets = get_data_num_and_offset(total_tokens, seq_parts_to_split)
positions_list = []
for i in range(seq_parts_to_split):
positions_list.append(positions[offsets[i]: offsets[i] + tokens[i]])
return positions_list
def split_attn_metadata(
attn_metadata: dict,
bs_parts_to_split: int,
seq_parts_to_split: int,
) -> List[Any]:
""" attn_metdata is a dict, which contains common and layer metadata."""
assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \
"We don't support split batch and sequence dimensions concurrently."
if bs_parts_to_split == 1 and seq_parts_to_split == 1:
return list([attn_metadata])
if attn_metadata is None:
return [None] * bs_parts_to_split * seq_parts_to_split
if seq_parts_to_split > 1:
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
if common_metadata is None or not hasattr(common_metadata, 'num_actual_tokens'):
raise ValueError("common_metadata is invalid or missing num_actual_tokens")
num_prefill_tokens = common_metadata.num_actual_tokens
tokens, offsets = get_data_num_and_offset(num_prefill_tokens, seq_parts_to_split)
device = common_metadata.seq_lens.device
sub_common_metadata, sub_layer_metadata = [], []
for i in range(seq_parts_to_split):
# query_start_loc tensor, which indices positions in input.
query_start_loc_tensor = torch.empty_like(common_metadata.query_start_loc)
query_start_loc_tensor[0] = 0
query_start_loc_tensor[1] = tokens[i]
# seq_lens tensor
seq_lens_tensor = torch.tensor(
[offsets[i] + tokens[i]],
dtype=common_metadata.seq_lens.dtype,
device=device
)
# seq_start_loc tensor, which indicates positions in the sequence(kv cache).
seq_start_loc_tensor = torch.empty_like(common_metadata.seq_start_loc)
seq_start_loc_tensor[0] = offsets[i]
seq_start_loc_tensor[1] = offsets[i] + tokens[i]
# max_query_len scalar
max_query_len = tokens[i]
# num_actual_tokens scalar
num_actual_tokens = tokens[i]
# num_input_tokens scalar
num_input_tokens = num_actual_tokens
# infer_mode
infer_mode = common_metadata.infer_mode
# update common metadata
sub_common_metadata.append(MLUCommonAttentionMetadata(
query_start_loc=query_start_loc_tensor,
query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used
seq_lens=seq_lens_tensor,
seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used
num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used
num_reqs=common_metadata.num_reqs, # FIXME: split when used
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_query_len,
block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used
slot_mapping=common_metadata.slot_mapping, # FIXME: split when used
seq_start_loc=seq_start_loc_tensor,
num_input_tokens=num_input_tokens,
infer_mode=infer_mode,
num_prefill_query_tokens=tokens[i],
num_prefill_kv_tokens=offsets[i] + tokens[i],
))
# slot_mapping tensor
slot_mapping = layer_metadata.slot_mapping[offsets[i]:offsets[i] + tokens[i]]
# update layer metadata
REQUIRED_NUM_DECODES = 0
REQUIRED_NUM_DECODE_TOKENS = 0
REQUIRED_NUM_PREFILLS = 1
if not hasattr(layer_metadata, 'num_prefills') or \
layer_metadata.num_prefills is None:
raise ValueError("layer_metadata.num_prefills is required")
assert layer_metadata.num_decodes == REQUIRED_NUM_DECODES and \
layer_metadata.num_decode_tokens == REQUIRED_NUM_DECODE_TOKENS and \
layer_metadata.num_prefills == REQUIRED_NUM_PREFILLS, (
f"num_decodes, num_decode_tokens, num_prefills must be {REQUIRED_NUM_DECODES}, {REQUIRED_NUM_DECODE_TOKENS}, "
f"{REQUIRED_NUM_PREFILLS}, but got {layer_metadata.num_decodes}, {layer_metadata.num_decode_tokens}, "
f"{layer_metadata.num_prefills}."
)
assert layer_metadata.prefill.chunked_context is None, (
f"chunked_context is only available for prefill with chunked context, "
f"and it is not supported when enabling mcc."
)
prefill_metadata = FlashMLAPrefillMetadata(
block_table=layer_metadata.prefill.block_table,
query_start_loc=query_start_loc_tensor,
max_query_len=max_query_len,
chunked_context=None,
num_prefills=layer_metadata.prefill.num_prefills,
max_seq_len=layer_metadata.prefill.max_seq_len,
)
# Note: for sequence dimension partition, we provide cu_seqlens_kv filed to
# indicates key/value size for flash attention operator.
prefill_metadata.cu_seqlens_kv = torch.empty_like(prefill_metadata.query_start_loc)
prefill_metadata.cu_seqlens_kv[0] = 0
prefill_metadata.cu_seqlens_kv[1] = offsets[i] + tokens[i]
sub_layer_metadata.append(FlashMLAMetadata(
num_reqs=layer_metadata.num_reqs,
max_query_len=max_query_len,
max_seq_len=max_query_len,
num_actual_tokens=num_actual_tokens,
query_start_loc=query_start_loc_tensor,
slot_mapping=slot_mapping,
num_decodes=layer_metadata.num_decodes,
num_decode_tokens=layer_metadata.num_decode_tokens,
num_prefills=layer_metadata.num_prefills,
num_prefill_tokens=tokens[i],
head_dim=layer_metadata.head_dim,
decode=layer_metadata.decode,
prefill=prefill_metadata,
))
sub_attn_metadata_list = []
for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata):
sub_attn_metadata_dict = {}
for key, value in attn_metadata.items():
if key == COMMON_METADATA_STR:
sub_attn_metadata_dict[key] = common_meta
else:
sub_attn_metadata_dict[key] = layer_meta
sub_attn_metadata_list.append(sub_attn_metadata_dict)
return sub_attn_metadata_list
elif bs_parts_to_split > 1:
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
if not hasattr(layer_metadata, 'num_prefills') or layer_metadata.num_prefills is None:
raise ValueError("layer_metadata.num_prefills is required")
total_batch = layer_metadata.num_prefills
batch_sizes, offsets = get_data_num_and_offset(total_batch, bs_parts_to_split)
sub_common_metadata, sub_layer_metadata = [], []
for i in range(bs_parts_to_split):
# query_start_loc tensor
start, end = offsets[i], offsets[i] + batch_sizes[i]
query_start_loc_tensor = common_metadata.query_start_loc[start:end+1].clone()
if i > 0:
query_start_loc_tensor -= common_metadata.query_start_loc[start]
# block_table
block_tables = torch.empty(
(batch_sizes[i], 0),
dtype=layer_metadata.prefill.block_table.dtype,
device=layer_metadata.prefill.block_table.device,
)
# seq_lens tensor
seq_lens_tensor = common_metadata.seq_lens[start:end].clone()
# seq_start_loc tensor
seq_start_loc_tensor = query_start_loc_tensor
# max_query_len scalar
max_query_len = seq_lens_tensor.max().item() if seq_lens_tensor.numel() > 0 else 0
# num_actual_tokens scalar
num_actual_tokens = seq_start_loc_tensor[-1].item()
# num_input_tokens scalar
num_input_tokens = num_actual_tokens
# infer_mode
infer_mode = common_metadata.infer_mode
# slot_mapping tensor
slot_mapping_start = 0
for data in sub_common_metadata:
slot_mapping_start += data.num_actual_tokens
slot_mapping_tensor = layer_metadata.slot_mapping[
slot_mapping_start:slot_mapping_start + num_actual_tokens
]
# update common metadata
sub_common_metadata.append(MLUCommonAttentionMetadata(
query_start_loc=query_start_loc_tensor,
query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used
seq_lens=seq_lens_tensor,
seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used
num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used
num_reqs=common_metadata.num_reqs, # FIXME: split when used
block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used
slot_mapping=common_metadata.slot_mapping, # FIXME: split when used
seq_start_loc=seq_start_loc_tensor,
max_query_len=max_query_len,
max_seq_len=max_query_len,
num_actual_tokens=num_actual_tokens,
num_input_tokens=num_input_tokens,
infer_mode=infer_mode,
num_prefill_query_tokens=num_actual_tokens,
num_prefill_kv_tokens=num_actual_tokens,
))
# update layer_metadata
prefill_metadata = FlashMLAPrefillMetadata(
block_table=block_tables,
query_start_loc=query_start_loc_tensor,
max_query_len=max_query_len,
chunked_context=None,
num_prefills=batch_sizes[i],
max_seq_len=max_query_len,
)
sub_layer_metadata.append(FlashMLAMetadata(
num_reqs=batch_sizes[i],
max_query_len=max_query_len,
max_seq_len=max_query_len,
num_actual_tokens=num_actual_tokens,
query_start_loc=query_start_loc_tensor,
slot_mapping=slot_mapping_tensor,
num_decodes=layer_metadata.num_decodes, # useless field
num_decode_tokens=0, # useless field
num_prefills=batch_sizes[i],
num_prefill_tokens=num_actual_tokens,
head_dim=layer_metadata.head_dim,
decode=layer_metadata.decode,
prefill=prefill_metadata,
))
sub_attn_metadata_list = []
for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata):
sub_attn_metadata_dict = {}
for key, value in attn_metadata.items():
if key == COMMON_METADATA_STR:
sub_attn_metadata_dict[key] = common_meta
else:
sub_attn_metadata_dict[key] = layer_meta
sub_attn_metadata_list.append(sub_attn_metadata_dict)
return sub_attn_metadata_list
def execute_with_updated_forward_context(
vllm_config: VllmConfig,
attn_metadata: AttentionMetadata,
func: Callable,
kwargs: Dict[str, Any],
):
with set_forward_context(attn_metadata, vllm_config):
return func(**kwargs)

View File

@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Type, Union
import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.models.registry import (
_LazyRegisteredModel, _RegisteredModel, _ModelRegistry)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__model_executor__models__registry___ModelRegistry__register_model(
self,
model_arch: str,
model_cls: Union[type[nn.Module], str],
) -> None:
"""
Register an external model to be used in vLLM.
`model_cls` can be either:
- A [`torch.nn.Module`][] class directly referencing the model.
- A string in the format `<module>:<class>` which can be used to
lazily import the model. This is useful to avoid initializing CUDA
when importing the model and thus the related error
`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
"""
if not isinstance(model_arch, str):
msg = f"`model_arch` should be a string, not a {type(model_arch)}"
raise TypeError(msg)
'''
=============================
Modify by vllm_mlu
=============================
@brief: change mlu models register log level
'''
if model_arch in self.models:
if isinstance(model_cls, str) and "MLU" in model_cls:
logger.debug(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls)
else:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls)
'''
==================
End of MLU Hijack
==================
'''
if isinstance(model_cls, str):
split_str = model_cls.split(":")
if len(split_str) != 2:
msg = "Expected a string in the format `<module>:<class>`"
raise ValueError(msg)
model = _LazyRegisteredModel(*split_str)
elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
model = _RegisteredModel.from_model_cls(model_cls)
else:
msg = ("`model_cls` should be a string or PyTorch model class, "
f"not a {type(model_arch)}")
raise TypeError(msg)
self.models[model_arch] = model
MluHijackObject.apply_hijack(
_ModelRegistry,
_ModelRegistry.register_model,
vllm__model_executor__models__registry___ModelRegistry__register_model
)

View File

@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import json
import os
import torch
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.forward_context import get_forward_context
def set_attn_compute_dtype_v1(attn_metadata, dtype: torch.dtype):
'''
set attn compute_dtype for v1
'''
if isinstance(attn_metadata, dict):
for _, metadata in attn_metadata.items():
metadata.compute_dtype = dtype
else:
metadata.compute_dtype = dtype
def set_attn_compute_dtype(dtype: torch.dtype):
'''
set attn compute_dtype.
TODO: FA may standardize on half precision computation in the future
set_attn_compute_dtype might be deprecated and removed
'''
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
set_attn_compute_dtype_v1(attn_metadata, dtype)
def is_tie_word_embeddings(
model_config: ModelConfig,
org_tie_word_embeddings: bool
) -> bool:
'''
Vllm language model config for multimodal model may have wrong tie_word_embeddings,
for example, InternVL3.5-38B, InternVL3.5-30B-A3B, etc.
This function is a WorkAround.
'''
from vllm.lora.utils import get_adapter_absolute_path
if not model_config.is_multimodal_model:
return org_tie_word_embeddings
model_path = get_adapter_absolute_path(model_config.model)
config_path = os.path.join(model_path, "config.json")
if not os.path.exists(config_path):
return org_tie_word_embeddings
tie_word_embeddings = org_tie_word_embeddings
with open(config_path) as f:
config = json.load(f)
# first, we find if tie_word_embeddings config is in overall config
if config.get("tie_word_embeddings") is not None:
tie_word_embeddings = config["tie_word_embeddings"]
# then, we find if tie_word_embeddings config is in language model config
if (config.get("llm_config") is not None
and config["llm_config"].get("tie_word_embeddings") is not None):
tie_word_embeddings = config["llm_config"]["tie_word_embeddings"]
return tie_word_embeddings

View File

@@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Callable, Any
import torch
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.distributed import (
get_parallel_rank_with_group,
get_parallel_world_size_with_group,
)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
vllm__model_executor__parameter__BasevLLMParameter____init__org = BasevLLMParameter.__init__
def vllm__model_executor__parameter__BasevLLMParameter____init__(
self,
data: torch.Tensor,
weight_loader: Callable,
tp_group: Any = None
):
vllm__model_executor__parameter__BasevLLMParameter____init__org(
self, data, weight_loader
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add self.tp_group, world_size and tp_rank
'''
if tp_group is not None:
self.tp_group = tp_group
self.tp_world_size = get_parallel_world_size_with_group(self.tp_group)
self.tp_rank = get_parallel_rank_with_group(self.tp_group)
'''
=================
End of MLU Hijack
=================
'''
MluHijackObject.apply_hijack(BasevLLMParameter,
BasevLLMParameter.__init__,
vllm__model_executor__parameter__BasevLLMParameter____init__)

View File

@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""
Warmup kernels used during model execution.
This is useful specifically for JIT'ed kernels as we don't want JIT'ing to
happen during model execution.
"""
from typing import TYPE_CHECKING
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.v1.worker.gpu_worker import Worker
logger = init_logger(__name__)
def kernel_warmup(worker: "Worker"):
'''
=============================
Modify by vllm_mlu
=============================
@brief: skip deep GEMM warmup, flashinfer autotune, and
flash infer attention warmup
'''
'''
==================
End of MLU Hijack
==================
'''
pass