[Model] Support DeepSeek-V4

This commit is contained in:
chenxb002
2026-04-24 09:50:34 +08:00
commit b9925203b8
172 changed files with 44780 additions and 0 deletions

15
vllm_mlu/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
def register_mlu_platform():
"""Register the MLU platform."""
return "vllm_mlu.platforms.mlu.MLUPlatform"
def register_mlu_hijack():
"""Register the MLU models and hijack."""
from vllm_mlu import mlu_hijack
from vllm_mlu.model_executor.models import register_model
register_model()
return

1853
vllm_mlu/_mlu_ops.py Normal file

File diff suppressed because it is too large Load Diff

107
vllm_mlu/_mlu_utils.py Normal file
View File

@@ -0,0 +1,107 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import os
import torch
import vllm.envs as envs
def _check_env(env, default=False):
if env in os.environ:
return os.environ[env].lower() in ["true", "1"]
return default
def _check_env_value(env, default=0):
if env in os.environ:
if not os.environ[env].isdigit():
raise ValueError(f"'{env}' should be set with integer")
value = int(os.environ[env])
return value
return default
def _check_env_float(env, default=0):
if env in os.environ:
try:
value = float(os.environ[env])
except ValueError:
raise ValueError(f"'{env}' should be set with float")
return value
return default
# VLLM_LATENCY_DEBUG: Get more kernel info for benchmark latency.
VLLM_LATENCY_DEBUG = _check_env("VLLM_LATENCY_DEBUG", default=False)
# VLLM_LATENCY_DEBUG_NO_DEVICE: Get more kernel info(without device) for benchmark latency.
VLLM_LATENCY_DEBUG_NO_DEVICE = _check_env("VLLM_LATENCY_DEBUG_NO_DEVICE", default=False)
# VLLM_DUMP_TENSORS: Dump each layer outputs when running vLLM inference.
VLLM_DUMP_OUTPUTS = _check_env("VLLM_DUMP_OUTPUTS", default=False)
# VLLM_DUMP_MLU_INFO: Get device info when running vLLM inference.
VLLM_DUMP_MLU_INFO = _check_env("VLLM_DUMP_MLU_INFO", default=False)
# VLLM_DUMP_MLU_INFO_DEBUG: Dump device debug info when running vLLM inference.
VLLM_DUMP_MLU_INFO_DEBUG = _check_env("VLLM_DUMP_MLU_INFO_DEBUG", default=False)
# VLLM_SCHEDULER_PROFILE: Profiling vLLM scheduler.
VLLM_SCHEDULER_PROFILE = _check_env("VLLM_SCHEDULER_PROFILE", default=False)
# VLLM_GRAPH_DEBUG: Debug the graph status when running decoder, default value is True.
# Set to False to disable warning messages.
VLLM_GRAPH_DEBUG = _check_env("VLLM_GRAPH_DEBUG", default=True)
# VLLM_AVG_MOE_EN: make moe experts workload balance, default value is False.
VLLM_AVG_MOE_EN = _check_env("VLLM_AVG_MOE_EN", default=False) or _check_env("VLLM_RANDOM_MOE_EN", default=False)
VLLM_RANDOM_MOE_EN = _check_env("VLLM_RANDOM_MOE_EN", default=False)
# VLLM_LOGITS_USE_ALL_GATHER: use allgather for logits collection, default value is False.
VLLM_LOGITS_USE_ALL_GATHER = _check_env("VLLM_LOGITS_USE_ALL_GATHER", default=False)
VLLM_LATENCY_DEBUG_EN = (VLLM_LATENCY_DEBUG or VLLM_LATENCY_DEBUG_NO_DEVICE)
VLLM_LATENCY_DEBUG_WITH_DEVICE_EN = (VLLM_LATENCY_DEBUG and not VLLM_LATENCY_DEBUG_NO_DEVICE)
VLLM_DUMP_MLU_INFO_EN = (VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and VLLM_DUMP_MLU_INFO)
VLLM_DUMP_MLU_INFO_DEBUG = (VLLM_DUMP_MLU_INFO_DEBUG and VLLM_DUMP_MLU_INFO_EN)
# VLLM_V1_USE_UNCHUNK_SCHED: v1 use unchunk scheduler, default value is True.
VLLM_V1_USE_UNCHUNK_SCHED = _check_env("VLLM_V1_USE_UNCHUNK_SCHED", default=True)
# VLLM_V1_MIN_PREFILL_BATCH: the min scheduling batch in v1, default is 1.
VLLM_V1_MIN_PREFILL_BATCH = _check_env_value("VLLM_V1_MIN_PREFILL_BATCH", default=1)
# VLLM_V1_USE_FULL_GRAPH: v1 use full graph capture, default value is True.
VLLM_V1_USE_FULL_GRAPH = _check_env("VLLM_V1_USE_FULL_GRAPH", default=True)
# VLLM_V1_BENCHMARK: v1 benchmark, default value is False.
VLLM_V1_BENCHMARK = _check_env("VLLM_V1_BENCHMARK", default=False)
# VLLM_MTP_DEBUG: use to show mtp accepted rate, default value is False.
VLLM_MTP_DEBUG = _check_env("VLLM_MTP_DEBUG", default=False)
# VLLM_MTP_NO_QUANT: mtp use origin dtype, quant_config use None
VLLM_MTP_NO_QUANT = _check_env("VLLM_MTP_NO_QUANT", default=False)
# VLLM_MTP_FIXED_ACCEPTANCE_RATE: use fixed acceptance rate, default value is None.
VLLM_MTP_FIXED_ACCEPTANCE_RATE = _check_env_float("VLLM_MTP_FIXED_ACCEPTANCE_RATE", default=None)
# VLLM_MTP_NO_QUANT: mtp use origin dtype, quant_config use None
VLLM_MTP_NO_QUANT = _check_env("VLLM_MTP_NO_QUANT", default=False)
# VLLM_V1_UNCHUNK_SCHED_LOG: print v1 unchunk scheduler state
VLLM_V1_UNCHUNK_SCHED_LOG = _check_env("VLLM_V1_UNCHUNK_SCHED_LOG", default=False)
# VLLM_MOE_PREFILL_CHUNK_SIZE: in number of tokens. enabled when > 0.
VLLM_MOE_PREFILL_CHUNK_SIZE = _check_env_value("VLLM_MOE_PREFILL_CHUNK_SIZE", default=0)
# VLLM_CI_ACCURACY_TEST: CI accuracy test, default value is False.
VLLM_CI_ACCURACY_TEST = _check_env("VLLM_CI_ACCURACY_TEST", default=False)
# VLLM_DISAGG_TRANS_ALL_BLOCKS: optimize the performance of disagg
VLLM_DISAGG_TRANS_ALL_BLOCKS = _check_env("VLLM_DISAGG_TRANS_ALL_BLOCKS", default=True)
# vllm disagg debug
VLLM_DISAGG_CNPX_EXECUTE = _check_env("VLLM_DISAGG_CNPX_EXECUTE", default=False)
VLLM_DISAGG_CNPX_REQUEST = _check_env("VLLM_DISAGG_CNPX_REQUEST", default=False)
VLLM_DISAGG_FAKE_DECODER = _check_env("VLLM_DISAGG_FAKE_DECODER", default=False)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

351
vllm_mlu/attention/layer.py Normal file
View File

@@ -0,0 +1,351 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Any, cast
import torch
from torch import nn
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import MLAAttentionImpl
from vllm.attention.layer import Attention, MLAAttention, _init_kv_cache_quant
from vllm.attention.selector import get_attn_backend
from vllm.config.cache import CacheConfig
from vllm.config.vllm import QuantizationConfig, VllmConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.platforms import current_platform
from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm_mlu.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.v1.kv_cache_interface import (
MLUFullAttentionSpec,
MLUMLAAttentionSpec,
MLUSlidingWindowSpec,
)
@maybe_transfer_kv_layer
def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
kwargs: dict[str, Any] = {},
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
'''
=============================
Modify by vllm_mlu
=============================
@brief: add return for self.impl.forward and it's param kwargs
'''
output = self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
kwargs=kwargs,
)
'''
==================
End of MLU Hijack
==================
'''
return output
class Attention_MluHijack(Attention):
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
if self.sliding_window is not None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace SlidingWindowSpec with MLUSlidingWindowSpec.
'''
return MLUSlidingWindowSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
sliding_window=self.sliding_window,
)
'''
==================
End of MLU Hijack
==================
'''
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace FullAttentionSpec with MLUFullAttentionSpec.
'''
return MLUFullAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
)
'''
==================
End of MLU Hijack
==================
'''
class MLAAttention_MluHijack(MLAAttention):
def __init__(
self,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
kv_b_proj: ColumnParallelLinear,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_sparse: bool = False,
indexer: object | None = None,
**extra_impl_args,
) -> None:
nn.Module.__init__(self)
self.num_heads = num_heads
self.scale = scale
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
# self.head_size = kv_lora_rank + qk_rope_head_dim
self.layer_name = prefix
'''
=============================
Modify by vllm_mlu
=============================
@brief: insert num_kv_heads for mlu platform
'''
self.head_size = qk_nope_head_dim + qk_rope_head_dim
self.num_kv_heads = extra_impl_args.pop("num_kv_heads", None)
if self.num_kv_heads is None:
self.num_kv_heads = num_heads
self.decoder_attn_dtype = None
decoder_attn_dtype = get_current_vllm_config().mlu_config.decoder_attn_dtype
if decoder_attn_dtype in ["int8", "fp8_e4m3", "fp8"]:
self.decoder_attn_dtype = (
torch.int8 if decoder_attn_dtype == "int8"
else torch.float8_e4m3fn
)
extra_impl_args['decoder_attn_dtype'] = self.decoder_attn_dtype
'''
==================
End of MLU Hijack
==================
'''
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
# Initialize KV cache quantization attributes
_init_kv_cache_quant(
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
)
dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
self.head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=True,
use_sparse=use_sparse,
)
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls(
self.num_heads,
self.head_size,
self.scale,
self.num_kv_heads,
None, # alibi_slops
None, # sliding_window
kv_cache_dtype,
None, # logits_soft_cap
AttentionType.DECODER, # attn_dtype
None, # kv_sharing_target_layer_name
**extra_impl_args,
)
self.dtype = dtype
self.use_direct_call = not current_platform.opaque_attention_op()
if current_platform.is_out_of_tree():
self.use_direct_call = False
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
'''
=============================
Modify by vllm_mlu
=============================
@brief: support kv8 and deepseek v3.2
'''
self.kv_cache = [
[torch.tensor([]), torch.tensor([]), torch.tensor([])]
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
self.impl.use_mla = True
'''
==================
End of MLU Hijack
==================
'''
self.use_sparse = use_sparse
# Initialize q/k/v range constants.
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
kv_cache_dtype = kv_cache_dtype_str_to_dtype(
self.kv_cache_dtype, vllm_config.model_config
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace MLAAttentionSpec with MLUMLAAttentionSpec.
'''
index_head_dim, index_n_heads = 0, 0
if vllm_config.model_config.hf_text_config.model_type == "deepseek_v32":
index_head_dim = vllm_config.model_config.hf_text_config.index_head_dim
index_n_heads = 1
if vllm_config.model_config.hf_text_config.model_type == "deepseek_v4":
index_head_dim = vllm_config.model_config.hf_text_config.index_head_dim
index_n_heads = 1
return MLUMLAAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_size,
dtype=kv_cache_dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype,
index_head_dim=index_head_dim,
index_n_heads=index_n_heads,
)
'''
==================
End of MLU Hijack
==================
'''
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output_shape: torch.Size | None = None,
kwargs: dict[str, Any] = {},
) -> torch.Tensor:
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
assert not self.use_direct_call, "MLU-V1 does not support direct call."
if self.attn_backend.accept_output_buffer:
output_lse = None
output_shape = (output_shape if output_shape is not None else query.shape)
output_shape = [output_shape[0], self.num_heads * self.v_head_dim]
output = torch.empty(
output_shape,
dtype=self.dtype if query.dtype == torch.int8 else query.dtype,
device=query.device,
)
hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.v_head_dim)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.v_head_dim)
if not kwargs:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name
)
attn_output_list = output
else:
attn_output_list = unified_attention_with_output(
query, key, value, output, self.layer_name, kwargs=kwargs)
if isinstance(attn_output_list, (list, tuple)) and len(attn_output_list) > 1:
output_lse = attn_output_list[1]
if output_lse is not None:
return output.view(-1, hidden_size), output_lse
else:
return output.view(-1, hidden_size)
'''
==================
End of MLU Hijack
==================
'''
else:
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name
)
MluHijackObject.apply_hijack(
Attention,
Attention.get_kv_cache_spec,
Attention_MluHijack.get_kv_cache_spec,
)
MluHijackObject.apply_hijack(
MLAAttention,
MLAAttention.__init__,
MLAAttention_MluHijack.__init__,
)
MluHijackObject.apply_hijack(
MLAAttention,
MLAAttention.get_kv_cache_spec,
MLAAttention_MluHijack.get_kv_cache_spec,
)
MluHijackObject.apply_hijack(
MLAAttention,
MLAAttention.forward,
MLAAttention_MluHijack.forward,
)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import inspect
from collections.abc import Callable
from functools import wraps
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group,
)
def maybe_transfer_kv_layer(func: Callable) -> Callable:
"""Decorator that handles KV layer transfer prior and after execution of
an attention layer, if enabled. Otherwise, the wrapper is a no-op.
On entry: waits for the KV layer from the connector.
On exit: saves the KV layer to the connector.
"""
# Import at runtime to avoid circular dependency
from vllm.attention.layer import get_attention_context
# Inspect the signature ONCE when the decorator is applied.
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
# Find the index of 'layer_name' parameter.
try:
layer_name_index = param_names.index("layer_name")
except ValueError as e:
raise TypeError(
f"Function {func.__name__} must have a 'layer_name' parameter"
) from e
@wraps(func)
def wrapper(*args, **kwargs):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return func(*args, **kwargs)
layer_name: str = args[layer_name_index]
# Extract attention context (layer-specific metadata, layer, and kv_cache)
attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name)
connector = get_kv_transfer_group()
if attn_metadata is None or not connector.has_connector_metadata():
return func(*args, **kwargs)
# Wait for KV layer on entry
connector.wait_for_layer_load(layer_name)
# Execute the function
result = func(*args, **kwargs)
# Save KV cache layer on exit
if kwargs is None or kwargs.get("save_kv_layer", True):
connector.save_kv_layer(layer_name, kv_cache, attn_metadata)
return result
return wrapper

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""
This module defines a framework for sampling benchmark requests from various
datasets. Each dataset subclass of BenchmarkDataset must implement sample
generation. Supported dataset types include:
- ShareGPT
- Random (synthetic)
- Sonnet
- BurstGPT
- HuggingFace
- VisionArena
"""
from tempfile import NamedTemporaryFile
import numpy as np
from vllm.benchmarks.datasets import RandomMultiModalDataset
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__benchmarks__datasets__RandomMultiModalDataset__generate_synthetic_video(
self, width: int, height: int, num_frames: int
) -> dict:
"""Generate synthetic video with random values.
Creates a video with random pixel values, encodes it to MP4 format,
and returns the content as bytes.
"""
import cv2
random_pixels = self._rng.integers(
0,
256,
(num_frames, height, width, 3),
dtype=np.uint8,
)
# Create a temporary video file in memory
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
fps = 30 # frames per second
with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
temp_path = temp_file.name
# Create video writer
video_writer = cv2.VideoWriter(
temp_path, fourcc=fourcc, fps=fps, frameSize=(width, height)
)
if not video_writer.isOpened():
raise RuntimeError("Failed to create video writer")
for frame in random_pixels:
video_writer.write(frame)
video_writer.release()
temp_file.close()
# Read the video file content
with open(temp_path, "rb") as f:
video_content = f.read()
return {"bytes": video_content}
MluHijackObject.apply_hijack(
RandomMultiModalDataset,
RandomMultiModalDataset.generate_synthetic_video,
vllm__benchmarks__datasets__RandomMultiModalDataset__generate_synthetic_video,
)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,185 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
import operator
from typing import Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.platforms import current_platform
from vllm.logger import init_logger
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fx_utils import is_func
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
class FixFunctionalizationPass_MluHijack(FixFunctionalizationPass):
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if current_platform.is_xpu():
logger.debug(
"XPU platform does not support fix functionalizationpass currently."
)
return
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
for node in graph.nodes:
'''
=============================
Modify by vllm_mlu
=============================
@brief: skip custom op on mlu
'''
if current_platform.is_out_of_tree():
continue # skip the count on mlu
'''
==================
End of MLU Hijack
==================
'''
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting
kwargs = node.kwargs
at_target = node.args[0]
if at_target == torch.ops._C.rotary_embedding.default:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = self.getitem_users(node)
if (
is_func(query, operator.getitem)
and is_func(key, operator.getitem)
and query.args[0] == key.args[0]
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
and all(
is_func(user, torch.ops.aten.slice_scatter.default)
for getitem_node in getitem_nodes.values()
for user in getitem_node.users
)
):
# Pattern where query and key are slices of an mm_node.
# While functionalized, results at [1] and [2] are scattered
# back into mm_node. So after de-functionalization, we can
# just use mm_node directly.
mm_node = query.args[0].args[0]
for user in getitem_nodes.values():
for user_of_getitem in user.users:
if is_func(
user_of_getitem, torch.ops.aten.slice_scatter.default
):
user_of_getitem.replace_all_uses_with(mm_node)
self._remove(user_of_getitem)
self._remove(user)
self.insert_defunctionalized(graph, node)
self._remove(node)
else:
# Directly replace the auto_functionalize(rotary_embedding)
# with the inplace rotary_embedding. In theory, we shouldn't
# do this blindly, but in practice in vLLM it's ok. The best
# solution is to use auto_functionalization_v2 and then use
# inductor's builtin defunctionalization (reinplacing) pass.
mutated_args = {1: "query", 2: "key"}
self.defunctionalize(graph, node, mutated_args)
# rms_norm replacements avoid the most copies for LLaMa.
elif at_target == torch.ops._C.fused_add_rms_norm.default:
mutated_args = {1: "input", 2: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
mutated_args = {1: "result", 2: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
mutated_args = {1: "result", 2: "scale", 3: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target in [
torch.ops._C.rms_norm.default,
torch.ops._C.rms_norm_static_fp8_quant.default,
]:
mutated_args = {1: "result"}
self.defunctionalize(graph, node, mutated_args)
# For some reason we need to specify the args for both
# silu_and_mul and silu_and_mul_quant. The kwargs
# pathway gets the wrong answer.
elif at_target == torch.ops._C.silu_and_mul.default:
mutated_args = {1: "result"}
self.defunctionalize(
graph, node, mutated_args, args=("result", "input")
)
elif at_target == torch.ops._C.silu_and_mul_quant.default:
mutated_args = {1: "result"}
self.defunctionalize(
graph, node, mutated_args, args=("result", "input", "scale")
)
elif (
hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")
and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default
):
mutated_args = {1: "result", 2: "result_block_scale"}
self.defunctionalize(
graph,
node,
mutated_args,
args=(
"result",
"result_block_scale",
"input",
"input_global_scale",
),
)
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
mutated_args = {1: "qkv"}
args = (
"qkv",
"num_heads_q",
"num_heads_k",
"num_heads_v",
"head_dim",
"eps",
"q_weight",
"k_weight",
"cos_sin_cache",
"is_neox",
"position_ids",
)
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
else:
continue # skip the count
count += 1
self.dump_graph(graph, "before_cleanup")
# Remove the nodes all at once
count_removed = len(self.nodes_to_remove)
for node in self.nodes_to_remove:
graph.erase_node(node)
logger.debug(
"De-functionalized %s nodes, removed %s nodes", count, count_removed
)
self.nodes_to_remove.clear()
MluHijackObject.apply_hijack(
FixFunctionalizationPass,
FixFunctionalizationPass.__call__,
FixFunctionalizationPass_MluHijack.__call__
)

View File

@@ -0,0 +1,242 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import dataclasses
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any
from unittest.mock import patch
import torch
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import weak_ref_tensors
from vllm.compilation.cuda_graph import (
CUDAGraphEntry,
CUDAGraphWrapper,
CUDAGraphOptions,
)
from vllm_mlu.v1.attention.backends.utils import MLUInferMode
logger = init_logger(__name__)
'''
=============================
Modify by vllm_mlu
=============================
@brief: specialized graph entry for prefill graphs
'''
@dataclasses.dataclass
class PrefillGraphEntry:
batch_size: int = 0
seq_len: int = 0
cudagraph: torch.mlu.MLUGraph | None = None
output: Any | None = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: list[int] | None = None
'''
==================
End of MLU Hijack
==================
'''
class MLUGraphWrapper(CUDAGraphWrapper):
def __init__(
self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
cudagraph_options: CUDAGraphOptions | None = None,
):
super().__init__(runnable, vllm_config, runtime_mode, cudagraph_options)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add separate dict for prefill graph entries
'''
self.prefill_mlugraph_entry: PrefillGraphEntry | None = None
'''
==================
End of MLU Hijack
==================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: check if running in prefill mode
'''
def is_running_in_prefill(self, entry: PrefillGraphEntry | None = None) -> bool:
forward_context = get_forward_context()
if forward_context.attn_metadata is None:
return False
infer_mode = forward_context.attn_metadata['common_metadata'].infer_mode
seq_lens_cpu = forward_context.attn_metadata['common_metadata'].seq_lens_cpu
if entry is not None \
and infer_mode == MLUInferMode.PREFILL_ONLY \
and seq_lens_cpu.size(0) == entry.batch_size \
and (seq_lens_cpu == entry.seq_len).all().item():
return True
return False
'''
==================
End of MLU Hijack
==================
'''
def __call__(
self,
is_capturing_prefill: bool = False,
prefill_enable_mlugraph: bool = False,
prefill_batch_size: int = 0,
prefill_seq_len: int = 0,
is_running_drafter: bool = False,
*args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
if (
cudagraph_runtime_mode == CUDAGraphMode.NONE
or cudagraph_runtime_mode != self.runtime_mode
):
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
# running without cudagraphs.
# We do not trigger capture/replay if the runtime mode is not
# matches. This enables properly dispatching to the correct
# CUDAGraphWrapper when nesting multiple instances with different
# runtime modes.
return self.runnable(*args, **kwargs)
'''
=============================
Modify by vllm_mlu
=============================
@brief: handle prefill graph separately
@brief: skip check in running drafter model
'''
if is_capturing_prefill: # PREFILL capture
self.prefill_mlugraph_entry = PrefillGraphEntry(
batch_size=prefill_batch_size,
seq_len=prefill_seq_len)
else: # FULL/DECODE capture
if batch_descriptor not in self.concrete_cudagraph_entries:
# create a new entry for this batch descriptor
self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(
batch_descriptor=batch_descriptor
)
if ((self.is_running_in_prefill(self.prefill_mlugraph_entry) and prefill_enable_mlugraph)
or is_capturing_prefill):
entry = self.prefill_mlugraph_entry
logger.debug(
f"Hitting a prefill cudagraph on {self.runtime_mode.name}, "
f"batch_size: {entry.batch_size}, seq_len: {entry.seq_len}")
else: # FULL/DECODE capture
entry = self.concrete_cudagraph_entries[batch_descriptor]
logger.debug(
"Hitting a decode cudagraph on (%s, %s)",
self.runtime_mode.name,
entry.batch_descriptor,
)
if entry.cudagraph is None:
if self.cudagraph_options.debug_log_enable:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in
# piecewise mode.
if is_capturing_prefill:
logger.debug(
"Capturing a prefill cudagraph on (%s, batch_size=%d, seq_len=%d)",
self.runtime_mode.name,
entry.batch_size,
entry.seq_len,
)
else:
logger.debug(
"Capturing a decode cudagraph on (%s, %s)",
self.runtime_mode.name,
entry.batch_descriptor,
)
if ((not is_capturing_prefill) and (not is_running_drafter)):
# validate that cudagraph capturing is legal at this point.
validate_cudagraph_capturing_enabled()
'''
==================
End of MLU Hijack
==================
'''
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.mlu.MLUGraph()
with ExitStack() as stack:
if self.cudagraph_options.gc_disable:
# during every model forward for piecewise cudagraph
# mode, we will capture many pieces of cudagraphs
# (roughly one per layer). running gc again and again
# across layers will make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(patch("torch.mlu.empty_cache", lambda: None))
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())
# mind-exploding: carefully manage the reference and memory.
with torch.mlu.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
if self.cudagraph_options.weak_ref_output:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph in piecewise cuadgraph mode, because
# the output of the last graph will not be used by
# any other cuda graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_captured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
f"Input addresses for cudagraphs are different "
f"during replay. Expected {entry.input_addresses}, "
f"got {new_input_addresses}"
)
entry.cudagraph.replay()
return entry.output

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

71
vllm_mlu/config/model.py Normal file
View File

@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.config.model import ModelConfig
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__config__model__ModelConfig__is_embedding_task(self) -> bool:
return self.runner_type == "pooling"
def vllm__config__model__ModelConfig__get_head_size(self) -> int:
# TODO remove hard code
if self.is_deepseek_mla:
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0)
if self.use_mla:
return self.hf_text_config.kv_lora_rank + qk_rope_head_dim
else:
qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0)
if qk_rope_head_dim and qk_nope_head_dim:
return qk_rope_head_dim + qk_nope_head_dim
if hasattr(self.hf_text_config, "model_type") and (
self.hf_text_config.model_type == "zamba2"
):
return self.hf_text_config.attention_head_dim
if self.is_attention_free:
return 0
# NOTE: Some configs may set head_dim=None in the config
if getattr(self.hf_text_config, "head_dim", None) is not None:
return self.hf_text_config.head_dim
# NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head`
if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None:
return self.hf_text_config.hidden_size_per_head
# FIXME(woosuk): This may not be true for all models.
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust num_heads and num_attention_heads.
'''
if hasattr(self.hf_text_config, "num_heads"):
num_attention_heads = self.hf_text_config.num_heads
else:
num_attention_heads = self.hf_text_config.num_attention_heads
return (self.hf_text_config.hidden_size // num_attention_heads)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(
ModelConfig,
"is_embedding_task",
vllm__config__model__ModelConfig__is_embedding_task,
)
MluHijackObject.apply_hijack(
ModelConfig,
ModelConfig.get_head_size,
vllm__config__model__ModelConfig__get_head_size,
)

View File

@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing_extensions import Self
from vllm.config.scheduler import SchedulerConfig
from vllm.logger import init_logger
from vllm_mlu._mlu_utils import VLLM_V1_BENCHMARK
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__config__scheduler__SchedulerConfig__verify_max_model_len(
self, max_model_len: int,
) -> Self:
'''
=============================
Modify by vllm_mlu
=============================
@brief: This restriction is removed when VLLM_V1_BENCHMARK is set to True
'''
if not VLLM_V1_BENCHMARK:
if (
self.max_num_batched_tokens < max_model_len
and not self.enable_chunked_prefill
):
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({max_model_len}). "
"This effectively limits the maximum sequence length to "
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len."
)
'''
==================
End of MLU Hijack
==================
'''
if self.max_num_batched_tokens < self.max_num_seqs:
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
"be greater than or equal to max_num_seqs "
f"({self.max_num_seqs})."
)
if self.max_num_batched_tokens > self.max_num_seqs * max_model_len:
logger.warning(
"max_num_batched_tokens (%d) exceeds max_num_seqs "
"* max_model_len (%d). This may lead to unexpected behavior.",
self.max_num_batched_tokens,
self.max_num_seqs * max_model_len,
)
if self.max_num_partial_prefills > 1:
if not self.enable_chunked_prefill:
raise ValueError(
"Chunked prefill must be enabled to set "
"max_num_partial_prefills > 1."
)
if self.long_prefill_token_threshold > max_model_len:
raise ValueError(
"long_prefill_token_threshold "
f"({self.long_prefill_token_threshold}) cannot be greater "
f"than the max_model_len ({max_model_len})."
)
if self.max_long_partial_prefills > self.max_num_partial_prefills:
raise ValueError(
f"{self.max_long_partial_prefills=} must be less than or equal to "
f"{self.max_num_partial_prefills=}."
)
return self
MluHijackObject.apply_hijack(
SchedulerConfig,
SchedulerConfig.verify_max_model_len,
vllm__config__scheduler__SchedulerConfig__verify_max_model_len,
)

View File

@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.config.parallel import ParallelConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
@staticmethod
def vllm__config__speculative__SpeculativeConfig__create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
"""
'''
=============================
Modify by vllm_mlu
@brief: add draft data parallel parameters
=============================
'''
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
tensor_parallel_size=speculative_draft_tensor_parallel_size,
distributed_executor_backend=target_parallel_config.distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
placement_group=target_parallel_config.placement_group,
# add draft data parallel parameters
data_parallel_size=target_parallel_config.data_parallel_size,
data_parallel_size_local=target_parallel_config.data_parallel_size_local,
data_parallel_master_ip=target_parallel_config.data_parallel_master_ip,
data_parallel_rpc_port=target_parallel_config.data_parallel_rpc_port,
)
'''
==================
End of MLU Hijack
==================
'''
return draft_parallel_config
vllm__config__speculative__SpeculativeConfig____post_init___org = SpeculativeConfig.__post_init__
def vllm__config__speculative__SpeculativeConfig____post_init__(self):
if self.model is None and self.num_speculative_tokens is not None and self.method is None:
self.method = "mtp"
vllm__config__speculative__SpeculativeConfig____post_init___org(self)
MluHijackObject.apply_hijack(
SpeculativeConfig,
SpeculativeConfig.create_draft_parallel_config,
vllm__config__speculative__SpeculativeConfig__create_draft_parallel_config,
)
MluHijackObject.apply_hijack(
SpeculativeConfig,
SpeculativeConfig.__post_init__,
vllm__config__speculative__SpeculativeConfig____post_init__,
)

213
vllm_mlu/config/vllm.py Normal file
View File

@@ -0,0 +1,213 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import os
from vllm.config.vllm import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__config__vllm__VllmConfig___set_cudagraph_sizes(self):
"""
vLLM defines the default candidate list of batch sizes for CUDA graph
capture as:
```python
max_graph_size = min(max_num_seqs * 2, 512)
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
# up to max_graph_size
cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
range(256, max_graph_size + 1, 16))
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
will be the final sizes to capture cudagraph (in ascending order).
These sizes are used to capture and reuse CUDA graphs for
performance-critical paths (e.g., decoding). Capturing enables
significantly faster kernel dispatch by avoiding Python overhead. The
list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on
most GPUs), which controls the total allowed number of tokens in a
batch. Since each sequence may have a variable number of tokens, the
maximum usable batch size will depend on actual sequence lengths.
Example:
With `max_num_batched_tokens = 8192`, and typical sequences
averaging ~32 tokens, most practical batch sizes fall below 256.
However, the system will still allow capture sizes up to 512 if
shape and memory permit.
Note:
If users explicitly specify cudagraph capture sizes in the
compilation config, those will override this default logic.
At runtime:
- If batch size <= one of the `cudagraph_capture_sizes`, the closest
padded CUDA graph will be used.
- If batch size > largest `cudagraph_capture_sizes`, cudagraph will
not be used.
"""
if hasattr(self.compilation_config, "_has_set_capture_list"):
# avoid set capture list twice while init
return
if (
self.model_config is not None
and not self.model_config.enforce_eager
and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
):
# determine the initial max_cudagraph_capture_size
max_cudagraph_capture_size = (
self.compilation_config.max_cudagraph_capture_size
)
if max_cudagraph_capture_size is None:
max_cudagraph_capture_size = min(
self.scheduler_config.max_num_seqs * 2, 512
)
max_num_tokens = self.scheduler_config.max_num_batched_tokens
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)
assert max_cudagraph_capture_size >= 1, (
"Maximum cudagraph size should be greater than or equal to 1 "
"when using cuda graph."
)
# determine the cudagraph_capture_sizes
if self.compilation_config.cudagraph_capture_sizes is not None:
assert len(self.compilation_config.cudagraph_capture_sizes) > 0, (
"cudagraph_capture_sizes should contain at least one element "
"when using cuda graph."
)
# de-duplicate the sizes provided by the config
dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
cudagraph_capture_sizes = [
i for i in dedup_sizes if i <= max_num_tokens
]
# sort to make sure the sizes are in ascending order
cudagraph_capture_sizes.sort()
else:
cudagraph_capture_sizes = [
i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
]
if max_cudagraph_capture_size >= 8:
# Step size 8 for small batch sizes, up to 256(not included)
cudagraph_capture_sizes += list(
range(8, min(max_cudagraph_capture_size + 1, 256), 8)
)
if max_cudagraph_capture_size >= 256:
# Step size 16 for larger batch sizes
cudagraph_capture_sizes += list(
range(256, max_cudagraph_capture_size + 1, 16)
)
'''
=============================
Modify by vllm_mlu
=============================
@brief:
1) check batch_size_capture_list when enable mtp because bs * (K + 1)
may greater than max_num_batched_tokens
2) capture MLUGraph by given batch list
'''
mlu_graph_capture_list = os.getenv("MLU_GRAPH_CAPTURE_LIST", None)
if mlu_graph_capture_list:
if "-" in mlu_graph_capture_list:
batch_info = mlu_graph_capture_list.split("-")
assert len(batch_info) == 3, \
f"Got invalid graph_capture_list={mlu_graph_capture_list}, " + \
f"but expected format 'min_bs-max_bs(may not include)-step'."
start, end, step = mlu_graph_capture_list.split("-")
cudagraph_capture_sizes = [1, 2, 4] + [
i for i in range(int(start), int(end), int(step))
]
cudagraph_capture_sizes = sorted(list(set(cudagraph_capture_sizes)))
else:
cudagraph_capture_sizes = [int(x) for x in mlu_graph_capture_list.split(",")]
if (self.speculative_config is not None
and self.speculative_config.num_speculative_tokens > 0
):
K = self.speculative_config.num_speculative_tokens
cudagraph_capture_sizes = [x * (1 + K) for x in cudagraph_capture_sizes]
cudagraph_capture_sizes = [
size for size in cudagraph_capture_sizes
if size <= self.scheduler_config.max_num_batched_tokens
]
'''
==================
End of MLU Hijack
==================
'''
if (
self.parallel_config.tensor_parallel_size > 1
and self.compilation_config.pass_config.enable_sequence_parallelism
):
cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
cudagraph_capture_sizes
)
# user-specific compilation_config.max_cudagraph_capture_size get
# truncated to valid_max_size when they are inconsistent.
valid_max_size = (
cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
)
if (
self.compilation_config.max_cudagraph_capture_size is not None
and self.compilation_config.max_cudagraph_capture_size != valid_max_size
):
# raise error only when both two flags are user-specified
# and they are inconsistent with each other
if self.compilation_config.cudagraph_capture_sizes is not None:
raise ValueError(
"customized max_cudagraph_capture_size"
f"(={self.compilation_config.max_cudagraph_capture_size}) "
"should be consistent with the max value of "
f"cudagraph_capture_sizes(={valid_max_size})"
)
logger.warning(
"Truncating max_cudagraph_capture_size to %d",
valid_max_size,
)
# always set the final max_cudagraph_capture_size
self.compilation_config.max_cudagraph_capture_size = valid_max_size
if self.compilation_config.cudagraph_capture_sizes is not None and len(
cudagraph_capture_sizes
) < len(self.compilation_config.cudagraph_capture_sizes):
# If users have specified capture sizes, we only need to
# compare the lens before and after modification since the modified
# list is only the subset of the original list.
logger.warning(
(
"cudagraph_capture_sizes specified in compilation_config"
" %s is overridden by config %s"
),
self.compilation_config.cudagraph_capture_sizes,
cudagraph_capture_sizes,
)
# always write back the final sizes
self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes
else:
# no cudagraph in use
self.compilation_config.max_cudagraph_capture_size = 0
self.compilation_config.cudagraph_capture_sizes = []
# complete the remaining process.
self.compilation_config.post_init_cudagraph_sizes()
setattr(self.compilation_config, "_has_set_capture_list", True)
MluHijackObject.apply_hijack(
VllmConfig,
VllmConfig._set_cudagraph_sizes,
vllm__config__vllm__VllmConfig___set_cudagraph_sizes,
)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,319 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# cn_api based pytorch pluggable allocator to implement sleep mode.
import dataclasses
import gc
import os
from collections.abc import Callable
from contextlib import contextmanager
from typing import Any
import torch
from vllm.logger import init_logger
from vllm.utils.platform_utils import is_pin_memory_available
logger = init_logger(__name__)
def find_loaded_library(lib_name) -> str | None:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
found_line = None
with open("/proc/self/maps") as f:
for line in f:
if lib_name in line:
found_line = line
break
if found_line is None:
# the library is not loaded in the current process
return None
# if lib_name is libcudart, we need to match a line with:
# address /path/to/libcudart-hash.so.11.0
start = found_line.index("/")
path = found_line[start:].strip()
filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(lib_name), (
f"Unexpected filename: {filename} for library {lib_name}"
)
return path
cnmem_available = False
try:
from vllm_mlu.vllm_mlu_C import (
init_module,
python_create_and_map,
python_unmap_and_release,
python_cn_memcpy,
)
lib_name = find_loaded_library("vllm_mlu_C")
cnmem_available = True
except ModuleNotFoundError as e:
logger.error("Failed to import cnmem_allocator:%s", e)
init_module = None
python_create_and_map = None
python_unmap_and_release = None
lib_name = None
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
HandleType = tuple[int, int, int, int]
@dataclasses.dataclass
class AllocationData:
handle: HandleType
tag: str
cpu_backup_tensor: torch.Tensor | None = None
def create_and_map(allocation_handle: HandleType) -> None:
python_create_and_map(*allocation_handle)
def unmap_and_release(allocation_handle: HandleType) -> None:
python_unmap_and_release(*allocation_handle)
def get_pluggable_allocator(
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
python_free_func: Callable[[int], tuple[int, int, int, int]]
) -> torch.mlu.memory.MLUPluggableAllocator:
init_module(python_malloc_fn, python_free_func)
new_alloc = torch.mlu.memory.MLUPluggableAllocator(
lib_name, "my_malloc", "my_free"
)
return new_alloc
@contextmanager
def use_memory_pool_with_allocator(
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
python_free_func: Callable[[int], tuple[int, int, int, int]]):
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.mlu.memory.MemPool(new_alloc._allocator)
with torch.mlu.memory.use_mem_pool(mem_pool):
yield mem_pool, new_alloc
class CnMemAllocator:
"""
A singleton class that manages a memory pool for MLU tensors.
The memory in this pool can be offloaded or discarded when the
allocator sleeps.
Inside the `use_memory_pool(tag)` context, all tensors created will
be allocated in the memory pool, and has the same tag as the
tag passed to the context.
When we call `sleep`, all tensors with the specified tag will be
offloaded to CPU memory, and the rest of the tensors will be discarded.
When we call `wake_up`, all tensors that are previously offloaded
will be loaded back to GPU memory, and the rest of the tensors will
have empty memory.
Why it needs to be a singleton?
When allocated tensors are garbage collected, PyTorch will call
the free callback, which will call the `python_free_callback` method.
The C-extension uses a global variable to store the function of an
instance of this class. If we create multiple instances of this class,
the global variable will be overwritten and the free callback will
not work as expected.
"""
instance: "CnMemAllocator" = None
default_tag: str = "default"
@staticmethod
def get_instance() -> "CnMemAllocator":
"""
CnMemAllocator is a singleton class.
We cannot call the constructor directly.
Call this method to get the instance.
"""
assert cnmem_available, "cnmem allocator is not available"
if CnMemAllocator.instance is None:
CnMemAllocator.instance = CnMemAllocator()
return CnMemAllocator.instance
def __init__(self):
conf = os.environ.get("PYTORCH_MLU_ALLOC_CONF", "")
assert "expandable_segments:True" not in conf, (
"Expandable segments are not compatible with memory pool. "
"Please track https://github.com/pytorch/pytorch/issues/147851 "
"for the latest updates."
)
self.pointer_to_data: dict[int, AllocationData] = {}
self.current_tag: str = CnMemAllocator.default_tag
self.allocator_and_pools: dict[str, Any] = {}
# Creating strong references to the two callbacks here to prevent
# these ephemeral bound-method objects being garbage collected.
# See discussions in https://github.com/vllm-project/vllm/pull/22724
self.python_malloc_callback = self._python_malloc_callback
self.python_free_callback = self._python_free_callback
def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
"""
Internal method to store the allocation data
when memory is allocated in the memory pool."""
py_d_mem = allocation_handle[2]
self.pointer_to_data[py_d_mem] = AllocationData(
allocation_handle, self.current_tag
)
logger.debug(
"Allocated %s bytes for %s with address %s from cnmem allocator",
allocation_handle[1],
self.current_tag,
py_d_mem,
)
return
def _python_free_callback(self, ptr: int) -> HandleType:
"""
Internal method to look up the allocation data
when memory is freed in the memory pool."""
data = self.pointer_to_data.pop(ptr)
if data.cpu_backup_tensor is not None:
data.cpu_backup_tensor = None
logger.debug(
"Freed %s bytes for %s with address %s from cnmem allocator",
data.handle[1],
data.tag,
ptr,
)
return data.handle
def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None:
"""
Put the allocator in sleep mode.
All data in the memory allocation with the specified tag will be
offloaded to CPU memory, and others will be discarded.
:param offload_tags: The tags of the memory allocation that will be
offloaded. The rest of the memory allocation will be discarded.
"""
if offload_tags is None:
# by default, allocated tensors are offloaded
# when the allocator sleeps
offload_tags = (CnMemAllocator.default_tag, )
elif isinstance(offload_tags, str):
offload_tags = (offload_tags,)
assert isinstance(offload_tags, tuple)
total_bytes = 0
backup_bytes = 0
for ptr, data in self.pointer_to_data.items():
handle = data.handle
total_bytes += handle[1]
if data.tag in offload_tags:
backup_bytes += handle[1]
size_in_bytes = handle[1]
cpu_backup_tensor = torch.empty(
size_in_bytes,
dtype=torch.uint8,
device="cpu",
pin_memory=is_pin_memory_available(),
)
cpu_ptr = cpu_backup_tensor.data_ptr()
python_cn_memcpy(cpu_ptr, ptr, size_in_bytes)
data.cpu_backup_tensor = cpu_backup_tensor
unmap_and_release(handle)
logger.info(
"CnMemAllocator: sleep freed %.2f GiB memory in total, of which "
"%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded "
"directly.",
total_bytes / 1024**3,
backup_bytes / 1024**3,
(total_bytes - backup_bytes) / 1024**3,
)
gc.collect()
torch.mlu.empty_cache()
def wake_up(self, tags: list[str] | None = None) -> None:
"""
Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory.
:param tags: The tags of the memory allocation that will be loaded
back to GPU memory. If None, all memory allocation will be loaded
back to GPU memory.
"""
for ptr, data in self.pointer_to_data.items():
if tags is None or data.tag in tags:
handle = data.handle
create_and_map(handle)
if data.cpu_backup_tensor is not None:
cpu_backup_tensor = data.cpu_backup_tensor
if cpu_backup_tensor is not None:
size_in_bytes = (
cpu_backup_tensor.numel() * cpu_backup_tensor.element_size()
)
cpu_ptr = cpu_backup_tensor.data_ptr()
python_cn_memcpy(ptr, cpu_ptr, size_in_bytes)
data.cpu_backup_tensor = None
@contextmanager
def use_memory_pool(self, tag: str | None = None):
"""
A context manager to use the memory pool.
All memory allocation created inside the context will be allocated
in the memory pool, and has the specified tag.
:param tag: The tag of the memory allocation. If None, the default tag
will be used.
"""
if tag is None:
tag = CnMemAllocator.default_tag
assert isinstance(tag, str)
old_tag = self.current_tag
self.current_tag = tag
with use_memory_pool_with_allocator(
self.python_malloc_callback, self.python_free_callback
) as data:
# start to hit another PyTorch bug in PyTorch 2.6,
# possibly because of gc-related issue w.r.t. the allocator and
# the memory pool.
# to avoid the issue, we keep a reference of the data.
# see https://github.com/pytorch/pytorch/issues/146431 .
self.allocator_and_pools[tag] = data
yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see
# https://github.com/pytorch/pytorch/issues/145168 .
# if we have some memory allocated and then freed,
# the memory will not be released, e.g. in online quantization,
# where the model is created in higher precision, and then
# quantized in lower precision.
# Find all unused allocations and manually release them.
# TODO: we should expose `empty_cache` method in the memory pool.
# TODO: ask for help from PyTorch team to expose this method.
allocations = data[0].snapshot()
for allocation in allocations:
if allocation["allocated_size"] == 0:
handle = self._python_free_callback(allocation["address"])
unmap_and_release(handle)
self.current_tag = old_tag
def get_current_usage(self) -> int:
"""
Get the total number of bytes allocated in the memory pool.
"""
sum_bytes: int = 0
for ptr, data in self.pointer_to_data.items():
handle = data.handle
sum_bytes += handle[1]
return sum_bytes

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase,
)
class MLUCommunicator(DeviceCommunicatorBase):
def __init__(
self,
cpu_group: ProcessGroup,
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = ""
):
super().__init__(cpu_group, device, device_group, unique_name)
# init device according to rank
self.device = torch.mlu.current_device()
self.ca_comm: CustomAllreduce | None = None

View File

@@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
MLUKVConnectors: dict[str, tuple[str, str]] = {
"MLUSharedStorageConnector": (
"vllm_mlu.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
"SharedStorageConnector"
),
"MLUNixlConnector": (
"vllm_mlu.distributed.kv_transfer.kv_connector.v1.nixl_connector",
"MLUNixlConnector"
),
}
for name, (module_path, class_name) in MLUKVConnectors.items():
if name not in KVConnectorFactory._registry:
KVConnectorFactory.register_connector(name, module_path, class_name)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
from vllm_mlu.mlu_hijack_utils import MluHijackObject
class LMCacheConnectorV1_MluHijack(LMCacheConnectorV1):
def response_remote_alloc_once(self) -> None:
self._lmcache_engine.response_remote_alloc_once()
def request_remote_memory_send(self) -> None:
self._lmcache_engine.request_remote_memory_send()
MluHijackObject.apply_hijack(LMCacheConnectorV1,
"response_remote_alloc_once",
LMCacheConnectorV1_MluHijack.response_remote_alloc_once)
MluHijackObject.apply_hijack(LMCacheConnectorV1,
"request_remote_memory_send",
LMCacheConnectorV1_MluHijack.request_remote_memory_send)

View File

@@ -0,0 +1,346 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import math
import threading
import time
import uuid
from collections import defaultdict
from collections.abc import Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
import zmq
from vllm import envs
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.logger import init_logger
from vllm.platforms import _Backend
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
EngineId, NixlConnectorWorker, NixlAgentMetadata, NixlConnectorScheduler, NixlConnector)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
Transfer = tuple[int, float] # (xfer_handle, start_time)
GET_META_MSG = b"get_meta_msg"
logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
from nixl._api import nixl_agent as NixlWrapper
logger.info("NIXL is available")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
class MLUNixlConnector(NixlConnector):
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super(NixlConnector, self).__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler : MLUNixlConnectorScheduler | None = (
MLUNixlConnectorScheduler(vllm_config, self.engine_id)
)
self.connector_worker: MLUNixlConnectorWorker | None = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = MLUNixlConnectorWorker(vllm_config, self.engine_id)
class MLUNixlConnectorScheduler(NixlConnectorScheduler):
"""Implementation of Scheduler side methods"""
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
'''
=============================
Modify by vllm_mlu
=============================
@brief: kv transfer info
'''
if request.kv_transfer_params.get("do_remote_prefill", False):
logger.info(f"NIXLConnector update_state_after_alloc: request_id={request.request_id}, "
f"num_prompt_tokens={request.num_prompt_tokens}, "
f"num_external_tokens={num_external_tokens}, "
f"kv_transfer_params={request.kv_transfer_params}")
'''
==================
End of MLU Hijack
==================
'''
params = request.kv_transfer_params
logger.debug(
"NIXLConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens,
params,
)
if not params:
return
if params.get("do_remote_decode"):
self._reqs_in_batch.add(request.request_id)
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
# save all blocks
block_ids = blocks.get_block_ids()[0]
# TODO: skip the blocks that are already in the host xfer buffer.
# Currently, the host xfer buffer block is 1-to-1 mapped to device
# kv blocks, so host blocks won't be flushed as long as its device
# block is not overwritten; and it will be safe to skip saving them
# to host xfer buffer.
if block_ids:
self._reqs_need_save[request.request_id] = (request, block_ids)
elif params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(
p in params
for p in ("remote_engine_id", "remote_host", "remote_port")
):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids = (
blocks.get_unhashed_block_ids()
if num_external_tokens > 0
else []
)
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (
request,
local_block_ids,
)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer",
params,
)
else:
assert num_external_tokens == 0
# Only trigger 1 KV transfer per request.
params["do_remote_prefill"] = False
class MLUNixlConnectorWorker(NixlConnectorWorker):
"""Implementation of Worker side methods"""
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
_, first_kv_cache = next(iter(kv_caches.items()))
'''
=============================
Add by vllm_mlu
=============================
@brief: not support kv8
'''
if not isinstance(first_kv_cache, torch.Tensor):
kv_caches = {key: value[0] for key, value in kv_caches.items()}
_, first_kv_cache = next(iter(kv_caches.items()))
'''
==================
End of MLU Hijack
==================
'''
kv_elem_size = first_kv_cache.element_size()
# TODO(tms): Find a more robust way to detect and handle MLA
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
'''
=============================
Add by vllm_mlu
=============================
@brief: support mla
'''
use_mla = first_kv_cache.shape[0] == 1
'''
==================
End of MLU Hijack
==================
'''
assert use_mla == self.use_mla
# TODO (NickLucche) not compatible with hybrid allocator. Enforce check
# once it goes live, as a single kv layout is expected for xfers.
if use_mla:
# MLA case.
'''
=============================
Add by vllm_mlu
=============================
@brief: support mla
'''
self.num_blocks = first_kv_cache.shape[1]
'''
==================
End of MLU Hijack
==================
'''
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, kv_latent_dim = block_shape
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
# FlashInfer swaps 2<->num_blocks dimensions.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 4 # [2, block_size, kv_heads, head_dim]
else:
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
'''
=============================
Add by vllm_mlu
=============================
@brief: MLU kv_cache layout is [2 (k and v), num_blocks, kv_heads, block_size, head_dim]
'''
n_kv_heads, block_size, head_dim = block_shape[-3:]
'''
==================
End of MLU Hijack
==================
'''
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
assert block_size == self.block_size
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self.block_len = kv_elem_size * math.prod(block_shape)
logger.info(
"Registering KV_Caches: use_mla: %s, num_blocks: %s, "
"block_shape: %s, per_layer_kv_cache_shape: %s", use_mla,
self.num_blocks, block_shape, first_kv_cache.shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.kv_caches = kv_caches
kv_caches_base_addr = []
caches_data = []
# Note(tms): I modified this from the original region setup code.
# K and V are now in different regions. Advantage is that we can
# elegantly support MLA and any cases where the K and V tensors
# are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \
else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len
caches_data.append(
(base_addr, region_len, cache.device.index, ""))
kv_caches_base_addr.append(base_addr)
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
self.num_regions = len(caches_data)
self.num_layers = len(self.kv_caches.keys())
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs)
logger.debug("Done registering descs")
self._registered_descs.append(descs)
# Register local/src descr for NIXL xfer.
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
# select agent_meta.num_blocks instead of self.num_blocks for
# local descr, and that makes handling regular flow less clean.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, self.block_len, self.tp_rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.tp_rank)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
# NIXL_INIT_AGENT to be used for preparations of local descs.
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)
# After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
tp_size=self.world_size,
block_len=self.block_len,
attn_backend_name=self.backend_name)
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
daemon=True,
name="nixl_handshake_listener")
self._nixl_handshake_listener_t.start()
ready_event.wait()

View File

@@ -0,0 +1,450 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import hashlib
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
import safetensors
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
from vllm_mlu.v1.attention.backends.flash_mla import MLAFlashAttentionCommonMetadata
logger = init_logger(__name__)
@dataclass
class ReqMeta:
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
# Is store or load
is_store: bool
mm_hashes: list[str]
@staticmethod
def make_meta(
token_ids: list[int],
block_ids: list[int],
block_size: int,
is_store: bool,
mm_hashes: list[str],
) -> "ReqMeta":
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = (
block_offsets.reshape((1, block_size))
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
)
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta(
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
is_store=is_store,
mm_hashes=mm_hashes,
)
@dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta] = field(default_factory=list)
def add_request(
self,
token_ids: list[int],
block_ids: list[int],
block_size: int,
is_store: bool,
mm_hashes: list[str],
) -> None:
self.requests.append(
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store, mm_hashes)
)
class SharedStorageConnector(KVConnectorBase_V1):
# NOTE: This is Simple debug implementation of the KV connector.
# It save / load the KV cache to / from the disk.
# It does extra work which will overwrite the existing prefix-cache in GPU
# - to remove the overhead, need to add some "mask" in the ReqMeta class
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(
vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config,
)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Request] = {}
self._storage_path = self._kv_transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp")
logger.info(self._kv_transfer_config)
logger.info("Shared storage path is %s", self._storage_path)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
attn_metadata = forward_context.attn_metadata
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""Inject the KV cache into the layer.
Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache
layer. In shape [2, num_pages, page_size, xxx] if not
using MLA, [num_pages, page_size, xxx] otherwise.
src_kv_cache (torch.Tensor): the source KV cache. In shape
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise.
slot_mapping (torch.Tensor): the slot mapping. In shape
[num_tokens].
"""
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLAFlashAttentionCommonMetadata):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
num_pages * page_size, -1
)
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
2, num_pages * page_size, -1
)
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata
metadata: KVConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, SharedStorageConnectorMetadata)
if metadata is None:
logger.warning(
"In connector.start_load_kv, but the connector metadata is None"
)
return
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
logger.warning("In connector.start_load_kv, but the attn_metadata is None")
return
# Load the KV for each request each layer
for request in metadata.requests:
if request.is_store:
continue
logger.info(
"Inject KV cache of %d tokens to the paged memory",
len(request.slot_mapping),
)
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE/MLP etc.
kv_cache_attr = getattr(layer, "kv_cache", None)
if kv_cache_attr is None:
continue
kv_cache_layer = kv_cache_attr[forward_context.virtual_engine]
filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLAFlashAttentionCommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
for request in connector_metadata.requests:
if request.is_store:
filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
tensors = {"kv_cache": kv_cache.detach().cpu()}
safetensors.torch.save_file(tensors, filename)
def wait_for_save(self):
return
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# NOTE: in this debug implementation, we assume that the prompt is
# cached_prompt + newly_generated_single_token
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
# with the block granularity. And it expects the returned blocks and
# num_computed_tokens to also be aligned with the block granularity.
if not self._found_match_for_request(request):
return 0, False
logger.info("External Cache Hit!")
# Now, first num_tokens_to_check tokens are hit, we need to prepare
# the metadata for the worker connector to correctly load the KV
token_ids = request.prompt_token_ids or []
num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size)
return num_tokens_to_check - num_computed_tokens, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
"""
Update KVConnector state after block allocation.
If blocks were allocated, add to _requests_need_load,
such that we load the KVs in the next forward pass.
"""
if num_external_tokens > 0:
self._requests_need_load[request.request_id] = request
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = SharedStorageConnectorMetadata()
total_need_load = 0
for new_req in scheduler_output.scheduled_new_reqs:
token_ids = new_req.prompt_token_ids or []
mm_hashes = [f.identifier for f in new_req.mm_features]
if new_req.req_id in self._requests_need_load:
meta.add_request(
token_ids=token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=False,
mm_hashes=mm_hashes,
)
total_need_load += 1
else:
# NOTE: here, we set the store and load being exclusive,
# but a single request can have both store and load.
# NOTE(rob): for this debug implementation, we only cache
# the original prompt tokens.
if not self._found_match_for_prompt(token_ids, mm_hashes):
meta.add_request(
token_ids=token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=True,
mm_hashes=mm_hashes,
)
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
if not resumed_from_preemption or req_id not in self._requests_need_load:
continue
num_computed_tokens = cached_reqs.num_computed_tokens[i]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
new_block_ids = cached_reqs.new_block_ids[i]
# NOTE(rob): cached_req_data does not have the full
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[req_id]
total_tokens = num_computed_tokens + num_new_tokens
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
assert new_block_ids is not None
block_ids = new_block_ids[0]
meta.add_request(
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size,
is_store=False,
mm_hashes=[f.identifier for f in request.mm_features],
)
total_need_load += 1
assert total_need_load == len(self._requests_need_load)
self._requests_need_load.clear()
return meta
# ==============================
# Helper functions
# ==============================
def _found_match_for_request(
self,
request: "Request",
) -> bool:
"""Check if the cache is hit for the request."""
return self._found_match_for_prompt(
list(request.prompt_token_ids or []),
[f.identifier for f in request.mm_features],
)
def _found_match_for_prompt(
self,
prompt_token_ids: list[int],
mm_hashes: list[str],
) -> bool:
num_tokens_to_check = align_to_block_size(
len(prompt_token_ids) - 1, self._block_size
)
foldername = self._generate_foldername_debug(
torch.tensor(prompt_token_ids)[:num_tokens_to_check],
mm_hashes,
create_folder=False,
)
return os.path.exists(foldername)
def _generate_foldername_debug(
self,
token_ids: torch.Tensor,
mm_hashes: list[str],
create_folder=False,
) -> str:
"""Generate a folder name based on the hash of the bytes of the input
ids.
"""
token_bytes = token_ids.numpy().tobytes()
# Add mm_hashes to the bytes being hashed to avoid path traversal and
# to create a canonical key.
if mm_hashes:
mm_str = "-".join(mm_hashes)
token_bytes += mm_str.encode("utf-8")
input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest()
foldername = os.path.join(self._storage_path, input_ids_hash)
if create_folder:
os.makedirs(foldername, exist_ok=True)
return foldername
def _generate_filename_debug(
self,
layer_name: str,
token_ids: torch.Tensor,
mm_hashes: list[str],
) -> str:
"""Generate a file name based on the layer name and the hash
of the bytes of the input ids.
"""
foldername = self._generate_foldername_debug(
token_ids, mm_hashes=mm_hashes, create_folder=True
)
return os.path.join(foldername, f"{layer_name}.safetensors")
def align_to_block_size(num_tokens: int, block_size) -> int:
"""Align the number of tokens to the block size."""
return (num_tokens - 1) // block_size * block_size

View File

@@ -0,0 +1,286 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from contextlib import contextmanager, nullcontext
from typing import Optional
from dataclasses import dataclass
import torch
from vllm.distributed.parallel_state import (
GroupCoordinator,
GraphCaptureContext,
get_pp_group,
get_tp_group,
)
from vllm.distributed.mlu_parallel_state import(
get_moe_expert_parallel_world_size,
get_moe_expert_parallel_rank,
get_moe_expert_parallel_group,
)
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
@dataclass
class MLUGraphCaptureContext:
stream: torch.mlu.Stream
@contextmanager
def mlu_graph_capture(device: torch.device):
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
context = MLUGraphCaptureContext(torch.mlu.Stream(device=device))
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context):
yield context
@contextmanager
def vllm__distributed__parallel_state__GroupCoordinator__graph_capture(
self,
graph_capture_context: GraphCaptureContext | None = None,
):
if graph_capture_context is None:
stream = torch.mlu.Stream()
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
# only cuda uses this function,
# so we don't abstract it into the base class
maybe_ca_context = nullcontext()
from vllm_mlu.distributed.device_communicators.mlu_communicator import (
MLUCommunicator,
)
if self.device_communicator is not None:
assert isinstance(self.device_communicator, MLUCommunicator)
ca_comm = self.device_communicator.ca_comm
if ca_comm is not None:
maybe_ca_context = ca_comm.capture() # type: ignore
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.mlu.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch.mlu.stream(stream), maybe_ca_context:
yield graph_capture_context
@dataclass
class CnclEPBuffer:
dispatch_send_token_tensor: torch.Tensor
dispatch_recv_token_tensor: torch.Tensor
combine_send_token_tensor: torch.Tensor
combine_recv_token_tensor: torch.Tensor
class CnclEP:
def __init__(self,
dispatch_token_size: int,
combine_token_size: int,
max_num_tokens_per_rank: int,
num_global_experts: int,
use_quant_dispatch: bool = True) -> None:
nranks = get_moe_expert_parallel_world_size()
rank = get_moe_expert_parallel_rank()
moe_ep_group = get_moe_expert_parallel_group()
self.max_num_tokens_per_rank = max_num_tokens_per_rank
self.use_quant_dispatch = use_quant_dispatch
(
handle,
exchange_info_size,
exchange_info,
dispatch_send_token_tensor,
dispatch_recv_token_tensor,
combine_send_token_tensor,
combine_recv_token_tensor
) = mlu_ops.moe_all2all_create(dispatch_token_size,
combine_token_size,
num_global_experts,
max_num_tokens_per_rank,
rank,
nranks)
self.handle = handle
self.buffer = CnclEPBuffer(
dispatch_send_token_tensor,
dispatch_recv_token_tensor,
combine_send_token_tensor,
combine_recv_token_tensor)
assert exchange_info.ndim == 1, "exchange_info should be 1D"
all_exchange_info = torch.empty((nranks, exchange_info.size(0)),
dtype=exchange_info.dtype,
device=exchange_info.device)
exchange_info = exchange_info.unsqueeze(0)
torch.distributed.all_gather_into_tensor(all_exchange_info,
exchange_info,
group=moe_ep_group.cpu_group,
async_op=False)
mlu_ops.moe_all2all_init(self.handle, all_exchange_info)
torch.distributed.barrier(group=moe_ep_group.cpu_group)
def dispatch(self,
token_byte: int,
token_num: int,
send_layout: torch.Tensor,
send_token_num: torch.Tensor,
recv_layout: torch.Tensor,
recv_token_num: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
) -> None:
'''
The returned tensors are in-placed modified, we could directly use them
after dispatch finishes.
'''
mlu_ops.moe_all2all_dispatch(self.handle,
token_byte,
token_num,
send_layout,
send_token_num,
recv_layout,
recv_token_num,
send_token,
recv_token)
def combine(self,
token_byte: int,
token_num: int,
send_src_layout: torch.Tensor,
send_dst_layout: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
) ->None:
mlu_ops.moe_all2all_combine(self.handle,
token_byte,
token_num,
send_src_layout,
send_dst_layout,
send_token,
recv_token)
def destroy(self) -> None:
mlu_ops.moe_all2all_destroy(self.handle)
_CNCLEP: CnclEP | None = None
_CNCLEP_BF16: CnclEP | None = None
def get_cnclep(use_quant_dispatch: bool = True) -> CnclEP:
if use_quant_dispatch:
assert _CNCLEP is not None, "cnclep is not initialized"
return _CNCLEP
else:
assert _CNCLEP_BF16 is not None, "cnclep_bf16 is not initialized"
return _CNCLEP_BF16
def init_cnclep(dispatch_token_size: int,
combine_token_size: int,
max_num_tokens_per_rank: int,
num_global_experts: int,
use_quant_dispatch: bool = True):
if use_quant_dispatch:
global _CNCLEP
assert _CNCLEP is None, "cnclep has been initialized"
_CNCLEP = CnclEP(dispatch_token_size,
combine_token_size,
max_num_tokens_per_rank,
num_global_experts,
use_quant_dispatch)
else:
global _CNCLEP_BF16
assert _CNCLEP_BF16 is None, "cnclep_bf16 has been initialized"
_CNCLEP_BF16 = CnclEP(dispatch_token_size,
combine_token_size,
max_num_tokens_per_rank,
num_global_experts,
use_quant_dispatch)
def cnclep_dispatch(token_byte: int,
token_num: int,
send_layout: torch.Tensor,
send_token_num: torch.Tensor,
recv_layout: torch.Tensor,
recv_token_num: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
use_quant_dispatch: bool = True,
):
if use_quant_dispatch:
_CNCLEP.dispatch(token_byte,
token_num,
send_layout,
send_token_num,
recv_layout,
recv_token_num,
send_token,
recv_token)
else:
_CNCLEP_BF16.dispatch(token_byte,
token_num,
send_layout,
send_token_num,
recv_layout,
recv_token_num,
send_token,
recv_token)
def cnclep_combine(token_byte: int,
token_num: int,
send_src_layout: torch.Tensor,
send_dst_layout: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
use_quant_dispatch: bool = True,
):
if use_quant_dispatch:
_CNCLEP.combine(token_byte,
token_num,
send_src_layout,
send_dst_layout,
send_token,
recv_token)
else:
_CNCLEP_BF16.combine(token_byte,
token_num,
send_src_layout,
send_dst_layout,
send_token,
recv_token)
def destroy_cnclep():
global _CNCLEP
if _CNCLEP:
_CNCLEP.destroy()
_CNCLEP = None
global _CNCLEP_BF16
if _CNCLEP_BF16:
_CNCLEP_BF16.destroy()
_CNCLEP_BF16 = None
MluHijackObject.apply_hijack(GroupCoordinator,
GroupCoordinator.graph_capture,
vllm__distributed__parallel_state__GroupCoordinator__graph_capture)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,294 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import get_args
from vllm.platforms import current_platform
from vllm.config import (
ModelConfig,
VllmConfig,
SchedulerConfig,
)
from vllm.config.cache import CacheDType
from vllm.engine.arg_utils import (
EngineArgs,
_raise_unsupported_error,
)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
import vllm_mlu._mlu_utils as mlu_envs
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
@classmethod
def vllm__engine__arg_utils__EngineArgs__get_chunked_prefill_prefix_caching_defaults(
cls,
model_config: ModelConfig,
) -> tuple[bool, bool]:
if model_config.runner_type != "pooling":
'''
=============================
Modify by vllm_mlu
=============================
@brief: mlu-v1 default use unchunked scheduler
'''
if mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED:
default_chunked_prefill = False
else:
default_chunked_prefill = True
'''
==================
End of MLU Hijack
==================
'''
# Disable prefix caching default for hybrid models
# since the feature is still experimental.
default_prefix_caching = not model_config.is_hybrid
else:
assert model_config.pooler_config is not None
pooling_type = model_config.pooler_config.pooling_type
incremental_prefill_supported = (
pooling_type is not None
and pooling_type.lower() == "last"
and getattr(model_config.hf_config, "is_causal", True)
)
default_chunked_prefill = incremental_prefill_supported
default_prefix_caching = incremental_prefill_supported
return default_chunked_prefill, default_prefix_caching
def vllm__engine__arg_utils__EngineArgs___set_default_args(
self, usage_context: UsageContext, model_config: ModelConfig
) -> None:
"""Set Default Arguments for V1 Engine."""
(
default_chunked_prefill,
default_prefix_caching,
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = default_chunked_prefill
logger.debug(
"%s chunked prefill by default",
"Enabling" if default_chunked_prefill else "Disabling",
)
elif (
model_config.runner_type == "pooling"
and self.enable_chunked_prefill
and not default_chunked_prefill
):
logger.warning(
"This model does not officially support chunked prefill. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
)
if self.enable_prefix_caching is None:
self.enable_prefix_caching = default_prefix_caching
logger.debug(
"%s prefix caching by default",
"Enabling" if default_prefix_caching else "Disabling",
)
elif (
model_config.runner_type == "pooling"
and self.enable_prefix_caching
and not default_prefix_caching
):
logger.warning(
"This model does not officially support prefix caching. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
)
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
(
default_max_num_batched_tokens,
default_max_num_seqs,
) = self.get_batch_defaults(world_size)
orig_max_num_batched_tokens = self.max_num_batched_tokens
orig_max_num_seqs = self.max_num_seqs
if self.max_num_seqs is None:
self.max_num_seqs = default_max_num_seqs.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: only set max_num_batched_tokens when enable chunked_prefill
'''
if self.max_num_batched_tokens is None:
self.max_num_batched_tokens = default_max_num_batched_tokens.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
if orig_max_num_batched_tokens is None:
if not self.enable_chunked_prefill:
# If max_model_len is too short, use the default for higher throughput.
self.max_num_batched_tokens = max(
model_config.max_model_len,
self.max_num_batched_tokens,
)
# When using default settings,
# Ensure max_num_batched_tokens does not exceed model limit.
# Some models (e.g., Whisper) have embeddings tied to max length.
self.max_num_batched_tokens = min(
self.max_num_seqs * model_config.max_model_len,
self.max_num_batched_tokens,
)
logger.debug(
"Defaulting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens,
usage_context.value if usage_context else None,
)
if orig_max_num_seqs is None:
if self.max_num_batched_tokens is not None: # For type checking
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
logger.debug(
"Defaulting max_num_seqs to %d for %s usage context.",
self.max_num_seqs,
usage_context.value if usage_context else None,
)
'''
==================
End of MLU Hijack
==================
'''
_VALID_QUANT_ATTN_QKV_DTYPE = ['int8', 'fp8', 'fp8_e4m3']
def vllm__engine__arg_utils__EngineArgs__create_engine_config(
self,
usage_context: UsageContext | None = None,
headless: bool = False,
) -> VllmConfig:
"""
Create the VllmConfig.
NOTE: If VllmConfig is incompatible, we raise an error.
"""
'''
=============================
Modify by vllm_mlu
=============================
@brief: add data parallel params to parallel config.
'''
if self.mlu_config and "decoder_attn_dtype" in self.mlu_config:
if self.mlu_config.get("decoder_attn_dtype") in ["int8", "fp8", "fp8_e4m3"]:
self.kv_cache_dtype = self.mlu_config.get("decoder_attn_dtype")
engine_config = vllm__engine__arg_utils__EngineArgs__create_engine_config_org(
self, usage_context, headless)
world_size = engine_config.parallel_config.world_size_across_dp
tensor_parallel_size = engine_config.parallel_config.tensor_parallel_size
embedding_tp_size = engine_config.mlu_config.layer_embedding_logit_tp_size
if embedding_tp_size:
assert embedding_tp_size >= tensor_parallel_size and embedding_tp_size <= world_size, (
f"embedding_tp_size = {embedding_tp_size} out of bounds. "
f"Require {tensor_parallel_size} ≤ size ≤ {world_size}")
dense_mlp_tp_size = engine_config.mlu_config.layer_dense_mlp_tp_size
if dense_mlp_tp_size:
assert dense_mlp_tp_size >= 1 and dense_mlp_tp_size <= world_size, (
f"dense_mlp_tp_size = {dense_mlp_tp_size} out of bounds. Require 1 ≤ size ≤ {world_size}")
if dense_mlp_tp_size != world_size:
assert not engine_config.mlu_config.is_dpsk_mcc_enabled, (
"dense_mlp_tp_size is not supported when dpsk mcc is enabled.")
if engine_config.model_config.is_longcat_flash and tensor_parallel_size > 1:
raise ValueError("For now, for longcat model, custom dense mlp tp split in data parallel requires dpXtp1. "
"Necessity of this constraint requires further investigation.")
if engine_config.model_config.is_longcat_flash and dense_mlp_tp_size < tensor_parallel_size:
raise ValueError(f"For longcat model, custom dense mlp tp_size {dense_mlp_tp_size} "
f"must be greater than or equal to tensor_parallel_size {tensor_parallel_size}")
if engine_config.model_config.is_deepseek_mla and dense_mlp_tp_size % tensor_parallel_size != 0:
raise ValueError(f"For deepseek mla model, custom mlp tp size {dense_mlp_tp_size} must "
f"be divisible by {tensor_parallel_size}")
if ((engine_config.parallel_config.data_parallel_size > 1 or engine_config.speculative_config is not None
or engine_config.mlu_config.prefill_use_sequence_parallel) and engine_config.mlu_config.prefill_enable_mlugraph):
logger.info("Data parallel or sequence parallel or speculative is enabled, forcing context mlugraph to be disabled.")
engine_config.mlu_config.prefill_enable_mlugraph = False
if engine_config.mlu_config.decoder_attn_dtype:
if engine_config.mlu_config.decoder_attn_dtype not in get_args(CacheDType):
raise ValueError(f"MLU backend does not support {engine_config.mlu_config.decoder_attn_dtype} "
f"decoder_attn_dtype for now")
is_glm4_moe = (hasattr(engine_config.model_config.hf_text_config, "model_type") and
engine_config.model_config.hf_text_config.model_type == "glm4_moe")
if (not (engine_config.model_config.is_deepseek_mla or is_glm4_moe)
and engine_config.mlu_config.decoder_attn_dtype != "auto"):
raise ValueError(f"mlu_config.decoder_attn_dtype only support deepseek_mla and glm4_moe model")
# sequence parallel checks
if (engine_config.mlu_config.prefill_use_sequence_parallel
and engine_config.model_config.hf_text_config.model_type not in ["deepseek_v32", "deepseek_v3"]):
raise ValueError("Prefill sequence parallel can only use in deepseek model.")
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.scheduler_config.enable_chunked_prefill:
raise ValueError("Prefill sequence parallel can not use with chunked prefill for now.")
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.mlu_config.is_dpsk_mcc_enabled:
raise ValueError("Prefill sequence parallel can not use with mcc.")
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.parallel_config.data_parallel_size > 1:
raise ValueError("Prefill sequence parallel can not use with data parallel.")
if (engine_config.mlu_config.prefill_use_sequence_parallel
and engine_config.model_config.hf_text_config.model_type == "deepseek_v3"
and engine_config.quant_config.get_name() != "SmoothQuant"):
raise ValueError("Prefill sequence parallel can only use SmoothQuant for deepseek_v3.")
# disagg constraint
# 1、only support deepseek-v3/r1
# 2、unsupport kv8
if self.kv_transfer_config is not None:
if engine_config.model_config.hf_config.model_type != "deepseek_v3":
raise ValueError("Disagg only support DeepDeek-V3/R1")
if engine_config.cache_config.cache_dtype == "int8":
raise ValueError("Disagg does not support KV cache dtype is int8")
if engine_config.cache_config.enable_prefix_caching:
raise ValueError("Disagg does not support prefix caching")
if isinstance(self.kv_transfer_config, dict):
kv_connector = self.kv_transfer_config.get("kv_connector")
kv_role = self.kv_transfer_config.get("kv_role")
else:
kv_connector = self.kv_transfer_config.kv_connector
kv_role = self.kv_transfer_config.kv_role
if kv_connector != "LMCacheConnectorV1":
raise ValueError("Disagg only support LMCacheConnectorV1 connector")
if kv_role == "kv_consumer":
if not self.enable_chunked_prefill:
raise ValueError("Disagg decoder only support chunk scheduler")
'''
==================
End of MLU Hijack
==================
'''
return engine_config
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs._set_default_args,
vllm__engine__arg_utils__EngineArgs___set_default_args)
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs.create_engine_config,
vllm__engine__arg_utils__EngineArgs__create_engine_config)
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs.get_chunked_prefill_prefix_caching_defaults,
vllm__engine__arg_utils__EngineArgs__get_chunked_prefill_prefix_caching_defaults)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

158
vllm_mlu/entrypoints/llm.py Normal file
View File

@@ -0,0 +1,158 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from tqdm import tqdm
from typing import Callable
from vllm.entrypoints.llm import LLM
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.logger import init_logger
import vllm_mlu._mlu_utils as mlu_envs
from vllm_mlu.mlu_metric import LLMMetric
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__entrypoints__llm__LLM__get_mlu_metrics(
self,
metrics_idx_start,
only_average,
input_len,
output_len,
tp_nums,
quantization,
show_per_iter=False,
is_embedding_task=False,
mm_kwargs=None,
total_prefill_steps=1,
num_speculative_tokens=0,
dp_size=1,
) -> None:
'''
@brief:该函数用来打印vLLM调用generate接口过程中代码统计的各项性能指标数据
@params:
metrics_idx_start: 考虑存在调用generate接口为warmup过程的情况
因此设置该参数可忽略统计[0,metrics_idx_start)之间的数据,默认为0,即所有性能数据有效。
only_average: True 只打印N次调用generate接口的平均性能 False 打印每次调用generate接口的性能及其均值 若N次性能数据波动较大需自行排查测试环境是否稳定。
其余参数:均为模型配置参数
'''
if mlu_envs.VLLM_LATENCY_DEBUG_EN:
batch_size = self.metric.batch_size_list[-1] * dp_size
if mm_kwargs or is_embedding_task:
# The multimodal and pooling model doesn't support the hfu feature yet.
hfu_info, io_efficiency = None, None
else:
hfu_info, io_efficiency = self.llm_engine.get_hfu_info(batch_size, input_len, output_len)
self.metric.calc_metric(
self.llm_engine.model_config.model,
self.llm_engine.model_config.dtype,
metrics_idx_start, only_average,
input_len, output_len, tp_nums,
quantization, show_per_iter,
is_embedding_task, mm_kwargs, total_prefill_steps,
num_speculative_tokens, dp_size=dp_size, hfu_info=hfu_info, io_efficiency=io_efficiency)
else:
print("Warnning:please set VLLM_LATENCY_DEBUG=true!")
def vllm__entrypoints__llm__LLM___run_engine(
self, *, use_tqdm: bool | Callable[..., tqdm] = True
) -> list[RequestOutput | PoolingRequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
pbar = tqdm_func(
total=num_requests,
desc="Processed prompts",
dynamic_ncols=True,
postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
)
'''
=============================
Added by vllm_mlu
=============================
'''
if mlu_envs.VLLM_LATENCY_DEBUG_EN:
total_request_num = self.llm_engine.get_num_unfinished_requests()
e2e_start_time = self.metric.get_mlu_cost_time()
if not self.llm_engine.model_config.is_embedding_task():
peak_memory, block_memory, num_gpu_blocks, num_cpu_blocks = \
self.llm_engine.get_memory_usage()
self.metric.update_memory_usage(peak_memory, block_memory,
num_gpu_blocks, num_cpu_blocks)
'''
==================
End of addition
==================
'''
# Run the engine.
outputs: list[RequestOutput | PoolingRequestOutput] = []
total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
n = len(output.outputs)
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids) * n
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(
len(stp.token_ids) for stp in output.outputs
)
out_spd = total_out_toks / pbar.format_dict["elapsed"]
pbar.postfix = (
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s"
)
pbar.update(n)
else:
pbar.update(1)
if pbar.n == num_requests:
pbar.refresh()
if use_tqdm:
pbar.close()
'''
=============================
Added by vllm_mlu
=============================
'''
if mlu_envs.VLLM_LATENCY_DEBUG_EN:
e2e_end_time = self.metric.get_mlu_cost_time()
e2e_latency = e2e_end_time - e2e_start_time
engine_step_latency, model_forward_latency, mm_encoder_latency = self.llm_engine.get_latency()
self.metric.update_step_latency(engine_step_latency)
if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
self.metric.update_step_latency_device(model_forward_latency)
self.metric.update_mm_encoder_latency_device(mm_encoder_latency)
self.metric.add_metrics(total_request_num, e2e_latency)
'''
==================
End of addition
==================
'''
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))
LLM.metric = LLMMetric()
MluHijackObject.apply_hijack(LLM,
"get_mlu_metrics",
vllm__entrypoints__llm__LLM__get_mlu_metrics)
MluHijackObject.apply_hijack(LLM,
LLM._run_engine,
vllm__entrypoints__llm__LLM___run_engine)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from fastapi import Request
from fastapi.responses import Response
import vllm_mlu._mlu_utils as mlu_envs
from vllm.entrypoints.openai.api_server import (
router, engine_client
)
from vllm_mlu.logger import logger
if mlu_envs.VLLM_SCHEDULER_PROFILE:
logger.info(
"vLLM V1 Scheduler Profiler is enabled in the API server. Please use "
"'tools/utils/post_scheduler_view_action.py' to dump profiling data "
"after all requests finished.")
@router.post("/v1/start_scheduler_profile")
async def start_scheduler_profile(raw_request: Request):
logger.info("VLLM-V1 starting scheduler profiler...")
await engine_client(raw_request).start_scheduler_profile()
return Response(status_code=200)
@router.post("/v1/stop_scheduler_profile")
async def stop_scheduler_profile(raw_request: Request):
logger.info("VLLM-V1 scheduler stopping profiler...")
await engine_client(raw_request).stop_scheduler_profile()
return Response(status_code=200)

41
vllm_mlu/envs.py Normal file
View File

@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import os
from typing import Any, Callable, Dict
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
# begin-env-vars-definition
env_variables: Dict[str, Callable[[], Any]] = {
# max compile thread num
"MAX_JOBS":
lambda: os.getenv("MAX_JOBS", None),
"CMAKE_BUILD_TYPE":
lambda: os.getenv("CMAKE_BUILD_TYPE"),
"COMPILE_CUSTOM_KERNELS":
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
"VERBOSE":
lambda: bool(int(os.getenv('VERBOSE', '0'))),
"LD_LIBRARY_PATH":
lambda: os.getenv("LD_LIBRARY_PATH", None),
"CXX_COMPILER":
lambda: os.getenv("CXX_COMPILER", None),
"C_COMPILER":
lambda: os.getenv("C_COMPILER", None)
}
# end-env-vars-definition
def __getattr__(name: str):
# lazy evaluation of environment variables
if name in env_variables:
return env_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(env_variables.keys())

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

47
vllm_mlu/logger.py Normal file
View File

@@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import logging
from typing import cast
from vllm.logger import _VllmLogger
class _ColorFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
if not record.name.startswith('vllm_mlu'):
return True
if record.levelno == logging.INFO:
record.msg = f"\033[32m{record.msg}\033[0m"
elif record.levelno == logging.WARNING:
record.msg = f"\033[33m{record.msg}\033[0m"
return True
def _apply_mlu_color(logger):
if not logger.handlers:
return
for h in logger.handlers:
if any(isinstance(f, _ColorFilter) for f in h.filters):
return
h.addFilter(_ColorFilter())
def _mlu_init_logger(name: str) -> logging.Logger:
"""Initialize loggers for vllm_mlu module,
and keep the configuration consistent with the vllm module"""
mlu_logger = logging.getLogger(name)
vllm_logger = logging.Logger.manager.loggerDict.get('vllm', None)
if vllm_logger:
mlu_logger.setLevel(vllm_logger.level)
mlu_logger.propagate = vllm_logger.propagate
mlu_logger.handlers = vllm_logger.handlers
return mlu_logger
def init_logger(name: str) -> _VllmLogger:
vllm_logger = cast(_VllmLogger, _mlu_init_logger(name))
_apply_mlu_color(vllm_logger)
return vllm_logger
logger = init_logger(__name__)

41
vllm_mlu/lora/__init__.py Normal file
View File

@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.layers.column_parallel_linear import (
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithShardedLoRA,
QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA,
)
from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
from vllm.lora.layers.row_parallel_linear import (
RowParallelLinearWithLoRA,
RowParallelLinearWithShardedLoRA,
)
from vllm.lora.layers.utils import LoRAMapping
from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA
__all__ = [
"BaseLayerWithLoRA",
"VocabParallelEmbeddingWithLoRA",
"LogitsProcessorWithLoRA",
"ColumnParallelLinearWithLoRA",
"ColumnParallelLinearWithShardedLoRA",
"MergedColumnParallelLinearWithLoRA",
"MergedColumnParallelLinearWithShardedLoRA",
"MergedQKVParallelLinearWithLoRA",
"MergedQKVParallelLinearWithShardedLoRA",
"QKVParallelLinearWithLoRA",
"QKVParallelLinearWithShardedLoRA",
"RowParallelLinearWithLoRA",
"RowParallelLinearWithShardedLoRA",
"ReplicatedLinearWithLoRA",
"LoRAMapping",
"FusedMoEWithLoRA",
]

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
from vllm.platforms import current_platform
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__lora__layers__row_parallel_linear__BaseLinearLayerWithLoRA__apply(
self,
x: torch.Tensor,
bias: torch.Tensor | None,
residual: torch.Tensor | None = None,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual in matmul
'''
output = self.base_layer.quant_method.apply(self.base_layer, x, bias, residual)
'''
==================
End of MLU Hijack
==================
'''
# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if x.ndim == 3 and output.ndim == 3:
output = output.flatten(0, 1)
x = x.flatten(0, 1)
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
)
if not current_platform.can_update_inplace():
output = lora_output
return output
MluHijackObject.apply_hijack(
BaseLinearLayerWithLoRA,
BaseLinearLayerWithLoRA.apply,
vllm__lora__layers__row_parallel_linear__BaseLinearLayerWithLoRA__apply,
)

View File

@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.lora.layers.column_parallel_linear import ColumnParallelLinearWithLoRA
from vllm_mlu.mlu_hijack_utils import MluHijackObject
vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward_org = ColumnParallelLinearWithLoRA.forward
'''
=============================
Modify by vllm_mlu
=============================
@brief: add smooth_quant_scale and use_tp_weight parameters.
'''
def vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward(
self,
input_,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
assert not use_tp_weight, "LoRa does not support use_tp_weight yet."
assert smooth_quant_scale is None, "LoRA does not support smooth quant yet."
return vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward_org(self, input_)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithLoRA.forward,
vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward,
)

View File

@@ -0,0 +1,163 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.distributed import (
split_tensor_along_last_dim,
tensor_model_parallel_all_reduce,
)
from vllm.lora.layers.row_parallel_linear import (
RowParallelLinearWithLoRA,
RowParallelLinearWithShardedLoRA,
)
from vllm.platforms import current_platform
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__lora__layers__row_parallel_linear__RowParallelLinearWithShardedLoRA__apply(
self,
x: torch.Tensor,
bias: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual and bias in matmul
'''
output = self.base_layer.quant_method.apply(
self.base_layer, x, bias, residual)
'''
==================
End of MLU Hijack
==================
'''
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
buffer = torch.zeros(
(self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
shrunk_buffer: torch.Tensor | None = self.punica_wrapper.add_shrink(
buffer, x, self.lora_a_stacked, 1.0
)
if not current_platform.can_update_inplace():
buffer = shrunk_buffer
if self.tp_size > 1:
buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output
# tensor, which is a partial sum due to row parallel. All that
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
# NOTE offset are based on the rank.
shard_size = self.lora_b_stacked[0].shape[2]
offset_start = self.tp_rank * shard_size
lora_output: torch.Tensor | None = self.punica_wrapper.add_expand(
output,
buffer,
self.lora_b_stacked,
self.output_slices,
offset_start=offset_start,
add_input=True,
)
if not current_platform.can_update_inplace():
output = lora_output
output = output.view(*out_orig_shape)
return output
def vllm__lora__layers__row_parallel_linear__RowParallelLinearWithLoRA__forward(
self,
input_: torch.Tensor,
residual: torch.Tensor | None = None,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
output: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add parameters `residual`, `smooth_quant_scale`, `use_tp_weight` and `output`
to keep parameters consistent with RowParallelLinear.forward.
'''
assert (not use_tp_weight) and output is None, (
f"RowParallelLinearWithLoRA.forward does not support use_tp_wight=True"
f" or pass output parameters.")
'''
==================
End of MLU Hijack
==================
'''
# Set up backprop all-reduce.
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
# TODO: simplify code below
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size
)
input_parallel = splitted_input[self.tp_rank].contiguous()
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1) apply residual fusion in matmul like RowParallelLinear
2) add bias in matmul, not after all reduce
'''
# Matrix multiply.
bias_ = (
None if (self.base_layer.tp_rank > 0 or self.base_layer.skip_bias_add)
else self.base_layer.bias
)
residual_ = None if self.base_layer.tp_rank > 0 else residual
output_parallel = self.apply(input_parallel, bias_, residual_)
'''
==================
End of MLU Hijack
==================
'''
if self.base_layer.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
'''
=============================
Modify by vllm_mlu
=============================
@brief: do not add bias after all_reduce
'''
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
'''
==================
End of MLU Hijack
==================
'''
if not self.base_layer.return_bias:
return output
return output, output_bias
MluHijackObject.apply_hijack(
RowParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA.apply,
vllm__lora__layers__row_parallel_linear__RowParallelLinearWithShardedLoRA__apply,
)
MluHijackObject.apply_hijack(
RowParallelLinearWithLoRA,
RowParallelLinearWithLoRA.forward,
vllm__lora__layers__row_parallel_linear__RowParallelLinearWithLoRA__forward,
)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm_mlu.lora.ops.triton_ops.sgmv_expand import sgmv_expand_mlu
from vllm_mlu.lora.ops.triton_ops.sgmv_expand_slice import sgmv_expand_slice_mlu
from vllm_mlu.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink_mlu
from vllm_mlu.lora.ops.triton_ops.lora_shrink_op import lora_shrink
from vllm_mlu.lora.ops.triton_ops.lora_expand_op import lora_expand
__all__ = [
"sgmv_expand_mlu",
"sgmv_expand_slice_mlu",
"sgmv_shrink_mlu",
"lora_expand",
"lora_shrink"
]

View File

@@ -0,0 +1,308 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""
Utilities for Punica kernel construction.
"""
from vllm.triton_utils import tl, triton
'''
=============================
Modify by vllm_mlu
=============================
@brief: modify mm triton
1) add parameter offset_n: mlu add offset_n of matrix B,
value: tl.arange(0, BLOCK_N) + pid_n * BLOCK_N, shape: [BLOCK_N]
add parameter N: mlu add column number of matrix B
2) tiled_b always need mask in case offset_n > N
'''
@triton.jit
def mm_k(
a_ptr,
b_ptr,
ak_stride,
bk_stride,
offset_n,
offset_k,
K: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
N: tl.constexpr,
CAST_TYPE: tl.constexpr,
b_dtype: tl.constexpr,
):
"""
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
B (k x n), iterate, through the K dimension to compute the partial/complete
matrix block product.
If SPLIT_K == 1, the output m x n product is complete.
If SPLIT_K > 1, the thread block computes partial outputs. The partial
outputs are then atomically summed in the caller code.
Args:
a_ptr: Array of pointers, identifying rows of A
b_ptr: Array of pointers, identifying columns of B
ak_stride: K dimension stride of the A matrix
bk_stride: K dimension stride of the B matrix
K: Length of the K dimension
BLOCK_M: M dimension of the output block m x n
BLOCK_N: N dimension of the output block m x n
BLOCK_K: K dimension atom
EVEN_K: True if the blocks of A and B can be loaded without any
masking.
SPLIT_K: Parameter signifying parallelism in the K dimension.
CAST_TYPE: if True, cast the values from the A matrix to the B
matrix dtype.
b_dtype: datatype of the B matrix
"""
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
tiled_a = tl.load(a_ptr)
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N, other=0.0)
else:
tiled_a = tl.load(a_ptr,
mask=offset_k[None, :]
< K - k * (BLOCK_K * SPLIT_K),
other=0)
tiled_b = tl.load(b_ptr,
mask=(offset_k[:, None]
< K - k * (BLOCK_K * SPLIT_K)) & (offset_n < N)[None, :],
other=0.0)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * SPLIT_K * ak_stride
b_ptr += BLOCK_K * SPLIT_K * bk_stride
return accumulator
'''
==================
End of MLU Hijack
==================
'''
@triton.jit
def do_expand_kernel(
pid_n,
lora_index,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
M_LEN,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SAME_STRIDE: tl.constexpr,
SLICE_NUM: tl.constexpr,
EVEN_K: tl.constexpr,
CAST_TYPE: tl.constexpr,
ADD_INPUTS: tl.constexpr,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice,
compute the matrix product and store in the appropriate output location.
Given that this is an expand kernel, we don't perform any split-K reduction
as the K dimension is assumed to be small.
"""
# ls_d*_ptr can be either an integer or a pointer
if SAME_STRIDE: # 'same_stride': True
# integer
cur_lora_d0_stride = ls_d0_ptr
cur_lora_d1_stride = ls_d1_ptr
cur_lora_d2_stride = ls_d2_ptr
else:
# pointer
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
# Identify the input_ptr and lora_ptr from slice_id.
if SLICE_NUM == 1:
cur_input_ptr = input_ptr
cur_lora_ptr = lora_ptr
else:
cur_input_ptr = input_ptr + slice_id * input_d0_stride
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(out_ptr.dtype.element_ty))
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1) remove rbn definition: mlu doesn't support contiguous and
will handle as head corruption
2) re-write b_ptr, use offset_n to identify its position
'''
# Identify the column indices of B to process.
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
# rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# Identify A and B block pointers
offset_k = tl.arange(0, BLOCK_K)
a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride +
offset_k[None, :] * input_d2_stride)
# b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
# offset_k[:, None] * cur_lora_d2_stride +
# rbn[None, :] * cur_lora_d1_stride)
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
offset_k[:, None] * cur_lora_d2_stride +
offset_n[None, :] * cur_lora_d1_stride)
# Compute the block matrix product.
SPLIT_K = 1
accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride, offset_n,
offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, N,
CAST_TYPE, cur_lora_ptr.dtype.element_ty)
'''
==================
End of MLU Hijack
==================
'''
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
if SLICE_NUM == 1:
cur_slice_start = slice_start_loc
else:
cur_slice_start = tl.load(slice_start_loc + slice_id)
# Identify the C output pointers to store the results of the accumulator.
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
offset_cm = tl.arange(0, BLOCK_M)
c_ptr = (out_ptr + ram[:, None] * output_d0_stride +
offset_cn[None, :] * output_d1_stride)
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :]
< (cur_slice_start + N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@triton.jit
def do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_index,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
M_LEN,
ram,
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
SLICE_NUM: tl.constexpr,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice, compute the
matrix product and store in the appropriate output location.
"""
# Identify the lora_ptr from slice_id.
if SLICE_NUM == 1:
# current lora ptr
cur_lora_ptr = lora_ptr
else:
# current lora ptr
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(input_ptr.dtype.element_ty))
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1) remove rbn definition: mlu doesn't support contiguous and
will handle as head corruption
2) re-write b_ptr, use offset_n to identify its position
'''
# Identify the column indices of B to process.
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
# rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# Identify A and B block pointers
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
a_ptr = (input_ptr + ram[:, None] * input_d0_stride +
offset_k[None, :] * input_d1_stride)
# b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
# rbn[None, :] * lora_d1_stride +
# offset_k[:, None] * lora_d2_stride)
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
offset_n[None, :] * lora_d1_stride +
offset_k[:, None] * lora_d2_stride)
# Compute partial/complete block matrix product.
accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_n, offset_k,
K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, N, False,
cur_lora_ptr.dtype.element_ty)
'''
==================
End of MLU Hijack
==================
'''
# Identify the C output pointers to store the results of the accumulator.
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_cm = tl.arange(0, BLOCK_M)
cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr +
slice_id * output_d0_stride)
c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[
None, :] * output_d2_stride
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
accumulator *= scaling
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)

View File

@@ -0,0 +1,308 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
'''
=============================
Modify by vllm_mlu
=============================
@brief: use vllm_mlu hijacked kernel
'''
from vllm_mlu.lora.ops.triton_ops.kernel_utils import do_expand_kernel
'''
==================
End of MLU Hijack
==================
'''
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
@triton.jit
def _lora_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_loc,
input_d0_stride,
input_d1_stride,
input_d2_stride, # 1
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr, # 1
output_d0_stride,
output_d1_stride, # 1
output_hs_ptr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
SLICE_NUM: tl.constexpr,
SAME_STRIDE: tl.constexpr,
):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_mn = tl.program_id(axis=0)
pid_m = pid_mn % cta_m_num
pid_n = (pid_mn // cta_m_num) % cta_n_num
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# When the output dimensions of each slice are the same,cur_n=N, otherwise
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
# qkv linear.
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
if pid_n * BLOCK_N >= curr_N:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_expand_kernel(
pid_n,
lora_id,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
curr_N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M,
BLOCK_N,
BLOCK_K,
SAME_STRIDE,
SLICE_NUM,
EVEN_K,
CAST_TYPE,
ADD_INPUTS,
)
@torch.inference_mode()
def _lora_expand(
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
lora_b_weights: list[torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
output_tensor: torch.Tensor, # shape [num_tokens, hidden_size * num_slices]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (list[torch.Tensor]): lora'b weight
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
LoRA doesn't apply to that token.
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
the A matrix grouped by LoRA IDs.
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
of tokens that are to be processed by LoRA ID lora_ids[i]
lora_token_start_loc (torch.Tensor): A cumulative sum of
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
lora_token_start_loc[i], along with num_tokens_per_lora[i]
identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
offset_start (int, optional): Offset start for output_tensor.
Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(0) == len(lora_b_weights)
assert output_tensor.is_contiguous()
# metadata sanity check.
M = inputs.size(1)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
(
slice_start_tensor,
lora_ptr_tensor,
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
hidden_sizes_tensor,
same_stride,
MAX_N,
) = _get_lora_b_ptr(lora_b_weights, offset_start, inputs.device)
K = lora_b_weights[0].shape[-1] # K= rank
ADD_INPUTS = add_inputs
MAX_LORAS = lora_ids.size(0)
CAST_TYPE = False
NUM_SLICES = len(lora_b_weights)
# Triton kernel configs.
BLOCK_M = 64
BLOCK_N = 128
BLOCK_K = 16
NUM_WARPS = 4
NUM_CTAS = 1
NUM_STAGES = 2
EVEN_K = K % BLOCK_K == 0 # type: ignore
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
# TODO (varun): This grid formulation maximizes parallelization at the
# cost of wasteful thread block launch when only a few input tokens require
# LoRA. This might not be the best in all cases.
grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks simply exit.
MAX_LORAS,
)
_lora_expand_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
M,
MAX_N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_tensor,
inputs.stride(0),
inputs.stride(1),
inputs.stride(2),
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
hidden_sizes_tensor,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
NUM_SLICES,
same_stride,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
)
return
def _lora_expand_fake(
inputs: torch.Tensor,
lora_b_weights: list[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
return
'''
=============================
Modify by vllm_mlu
=============================
@brief: use only vllm operand
'''
lora_expand = _lora_expand
'''
==================
End of MLU Hijack
==================
'''

View File

@@ -0,0 +1,258 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
'''
=============================
Modify by vllm_mlu
=============================
@brief: use vllm_mlu hijacked kernel
'''
from vllm_mlu.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
'''
==================
End of MLU Hijack
==================
'''
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
@triton.jit
def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
token_indices_sorted_by_lora_ids, num_tokens_per_lora,
lora_token_start_loc, lora_ids, scaling,
input_d0_stride, input_d1_stride, lora_d0_stride,
lora_d1_stride, lora_d2_stride, output_d0_stride,
output_d1_stride, output_d2_stride,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_sk_m_n = tl.program_id(axis=0)
pid_sk = pid_sk_m_n % SPLIT_K
pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num
pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (token_indices_sorted_by_lora_ids +
lora_m_indices_start + cta_m_offset)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_id,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
SLICE_NUM)
@torch.inference_mode()
def _lora_shrink(
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
lora_a_weights: list[
torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
scaling: float,
) -> None:
"""
Args:
inputs (torch.Tensor): Input tensor
lora_a_weights (list[torch.Tensor]): LoRA weights
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
LoRA doesn't apply to that token.
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
the A matrix grouped by LoRA IDs.
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
of tokens that are to be processed by LoRA ID lora_ids[i]
lora_token_start_loc (torch.Tensor): A cumulative sum of
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
lora_token_start_loc[i], along with num_tokens_per_lora[i]
identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
scaling (float): Scaling factor.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
for weight in lora_a_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(1) == lora_a_weights[0].size(-1)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
# metadata sanity check
M = inputs.size(0)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
NUM_SLICES = len(lora_a_weights)
MAX_LORAS = lora_ids.size(0)
# Triton kernel configs
BLOCK_M = 32
BLOCK_N = 16
BLOCK_K = 256 if M < 128 else 32
SPLIT_K = 64 if M < 128 else 8
NUM_WARPS = 4
NUM_CTAS = 1
NUM_STAGES = 2
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
# TODO (varun): This grid formulation maximizes parallelization at the
# cost of wasteful thread block launch when only few of the input tokens
# require LoRA. This might not be the best in all cases.
grid = (
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks exit early.
MAX_LORAS,
)
_lora_shrink_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_strides_d0,
lora_strides_d1,
lora_strides_d2,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor.stride(2),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
NUM_SLICES,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
)
return
def _lora_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: list[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
scaling: float,
) -> None:
return
'''
=============================
Modify by vllm_mlu
=============================
@brief: use only vllm operand
'''
lora_shrink = _lora_shrink
'''
==================
End of MLU Hijack
==================
'''

View File

@@ -0,0 +1,238 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import triton
import triton.language as tl
from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size
from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit
def _sgmv_expand_kernel_mlu(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
xm_stride,
xk_stride, # 1
l0_stride, # hidden_size*max_rank
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
The sgmv's expand triton kernel is based on GroupGEMM.
"""
pid = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = tl.arange(0, BLOCK_K)
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \
offset_k[None, :] * xk_stride
b_ptr = lora_ptr + l0_stride * lora_index + \
offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride
'''
==================
End of MLU Hijack
==================
'''
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K)):
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
if EVEN_K:
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
else:
tiled_a = tl.load(a_ptr,
mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)),
other=0)
tiled_b = tl.load(b_ptr,
mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)),
other=0)
'''
==================
End of MLU Hijack
==================
'''
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * xk_stride
b_ptr += BLOCK_K * lora_n_stride
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
offset_cn[None, :] * cn_stride)
M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] <
(cur_seq_start + M)) & (offset_cn[None, :] < N)
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@torch.inference_mode()
def sgmv_expand_mlu(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
'''
=============================
Modify by vllm_mlu
=============================
@brief: Workaround: Adjust block size to meet mlu restrictions.
The grid of mlu triton kernel must less than 65536, it will be out of bound when
the input seq is very long, and causes runtime error. So we need to adjust the block
size to avoid this.
'''
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32)
'''
==================
End of MLU Hijack
==================
'''
BLOCK_K = 16
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
batches,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: call _sgmv_expand_kernel_mlu
'''
_sgmv_expand_kernel_mlu[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
)
'''
==================
End of MLU Hijack
==================
'''
return

View File

@@ -0,0 +1,248 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import triton
import triton.language as tl
from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size
from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit
def _sgmv_expand_slice_kernel_mlu(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
xm_stride,
xk_stride, # 1
l0_stride, # hidden_size*max_rank
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
slice_offset,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
Similar to the 'sgmv_expand' operator, but with an added parameter
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
might be that in the future, we could implement a fusion operator to
achieve the current functionality instead of having to call it multiple
times.
"""
pid = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = tl.arange(0, BLOCK_K)
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \
offset_k[None, :] * xk_stride
b_ptr = lora_ptr + l0_stride * lora_index + \
offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride
'''
==================
End of MLU Hijack
==================
'''
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K)):
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
if EVEN_K:
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
else:
tiled_a = tl.load(a_ptr,
mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)),
other=0)
tiled_b = tl.load(b_ptr,
mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)),
other=0)
'''
==================
End of MLU Hijack
==================
'''
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * xk_stride
b_ptr += BLOCK_K * lora_n_stride
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
offset_cn[None, :] * cn_stride)
M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
(slice_offset + N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@torch.inference_mode()
def sgmv_expand_slice_mlu(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
) -> None:
"""_summary_
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert slice_size == lora_b_weights.size(-2)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
'''
=============================
Modify by vllm_mlu
=============================
@brief: Workaround: Adjust block size to meet mlu restrictions.
The grid of mlu triton kernel must less than 65536, it will be out of bound when
the input seq is very long, and causes runtime error. So we need to adjust the block
size to avoid this.
'''
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32)
'''
==================
End of MLU Hijack
==================
'''
BLOCK_K = 16
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
batches,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: call _sgmv_expand_kernel_mlu
'''
_sgmv_expand_slice_kernel_mlu[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
slice_offset,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
)
'''
==================
End of MLU Hijack
==================
'''
return

View File

@@ -0,0 +1,231 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import triton
import triton.language as tl
from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size
from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit
def _sgmv_shrink_kernel_mlu(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
scaling,
xm_stride, # hidden_size
xk_stride, # 1
l0_stride, # hidden_size*max_rank
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
):
"""
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
introducing SPLIT-K can improve performance
"""
pid = tl.program_id(axis=0)
pid_sk = tl.program_id(axis=1)
cur_batch = tl.program_id(axis=2)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \
offset_k[None, :] * xk_stride
b_ptr = lora_ptr + l0_stride * lora_index + offset_n[None, :] * lora_k_stride + \
offset_k[:, None] * lora_n_stride
'''
==================
End of MLU Hijack
==================
'''
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
if EVEN_K:
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
tiled_a = tl.load(a_ptr,
mask=((offset_k[None, :] < k_remaining) & (offset_m[:, None] < M)),
other=0.0)
tiled_b = tl.load(b_ptr,
mask=((offset_k[:, None] < k_remaining) & (offset_n[None, :] < N)),
other=0.0)
'''
==================
End of MLU Hijack
==================
'''
accumulator += tl.dot(tiled_a, tiled_b)
a_ptr += BLOCK_K * SPLIT_K * xk_stride
b_ptr += BLOCK_K * SPLIT_K * lora_n_stride
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
offset_cn[None, :] * cn_stride)
c_mask = (offset_cm[:, None] <
(cur_seq_start + M)) & (offset_cn[None, :] < N)
accumulator *= scaling
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
@torch.inference_mode()
def sgmv_shrink_mlu(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_a_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor.
"""
assert inputs.dtype == lora_a_weights.dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
assert lora_a_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_a_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert inputs.is_contiguous()
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
assert lora_a_weights.size(1) == 1
lora_a_weights = lora_a_weights.squeeze(dim=1)
else:
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
assert lora_a_weights.is_contiguous()
assert output_tensor.is_contiguous()
# TODO tuning this config
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
'''
=============================
Modify by vllm_mlu
=============================
@brief: Workaround: adjust block size to meet mlu restrictions.
The grid of mlu triton kernel must less than 65536, it will be out of bound when
the input seq is very long, and causes runtime error. So we need to adjust the block
size to avoid this.
'''
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 16)
'''
==================
End of MLU Hijack
==================
'''
BLOCK_K = 32
SPLIT_K = 8
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
SPLIT_K,
batches,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: call _sgmv_shrink_kernel_mlu
'''
_sgmv_shrink_kernel_mlu[grid](
inputs,
lora_a_weights,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_a_weights.stride(0),
lora_a_weights.stride(1),
lora_a_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
)
'''
==================
End of MLU Hijack
==================
'''
return

View File

@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Tuple
from math import ceil
_MLU_MAX_GRID_SIZE = 65536
def adjust_kernel_block_size(
m: int,
block_m: int,
n: int,
block_n: int
) -> Tuple[int, int]:
"""Adjust block size to meet mlu triton grid restrictions.
Calculation of the max block size in candidates list:
LLama3.1-8b-tp1 max n is 14336
LLama3.1-70b-tp4 max n is 7168
LLama3.1-405b-tp8 max n is 6656
when n is 14336, the max sequence length of block size 256 can be
floor(65536 / ceil(14336 / 256)) * 256 = 299520.
"""
candidates_list = [16, 32, 64, 96, 128, 192, 256]
candidates_list_len = len(candidates_list)
m_idx = 1
n_idx = 0 if block_n == 16 else 1
while m_idx < candidates_list_len and n_idx < candidates_list_len:
block_m = candidates_list[m_idx]
block_n = candidates_list[n_idx]
if ceil(m / block_m) * ceil(n / block_n) < _MLU_MAX_GRID_SIZE:
break
if m_idx < candidates_list_len:
m_idx += 1
if n_idx < candidates_list_len:
n_idx += 1
if ceil(m / block_m) * ceil(n / block_n) >= _MLU_MAX_GRID_SIZE:
raise ValueError(f"the max seq len {m} is too long for lora triton kernel")
return block_m, block_n

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import Optional, Tuple, Union, final
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm_mlu.lora.ops.triton_ops import sgmv_expand_mlu
from vllm_mlu.lora.ops.triton_ops import sgmv_expand_slice_mlu
from vllm_mlu.lora.ops.triton_ops import sgmv_shrink_mlu
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
@final
class PunicaWrapperMLU(PunicaWrapperCPU):
"""
PunicaWrapperMLU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica triton kernel.
"""
def _shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
add_inputs,
)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_inputs,
)

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from dataclasses import dataclass
from typing import List
from vllm.forward_context import DPMetadata
@dataclass
class MLUDPMetadata(DPMetadata):
# mlu platform arguments
# token num for current dp group
token_num: int = None
# token num offset for current dp group
token_num_offset: int = None
# whether we can use reduce scatter for both attn layer and mlp layer
layer_use_reduce_scatter: bool = False
# token num need to be pad for prefill, then we can do reduce scatter +
# all gather to optimize comm time
prefill_pad_to_token_num: int = -1
# token num in each dp group, the list length is attn data parallel size
# used to do all gather in dp groups after all reduce in attn
token_split_list: List[int] = None
# token num in each card, the list length is world size
# used to do all gather in all cards after reduce scatter in attn
attn_token_split_list_reduce_scatter: List[int] = None
# token num in each tp group, the list length is tensor parallel size
# used to do all gather in tp groups after reduce scatter in moe
moe_token_split_list_reduce_scatter: List[int] = None
# prefill or decode stage in each dp group
dp_is_prefill: List[bool] = None
# ADDITIONAL fields for merged compute and communication.
# Global sequence lengths for each batch size for prefill stage.
seq_lens: List[int] = None
# Batch sizes for each attn dp rank for prefill stage.
batch_sizes: List[int] = None
# ADDITIONAL fields for custom split for embedding, logits and dense mlp layer
# token num in each emb tp group, the list length is tensor parallel size
# used to do all gather in emb tp groups after reduce scatter in moe
emb_token_split_list: List[int] = None
# batch sizes in each logits tp group, the list length is tensor parallel size
# used to do all gather in logits tp groups after reduce scatter in moe
logits_batch_split_list: List[int] = None
# token num in each dense mlp group, the list length is dense mlp tp size
# used to do one more all gather after dense mlp and before reduce scatter
dense_attn_token_split_list: List[int] = None
@staticmethod
def make_oot(
data_parallel_rank: int,
data_parallel_size: int,
tensor_parallel_size: int,
dp_token_nums: List[int],
dp_is_prefill: List[bool],
prefill_dispatch_use_RS_AG: bool,
seq_lens: List[int] = None,
batch_sizes: List[int] = None,
emb_query_lens: List[int] = None,
logits_batch_sizes: List[int] = None,
dense_attn_token_split_list: List[int] = None,
) -> "MLUDPMetadata":
token_num_offset = sum(dp_token_nums[:data_parallel_rank])
token_num = dp_token_nums[data_parallel_rank]
token_split_list = dp_token_nums
attn_can_use_reduce_scatter = all(
(num != 0 and num % tensor_parallel_size == 0)
for num in token_split_list
)
all_split_token_num_equal = all(
num == token_split_list[0] for num in token_split_list
)
layer_can_use_reduce_scatter = (
attn_can_use_reduce_scatter and all_split_token_num_equal
)
attn_token_split_list_reduce_scatter = None
moe_token_split_list_reduce_scatter = None
prefill_pad_to_token_num = -1
tp_world_size = data_parallel_size * tensor_parallel_size
if layer_can_use_reduce_scatter:
attn_token_split_list_reduce_scatter = (
[token_split_list[0] // tensor_parallel_size] * tp_world_size
)
moe_token_split_list_reduce_scatter = (
attn_token_split_list_reduce_scatter[:tensor_parallel_size]
)
elif (
prefill_dispatch_use_RS_AG
and all(is_prefill for is_prefill in dp_is_prefill)
):
dp_group_max_token_nums = max(dp_token_nums)
prefill_pad_to_token_num = (
(dp_group_max_token_nums + tensor_parallel_size - 1)
// tensor_parallel_size
) * tensor_parallel_size
attn_token_split_list_reduce_scatter = (
[prefill_pad_to_token_num // tensor_parallel_size] * tp_world_size
)
return MLUDPMetadata(
max_tokens_across_dp_cpu=None,
num_tokens_across_dp_cpu=None,
token_num=token_num,
token_num_offset=token_num_offset,
token_split_list=token_split_list,
layer_use_reduce_scatter=layer_can_use_reduce_scatter,
prefill_pad_to_token_num=prefill_pad_to_token_num,
attn_token_split_list_reduce_scatter=attn_token_split_list_reduce_scatter,
moe_token_split_list_reduce_scatter=moe_token_split_list_reduce_scatter,
seq_lens=seq_lens,
batch_sizes=batch_sizes,
dp_is_prefill=dp_is_prefill,
emb_token_split_list=emb_query_lens,
logits_batch_split_list=logits_batch_sizes,
dense_attn_token_split_list=dense_attn_token_split_list,
)

79
vllm_mlu/mlu_hijack.py Normal file
View File

@@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import importlib.util
from vllm_mlu._mlu_utils import *
from vllm_mlu.logger import logger
def is_module_available(module_name):
spec = importlib.util.find_spec(module_name)
return spec is not None
def check_environ_compatibility():
if is_module_available('apex'):
logger.error(f"The `apex` package is currently present in your environment, "
f"which may cause model accuracy issues or other problems. It is "
f"strongly recommended that you uninstall it before using vLLM.")
# Check environment compatibility first before applying mlu hijack.
check_environ_compatibility()
logger.info(f"[MLU] Apply Monkey Patch.")
# Apply v1 hijack
import vllm_mlu.v1.engine.core
import vllm_mlu.v1.engine.core_client
import vllm_mlu.v1.engine.llm_engine
import vllm_mlu.v1.engine.async_llm
import vllm_mlu.v1.core.sched.scheduler
import vllm_mlu.v1.core.single_type_kv_cache_manager
import vllm_mlu.v1.core.kv_cache_utils
import vllm_mlu.v1.core.kv_cache_manager
import vllm_mlu.v1.executor.abstract
import vllm_mlu.v1.executor.ray_executor
import vllm_mlu.v1.executor.multiproc_executor
import vllm_mlu.v1.sample.rejection_sampler
import vllm_mlu.v1.worker.lora_model_runner_mixin
import vllm_mlu.v1.worker.block_table
import vllm_mlu.v1.worker.gpu_input_batch
import vllm_mlu.v1.worker.kv_connector_model_runner_mixin
import vllm_mlu.v1.attention.backends.gdn_attn
import vllm_mlu.v1.attention.backends.mla.flashmla
import vllm_mlu.compilation.fix_functionalization
# Apply common hijack
import vllm_mlu.attention.layer
import vllm_mlu.benchmarks.datasets
import vllm_mlu.config.model
import vllm_mlu.config.scheduler
import vllm_mlu.config.speculative
import vllm_mlu.config.vllm
import vllm_mlu.utils
import vllm_mlu.distributed.parallel_state
import vllm_mlu.distributed.kv_transfer.kv_connector.factory
import vllm_mlu.engine.arg_utils
import vllm_mlu.entrypoints.llm
import vllm_mlu.lora.layers.base_linear
import vllm_mlu.lora.layers.row_parallel_linear
import vllm_mlu.lora.layers.column_parallel_linear
import vllm_mlu.model_executor.parameter
import vllm_mlu.model_executor.layers.linear
import vllm_mlu.model_executor.layers.rotary_embedding
import vllm_mlu.model_executor.layers.quantization.utils.w8a8_utils
import vllm_mlu.model_executor.layers.quantization.fp8
import vllm_mlu.model_executor.layers.activation
import vllm_mlu.model_executor.layers.layernorm
import vllm_mlu.model_executor.layers.fused_moe.layer
import vllm_mlu.model_executor.model_loader.tensorizer_loader
import vllm_mlu.model_executor.models.registry
import vllm_mlu.model_executor.models.config
import vllm_mlu.multimodal.utils
if is_module_available('lmcache'):
import vllm_mlu.distributed.kv_transfer.kv_connector.v1.lmcache_connector
if VLLM_CI_ACCURACY_TEST:
import vllm_mlu.model_executor.model_loader.dummy_loader
if VLLM_SCHEDULER_PROFILE:
import vllm_mlu.entrypoints.openai.api_server

View File

@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.logger import init_logger
logger = init_logger(__name__)
IS_GATED=False
class MluHijackObject:
hijack_objs = []
@classmethod
def apply_hijack(cls, obj, org_func, hijack_func,
verify_orig_func_exists: bool = False):
"""
Optional Args:
verify_orig_func_exists (bool): If True, verifies that hijack succeeds
"""
cls.hijack_objs.append((obj, org_func, hijack_func))
if type(org_func) == str:
org_func_name = org_func
else:
if isinstance(org_func, property):
split_name = org_func.fget.__name__.split('__')
else:
split_name = org_func.__name__.split('__')
org_func_name = split_name[-1]
if org_func_name == "":
assert split_name[-2] != "", f"invalid {org_func.__name__} to apply hijack"
org_func_name = split_name[-2] + "__"
if len(split_name) >= 3 and split_name[-3] == "":
org_func_name = "__" + org_func_name
if verify_orig_func_exists and not hasattr(obj, org_func_name):
raise AttributeError(f"function {org_func_name} is not part of {obj}")
setattr(obj, org_func_name, hijack_func)
if (verify_orig_func_exists and getattr(obj, org_func_name) is not hijack_func):
raise AttributeError(
f"function {org_func_name} of {obj} failed to be swapped to {hijack_func}")
@classmethod
def undo_hijack(cls, obj_ = None, hijack_func_ = None):
if obj_ and hijack_func_:
for obj, org_func, hijack_func in cls.hijack_objs:
if obj_ == obj and hijack_func == hijack_func_:
if type(org_func) == str:
if hasattr(obj, org_func):
delattr(obj, org_func)
else:
org_func_name = org_func.__name__
setattr(obj, org_func_name, org_func)
return
for obj, org_func, hijack_func in cls.hijack_objs:
if type(org_func) == str:
if hasattr(obj, org_func):
delattr(obj, org_func)
else:
org_func_name = org_func.__name__
setattr(obj, org_func_name, org_func)
TypedDict = {
"hidden_size": 0,
"vocab_size": 0,
"ffn_inner_size": 0,
"moe_inner_size": 0,
"layer_num": 0,
"moe_layer_num": 0,
"head_num": 0,
"head_size": 0,
"head_num_kv": 0,
"tp_num": 0,
"shared_expert_intermediate_size": 0,
"shared_experts": 0,
"qk_nope_head_dim": 0,
"qk_rope_head_dim": 0,
"q_lora_rank": 0.0,
"num_attention_heads": 0,
"kv_lora_rank": 0,
"v_head_dim": 0,
"use_gated_ffn": False,
"experts_num": 0,
"topk_num": 0,
"use_causal_mask": False,
"cla_coeffient": 0,
"kv_cache_dtype": "",
"smooth_quant_type": "",
"data_type": "",
"model_type": "",
"filter_data_type": "",
}
def set_is_gated(flag):
global IS_GATED
IS_GATED=flag
def get_is_gated():
return IS_GATED

412
vllm_mlu/mlu_metric.py Normal file
View File

@@ -0,0 +1,412 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import time
import statistics
import pandas as pd
import numpy as np
import json
import os
from datetime import datetime
from vllm.logger import init_logger
from vllm_mlu._mlu_utils import VLLM_LATENCY_DEBUG_WITH_DEVICE_EN, VLLM_DUMP_MLU_INFO_EN
from vllm.model_executor.layers.quantization import get_quantization_config
logger = init_logger(__name__)
millisecond2second_unit = 1000
class LLMMetric:
def __init__(self)->None:
self.batch_size_list = []
self.context_latency_list = []
self.e2e_latency_list = []
self.per_token_latency_list = [ [] ]
self.per_token_latency_device_list = [ [] ]
self.mm_encoder_latency_device_list = [ [] ]
self.peak_memory = 0
self.block_memory = 0
self.num_total_gpu_blocks = 0
self.num_total_cpu_blocks = 0
self.num_free_gpu_blocks_list = [ [] ]
self.num_free_cpu_blocks_list = [ [] ]
self.num_spec_tokens = 0
self.draft_acceptance_rate = 0.0
self.context_latency_device = 0.0
self.generate_latency_device = 0.0
self.mm_encoder_latency_device = 0.0
def reset_metric(self):
self.batch_size_list = []
self.context_latency_list = []
self.e2e_latency_list = []
self.per_token_latency_list = [ [] ]
self.per_token_latency_device_list = [ [] ]
self.mm_encoder_latency_device_list = [ [] ]
self.num_free_gpu_blocks_list = [ [] ]
self.num_free_cpu_blocks_list = [ [] ]
self.num_spec_tokens = 0
self.draft_acceptance_rate = 0.0
@classmethod
def get_mlu_cost_time(cls):
torch.mlu.synchronize()
return time.perf_counter()
def is_prefill_stage(self):
return len(self.per_token_latency_list[-1]) == 0
def update_memory_usage(self, peak_memory, block_memory, num_total_gpu_blocks, num_total_cpu_blocks):
self.peak_memory = peak_memory
self.block_memory = block_memory
self.num_total_gpu_blocks = num_total_gpu_blocks
self.num_total_cpu_blocks = num_total_cpu_blocks
def update_step_block_usage(self, num_free_gpu_blocks, num_free_cpu_blocks):
self.num_free_gpu_blocks_list[-1].append(num_free_gpu_blocks)
self.num_free_cpu_blocks_list[-1].append(num_free_cpu_blocks)
def update_step_latency(self, step_latency):
if isinstance(step_latency, list):
self.per_token_latency_list[-1].extend(step_latency)
else:
self.per_token_latency_list[-1].append(step_latency)
def update_step_latency_device(self, step_latency):
if isinstance(step_latency, list):
self.per_token_latency_device_list[-1].extend(step_latency)
else:
self.per_token_latency_device_list[-1].append(step_latency)
def update_mm_encoder_latency_device(self, step_latency):
if isinstance(step_latency, list):
if len(step_latency) == 0:
return
assert len(step_latency) == 1, f"Not supported! Model with multi mm encoder steps. {len(step_latency)} {step_latency}"
self.mm_encoder_latency_device_list[-1].extend(step_latency)
else:
self.mm_encoder_latency_device_list[-1].append(step_latency)
def update_spec_decode_metrics(self, spec_decode_metrics):
self.num_spec_tokens = spec_decode_metrics.num_spec_tokens
self.draft_acceptance_rate = spec_decode_metrics.draft_acceptance_rate
def add_metrics(self, batch_size, e2e_latency)->None:
self.batch_size_list.append(batch_size)
self.e2e_latency_list.append(e2e_latency)
self.per_token_latency_list.append([]) # new iter
self.per_token_latency_device_list.append([])
self.mm_encoder_latency_device_list.append([])
self.num_free_gpu_blocks_list.append([])
self.num_free_cpu_blocks_list.append([])
def get_weight_dtype_str(self, model_path, model_dtype, quantization) -> str:
# get weight dtype based on quantization config if exists
if quantization == 'fp8':
return quantization
if quantization is not None:
quant_method = get_quantization_config(quantization)
# combine the model path with the quantization config file name
quant_config_paths = quant_method.get_config_filenames()
# if there are multiple quantization config files, return the first one existed
for quant_config_path in quant_config_paths:
quant_config_path = os.path.join(model_path, quant_config_path)
# check if the quantization config file exists
if not os.path.exists(quant_config_path):
continue
with open(quant_config_path, 'r') as f:
quant_config = json.load(f)
quant_config = quant_method.from_config(quant_config)
# for smoothquant and weightonly, return the quantization name with the weight bits
if quantization == "smoothquant" or quantization == ["weightonly"]:
return "{}-int{}".format(quant_config.get_name(), quant_config.weight_bits)
else:
# for other quantization methods, return the quantization name
return quant_config.get_name()
# if the quantization config file does not exist, just return the quanization name
return quant_config_path.get_name()
else:
# remove the prefix of model dtype from torch config
return str(model_dtype).split(".")[-1]
def to_csv(self, filename: str, show_per_iter=False) -> None:
if show_per_iter:
df = pd.DataFrame(self.metrics_data)
df = pd.DataFrame([df.iloc[-1]], columns=df.columns)
memory_df = pd.DataFrame(self.memory_metrics_data)
memory_df = pd.DataFrame([memory_df.iloc[-1]], columns=memory_df.columns)
else:
df = pd.DataFrame(self.metrics_data)
memory_df = pd.DataFrame(self.memory_metrics_data)
df_mean = df.mean().round(3)
memory_df_mean = memory_df.mean().round(3)
header = ["datetime", "model",
"weight dtype", self.batch_size_name,
]
header = header + list(self.mm_kwargs.keys())
header = header + ["input len", "output len", "tp",
self.context_latency_name, self.per_token_latency_name]
data = [datetime.now().strftime("%Y-%m-%d %H:%M:%S"), self.model,
self.weight_dtype_str, int(self.metrics_data[self.batch_size_name][0])]
data = data + [self.mm_kwargs[k] for k in self.mm_kwargs.keys()]
data = data + [self.input_len, self.output_len, self.tp,
df_mean[self.context_latency_name], df_mean[self.per_token_latency_name]]
if self.num_spec_tokens > 0:
header += [self.per_step_latency_name]
data += [df_mean[self.per_step_latency_name]]
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
if self.is_v1_multimodal:
header += [self.mm_encoder_latency_device_name,]
data += [df_mean[self.mm_encoder_latency_device_name],]
header += [self.context_latency_device_name, self.per_token_latency_device_name]
data += [df_mean[self.context_latency_device_name], df_mean[self.per_token_latency_device_name]]
header += [self.e2e_latency_name, self.e2e_throughput_name, self.decoder_throughput_name,]
if self.num_spec_tokens > 0:
header += [self.k_name, self.acceptance_rate_name]
header += [self.decode_times_name,
self.peak_memory_name, self.block_memory_name]
data += [
df_mean[self.e2e_latency_name], df_mean[self.e2e_throughput_name], df_mean[self.decoder_throughput_name],
]
if self.num_spec_tokens > 0:
data += [self.num_spec_tokens, df_mean[self.acceptance_rate_name],]
data += [df_mean[self.decode_times_name], memory_df_mean[self.peak_memory_name], memory_df_mean[self.block_memory_name],]
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and self.save_hfu_info:
header += [self.context_hfu_name, self.decoder_hfu_name, self.decoder_io_efficiency_name]
data += [
df_mean[self.context_hfu_name], df_mean[self.decoder_hfu_name],
df_mean[self.decoder_io_efficiency_name]
]
data_dict = dict(zip(header, data))
df_csv = pd.DataFrame(data_dict, index=[0])
append = False
if os.path.isfile(filename):
try:
df_old = pd.read_csv(filename)
append = (df_old.columns.tolist() == header)
except Exception as e:
logger.info(f"Existing {filename} failed to be read and will be overwritten")
if append:
df_csv.to_csv(filename, mode='a', header=False, index=False)
logger.info(f"Metric appended to existing {filename}")
else:
df_csv.to_csv(filename, index=False)
logger.info(f"Metric written to {filename}")
def calc_metric(self, model, model_dtype, metrics_idx_start, only_average,
input_len, output_len, tp_nums, quantization,
show_per_iter=False, is_embedding_task=False, mm_kwargs=None,
total_prefill_steps=1, num_spec_tokens=0, dp_size=1, hfu_info=None, io_efficiency=0.0) -> None:
keep_digits = 2
def round_fn(data):
return round(data, keep_digits)
metrics_idx_end = len(self.per_token_latency_list) - 1 # without last []
idx_range = range(metrics_idx_start, metrics_idx_end)
# specify entries to write to csv
self.is_v1_multimodal = mm_kwargs
self.mm_kwargs = mm_kwargs if mm_kwargs else {} # multimodal args
self.batch_size_name = "batch size"
self.input_len = input_len
self.output_len = output_len
self.tp = tp_nums
self.dp = dp_size
self.model = model
self.context_latency_name = "context latency(ms)"
self.mm_encoder_latency_device_name = "multimodal encoder latency device(ms)"
self.context_latency_device_name = "context latency device(ms)"
if num_spec_tokens > 0:
self.per_step_latency_name = "per step latency(ms)"
self.per_token_latency_device_name = "per step latency device(ms)"
else:
self.per_token_latency_device_name = "per token latency device(ms)"
self.per_token_latency_name = "per token latency(ms)"
self.e2e_latency_name = "e2e latency(ms)"
self.e2e_throughput_name = "e2e throughput(tokens/s)"
self.decoder_throughput_name = "decoder throughput(tokens/s)"
self.k_name = "K"
self.acceptance_rate_name = "acceptance rate"
self.decode_times_name = "decode times"
self.weight_dtype_str = self.get_weight_dtype_str(model, model_dtype, quantization)
self.num_spec_tokens = num_spec_tokens
rate_list=[]
rate=0
if num_spec_tokens > 0:
for i in range(metrics_idx_end):
if len(self.per_token_latency_list[i]) - total_prefill_steps == 0:
logger.warning("For now output_len is 0, no need mtp info, if you need mtp info, please increase output_len.")
rate_list.append(0.0)
else:
rate_list.append(((self.output_len - 1) / (float)(len(self.per_token_latency_list[i]) - total_prefill_steps) - 1) / num_spec_tokens)
rate = statistics.fmean(rate_list[metrics_idx_start: metrics_idx_end])
metrics_data = [
(
self.batch_size_name, [self.dp * int(self.batch_size_list[i]) for i in idx_range]
),
(
self.context_latency_name, [round_fn(millisecond2second_unit * sum(self.per_token_latency_list[i][:total_prefill_steps])) for i in idx_range]
),
(
self.per_token_latency_name, [
0.0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \
round_fn(statistics.fmean(self.per_token_latency_list[i][total_prefill_steps:]) * (len(self.per_token_latency_list[i]) - total_prefill_steps) / (self.output_len - 1) * millisecond2second_unit) for i in idx_range
]
),
]
if num_spec_tokens > 0:
metrics_data += [(self.per_step_latency_name, [
0.0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \
round_fn(statistics.fmean(self.per_token_latency_list[i][total_prefill_steps:]) * millisecond2second_unit) for i in idx_range
])]
metrics_data += [
(
self.e2e_latency_name, [round_fn(millisecond2second_unit * self.e2e_latency_list[i]) for i in idx_range]
),
(
self.e2e_throughput_name, [
round_fn(self.dp * (output_len / self.e2e_latency_list[i]) * self.batch_size_list[i]) \
for i in idx_range
]
),
(
self.decoder_throughput_name, [
0.0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \
round_fn(self.dp * ((output_len-1) / sum(self.per_token_latency_list[i][total_prefill_steps:])) * self.batch_size_list[i]) \
for i in idx_range
]
),
(
self.decode_times_name, [
0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \
len(self.per_token_latency_list[i][total_prefill_steps:]) for i in idx_range
]
),
]
if num_spec_tokens > 0:
metrics_data.append((self.k_name, num_spec_tokens))
metrics_data.append((self.acceptance_rate_name, [rate_list[i] for i in idx_range]))
insert_latency_device = VLLM_LATENCY_DEBUG_WITH_DEVICE_EN
if insert_latency_device:
device_item_idx = 3
if self.is_v1_multimodal:
mm_encoder_latency_device = [round_fn(sum(self.mm_encoder_latency_device_list[i])) for i in idx_range]
metrics_data.insert(device_item_idx, (self.mm_encoder_latency_device_name, mm_encoder_latency_device))
device_item_idx = device_item_idx + 1
context_latency_device = [round_fn(sum(self.per_token_latency_device_list[i][:total_prefill_steps])) for i in idx_range]
per_token_latency_device = [0.0 if len(self.per_token_latency_device_list[i]) <= total_prefill_steps else \
round_fn(statistics.fmean(self.per_token_latency_device_list[i][total_prefill_steps:])) for i in idx_range]
metrics_data.insert(device_item_idx, (self.context_latency_device_name, context_latency_device))
metrics_data.insert(device_item_idx + 1, (self.per_token_latency_device_name, per_token_latency_device))
self.metrics_data = dict(metrics_data)
# Print
df = pd.DataFrame(self.metrics_data)
if show_per_iter:
df = pd.DataFrame([df.iloc[-1]], columns=df.columns)
else:
df.loc["Average(" + str(metrics_idx_end-metrics_idx_start) + "iters)"] = df.mean().round(keep_digits)
if only_average:
df = pd.DataFrame([df.iloc[-1]], columns=df.columns)
df.index.name = 'iter index'
df[self.batch_size_name] = df[self.batch_size_name].astype(int)
if num_spec_tokens > 0:
df[self.k_name] = df[self.k_name].astype(int)
self.peak_memory_name = "profile memory(GB)"
self.block_memory_name = "total cache memory(GB)"
memory_metrics_data = [
(
self.peak_memory_name, [round_fn(self.peak_memory / 1024 / 1024 / 1024) for i in idx_range]
),
(
self.block_memory_name, [round_fn(self.block_memory / 1024 / 1024 / 1024) for i in idx_range]
),
]
self.memory_metrics_data = dict(memory_metrics_data)
# Print
memory_df = pd.DataFrame(self.memory_metrics_data)
if show_per_iter:
memory_df = pd.DataFrame([memory_df.iloc[-1]], columns=memory_df.columns)
else:
memory_df.loc["Average(" + str(metrics_idx_end-metrics_idx_start) + "iters)"] = memory_df.mean().round(keep_digits)
if only_average:
memory_df = pd.DataFrame([memory_df.iloc[-1]], columns=memory_df.columns)
memory_df.index.name = 'iter index'
pd.set_option('display.colheader_justify', 'center')
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
print("********************************* Test Info****************************")
mm_params_text = " ".join(f"{key}:{value}" for key, value in self.mm_kwargs.items())
print("Generation Config {} input len:{} output len:{} tp_nums:{} quantization:{}".format(
mm_params_text, input_len,output_len,tp_nums,quantization))
self.context_latency_device = np.mean(self.metrics_data['context latency device(ms)'])
self.generate_latency_device = np.mean(self.metrics_data[self.per_token_latency_device_name])
if self.is_v1_multimodal:
self.mm_encoder_latency_device = np.mean(self.metrics_data[self.mm_encoder_latency_device_name])
print("*************************Performance Info******************************")
print(f"Total prefill steps: {total_prefill_steps}")
print(df.to_string())
if not is_embedding_task:
# embedding task does not do profile run, so does not have memory infos
print(memory_df.to_string())
if insert_latency_device :
context_latency = np.mean(self.metrics_data['context latency device(ms)'])
generate_latency = np.mean(self.metrics_data[self.per_token_latency_device_name])
if num_spec_tokens > 0:
print("MTP token accept rate: {:.2f}%".format(rate*100))
self.dump_performance_info(hfu_info, io_efficiency)
avg_latency_e2e = sum(sum(self.per_token_latency_list[i]) for i in idx_range) / len(idx_range)
print("Avg latency without host time is :", avg_latency_e2e)
print("***********************************************************************")
# collect context_hfu and
self.save_hfu_info = False
if insert_latency_device:
if VLLM_DUMP_MLU_INFO_EN:
try:
import device_info
self.save_hfu_info = True
except:
logger.info(f"try import device_info failed. try pip install device_info.")
self.context_hfu_name = "Context HFU"
self.decoder_hfu_name = "Decoder HFU"
self.decoder_io_efficiency_name = "Decoder IO Efficiency"
if self.save_hfu_info:
self.metrics_data[self.context_hfu_name] = hfu_info["context_hfu"] * 100
self.metrics_data[self.decoder_hfu_name] = hfu_info["decoder_hfu"] * 100
self.metrics_data[self.decoder_io_efficiency_name] = io_efficiency * 100
if csv_path := os.getenv("OUTPUT_CSV_PATH"):
try:
if dir_path := os.path.dirname(csv_path):
os.makedirs(dir_path, exist_ok=True)
self.to_csv(csv_path, show_per_iter=show_per_iter)
except Exception as e:
logger.error(f"Invalid OUTPUT_CSV_PATH: {csv_path} to dump metrics, Error: {e}")
def dump_performance_info(self, hfu_info, io_efficiency):
try:
if VLLM_DUMP_MLU_INFO_EN and hfu_info != None:
hfu_info["context_hfu"] = hfu_info["context_hfu"] / (self.context_latency_device / millisecond2second_unit)
hfu_info["decoder_hfu"] = hfu_info["decoder_hfu"] / (self.generate_latency_device / millisecond2second_unit)
io_efficiency = io_efficiency / self.generate_latency_device
print(f"Context HFU-visible: {hfu_info['context_hfu']:.3%}")
print(f"Decoder HFU-visible: {hfu_info['decoder_hfu']:.3%}")
print(f"Decoder IO Efficiency: {io_efficiency:.3%}")
elif hfu_info != None:
print(f"Context FLOPS-visible: {hfu_info['context_flops']}")
print(f"Decoder FLOPS-visible: {hfu_info['decoder_flops']}")
else:
logger.info("Unsupport dump performance information")
except Exception as e:
logger.error(f"Failed to dump performance information: {str(e)}")

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.model_executor.layers.activation import QuickGELU
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu import _mlu_ops as mlu_ops
def vllm__model_executor__activation__QuickGELU__forward_oot(self, x: torch.Tensor) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: implement forward_oot
'''
return mlu_ops.active(x, 'quick_gelu', False)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(QuickGELU,
QuickGELU.forward_oot,
vllm__model_executor__activation__QuickGELU__forward_oot)

View File

@@ -0,0 +1,277 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import math
from typing import Callable
from scipy.linalg import hadamard
import torch
from torch import nn
import torch.nn.functional as F
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
def hadamard_transform_ref(x, scale=1.0):
"""
x: (..., dim)
out: (..., dim)
"""
x_shape = x.shape
dim = x.shape[-1]
x = x.reshape(-1, dim)
log_dim = math.ceil(math.log2(dim))
dim_padded = 2 ** log_dim
if dim != dim_padded:
x = F.pad(x, (0, dim_padded - dim))
out = F.linear(
x,
torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device),
)
out = out * scale
return out[..., :dim].reshape(*x_shape)
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.bfloat16
hidden_size = x.size(-1)
return hadamard_transform_ref(x, scale=hidden_size ** -0.5)
class Compressor(nn.Module):
def __init__(self,
vllm_config: VllmConfig,
rope,
compress_ratio: int = 4,
head_dim: int = 512,
rotate: bool = False,
prefix: str = "",
**kwargs,):
super().__init__()
config = vllm_config.model_config.hf_config
self.dim = config.dim
self.head_dim = head_dim
self.rope_head_dim =config.rope_head_dim
self.nope_head_dim = head_dim - config.rope_head_dim
self.compress_ratio = compress_ratio
self.overlap = compress_ratio == 4
self.rotate = rotate
coff = 1 + self.overlap
self.norm_eps = config.norm_eps
self.window_size = config.window_size
self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
# wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
# The first half of dimensions for overlapping compression and second half for normal compression.
self.wkv = ReplicatedLinear(
self.dim,
coff * self.head_dim,
bias=False,
quant_config=None,
params_dtype = torch.float32,
prefix=f"{prefix}.wkv",
)
self.wgate = ReplicatedLinear(
self.dim,
coff * self.head_dim,
bias=False,
quant_config=None,
params_dtype = torch.float32,
prefix=f"{prefix}.wgate",
)
self.norm = RMSNorm(self.head_dim, self.norm_eps)
self.rotary_emb = rope
hf_config = vllm_config.model_config.hf_config
assert hasattr(hf_config, "cached_state_num"), \
f"cached_state_num is not set in hf_config"
cached_state_num = hf_config.cached_state_num
self.register_buffer(
"kv_state",
torch.zeros(cached_state_num, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"score_state",
torch.full(
(cached_state_num, coff * compress_ratio, coff * self.head_dim),
float("-inf"),
dtype=torch.float32,
),
persistent=False,
)
self.hadamard_matrix = torch.tensor(
hadamard(self.head_dim, dtype=float), dtype=torch.get_default_dtype(), device="mlu")
def overlap_transform(self, tensor: torch.Tensor, value=0):
# tensor: [b,s,r,2d]
b, s, _, _ = tensor.size()
ratio, d = self.compress_ratio, self.head_dim
new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
return new_tensor
def forward_decode(
self,
x: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
kv_cache: torch.Tensor,
window_offset: int,
compressor_slot_mapping: torch.Tensor,
):
x = x.float()
kv_pack, _ = self.wkv(x)
score_pack, _ = self.wgate(x)
mlu_ops.fused_compress_single_kv(
kv=kv_pack.unsqueeze(1), # (token, D) -> (B, S, D)
score=score_pack.unsqueeze(1), # (token, D) -> (B, S, D)
position=positions,
ape=self.ape,
kv_state=self.kv_state,
score_state=self.score_state,
gamma=self.norm.weight,
sin=self.rotary_emb.sin_,
cos=self.rotary_emb.cos_,
hadamard_matrix=self.hadamard_matrix,
slot_mapping=compressor_slot_mapping,
kv_cache=kv_cache,
kv_cache_scale=None,
eps=self.norm_eps,
overlap=self.overlap,
rotate=self.rotate,
state_idx=batch_to_kv_state,
)
# Here, return fake compressed_kv.
return None
def forward(
self,
x: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
kv_cache: torch.Tensor,
window_offset: int,
compressor_slot_mapping: torch.Tensor,
):
common_metadata = get_common_metadata()
forward_func: Callable = (
self.forward_prefill if common_metadata.is_prefill_only
else self.forward_decode
)
return forward_func(
x,
positions,
attn_metadata,
batch_to_kv_state,
kv_cache,
window_offset,
compressor_slot_mapping,
)
def forward_prefill(
self,
x: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
kv_cache: torch.Tensor,
window_offset: int,
compressor_slot_mapping: torch.Tensor,
):
common_metadata = get_common_metadata()
seq_lens = common_metadata.seq_lens
query_start_loc = common_metadata.query_start_loc
query_lens = query_start_loc[1:] - query_start_loc[:-1]
ratio, overlap = self.compress_ratio, self.overlap
dtype = x.dtype
x = x.float()
kv_pack, _ = self.wkv(x)
score_pack, _ = self.wgate(x)
compress_lens = query_lens // self.compress_ratio
cu_compress_lens = torch.cat([
torch.tensor([0], dtype=compress_lens.dtype, device=compress_lens.device),
torch.cumsum(compress_lens, dim=0)],
)
compress_positions = []
for i in range(len(seq_lens)):
seqlen = (query_start_loc[i+1] - query_start_loc[i]).item()
remainder = seqlen % ratio
cutoff = seqlen - remainder
pos = positions[query_start_loc[i]: query_start_loc[i+1]]
positions_ = pos[:cutoff:ratio].contiguous()
compress_positions.append(positions_)
kv_positions = torch.cat(compress_positions, dim=0)
total_compress_len = cu_compress_lens[-1].item()
kv = torch.empty(
[total_compress_len, self.head_dim],
dtype=kv_pack.dtype,
device=kv_pack.device,
)
mlu_ops.fused_compress_multi_kv(
kv = kv_pack,
score = score_pack,
kv_state = self.kv_state,
score_state = self.score_state,
state_batch_idx = batch_to_kv_state,
cu_seqlens = query_start_loc,
ape = self.ape,
max_seqlen = common_metadata.max_query_len,
overlap = overlap,
compressed_kv = kv,
)
if kv.size(0) == 0:
return kv.unsqueeze(-2).to(dtype) # (compress_token_num, 1, head_size)
kv = self.norm(kv.to(dtype))
kv_rope = kv[..., -self.rope_head_dim:].unsqueeze(-2)
# use compressed cu_seqlens here, so can not call rotary_emb directly
kv_rope = mlu_ops.rotary_embedding(
kv_rope,
self.rotary_emb.sin_,
self.rotary_emb.cos_,
kv_positions,
torch.tensor([0, kv_positions.size(0)], dtype=torch.int32, device=kv_positions.device), # cu_seqlens
True, # interleaved
True, # discrete
False,
common_metadata.max_query_len,
)
if self.rotate:
kv = rotate_activation(kv)
mlu_ops.reshape_paged_cache(
kv.unsqueeze(1),
None,
kv_cache,
None,
compressor_slot_mapping,
)
return kv.unsqueeze(-2) # (compress_token_num, 1, head_size)

View File

@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional
import torch
from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_gather)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm_mlu.model_executor.models.dp_utils import (
tensor_model_parallel_all_gather_dp, DataParallelRuntimeParams)
class DPLogitsProcessor(LogitsProcessor):
"""DP LogitsProcessor."""
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
dp_params: Optional[DataParallelRuntimeParams] = None,
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
batch_sizes = None
if (lm_head.tp_group is not None
and dp_params is not None
and dp_params.logits_batch_split_list is not None):
batch_sizes = dp_params.logits_batch_split_list
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=batch_sizes,
rank=lm_head.tp_rank,
hidden_states=hidden_states,
group=lm_head.tp_group,
)
logits = lm_head.quant_method.apply(
lm_head, hidden_states, bias=embedding_bias)
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits, tp_group=lm_head.tp_group)
else:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits, tp_group=lm_head.tp_group)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., : self.org_vocab_size]
if batch_sizes is not None:
offset = sum(batch_sizes[:lm_head.tp_rank])
logits = logits[offset : offset + batch_sizes[lm_head.tp_rank]]
return logits
def forward(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
dp_params: Optional[DataParallelRuntimeParams] = None,
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
# Get the logits for the next tokens.
logits = self._get_logits(
hidden_states, lm_head, embedding_bias, dp_params)
if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap
logits = torch.tanh(logits)
logits = logits * self.soft_cap
if self.scale != 1.0:
logits *= self.scale
return logits

View File

@@ -0,0 +1,219 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
method_has_implemented_embedding,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod,
VocabParallelEmbedding,
DEFAULT_VOCAB_PADDING_SIZE,
get_masked_input_and_mask,
pad_vocab_size,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed.communication_op import (
tensor_model_parallel_all_reduce,
)
from vllm.distributed import (
divide,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_logits_tp_group,
get_logits_tp_world_size,
get_logits_tp_rank,
)
from vllm_mlu.model_executor.models.dp_utils import (
DataParallelRuntimeParams,
tensor_model_parallel_all_gather_dp,
)
class DPVocabParallelEmbedding(VocabParallelEmbedding):
"""DP Embedding parallelized in the vocabulary dimension."""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
torch.nn.Module.__init__(self)
"""
=============================
Modify by vllm_mlu
=============================
@brief: add self.tp_group, world_size and tp_rank to support other parallel
"""
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_world_size = get_tensor_model_parallel_world_size()
self.tp_group = None
logits_tp_world_size = get_logits_tp_world_size()
if logits_tp_world_size != self.tp_world_size:
self.tp_group = get_logits_tp_group()
self.tp_world_size = logits_tp_world_size
self.tp_rank = get_logits_tp_rank()
# Keep the input dimensions.
tp_rank = self.tp_rank
self.tp_size = self.tp_world_size
"""
=================
End of MLU Hijack
=================
"""
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size, tp_rank,
self.tp_size)
self.embedding_dim = embedding_dim
quant_method = None
if quant_config is not None:
quant_method = quant_config.get_quant_method(self, prefix=prefix)
if quant_method is None:
quant_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self) is VocabParallelEmbedding
quant_method_implements_embedding = method_has_implemented_embedding(
type(quant_method))
if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError(
f"The class {type(quant_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
self.quant_method: QuantizeMethodBase = quant_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
self.tp_size)
assert (self.shard_indices.num_elements_padded ==
self.num_embeddings_per_partition)
self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index -
self.shard_indices.org_vocab_start_index)
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)
self.quant_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
def forward(self, input_,
dp_params: Optional[DataParallelRuntimeParams] = None):
token_split_list = None
if (dp_params is not None
and self.tp_group is not None
and dp_params.emb_token_split_list is not None):
token_split_list = dp_params.emb_token_split_list
input_ = tensor_model_parallel_all_gather_dp(
group_num_tokens=token_split_list,
rank=self.tp_rank,
hidden_states=input_.reshape(-1, 1),
group=self.tp_group,
).reshape(-1)
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
input_,
self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index,
)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self, masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel, tp_group=self.tp_group)
if token_split_list is not None:
offset = sum(token_split_list[:self.tp_rank])
output = output[offset : offset + token_split_list[self.tp_rank]]
return output
class DPParallelLMHead(DPVocabParallelEmbedding):
"""DP Parallelized LM head.
NOTE: A copy of ParallelLMHead class, and only change its parent
from VocabParallelEmbedding to DPVocabParallelEmbedding.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config,
prefix)
self.quant_config = quant_config
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
"""Tie the weights with word embeddings."""
# GGUF quantized embed_tokens.
if self.quant_config and self.quant_config.get_name() == "gguf":
return embed_tokens
else:
self.weight = embed_tokens.weight
return self
def forward(self, input_):
del input_
raise RuntimeError("LMHead's weights should be used in the sampler.")

View File

@@ -0,0 +1,224 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import torch.nn.functional as F
from typing import Any
from vllm.distributed import (
get_parallel_world_size_with_group,
get_parallel_rank_with_group,
)
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
ColumnParallelLinear,
RowParallelLinear
)
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.mlu_hijack_utils import set_is_gated
logger = init_logger(__name__)
class FeedForward(torch.nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
up_proj_name: str,
is_gated: bool,
down_proj_name: str,
bias: bool,
quant_config: QuantizationConfig | None = None,
skip_bias_add: bool = False,
reduce_results: bool = True,
prefix: str = "",
tp_group: Any = None,
keep_full_weights: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.is_gated = is_gated
self.bias = bias
self.up_proj_name = up_proj_name
self.down_proj_name = down_proj_name
self.quant_config = quant_config
self.is_initialized = False
self.skip_bias_add = skip_bias_add
self.reduce_results = reduce_results
self.use_bt_ffn = True
set_is_gated(self.is_gated)
# modify tp_size, tp_rank and tp_group when enable data parallel
self.tp_size = get_parallel_world_size_with_group(tp_group)
self.tp_rank = get_parallel_rank_with_group(tp_group)
self.tp_group = tp_group
self.keep_full_weights = keep_full_weights
if self.keep_full_weights:
self.tp_size = 1
self.tp_rank = 0
self.tp_group = None
# up_proj with gate or not
if self.is_gated:
up_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.{up_proj_name}",
tp_group=self.tp_group,
keep_full_weights=keep_full_weights)
else:
up_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=bias,
skip_bias_add=skip_bias_add,
quant_config=quant_config,
prefix=f"{prefix}.{up_proj_name}",
tp_group=self.tp_group,
keep_full_weights=keep_full_weights)
self.register_module(up_proj_name, up_proj)
# down_proj
down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=bias,
skip_bias_add=skip_bias_add,
reduce_results=reduce_results,
quant_config=quant_config,
prefix=f"{prefix}.{down_proj_name}",
tp_group=self.tp_group,
keep_full_weights=keep_full_weights)
self.register_module(down_proj_name, down_proj)
def prepare_weight(self):
if not self.is_initialized:
# alpha and beta are 1.0 and 0.0 respectively due to the fact that we don't need residual for now
self.alpha = 1.0
self.beta = 0.0
# place it here to avoid the overhead of calling it in the forward pass
self.is_initialized = True
def _forward(self, hidden_states):
self.prepare_weight()
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
act_dict = {
"relu": F.relu,
"gelu": F.gelu,
"silu": F.silu,
}
fc1 = F.linear(hidden_states, up_proj.weight, bias=up_proj.bias)
if self.is_gated:
d = fc1.shape[-1] // 2
fc1 = act_dict[self.hidden_act](fc1[..., :d]) * fc1[..., d:]
else:
fc1 = act_dict[self.hidden_act](fc1)
fc2 = F.linear(fc1, down_proj.weight, bias=None)
fc2 = tensor_model_parallel_all_reduce(fc2)
if not self.skip_bias_add:
fc2 = fc2 + down_proj.bias if down_proj.bias is not None else fc2
return fc2
def forward_naive(
self,
hidden_states,
residual: torch.Tensor | None = None,
smooth_quant_scale: torch.Tensor | None = None
):
'''
used by quant_tools
'''
assert self.quant_config is None, "ffn naive forward dosen't support quantization"
assert smooth_quant_scale is None, "ffn naive forward dosen't support smooth_quant_scale"
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
residual_ = None if self.tp_rank > 0 else residual
fc1, bias = up_proj(hidden_states)
if bias is not None:
fc1 += bias
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
out, bias = down_proj(fc1, residual=residual_)
if self.skip_bias_add:
return out, bias
return out
def forward(
self,
hidden_states,
residual: torch.Tensor | None = None,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
output: torch.Tensor | None = None,
):
self.prepare_weight()
if self.use_bt_ffn is False:
return self.forward_naive(hidden_states, residual, None)
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
residual_ = None if self.tp_rank > 0 else residual
if (self.quant_config is None and not isinstance(up_proj, BaseLayerWithLoRA)
and not isinstance(down_proj, BaseLayerWithLoRA)):
# The matmul formula is the following:
# mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual
# output = active(mul_out)
# Notes: We cannot use the activation function in matmul because it does not support gated operation
# we might support its in tmo matmul in the future
up_proj_weight = up_proj.weight
down_proj_weight = down_proj.weight
if self.keep_full_weights and use_tp_weight:
up_proj_weight = up_proj.tp_weight
down_proj_weight = down_proj.tp_weight
fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj_weight, up_proj.bias,
None, 'none', self.alpha, self.beta)
act_out = mlu_ops.active(fc1.float(), self.hidden_act, self.is_gated).to(dtype=fc1.dtype)
beta = 0.0
if residual_ is not None:
beta = 1.0
residual_ = residual_.view(-1, residual_.shape[-1])
out_ = mlu_ops.matmul(act_out, down_proj_weight, None, residual_, 'none', self.alpha, beta)
# bias if existed need to add after second matmul according to the original design of vllm
if self.reduce_results:
out = tensor_model_parallel_all_reduce(out_, self.tp_group)
else:
out = out_
# do the bias add if needed
if not self.skip_bias_add:
out = out + down_proj.bias if down_proj.bias is not None else out
else:
return out, down_proj.bias
else:
fc1, bias = up_proj(hidden_states, smooth_quant_scale=smooth_quant_scale, use_tp_weight=use_tp_weight)
if bias is not None:
fc1 += bias
input_scale= None
if (self.quant_config is not None and self.quant_config.get_name() == "SmoothQuant" and
self.quant_config.input_quant_method == "per_token" and not self.quant_config.is_fp8):
down_proj.quant_method.skip_quant_input = True
down_proj_smooth = down_proj.smooth
if self.keep_full_weights and use_tp_weight:
assert down_proj.tp_smooth is not None, "tp_smooth is not initialized"
down_proj_smooth = down_proj.tp_smooth
fc1, input_scale = mlu_ops.per_token_smooth_quantize(
fc1, down_proj_smooth, None, None, act_mode=self.hidden_act, is_gated=self.is_gated)
else:
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
out, bias = down_proj(
fc1, residual=residual_, smooth_quant_scale=input_scale,
use_tp_weight=use_tp_weight, output=output)
if self.skip_bias_add:
return out, bias
return out

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,935 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
import functools
import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_moe_kernel_gptq_awq,
write_zeros_to_output,
get_default_config,
try_get_optimal_moe_config,
_get_config_quant_dtype,
)
from vllm.model_executor.layers.fused_moe.utils import (
activation_without_mul,
disable_inplace,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm_mlu.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize
import vllm_mlu._mlu_ops as mlu_ops
logger = init_logger(__name__)
@triton.jit
def fused_moe_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
b_bias_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
stride_bbe, # bias expert stride
stride_bbn, # bias N stride
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
'''
=============================
Modify by vllm_mlu
=============================
@brief: Split the program ID into two dimensions (pid_0 and pid_1)
'''
pid_0 = tl.program_id(axis=0)
pid_1 = tl.program_id(axis=1)
pid = pid_1 * tl.num_programs(axis=0) + pid_0
'''
==================
End of MLU Hijack
==================
'''
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(
c_ptr,
stride_cm,
stride_cn,
pid_n,
N,
offs_token,
token_mask,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
compute_type,
)
return
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
if use_int8_w8a16:
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
)
b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8 or use_int8_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
)
# channel-wise
elif per_channel_quant:
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
)
b_scale = tl.load(b_scale_ptrs)
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
# tensor-wise
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
if HAS_BIAS:
# bias shape: [num_experts, N]
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
if use_fp8_w8a8:
# acc used to enable fp8_fast_accum
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if HAS_BIAS:
accumulator = accumulator + bias[None, :]
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: torch.Tensor | None,
B_scale: torch.Tensor | None,
B_zp: torch.Tensor | None,
topk_weights: torch.Tensor | None,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: dict[str, Any],
compute_type: tl.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: list[int] | None = None,
B_bias: torch.Tensor | None = None,
) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8 or use_int8_w8a8:
assert B_scale is not None
assert block_shape is None or triton.cdiv(
B.size(-2), block_shape[0]
) == B_scale.size(-2)
assert block_shape is None or triton.cdiv(
B.size(-1), block_shape[1]
) == B_scale.size(-1)
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
M = A.size(0)
num_tokens = M * top_k
EM = sorted_token_ids.size(0)
if A.size(0) < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique,
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
'''
=============================
Modify by vllm_mlu
=============================
@brief: Split the program ID into two dimensions (pid_0, pid_1)
'''
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']), triton.cdiv(
B.shape[1], META['BLOCK_SIZE_N']), )
assert not (use_int8_w8a16 or use_int4_w4a16)
'''
==================
End of MLU Hijack
==================
'''
HAS_BIAS = B_bias is not None
if (
(use_int8_w8a16 or use_int4_w4a16)
and block_shape is not None
and block_shape[1] > 0
):
assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
num_valid_tokens=num_tokens,
group_size=block_shape[1],
num_experts=B.size(0),
bit=4 if use_int4_w4a16 else 8,
)
config = config.copy()
config.update(
get_moe_wna16_block_config(
config=config,
use_moe_wna16_cuda=use_moe_wna16_cuda,
num_valid_tokens=num_tokens,
size_k=A.size(1),
size_n=B.size(1),
num_experts=B.size(1),
group_size=block_shape[1],
real_top_k=top_k,
block_size_m=config["BLOCK_SIZE_M"],
)
)
if use_moe_wna16_cuda:
bit = 4 if use_int4_w4a16 else 8
ops.moe_wna16_gemm(
A,
C,
B,
B_scale,
B_zp,
topk_weights if mul_routed_weight else None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
top_k,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
bit,
)
return
fused_moe_kernel_gptq_awq[grid](
A,
B,
C,
B_scale,
B_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.size(1),
A.size(1),
EM,
num_tokens,
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
B_scale.stride(0),
B_scale.stride(2),
B_scale.stride(1),
B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0,
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
**config,
)
else:
config = config.copy()
config["SPLIT_K"] = 1
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
if block_shape is not None:
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
fused_moe_kernel[grid](
A,
B,
C,
B_bias,
A_scale,
B_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.size(1),
B.size(2),
EM,
num_tokens,
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
B_bias.stride(0) if B_bias is not None else 0,
B_bias.stride(1) if B_bias is not None else 0,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
HAS_BIAS=HAS_BIAS,
BLOCK_SIZE_K=BLOCK_SIZE_K,
**config,
)
def outplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
True,
activation,
apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
ocp_mx_scheme,
per_channel_quant,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
w1_bias,
w2_bias,
)
def outplace_fused_experts_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> None:
pass
direct_register_custom_op(
op_name="outplace_fused_experts_mlu",
op_func=outplace_fused_experts,
mutates_args=["hidden_states"],
fake_impl=outplace_fused_experts_fake,
dispatch_key="PrivateUse1",
tags=(
()
if is_torch_equal_or_newer("2.7.0")
else (torch.Tag.needs_fixed_stride_order,)
),
)
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
return torch.ops.vllm.outplace_fused_experts_mlu(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu")
RELU2_NO_MUL: str = activation_without_mul("relu2")
def fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> torch.Tensor:
# Check constraints.
if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
elif ocp_mx_scheme is not None:
if ocp_mx_scheme in {
"w_mxfp4_a_mxfp4",
"w_mxfp4_a_mxfp6_e3m2",
"w_mxfp4_a_mxfp6_e2m3",
}:
# 16bit activation and fp4x2 packed weight
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
elif ocp_mx_scheme in {
"w_mxfp6_e3m2_a_mxfp6_e3m2",
"w_mxfp6_e2m3_a_mxfp6_e2m3",
}:
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
"hidden size mismatch"
)
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
else:
assert hidden_states.size(1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
)
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
num_tokens = hidden_states.size(0)
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
config_dtype = _get_config_dtype_str(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
ocp_mx_scheme=ocp_mx_scheme,
dtype=hidden_states.dtype,
)
# Note: for use_int8_w8a16 or use_int4_w4a16, the activations are
# quantized prior to calling fused_experts.
quant_dtype = _get_config_quant_dtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
ocp_mx_scheme=ocp_mx_scheme,
)
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.size(),
w2.size(),
top_k_num,
config_dtype,
block_shape=block_shape,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Only use the default config
'''
config = get_default_config(M, E, N, w1.shape[2], topk_ids.shape[1],
hidden_states.dtype, block_shape)
'''
==================
End of MLU Hijack
==================
'''
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
cache13 = torch.empty(
M * top_k_num * max(N, K),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
# This needs separate memory since it's used concurrently with cache1
intermediate_cache2 = torch.empty(
(M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
if inplace and not disable_inplace():
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
if ocp_mx_scheme is not None:
# TODO: On platforms for which `current_platform.supports_mx()` is True
# and for which we have a native OCP mx fused MOE kernel,
# this dequantization step should not be done.
if ocp_mx_scheme in {
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
}:
# Weight has to be dequantized for mxfp4 emulation.
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
w1_scale = None
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2:
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w1_scale = None
w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3:
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
w1_scale = None
w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
w2_scale = None
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (
chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE, num_tokens),
)
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.size()
if tokens_in_chunk == 0:
break
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
a1q_scale: Optional[torch.Tensor] = None
if use_fp8_w8a8:
qcurr_hidden_states, a1q_scale = _fp8_quantize(
curr_hidden_states, a1_scale, block_shape)
else:
qcurr_hidden_states = curr_hidden_states
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
)
invoke_fused_moe_kernel(
qcurr_hidden_states,
w1,
intermediate_cache1,
a1q_scale,
w1_scale,
w1_zp,
curr_topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
apply_router_weight_on_input,
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w1_bias,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Activate by mlu_ops
'''
intermediate_cache2 = mlu_ops.active(intermediate_cache1.view(-1, N),
act_mode=activation,
is_gated=True)
'''
==================
End of MLU Hijack
==================
'''
a2q_scale: Optional[torch.Tensor] = None
if use_fp8_w8a8:
qintermediate_cache2, a2q_scale = _fp8_quantize(
intermediate_cache2, a2_scale, block_shape)
else:
qintermediate_cache2 = intermediate_cache2
invoke_fused_moe_kernel(
qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
curr_topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w2_bias,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace moe_sum with torch.sum
Reference Links: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py#L1513
'''
if topk_ids.shape[1] == 2:
torch.add(
intermediate_cache3[:, 0],
intermediate_cache3[:, 1],
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
).squeeze(dim=1)
elif topk_ids.shape[1] > 2:
torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
'''
==================
End of MLU Hijack
==================
'''
return out_hidden_states

View File

@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional, Callable
import torch
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
def vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
#TODO: support `routed_scaling_factor`
assert routed_scaling_factor == 1.0, (
f"routed_scaling_factor {routed_scaling_factor} is not supported for MLU."
)
use_fused_kernel = topk_group is None
if use_fused_kernel:
assert not enable_eplb, f"MLU not support eplb in fused_moe kernel."
assert use_grouped_topk is False and num_expert_group is None and topk_group is None, \
f"Following params: use_grouped_topk, num_expert_group, topk_group are not support yet."
return mlu_ops.fused_moe(
x,
router_logits,
layer.w13_weight, layer.w2_weight,
None, None, # bias1, bias2
None, # residual
None, # input_smooth
None, # act_smooth
None, None, # w1_scale, w2_scale
top_k,
renormalize,
True, # gated
activation
)
else:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
if self.rocm_aiter_moe_enabled:
assert expert_map is None
return self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
MluHijackObject.apply_hijack(
UnquantizedFusedMoEMethod,
UnquantizedFusedMoEMethod.forward_oot,
vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot
)

View File

@@ -0,0 +1,248 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
import torch
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv, round_up
'''
=============================
Modify by vllm_mlu
=============================
@brief: Implementation of moe_align_block_size_triton.
Note: the implemtentation has been removed from vllm since the
cuda implementation is more efficient.
'''
@triton.jit
def moe_align_block_size_stage1(
topk_ids_ptr,
tokens_cnts_ptr,
num_experts: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = pid * tokens_per_thread
off_c = (pid + 1) * num_experts
for i in range(tokens_per_thread):
if start_idx + i < numel:
idx = tl.load(topk_ids_ptr + start_idx + i)
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
@triton.jit
def moe_align_block_size_stage2(
tokens_cnts_ptr,
num_experts: tl.constexpr,
):
pid = tl.program_id(0)
last_cnt = 0
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
last_cnt = last_cnt + token_cnt
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
@triton.jit
def moe_align_block_size_stage3(
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
):
last_cumsum = 0
off_cnt = num_experts * num_experts
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
tl.store(cumsum_ptr + i, last_cumsum)
tl.store(total_tokens_post_pad_ptr, last_cumsum)
@triton.jit
def moe_align_block_size_stage4(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = tl.load(cumsum_ptr + pid)
end_idx = tl.load(cumsum_ptr + pid + 1)
for i in range(start_idx, end_idx, block_size):
tl.store(expert_ids_ptr + i // block_size, pid)
start_idx = pid * tokens_per_thread
off_t = pid * num_experts
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
numel)):
expert_id = tl.load(topk_ids_ptr + i)
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
# Triton implementation based on:
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
def moe_align_block_size_triton(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts, )
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
dtype=torch.int32,
device=topk_ids.device)
cumsum = torch.zeros((num_experts + 1, ),
dtype=torch.int32,
device=topk_ids.device)
tokens_per_thread = cdiv(numel, num_experts)
sorted_token_ids.fill_(numel)
expert_ids.zero_()
moe_align_block_size_stage1[grid](
topk_ids,
tokens_cnts,
num_experts,
numel,
tokens_per_thread,
)
moe_align_block_size_stage2[grid](
tokens_cnts,
num_experts,
)
moe_align_block_size_stage3[(1, )](
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
)
moe_align_block_size_stage4[grid](
topk_ids,
sorted_token_ids,
expert_ids,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
)
'''
==================
End of MLU Hijack
==================
'''
def moe_align_block_size(
topk_ids: torch.Tensor,
block_size: int,
num_experts: int,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
- expert_map: A tensor of shape [num_experts] that maps the expert index
from the global space to the local index space of the current
expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1.
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
should be padded to a multiple of block_size,
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids = torch.zeros((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Only use triton to implement moe_align_block_size
'''
moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
'''
==================
End of MLU Hijack
==================
'''
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
from math import prod
from typing import List, Optional, Tuple
import torch
from vllm.utils.math_utils import cdiv
def _fp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
block_shape: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
assert block_shape is not None
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
return A, A_scale

View File

@@ -0,0 +1,278 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import (
get_tensor_model_parallel_world_size,
get_tp_group
)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.compressor import (
Compressor,
rotate_activation,
)
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
logger = init_logger(__name__)
class Indexer(torch.nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
rope,
compress_ratio: int = 4,
prefix: str = "",
**kwargs,
):
super().__init__()
config = vllm_config.model_config.hf_config
self.dim = config.dim
self.n_heads = config.index_n_heads
self.tp_size = get_tensor_model_parallel_world_size()
self.n_local_heads = config.index_n_heads // self.tp_size
self.head_dim = config.index_head_dim
self.rope_head_dim = config.rope_head_dim
self.index_topk = config.index_topk
self.q_lora_rank = config.q_lora_rank
self.window_size = config.window_size
self.block_size = vllm_config.cache_config.block_size
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=None,
prefix=f"{prefix}.wq_b",
)
self.weights_proj = ReplicatedLinear(
self.dim,
self.n_heads,
bias=False,
quant_config=None,
params_dtype = torch.bfloat16,
prefix=f"{prefix}.weights_proj",
)
self.softmax_scale = self.head_dim ** -0.5
self.merged_softmax_scale = (self.head_dim ** -0.5) * (self.n_heads ** -0.5)
self.compress_ratio = compress_ratio
self.max_model_len = vllm_config.model_config.max_model_len
self.rotary_emb = rope
self.tp_group = get_tp_group()
self.compressor = Compressor(vllm_config, self.rotary_emb, compress_ratio, self.head_dim, True, f"{prefix}.compressor")
self.freqs_cis = None
def forward_prefill(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
weights: torch.Tensor,
attn_metadata: AttentionMetadata,
k_full: torch.Tensor,
context_lens: torch.Tensor,
):
assert attn_metadata.prefill.chunked_context is None, \
f"Prefill chunked context is not supported."
query_start_loc = attn_metadata.prefill.query_start_loc
cu_seq_q_lens = query_start_loc
cu_seq_k_lens = torch.zeros(
context_lens.size(0) + 1, dtype=torch.int32, device=q.device,
)
torch.cumsum(context_lens, dim=0, out=cu_seq_k_lens[1:])
attn_metadata.prefill.query_start_loc
seq_lens = torch.diff(cu_seq_k_lens)
batch_size = seq_lens.shape[0]
new_block_tables = torch.empty(
[attn_metadata.num_prefill_tokens, self.index_topk],
dtype=torch.int32,
device=q.device,
)
new_context_lens = torch.empty(
[attn_metadata.num_prefill_tokens],
dtype=torch.int32,
device=q.device,
)
q_seq_lens = cu_seq_q_lens[1:]-cu_seq_q_lens[:-1]
max_seq_len = q_seq_lens.max().item()
batch_size = q_seq_lens.size(0)
max_compressed_kv_len = max_seq_len // self.compress_ratio
kv_cache_block_table = torch.zeros([batch_size, max_compressed_kv_len], dtype=torch.int32, device=q.device)
# The layout of linear kv is as follows:
# | bs0_origin_kv | bs1_origin_kv | bs0_compressed_kv | bs1_compressed_kv |
for i in range(batch_size):
start = cu_seq_k_lens[i].item()
kv_cache_block_table[i] = torch.arange(
start, start + max_compressed_kv_len,
dtype=torch.int32,
device=q.device,
)
# offset total origin_kv len
kv_cache_block_table = kv_cache_block_table + cu_seq_q_lens[-1]
# query: (tokens, index_head, index_head_dim)
# k_full: (tokens, index_head_dim)
# weights: (tokens, index_head, 1)
mlu_ops.masked_indexer_select_paged_kv_prefill(
query=q,
key_value=k_full,
weights=weights.unsqueeze(-1),
kv_cache_block_table=kv_cache_block_table,
cu_seq_q_lens=cu_seq_q_lens,
cu_seq_k_lens=cu_seq_k_lens,
index_topk=self.index_topk,
kv_cache_block_size=self.block_size,
softmax_scale=self.merged_softmax_scale,
q_scale=None,
k_scale_cache=None,
sparse_block_table=new_block_tables,
sparse_context_lens=new_context_lens,
compress_ratio=self.compress_ratio,
kv_cache_block_table_offset=None,
)
return new_block_tables, new_context_lens
def forward_decode(
self,
q: torch.Tensor,
x: torch.Tensor,
k_cache: torch.Tensor,
weights: torch.Tensor,
attn_metadata: AttentionMetadata,
):
block_table = attn_metadata.decode.block_table
batch_size = block_table.shape[0]
seq_len = x.shape[0] // batch_size
q = q.view(batch_size, seq_len, *q.shape[1:])
weights = weights.view(batch_size, seq_len, *weights.shape[1:])
seq_lens = attn_metadata.decode.seq_lens
k_block_table = block_table
seq_len = x.shape[0] // batch_size
new_block_tables = torch.empty(
[batch_size, seq_len, self.index_topk],
dtype=torch.int32,
device=block_table.device,
)
new_context_lens = torch.empty(
[attn_metadata.num_decode_tokens],
dtype=torch.int32,
device=block_table.device,
)
kv_cache_block_table_offset=torch.empty(
[attn_metadata.num_decode_tokens],
dtype=torch.int32,
device=block_table.device,
)
kv_cache_block_table_offset.fill_(self.window_size)
mlu_ops.masked_indexer_select_paged_kv_decode(
query=q,
k_cache=k_cache,
weights=weights.unsqueeze(-1), # (bsz, seq_q, head_num, 1)
kv_cache_block_table=block_table,
k_context_lens=seq_lens // self.compress_ratio,
k_cache_block_table=k_block_table,
index_topk=self.index_topk,
kv_cache_block_size=self.block_size,
softmax_scale=self.merged_softmax_scale,
q_scale=None,
k_scale_cache=None,
sparse_block_table=new_block_tables,
sparse_context_lens=new_context_lens,
compress_ratio=self.compress_ratio,
kv_cache_block_table_offset=kv_cache_block_table_offset,
)
# [batch, seq_q, index_topk] -> [batch, index_topk]
new_block_tables = new_block_tables.squeeze(1)
return new_block_tables, new_context_lens
def forward(self,
x: torch.Tensor,
qr: torch.Tensor,
positions: torch.Tensor,
offsets: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
indexer_kv_cache: torch.Tensor,
compressor_slot_mapping: torch.Tensor,
):
common_metadata = get_common_metadata()
query_start_loc = common_metadata.query_start_loc
query_lens = query_start_loc[1:] - query_start_loc[:-1]
rd = self.rope_head_dim
q = self.wq_b(qr)[0]
q = q.unflatten(-1, (self.n_heads, self.head_dim))
self.rotary_emb(positions, q[..., -rd:], None, only_prefill=False)
q_pack = rotate_activation(q)
weights_pack = self.weights_proj(x)[0] # (tokens, index_local_head)
num_decode_tokens = attn_metadata.num_decode_tokens
compressed_kv = self.compressor(
x,
positions,
attn_metadata,
batch_to_kv_state,
indexer_kv_cache,
0,
compressor_slot_mapping,
)
if attn_metadata.prefill:
assert compressed_kv is not None and compressed_kv.dim() == 3
compressed_kv = compressed_kv.squeeze(-2)
compressed_context_lens = query_lens // self.compress_ratio
prefill_q = q_pack[num_decode_tokens:, ...]
prefill_weights = weights_pack[num_decode_tokens:, ...]
prefill_block_tables, prefill_context_lens = self.forward_prefill(
prefill_q,
indexer_kv_cache,
prefill_weights,
attn_metadata,
compressed_kv,
compressed_context_lens,
)
if attn_metadata.decode:
decode_x = x[:num_decode_tokens, ...]
decode_q = q_pack[:num_decode_tokens, ...]
decode_weights = weights_pack[attn_metadata.num_prefills:]
decode_block_tables, decode_context_lens = self.forward_decode(
decode_q,
decode_x,
indexer_kv_cache,
decode_weights,
attn_metadata,
)
if attn_metadata.prefill and attn_metadata.decode:
new_block_tables = torch.cat([prefill_block_tables, decode_block_tables], dim=0)
new_context_lens = torch.cat([prefill_context_lens, decode_context_lens], dim=0)
elif attn_metadata.prefill:
new_block_tables = prefill_block_tables
new_context_lens = prefill_context_lens
else:
new_block_tables = decode_block_tables
new_context_lens = decode_context_lens
return new_block_tables, new_context_lens

View File

@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Tuple
import torch
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.models.layer_utils import is_per_token_smoothquant
@CustomOp.register("quant_fusion_rms_norm")
class QuantFusionRMSNorm(RMSNorm):
def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase):
super().__init__(hidden_size, variance_epsilon)
assert not isinstance(
proj.quant_method, UnquantizedLinearMethod
), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported"
proj.quant_method.skip_quant_input = True
if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config):
quant_scale = proj.smooth.data
else:
quant_scale = proj.scale_to_int.data
self.dynamic_quant = dynamic_quant
self.quant_scale = torch.nn.Parameter(quant_scale)
def forward(
self, x: torch.Tensor, residual: torch.Tensor | None = None
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
return mlu_ops.fused_rms_norm(
x,
residual,
self.weight.data,
None,
None,
self.variance_epsilon,
False,
self.quant_scale.data,
self.dynamic_quant,
)
@CustomOp.register("quant_fusion_layer_norm")
class QuantFusionLayerNorm(torch.nn.LayerNorm, CustomOp):
def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase):
super().__init__(hidden_size, variance_epsilon)
assert not isinstance(
proj.quant_method, UnquantizedLinearMethod
), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported"
proj.quant_method.skip_quant_input = True
if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config):
quant_scale = proj.smooth.data
else:
quant_scale = proj.scale_to_int.data
self.dynamic_quant = dynamic_quant
self.quant_scale = torch.nn.Parameter(quant_scale)
def forward(
self, x: torch.Tensor, residual: torch.Tensor | None = None
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
bias = None if self.bias is None else self.bias.data
return mlu_ops.fused_layer_norm(
x,
residual,
self.weight.data,
bias,
None,
self.eps,
False,
self.quant_scale.data,
self.dynamic_quant,
)
def vllm__model_executor__layers__layernorm__RMSNorm__forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
out: torch.Tensor | None = None,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
org_shape = x.shape
x = x.reshape(-1, self.weight.data.shape[0])
if out is not None:
out = out.view(-1, self.weight.data.shape[0])
if residual is not None:
residual = residual.view(-1, self.weight.data.shape[0])
x = mlu_ops.fused_rms_norm(
x,
residual,
self.weight.data,
None,
None,
self.variance_epsilon,
True,
out=out,
)
else:
x = mlu_ops.fused_rms_norm(
x,
residual,
self.weight.data,
None,
None,
self.variance_epsilon,
False,
out=out,
)
if out is not None:
return x
if residual is None:
assert isinstance(x, torch.Tensor)
return x.view(org_shape)
assert isinstance(x, tuple)
assert len(x) == 2
return x[0].view(org_shape), x[1].view(org_shape)
MluHijackObject.apply_hijack(
RMSNorm,
RMSNorm.forward_oot,
vllm__model_executor__layers__layernorm__RMSNorm__forward_oot,
)

View File

@@ -0,0 +1,693 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional, Any
import torch
from torch.nn.parameter import Parameter
from vllm.distributed import (divide, split_tensor_along_last_dim,
get_parallel_rank_with_group, get_parallel_world_size_with_group,
get_tp_world_group, get_tp_world_world_size, get_tp_world_rank)
from vllm.distributed.communication_op import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.linear import (
WEIGHT_LOADER_V2_SUPPORTED, UnquantizedLinearMethod, LinearBase,
ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED.extend([
"GPTQMluLinearMethod",
"AWQMluLinearMethod"
])
vllm__module_executor__layers__linear__LinearBase____init__org = LinearBase.__init__
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader_org = MergedColumnParallelLinear.weight_loader
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader_org = RowParallelLinear.weight_loader
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual parameter.
@brief: dispatch unquantized_gemm to mlu ops.
'''
def vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
residual: torch.Tensor | None = None
) -> torch.Tensor:
beta = 0.0
if residual is not None:
beta = 1.0
residual = residual.view(-1, residual.shape[-1])
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
return mlu_ops.matmul(x.reshape(x.numel() // x.shape[-1], x.shape[-1]),
layer.weight,
bias, residual, 'none', 1.0, beta).view(res_shape)
'''
==================
End of MLU Hijack
==================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group and keep_full_weights parameters.
'''
def vllm__module_executor__layers__linear__LinearBase____init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
vllm__module_executor__layers__linear__LinearBase____init__org(
self=self,
input_size=input_size,
output_size=output_size,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add self.tp_group, world_size and tp_rank to support data parallel and moe expert parallel
'''
self.tp_group = tp_group
self.tp_world_size = get_parallel_world_size_with_group(self.tp_group)
self.tp_size = self.tp_world_size
self.tp_rank = get_parallel_rank_with_group(self.tp_group)
self.keep_full_weights = keep_full_weights
if self.keep_full_weights or disable_tp:
self.tp_group = None
self.tp_world_size = 1
self.tp_size = self.tp_world_size
self.tp_rank = 0
self.tp_world_size_org = get_tp_world_world_size()
self.tp_rank_org = get_tp_world_rank()
'''
=================
End of MLU Hijack
=================
'''
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group and keep_full_weights parameters.
'''
def vllm__module_executor__layers__linear__ColumnParallelLinear____init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
output_sizes: list[int] | None = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
super(ColumnParallelLinear, self).__init__(
input_size,
output_size,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
tp_group=tp_group,
keep_full_weights=keep_full_weights,
return_bias=return_bias,
disable_tp=disable_tp,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: self.tp_size and self.tp_rank has been initialized in LinearBase.__init__
'''
# Divide the weight matrix along the last dimension.
# self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
# self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
'''
=================
End of MLU Hijack
=================
'''
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size) for output_size in self.output_sizes
]
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group in create_weights
'''
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
else self.weight_loader
),
tp_group=self.tp_group,
)
'''
=================
End of MLU Hijack
=================
'''
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition, dtype=params_dtype)
)
set_weight_attrs(
self.bias,
{
"output_dim": 0,
"weight_loader": self.weight_loader,
},
)
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add smooth_quant_scale and use_tp_weight parameters.
'''
def vllm__module_executor__layers__linear__ColumnParallelLinear__forward(
self,
input_,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add input_scale and use_tp_weight parameter.
'''
kwargs = {'bias': bias}
if use_tp_weight:
kwargs['use_tp_weight'] = use_tp_weight
if smooth_quant_scale is not None:
kwargs['input_scale'] = smooth_quant_scale
output_parallel = self.quant_method.apply(self, input_, **kwargs)
'''
==================
End of MLU Hijack
==================
'''
if self.gather_output and self.tp_size > 1:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group param to tensor_model_parallel_all_gather
'''
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel, dim=-1, tp_group=self.tp_group)
'''
=================
End of MLU Hijack
=================
'''
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group and keep_full_weights parameters.
'''
def vllm__module_executor__layers__linear__MergedColumnParallelLinear____init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
self.output_sizes = output_sizes
'''
=============================
Modify by vllm_mlu
=============================
@brief: checkout output_sizes after init to get self.tp_world_size
@brief: add keep_full_weights for dp parallelize shared expert
'''
super(MergedColumnParallelLinear, self).__init__(
input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
output_sizes=self.output_sizes,
prefix=prefix,
tp_group=tp_group,
keep_full_weights=keep_full_weights,
return_bias=return_bias,
disable_tp=disable_tp,
)
assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
if self.keep_full_weights:
tp_size = self.tp_world_size_org
if isinstance(self.quant_method, UnquantizedLinearMethod):
out_dim, in_dim = self.weight.shape
out_dim_tp = divide(out_dim, tp_size)
self.tp_weight = Parameter(
self.weight.data.new_empty((out_dim_tp, in_dim)),
requires_grad=False,
)
elif (isinstance(self.quant_method, SmoothQuantLinearMethod)
and quant_config.input_quant_method == "per_token"):
out_dim, in_dim = self.qweight.shape
out_dim_tp = divide(out_dim, tp_size)
self.tp_qweight = Parameter(
self.qweight.data.new_empty((out_dim_tp, in_dim)),
requires_grad=False,
)
self.tp_per_channel_scale = Parameter(
self.per_channel_scale.data.new_empty((out_dim_tp)),
requires_grad=False,
)
else:
raise TypeError(f"quant method is expected to be unquantized or smoothquant per-token")
'''
=================
End of MLU Hijack
=================
'''
'''
=================
End of MLU Hijack
=================
'''
def vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: int | None = None,
):
loaded_weight_orig = loaded_weight
output_dim = getattr(param, "output_dim", None)
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader_org(
self=self,
param=param,
loaded_weight=loaded_weight,
loaded_shard_id=loaded_shard_id,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add keep_full_weights for dp parallelize shared expert
'''
# load into tp weight
if self.keep_full_weights:
tp_size = self.tp_world_size_org
tp_rank = self.tp_rank_org
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
start_idx = tp_rank * shard_size
if isinstance(self.quant_method, UnquantizedLinearMethod):
tp_weight = loaded_weight_orig.narrow(output_dim, start_idx, shard_size)
tp_weight_shard = self.tp_weight.narrow(output_dim, shard_offset, shard_size)
tp_weight_shard.copy_(tp_weight)
elif isinstance(self.quant_method, SmoothQuantLinearMethod):
if output_dim is None:
return
tp_weight = loaded_weight_orig.narrow(output_dim, start_idx, shard_size)
if loaded_weight_orig.ndim == 1:
tp_weight_shard = self.tp_per_channel_scale.narrow(output_dim, shard_offset, shard_size)
elif loaded_weight_orig.ndim == 2:
tp_weight_shard = self.tp_qweight.narrow(output_dim, shard_offset, shard_size)
else:
raise ValueError("only support rank 1 and 2 when using tp_weight")
tp_weight_shard.copy_(tp_weight)
else:
raise TypeError(f"quant method is expected to be either unquantized or smoothquant")
'''
=================
End of MLU Hijack
=================
'''
def vllm__module_executor__layers__linear__RowParallelLinear____init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
super(RowParallelLinear, self).__init__(
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
tp_group=tp_group,
keep_full_weights=keep_full_weights,
return_bias=return_bias,
disable_tp=disable_tp,
)
# Divide the weight matrix along the last dimension
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group in create_weights
'''
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
else self.weight_loader
),
tp_group=self.tp_group,
)
'''
=================
End of MLU Hijack
=================
'''
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(
self.bias,
{
"output_dim": 0,
"weight_loader": self.weight_loader,
},
)
else:
self.register_parameter("bias", None)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add keep_full_weights for dp parallelize shared expert
'''
if self.keep_full_weights:
tp_size = self.tp_world_size_org
if isinstance(self.quant_method, UnquantizedLinearMethod):
out_dim, in_dim = self.weight.data.shape
in_dim_tp = divide(in_dim, tp_size)
self.tp_weight = Parameter(self.weight.data.new_empty((out_dim, in_dim_tp)),
requires_grad=False)
elif (isinstance(self.quant_method, SmoothQuantLinearMethod)
and quant_config.input_quant_method == "per_token"):
out_dim, in_dim = self.qweight.data.shape
in_dim_tp = divide(in_dim, tp_size)
self.tp_qweight = Parameter(self.qweight.data.new_empty((out_dim, in_dim_tp)),
requires_grad=False)
if hasattr(self, "smooth"):
assert len(self.smooth.shape) == 1, "smooth should be a 1D tensor"
dim = self.smooth.shape[0]
dim_tp = divide(dim, tp_size)
self.tp_smooth = Parameter(self.smooth.data.new_empty((dim_tp)),
requires_grad=False)
else:
raise TypeError("quant method expected to be unquantized or smoothquant per-token")
'''
=================
End of MLU Hijack
=================
'''
self.update_param_tp_status()
def vllm__module_executor__layers__linear__RowParallelLinear__weight_loader(
self, param: Parameter, loaded_weight: torch.Tensor
):
input_dim = getattr(param, "input_dim", None)
loaded_weight_orig = loaded_weight
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader_org(
self=self,
param=param,
loaded_weight=loaded_weight,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add keep_full_weights for dp parallelize shared expert
'''
if self.keep_full_weights:
if input_dim is None:
return
tp_size = self.tp_world_size_org
tp_rank = self.tp_rank_org
shard_size = divide(loaded_weight_orig.shape[input_dim], tp_size)
start_idx = tp_rank * shard_size
if isinstance(self.quant_method, UnquantizedLinearMethod):
shard_view = self.weight.narrow(input_dim, start_idx, shard_size)
self.tp_weight.copy_(shard_view)
elif isinstance(self.quant_method, SmoothQuantLinearMethod):
if loaded_weight_orig.ndim == 1:
shard_view = self.smooth.narrow(input_dim, start_idx, shard_size)
self.tp_smooth.copy_(shard_view)
elif loaded_weight_orig.ndim == 2:
shard_view = self.qweight.narrow(input_dim, start_idx, shard_size)
self.tp_qweight.copy_(shard_view)
else:
raise ValueError("only rank 1 and 2 is supported for tp_weight")
else:
raise TypeError("quant method is expected to be UnquantizedLinearMethod and SmoothQuant")
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual, smooth_quant_scale, use_tp_weight and output parameters.
'''
def vllm__module_executor__layers__linear__RowParallelLinear__forward(
self,
input_,
residual: torch.Tensor | None = None,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
output: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add additional matmul parameters.
'''
residual_ = None if self.tp_rank > 0 else residual
kwargs = {'bias': bias_, 'residual': residual_}
if use_tp_weight:
kwargs['use_tp_weight'] = use_tp_weight
if smooth_quant_scale is not None:
kwargs['input_scale'] = smooth_quant_scale
if output is not None:
kwargs['output'] = output
output_parallel = self.quant_method.apply(self, input_parallel, **kwargs)
'''
=================
End of MLU Hijack
=================
'''
if self.reduce_results and self.tp_size > 1:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tensor_model_parallel_all_reduce() with self.tp_group
'''
output = tensor_model_parallel_all_reduce(output_parallel, tp_group=self.tp_group)
'''
=================
End of MLU Hijack
=================
'''
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
'''
=================
End of MLU Hijack
=================
'''
MluHijackObject.apply_hijack(UnquantizedLinearMethod,
UnquantizedLinearMethod.apply,
vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply)
MluHijackObject.apply_hijack(LinearBase,
LinearBase.__init__,
vllm__module_executor__layers__linear__LinearBase____init__)
MluHijackObject.apply_hijack(ColumnParallelLinear,
ColumnParallelLinear.__init__,
vllm__module_executor__layers__linear__ColumnParallelLinear____init__)
MluHijackObject.apply_hijack(ColumnParallelLinear,
ColumnParallelLinear.forward,
vllm__module_executor__layers__linear__ColumnParallelLinear__forward)
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
MergedColumnParallelLinear.__init__,
vllm__module_executor__layers__linear__MergedColumnParallelLinear____init__)
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
MergedColumnParallelLinear.weight_loader,
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.__init__,
vllm__module_executor__layers__linear__RowParallelLinear____init__)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.weight_loader,
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.forward,
vllm__module_executor__layers__linear__RowParallelLinear__forward)

View File

@@ -0,0 +1,744 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""Inference-only MOE model."""
from typing import Optional, Any, List, Dict
import torch
from torch import nn
from vllm.distributed import (
divide,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
from vllm_mlu.distributed.parallel_state import(
cnclep_dispatch, cnclep_combine)
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
class LongCatSparseMoeMlp(SparseMoeMlp):
"""
sparse moe mlp layer specific to longcat model
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
up_proj_name: str,
is_gated: bool,
down_proj_name: str,
has_bias: bool,
skip_bias_add: bool = False,
renormalize:bool = False,
hidden_act: str = "silu",
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
is_use_fused_moe: bool = False,
expert_group: Optional[int] = 1,
topk_group: Optional[int] = 1,
scoring_func: str = "softmax",
topk_method: str = "",
routed_scaling_factor: float = 1.0,
tp_group: Any = None,
use_all2all: bool = False,
num_zero_experts: int = 0,
):
super().__init__(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
up_proj_name=up_proj_name,
is_gated=is_gated,
down_proj_name=down_proj_name,
has_bias=has_bias,
skip_bias_add=skip_bias_add,
renormalize=renormalize,
hidden_act=hidden_act,
params_dtype=params_dtype,
quant_config=quant_config,
is_use_fused_moe=is_use_fused_moe,
expert_group=expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
topk_method=topk_method,
routed_scaling_factor=routed_scaling_factor,
tp_group=tp_group,
use_all2all=use_all2all,
init_avg_moe=False,
)
self.num_zero_experts = num_zero_experts
self.total_experts_including_zero = self.num_total_experts + self.num_zero_experts
self.use_quant_all2all = use_all2all and quant_config is not None
self.zero_expert_size = divide(self.num_zero_experts, self.moe_ep_size)
self.start_zero_expert_id = (
self.num_total_experts + self.moe_ep_rank * ((self.num_zero_experts + self.moe_ep_size - 1) // self.moe_ep_size)
)
if VLLM_AVG_MOE_EN and not SparseMoeMlp.is_expert_avg:
n_tokens = SparseMoeMlp.max_batched_token * self.dp_size
expert_group = self.moe_ep_size
val = 1.0 / float(self.total_experts_including_zero)
SparseMoeMlp.reduce_weight = torch.full((n_tokens, top_k), val, device="mlu", dtype=torch.float32)
if VLLM_RANDOM_MOE_EN:
import numpy as np
# example deepseekv2: experts 160 topk 6
# avg list: 92, 8, 88, 45, 99, 9,... 118, 142, 116, 57, 104, 6,......
array = np.stack([np.random.permutation(self.total_experts_including_zero)[:top_k] for _ in range(n_tokens)])
table = torch.from_numpy(array.flatten()).to(device="mlu", dtype=torch.int32)
else:
# example deepseekv2: experts 160
# avg list: 0,20,40,60,80...120,140, 1,21,...121,141, 2...142, ...... 19,...159, 0,20,......
import math
batch_table = math.ceil(n_tokens * top_k / self.total_experts_including_zero) * self.total_experts_including_zero
hi_val = batch_table // self.total_experts_including_zero
table = (torch.arange(hi_val * num_experts, device="mlu", dtype=torch.int32) % num_experts).view(
hi_val, expert_group, num_experts // expert_group).transpose(1, 2)
if self.num_zero_experts > 0:
# Longcat model, for avg expert, we choose eight non-zero experts and four zero
# experts for each token accorrding to the paper.
assert num_experts == 512 and num_zero_experts == 256 and top_k == 12
assert num_zero_experts % expert_group == 0
non_zero_expert_num_per_token = 8
zero_expert_num_per_token = 4
zero_expert_table = torch.arange(
num_experts, num_experts + num_zero_experts, dtype=table.dtype, device=table.device).view(
expert_group, num_zero_experts // expert_group).transpose(0, 1).flatten()
non_zero_expert_table = table[0].flatten()
token_expert_list = []
for idx in range(0, num_experts // non_zero_expert_num_per_token):
token_expert_list.append(non_zero_expert_table[
idx * non_zero_expert_num_per_token:
idx * non_zero_expert_num_per_token + non_zero_expert_num_per_token])
token_expert_list.append(zero_expert_table[
idx * zero_expert_num_per_token:
idx * zero_expert_num_per_token + zero_expert_num_per_token])
avg_expert_table = torch.cat(token_expert_list)
table = avg_expert_table.repeat(hi_val)
SparseMoeMlp.expert_id = table.flatten()[:n_tokens * top_k].view(n_tokens, top_k)
SparseMoeMlp.is_expert_avg = True
def forward_experts_nofused_longcat(
self, hidden_states, total_num_experts, total_num_experts_per_rank,
topk_indices=None, topk_weights=None, residual_=None):
assert self.moe_ep_size == 1
assert not self.use_all2all
expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count = mlu_ops.moe_gen_idx(
topk_indices.to(torch.int32), total_num_experts)
# no expert is routed, then expand_gather_idx, expand_scatter_idx has no item,
# expand_token_count and expand_cusum_token_count has item but the value is all zero
# so this rank should only return final_hidden_states with zero value
if cusum_token_count[-1] == 0:
final_hidden_states = torch.zeros_like(hidden_states,
dtype=hidden_states.dtype,
device=hidden_states.device)
return final_hidden_states
expand_hidden_states = mlu_ops.moe_expand_input(
hidden_states, expand_gather_idx, cusum_token_count,
start_expert_id=self.start_expert_id,
expert_size=self.end_expert_id - self.start_expert_id)
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_gather_idx, cusum_token_count,
start_expert_id=self.start_zero_expert_id,
expert_size=self.zero_expert_size)
expand_output_list = []
expand_cusum_token_count = cusum_token_count[self.start_expert_id:self.end_expert_id +
1] - cusum_token_count[self.start_expert_id]
for expert_idx, num_tokens_per_expert in enumerate(expand_token_count[:self.num_total_experts]):
if num_tokens_per_expert > 0:
expert_hidden_states = expand_hidden_states[
expand_cusum_token_count[expert_idx]:expand_cusum_token_count[expert_idx + 1]]
if expert_idx < self.num_total_experts:
expert_output = self.experts[expert_idx](expert_hidden_states)
else:
expert_output = expert_hidden_states
expert_output = expert_output[0] if isinstance(expert_output, (tuple, list)) else expert_output
expand_output_list.append(expert_output)
expand_output = torch.cat(expand_output_list, dim=0)
num_normal_tokens = cusum_token_count[self.num_total_experts]
expand_hidden_states[:num_normal_tokens] = expand_output
# reduce normal experts
final_hidden_states = mlu_ops.moe_combine_result(
expand_hidden_states, topk_weights, scatter_idx,
residual_, cusum_token_count, start_expert_id=self.start_expert_id,
expert_size=self.end_expert_id - self.start_expert_id, bias=None)
# reduce zero experts
if self.moe_ep_size > 1 or self.moe_tp_rank == 0:
final_hidden_states = mlu_ops.moe_combine_result(
expand_hidden_states_zero, topk_weights, scatter_idx,
final_hidden_states, cusum_token_count, start_expert_id=self.start_zero_expert_id,
expert_size=self.zero_expert_size, bias=None,
output=final_hidden_states)
return final_hidden_states
# no compute-communication parallel, for prototyping only, not in actual use.
# subject to becoming stale
def forward_all2all_int8_longcat(
self, hidden_states, total_num_experts, total_num_experts_per_rank,
topk_indices=None, topk_weights=None, residual_=None):
ori_input_shape = hidden_states.shape
dtype = hidden_states.dtype
self.pack_params()
self.pack_params_after_loading()
w1=self.w13
w2=self.w2
bias2=self.b2
input_smooth=self.a13_scale_all_experts
act_smooth=self.a2_scale
w1_scale=self.w13_scale
w2_scale=self.w2_scale
act_mode=self.hidden_act
quant_input=None
max_m = hidden_states.shape[0]
reduce_weight = topk_weights
expert_id = topk_indices
expand_idx, combine_idx, token_count, cusum_token_count \
= mlu_ops.moe_gen_idx(expert_id, total_num_experts)
num_token_expand = hidden_states.shape[0] * self.top_k
dispatch_bytes = num_token_expand * self.dispatch_token_size
dispatch_send_token_tensor = (
self.dispatch_send_buffer[:dispatch_bytes]
.view(num_token_expand, self.dispatch_token_size)
)
quant_size = self.hidden_size
quant_input = dispatch_send_token_tensor[:, : quant_size]
input_scale = dispatch_send_token_tensor[:, quant_size :].view(torch.float32)
quant_input, input_scale = mlu_ops.moe_quantize(
hidden_states, input_smooth, None, token_count[:self.num_total_experts],
expand_idx, None,
output=quant_input,
output_scale=input_scale)
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.num_total_experts,
expert_size=self.num_zero_experts)
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
token_count[:self.num_total_experts], self.moe_ep_size)
cnclep_dispatch(self.dispatch_token_size,
num_token_expand,
dispatch_send_layout,
token_count[:self.num_total_experts],
self.dispatch_recv_layout,
self.dispatch_recv_token_num)
recv_token_num = self.dispatch_recv_token_num.view(
self.moe_ep_size, self.num_experts_per_rank)
pad_num = self.max_num_tokens_per_rank
(
gather_by_expert_index,
gather_by_rank_index,
tokens_per_local_expert,
token_sum
) = mlu_ops.moe_all2all_gen_gather_index(recv_token_num, pad_num)
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
dispatch_recv_token_tensor = (
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
.view(self.max_num_tokens_recv, self.dispatch_token_size))
mlu_ops.gather_split(dispatch_recv_token_tensor,
gather_by_expert_index,
token_sum,
self.quant_input_recv,
self.input_scale_recv)
max_m = self.max_num_tokens_per_expert
gemm_out = mlu_ops.smooth_quant_group_gemm(self.quant_input_recv, w1,
tokens_per_local_expert,
None, None, None, None,
self.input_scale_recv.view(torch.float32).flatten(),
w1_scale, dtype, max_m)
# continue reusing self.quant_input_recv and self.input_scale_recv
quant_input = self.quant_input_recv[:, :gemm_out.shape[-1] // 2]
input_scale_fp32 = self.input_scale_recv.view(torch.float32).flatten()[:gemm_out.shape[0]]
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None,
tokens_per_local_expert,
output=quant_input,
output_scale=input_scale_fp32,
act_mode=act_mode,
is_gated=self.is_gated)
gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, w2,
tokens_per_local_expert,
None, None, None, None, input_scale, w2_scale, dtype, max_m)
combine_send_token_tensor = self.combine_send_buffer.view(self.max_num_tokens_recv, -1).view(hidden_states.dtype)
mlu_ops.gather_split(gemm_out,
gather_by_rank_index,
token_sum,
combine_send_token_tensor,
None)
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(self.dispatch_recv_token_num, self.moe_ep_size)
combine_recv_layout = self.dispatch_recv_layout
# combine
combine_args = dict(
token_byte=self.hidden_size * 2,
token_num=num_token_expand,
send_src_layout=combine_send_layout,
send_dst_layout=combine_recv_layout,
send_token=None,
recv_token=None)
cnclep_combine(**combine_args)
numel_recv = num_token_expand * self.hidden_size
recv_token = (self.combine_recv_buffer.view(hidden_states.dtype)[:numel_recv]
.view(num_token_expand, self.hidden_size))
residual_ = None
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
residual_, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts, bias=bias2, output=hidden_states)
assert self.moe_ep_size > 1
# zero expert reduce
output = mlu_ops.moe_combine_result(
expand_hidden_states_zero, reduce_weight, combine_idx,
output, cusum_token_count, self.num_total_experts,
self.num_zero_experts, output=hidden_states)
return output.view(ori_input_shape)
# no compute-communication parallel, for prototyping only, not in actual use.
# subject to becoming stale
def forward_all2all_bf16_longcat(
self, hidden_states, total_num_experts, total_num_experts_per_rank,
topk_indices=None, topk_weights=None, residual_=None):
is_fp8_quant = isinstance(self.quant_config, Fp8Config)
ori_input_shape = hidden_states.shape
dtype = hidden_states.dtype
self.pack_params()
self.pack_params_after_loading()
w1=self.w13
w2=self.w2
bias1=self.b13
bias2=self.b2
gated=self.is_gated
act_mode=self.hidden_act
max_m = hidden_states.shape[0]
reduce_weight = topk_weights
expert_id = topk_indices
# gen_idx
expand_idx, combine_idx, token_count, cusum_token_count = \
mlu_ops.moe_gen_idx(expert_id, total_num_experts)
num_token_expand = hidden_states.shape[0] * self.top_k
dispatch_bytes = num_token_expand * self.dispatch_token_size
dispatch_send_token_tensor = (
self.dispatch_send_buffer[:dispatch_bytes]
.view(num_token_expand, self.dispatch_token_size)
.view(hidden_states.dtype)
)
expand_hidden_states = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts)
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.num_total_experts,
expert_size=self.num_zero_experts)
dispatch_send_token_tensor.copy_(expand_hidden_states)
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
token_count[:self.num_total_experts], self.moe_ep_size)
cnclep_dispatch(self.dispatch_token_size,
num_token_expand,
dispatch_send_layout,
token_count[:self.num_total_experts],
self.dispatch_recv_layout,
self.dispatch_recv_token_num,
use_quant_dispatch=False,
)
recv_token_num = self.dispatch_recv_token_num.view(
self.moe_ep_size, self.num_experts_per_rank)
pad_num = self.max_num_tokens_per_rank
(
gather_by_expert_index,
gather_by_rank_index,
tokens_per_local_expert,
token_sum
) = mlu_ops.moe_all2all_gen_gather_index(recv_token_num, pad_num)
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
dispatch_recv_token_tensor = (
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
.view(self.max_num_tokens_recv, self.dispatch_token_size)
.view(hidden_states.dtype)
)
self.quant_input_recv = self.quant_input_recv.view(hidden_states.dtype)
mlu_ops.gather_split(dispatch_recv_token_tensor,
gather_by_expert_index,
token_sum,
self.quant_input_recv)
max_m = self.max_num_tokens_per_expert
gemm_out = mlu_ops.group_gemm(
self.quant_input_recv, w1, tokens_per_local_expert,
None, None, None, None, max_m)
act_out = mlu_ops.moe_active(
gemm_out, act_mode, gated)
gemm_out = mlu_ops.group_gemm(
act_out, w2, tokens_per_local_expert,
None, None, None, None, max_m)
combine_send_token_tensor = self.combine_send_buffer.view(
self.max_num_tokens_recv, -1).view(hidden_states.dtype)
mlu_ops.gather_split(gemm_out,
gather_by_rank_index,
token_sum,
combine_send_token_tensor,
None)
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(
self.dispatch_recv_token_num, self.moe_ep_size)
combine_recv_layout = self.dispatch_recv_layout
combine_args = dict(
token_byte=self.hidden_size * 2,
token_num=num_token_expand,
send_src_layout=combine_send_layout,
send_dst_layout=combine_recv_layout,
send_token=None,
recv_token=None,
use_quant_dispatch=False,
)
cnclep_combine(**combine_args)
numel_recv = num_token_expand * self.hidden_size
recv_token = (self.combine_recv_buffer.view(hidden_states.dtype)[:numel_recv]
.view(num_token_expand, self.hidden_size))
residual_ = None
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
residual_, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts, bias=bias2, output=hidden_states)
# zero expert reduce
output = mlu_ops.moe_combine_result(
expand_hidden_states_zero, reduce_weight, combine_idx,
output, cusum_token_count, self.num_total_experts,
self.num_zero_experts, output=hidden_states)
return output.view(ori_input_shape)
def forward_before_dispatch(self, hidden_states: torch.Tensor,
topk_indices: torch.Tensor):
# gate and softmax topk is called in router for longcat
# other models can do these operations here
expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(
topk_indices, self.total_experts_including_zero)
num_token_expand = hidden_states.shape[0] * self.top_k
dispatch_bytes = num_token_expand * self.dispatch_token_size
dispatch_send_token_tensor = (
self.dispatch_send_buffer[:dispatch_bytes]
.view(num_token_expand, self.dispatch_token_size)
)
if self.use_quant_all2all:
hidden_states_stride = self.hidden_size
quant_input = dispatch_send_token_tensor[:, : hidden_states_stride]
input_scale = dispatch_send_token_tensor[:, hidden_states_stride :].view(torch.float32)
# expand input + quantize
quant_input, input_scale = mlu_ops.moe_quantize(
hidden_states, self.a13_scale_all_experts, None,
token_count[:self.num_total_experts],
expand_idx, None,
output=quant_input,
output_scale=input_scale)
# expand input of zero-expert
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.num_total_experts,
expert_size=self.num_zero_experts)
else:
expand_hidden_states = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts)
dispatch_send_token_tensor = dispatch_send_token_tensor.view(
hidden_states.dtype)
dispatch_send_token_tensor.copy_(expand_hidden_states)
del expand_hidden_states
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.num_total_experts,
expert_size=self.num_zero_experts)
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
token_count[:self.num_total_experts], self.moe_ep_size)
return combine_idx, token_count, cusum_token_count, dispatch_send_layout, expand_hidden_states_zero
def forward_dispatch(self, token_num: int, dispatch_send_layout: torch.Tensor,
token_count: torch.Tensor):
num_token_expand = token_num * self.top_k
cnclep_dispatch(self.dispatch_token_size,
num_token_expand,
dispatch_send_layout,
token_count[:self.num_total_experts],
self.dispatch_recv_layout,
self.dispatch_recv_token_num,
use_quant_dispatch=self.use_quant_all2all)
def forward_before_combine(self, hidden_states_dtype: torch.dtype):
recv_token_num = self.dispatch_recv_token_num.view(
self.moe_ep_size, self.num_experts_per_rank)
(
gather_by_expert_index,
gather_by_rank_index,
tokens_per_local_expert,
token_sum,
cusum_token_count
) = mlu_ops.moe_all2all_gen_gather_index(
recv_token_num, self.max_num_tokens_per_rank,
return_cusum_token_count=True)
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
dispatch_recv_token_tensor = (
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
.view(self.max_num_tokens_recv, self.dispatch_token_size))
max_m = self.max_num_tokens_per_expert
if self.use_quant_all2all:
mlu_ops.gather_split(dispatch_recv_token_tensor,
gather_by_expert_index,
token_sum,
self.quant_input_recv,
self.input_scale_recv)
# OPT: input_scale_recv_flatten can reuse self.input_scale_recv
input_scale_recv_flatten = self.input_scale_recv.view(torch.float32).flatten()
gemm_out = mlu_ops.smooth_quant_group_gemm(self.quant_input_recv, self.w13,
tokens_per_local_expert,
None, None, None, None,
input_scale_recv_flatten,
self.w13_scale, hidden_states_dtype, max_m)
quant_input = self.quant_input_recv[:, :gemm_out.shape[-1] // 2]
input_scale_fp32 = input_scale_recv_flatten[:gemm_out.shape[0]]
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, self.a2_scale, None,
tokens_per_local_expert,
output=quant_input,
output_scale=input_scale_fp32,
act_mode=self.hidden_act,
is_gated=self.is_gated)
gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, self.w2, tokens_per_local_expert,
None, None, None, None, input_scale, self.w2_scale,
hidden_states_dtype, max_m)
else:
dispatch_recv_token_tensor = dispatch_recv_token_tensor.view(hidden_states_dtype)
self.input_recv = self.input_recv.view(hidden_states_dtype)
mlu_ops.gather_split(dispatch_recv_token_tensor,
gather_by_expert_index,
token_sum,
self.input_recv)
gemm_out = mlu_ops.group_gemm(
self.input_recv, self.w13, tokens_per_local_expert,
None, None, None, None, max_m)
act_out = self.input_recv[:, :gemm_out.shape[-1] // 2]
act_out = mlu_ops.moe_active(
gemm_out, self.hidden_act, self.is_gated, output=act_out,
bias=None, cusum_token_count=cusum_token_count,
start_expert_id=0, expert_size=self.num_experts_per_rank)
gemm_out = mlu_ops.group_gemm(
act_out, self.w2, tokens_per_local_expert,
None, None, None, None, max_m)
combine_send_token_tensor = self.combine_send_buffer.view(
self.max_num_tokens_recv, -1).view(hidden_states_dtype)
mlu_ops.gather_split(gemm_out,
gather_by_rank_index,
token_sum,
combine_send_token_tensor,
None)
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(
self.dispatch_recv_token_num, self.moe_ep_size)
return combine_send_layout
def forward_combine(self, token_num: int, combine_send_layout: torch.Tensor):
num_token_expand = token_num * self.top_k
# combine_recv_layout(self.dispatch_recv_layout) is calculated when cnclep_dispatch
# because dispatch and combine are inverse operation
cnclep_combine(token_byte=self.hidden_size * 2,
token_num=num_token_expand,
send_src_layout=combine_send_layout,
send_dst_layout=self.dispatch_recv_layout,
send_token=None,
recv_token=None,
use_quant_dispatch=self.use_quant_all2all)
def forward_after_combine(self, token_num: int,
reduce_weight: torch.Tensor,
combine_idx: torch.Tensor,
cusum_token_count: torch.Tensor,
expand_hidden_states_zero: torch.Tensor,
output_tensor_dtype: torch.dtype,
output_tensor: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None):
num_token_expand = token_num * self.top_k
numel_recv = num_token_expand * self.hidden_size
recv_token = (self.combine_recv_buffer.view(output_tensor_dtype)[:numel_recv]
.view(num_token_expand, self.hidden_size))
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
residual, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts, bias=self.b2, output=output_tensor)
output = mlu_ops.moe_combine_result(
expand_hidden_states_zero, reduce_weight, combine_idx,
output, cusum_token_count, self.num_total_experts,
self.num_zero_experts, output=output_tensor)
return output
# no compute-communication parallel, for prototyping only, not in actual use.
# subject to becoming stale
def forward_group_experts_longcat(
self, hidden_states, total_num_experts, total_num_experts_per_rank,
topk_indices=None, topk_weights=None, residual_=None,
expand_idx=None, combine_idx=None, token_count=None, cusum_token_count=None):
is_fp8_quant = isinstance(self.quant_config, Fp8Config)
ori_input_shape = hidden_states.shape
dtype = hidden_states.dtype
self.pack_params()
self.pack_params_after_loading()
w1=self.w13
w2=self.w2
bias1=self.b13
bias2=self.b2
input_smooth=self.a13_scale
act_smooth=self.a2_scale
w1_scale=self.w13_scale
w2_scale=self.w2_scale
gated=self.is_gated
act_mode=self.hidden_act
quant_input=None
start_expert_id=self.start_expert_id
expert_size = w1.size(0)
max_m = hidden_states.shape[0]
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
residual_ = residual_.view(-1, residual_.size(-1)) if residual_ is not None else None
# Check smooth quant parameters.
per_token_sq = False
if not is_fp8_quant:
check_list = [input_smooth, act_smooth, w1_scale, w2_scale]
if all(x is not None for x in check_list):
per_token_sq = True
if not (all(x is None for x in check_list) or all(x is not None for x in check_list)):
raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present "
"and absent at the same time.")
expert_id = topk_indices
reduce_weight = topk_weights
# gen_idx
if expert_id is not None:
expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(expert_id, total_num_experts)
# check quant
if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token':
raise NotImplementedError
elif per_token_sq:
expand_hidden_states = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=start_expert_id,
expert_size=expert_size)
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.start_zero_expert_id,
expert_size=self.zero_expert_size)
quant_input, input_scale = mlu_ops.moe_quantize(
expand_hidden_states, input_smooth, None,
token_count[start_expert_id:start_expert_id+expert_size])
else:
expand_hidden_states = mlu_ops.moe_expand_input(hidden_states, expand_idx,
cusum_token_count, start_expert_id, expert_size)
expand_hidden_states_zero = mlu_ops.moe_expand_input(hidden_states, expand_idx,
cusum_token_count, self.start_zero_expert_id, self.zero_expert_size)
if (is_fp8_quant and self.quant_config.activation_quant_method == 'per_token') or per_token_sq:
gemm_out = mlu_ops.smooth_quant_group_gemm(
quant_input, w1,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, input_scale, w1_scale, dtype, max_m)
else:
gemm_out = mlu_ops.group_gemm(expand_hidden_states, w1,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, max_m)
# add_bias_active
if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token':
raise NotImplementedError
elif per_token_sq:
quant_input = quant_input[:, :gemm_out.shape[-1] // 2]
input_scale = input_scale[:gemm_out.shape[0]]
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None,
token_count[start_expert_id:start_expert_id+expert_size],
output=quant_input,
output_scale=input_scale,
act_mode=act_mode,
is_gated=self.is_gated)
if ((is_fp8_quant and self.quant_config.activation_quant_method == 'per_token')
or per_token_sq):
# Remove the reference to gemm_out tensor.
# If that was the only reference, the tensors memory becomes eligible for deallocation
# So that we can reuse this memory for the new allocation of next gemm operation
# del gemm_out
gemm_out = mlu_ops.smooth_quant_group_gemm(
quant_input, w2,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, input_scale, w2_scale, dtype, max_m,
output=expand_hidden_states)
else:
act_out = mlu_ops.moe_active(
gemm_out, act_mode, gated, gemm_out[:,:gemm_out.shape[-1]//2],
bias1, cusum_token_count, start_expert_id, expert_size)
gemm_out = mlu_ops.group_gemm(
act_out, w2, token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, max_m,
output=expand_hidden_states)
output = mlu_ops.moe_combine_result(
gemm_out, reduce_weight, combine_idx,
residual_, cusum_token_count, start_expert_id,
expert_size, bias2)
if self.moe_ep_size > 1 or self.moe_tp_rank == 0:
output = mlu_ops.moe_combine_result(
expand_hidden_states_zero, reduce_weight, combine_idx,
output, cusum_token_count, self.start_zero_expert_id,
self.zero_expert_size, bias2,
output=output)
return output.view(ori_input_shape)

View File

@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.model_executor.layers.quantization import (
QUANTIZATION_METHODS, register_quantization_config
)
MLU_QUANTIZATION_METHODS= [
"smoothquant",
"weightonly",
"awq_mlu",
"gptq_mlu",
]
def register_fake_mlu_quantization_methods():
for quant_method in MLU_QUANTIZATION_METHODS:
if quant_method not in QUANTIZATION_METHODS:
QUANTIZATION_METHODS.append(quant_method)
def remove_fake_mlu_quantization_methods():
for quant_method in MLU_QUANTIZATION_METHODS:
if quant_method in QUANTIZATION_METHODS:
QUANTIZATION_METHODS.remove(quant_method)
def register_real_mlu_quantization_methods():
remove_fake_mlu_quantization_methods()
from vllm_mlu.model_executor.layers.quantization.weightonly import WeightOnlyConfig
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantConfig
from vllm_mlu.model_executor.layers.quantization.awq_mlu import AWQMluConfig
from vllm_mlu.model_executor.layers.quantization.gptq_mlu import GPTQMluConfig
register_quantization_config("weightonly")(WeightOnlyConfig)
register_quantization_config("smoothquant")(SmoothQuantConfig)
register_quantization_config("awq_mlu")(AWQMluConfig)
register_quantization_config("gptq_mlu")(GPTQMluConfig)

View File

@@ -0,0 +1,412 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Any, Dict, List, Optional, Tuple
import torch
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.scalar_type import ScalarType, scalar_types
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
def query_mlu_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if has_zp:
# AWQ style, unsigned + zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
return [scalar_types.uint4b8, scalar_types.uint8b128]
def check_mlu_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
supported_types = query_mlu_supported_quant_types(
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"Mlu does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
return (False, f"Mlu does not support group_size = {group_size}. "
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True
# @register_quantization_config("awq_mlu")
class AWQMluConfig(QuantizationConfig):
"""Config class for AWQMlu.
Reference: https://arxiv.org/abs/2306.00978
"""
# num_bits -> type
TYPE_MAP = {
4: {
False: scalar_types.uint4b8,
True: scalar_types.uint4,
},
8: {
False: scalar_types.uint8b128,
True: scalar_types.uint8,
}
}
VERSION = ["gemm"]
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
lm_head_quantized: bool,
version: str = "gemm",
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.pack_factor = 32 // self.weight_bits
self.version = version
self.support_scale_zeros = False
if self.weight_bits not in [4, 8]:
raise ValueError(
"Currently, only 4/8-bit weight quantization is supported for "
f"AWQMlu, but got {self.weight_bits} bits.")
if self.version not in self.VERSION:
raise ValueError(
"Currently, only gemm, gemv version is supported for "
f"AWQMlu, but got verion:{self.version}.")
if self.version in ["gemm"]:
self.order_map = {4: [0, 2, 4, 6, 1, 3, 5, 7], 8: [0, 2, 1, 3]}
self.reverse_order_map = {4 : [0, 4, 1, 5, 2, 6, 3, 7], 8: [0, 2, 1, 3]}
else:
self.order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
self.reverse_order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
def __repr__(self) -> str:
return (f"AWQMluConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point}), "
f"lm_head_quantized={self.lm_head_quantized})")
@classmethod
def get_name(cls) -> str:
return "awq_mlu"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@staticmethod
def get_config_filenames() -> List[str]:
return ["quant_config.json", "quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQMluConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
version = cls.get_from_keys_or(config, ["version"],
default="gemm")
return cls(weight_bits, group_size, zero_point, lm_head_quantized, version)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQMluLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return AWQMluLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_awq_mlu_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "awq"
or user_quant == "awq_mlu")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "awq":
logger.info("Detected that the model can run with awq_mlu"
", however you specified quantization=awq explicitly,"
" so forcing awq. Use quantization=awq_mlu for"
" faster inference")
return None
@classmethod
def is_awq_mlu_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
has_zp = quant_config.get("zero_point", None)
version = quant_config.get("version", "gemm")
if quant_method != "awq":
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or has_zp is None):
return False
if num_bits not in cls.TYPE_MAP:
return False
if version not in cls.VERSION:
return False
return check_mlu_supported(quant_type=cls.TYPE_MAP[num_bits][has_zp],
group_size=group_size,
has_zp=has_zp)
class AWQMluLinearMethod(LinearMethodBase):
"""Linear method for AWQMlu.
Args:
quant_config: The AWQMlu quantization config.
"""
def __init__(self, quant_config: AWQMluConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
weight_loader = extra_weight_attrs.get("weight_loader")
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
qzeros = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
scales = GroupQuantScaleParameter(data=torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
dtype=params_dtype,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader)
layer.register_parameter("qweight", qweight)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
packed_qweight, scale_zeros = self.extract_autoawq(layer)
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
layer.qzeros = None
layer.scales = None
else:
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
if scale_zeros is not None:
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
else:
layer.qzeros = None
layer.scales = torch.nn.Parameter(layer.scales.data.transpose(0, 1).contiguous(), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.quant_config.zero_point and not self.quant_config.support_scale_zeros:
output = mlu_ops.matmul(x, layer.qweight, bias)
if residual is not None:
output = output + residual
else:
output = mlu_ops.weight_only_quant_matmul(x,
layer.qweight,
layer.scales,
layer.qzeros,
bias,
residual,
"none",
self.quant_config.weight_bits)
return output
def extract_autoawq(self, layer: torch.nn.Module):
qweight = layer.qweight.data
qzeros = layer.qzeros.data
scales = layer.scales.data
bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
# Unpack the qweight and qzeros tensors
iweight, izeros = self.unpack_awq_int32_into_int8(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = self.reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
if izeros is not None:
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
scales = scales.repeat_interleave(group_size, dim=0)
if izeros is not None:
izeros = izeros.repeat_interleave(group_size, dim=0)
fweight = (iweight - izeros) * scales
else:
fweight = iweight * scales
# transpose [ci, co] -> [co, ci]
fweight = fweight.transpose(0, 1)
return fweight, None
if self.quant_config.zero_point and self.quant_config.support_scale_zeros and izeros is not None:
scale_zeros = izeros.to(scales.dtype) * -1 * scales
# transpose [ci, co] -> [co, ci]
scale_zeros = scale_zeros.transpose(0, 1)
else:
scale_zeros = None
# transpose [ci, co] -> [co, ci]
iweight = iweight.to(torch.int8).transpose(0, 1)
if bits == 4:
higher_bit_tensor = iweight[:, 1::2]
lower_bit_tensor = iweight[:, 0::2]
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
else:
packed_qweight = iweight
return packed_qweight, scale_zeros
def unpack_awq_int32_into_int8(self, qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qweight.device)
dtype = torch.int16 if bits == 8 else torch.int8
# unpacking columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(dtype)
iweights = iweights.view(iweights.shape[0], -1)
if not self.quant_config.zero_point or self.quant_config.support_scale_zeros:
iweights = torch.bitwise_and(iweights - 2**(bits - 1), (2 ** bits) - 1)
# unpacking columnwise
if qzeros is not None:
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(dtype)
izeros = izeros.view(izeros.shape[0], -1)
if not self.quant_config.zero_point:
izeros = torch.bitwise_and(izeros - 2**(bits - 1), (2 ** bits) - 1)
else:
izeros = None
return iweights, izeros
def reverse_awq_order(self, iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
reverse_order_tensor = torch.arange(iweights.shape[-1], dtype=torch.int32, device=iweights.device)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, self.quant_config.reverse_order_map[bits]]
reverse_order_tensor = reverse_order_tensor.view(-1)
rweights = iweights[:, reverse_order_tensor]
if izeros is not None:
rzeros = izeros[:, reverse_order_tensor]
return rweights, rzeros
def combine_low_bits(self, tensor_a, tensor_b):
"""
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
Args:
tensor_a (torch.Tensor): First tensor of type int8.
tensor_b (torch.Tensor): Second tensor of type int8.
Returns:
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
"""
# 确保输入是 int8 类型
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
raise ValueError("Both tensors must be of int8 type.")
# 提取每个 tensor 的低4位
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
# 将 tensor_a 的低4位左移4位
shifted_low_bits_a = low_bits_a << 4
# 组合两个 tensor 的低4位
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
return combined

View File

@@ -0,0 +1,753 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import functools
from functools import partial
import importlib.util
from typing import Any, Callable, Optional, Union
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from typing import Any, Dict, List, Optional, Callable
from vllm import envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.quantization.fp8 import (
get_flashinfer_moe_backend,
ACTIVATION_SCHEMES,
Fp8Config,
Fp8LinearMethod,
Fp8MoeBackend,
Fp8MoEMethod,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
validate_fp8_block_shape
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, cutlass_block_fp8_supported, cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale,
maybe_create_device_identity, Fp8LinearOp)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter, ChannelQuantScaleParameter,
ModelWeightParameter, PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
from vllm.utils.flashinfer import has_flashinfer_moe
from vllm.utils.import_utils import has_deep_gemm
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize
import vllm_mlu._mlu_ops as mlu_ops
logger = init_logger(__name__)
def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if (
current_platform.is_cuda()
and (
current_platform.is_device_capability(100)
or current_platform.is_device_capability(90)
)
and envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
):
backend = get_flashinfer_moe_backend()
if backend == FlashinferMoeBackend.TENSORRT_LLM:
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
if block_quant and current_platform.is_device_capability(100):
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency "
"instead."
)
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
return Fp8MoeBackend.FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: disable marlin for MLU backend.
'''
if current_platform.is_rocm() or current_platform.is_out_of_tree():
use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
if use_marlin:
logger.info_once("Using Marlin backend for FP8 MoE")
return Fp8MoeBackend.MARLIN
# deepGEMM on supported platforms with block-quantized weights
if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
if not has_deep_gemm():
logger.warning_once("DeepGEMM backend requested but not available.")
elif is_deep_gemm_supported():
logger.info_once("Using DeepGEMM backend for FP8 MoE")
return Fp8MoeBackend.DEEPGEMM
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
if (
current_platform.is_cuda()
and current_platform.is_device_capability(100)
and block_quant
):
logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
# default to Triton
logger.info_once("Using Triton backend for FP8 MoE")
return Fp8MoeBackend.TRITON
Fp8Config____init____org = Fp8Config.__init__
def vllm__model_executor__layers__quantization__fp8__Fp8Config____init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: list[str] | None = None,
weight_block_size: list[int] | None = None,
activation_quant_method: Optional[str] = None,
weight_quant_method: Optional[str] = None,
) -> None:
super(Fp8Config, self).__init__()
Fp8Config____init____org(
self,
is_checkpoint_fp8_serialized,
activation_scheme,
ignored_layers,
weight_block_size
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add class members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
self.activation_quant_method = activation_quant_method
self.weight_quant_method = weight_quant_method
assert (self.weight_block_size or \
self.activation_quant_method == "per_token" and self.weight_quant_method == "per_channel"
and self.activation_scheme == "dynamic"), "Only support block-wise quantization, or "\
"input dynamic per-token weight per-channel quantization yet."
'''
==================
End of MLU Hijack
==================
'''
@classmethod
def vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config(
cls, config: Dict[str, Any]
) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "fp8" in quant_method
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
if not ignored_layers:
ignored_layers = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add config members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
activation_quant_method = cls.get_from_keys_or(config,
["activation_quant_method"],
'per_token')
weight_quant_method = cls.get_from_keys_or(config,
["weight_quant_method"],
None)
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=weight_block_size,
activation_quant_method=activation_quant_method,
weight_quant_method=weight_quant_method)
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group.
'''
tp_group = extra_weight_attrs.get("tp_group", None)
'''
==================
End of MLU Hijack
==================
'''
if self.block_quant:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
validate_fp8_block_shape(
layer,
input_size,
output_size,
input_size_per_partition,
output_partition_sizes,
self.weight_block_size,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group.
'''
# WEIGHT
if self.quant_config.is_checkpoint_fp8_serialized:
weight = create_fp8_weight_parameter(
output_size_per_partition, input_size_per_partition, weight_loader
)
else:
# For non-serialized checkpoints, use original dtype
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
tp_group=tp_group,
)
'''
==================
End of MLU Hijack
==================
'''
layer.register_parameter("weight", weight)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
if not self.block_quant:
'''
=============================
Modify by vllm_mlu
=============================
@brief: Support weight per channel quantization.
@brief: Add tp_group to enable custom split.
'''
if self.weight_per_channel:
scale = ChannelQuantScaleParameter(
data=torch.empty(sum(output_partition_sizes), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
tp_group=tp_group,
)
else:
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
layer.register_parameter("weight_scale", scale)
'''
==================
End of MLU Hijack
==================
'''
else:
assert not self.act_q_static
assert self.weight_block_size is not None
scale = create_fp8_scale_parameter(
BlockQuantScaleParameter,
output_partition_sizes,
input_size_per_partition,
self.weight_block_size,
weight_loader,
)
set_weight_attrs(scale, {"scale_type": "weight_scale"})
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
# INPUT ACTIVATION SCALE
if self.act_q_static:
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__(
self,
quant_config: Fp8Config
):
self.quant_config = quant_config
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
if vllm_is_batch_invariant():
self.use_marlin = False
# AITER is only supported on ROCm and only for FP8_FNUZ
# and at the moment are MI300 series
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
self.use_deep_gemm = is_deep_gemm_supported()
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.act_q_static = self.quant_config.activation_scheme == "static"
if self.weight_block_size:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add config members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
self.weight_per_channel = (self.quant_config.weight_quant_method == 'per_channel')
self.activation_per_token = (self.quant_config.activation_quant_method == 'per_token')
if self.weight_per_channel and self.activation_per_token:
self.use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
if self.block_quant:
assert not self.act_q_static
assert self.weight_block_size is not None
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape,
)
Fp8LinearMethod__process_weights_after_loading__org = Fp8LinearMethod.process_weights_after_loading
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading(
self,
layer: Module,
) -> None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: For dynamic activation and channel-wise weight quantization,
additional processing is not needed.
'''
if (self.quant_config.is_checkpoint_fp8_serialized
and self.weight_per_channel
and self.quant_config.activation_scheme == "dynamic"):
return
'''
==================
End of MLU Hijack
==================
'''
Fp8LinearMethod__process_weights_after_loading__org(self=self, layer=layer)
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert residual is None, "Fp8Linear residual is not supported yet."
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
if vllm_is_batch_invariant():
if self.block_quant:
assert self.weight_block_size is not None
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
else:
# per-tensor/channel: dequant to BF16 and run GEMM
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
if weight_scale.numel() == 1:
# Per-tensor: simple scalar multiplication
weight_bf16 = weight_fp8 * weight_scale
else:
# Multiple scales (fused modules like QKV)
# Try to infer correct broadcasting
# weight is [K, N], scale could be [num_logical_weights]
# Need to figure out how to broadcast - for now just try
# direct multiplication
if (
weight_scale.dim() == 1
and weight_scale.shape[0] == weight_fp8.shape[0]
):
# Per-row scaling
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
else:
# Fallback
weight_bf16 = weight_fp8 * weight_scale
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
if self.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
if self.block_quant:
assert self.weight_block_size is not None
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear)
return apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Use activation per token quantization based on quantization config.
'''
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
weight_per_channel=self.weight_per_channel,
activation_per_token=self.activation_per_token)
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__(
self,
quant_config: Fp8Config,
layer: torch.nn.Module
):
super(Fp8MoEMethod, self).__init__(layer.moe_config)
self.layer = layer
self.quant_config = quant_config
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
if self.block_quant:
assert self.weight_block_size == [128, 128], (
f"Only support weight_block_size == [128, 128], "
f"got {self.weight_block_size}"
)
self.flashinfer_moe_fn = partial(
flashinfer_cutlass_moe_fp8,
moe=self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
self.allow_cutlass_block_scaled_grouped_gemm = (
self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: In mlu, always set self.use_marlin as False.
'''
self.use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Use moe_softmax_topk and moe_sigmoid_topk of mlu_ops to implement FusedMoE.select_experts
'''
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
if scoring_func == "softmax":
topk_weights, topk_ids = mlu_ops.moe_softmax_topk(
router_logits,
top_k,
renormalize,
num_expert_group,
topk_group,
route_scale=routed_scaling_factor,
)
elif scoring_func == "sigmoid":
topk_weights, topk_ids = mlu_ops.moe_sigmoid_topk(
router_logits,
top_k,
renormalize,
num_expert_group,
topk_group,
routed_scaling_factor,
e_score_correction_bias,
)
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
# gen_idx
ori_input_shape = x.shape
x = x.reshape(-1, x.size(-1))
router_logits = router_logits.reshape(-1, router_logits.size(-1))
expert_num = router_logits.size(-1)
tokens_num = x.size(0)
expert_size = layer.w13_weight.size(0)
expand_idx, combine_idx, token_count, cumsum_token_count = mlu_ops.moe_gen_idx(
topk_ids, expert_num
)
expand_hidden_states = mlu_ops.moe_expand_input(
x, expand_idx, cumsum_token_count, 0, expert_size
)
quant_input, input_scale = _fp8_quantize(
expand_hidden_states, A_scale=None, block_shape=self.quant_config.weight_block_size
)
gemm1_out = mlu_ops.smooth_quant_group_gemm(
quant_input,
layer.w13_weight,
token_count,
expand_idx=None,
c=None,
alpha=None,
beta=None,
a_scale=input_scale.T.contiguous(),
b_scale=layer.w13_weight_scale_inv,
dtype=x.dtype,
max_m=tokens_num,
)
act_out = mlu_ops.active(gemm1_out, activation, is_gated=True)
act_out_quantize, act_out_scale = _fp8_quantize(
act_out, A_scale=None, block_shape=self.quant_config.weight_block_size
)
gemm2_out = mlu_ops.smooth_quant_group_gemm(
act_out_quantize,
layer.w2_weight,
token_count,
expand_idx=None,
c=None,
alpha=None,
beta=None,
a_scale=act_out_scale.T.contiguous(),
b_scale=layer.w2_weight_scale_inv,
dtype=x.dtype,
max_m=tokens_num,
)
output = mlu_ops.moe_combine_result(
gemm2_out,
topk_weights,
combine_idx,
residual=None,
cusum_token_count=cumsum_token_count,
start_expert_id=0,
expert_size=expert_size,
bias=None,
)
return output.view(ori_input_shape)
"""
==================
End of MLU Hijack
==================
"""
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.apply,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply
)
MluHijackObject.apply_hijack(
Fp8Config,
Fp8Config.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8Config____init__
)
MluHijackObject.apply_hijack(
Fp8Config,
Fp8Config.from_config,
vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.create_weights,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.process_weights_after_loading,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading
)
MluHijackObject.apply_hijack(
Fp8MoEMethod,
Fp8MoEMethod.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__
)
MluHijackObject.apply_hijack(
Fp8MoEMethod,
Fp8MoEMethod.apply,
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply
)

View File

@@ -0,0 +1,440 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from fractions import Fraction
from typing import Any, Dict, List, Optional, Tuple
import torch
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
def query_mlu_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if has_zp:
# AWQ style, unsigned + zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
return [scalar_types.uint4b8, scalar_types.uint8b128]
def check_mlu_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
supported_types = query_mlu_supported_quant_types(
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"Mlu does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
return (False, f"Mlu does not support group_size = {group_size}. "
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True
# @register_quantization_config("gptq_mlu")
class GPTQMluConfig(QuantizationConfig):
"""Config class for GPTQMlu.
Reference: https://arxiv.org/abs/2210.17323
"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
(4, False): scalar_types.uint4b8,
(8, False): scalar_types.uint8b128,
}
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.lm_head_quantized = lm_head_quantized
self.pack_factor = Fraction(32, self.weight_bits)
self.support_scale_zeros = False
self.use_native = self.desc_act or (not self.is_sym and not self.support_scale_zeros)
if self.weight_bits not in [4, 8]:
raise ValueError(
"Currently, only 4/8-bit weight quantization is "
f"supported for GPTQMlu, but got {self.weight_bits} bits.")
def __repr__(self) -> str:
return (f"GPTQMluConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}),"
f"lm_head_quantized={self.lm_head_quantized}")
@classmethod
def get_name(cls) -> str:
return "gptq_mlu"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quant_config.json", "quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMluConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym, lm_head_quantized)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQMluLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQMluLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
@classmethod
def is_gptq_mlu_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
if quant_method != "gptq":
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
if (num_bits, sym) not in cls.TYPE_MAP:
return False
return check_mlu_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size, has_zp=False)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_gptq_mlu_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "gptq_mlu")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
return None
class GPTQMluLinearMethod(LinearMethodBase):
"""Linear method for GPTQMlu.
Args:
quant_config: The GPTQMlu quantization config.
"""
def __init__(self, quant_config: GPTQMluConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs.get("weight_loader")
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if (input_size != input_size_per_partition) and (self.quant_config.group_size !=
-1) and (not self.quant_config.desc_act):
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
g_idx = RowvLLMParameter(data=torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
qzeros_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scale_and_zero_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.device = layer.qweight.data.device
packed_qweight, scale_zeros = self.extract_autogptq(layer)
if self.quant_config.use_native:
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
layer.qzeros = None
layer.scales = None
else:
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
if scale_zeros is not None:
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
else:
layer.qzeros = None
layer.scales = torch.nn.Parameter(layer.scales.transpose(0, 1).contiguous(), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.quant_config.use_native:
output = mlu_ops.matmul(x, layer.qweight, bias)
if residual is not None:
output = output + residual
else:
output = mlu_ops.weight_only_quant_matmul(x,
layer.qweight,
layer.scales,
layer.qzeros,
bias,
residual,
"none",
self.quant_config.weight_bits)
return output
def extract_autogptq(self, layer: torch.nn.Module):
scales = layer.scales.data
bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
# Unpack the qweight and qzeros tensors
iweight = self.unpack_gptq_qweight_int32_into_int8(layer.qweight.data, bits)
izeros = self.unpack_gptq_qzeros_int32_into_int8(layer.qzeros.data, bits)
if self.quant_config.use_native:
if self.quant_config.desc_act:
scales = torch.index_select(scales, 0, layer.g_idx)
if izeros is not None:
izeros = torch.index_select(izeros, 0, layer.g_idx)
else:
scales = scales.repeat_interleave(group_size, dim=0)
if izeros is not None:
izeros = izeros.repeat_interleave(group_size, dim=0)
if izeros is not None:
fweight = (iweight - izeros) * scales
else:
fweight = iweight * scales
# transpose [ci, co] -> [co, ci]
fweight = fweight.transpose(0, 1)
return fweight, None
if not self.quant_config.is_sym and self.quant_config.support_scale_zeros and izeros is not None:
scale_zeros = izeros.to(scales.dtype) * -1 * scales
# transpose [ci, co] -> [co, ci]
scale_zeros = scale_zeros.transpose(0, 1)
else:
# for is_sym is true now, so make iweight to sign value and ignore qzeros
iweight = torch.bitwise_and(iweight - 2**(bits - 1), (2 ** bits) - 1)
scale_zeros = None
# transpose [ci, co] -> [co, ci]
iweight = iweight.to(torch.int8).transpose(0, 1)
if bits == 4:
higher_bit_tensor = iweight[:, 1::2]
lower_bit_tensor = iweight[:, 0::2]
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
else:
packed_qweight = iweight
return packed_qweight, scale_zeros
def unpack_gptq_qweight_int32_into_int8(self, qweight: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qweight.device).unsqueeze(0)
dtype = torch.int16 if bits == 8 else torch.int8
weight = torch.bitwise_right_shift(
torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1),
shifts.unsqueeze(-1),
).to(dtype)
weight = torch.bitwise_and(weight, (2**bits) - 1)
weight = weight.reshape(-1, weight.shape[-1])
return weight
def unpack_gptq_qzeros_int32_into_int8(self, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qzeros.device).unsqueeze(0)
dtype = torch.int16 if bits == 8 else torch.int8
zeros = torch.bitwise_right_shift(
torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
shifts.unsqueeze(0),
).to(dtype)
zeros = zeros + 1
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
zeros = zeros.reshape(qzeros.shape[0], -1)
return zeros
def combine_low_bits(self, tensor_a, tensor_b):
"""
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
Args:
tensor_a (torch.Tensor): First tensor of type int8.
tensor_b (torch.Tensor): Second tensor of type int8.
Returns:
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
"""
# 确保输入是 int8 类型
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
raise ValueError("Both tensors must be of int8 type.")
# 提取每个 tensor 的低4位
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
# 将 tensor_a 的低4位左移4位
shifted_low_bits_a = low_bits_a << 4
# 组合两个 tensor 的低4位
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
return combined

View File

@@ -0,0 +1,337 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
ModelWeightParameter,
RowvLLMParameter)
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.quantization.utils.common_utils import (str_dtype_to_torch,
str_dtype_to_bits,
is_fp8_str_dtype)
# @register_quantization_config("smoothquant")
class SmoothQuantConfig(QuantizationConfig):
"""Config class for SmoothQuant.
"""
def __init__(
self,
quant_mode: str, # smoothquant
input_quant_method: str, # per token/per tensor
group_size: int,
weight_precision: str,
activation_precision: str,
only_expert_per_group: bool,
expert_weight_precision: str,
expert_activation_precision: str,
force_use_weightonly_except_expert: bool,
) -> None:
super().__init__()
self.quant_mode = quant_mode
self.input_quant_method = input_quant_method
self.group_size = group_size
self.weight_precision = weight_precision
self.activation_precision = activation_precision
self.only_expert_per_group = only_expert_per_group
self.expert_weight_precision = expert_weight_precision
self.expert_activation_precision = expert_activation_precision
self.force_use_weightonly_except_expert = force_use_weightonly_except_expert
if quant_mode == "SmoothQuant" and (self.input_quant_method != "per_token" and self.input_quant_method != "per_tensor"):
raise ValueError(
"Currently, only per_token or per_tensor input quantization is supported for "
f"SmoothQuant, but got {self.input_quant_method}.")
self.weight_bits = str_dtype_to_bits(self.weight_precision)
self.expert_weight_bits = str_dtype_to_bits(self.expert_weight_precision)
if self.weight_precision == 'int4':
self.weight_dtype = torch.int8
else:
self.weight_dtype = str_dtype_to_torch(self.weight_precision)
if self.expert_weight_precision == 'int4':
self.expert_weight_dtype = torch.int8
else:
self.expert_weight_dtype = str_dtype_to_torch(self.expert_weight_precision)
self.is_fp8 = is_fp8_str_dtype(self.weight_precision)
self.expert_is_fp8 = is_fp8_str_dtype(self.expert_weight_precision)
self.pack_factor = 8 // self.weight_bits
self.expert_pack_factor = 8 // self.expert_weight_bits
def __repr__(self) -> str:
return (f"SmoothQuantConfig(input_quant_method={self.input_quant_method}, "
f"quant_mode={self.quant_mode}, "
f"group_size={self.group_size}, "
f"weight_precision={self.weight_precision}, "
f"activation_precision={self.activation_precision}, "
f"only_expert_per_group={self.only_expert_per_group}, "
f"expert_weight_precision={self.expert_weight_precision}, "
f"expert_activation_precision={self.expert_activation_precision}, "
f"force_use_weightonly_except_expert={self.force_use_weightonly_except_expert})")
@classmethod
def get_name(self) -> str:
return "SmoothQuant"
@classmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@staticmethod
def get_config_filenames() -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig":
quant_mode = cls.get_from_keys(config, ["quant_mode"])
input_quant_method = cls.get_from_keys(config, ["input_quant_method"])
group_size = cls.get_from_keys_or(config, ["group_size"], 1)
weight_precision = cls.get_from_keys_or(config, ["weight_precision"], "int8")
activation_precision = cls.get_from_keys_or(config, ["activation_precision"], "int8")
only_expert_per_group = cls.get_from_keys_or(config, ["only_expert_per_group"], False)
expert_weight_precision = cls.get_from_keys_or(config, ["expert_weight_precision"], None)
expert_activation_precision = cls.get_from_keys_or(config, ["expert_activation_precision"], None)
force_use_weightonly_except_expert = cls.get_from_keys_or(config, ["force_use_weightonly_except_expert"], False)
if expert_weight_precision is None:
expert_weight_precision = weight_precision
if group_size > 1 and only_expert_per_group and weight_precision == 'int4':
weight_precision = 'int8'
if expert_activation_precision is None:
expert_activation_precision = activation_precision
return cls(quant_mode=quant_mode,
input_quant_method=input_quant_method,
group_size=group_size,
weight_precision=weight_precision,
activation_precision=activation_precision,
only_expert_per_group=only_expert_per_group,
expert_weight_precision=expert_weight_precision,
expert_activation_precision=expert_activation_precision,
force_use_weightonly_except_expert=force_use_weightonly_except_expert)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["SmoothQuantLinearMethod"]:
if isinstance(layer, LinearBase):
return SmoothQuantLinearMethod(self, prefix)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class SmoothQuantLinearMethod(LinearMethodBase):
"""Linear method for SmoothQuant.
Args:
quant_config: The SmoothQuant quantization config.
"""
def __init__(self, quant_config: SmoothQuantConfig, prefix: str):
self.quant_config = quant_config
# for per-tensor case, we can skip quant input for the first attn|ffn linear
# and fusion this step in layernorm to get better performance
self.skip_quant_input = False
self.compute_dtype = torch.get_default_dtype()
self.is_expert = 'expert' in prefix and "shared_expert" not in prefix
self.weight_dtype = quant_config.expert_weight_dtype if self.is_expert else quant_config.weight_dtype
self.pack_factor = quant_config.expert_pack_factor if self.is_expert else quant_config.pack_factor
self.is_fp8 = quant_config.expert_is_fp8 if self.is_expert else quant_config.is_fp8
if quant_config.only_expert_per_group and self.is_expert and quant_config.group_size > 1:
self.is_group_quant = True
elif quant_config.only_expert_per_group is False and quant_config.group_size > 1:
self.is_group_quant = True
else:
self.is_group_quant = False
self.has_smooth = self.quant_config.input_quant_method == "per_token" and (
self.quant_config.force_use_weightonly_except_expert is False or self.is_expert)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor != 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
weight_loader = extra_weight_attrs.get("weight_loader")
group_num = 1
if self.is_group_quant:
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
f"The input size {input_size_per_partition} is not aligned with the quantized "
f"weight shape. This can be caused by too large "
f"tensor parallel size. group_size: {self.quant_config.group_size}.")
group_num = (input_size + self.quant_config.group_size - 1) // self.quant_config.group_size
if input_size_per_partition != input_size:
group_num = (input_size_per_partition + self.quant_config.group_size - 1) // self.quant_config.group_size
qweight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.pack_factor,
device="mlu",
dtype=self.weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
if self.is_group_quant:
per_channel_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
group_num,
device="mlu",
dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
else:
per_channel_scale = ChannelQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
device="mlu",
dtype=torch.float32,
),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("qweight", qweight)
layer.register_parameter("per_channel_scale", per_channel_scale)
if self.has_smooth:
smooth = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
device="mlu",
dtype=torch.float32,
),
input_dim=0,
weight_loader=weight_loader,
)
set_weight_attrs(smooth, {
"ignore_warning": True,
})
layer.register_parameter("smooth", smooth)
if self.quant_config.input_quant_method == "per_tensor":
scale_to_int = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
device="mlu",
dtype=torch.float32,
),
input_dim=0,
weight_loader=weight_loader,
)
set_weight_attrs(scale_to_int, {
"ignore_warning": True,
})
layer.register_parameter("scale_to_int", scale_to_int)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.has_smooth and layer.smooth.dtype != torch.float:
layer.smooth = layer.smooth.to(torch.float)
if self.quant_config.input_quant_method == "per_tensor" and layer.scale_to_int.dtype != torch.float:
layer.scale_to_int = layer.scale_to_int.to(torch.float)
if layer.per_channel_scale.dtype != torch.float:
layer.per_channel_scale = layer.per_channel_scale.to(torch.float)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.per_channel_scale = Parameter(layer.per_channel_scale.data, requires_grad=False)
if self.has_smooth:
layer.smooth = Parameter(layer.smooth.data, requires_grad=False)
if self.quant_config.input_quant_method == "per_tensor":
layer.scale_to_int = Parameter(layer.scale_to_int.data, requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
use_tp_weight : bool = False,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
layer_smooth = layer.smooth if self.has_smooth else None
layer_qweight = layer.qweight
layer_per_channel_scale = layer.per_channel_scale
if use_tp_weight:
if hasattr(layer, 'tp_smooth'):
layer_smooth = layer.tp_smooth
if hasattr(layer, 'tp_qweight'):
layer_qweight = layer.tp_qweight
if hasattr(layer, 'tp_per_channel_scale'):
layer_per_channel_scale = layer.tp_per_channel_scale
quant_input = None
if self.skip_quant_input:
quant_input = x
elif self.quant_config.input_quant_method == "per_token":
if self.is_fp8:
quant_input, input_scale = mlu_ops.scaled_quantize(x,
layer_smooth,
quant_type=self.weight_dtype,
quant_mode='dynamic_per_token')
else:
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer_smooth, None)
elif self.quant_config.input_quant_method == "per_tensor":
quant_input = mlu_ops.quantize(x, layer.scale_to_int, None)
else:
raise ValueError(
"Currently, only per_token or per_tensor input quantization is supported for "
f"SmoothQuant, but got {self.input_quant_method}.")
quant_input_shape = quant_input.shape
if len(quant_input_shape) > 2:
quant_input = quant_input.view(-1, quant_input_shape[-1])
input_scale = input_scale.view(-1)
if residual is not None and len(residual.shape) > 2:
residual = residual.view(-1, residual.shape[-1])
if self.is_fp8:
out = mlu_ops.scaled_matmul(quant_input, layer_qweight, input_scale,
layer_per_channel_scale,
self.compute_dtype if hasattr(self, 'compute_dtype') else x.dtype,
bias,
c=residual, act_mode="none",quant_bit_size=8,
alpha=1.0, beta=1.0, use_hp_active=False,
a_quant_bit_size=8, a_calib=None, b_calib=None)
if output is not None:
out = out.view(output.shape)
output.copy_(out)
out = output
else:
if output is not None:
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
layer_per_channel_scale, self.compute_dtype, bias, residual, output=output)
else:
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
layer_per_channel_scale, self.compute_dtype, bias, residual)
if len(quant_input_shape) > 2:
out = out.view(*quant_input_shape[:-1], out.shape[-1])
return out

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
QUANTIZATION_CHOICES = ['int8', 'int4', 'e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz']
INTERGER_DTYPES = [torch.uint8, torch.uint16, torch.uint32, torch.uint64, torch.int8, torch.int16, torch.short,
torch.int32, torch.int, torch.int64, torch.long]
FLOAT_DTYPES = [torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.bfloat16,
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz, torch.half]
FP8_DTYPE = [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
FP8_STR_DTYPE = ['e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz']
GEMM_GROUP_SIZE = [64, 128, 256, 512]
_STR_TO_TORCH_DTYPE_DICT = dict(
bfloat16=torch.bfloat16,
float16=torch.float16,
float32=torch.float32,
int64=torch.int64,
int32=torch.int32,
int8=torch.int8,
bool=torch.bool,
e4m3fn=torch.float8_e4m3fn,
e4m3fnuz=torch.float8_e4m3fnuz,
e5m2=torch.float8_e5m2,
e5m2fnuz=torch.float8_e5m2fnuz,
)
TORCH_DTYPE_TO_STR_DICT = {
torch.bfloat16: "bfloat16",
torch.float16: "float16",
torch.float32: "float32",
torch.int64: "int64",
torch.int32: "int32",
torch.int8: "int8",
torch.bool: "bool",
torch.float8_e4m3fn: "e4m3fn",
torch.float8_e4m3fnuz: "e4m3fnuz",
torch.float8_e5m2: "e5m2",
torch.float8_e5m2fnuz: "e5m2fnuz",
}
STR_DTYPE_TO_BITS_DICT = {
"bfloat16": 16,
"float16": 16,
"float32": 32,
"int64": 64,
"int32": 32,
"int8": 8,
'int4': 4,
"bool": 1,
"e4m3fn": 8,
"e4m3fnuz": 8,
"e5m2": 8,
"e5m2fnuz": 8,
}
def str_dtype_to_torch(str_dtype: str):
'''
convert torch dytpe to str dtype
'''
ret = _STR_TO_TORCH_DTYPE_DICT.get(str_dtype)
dtype = ret if ret is not None else torch.float16
return dtype
def torch_dtype_to_str(dtype: torch.dtype):
'''
convert torch dytpe to str dtype
'''
ret = TORCH_DTYPE_TO_STR_DICT.get(dtype)
str_dtype = ret if ret is not None else "float16"
return str_dtype
def str_dtype_to_bits(str_dtype):
'''
convert torch dtype to bits size
'''
ret = STR_DTYPE_TO_BITS_DICT.get(str_dtype)
bits = ret if ret is not None else 8
return bits
def is_integer_dtype(dtype: torch.dtype):
'''
check whether is integer or not
'''
return dtype in INTERGER_DTYPES
def is_float_dtype(dtype: torch.dtype):
'''
check whether is float or not
'''
return dtype in FLOAT_DTYPES
def is_fp8_dtype(dtype: torch.dtype):
'''
judge fp8 torch dtype
'''
return dtype in FP8_DTYPE
def is_fp8_str_dtype(str_dtype: str):
'''
judge fp8 str dtype
'''
return str_dtype in FP8_STR_DTYPE

View File

@@ -0,0 +1,424 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/pull/2575
import functools
import json
import os
from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_per_token_group_quant_fp8_colmajor)
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
'''
=============================
Modify by vllm_mlu
=============================
@brief: get total core for split triton kernel
'''
import triton.backends.mlu.driver as driver
_devprob = driver.BangUtils().get_device_properties(torch.mlu.current_device())
TOTAL_CLUSTER_NUM = _devprob.get("cluster_num")
TOTAL_CORE_NUM = TOTAL_CLUSTER_NUM * _devprob.get("core_num_per_cluster")
'''
==================
End of MLU Hijack
==================
'''
def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
and weight.shape[1] % 128 == 0)
if current_platform.is_rocm():
# TODO this is never used, as cutlass_block_fp8_supported is False
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
input_2d.shape[:-1])[::-1]
scale_b_shape = (weight_scale.view(-1, 1)
if weight_scale.dim() <= 1 else weight_scale.T).shape
ar, ac = scale_a_shape
br, bc = scale_b_shape
if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0])
or br not in (1, weight.shape[0])):
shape_supported_by_cutlass = False
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=True)
output = ops.cutlass_scaled_mm(q_input,
weight.T,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale.T)
else:
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=False)
output = w8a8_block_fp8_matmul(q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=input.dtype)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
'''
=============================
Modify by vllm_mlu
=============================
@brief: split for limit the memory usage(65536)
'''
group_per_block = 1
while M >= 65536:
group_per_block *= 2
M = x.numel() // (group_size * group_per_block)
'''
==================
End of MLU Hijack
==================
'''
if column_major_scales:
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device,
dtype=torch.float32).permute(-1, -2)
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
'''
=============================
Modify by vllm_mlu
=============================
@brief: set num_warps to 1 for triton-mlu
'''
num_warps = 1
num_stages = 1
'''
==================
End of MLU Hijack
==================
'''
if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M, )](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: replaced the 'scaled_quantize' kernel from the 'tmo' library with
'_per_token_group_quant_fp8' kernel
'''
# Check if x is contiguous, if not, create a new tensor for contiguous x
if not x.is_contiguous():
x = x.contiguous()
x_origin_shape = x.shape
x = x.reshape(*x.shape[:-1], -1, group_size)
x_q, x_s = mlu_ops.scaled_quantize(x,
None,
quant_type=dtype,
quant_mode='dynamic_per_token')
x_q = x_q.reshape(x_origin_shape)
'''
==================
End of MLU Hijack
==================
'''
return x_q, x_s
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
'''
=============================
Modify by vllm_mlu
=============================
@brief: split for limit the memory usage(65536)
'''
num_block_size_all = num_pid_m * num_pid_n
num_block_size_per = num_block_size_all // tl.num_programs(axis=0)
num_block_size_rem = num_block_size_all % tl.num_programs(axis=0)
core_deal_num_block_size = num_block_size_per + (pid < num_block_size_rem)
core_deal_num_block_start = num_block_size_per * pid + min(num_block_size_rem, pid)
for pid_i in range(0, core_deal_num_block_size):
pid_in_core_deal_block = core_deal_num_block_start + pid_i
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid_in_core_deal_block // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid_in_core_deal_block % group_size_m)
pid_n = (pid_in_core_deal_block % num_pid_in_group) // group_size_m
'''
==================
End of MLU Hijack
==================
'''
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
'''
=============================
Modify by vllm_mlu
=============================
@brief: replaced the 'scaled_matmul' kernel from the 'tmo' library with
'_w8a8_block_fp8_matmul' kernel
'''
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert B.ndim == 2 and Bs.ndim == 2
if (B.shape[0] % 128 == 0) and (B.shape[1] % 128 == 0):
C = mlu_ops.scaled_matmul(A, B, As, Bs, output_dtype, bias=None, c=None, act_mode="none",
quant_bit_size=8, alpha=1, beta=1, use_hp_active=False,
a_quant_bit_size=8, a_calib=None, b_calib=None)
else:
# NOTE(wulingchao): scaled_matmul 底层算子只支持n和k是128的倍数
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
# Default config
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
# BLOCK_SIZE_K must be divisible by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 1,
"num_stages": 1,
}
def grid(META):
return (TOTAL_CORE_NUM, )
_w8a8_block_fp8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
'''
==================
End of MLU Hijack
==================
'''
return C

View File

@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional, Callable
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, USE_ROWWISE_TORCH_SCALED_MM, cutlass_w8a8_scaled_mm,
flashinfer_w8a8_scaled_mm, rocm_per_tensor_w8a8_scaled_mm,
torch_per_tensor_w8a8_scaled_mm, torch_per_token_w8a8_scaled_mm,
torch_channelwise_w8a8_scaled_mm)
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def mlu_w8a8_scaled_mm(
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
output_shape: list, **kwargs
) -> torch.Tensor:
output = mlu_ops.scaled_matmul(
qinput, # a
weight, # b
scale_a, # a_scale
scale_b, # b_scale
out_dtype, # output_dtype
bias, # bias
c=None, act_mode="none",quant_bit_size=8, alpha=1, beta=1, use_hp_active=False,
a_quant_bit_size=8, a_calib=None, b_calib=None
)
return output.view(*output_shape)
def dispatch_w8a8_scaled_mm(
preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool,
weight_per_channel: bool, activation_per_token: bool
) -> Callable[..., torch.Tensor]:
if per_tensor_weights and per_tensor_activations:
if preferred_backend == "rocm":
return rocm_per_tensor_w8a8_scaled_mm
if preferred_backend == "flashinfer":
return flashinfer_w8a8_scaled_mm
if preferred_backend == "cutlass":
return cutlass_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
return cutlass_w8a8_scaled_mm
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
if (
not per_tensor_weights
and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM
):
return torch_per_token_w8a8_scaled_mm
# Normally, torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
'''
=============================
Modify by vllm_mlu
=============================
@brief: dispatch to mlu_w8a8_scaled_mm
'''
if weight_per_channel and activation_per_token:
return mlu_w8a8_scaled_mm
'''
==================
End of MLU Hijack
==================
'''
return torch_channelwise_w8a8_scaled_mm
def vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype | None = None,
input_scale: torch.Tensor | None = None,
input_scale_ub: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
weight_per_channel: bool = True,
activation_per_token: bool = True,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
'''
=============================
Modify by vllm_mlu
=============================
@brief: add mlu_fp8_supported
'''
self.mlu_fp8_supported = False
if weight_per_channel and activation_per_token:
self.mlu_fp8_supported = True
'''
==================
End of MLU Hijack
==================
'''
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[1]]
if out_dtype is None:
out_dtype = input.dtype
if self.mlu_fp8_supported:
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add support for activation-per-token weight-per-channel quantization.
'''
qinput, x_scale = mlu_ops.scaled_quantize(
input_2d,# x
None, # scale
None, # zero
None, # scale_ub
quant_type=torch.float8_e4m3fn,
quant_mode='dynamic_per_token'
)
output_shape = [*input.shape[:-1], weight.shape[0]]
'''
==================
End of MLU Hijack
==================
'''
else:
# If input not quantized
# TODO(luka) remove this path if not used anymore
if input.dtype != current_platform.fp8_dtype():
qinput, x_scale = self.quant_fp8(
input_2d,
input_scale,
input_scale_ub,
)
else:
qinput, x_scale = input_2d, input_scale
# Must have dim() conditions
# In per-token quant scenario, when the number of token is 1,
# the scale will only have 1 elements.
# Without checking the dim(),
# we cannot distingushes between per-tensor and per-token quant.
# Example:
# When the number of token is 1, per-token scale is [[1]]
# When per-tensor scale is [1] or ().
per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
self.preferred_backend, per_tensor_weights, per_tensor_activations,
weight_per_channel, activation_per_token)
return w8a8_scaled_mm_func(
qinput=qinput,
weight=weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
output_shape=output_shape,
)
MluHijackObject.apply_hijack(
Fp8LinearOp,
Fp8LinearOp.apply,
vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply
)

View File

@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm_mlu import _mlu_ops as mlu_ops
from vllm.logger import init_logger
logger = init_logger(__name__)
# @register_quantization_config("weightonly")
class WeightOnlyConfig(QuantizationConfig):
"""Config class for WeightOnly.
"""
def __init__(
self,
weight_bits: int,
quant_mode: str, # weight_only
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.quant_mode = quant_mode
if quant_mode == "WeightOnly" and (self.weight_bits != 8 and self.weight_bits != 4):
raise ValueError(
"Currently, only 8/4-bit weight quantization is supported for "
f"weight_only, but got {self.weight_bits} bits.")
self.pack_factor = 8 // self.weight_bits
def __repr__(self) -> str:
return (f"WeightOnlyConfig(weight_bits={self.weight_bits}, "
f"quant_mode={self.quant_mode})")
def get_name(self) -> str:
return "WeightOnly"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@staticmethod
def get_config_filenames() -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "WeightOnlyConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
try:
quant_mode = cls.get_from_keys(config, ["quant_mode"])
except Exception:
quant_mode = "WeightOnly"
return cls(weight_bits, quant_mode)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["WeightOnlyLinearMethod"]:
if isinstance(layer, LinearBase):
return WeightOnlyLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class WeightOnlyLinearMethod(LinearMethodBase):
"""Linear method for WeightOnly.
Args:
quant_config: The WeightOnly quantization config.
"""
def __init__(self, quant_config: WeightOnlyConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> Dict[str, Any]:
output_size_per_partition = sum(output_partition_sizes)
if self.quant_config.quant_mode == "WeightOnly":
scale_and_zero_input_dim = None
if output_size != output_size_per_partition:
scale_and_zero_input_dim = 0
qweight = Parameter(
torch.empty(
output_size_per_partition,
input_size_per_partition // self.quant_config.pack_factor,
device="mlu",
dtype=torch.int8,
),
requires_grad=False,
)
set_weight_attrs(qweight, {
"input_dim": 1,
"output_dim": 0,
})
scales = Parameter(
torch.empty(
output_size_per_partition,
device="mlu",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": scale_and_zero_input_dim,
"output_dim": 0,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if layer.scales.dtype != torch.float:
layer.scales = Parameter(layer.scales.to(torch.float), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
x_shape = x.shape
if len(x_shape) > 2:
x = x.view(-1, x_shape[-1])
out = mlu_ops.weight_only_quant_matmul(x,
layer.qweight,
layer.scales,
None,
bias,
residual,
"none",
self.quant_config.weight_bits)
if len(x_shape) > 2:
out = out.view(*x_shape[:-1], out.shape[-1])
return out

View File

@@ -0,0 +1,342 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import math
from typing import Any
import torch
from vllm.logger import init_logger
import vllm.model_executor.layers.rotary_embedding as rotary_embedding
from vllm.model_executor.layers.rotary_embedding import (
_ROPE_DICT,
RotaryEmbedding,
)
from vllm.model_executor.layers.rotary_embedding import (
_ROPE_DICT,
DualChunkRotaryEmbedding,
DynamicNTKAlphaRotaryEmbedding,
DynamicNTKScalingRotaryEmbedding,
Llama4VisionRotaryEmbedding,
MRotaryEmbedding,
NTKScalingRotaryEmbedding,
Phi3LongRoPEScaledRotaryEmbedding,
YaRNScalingRotaryEmbedding,
)
from .base import MLURotaryEmbedding
from .deepseek_scaling_rope import MLUDeepseekScalingRotaryEmbedding
from .dynamic_ntk_alpha_rope import MLUDynamicNTKAlphaRotaryEmbedding
from .dynamic_ntk_scaling_rope import MLUDynamicNTKScalingRotaryEmbedding
from .linear_scaling_rope import MLULinearScalingRotaryEmbedding
from .llama3_rope import MLULlama3RotaryEmbedding
from .mrope import MLUMRotaryEmbedding
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def get_long_max_model_max_position_emb(max_position_embeddings, scaling_factor):
if MLURotaryEmbedding.max_seq_len != None and \
MLURotaryEmbedding.max_seq_len > max_position_embeddings * scaling_factor:
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
f"max_position_embedding ({max_position_embeddings}) * scaling_factor ({scaling_factor}) " +
"from model's config.json, This may lead to incorrect model outputs or MLU errors. " +
f"Make sure the value is correct and within the model context size. " +
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
return math.ceil(MLURotaryEmbedding.max_seq_len / scaling_factor)
return max_position_embeddings
def vllm__model_executor__layers__rotary_embedding__get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: float,
is_neox_style: bool = True,
rope_scaling: dict[str, Any] | None = None,
dtype: torch.dtype | None = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: dict[str, Any] | None = None,
inverse: bool = False
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if dual_chunk_attention_config is not None:
dual_chunk_attention_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in dual_chunk_attention_config.items()
if k != "sparse_attention_config"
}
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
else:
dual_chunk_attention_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling_args,
dual_chunk_attention_args,
dtype,
inverse,
)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
for k, v in dual_chunk_attention_config.items()
if k in ("chunk_size", "local_size")
}
rotary_emb = DualChunkRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
**extra_kwargs,
)
elif not rope_scaling:
rotary_emb = MLURotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype,
inverse=inverse,
)
else:
scaling_type = rope_scaling["rope_type"]
if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
rotary_emb = MLULlama3RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
scaling_factor,
low_freq_factor,
high_freq_factor,
original_max_position,
)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MLUMRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
)
else:
rotary_emb = MLURotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
inverse=inverse,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
rotary_emb = MLULinearScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
elif scaling_type == "ntk":
scaling_factor = rope_scaling["factor"]
mixed_b = rope_scaling.get('mixed_b', None)
rotary_emb = NTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
mixed_b,
)
elif scaling_type == "dynamic":
if "alpha" in rope_scaling:
scaling_alpha = rope_scaling["alpha"]
rotary_emb = MLUDynamicNTKAlphaRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
dtype,
)
elif "factor" in rope_scaling:
scaling_factor = rope_scaling["factor"]
rotary_emb = MLUDynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
else:
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"apply_yarn_scaling",
)
}
if "mrope_section" in rope_scaling:
extra_kwargs.pop("apply_yarn_scaling", None)
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
scaling_factor=scaling_factor,
**extra_kwargs,
)
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: update original_max_position
'''
original_max_position = get_long_max_model_max_position_emb(
original_max_position, scaling_factor,
)
'''
==================
End of MLU Hijack
==================
'''
rotary_emb = YaRNScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
}
'''
=============================
Modify by vllm_mlu
=============================
@brief: update original_max_position
'''
original_max_position = get_long_max_model_max_position_emb(
original_max_position, scaling_factor,
)
'''
==================
End of MLU Hijack
==================
'''
rotary_emb = MLUDeepseekScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
inverse,
**extra_kwargs,
)
elif scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size,
rotary_dim,
max_position,
original_max_position,
base,
is_neox_style,
dtype,
short_factor,
long_factor,
**extra_kwargs,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb
MluHijackObject.apply_hijack(
rotary_embedding,
rotary_embedding.get_rope,
vllm__model_executor__layers__rotary_embedding__get_rope,
)

View File

@@ -0,0 +1,302 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Tuple
import torch
from vllm.config import get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.rotary_embedding.base import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.v1.attention.backends.utils import (
get_common_metadata,
MLUCommonAttentionMetadata,
)
from vllm_mlu.v1.attention.backends.mla.flashmla import MLACommonMetadata
from vllm_mlu.model_executor.models.sp_utils import get_sp_forward_context
logger = init_logger(__name__)
@CustomOp.register("rotary_embedding_mlu")
class MLURotaryEmbedding(RotaryEmbedding, CustomOp):
cu_seq_lens : torch.Tensor = None
max_seq_len : int = None
max_model_len : int = None
is_prompt : bool = False
is_chunked : bool = False
positions_: torch.Tensor = None
chunked_prefill_enabled: bool = False
prefill_cu_seq_lens: torch.Tensor = None
prefill_max_seq_len: int = None
decode_cu_seq_lens: torch.Tensor = None
decode_max_seq_len: int = None
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
inverse: bool = False,
) -> None:
CustomOp.__init__(self)
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
# TODO(mgoin): disabled for now due to failures
# Flashinfer only supports head_size=64, 128, 256, 512.
# https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202
# self.use_flashinfer = (self.enabled()
# and dtype in (torch.float16, torch.bfloat16)
# and current_platform.is_cuda()
# and has_flashinfer()
# and self.head_size in [64, 128, 256, 512])
self.use_flashinfer = False
self.inverse = inverse
# For vlm v1
# 1. mlu rope run in eager mode
# 2. all layer use layer0's rope to inference
prefix = "global_rope"
vllm_config = get_current_vllm_config()
self.use_direct_call = False
if not self.use_direct_call:
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
pass
else:
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import DeepseekScalingRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding.yarn_scaling_rope import YaRNScalingRotaryEmbedding
if MLURotaryEmbedding.max_seq_len != None \
and self.max_position_embeddings < MLURotaryEmbedding.max_seq_len and \
not isinstance(self, (YaRNScalingRotaryEmbedding, DeepseekScalingRotaryEmbedding)):
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
f"max_position_embedding ({max_position_embeddings}) from model's config.json, " +
f"This may lead to incorrect model outputs or MLU errors. " +
f"Make sure the value is correct and within the model context size. " +
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
self.max_position_embeddings = MLURotaryEmbedding.max_seq_len
cache = self._compute_cos_sin_cache()
from vllm_mlu.model_executor.layers.rotary_embedding.linear_scaling_rope import MLULinearScalingRotaryEmbedding
if isinstance(self, MLULinearScalingRotaryEmbedding):
logger.debug(f"Using mlu defining _compute_cos_sin_cache due to the special tensor composition")
elif is_neox_style:
cache_pos = cache.shape[0]
cache = cache.reshape(cache_pos, 2, -1)
cache = torch.tile(cache, (1, 1, 2)).reshape(cache_pos, -1)
else:
cache = cache.repeat_interleave(2, dim=-1)
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.cos_, self.sin_ = self._get_cos_sin()
@classmethod
def set_mlu_var_v1(
cls,
common_metadata: MLUCommonAttentionMetadata
) -> None:
cls.unset_mlu_var()
cls.cu_seq_lens = common_metadata.query_start_loc
cls.max_seq_len = common_metadata.max_query_len
cls.is_prompt = common_metadata.is_prefill_only
cls.is_chunked = common_metadata.is_chunked
# for MLA
attn_metadata = get_forward_context().attn_metadata
if isinstance(attn_metadata, dict):
_, attn_metadata = next(iter(attn_metadata.items()))
if isinstance(attn_metadata, MLACommonMetadata):
prefill_metadata = attn_metadata.prefill
decode_metadata = attn_metadata.decode
if prefill_metadata:
cls.prefill_max_seq_len = prefill_metadata.max_query_len
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
else:
cls.prefill_max_seq_len = cls.max_seq_len
cls.prefill_cu_seq_lens = cls.cu_seq_lens
if decode_metadata:
cls.decode_max_seq_len = decode_metadata.max_query_len
cls.decode_cu_seq_lens = decode_metadata.query_start_loc
else:
cls.decode_max_seq_len = cls.max_seq_len
cls.decode_cu_seq_lens = cls.cu_seq_lens
# for sp
sp_context = get_sp_forward_context()
if sp_context is not None and sp_context.is_v32:
prefill_metadata = sp_context.sp_attn_metadata.prefill
cls.is_chunked = True
cls.prefill_max_seq_len = prefill_metadata.max_query_len
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
@classmethod
def unset_mlu_var(cls):
cls.cu_seq_lens = None
cls.max_seq_len = None
cls.is_prompt = False
cls.is_chunked = False
cls.positions_ = None
cls.chunked_prefill_enabled = False
cls.prefill_cu_seq_lens = None
cls.prefill_max_seq_len = None
cls.decode_cu_seq_lens = None
cls.decode_max_seq_len = None
def _get_cos_sin(self) -> Tuple[torch.Tensor, torch.Tensor]:
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
sin = sin.view(-1, self.rotary_dim)
cos = cos.view(-1, self.rotary_dim)
return cos, sin
def _get_positions_with_offsets_mlu(
self,
positions: torch.Tensor,
offsets: torch.Tensor
) -> torch.Tensor:
if offsets.numel() != positions.numel():
raise Exception("rope offsets numel mismatch with positions, "
f"positions: {positions.numel()}, offsets: {offsets.numel()}")
return (positions + offsets).to(torch.int32)
def forward_impl(
self,
positions: torch.Tensor,
x: torch.Tensor,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
common_metadata: MLUCommonAttentionMetadata = get_common_metadata()
if common_metadata is None:
num_tokens, head_num, head_size = x.shape
x = mlu_ops.rotary_embedding(
x.view(1, num_tokens, head_num, head_size),
self.sin_,
self.cos_,
positions,
None,
not self.is_neox_style,
True,
False,
num_tokens
)
return x
else:
cu_seq_lens_ = common_metadata.query_start_loc
if offsets is not None:
if MLURotaryEmbedding.positions_ is None:
MLURotaryEmbedding.positions_ = (
self._get_positions_with_offsets_mlu(positions, offsets))
position_ids = MLURotaryEmbedding.positions_
discrete = True
elif MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
position_ids = positions
discrete = True
else:
position_ids = None
discrete = False
x = mlu_ops.rotary_embedding(
x,
self.sin_,
self.cos_,
position_ids,
cu_seq_lens_,
not self.is_neox_style,
discrete,
False,
MLURotaryEmbedding.max_seq_len
)
return x
def get_param(self, positions, discrete=False):
interleaved = True
if self.is_neox_style:
interleaved = False
if discrete:
position_ids = positions
discrete = discrete
else:
if MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
position_ids = positions
discrete = True
else:
position_ids = None
discrete = False
return position_ids, interleaved, discrete
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.outer(t, inv_freq)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cos = freqs_cis.real
sin = freqs_cis.imag * (-1 if self.inverse else 1)
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor | None = None,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
only_prefill: bool | None = False,
only_decode: bool | None = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.forward_impl(positions, query, offsets)
if key is not None:
self.forward_impl(positions, key, offsets)
return query, key
def rope_forward(
positions: torch.Tensor,
x: torch.Tensor,
layer_name: str,
offsets: torch.Tensor | None = None,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_impl(positions, x, offsets)
def rope_forward_fake(
positions: torch.Tensor,
x: torch.Tensor,
layer_name: str,
offsets: torch.Tensor | None = None,
) -> None:
return
direct_register_custom_op(
op_name="rope_forward",
op_func=rope_forward,
mutates_args=["x"],
fake_impl=rope_forward_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@@ -0,0 +1,166 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Tuple
import torch
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import (
DeepseekScalingRotaryEmbedding,
yarn_get_mscale,
)
from vllm.model_executor.layers.rotary_embedding.common import (
rotate_gptj,
rotate_neox,
yarn_find_correction_range,
yarn_linear_ramp_mask,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLUDeepseekScalingRotaryEmbedding(MLURotaryEmbedding, DeepseekScalingRotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
inverse: bool = False,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
self.inverse = inverse
MLURotaryEmbedding.__init__(
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def forward_mlu_rot(self, input, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len):
"""only one input rotary implementation"""
if input is None:
return None
if self.rotary_dim < self.head_size:
input_pass = input[..., self.rotary_dim:]
input_rot = input[..., :self.rotary_dim]
input_rot = mlu_ops.rotary_embedding(
input_rot,
self.sin_,
self.cos_,
position_ids,
cu_seq_lens,
interleaved,
discrete,
False,
max_seq_len
)
if self.rotary_dim < self.head_size:
input = torch.cat((input_rot, input_pass), dim=-1)
else:
input = input_rot
return input
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor | None = None,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
only_prefill: bool | None = False,
only_decode: bool | None = False,
discrete: bool | None = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
position_ids, interleaved, discrete = self.get_param(positions, discrete)
cu_seq_lens = MLURotaryEmbedding.cu_seq_lens
max_seq_len = MLURotaryEmbedding.max_seq_len
# for MLA
attn_metadata = get_forward_context().attn_metadata
if isinstance(attn_metadata, dict):
_, attn_metadata = next(iter(attn_metadata.items()))
if isinstance(attn_metadata, MLACommonMetadata):
if only_prefill:
cu_seq_lens = MLURotaryEmbedding.prefill_cu_seq_lens
max_seq_len = MLURotaryEmbedding.prefill_max_seq_len
elif only_decode:
cu_seq_lens = MLURotaryEmbedding.decode_cu_seq_lens
max_seq_len = MLURotaryEmbedding.decode_max_seq_len
query = self.forward_mlu_rot(query, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len)
key = self.forward_mlu_rot(key, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len)
return query, key
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** (
torch.arange(
0,
self.rotary_dim,
2,
dtype=torch.float,
device=current_platform.device_type,
)
/ self.rotary_dim
)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.rotary_dim,
self.base,
self.max_position_embeddings,
)
# Get n-d rotational scaling corrected for extrapolation
device = current_platform.device_type
inv_freq_mask = ((
1
- yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
) * self.extrapolation_factor).to(device)
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(
self.max_position_embeddings * self.scaling_factor,
device=current_platform.device_type,
dtype=torch.float32,
)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale * (-1 if self.inverse else 1)
cache = torch.cat((cos, sin), dim=-1)
return cache
forward = MLURotaryEmbedding.forward
forward_native = forward_oot

View File

@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.model_executor.layers.rotary_embedding.dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLUDynamicNTKAlphaRotaryEmbedding(MLURotaryEmbedding, DynamicNTKAlphaRotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_alpha: float,
dtype: torch.dtype,
) -> None:
self.scaling_alpha = scaling_alpha
MLURotaryEmbedding.__init__(
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)

View File

@@ -0,0 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.model_executor.layers.rotary_embedding.dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLUDynamicNTKScalingRotaryEmbedding(MLURotaryEmbedding, DynamicNTKScalingRotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
) -> None:
self.scaling_factor = scaling_factor
MLURotaryEmbedding.__init__(
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)

View File

@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Union
import torch
from vllm.platforms import current_platform
from vllm.model_executor.layers.rotary_embedding.linear_scaling_rope import LinearScalingRotaryEmbedding
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLULinearScalingRotaryEmbedding(MLURotaryEmbedding, LinearScalingRotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factors: list[float] | float,
dtype: torch.dtype,
) -> None:
if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors: list[float] = scaling_factors # noqa
MLURotaryEmbedding.__init__(
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
# Lazy initialized.
self._scaling_factor_to_offset: dict[float, int]
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
device = current_platform.device_type
if self.is_neox_style:
half_dim = self.rotary_dim // 2
inv_freq = 1.0 / (
base
** (torch.arange(0, self.rotary_dim, 1, dtype=torch.float32, device=device)
% half_dim * 2 / self.rotary_dim)
)
else:
inv_freq = 1.0 / (
base
** (torch.arange(0, self.rotary_dim, 1, dtype=torch.float32, device=device)
// 2 * 2 / self.rotary_dim
)
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
cache_list: list[torch.Tensor] = []
# offsets to the next cache in a tensor.
# Each offset corresponds to the same index in scaling_factors.
offsets: list[int] = []
device = current_platform.device_type
for scaling_factor in self.scaling_factors:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * scaling_factor
t = torch.arange(max_len, dtype=torch.float, device=device)
t = t / scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
if not cache_list:
offset = 0
else:
last_offset = offsets[-1]
next_max_len = cache_list[-1].shape[0]
offset = last_offset + next_max_len
offsets.append(offset)
cache_list.append(cache)
self._scaling_factor_to_offset = {
float(scaling_factor): offsets[i]
for i, scaling_factor in enumerate(self.scaling_factors)
}
assert len(self.scaling_factors) == len(offsets)
return torch.cat(cache_list, dim=0)

View File

@@ -0,0 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLULlama3RotaryEmbedding(MLURotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
scaling_factor: float,
low_freq_factor: float,
high_freq_factor: float,
orig_max_position: int,
) -> None:
self.scaling_factor = scaling_factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)

View File

@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.model_executor.layers.rotary_embedding.common import yarn_get_mscale
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLUMRotaryEmbedding(MLURotaryEmbedding, MRotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: list[int] | None = None,
mrope_interleaved: bool = False,
# YaRN parameters.
*,
scaling_factor: float | None = None,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
if self.scaling_factor is not None:
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
else:
self.mscale = 1.0
# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
# a larger the cos and sin cache.
self.cache_max_position_num = max_position_embeddings * 4
MLURotaryEmbedding.__init__(
self,
head_size,
rotary_dim,
self.cache_max_position_num,
base,
is_neox_style,
dtype,
)
self.mrope_section = mrope_section
self.mrope_interleaved = mrope_interleaved
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
def _apply_mrope(self, positions):
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
num_section = len(self.mrope_section)
mrope_section = self.mrope_section * 2
def _apply(x):
x = torch.cat([
m[i % num_section]
for i, m in enumerate(x.split(mrope_section, dim=-1))
],
dim=-1)
return x
return _apply(cos), _apply(sin)
def _apply_interleaved_mrope(self, positions):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
mrope_section = self.mrope_section
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
def _apply(x):
x_t = x[0].clone()
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
offset = self.rotary_dim // 2
x_t[..., 1 + offset:mrope_section[1] * 3 + offset:3] = x[1, ..., 1 + offset:mrope_section[1] * 3 + offset:3]
x_t[..., 2 + offset:mrope_section[2] * 3 + offset:3] = x[2, ..., 2 + offset:mrope_section[2] * 3 + offset:3]
return x_t
return _apply(cos), _apply(sin)
def precompute_sin_cos_cache(
self,
positions: torch.Tensor
):
'''
call this function before forward decoder layers
precompute sin/cos cache for mrope
'''
if positions.ndim == 1:
return
assert positions.ndim == 2
assert self.mrope_section
if self.mrope_interleaved:
cos, sin = self._apply_interleaved_mrope(positions)
else:
cos, sin = self._apply_mrope(positions)
self.mrope_cos_cache = cos
self.mrope_sin_cache = sin
self.mrope_cu_seq_lens = torch.zeros(2, dtype=torch.int32, device=positions.device)
num_tokens = positions.shape[-1]
self.mrope_cu_seq_lens[1] = num_tokens
def forward_oot(
self,
positions: torch.Tensor,
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
assert positions.ndim == 1 or positions.ndim == 2
if positions.ndim == 1:
return MLURotaryEmbedding.forward_oot(self, positions, x)
assert self.mrope_cos_cache is not None and self.mrope_sin_cache is not None,\
"call precompute_sin_cos_cache first!"
num_tokens = positions.shape[-1]
x = mlu_ops.rotary_embedding(x,
self.mrope_sin_cache,
self.mrope_cos_cache,
None,
self.mrope_cu_seq_lens,
not self.is_neox_style,
False,
False,
num_tokens)
return x
forward = MLURotaryEmbedding.forward

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,173 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import torch.nn as nn
import numpy as np
from typing import List, Tuple
from tqdm import tqdm
from vllm.config import ModelConfig
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def initialize_dummy_weights_normal_dist(
model: torch.nn.Module,
low: float = -1e-3,
high: float = 1e-3,
std: float = 0.5,
seed: int = 1234,
) -> None:
"""
Initialize the weights of a PyTorch model with values drawn from a normal distribution.
Floating point parameters are initialized with a normal distribution whose mean is randomly
sampled from [low, high] and standard deviation is fixed at 0.5. Integer parameters are
initialized with random integers in [floor(low), ceil(high)). The initialization is performed
in a batched and efficient way for both floating point and integer parameters.
Optimized version: Uses shared pinned memory based on the largest parameter block size
to minimize H2D transfers, sacrificing global uniqueness for performance.
Args:
model (torch.nn.Module): The model whose weights will be initialized.
low (float): Lower bound for sampling the mean of the normal distribution (for float params).
high (float): Upper bound for sampling the mean of the normal distribution (for float params).
std (float): Standard deviation for the normal distribution (for float params).
seed (int): Random seed for reproducibility.
"""
# Randomly sample the mean for the normal distribution from [low, high]
rng = np.random.RandomState(seed)
mean = float(rng.uniform(low, high, 1).item())
# Create a CPU generator for reproducibility
cpu_gen = torch.Generator(device="cpu")
cpu_gen.manual_seed(seed)
# Collect parameters: separate into floating point and integer types
float_params: List[Tuple[str, torch.Tensor]] = []
int_params: List[Tuple[str, torch.Tensor]] = []
for name, t in tqdm(model.state_dict().items(), desc="Gen dummy weights: Collect params"):
if not isinstance(t, torch.Tensor):
continue
if torch.is_floating_point(t):
float_params.append((name, t))
elif t.dtype in (torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64):
int_params.append((name, t))
# -------- Floating point parameters: optimized shared memory initialization --------
if float_params:
# Find the largest parameter block size
max_float_elems = max(p.numel() for _, p in float_params)
# Create shared pinned memory buffer based on largest parameter
shared_float_buffer = torch.empty(max_float_elems, dtype=torch.float32, device="cpu", pin_memory=True)
shared_float_buffer.normal_(mean=mean, std=std, generator=cpu_gen)
# Copy shared buffer to device once
device_buffer = shared_float_buffer.to(next(iter(float_params))[1].device, non_blocking=True)
for _, p in tqdm(float_params, desc="Gen dummy weights: Init float params"):
n = p.numel()
# Extract from device buffer (may reuse same values for different parameters)
view = device_buffer[:n].view(p.shape)
# torch.normal_ does not support dtypes < fp16, so cast via fp16 if needed
if torch.finfo(p.dtype).bits < 16:
tmp = view.to(torch.float16)
tmp = tmp.to(p.dtype)
else:
tmp = view.to(p.dtype)
# Copy from device buffer to parameter (D2D copy, much faster)
p.data.copy_(tmp)
# -------- Integer parameters: optimized shared memory initialization --------
if int_params:
# Find the largest parameter block size
max_int_elems = max(p.numel() for _, p in int_params)
int_low = int(np.floor(low))
int_high = int(np.ceil(high))
if int_high == int_low:
int_high = int_low + 1 # Ensure at least one possible value
# Create shared pinned memory buffer based on largest parameter
shared_int_buffer = torch.randint(
low=int_low,
high=int_high,
size=(max_int_elems,),
dtype=torch.int64,
generator=cpu_gen,
device="cpu",
pin_memory=True
)
# Copy shared buffer to device once
device_int_buffer = shared_int_buffer.to(next(iter(int_params))[1].device, non_blocking=True)
for _, p in tqdm(int_params, desc="Gen dummy weights: Init int params"):
n = p.numel()
# Extract from device buffer (may reuse same values for different parameters)
view = device_int_buffer[:n].view(p.shape)
tmp = view.to(p.dtype)
# Copy from device buffer to parameter (D2D copy, much faster)
p.data.copy_(tmp)
SMOOTHQUANT_METHOD = "smoothquant"
MULTIMODAL_ARCH_KEYWORDS = {"VL", "Vision", "Multimodal"}
def vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
'''
=============================
Modify by vllm_mlu
=============================
@brief: use torch.normal_ instead of torch.uniform_ for distinguishable logits
std=0.5 is used for better distinguishable logits
'''
# === Default parameter setup (Original values as fallback) ===
low_val = -1e-3
high_val = 1e-3
std_val = 0.5
# === Model and Quantization Check Logic ===
quant_method = getattr(model_config, "quantization", None)
# Attempt to get the architectures list from model_config
archs = getattr(model_config, "architectures", []) or []
# Determine if the model is multimodal (based on architecture names)
is_multimodal = any(
keyword in arch
for arch in archs
for keyword in MULTIMODAL_ARCH_KEYWORDS
)
# === Apply SmoothQuant + Multimodal Parameters ===
if is_multimodal and quant_method == SMOOTHQUANT_METHOD:
# (smoothquant) + Multimodal specific values to mitigate NaN overflow
std_val = 1e-4
initialize_dummy_weights_normal_dist(
model,
low=low_val,
high=high_val,
std=std_val
)
# add a sync to make sure the weights are initialized
torch.mlu.synchronize()
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(
DummyModelLoader,
DummyModelLoader.load_weights,
vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights
)

View File

@@ -0,0 +1,137 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import time
import torch
from torch import nn
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, TensorDeserializer, TensorizerArgs,
_check_tensors_on_meta_device, _resize_lora_embeddings,
is_valid_deserialization_uri)
from vllm.platforms import current_platform
from vllm.logger import init_logger
try:
from tensorizer.stream_io import open_stream
from tensorizer.utils import (convert_bytes, get_mem_usage,
no_init_or_tensor)
except ImportError:
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
logger = init_logger(__name__)
def deserialize_tensorizer_model(model: nn.Module,
tensorizer_config: TensorizerConfig) -> None:
tensorizer_args = tensorizer_config._construct_tensorizer_args()
if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri):
raise ValueError(
f"{tensorizer_config.tensorizer_uri} is not a valid "
f"tensorizer URI. Please check that the URI is correct. "
f"It must either point to a local existing file, or have a "
f"S3, HTTP or HTTPS scheme.")
before_mem = get_mem_usage()
start = time.perf_counter()
'''
=============================
Modify by vllm_mlu
=============================
@brief: use mlu device
'''
device = ''
if current_platform.is_out_of_tree():
device = f'mlu:{torch.mlu.current_device()}'
elif current_platform.is_xpu():
device = f'xpu:{torch.xpu.current_device()}'
else:
device = f'cuda:{torch.cuda.current_device()}'
with open_stream(
tensorizer_config.tensorizer_uri,
mode="rb",
**tensorizer_args.stream_kwargs) as stream, TensorDeserializer(
stream,
dtype=tensorizer_config.dtype,
device=device,
**tensorizer_args.deserialization_kwargs) as deserializer:
deserializer.load_into_module(model)
end = time.perf_counter()
'''
==================
End of MLU Hijack
==================
'''
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
deserializer.close()
logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
end - start, per_second)
logger.info("Memory usage before: %s", before_mem)
logger.info("Memory usage after: %s", after_mem)
_check_tensors_on_meta_device(model)
_resize_lora_embeddings(model)
del model.vllm_tensorized_marker
def serialize_extra_artifacts(
tensorizer_args: TensorizerArgs,
served_model_name: Union[str, list[str], None]) -> None:
if not isinstance(served_model_name, str):
raise ValueError(
f"served_model_name must be a str for serialize_extra_artifacts, "
f"not {type(served_model_name)}.")
'''
=============================
Modify by vllm_mlu
=============================
@brief: use local file
'''
import shutil
from pathlib import Path
local_model_path = Path(served_model_name)
if not local_model_path.exists() or not local_model_path.is_dir():
raise ValueError(
f"served_model_name must be a valid local directory in offline mode, "
f"but got: {served_model_name}"
)
'''
==================
End of MLU Hijack
==================
'''
with tempfile.TemporaryDirectory() as tmpdir:
'''
=============================
Modify by vllm_mlu
=============================
@brief: copy local file
'''
logger.info("Copying local model from %s to temporary directory %s",
local_model_path, tmpdir)
shutil.copytree(local_model_path, tmpdir, dirs_exist_ok=True)
'''
==================
End of MLU Hijack
==================
'''
for artifact in os.scandir(tmpdir):
if not artifact.is_file():
continue
with open(artifact.path, "rb") as f, open_stream(
f"{tensorizer_args.tensorizer_dir}/{artifact.name}",
mode="wb+",
**tensorizer_args.stream_kwargs) as stream:
logger.info("Writing artifact %s", artifact.name)
stream.write(f.read())

View File

@@ -0,0 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from torch import nn
from vllm.config import ModelConfig
from vllm.model_executor.model_loader.tensorizer import is_vllm_tensorized
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
from vllm_mlu.model_executor.model_loader.tensorizer import deserialize_tensorizer_model
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__model_executor__model_loader__tensorizer_loader__TensorizerLoader__load_weights(
self,
model: nn.Module,
model_config: ModelConfig
) -> None:
"""Load serialized model weights with tensorizer.
Expects a vLLM-tensorized model. See the
examples/others/tensorize_vllm_model.py example script
for serializing vLLM models."""
if is_vllm_tensorized(self.tensorizer_config):
tensorizer_config = self._patch_tensorizer_config(model_config)
deserialize_tensorizer_model(model, tensorizer_config)
else:
model.load_weights(self._get_weights_iterator())
MluHijackObject.apply_hijack(
TensorizerLoader,
TensorizerLoader.load_weights,
vllm__model_executor__model_loader__tensorizer_loader__TensorizerLoader__load_weights
)

View File

@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm import ModelRegistry
def register_model():
from .deepseek_v4 import MLUDeepseekV4ForCausalLM # noqa: F401
ModelRegistry.register_model(
"DeepseekV4ForCausalLM",
"vllm_mlu.model_executor.models.deepseek_v4:MLUDeepseekV4ForCausalLM")

View File

@@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from math import lcm
from typing import TYPE_CHECKING
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor.models.config import (HybridAttentionMambaModelConfig,
MambaModelConfig)
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
@classmethod
def vllm__module_executor__models__config__HybridAttentionMambaModelConfig__verify_and_update_config(
cls,
vllm_config: "VllmConfig"
) -> None:
"""
Ensure that page size of attention layers is greater than or
equal to the mamba layers. If not, automatically set the attention
block size to ensure that it is. If the attention page size is
strictly greater than the mamba page size, we pad the mamba page size
to make them equal.
Args:
vllm_config: vLLM Config
"""
# Save the user input before it gets modified by MambaModelConfig
mamba_block_size = vllm_config.cache_config.mamba_block_size
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
if cache_config.cache_dtype == "auto":
kv_cache_dtype = model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# get attention page size (for 1 token)
# Attention backend constraints:
# - FlashAttention (FA) requires block size to be multiple of 16
# - MLA (Multi-head Latent Attention) requires larger alignment:
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
# * Other MLA backends: kernel_block_size 64 alignment
if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
attn_page_size_1_token = MLAAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
else:
kernel_block_alignment_size = 16
if (
current_platform.is_device_capability(100)
and model_config.get_head_size() == 256
and (
envs.VLLM_ATTENTION_BACKEND is None
or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER"
)
):
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
# head size 256 and block size 16 is not supported on blackwell.
kernel_block_alignment_size = 32
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
model_config=model_config,
)
# get mamba page size
mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
block_size=model_config.max_model_len,
).page_size_bytes
# Model may be marked as is_hybrid
# but mamba is skipped via config,
# return directly
if mamba_page_size == 0:
return
if cache_config.enable_prefix_caching:
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
# Mamba2 SSD kernel uses a chunk_size, e.g. 256
# Align the block to the kernel: use lowest multiple of chunk_size
# of attention tokens that would fit mamba_page_size:
# e.g. for mamba page size = 788kB
# attn_1_token = 2kB -> fits ~394 tokens
# then round up to a mulitple of 256 -> 512 tokens
# End result:
# attn_block_size = 512
# mamba_block_size = 512 (aligned to a multiple of chunk_size)
# TODO(tdoublep): this constraint can be relaxed fairly
# easily by changing the way we layout chunks in the
# mamba2 kernels.
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size
else:
# Without prefix caching, select minimum valid attention block size
# to minimize mamba state padding
# Calculate minimum attention block size that satisfies both:
# 1. Backend alignment requirements (kernel_block_alignment_size)
# 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
attn_block_size = kernel_block_alignment_size * cdiv(
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: support qwen3-next
'''
if (vllm_config.mlu_config.enable_mamba_split_page_size):
vllm_config.mlu_config.mamba_to_attn_block_ratio = cdiv(attn_block_size, cache_config.block_size)
cache_config.mamba_page_size_padded = cache_config.block_size * attn_page_size_1_token
return
'''
==================
End of MLU Hijack
==================
'''
# override attention block size if either (a) the
# user has not set it or (b) the user has set it
# too small.
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
cache_config.block_size = attn_block_size
logger.info(
"Setting attention block size to %d tokens "
"to ensure that attention page size is >= mamba page size.",
attn_block_size,
)
# compute new attention page size
attn_page_size = cache_config.block_size * attn_page_size_1_token
assert attn_page_size >= mamba_page_size
if attn_page_size == mamba_page_size:
# don't need to pad mamba page size
return
# pad mamba page size to exactly match attention
if (
cache_config.mamba_page_size_padded is None
or cache_config.mamba_page_size_padded != attn_page_size
):
cache_config.mamba_page_size_padded = attn_page_size
mamba_padding_pct = (
100 * (attn_page_size - mamba_page_size) / mamba_page_size
)
logger.info(
"Padding mamba page size by %.2f%% to ensure "
"that mamba page size and attention page size are "
"exactly equal.",
mamba_padding_pct,
)
MluHijackObject.apply_hijack(HybridAttentionMambaModelConfig,
HybridAttentionMambaModelConfig.verify_and_update_config,
vllm__module_executor__models__config__HybridAttentionMambaModelConfig__verify_and_update_config)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,607 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import (
Any, List, Tuple, Optional, Dict, Union, ClassVar, Literal,
Protocol, overload, runtime_checkable)
from typing_extensions import TypeIs
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.config import VllmConfig
from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_gather_into_list,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter,
)
from vllm.distributed import (
get_tp_group,
get_pp_group,
get_dp_group,
get_data_parallel_group_rank,
get_data_parallel_group_world_size,
get_dense_mlp_tp_world_size,
get_tp_world_world_size,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_logits_tp_world_size,
get_parallel_rank_with_group,
get_tp_world_group,
get_tp_world_rank,
GroupCoordinator,
)
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm_mlu.mlu_forward_context import MLUDPMetadata
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
logger = init_logger(__name__)
# alias after refactor
DataParallelRuntimeParams = MLUDPMetadata
def enable_data_parallel():
return get_dp_group().world_size > 1
def enable_emb_logits_custom_parallel():
return get_logits_tp_world_size() != get_tensor_model_parallel_world_size()
def enable_dense_mlp_custom_parallel():
return get_dense_mlp_tp_world_size() != get_tp_world_world_size()
def get_runtime_infos_per_dp_group(
num_tokens: int, num_requests: int, all_prefill: bool, seq_lens: List[int],
device: torch.device, vllm_config: VllmConfig) -> Tuple[List[int], List[bool]]:
dp_tensor = torch.tensor([num_tokens, num_requests, int(all_prefill)]).to(device, non_blocking=True)
outputs = tensor_model_parallel_all_gather_into_list(dp_tensor, get_dp_group())
outputs = torch.cat(outputs).tolist() # d2h
dp_world_size = get_data_parallel_group_world_size()
dp_is_prefill, dp_query_lens, dp_group_bs, seq_len_per_batch = [], [], [], []
for i in range(0, 3 * dp_world_size, 3):
dp_query_lens.append(outputs[i])
dp_group_bs.append(outputs[i + 1])
dp_is_prefill.append(bool(outputs[i + 2]))
# Only run communication if mcc is enabled and is prefill.
if vllm_config.mlu_config.is_dpsk_mcc_enabled and all(dp_is_prefill):
assert len(seq_lens) == num_requests
seq_len_per_batch = [torch.empty([bs], dtype=dp_tensor.dtype, device=device) for bs in dp_group_bs]
seq_lens_tensor = torch.tensor(seq_lens, dtype=dp_tensor.dtype, device=device)
torch.distributed.all_gather(seq_len_per_batch, seq_lens_tensor, group=get_dp_group().device_group)
seq_len_per_batch=torch.cat(seq_len_per_batch).tolist()
else:
seq_len_per_batch = [0] * sum(dp_group_bs)
return dp_query_lens, dp_group_bs, dp_is_prefill, seq_len_per_batch
def get_deepseek_layer_split_list(
dp_query_lens: List[int], dp_group_bs: List[int]
) -> Tuple[Optional[List[int]], Optional[List[int]], Optional[List[int]]]:
if len(dp_query_lens) != len(dp_group_bs) or len(dp_query_lens) != get_data_parallel_group_world_size():
logger.warning(f"dp_query_lens length: {len(dp_query_lens)} != dp_group_bs length: {len(dp_group_bs)}, "
f"disable deepseek layer split")
return None, None, None
emb_query_lens, logits_batch_sizes, dense_attn_token_split_list = None, None, None
all_dp_query_lens, all_dp_group_bs = [], []
for i in range(len(dp_query_lens)):
all_dp_query_lens.extend([dp_query_lens[i]] * get_tensor_model_parallel_world_size())
all_dp_group_bs.extend([dp_group_bs[i]] * get_tensor_model_parallel_world_size())
if get_logits_tp_world_size() != get_tensor_model_parallel_world_size():
slice_start = get_tp_world_rank() // get_logits_tp_world_size() * get_logits_tp_world_size()
slice_end = slice_start + get_logits_tp_world_size()
emb_query_lens = all_dp_query_lens[slice_start:slice_end]
logits_batch_sizes = all_dp_group_bs[slice_start:slice_end]
if get_dense_mlp_tp_world_size() != get_tp_world_world_size():
slice_start = get_tp_world_rank() // get_dense_mlp_tp_world_size() * get_dense_mlp_tp_world_size()
slice_end = slice_start + get_dense_mlp_tp_world_size()
dense_attn_token_split_list = all_dp_query_lens[slice_start:slice_end]
return emb_query_lens, logits_batch_sizes, dense_attn_token_split_list
def get_dp_metadata(
num_tokens: int,
data_parallel_size: int,
data_parallel_rank: int,
tensor_parallel_size: int,
prefill_dispatch_use_RS_AG: bool,
) -> DataParallelRuntimeParams:
"""
Get dp params when dummy run or capture model graph. These two cases do not have
dp_params when forward call, because we do not want to hijack to much.
"""
dp_query_lens = [num_tokens] * data_parallel_size
in_prefill = get_forward_context().attn_metadata is None # dummy run
dp_is_prefill = [in_prefill] * data_parallel_size
emb_query_lens, logits_batch_sizes, dense_attn_token_split_list = None, None, None
if get_logits_tp_world_size() != get_tensor_model_parallel_world_size():
emb_query_lens = [num_tokens] * get_logits_tp_world_size()
logits_batch_sizes = None # dummy run and capture model does not contain logits
if get_dense_mlp_tp_world_size() != get_tp_world_world_size():
dense_attn_token_split_list = [num_tokens] * get_dense_mlp_tp_world_size()
return MLUDPMetadata.make_oot(data_parallel_rank,
data_parallel_size,
tensor_parallel_size,
dp_query_lens,
dp_is_prefill,
prefill_dispatch_use_RS_AG,
emb_query_lens=emb_query_lens,
logits_batch_sizes=logits_batch_sizes,
dense_attn_token_split_list=dense_attn_token_split_list)
def remove_paddings_after_all_gather(
hidden_states: torch.Tensor,
padding_to_token_num: int,
token_num_list: List[int],
) -> torch.Tensor:
dp_group_tensors = []
offset = 0
for token_num in token_num_list:
if token_num != 0:
dp_group_tensors.append(hidden_states[offset:offset+token_num])
offset += padding_to_token_num
if len(dp_group_tensors) == 1:
hidden_states = dp_group_tensors[0]
else:
hidden_states = torch.cat(dp_group_tensors)
return hidden_states
def tensor_model_parallel_all_gather_dp(
group_num_tokens: List[int],
rank: int,
hidden_states: Optional[torch.Tensor],
group: GroupCoordinator,
hidden_size: int = None,
dtype: torch.dtype = None,
device: torch.device = None) -> torch.Tensor:
"""
All gather in the group.
Input is a 2-D tensor, and can have different shape in the first dim,
for example, [4, 7, 5, 8], [2, 5, 4, 0].
"""
num_tokens_equal = all(x == group_num_tokens[0] for x in group_num_tokens)
if num_tokens_equal:
hidden_states = tensor_model_parallel_all_gather(
input_=hidden_states, dim=0, tp_group=group)
else:
max_num_tokens = max(group_num_tokens)
num_padding = max_num_tokens - group_num_tokens[rank]
if num_padding > 0:
if hidden_states is None:
hidden_states = torch.empty((max_num_tokens, hidden_size),
dtype=dtype, device=device)
else:
hidden_states = F.pad(hidden_states, (0, 0, 0, num_padding))
hidden_states = tensor_model_parallel_all_gather(
input_=hidden_states, dim=0, tp_group=group)
hidden_states = remove_paddings_after_all_gather(
hidden_states, max_num_tokens, group_num_tokens)
return hidden_states
def tensor_model_parallel_all_gather_op_v2(
input_: torch.Tensor,
dim_size_list: List[int],
group_coordinator: GroupCoordinator,
non_leading_dim_size: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""
All gather the input tensor across model parallel group with only communication ops.
Note: compared to `tensor_model_parallel_all_gather_dp`, this method supports different
sizes in the first dim, and does not involve padding operation.
"""
all_size_equal = all([dim_size == dim_size_list[0] for dim_size in dim_size_list])
output_shape = (sum(dim_size_list), non_leading_dim_size)
output = torch.empty(output_shape, device=device, dtype=dtype)
if input_ is None:
input_ = torch.empty((0, non_leading_dim_size), device=device, dtype=dtype)
if all_size_equal:
torch.distributed.all_gather_into_tensor(
output, input_, group=group_coordinator.device_group)
else:
# Note: torch.split splits the tensor into chunks. And each chunk
# is a view of the original tensor.
tensor_list = torch.split(output, dim_size_list, dim=0)
torch.distributed.all_gather(
list(tensor_list), input_, group=group_coordinator.device_group)
return output
def process_post_attention_communication(
hidden_states: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
tp_group: Any = None,
):
"""
Processes distributed communication operations after attention computation.
This function performs necessary communication operations after attention computation
to ensure data synchronization across different parallel groups.
Supports two modes:
1. Tensor parallel mode: Uses tp_group for all-reduce and all-gather operations
2. Data parallel mode: Uses reduce-scatter and all-gather for global synchronization
Args:
hidden_states: Hidden states tensor after attention computation, can be None
dp_params: Data parallel runtime parameters containing token distribution and padding info
hidden_size: Dimension size of hidden states
dtype: Data type of the tensor
device: Device where the tensor is located
tp_group: Tensor parallel group, if None uses data parallel mode
Returns:
Hidden states tensor after communication synchronization processing
Note:
- When prefill_pad_to_token_num != -1, padding and unpadding operations will be performed
- Function selects optimal communication path based on token count and parallel strategy
"""
if tp_group is not None:
if dp_params.token_num != 0:
hidden_states = tensor_model_parallel_all_reduce(
hidden_states)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.dense_attn_token_split_list,
rank=get_parallel_rank_with_group(tp_group),
hidden_states=hidden_states,
group=tp_group,
)
else:
if dp_params.prefill_pad_to_token_num != -1:
# pad hidden_states to use reduce_scatter and global all gather
pad_num = dp_params.prefill_pad_to_token_num - dp_params.token_num
if pad_num != 0:
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_num))
hidden_states = tensor_model_parallel_reduce_scatter(
hidden_states, dim=0)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.attn_token_split_list_reduce_scatter,
rank=get_tp_world_rank(),
hidden_states=hidden_states,
group=get_tp_world_group(),
)
# get origin hidden_states for moe compute
hidden_states = remove_paddings_after_all_gather(
hidden_states, dp_params.prefill_pad_to_token_num,
dp_params.token_split_list)
else:
hidden_states = tensor_model_parallel_all_reduce(
hidden_states)
all_gather_group = get_dp_group()
all_gather_rank = get_data_parallel_group_rank()
hidden_states = tensor_model_parallel_all_gather_dp(
dp_params.token_split_list, all_gather_rank, hidden_states,
all_gather_group, hidden_size, dtype, device)
return hidden_states
def dp_model_forward(
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
embedding_layer: nn.Module,
model_norm_layer: nn.Module,
start_layer: int,
end_layer: int,
layers: List[nn.Module],
layer_input_norm_name: str,
prefill_dispatch_use_RS_AG: bool,
streams: Optional[Dict[str, torch.mlu.Stream]] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
"""run model with dp."""
if dp_params is None:
dp_params = get_dp_metadata(positions.numel(),
get_data_parallel_group_world_size(),
get_data_parallel_group_rank(),
get_tensor_model_parallel_world_size(),
prefill_dispatch_use_RS_AG)
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
if embedding_layer.__class__.__name__ == "DPVocabParallelEmbedding":
hidden_states = embedding_layer(input_ids, dp_params=dp_params)
else:
hidden_states = embedding_layer(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(start_layer, end_layer):
is_first_layer = (i == start_layer)
is_last_layer = (i == end_layer - 1)
next_input_layernorm = None
if not is_last_layer:
next_input_layernorm = getattr(layers[i+1], layer_input_norm_name)
hidden_states, residual = layers[i](
positions=positions,
hidden_states=hidden_states,
residual=residual,
dp_params=dp_params,
is_first_layer=is_first_layer,
is_last_layer=is_last_layer,
streams=streams,
next_input_layernorm=next_input_layernorm,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = model_norm_layer(hidden_states)
return hidden_states
def dp_layer_forward(
input_norm: nn.Module,
self_attn: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
mlp_kwargs: List[Dict[str, Any]],
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
hidden_size: int,
hidden_states_dtype: torch.dtype,
is_first_layer: bool = False,
is_last_layer: bool = False,
next_input_layernorm: Optional[nn.Module] = None,
enable_all2all: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
run layer with dp. dispatch all2all or rs+ag or common.
For mlp_kwargs, because all2all forward args is often different with common mlp args.
So here we decide that the mlp_kwargs[-1] is always all2all kwargs. For example:
Deepseek enable all2all, mlp_kwargs will be: [{mlp common forward kwargs}, {mlp all2all kwargs}].
Deepseek does not enable all2all, mlp_kwargs will be: [{mlp common forward kwargs}].
"""
if dp_params.layer_use_reduce_scatter:
common_metadata = get_common_metadata()
is_decode_only = common_metadata is not None and common_metadata.is_decode_only
use_all2all = enable_all2all and is_decode_only and isinstance(mlp, SparseMoeMlp)
forward_func = _dp_forward_layer_all2all if use_all2all else _dp_forward_layer_rs_ag
hidden_states, residual = forward_func(input_norm,
self_attn,
post_norm,
mlp,
mlp_kwargs,
positions,
hidden_states,
residual,
dp_params,
is_first_layer,
is_last_layer,
next_input_layernorm)
else:
hidden_states, residual = _dp_forward_layer_common(input_norm,
self_attn,
post_norm,
mlp,
mlp_kwargs,
positions,
hidden_states,
residual,
dp_params,
hidden_size,
hidden_states_dtype)
return hidden_states, residual
def _dp_forward_layer_rs_ag(
input_norm: nn.Module,
self_attn: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
mlp_kwargs: List[Dict[str, Any]],
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
is_first_layer: bool,
is_last_layer: bool,
next_input_layernorm: List[Optional[nn.Module]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""run layer with rs+ag."""
if residual is None:
residual = hidden_states
# We move the input_layernorm of i+1 layer to the end of i layer.
# But for the first layer, we need to do input_layernorm first.
if is_first_layer:
hidden_states = input_norm(hidden_states)
# Self Attention
hidden_states = self_attn(
positions=positions,
hidden_states=hidden_states,
)
# add residual here for the first layer
if is_first_layer and get_tensor_model_parallel_rank() == 0:
hidden_states = hidden_states + residual
hidden_states = tensor_model_parallel_reduce_scatter(
hidden_states, dim=0)
# move norm between rs and ag
if is_first_layer:
residual = hidden_states
hidden_states = post_norm(hidden_states)
else:
hidden_states, residual = post_norm(hidden_states, residual)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.attn_token_split_list_reduce_scatter,
rank=get_tp_world_rank(),
hidden_states=hidden_states,
group=get_tp_world_group(),
)
# mlp, use all cards
hidden_states = mlp(hidden_states, **mlp_kwargs[0])
hidden_states = tensor_model_parallel_reduce_scatter(
hidden_states, dim=0, tp_group=get_tp_world_group())
if is_last_layer:
hidden_states = hidden_states + residual
residual = None
else:
# To reduce layernorm computation, we move the layernorm of i+1 layer to
# the end of i layer. Besides, we fuse residual addition into layernorm.
assert next_input_layernorm is not None
hidden_states, residual = next_input_layernorm(hidden_states, residual)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.moe_token_split_list_reduce_scatter,
rank=get_tensor_model_parallel_rank(),
hidden_states=hidden_states,
group=get_tp_group(),
)
return hidden_states, residual
def _dp_forward_layer_all2all(
input_norm: nn.Module,
self_attn: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
mlp_kwargs: List[Dict[str, Any]],
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
is_first_layer: bool,
is_last_layer: bool,
next_input_layernorm: List[Optional[nn.Module]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""run layer with all2all."""
if residual is None:
residual = hidden_states
# We move the input_layernorm of i+1 layer to the end of i layer.
# But for the first layer, we need to do input_layernorm first.
if is_first_layer:
hidden_states = input_norm(hidden_states)
# Self Attention
hidden_states = self_attn(
positions=positions,
hidden_states=hidden_states,
)
# add residual here for the first layer
if is_first_layer and get_tensor_model_parallel_rank() == 0:
hidden_states = hidden_states + residual
hidden_states = tensor_model_parallel_reduce_scatter(
hidden_states, dim=0)
# move norm between rs and ag
if is_first_layer:
residual = hidden_states
hidden_states = post_norm(hidden_states)
else:
# add residual in norm for other layers
hidden_states, residual = post_norm(hidden_states, residual)
hidden_states = mlp.forward_all2all(hidden_states, **mlp_kwargs[-1])
if is_last_layer:
hidden_states = hidden_states + residual
residual = None
else:
# To reduce layernorm computation, we move the layernorm of i+1 layer to
# the end of i layer. Besides, we fuse residual addition into layernorm.
assert next_input_layernorm is not None
hidden_states, residual = next_input_layernorm(hidden_states, residual)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.moe_token_split_list_reduce_scatter,
rank=get_tensor_model_parallel_rank(),
hidden_states=hidden_states,
group=get_tp_group(),
)
return hidden_states, residual
def _dp_forward_layer_common(
input_norm: nn.Module,
self_attn: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
mlp_kwargs: List[Dict[str, Any]],
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
hidden_size: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""run layer with common."""
if residual is None:
residual = hidden_states
hidden_states = input_norm(hidden_states)
hidden_states = self_attn(
positions=positions,
hidden_states=hidden_states,
)
# add residual here
if get_tensor_model_parallel_rank() == 0:
hidden_states = hidden_states + residual
hidden_states = process_post_attention_communication(
hidden_states, dp_params, hidden_size, dtype, positions.device, None
)
residual = hidden_states[dp_params.token_num_offset:
dp_params.token_num_offset + dp_params.token_num]
hidden_states = post_norm(hidden_states)
hidden_states = mlp(hidden_states, **mlp_kwargs[0])
hidden_states = tensor_model_parallel_all_reduce(
hidden_states, tp_group=get_tp_world_group())
# add residual here
hidden_states = hidden_states[dp_params.token_num_offset:
dp_params.token_num_offset+dp_params.token_num]
hidden_states = hidden_states + residual
residual = hidden_states
return hidden_states, residual

Some files were not shown because too many files have changed in this diff Show More