add qwen3

This commit is contained in:
Chranos
2026-02-04 17:22:39 +08:00
parent d1c0f68ab4
commit 8511fe8530
1932 changed files with 300426 additions and 0 deletions

View File

@@ -0,0 +1 @@
from . import mlu_hijack

View File

@@ -0,0 +1,122 @@
from torch.utils import collect_env as torch_collect_env
import os
import re
def _check_env(env, default=False):
if env in os.environ:
return os.environ[env].lower() in ["true", "1"]
return default
def _check_env_value(env, default=0):
if env in os.environ:
if not os.environ[env].isdigit():
raise ValueError(f"'{env}' should be set with integer")
value = int(os.environ[env])
return value
return default
def get_device_name(device_id: int = 0) -> str:
r"""Gets the name of a device.
Args:
device_id (int): device id for which to return the device name.
Returns:
str: the name of the device. eg. MLU370.
"""
run_lambda = torch_collect_env.run
try:
out = torch_collect_env.run_and_read_all(run_lambda, "cnmon -l")
matches = re.findall(r'MLU\d+(?:-\w+)?', out)
return matches[device_id]
except Exception as e:
raise Exception(f"No device found with ID {device_id}.")
def get_device_major_capability(device_id: int = 0) -> int:
r"""Gets the cuda major capability of a device.
Args:
device_id (int): device id for which to return the device capability.
Returns:
int: the major cuda capability of the device.
"""
try:
device_name = get_device_name(device_id)
return int(device_name[3])
except Exception as e:
raise Exception(f"Fail to parse device capability with ID: {device_id}.")
# USE_PAGED: Select the vLLM running mode, default value depends on current platform.
USE_PAGED = _check_env("USE_PAGED", default=(get_device_major_capability() > 3))
# VLLM_LATENCY_DEBUG: Get more kernel info for benchmark latency.
VLLM_LATENCY_DEBUG = _check_env("VLLM_LATENCY_DEBUG", default=False)
# VLLM_LATENCY_DEBUG_NO_DEVICE: Get more kernel info(without device) for benchmark latency.
VLLM_LATENCY_DEBUG_NO_DEVICE = _check_env("VLLM_LATENCY_DEBUG_NO_DEVICE", default=False)
# VLLM_DUMP_TENSORS: Dump each layer outputs when running vLLM inference.
VLLM_DUMP_OUTPUTS = _check_env("VLLM_DUMP_OUTPUTS", default=False)
# VLLM_DUMP_CPU_INFO: Get cpu info when running vLLM inference.
VLLM_DUMP_CPU_INFO = _check_env("VLLM_DUMP_CPU_INFO", default=False)
# VLLM_DUMP_MLU_INFO: Get device info when running vLLM inference.
VLLM_DUMP_MLU_INFO = _check_env("VLLM_DUMP_MLU_INFO", default=False)
# VLLM_SCHEDULER_PROFILE: Profiling vLLM scheduler.
VLLM_SCHEDULER_PROFILE = _check_env("VLLM_SCHEDULER_PROFILE", default=False)
# VLLM_GRAPH_DEBUG: Debug the graph status when running decoder, default value is True.
# Set to False to disable warning messages.
VLLM_GRAPH_DEBUG = _check_env("VLLM_GRAPH_DEBUG", default=True)
# CHUNKED_PIPELINE_PARALLEL_EN: use chunked pipeline parallel, default value is False.
CHUNKED_PIPELINE_PARALLEL_EN = _check_env("CHUNKED_PIPELINE_PARALLEL_EN", default=False)
# CONTEXT_PARALLEL_EN: use context parallel, default value is False.
CONTEXT_PARALLEL_EN = _check_env("CONTEXT_PARALLEL_EN", default=False)
# EXPERT_PARALLEL_EN: use expert parallel, default value is False.
EXPERT_PARALLEL_EN = _check_env("EXPERT_PARALLEL_EN", default=False)
VLLM_LATENCY_DEBUG_EN = (VLLM_LATENCY_DEBUG or VLLM_LATENCY_DEBUG_NO_DEVICE)
VLLM_LATENCY_DEBUG_WITH_DEVICE_EN = (VLLM_LATENCY_DEBUG and not VLLM_LATENCY_DEBUG_NO_DEVICE)
VLLM_DUMP_CPU_INFO_EN = (VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and VLLM_DUMP_CPU_INFO)
VLLM_DUMP_MLU_INFO_EN = (VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and VLLM_DUMP_MLU_INFO)
CUSTOM_VLLM_HIJACK_EN = (CHUNKED_PIPELINE_PARALLEL_EN or CONTEXT_PARALLEL_EN or EXPERT_PARALLEL_EN)
VLLM_PRELOAD_SIZE = _check_env_value("VLLM_PRELOAD_SIZE", default=0)
# ATTN_PARALLEL_NUM & FFN_PARALLEL_NUM: use context comm cmpt parallel.
ATTN_PARALLEL_NUM = 'ATTN_PARALLEL_NUM'
FFN_PARALLEL_NUM = 'FFN_PARALLEL_NUM'
# this class is used by layers, add BlockSizeInfo to get BLOCKSIZE in model/layer
class BlockSizeInfo :
BLOCK_SIZE = -1
@classmethod
def set_block_size(cls, a : int) :
if USE_PAGED :
if a != -1 and a != 16 :
raise ValueError("BLOCKSIZE other than 16 are not supported in paged mode, please check '--block-size' value.")
cls.BLOCK_SIZE = 16
else :
cls.BLOCK_SIZE = 2048 if a == -1 else a
def check_context_comm_cmpt_parallel():
return (ATTN_PARALLEL_NUM in os.environ) or (FFN_PARALLEL_NUM in os.environ)
def set_is_prompt(flag):
global IS_PROMPT
IS_PROMPT=flag
def get_is_prompt():
return IS_PROMPT

View File

@@ -0,0 +1,4 @@
import vllm_mlu.attention.backends
import vllm_mlu.attention.ops
import vllm_mlu.attention.layer
import vllm_mlu.attention.selector

View File

@@ -0,0 +1 @@
import vllm_mlu.attention.backends.mlu_attn

View File

@@ -0,0 +1,802 @@
import torch
from contextlib import contextmanager
from itertools import accumulate
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
from vllm import _mlu_ops as mlu_ops
from vllm.attention.backends.abstract import (AttentionMetadata,
AttentionType)
from vllm.attention.backends.utils import (
PAD_SLOT_ID, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args)
from vllm.forward_context import get_forward_context
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
make_tensor_with_pad)
from vllm.attention.backends.mlu_attn import (
MLUFlashAttentionBackend, MLUFlashAttentionMetadataBuilder,
MLUFlashAttentionMetadata, MLUFlashAttentionImpl,
MLUFlashAttentionState, _get_query_key_seq_metadata,
_get_causal_option)
from vllm_mlu._mlu_utils import USE_PAGED
from vllm_mlu.mlu_hijack_utils import MluHijackObject
class MLUFlashAttentionBackend_V2(MLUFlashAttentionBackend):
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 80, 96, 128, 160, 192, 224, 256, 512, 576]
@staticmethod
def get_impl_cls() -> Type["MLUFlashAttentionImpl_V2"]:
return MLUFlashAttentionImpl_V2
@staticmethod
def get_builder_cls() -> Type["MLUFlashAttentionMetadataBuilder_V2"]:
return MLUFlashAttentionMetadataBuilder_V2
@staticmethod
def get_state_cls() -> Type["MLUFlashAttentionState_V2"]:
return MLUFlashAttentionState_V2
@staticmethod
def get_kv_cache_scale_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
) -> Tuple[int, ...]:
return (2, num_blocks, num_kv_heads, block_size)
@staticmethod
def copy_blocks(
kv_caches: List[List[torch.Tensor]],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0][0] for kv_cache in kv_caches]
value_caches = [kv_cache[0][1] for kv_cache in kv_caches]
mlu_ops.copy_blocks(key_caches, value_caches, src_to_dists)
kv_cache_scales = [kv_cache[1] for kv_cache in kv_caches]
if len(kv_cache_scales) > 0 and kv_cache_scales[0].numel() > 0:
key_cache_scales = [kv_cache_scale[0] for kv_cache_scale in kv_cache_scales]
value_cache_scales = [kv_cache_scale[1] for kv_cache_scale in kv_cache_scales]
mlu_ops.copy_blocks(key_cache_scales, value_cache_scales, src_to_dists)
class MLUMLAFlashAttentionBackend(MLUFlashAttentionBackend):
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (1, num_blocks, num_kv_heads, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
mlu_ops.swap_blocks(dst_key_cache, src_key_cache, src_to_dst)
@staticmethod
def get_kv_cache_scale_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
) -> Tuple[int, ...]:
return (1, num_blocks, num_kv_heads, block_size)
@staticmethod
def copy_blocks(
kv_caches: List[List[torch.Tensor]],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0][0] for kv_cache in kv_caches]
mlu_ops.copy_blocks(key_caches, None, src_to_dists)
kv_cache_scales = [kv_cache[1] for kv_cache in kv_caches]
if len(kv_cache_scales) > 0 and kv_cache_scales[0].numel() > 0:
key_cache_scales = [kv_cache_scale[0] for kv_cache_scale in kv_cache_scales]
mlu_ops.copy_blocks(key_cache_scales, None, src_to_dists)
class MLUFlashAttentionState_V2(MLUFlashAttentionState):
def __init__(self, runner: "ModelRunnerBase"):
MLUFlashAttentionState.__init__(self, runner)
@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID if USE_PAGED else 0,
dtype=torch.int32,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=(self.runner.max_seq_len_to_capture if USE_PAGED
else min(self.runner.block_size, self.runner.max_seq_len_to_capture)),
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
)
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or " \
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
return attn_metadata
def get_graph_input_buffers(
self,
attn_metadata: AttentionMetadata,
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
input_buffers = {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": None,
"block_tables": None,
}
if attn_metadata.num_prefills > 0:
input_buffers["seq_lens_tensor"] = attn_metadata.prefill_metadata.seq_lens_tensor
input_buffers["block_tables"] = attn_metadata.prefill_metadata.block_tables
else:
input_buffers["seq_lens_tensor"] = attn_metadata.decode_metadata.seq_lens_tensor
input_buffers["block_tables"] = attn_metadata.decode_metadata.block_tables
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or "\
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
return input_buffers
def prepare_graph_input_buffers(
self,
input_buffers: Dict[str, Any],
attn_metadata: AttentionMetadata,
is_encoder_decoder_model: bool = False) -> None:
metadata = attn_metadata.prefill_metadata if \
attn_metadata.num_prefills > 0 else attn_metadata.decode_metadata
input_buffers["seq_lens_tensor"].copy_(
metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
metadata.block_tables, non_blocking=True)
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or "\
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)
@contextmanager
def graph_capture_with_context(
self,
ctx_graph_batch_size: int,
max_batch_size: int,
max_num_tokens: int
):
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_num_tokens, ),
PAD_SLOT_ID if USE_PAGED else 0,
dtype=torch.int32,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
# block tables used for decode mlugraph input buffer
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
# block tables used for context mlugraph input buffer
self._ctx_graph_block_tables = torch.zeros((ctx_graph_batch_size, 0),
dtype=self._graph_block_tables.dtype,
device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
del self._ctx_graph_block_tables
def fill_seq_lens_tensor(
self,
seq_len: int
) -> None:
self._graph_seq_lens.fill_(seq_len)
def graph_capture_get_metadata_for_context(
self,
batch_size: int,
seq_len: int,
is_encoder_decoder_model: bool = False
) -> MLUFlashAttentionMetadata:
assert self._is_graph_capturing
query_start_loc = torch.zeros(batch_size + 1,
dtype=torch.int32,
device=self.runner.device)
seq_start_loc = torch.zeros(batch_size + 1,
dtype=torch.int32,
device=self.runner.device)
context_lens_tensor = torch.zeros(batch_size,
dtype=torch.int32,
device=self.runner.device)
torch.cumsum(self._graph_seq_lens[:batch_size],
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
torch.cumsum(self._graph_seq_lens[:batch_size],
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
num_tokens = batch_size * seq_len
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=batch_size,
num_prefill_tokens=num_tokens,
num_decode_tokens=0,
slot_mapping=self._graph_slot_mapping[:num_tokens],
multi_modal_placeholder_index_maps=None,
seq_lens=[seq_len] * batch_size,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=seq_len,
max_decode_query_len=0,
max_prefill_seq_len=seq_len,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=self._ctx_graph_block_tables,
use_cuda_graph=True,
)
return attn_metadata
class MLUFlashAttentionMetadataBuilder_V2(MLUFlashAttentionMetadataBuilder):
def build(
self,
seq_lens: List[int],
query_lens: List[int],
cuda_graph_pad_size: int,
batch_size: int
) -> MLUFlashAttentionMetadata:
'''
=============================
Modify by vllm_mlu
=============================
@brief: Use origin func if do not use context mlugraph.
'''
if not self.runner.model_config.use_context_mlugraph():
return super().build(seq_lens,
query_lens,
cuda_graph_pad_size,
batch_size)
'''
==================
End of MLU Hijack
==================
'''
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
max_decode_query_len = max(decode_query_lens)
else:
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([
(PAD_SLOT_ID if USE_PAGED else 0)] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables)
else:
if USE_PAGED:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
else:
block_tables = make_tensor_without_pad(
self.block_tables,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
device, self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
'''
=============================
Modify by vllm_mlu
=============================
@brief: Check if we can use context mlugraph for the given input.
'''
if num_decode_tokens == 0 and self.num_prefills > 0:
ctx_graph_bs, ctx_graph_seq_len = (
self.runner.model_config.get_context_mlugraph_bs_and_seq())
use_captured_graph = len(seq_lens) == ctx_graph_bs and all(
seq_len == ctx_graph_seq_len for seq_len in seq_lens)
'''
==================
End of MLU Hijack
==================
'''
return MLUFlashAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class MLUFlashAttentionImpl_V2(MLUFlashAttentionImpl):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: List[torch.Tensor],
attn_metadata: MLUFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
use_mla: bool = False,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
output = torch.ops.vllm.unified_flash_attention_v2(
query,
key,
value,
self.num_heads,
self.head_size,
self.num_kv_heads,
kv_cache,
self.kv_cache_dtype,
k_scale,
v_scale,
self.scale,
attn_type.value,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
use_mla,
)
return output
def unified_flash_attention_v2(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: List[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
attn_type_int_val: int,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
use_mla: bool = False,
) -> torch.Tensor:
# Convert integer attn_type to enum
try:
attn_type = AttentionType(attn_type_int_val)
except ValueError as err:
raise AttributeError(
f"Invalid attention type {str(attn_type_int_val)}") from err
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, MLUFlashAttentionMetadata)
attn_metadata: MLUFlashAttentionMetadata = current_metadata
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
v_head_size = value.size(1) // num_kv_heads
if (key is not None) and (value is not None):
key = key.view(-1, num_kv_heads, head_size)
if use_mla and attn_metadata.prefill_metadata:
value = value.view(-1, num_kv_heads, v_head_size)
else:
value = value.view(-1, num_kv_heads, head_size)
if kv_cache[0].numel() > 0:
kv_cache_, kv_cache_scale_ = kv_cache
key_cache = kv_cache_[0]
value_cache = None if use_mla else kv_cache_[1]
key_cache_scale, value_cache_scale = None, None
if kv_cache_scale_.numel() > 0:
key_cache_scale = kv_cache_scale_[0]
value_cache_scale = None if use_mla else kv_cache_scale_[1]
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the KV
# cache is already populated with the cross-attention tensor.
# Thus, we skip cache updates during this time.
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
if USE_PAGED:
value_to_cache = None if use_mla else value
if use_mla and attn_metadata.prefill_metadata:
# MLA save cache info in models before flashattn
pass
else:
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_paged_cache(key,
value_to_cache,
key_cache,
value_cache,
key_cache_scale,
value_cache_scale,
updated_slot_mapping.flatten())
else:
mlu_ops.reshape_paged_cache(key,
value_to_cache,
key_cache,
value_cache,
updated_slot_mapping.flatten())
else:
# FIXME: After TMO-1496 is completed, remove this code.
if key.stride() != value.stride():
key = key.contiguous()
value = value.contiguous()
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_linear_cache(key,
value,
key_cache,
value_cache,
key_cache_scale,
value_cache_scale,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, # packed
None, # context_seq_offset
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else:
mlu_ops.reshape_linear_cache(key,
value,
key_cache,
value_cache,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, # packed
None, # context_seq_offset
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
if use_mla and attn_metadata.prefill_metadata:
output = torch.empty(query.shape[0], query.shape[1], v_head_size, dtype=query.dtype, device="mlu")
else:
output = torch.empty_like(query)
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
decode_query = query[num_prefill_query_tokens:]
# QKV for prefill.
query = query[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens
if prefill_meta := attn_metadata.prefill_metadata:
alibi_slopes = None if alibi_slopes is None else \
alibi_slopes.repeat(attn_metadata.num_prefills, 1)
# Prompt run.
if (kv_cache[0].numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]
mlu_ops.flash_attention(query,
key,
value,
output[:num_prefill_query_tokens],
q_seq_start_loc,
k_seq_start_loc,
alibi_slopes,
None,
q_seq_len,
k_seq_len,
softmax_scale,
_get_causal_option(attn_type),
-1 if window_size is None \
else window_size[0],
-1 if window_size is None \
else window_size[1],
torch.float,
False)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
mlu_ops.flash_attention(query,
key_cache,
value_cache,
output[:num_prefill_kv_tokens],
prefill_meta.query_start_loc,
prefill_meta.seq_start_loc,
alibi_slopes,
None,
prefill_meta.max_query_len,
max_seq_len,
softmax_scale,
True,
-1 if window_size is None \
else window_size[0],
-1 if window_size is None \
else window_size[1],
torch.float,
False,
prefill_meta.block_tables)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
alibi_slopes = None if alibi_slopes is None \
else alibi_slopes.repeat(attn_metadata.num_decode_tokens, 1)
decode_query = decode_query.view(-1, 1, num_heads, head_size)
decode_out = output[num_prefill_query_tokens:].view(-1, 1, num_heads, head_size)
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None
if decode_meta.max_decode_query_len > 1:
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1")
mlu_ops.flash_attention(decode_query,
key_cache,
value_cache,
decode_out,
decode_meta.query_start_loc,
decode_meta.seq_start_loc,
alibi_slopes,
None,
decode_meta.max_decode_query_len,
decode_meta.max_decode_seq_len,
softmax_scale,
True,
-1 if window_size is None \
else window_size[0],
-1 if window_size is None \
else window_size[1],
torch.float,
False,
decode_meta.block_tables)
else:
# Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg,
max_context_len,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
if use_mla:
value_cache = key_cache
value_cache_scale = key_cache_scale
mlu_ops.single_query_cached_kv_attn(decode_query,
key_cache,
value_cache,
decode_out,
block_tables_arg,
seq_lens_arg,
key_cache_scale,
value_cache_scale,
alibi_slopes,
max_context_len,
-1 if window_size is None \
else window_size[0],
-1 if window_size is None \
else window_size[1],
softmax_scale)
# Reshape the output tensor.
if use_mla and attn_metadata.prefill_metadata:
return output.view(num_tokens, num_kv_heads * v_head_size)
return output.view(num_tokens, hidden_size)
def unified_flash_attention_v2_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: List[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
attn_type_int_val: int,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
direct_register_custom_op(
op_name="unified_flash_attention_v2",
op_func=unified_flash_attention_v2,
mutates_args=["kv_cache"],
fake_impl=unified_flash_attention_v2_fake,
)
def make_tensor_without_pad(
x: List[List[int]],
dtype: torch.dtype,
device: Union[str, torch.device] = "mlu",
pin_memory: bool = False,
) -> torch.Tensor:
return torch.tensor(x,
dtype=dtype,
device=device,
pin_memory=pin_memory and str(device) == "cpu")
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
MLUFlashAttentionBackend.get_supported_head_sizes,
MLUFlashAttentionBackend_V2.get_supported_head_sizes)
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
MLUFlashAttentionBackend.get_impl_cls,
MLUFlashAttentionBackend_V2.get_impl_cls)
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
MLUFlashAttentionBackend.get_builder_cls,
MLUFlashAttentionBackend_V2.get_builder_cls)
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
MLUFlashAttentionBackend.get_state_cls,
MLUFlashAttentionBackend_V2.get_state_cls)
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
"get_kv_cache_scale_shape",
MLUFlashAttentionBackend_V2.get_kv_cache_scale_shape)
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
MLUFlashAttentionBackend.copy_blocks,
MLUFlashAttentionBackend_V2.copy_blocks)

View File

@@ -0,0 +1,118 @@
"""Attention layer."""
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.layer import Attention
from vllm_mlu.attention.selector import vllm__attention__selector__get_attn_backend as get_attn_backend
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm_mlu._mlu_utils import USE_PAGED
from vllm_mlu.mlu_hijack_utils import MluHijackObject
'''
=============================
Modify by vllm_mlu
=============================
@brief: add a arg use_mla for function get_attn_backend, _cached_get_attn_backend,
which_attn_to_use
'''
'''
==================
End of MLU Hijack
==================
'''
def vllm__attention__layer__Attention__init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
use_mla: bool = False,
prefix: str = "",
) -> None:
super(Attention, self).__init__()
self.use_mla = use_mla
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
sliding_window = cache_config.sliding_window
is_attention_free = cache_config.is_attention_free
else:
kv_cache_dtype = "auto"
block_size = 16
sliding_window = None
is_attention_free = False
if num_kv_heads is None:
num_kv_heads = num_heads
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self._k_scale = 1.0
self._v_scale = 1.0
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None:
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
block_size, is_attention_free,
blocksparse_params is not None,
use_mla=use_mla)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap)
def vllm__attention__layer__Attention__forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type,
use_mla=self.use_mla)
MluHijackObject.apply_hijack(Attention,
Attention.__init__,
vllm__attention__layer__Attention__init__)
MluHijackObject.apply_hijack(Attention,
Attention.forward,
vllm__attention__layer__Attention__forward)

View File

@@ -0,0 +1 @@
import vllm_mlu.attention.ops.prefix_prefill

View File

@@ -0,0 +1,157 @@
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
import torch
import triton
import triton.language as tl
import vllm.attention.ops.prefix_prefill
from vllm_mlu.mlu_hijack_utils import MluHijackObject
if triton.__version__ >= "2.1.0":
@torch.inference_mode()
def vllm__attention__ops__prefix_prefill__context_attention_fwd(q,
k,
v,
o,
k_cache,
v_cache,
b_loc,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
alibi_slopes=None,
sliding_window=None):
'''
=============================
Modify by vllm_mlu
=============================
@brief: use to many memory when block is 64
'''
BLOCK = 16
'''
==================
End of MLU Hijack
==================
'''
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded = triton.next_power_of_2(Lk)
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
num_warps = 8 if Lk <= 64 else 8
if alibi_slopes is not None:
assert Lk == Lk_padded
vllm.attention.ops.prefix_prefill._fwd_kernel_alibi[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
b_start_loc,
b_seq_len,
b_ctx_len,
alibi_slopes,
v_cache.shape[3],
8,
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4
), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
vllm.attention.ops.prefix_prefill._fwd_kernel[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
b_start_loc,
b_seq_len,
b_ctx_len,
v_cache.shape[3],
8,
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
SLIDING_WINDOW=sliding_window if sliding_window is not None else 0,
num_warps=num_warps,
num_stages=1,
)
return
MluHijackObject.apply_hijack(vllm.attention.ops.prefix_prefill,
vllm.attention.ops.prefix_prefill.context_attention_fwd,
vllm__attention__ops__prefix_prefill__context_attention_fwd)

View File

@@ -0,0 +1,802 @@
#!/usr/bin/env python
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
(https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported:
1) Fwd with causal masking
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims
"""
import torch
import triton
import triton.language as tl
torch_dtype: tl.constexpr = torch.float16
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
@triton.jit
def max_fn(x, y):
return tl.math.max(x, y)
@triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
ms = tl.arange(0, m)
ns = tl.arange(0, n)
return philox_offset + ms[:, None] * stride + ns[None, :]
@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
stride).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)
@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
stride)
rng_keep = rng_output > dropout_p
return rng_keep
@triton.jit
def load_fn(block_ptr, first, second, pad):
if first and second:
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
elif first:
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
@triton.jit
def _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
actual_seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr,
OFFS_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
PADDED_HEAD: tl.constexpr,
):
# loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
k = load_fn(
K_block_ptr,
PADDED_HEAD,
MASK_STEPS and (n_extra_tokens != 0),
"zero",
)
if PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
if MASK_STEPS: # noqa: SIM102
# If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
# if not is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M],
actual_seqlen_k,
dtype=tl.int32)
size_n = start_n + OFFS_N[None, :]
mask = size_n < boundary_m[:, None]
qk = tl.where(mask, qk, float("-inf"))
if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ----
qk += tl.dot(q, k)
if bias_ptr is not None:
bias = load_fn(bias_ptr, False, MASK_STEPS
and (n_extra_tokens != 0), "zero")
# While bias is added after multiplying qk with sm_scale, our
# optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk += bias * 1.44269504089
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT:
philox_offset = (batch_philox_offset +
start_m * BLOCK_M * actual_seqlen_k + start_n -
BLOCK_N)
keep = dropout_mask(
philox_seed,
philox_offset,
dropout_p,
BLOCK_M,
BLOCK_N,
actual_seqlen_k,
)
if RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
tl.where(keep, p,
-p).to(encoded_softmax_block_ptr.type.element_ty),
)
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
p.to(encoded_softmax_block_ptr.type.element_ty),
)
# -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
if not PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
# -- update m_i and l_i
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
(0, BLOCK_N))
return acc, l_i, m_i
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 64,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 128,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"PRE_LOAD_V": True,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 64,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 32,
"BLOCK_N": 32,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config(
{
"BLOCK_M": 16,
"BLOCK_N": 16,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
)
@triton.jit
def attn_fwd(
Q,
K,
V,
bias,
sm_scale,
L,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
stride_bz,
stride_bh,
stride_bm,
stride_bn,
cu_seqlens_q,
cu_seqlens_k,
dropout_p,
philox_seed,
philox_offset_base,
encoded_softmax,
HQ: tl.constexpr,
HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q:
return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
else:
cu_seqlens_q_start = 0
cu_seqlens_k_start = 0
seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking.
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
# are completely masked, resulting in 0s written to the output, and
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
if IS_CAUSAL:
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn
# matrix
n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this WG is
# part of the blocks that are all 0. We exit early.
if n_blocks <= 0:
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
off_h_q * stride_oh)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0
# for these masked blocks.
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
# tl.store(l_ptrs, l)
# TODO: Should dropout and return encoded softmax be handled here?
return
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
# Compute pointers for all the tensors used in this kernel.
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
cu_seqlens_q_start * stride_qm)
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
cu_seqlens_k_start * stride_kn)
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
cu_seqlens_k_start * stride_vk)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
if BIAS_TYPE != 0:
bias_ptr = tl.make_block_ptr(
base=bias + off_h_q * stride_bh,
shape=(seqlen_q, seqlen_k),
strides=(stride_bm, stride_bn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
bias_ptr = None
if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base \
+ (off_z * HQ + off_h_q) \
* seqlen_q * seqlen_k
else:
batch_philox_offset = 0
# We can ask to return the dropout mask without actually doing any dropout.
# In this case, we return an invalid pointer so indicate the mask is not i
# valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.make_block_ptr(
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
shape=(seqlen_q, seqlen_k),
strides=(seqlen_k, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
encoded_softmax_block_ptr = 0
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# have native e^x support in HW.
qk_scale = sm_scale * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q = load_fn(Q_block_ptr, True, padded_head, "zero")
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
# Additionally there might be one more due to dissimilar seqlens.
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
else:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
# block. In this case we might exceed n_blocks so pick the min.
masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks
block_min = 0
block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its
# value because there is no masking. Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min,
block_max,
0,
0,
0,
bias_ptr,
# IS_CAUSAL, ....
False,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
False,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
padded_head,
)
block_min = block_max
block_max = n_blocks * BLOCK_N
tl.debug_barrier()
# Remaining blocks, if any, are full / not masked.
if masked_blocks > 0:
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
(0, n_full_blocks))
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
True,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
padded_head,
)
# epilogue
acc = acc / l_i[:, None]
if ENABLE_DROPOUT:
acc = acc / (1 - dropout_p)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
causal_start_idx,
dtype=tl.int32)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None] >=
out_mask_boundary[None, :])
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
# overflow_size = end_m_idx - seqlen_q
# if overflow_size > 0:
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# # This is a > check because mask being 0 blocks the store.
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
# else:
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
off_h_q * stride_oh)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# Need boundary check on this to make sure the padding from the
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
def check_args(
q,
k,
v,
o,
varlen=True,
max_seqlens=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
):
assert q.dim() == k.dim() and q.dim() == v.dim()
if varlen:
assert q.dim() == 3
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
assert cu_seqlens_q is not None
assert cu_seqlens_k is not None
assert len(cu_seqlens_q) == len(cu_seqlens_k)
else:
assert q.dim() == 4
batch, nheads_q, seqlen_q, head_size = q.shape
_, nheads_k, seqlen_k, _ = k.shape
assert max_seqlens > 0
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
class _attention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
o,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
causal=False,
sm_scale=1.0,
bias=None,
):
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
check_args(
q,
k,
v,
o,
varlen=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)
if True: # varlen
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = len(cu_seqlens_q) - 1
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else:
batch, seqlen_q, nheads_q, head_size = q.shape
_, seqlen_k, nheads_k, _ = k.shape
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
# Get closest power of 2 over or equal to 32.
unpadded_head_dims = {32, 64, 128, 256}
if head_size not in unpadded_head_dims:
padded_d_model = None
for i in unpadded_head_dims:
if i > head_size:
padded_d_model = i
break
assert padded_d_model is not None
else:
padded_d_model = head_size
grid = lambda META: (
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
nheads_q,
batch,
)
encoded_softmax = None
# Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52
philox_offset = 0x1D4B42
if bias is not None:
bias_strides = (
bias.stride(0),
bias.stride(1),
bias.stride(2),
bias.stride(3),
)
else:
bias_strides = (0, 0, 0, 0)
attn_fwd[grid](
q,
k,
v,
bias,
sm_scale,
None,
o,
*q_strides,
*k_strides,
*v_strides,
*o_strides,
*bias_strides,
cu_seqlens_q,
cu_seqlens_k,
dropout_p=0.0,
philox_seed=philox_seed,
philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax,
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
IS_CAUSAL=causal,
VARLEN=True,
BLOCK_DMODEL=padded_d_model,
BIAS_TYPE=0 if bias is None else 1,
ENABLE_DROPOUT=False,
RETURN_ENCODED_SOFTMAX=False,
)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = head_size
ctx.causal = causal
ctx.dropout_p = 0.0
ctx.philox_seed = philox_seed
ctx.philox_offset = philox_offset
ctx.encoded_softmax = encoded_softmax
ctx.return_encoded_softmax = False
return o, encoded_softmax
triton_attention = _attention.apply

View File

@@ -0,0 +1,303 @@
import enum
import os
from contextlib import contextmanager
from functools import lru_cache
from typing import Generator, Optional, Type
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import get_global_forced_attn_backend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR
from vllm_mlu._mlu_utils import USE_PAGED
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.attention.selector import _Backend, backend_name_to_enum
from vllm.attention import selector
logger = init_logger(__name__)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add MLU_MLA_FLASH_ATTN for deepseekv2 MLA.
'''
_Backend.MLU_MLA_FLASH_ATTN = enum.auto()
'''
==================
End of MLU Hijack
==================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add a arg use_mla for function get_attn_backend, _cached_get_attn_backend,
which_attn_to_use
'''
'''
==================
End of MLU Hijack
==================
'''
def vllm__attention__selector__get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False,
use_mla: bool = False,
) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
# private function.
return vllm__attention__selector___cached_get_attn_backend(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
block_size=block_size,
is_attention_free=is_attention_free,
is_blocksparse=is_blocksparse,
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
)
@lru_cache(maxsize=None)
def vllm__attention__selector___cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False,
use_v1: bool = False,
use_mla: bool = False,
) -> Type[AttentionBackend]:
if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import (
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend
backend = vllm__attention__selector__which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
is_attention_free, use_v1, use_mla)
if backend == _Backend.FLASH_ATTN:
logger.info("Using Flash Attention backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
if backend == _Backend.FLASH_ATTN_VLLM_V1:
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend as FlashAttentionBackendV1)
return FlashAttentionBackendV1
if backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
XFormersBackend)
return XFormersBackend
elif backend == _Backend.ROCM_FLASH:
logger.info("Using ROCmFlashAttention backend.")
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA:
assert current_platform.is_cpu(), RuntimeError(
"Torch SDPA backend is only used for the CPU device.")
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
elif backend == _Backend.OPENVINO:
logger.info("Using OpenVINO Attention backend.")
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
return OpenVINOAttentionBackend
elif backend == _Backend.IPEX:
assert current_platform.is_xpu(), RuntimeError(
"IPEX attention backend is only used for the XPU device.")
logger.info("Using IPEX attention backend.")
from vllm.attention.backends.ipex_attn import IpexAttnBackend
return IpexAttnBackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
elif backend == _Backend.HPU_ATTN:
logger.info("Using HPUAttention backend.")
from vllm.attention.backends.hpu_attn import HPUAttentionBackend
return HPUAttentionBackend
elif backend == _Backend.PALLAS:
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
return PallasAttentionBackend
elif backend == _Backend.MLU_MLA_FLASH_ATTN:
logger.info("Using MLUFlashAttention backend.")
from vllm_mlu.attention.backends.mlu_attn import MLUMLAFlashAttentionBackend
return MLUMLAFlashAttentionBackend
elif backend == _Backend.MLU_FLASH_ATTN:
logger.info("Using MLUFlashAttention backend.")
from vllm.attention.backends.mlu_attn import MLUFlashAttentionBackend
return MLUFlashAttentionBackend
elif backend == _Backend.NO_ATTENTION:
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
return PlaceholderAttentionBackend
else:
raise ValueError("Invalid attention backend.")
def vllm__attention__selector__which_attn_to_use(head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
use_v1: bool = False,
use_mla: bool = False) -> _Backend:
"""Returns which flash attention backend to use."""
# Default case.
selected_backend = _Backend.FLASH_ATTN
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free:
return _Backend.NO_ATTENTION
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
backend_by_global_setting: Optional[_Backend] = (
get_global_forced_attn_backend())
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
# Check the environment variable and override if specified
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if current_platform.is_cpu():
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA
if current_platform.is_openvino():
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO
if current_platform.is_xpu():
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX
if current_platform.is_tpu():
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
if current_platform.is_mlu():
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add MLU_MLA_FLASH_ATTN for deepseekv2 MLA.
'''
'''
==================
End of MLU Hijack
==================
'''
if use_mla:
return _Backend.MLU_MLA_FLASH_ATTN
if selected_backend != _Backend.MLU_FLASH_ATTN:
logger.debug("Cannot use %s backend on MLU.", selected_backend)
return _Backend.MLU_FLASH_ATTN
if current_platform.is_rocm():
# AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if not current_platform.has_device_capability(90):
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH
if current_platform.is_hpu():
return _Backend.HPU_ATTN
if use_v1:
return _Backend.FLASH_ATTN_VLLM_V1
# FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN:
if not current_platform.has_device_capability(80):
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
selected_backend = _Backend.XFORMERS
elif dtype not in (torch.float16, torch.bfloat16):
logger.info(
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
selected_backend = _Backend.XFORMERS
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
selected_backend = _Backend.XFORMERS
elif block_size % 16 != 0:
logger.info(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
selected_backend = _Backend.XFORMERS
# FlashAttn is valid for the model, checking if the package is installed.
if selected_backend == _Backend.FLASH_ATTN:
try:
import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in supported_sizes:
logger.info(
"Cannot use FlashAttention-2 backend for head size %d.",
head_size)
selected_backend = _Backend.XFORMERS
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the "
"vllm.vllm_flash_attn package is not found. "
"Make sure that vllm_flash_attn was built and installed "
"(on by default).")
selected_backend = _Backend.XFORMERS
return selected_backend
MluHijackObject.apply_hijack(selector,
selector.get_attn_backend,
vllm__attention__selector__get_attn_backend)
MluHijackObject.apply_hijack(selector,
selector._cached_get_attn_backend,
vllm__attention__selector___cached_get_attn_backend)
MluHijackObject.apply_hijack(selector,
selector.which_attn_to_use,
vllm__attention__selector__which_attn_to_use)

View File

@@ -0,0 +1,138 @@
from typing import Tuple
from vllm.logger import init_logger
from vllm.config import ModelConfig, CacheConfig, LoRAConfig
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
vllm__config__LoRAConfig__verify_with_model_config_org = LoRAConfig.verify_with_model_config
def vllm__config__CacheConfig___verify_cache_dtype(self) -> None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add kv_cache_dtype int8 support
'''
if self.cache_dtype == "auto":
pass
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor")
elif self.cache_dtype == 'int8':
logger.info(
"Using int8 data type to store kv cache. It reduces the MLU "
"memory footprint and boosts the performance. ")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
'''
==================
End of MLU Hijack
==================
'''
def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type == 'deepseek_v2':
# feature flag MLA
return 1
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)
def vllm__config__ModelConfig__get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type == 'deepseek_v2':
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace 256 to 192.
'''
return 576
'''
==================
End of MLU Hijack
==================
'''
if self.is_attention_free:
return 0
if hasattr(self.hf_text_config, "head_dim"):
return self.hf_text_config.head_dim
# FIXME(woosuk): This may not be true for all models.
return (self.hf_text_config.hidden_size //
self.hf_text_config.num_attention_heads)
def vllm__config__ModelConfig__set_context_mlugraph_info(
self, enable_context_mlugraph: bool, batch_size: int, seq_len: int) -> None:
self.enable_context_mlugraph = enable_context_mlugraph
self.context_batch_size_to_capture = batch_size
self.context_seq_len_to_capture = seq_len
def vllm__config__ModelConfig__use_context_mlugraph(self) -> bool:
return hasattr(self, "enable_context_mlugraph") and self.enable_context_mlugraph
def vllm__config__ModelConfig__get_context_mlugraph_bs_and_seq(self) -> Tuple[int, int]:
return self.context_batch_size_to_capture, self.context_seq_len_to_capture
def vllm__config__LoRAConfig__verify_with_model_config(self, model_config: ModelConfig):
'''
=============================
Modify by vllm_mlu
=============================
@brief: do not support quantization with lora for now
'''
if model_config.quantization:
raise ValueError("vllm mlu does not support quantization with lora for now")
'''
==================
End of MLU Hijack
==================
'''
vllm__config__LoRAConfig__verify_with_model_config_org(self, model_config)
@property
def vllm__config__ModelConfig__is_deepseek_v2(self) -> bool:
result = hasattr(
self.hf_text_config,
"model_type") and self.hf_text_config.model_type == 'deepseek_v2'
return result
MluHijackObject.apply_hijack(ModelConfig,
"is_deepseek_v2",
vllm__config__ModelConfig__is_deepseek_v2)
MluHijackObject.apply_hijack(ModelConfig,
"set_context_mlugraph_info",
vllm__config__ModelConfig__set_context_mlugraph_info)
MluHijackObject.apply_hijack(ModelConfig,
"use_context_mlugraph",
vllm__config__ModelConfig__use_context_mlugraph)
MluHijackObject.apply_hijack(ModelConfig,
"get_context_mlugraph_bs_and_seq",
vllm__config__ModelConfig__get_context_mlugraph_bs_and_seq)
MluHijackObject.apply_hijack(CacheConfig,
CacheConfig._verify_cache_dtype,
vllm__config__CacheConfig___verify_cache_dtype)
MluHijackObject.apply_hijack(ModelConfig,
ModelConfig.get_head_size,
vllm__config__ModelConfig__get_head_size)
MluHijackObject.apply_hijack(ModelConfig,
ModelConfig.get_num_kv_heads,
vllm__config__ModelConfig__get_num_kv_heads)
MluHijackObject.apply_hijack(LoRAConfig,
LoRAConfig.verify_with_model_config,
vllm__config__LoRAConfig__verify_with_model_config)

View File

@@ -0,0 +1 @@
import vllm_mlu.core.block_manager

View File

@@ -0,0 +1,56 @@
from vllm.sequence import SequenceGroup, SequenceStatus
from vllm_mlu._mlu_utils import USE_PAGED
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.core.block_manager import SelfAttnBlockSpaceManager
from vllm.utils import Device
from vllm.logger import init_logger
logger = init_logger(__name__)
def vllm__core__block_manager__SelfAttnBlockSpaceManager__can_append_slots(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
"""Determine if there is enough space in the GPU KV cache to continue
generation of the specified sequence group.
We use a worst-case heuristic: assume each touched block will require a
new allocation (either via CoW or new block). We can append slots if the
number of touched blocks is less than the number of free blocks.
"Lookahead slots" are slots that are allocated in addition to the slots
for known tokens. The contents of the lookahead slots are not defined.
This is used by speculative decoding when speculating future tokens.
"""
'''
=============================
Modify by vllm_mlu
=============================
@brief: optimize the allocation strategy for unpagged mode
'''
if not USE_PAGED:
return True
else:
num_touched_blocks = 0
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
block_table = self.block_tables[seq.seq_id]
num_touched_blocks += (
block_table.get_num_blocks_touched_by_append_slots(
token_ids=block_table.get_unseen_token_ids(
seq.get_token_ids()),
num_lookahead_slots=num_lookahead_slots,
))
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
Device.GPU)
return num_touched_blocks <= num_free_gpu_blocks
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(SelfAttnBlockSpaceManager,
SelfAttnBlockSpaceManager.can_append_slots,
vllm__core__block_manager__SelfAttnBlockSpaceManager__can_append_slots)

View File

@@ -0,0 +1,328 @@
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import deque
from typing import Deque, List, Optional, Set, Callable
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.scheduler import (SchedulingBudget, SchedulerPrefillOutputs,
SchedulerRunningOutputs, SchedulerOutputs, Scheduler)
from vllm.sequence import SequenceGroup
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
logger = init_logger(__name__)
vllm__core__scheduler__Scheduler____init____org = Scheduler.__init__
vllm__core__scheduler__Scheduler___schedule_prefills__org = Scheduler._schedule_prefills
vllm__core__scheduler__Scheduler___schedule_running__org = Scheduler._schedule_running
vllm__core__scheduler__Scheduler___schedule__org = Scheduler._schedule
def vllm__core__scheduler__Scheduler__init_scheduler_view(self):
logger.info(f"vLLM scheduler profiling start...")
self.df = pd.DataFrame(
data={
'waiting': [], 'running': [], 'swapped': [], 'finished': [],
'wait_to_run_reqs': [], 'run_to_wait_reqs': [], 'wait_to_run_tokens': [],
'batch_utils': [], 'block_utils': [], 'preempt_ratio': []
},
dtype=np.float32
)
self.sched_step = 0
self.running_seqs = 0
self.waiting_seqs = 0
self.swapped_seqs = 0
self.finished_seqs = 0
self.total_seqs = 0
self.running_to_waiting_seqs = 0
self.waiting_to_running_seqs = 0
self.wait_to_run_tokens = 0
self.batch_utils = 0
self.block_utils = 0
self.preempt_ratio = 0
self.finished_seq_groups = []
def summary_finished_seq_groups(seq_groups: List[SequenceGroup]):
df = pd.DataFrame(
data={
'ttft/s': [], 'time_in_queue/s': [], 'context_latency/s': [], 'decoder_latency/s': []
},
dtype=np.float32
)
for seq_group in seq_groups:
ttft = seq_group.metrics.first_token_time - seq_group.metrics.arrival_time
time_in_queue = seq_group.metrics.time_in_queue
context_latency = seq_group.metrics.first_token_time - seq_group.metrics.first_scheduled_time
decoder_latency = seq_group.metrics.finished_time - seq_group.metrics.first_token_time
decoder_token_num = seq_group.get_seqs()[0].get_output_len() - 1
per_token_latency = decoder_latency if decoder_token_num == 0 \
else decoder_latency / decoder_token_num
df_ = pd.DataFrame(
[[ttft, time_in_queue, context_latency, decoder_latency, per_token_latency, decoder_token_num]],
columns=['ttft/s', 'time_in_queue/s', 'context_latency/s', 'decoder_latency/s', 'per_token_latency/s', 'decoder_tokens'],
index=[str(seq_group.request_id)]
)
df = pd.concat([df, df_])
sum_, max_, mean_, min_, p99_ = df.sum(), df.max(), df.mean(), df.min(), df.quantile(0.99)
df.loc['Sum'] = sum_
df.loc['Max'] = max_
df.loc['Mean'] = mean_
df.loc['Min'] = min_
df.loc['P99'] = p99_
return df
def vllm__core__scheduler__Scheduler__save_scheduler_view(self, scheduler_idx=0):
logger.info(f"vLLM scheduler profiling save...")
plt.rcParams.update({'font.size': 8})
figure = plt.figure(figsize=(6.4, 5.6))
gs = figure.add_gridspec(3, hspace=0)
axes = gs.subplots(sharex=True, sharey=False)
figure.suptitle("Cambricon vLLM Scheduler View")
# scheduler queue view
self.df.plot(ax=axes[0], y=['waiting', 'running', 'swapped', 'finished'])
axes[0].set_xlabel('X-LLMEngineStep', loc='left')
axes[0].set_ylabel('Y-ReqNum', loc='top')
# utilization
self.df.plot(ax=axes[1], y=['batch_utils', 'block_utils', 'preempt_ratio'])
axes[1].set_xlabel('X-LLMEngineStep', loc='left')
axes[1].set_ylabel('Y-Utilization(%)', loc='top')
# token view
self.df.plot(ax=axes[2], y=['wait_to_run_tokens'])
axes[2].set_xlabel('X-LLMEngineStep', loc='left')
axes[2].set_ylabel('Y-TokenNum', loc='top')
for ax in axes:
ax.label_outer()
ax.legend(loc='upper right')
figure.tight_layout()
figure.savefig(f"vllm_scheduler{scheduler_idx}_view.svg", dpi=300, format='svg')
plt.close(figure)
time_df = summary_finished_seq_groups(self.finished_seq_groups)
sched_df = self.df.copy(deep=True)
max_, mean_, min_ = sched_df.max(), sched_df.mean(), sched_df.min()
sched_df.loc["Max"] = max_
sched_df.loc["Mean"] = mean_
sched_df.loc["Min"] = min_
with pd.option_context('display.max_rows', None,
'display.max_columns', None,
'display.max_colwidth', None,
'display.float_format', '{:^6,.2f}'.format,
'expand_frame_repr', False):
logger.info(sched_df.loc[["Max", "Mean", "Min"]])
logger.info(time_df.loc[["Sum", "Max", "Mean", "Min", "P99"]])
sched_df.astype(str).to_csv(f"vllm_scheduler{scheduler_idx}_step_view.csv", mode="w")
time_df.astype(str).to_csv(f"vllm_scheduler{scheduler_idx}_reqs_view.csv", mode="w")
def vllm__core__scheduler__Scheduler____init__(
self,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1,
output_proc_callback: Optional[Callable] = None,
) -> None:
vllm__core__scheduler__Scheduler____init____org(
self=self,
scheduler_config=scheduler_config,
cache_config=cache_config,
lora_config=lora_config,
pipeline_parallel_size=pipeline_parallel_size,
output_proc_callback=output_proc_callback
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add for scheduler profiling
'''
self.init_scheduler_view()
'''
==================
End of MLU Hijack
==================
'''
def vllm__core__scheduler__Scheduler___schedule_prefills(
self,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
) -> SchedulerPrefillOutputs:
prefills = vllm__core__scheduler__Scheduler___schedule_prefills__org(
self=self,
budget=budget,
curr_loras=curr_loras,
enable_chunking=enable_chunking
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add for scheduler profiling
'''
self.waiting_to_running_seqs = len(prefills.seq_groups)
self.wait_to_run_tokens = budget.num_batched_tokens
'''
==================
End of MLU Hijack
==================
'''
return prefills
def vllm__core__scheduler__Scheduler___schedule_running(
self,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
) -> SchedulerRunningOutputs:
running_scheduled = vllm__core__scheduler__Scheduler___schedule_running__org(
self=self,
budget=budget,
curr_loras=curr_loras,
enable_chunking=enable_chunking
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add for scheduler profiling
'''
self.running_to_waiting_seqs += len(running_scheduled.preempted)
'''
==================
End of MLU Hijack
==================
'''
return running_scheduled
def vllm__core__scheduler__Scheduler___schedule(self) -> SchedulerOutputs:
scheduler_outputs = vllm__core__scheduler__Scheduler___schedule__org(self)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add for scheduler profiling
'''
self.sched_step += 1
self.running_seqs = len(self.running)
self.waiting_seqs = len(self.waiting)
self.swapped_seqs = len(self.swapped)
total_seqs_ = self.running_seqs + self.waiting_seqs + self.swapped_seqs + self.finished_seqs
if total_seqs_ == 0:
return
if total_seqs_ > self.total_seqs:
self.total_seqs = total_seqs_
self.batch_utils = self.running_seqs / self.scheduler_config.max_num_seqs
self.block_utils = (self.block_manager.num_total_gpu_blocks -
self.block_manager.get_num_free_gpu_blocks()) / self.block_manager.num_total_gpu_blocks
self.preempt_ratio = self.running_to_waiting_seqs / self.total_seqs
df_ = pd.DataFrame(
[[self.waiting_seqs, self.running_seqs, self.swapped_seqs,
self.waiting_to_running_seqs, self.running_to_waiting_seqs, self.wait_to_run_tokens,
self.batch_utils, self.block_utils, self.preempt_ratio]],
columns=['waiting', 'running', 'swapped',
'wait_to_run_reqs', 'run_to_wait_reqs', 'wait_to_run_tokens',
'batch_utils', 'block_utils', 'preempt_ratio'],
index=[str(self.sched_step)])
self.df = pd.concat([self.df, df_])
'''
==================
End of MLU Hijack
==================
'''
return scheduler_outputs
def vllm__core__scheduler__Scheduler__free_finished_seq_groups(self) -> None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add for scheduler profiling
'''
finished_seq_groups_ = []
remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running:
self._free_finished_seq_group(seq_group)
if not seq_group.is_finished():
remaining.append(seq_group)
else:
finished_seq_groups_.append(seq_group)
self.finished_seqs += len(finished_seq_groups_)
self.finished_seq_groups += finished_seq_groups_
'''
==================
End of MLU Hijack
==================
'''
self.running = remaining
# Handle async stopped sequence groups
# (ones that reached max model len)
if self._async_stopped:
for seq_group in self._async_stopped:
self._free_seq_group_cross_attn_blocks(seq_group)
self._finished_requests_ids.append(seq_group.request_id)
# Free finished seqs
self._free_finished_seqs(seq_group)
self._async_stopped.clear()
def vllm__core__scheduler__Scheduler____del__(self):
'''
=============================
Modify by vllm_mlu
=============================
@brief: add for scheduler profiling
'''
self.save_scheduler_view()
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(Scheduler,
Scheduler.__init__,
vllm__core__scheduler__Scheduler____init__)
MluHijackObject.apply_hijack(Scheduler,
Scheduler._schedule_prefills,
vllm__core__scheduler__Scheduler___schedule_prefills)
MluHijackObject.apply_hijack(Scheduler,
Scheduler._schedule_running,
vllm__core__scheduler__Scheduler___schedule_running)
MluHijackObject.apply_hijack(Scheduler,
Scheduler._schedule,
vllm__core__scheduler__Scheduler___schedule)
MluHijackObject.apply_hijack(Scheduler,
Scheduler.free_finished_seq_groups,
vllm__core__scheduler__Scheduler__free_finished_seq_groups)
MluHijackObject.apply_hijack(Scheduler,
"__del__",
vllm__core__scheduler__Scheduler____del__)
MluHijackObject.apply_hijack(Scheduler,
"init_scheduler_view",
vllm__core__scheduler__Scheduler__init_scheduler_view)
MluHijackObject.apply_hijack(Scheduler,
"save_scheduler_view",
vllm__core__scheduler__Scheduler__save_scheduler_view)

View File

@@ -0,0 +1 @@
import vllm_mlu.distributed.parallel_state

View File

@@ -0,0 +1,134 @@
import torch
from contextlib import contextmanager, nullcontext
from torch.distributed import Backend
from typing import Any, Dict, List, Optional, Tuple, Union
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.distributed.parallel_state import (GroupCoordinator,
GraphCaptureContext)
from vllm.logger import init_logger
logger = init_logger(__name__)
vllm__distributed__parallel_state__GroupCoordinator____init____org = GroupCoordinator.__init__
def vllm__distributed__parallel_state__GroupCoordinator____init__(
self,
group_ranks: List[List[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
use_pynccl: bool,
use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_hpu_communicator: bool,
use_xpu_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
'''
=============================
Modify by vllm_mlu
=============================
@brief: disable pynccl and custom_allreduce by default
'''
if use_pynccl or use_custom_allreduce:
logger.debug(f"Disable pynccl and custom_allreduce when using MLU backend.")
vllm__distributed__parallel_state__GroupCoordinator____init____org(
self=self,
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=torch_distributed_backend,
use_pynccl=False,
use_custom_allreduce=False,
use_tpu_communicator=use_tpu_communicator,
use_hpu_communicator=use_hpu_communicator,
use_xpu_communicator=use_xpu_communicator,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name
)
'''
==================
End of MLU Hijack
==================
'''
def vllm__distributed__parallel_state__GroupCoordinator__gather(
self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
if self.xpu_communicator is not None and \
not self.xpu_communicator.disabled:
return self.xpu_communicator.gather(input_, self.rank_in_group,
dst, dim)
'''
=============================
Modify by vllm_mlu
=============================
@brief: use p2p communication to reduce gather host time for non driver worker.
NOTE: this hijack function should be REMOVED when torch upgrade to 2.4.0
'''
rank = self.rank_in_group
gather_list = None
if rank == dst:
gather_list = [
torch.empty_like(input_) for _ in range(dst)
] + [input_] + [
torch.empty_like(input_) for _ in range(dst + 1, world_size)
]
send_recv_op_list = []
if rank != dst:
op = torch.distributed.P2POp(torch.distributed.isend,
input_,
self.ranks[dst],
group=self.device_group)
send_recv_op_list.append(op)
else:
for r in range(0, world_size):
if r == dst:
continue
op = torch.distributed.P2POp(torch.distributed.irecv,
gather_list[r],
self.ranks[r],
group=self.device_group)
send_recv_op_list.append(op)
reqs = torch.distributed.batch_isend_irecv(send_recv_op_list)
for req in reqs:
req.wait()
'''
==================
End of MLU Hijack
==================
'''
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
MluHijackObject.apply_hijack(GroupCoordinator,
GroupCoordinator.__init__,
vllm__distributed__parallel_state__GroupCoordinator____init__)
MluHijackObject.apply_hijack(GroupCoordinator,
GroupCoordinator.gather,
vllm__distributed__parallel_state__GroupCoordinator__gather)

View File

@@ -0,0 +1,409 @@
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import get_is_gated, ModelConfig
import ctypes
import json
from vllm.transformers_utils.config import get_config
from vllm.entrypoints.llm import LLM
from vllm_mlu._mlu_utils import VLLM_DUMP_CPU_INFO_EN, VLLM_DUMP_MLU_INFO_EN
logger = init_logger(__name__)
def get_deepseek_v2_flops(bcfg, batch, seq_len, hidden_size):
ATTN_PAD_SIZE = 192
qk_nope_head_dim = bcfg.qk_nope_head_dim
qk_rope_head_dim = bcfg.qk_rope_head_dim
v_head_dim = bcfg.v_head_dim
q_lora_rank = bcfg.q_lora_rank
kv_lora_rank = bcfg.kv_lora_rank
context_atn_pre = 2 * batch * seq_len * \
(hidden_size * q_lora_rank + \
hidden_size * (kv_lora_rank + qk_rope_head_dim) + \
q_lora_rank * bcfg.head_num * (qk_nope_head_dim + qk_rope_head_dim) + \
kv_lora_rank * bcfg.head_num * (qk_nope_head_dim + v_head_dim))
context_atn_qk = 2 * batch * seq_len * seq_len * bcfg.head_num * ATTN_PAD_SIZE
context_atn_qkv = 2 * batch * seq_len * seq_len * bcfg.head_num * ATTN_PAD_SIZE
context_atn_post = 2 * batch * seq_len * bcfg.head_num * v_head_dim * hidden_size
return context_atn_pre, context_atn_qk, context_atn_qkv, context_atn_post
class FlopsInfo(ctypes.Structure):
_fields_ = [("context_flops", ctypes.c_double),
("decoder_flops", ctypes.c_double)]
class LLMDumpInfo:
def __init__(self,
tensor_parallel_size=None,
dtype=None, kv_cache_dtype=None,
quantization=None,
model=None, batch_size=None,
input_len=None,
output_len=None,
trust_remote_code=None)->None:
self.so_file = None
self.dev_info = None
self.cpu_info = None
self.lib = None
self.hfu_info = None
self.flops_info = None
self.ctypes_model_config = ModelConfig()
self.io_efficiency = 0
self.context_latency_device = 0
self.generate_latency_device = 0
self.tensor_parallel_size = tensor_parallel_size
self.dtype = dtype
self.kv_cache_dtype = kv_cache_dtype
self.quantization = quantization
self.batch_size = batch_size
self.input_len = input_len
self.output_len = output_len
self.model = model
self.model_config = None
try:
from vllm_mlu.device_info import get_info_inner
self.so_file,self.dev_info,self.cpu_info,self.lib = get_info_inner(self.so_file, self.dev_info, self.cpu_info, self.lib)
except:
logger.info("Cannot get device info")
def init_param(self,
tensor_parallel_size=None,
dtype=None,
kv_cache_dtype=None,
quantization=None,
model=None,
batch_size=None,
input_len=None,
output_len=None,
trust_remote_code=None,
context_latency_device=None,
generate_latency_device=None):
if tensor_parallel_size != None:
self.tensor_parallel_size = tensor_parallel_size
if dtype != None:
self.dtype = dtype
if kv_cache_dtype != None:
self.kv_cache_dtype = kv_cache_dtype
if quantization != None:
self.quantization = quantization
if model != None:
self.model = model
if batch_size != None:
self.batch_size = batch_size
if input_len != None:
self.input_len = input_len
if output_len != None:
self.output_len = output_len
if trust_remote_code != None:
self.trust_remote_code = trust_remote_code
if context_latency_device != None:
self.context_latency_device = context_latency_device
if generate_latency_device != None:
self.generate_latency_device = generate_latency_device
# paser the model config
if self.model_config == None and self.model != None and self.trust_remote_code != None:
self.model_config = get_config(self.model, self.trust_remote_code)
def initialize_ctypes_model_config(self, model_cfg, tp_num, weight_dtype, kv_cache_dtype, quantization):
# prepare input
self.ctypes_model_config.hidden_size = model_cfg.hidden_size
self.ctypes_model_config.vocab_size = model_cfg.vocab_size
self.ctypes_model_config.cla_coeffient = 1.0
possible_keys_ffn_size = [
# chatglm3-6b-32k
"ffn_hidden_size",
# llama3-8b-hf
"intermediate_size",
]
possible_kv_heads = [
# chatglm3-6b-32k
"multi_query_group_num",
# llama3-8b-hf
"num_key_value_heads",
# falcon-180B-chat
"num_kv_heads",
]
possible_num_attention_heads = [
"num_attention_heads",
"n_heads",
]
moe_size=None
ffn_size=None
if getattr(model_cfg, "moe_intermediate_size", None):
moe_size = getattr(model_cfg, "moe_intermediate_size", None)
for key in possible_keys_ffn_size:
ffn_size = getattr(model_cfg, key, None)
if ffn_size is not None:
break
if model_cfg.model_type in ['bloom'] and ffn_size is None:
ffn_size = model_cfg.hidden_size * 4
if model_cfg.model_type in ['qwen']:
ffn_size = model_cfg.intermediate_size // 2
if ffn_size is None and moe_size is None:
logger.warning("The model's config.json does not contain any of the following"
"keys to determine the ffn_size or moe_size: "
f"{possible_keys_ffn_size}. ")
for key in possible_num_attention_heads:
num_attention_heads = getattr(model_cfg, key, None)
if num_attention_heads is not None:
break
if num_attention_heads is None:
logger.error("The model's config.json does not contain any of the following"
"keys to determine the num_attention_heads: "
f"{possible_num_attention_heads}. ")
for key in possible_kv_heads:
kv_heads = getattr(model_cfg, key, None)
if kv_heads is not None:
break
if kv_heads is None:
logger.warning("The model's config.json does not contain any of the following"
"keys to determine the kv_heads: "
f"{possible_kv_heads}, use num_attention_heads to replace")
kv_heads = model_cfg.num_attention_heads
self.ctypes_model_config.ffn_inner_size = 0 if ffn_size is None else ffn_size
self.ctypes_model_config.moe_inner_size = 0 if moe_size is None else moe_size
self.ctypes_model_config.moe_layer_num = 0 if moe_size is None else model_cfg.num_hidden_layers
self.ctypes_model_config.layer_num = model_cfg.num_hidden_layers
self.ctypes_model_config.head_num = num_attention_heads
self.ctypes_model_config.head_size = self.ctypes_model_config.hidden_size / self.ctypes_model_config.head_num
self.ctypes_model_config.head_num_kv = kv_heads
self.ctypes_model_config.tp_num = tp_num
if hasattr(model_cfg, "shared_expert_intermediate_size") and model_cfg.shared_expert_intermediate_size is not None:
self.ctypes_model_config.shared_expert_intermediate_size = model_cfg.shared_expert_intermediate_size
else:
self.ctypes_model_config.shared_expert_intermediate_size = 0
self.ctypes_model_config.use_gated_ffn = get_is_gated()
if hasattr(model_cfg, "n_shared_experts") and model_cfg.n_shared_experts is not None:
self.ctypes_model_config.shared_expert_intermediate_size = model_cfg.n_shared_experts * moe_size
else:
self.ctypes_model_config.shared_experts = 0
if hasattr(model_cfg, "num_experts") and model_cfg.num_experts is not None:
self.ctypes_model_config.experts_num = model_cfg.num_experts
if model_cfg.model_type == 'hunyuan':
self.ctypes_model_config.topk_num = model_cfg.moe_topk
else:
self.ctypes_model_config.topk_num = model_cfg.num_experts_per_tok
elif hasattr(model_cfg, "num_local_experts"):
self.ctypes_model_config.experts_num = model_cfg.num_local_experts
if model_cfg.model_type == 'hunyuan':
self.ctypes_model_config.topk_num = model_cfg.moe_topk
else:
self.ctypes_model_config.topk_num = model_cfg.num_experts_per_tok
elif hasattr(model_cfg, "n_routed_experts"):
self.ctypes_model_config.experts_num = model_cfg.n_routed_experts
if model_cfg.model_type == 'hunyuan':
self.ctypes_model_config.topk_num = model_cfg.moe_topk
else:
self.ctypes_model_config.topk_num = model_cfg.num_experts_per_tok
else:
self.ctypes_model_config.experts_num = 0
if hasattr(model_cfg, "model_type") and model_cfg.model_type is not None:
self.ctypes_model_config.model_type = model_cfg.model_type.encode('utf-8')
# when adding a moe model, need fix moe/ffn info, like
# moe_inner_size, ffn_inner_size, moe_layer_num, shared_expert_intermediate_size.
# add for mixtral
if model_cfg.model_type == "mixtral":
self.ctypes_model_config.moe_inner_size = ffn_size
self.ctypes_model_config.ffn_inner_size = 0
self.ctypes_model_config.moe_layer_num = model_cfg.num_hidden_layers
# add for deepseek-v2
if model_cfg.model_type == "deepseek_v2":
if hasattr(model_cfg, "first_k_dense_replace") and model_cfg.first_k_dense_replace is not None:
self.ctypes_model_config.moe_layer_num = model_cfg.num_hidden_layers - model_cfg.first_k_dense_replace
if hasattr(model_cfg, "qk_nope_head_dim") and model_cfg.qk_nope_head_dim is not None:
self.ctypes_model_config.qk_nope_head_dim = model_cfg.qk_nope_head_dim
if hasattr(model_cfg, "qk_rope_head_dim") and model_cfg.qk_rope_head_dim is not None:
self.ctypes_model_config.qk_rope_head_dim = model_cfg.qk_rope_head_dim
if hasattr(model_cfg, "v_head_dim") and model_cfg.v_head_dim is not None:
self.ctypes_model_config.v_head_dim = model_cfg.v_head_dim
if hasattr(model_cfg, "q_lora_rank") and model_cfg.q_lora_rank is not None:
self.ctypes_model_config.q_lora_rank = model_cfg.q_lora_rank
else:
self.ctypes_model_config.q_lora_rank = 0
if hasattr(model_cfg, "kv_lora_rank") and model_cfg.kv_lora_rank is not None:
self.ctypes_model_config.kv_lora_rank = model_cfg.kv_lora_rank
# add for Hunyuan
if model_cfg.model_type == "hunyuan":
self.ctypes_model_config.cla_coeffient = 0.5 # huanyuan model use CLA2
if hasattr(model_cfg, "num_shared_expert") and model_cfg.num_shared_expert is not None:
self.ctypes_model_config.shared_expert_intermediate_size = model_cfg.num_shared_expert * model_cfg.intermediate_size
if not self.ctypes_model_config.moe_inner_size and model_cfg.intermediate_size is not None:
self.ctypes_model_config.moe_inner_size = model_cfg.intermediate_size
if not self.ctypes_model_config.moe_layer_num and hasattr(model_cfg, "num_experts"):
self.ctypes_model_config.moe_layer_num = model_cfg.num_hidden_layers
self.ctypes_model_config.use_causal_mask = True # the flash attention is only use causal_mask in vllm
if weight_dtype == "auto":
self.ctypes_model_config.data_type = b'float16'
else:
self.ctypes_model_config.data_type = weight_dtype.encode('utf-8')
if quantization != None:
with open(self.model + "/quantize_config.json", 'r') as file:
config = json.load(file)
if config["quant_mode"] == "SmoothQuant":
self.ctypes_model_config.smooth_quant_type = b"SmoothQuant"
else:
self.ctypes_model_config.smooth_quant_type = b'invalid'
self.ctypes_model_config.filter_data_type = ("int" + str(config['bits'])).encode('utf-8')
else:
self.ctypes_model_config.smooth_quant_type = b'invalid'
self.ctypes_model_config.filter_data_type = self.ctypes_model_config.data_type
if kv_cache_dtype == "auto":
self.ctypes_model_config.kv_cache_dtype = self.ctypes_model_config.data_type
else:
self.ctypes_model_config.kv_cache_dtype = kv_cache_dtype.encode('utf-8')
def get_flops(self, bcfg, once_batch, input_seq_len, output_length, flops_info):
self.batch_size = once_batch
seq_len = input_seq_len
hidden_size = bcfg.hidden_size
voc_size = bcfg.vocab_size
ffn_size = bcfg.ffn_inner_size
moe_size = bcfg.moe_inner_size
shared_expert_intermediate_size = bcfg.shared_expert_intermediate_size
layer_num = bcfg.layer_num
out_seq = output_length
seq_len_decode = seq_len + out_seq / 2
r = bcfg.head_num / bcfg.head_num_kv
bsh2 = self.batch_size * seq_len * hidden_size * hidden_size
cla_coeffient = bcfg.cla_coeffient
if bcfg.model_type == b'deepseek_v2':
context_atn_pre, context_atn_qk, context_atn_qkv, context_atn_post = (
get_deepseek_v2_flops(bcfg, self.batch_size, seq_len, hidden_size)
)
else:
context_atn_pre = 2 * bsh2 + 4 * bsh2 / r * cla_coeffient
context_atn_qk = 2 * self.batch_size * seq_len * seq_len * hidden_size
context_atn_qkv = 2 * self.batch_size * seq_len * seq_len * hidden_size
context_atn_post = 2 * self.batch_size * seq_len * hidden_size * hidden_size
context_lm_head = 2 * self.batch_size * seq_len * hidden_size * voc_size
context_ffn = 0
bh2 = self.batch_size * hidden_size * hidden_size
decode_atn_pre = 2 * bh2 + 4 * bh2 / r * cla_coeffient
decode_atn_qk = 2 * self.batch_size * seq_len_decode * hidden_size
decode_atn_qkv = 2 * self.batch_size * seq_len_decode * hidden_size
decode_atn_post = 2 * self.batch_size * hidden_size * hidden_size
decode_lm_head = 2 * self.batch_size * hidden_size * voc_size
decode_ffn = 0
coeffient = 6 if bcfg.use_gated_ffn else 4
if bcfg.experts_num == 0:
context_ffn = coeffient * self.batch_size * seq_len * hidden_size * ffn_size
decode_ffn = coeffient * self.batch_size * hidden_size * ffn_size
else:
context_ffn = self.batch_size * seq_len * hidden_size * (coeffient * (moe_size * bcfg.topk_num + shared_expert_intermediate_size) + 2 * bcfg.experts_num)
decode_ffn = self.batch_size * hidden_size * (coeffient * (moe_size * bcfg.topk_num + shared_expert_intermediate_size) + 2 * bcfg.experts_num)
if bcfg.use_causal_mask:
c = 0.5
context_atn_qk *= c
context_atn_qkv *= c
flops_info.context_flops = context_lm_head
flops_info.decoder_flops = decode_lm_head
if bcfg.kv_cache_dtype != b"int8":
flops_info.context_flops += (layer_num * (context_atn_qk + context_atn_qkv))
flops_info.decoder_flops += (layer_num * (decode_atn_qk + decode_atn_qkv))
else:
flops_info.context_flops += (layer_num * (context_atn_qk + context_atn_qkv))
flops_info.decoder_flops += (layer_num * (decode_atn_qk + decode_atn_qkv))
if bcfg.smooth_quant_type == b"invalid":
flops_info.context_flops += (layer_num * (context_atn_pre + context_atn_post + context_ffn))
flops_info.decoder_flops += (layer_num * (decode_atn_pre + decode_atn_post + decode_ffn))
else:
flops_info.context_flops += (layer_num * (context_atn_pre + context_atn_post + context_ffn))
flops_info.decoder_flops += (layer_num * (decode_atn_pre + decode_atn_post + decode_ffn))
def capture_cpu_info(self):
if VLLM_DUMP_CPU_INFO_EN and self.cpu_info:
try:
from vllm_mlu.device_info import capture_cpu_info
self.cpu_info = capture_cpu_info(self.cpu_info, my_rank=0)
except:
logger.info("Unsupport capture_cpu_info function")
def memory_usage(self):
if VLLM_DUMP_CPU_INFO_EN and self.cpu_info:
try:
from vllm_mlu.device_info import memory_usage
self.cpu_info = memory_usage(self.cpu_info)
except:
logger.info("Unsupport memory_usage function")
def analyze_perf_data(self, rank=0):
try:
from vllm_mlu.device_info import analyze_perf_data
analyze_perf_data(self.cpu_info, self.lib)
except:
logger.info("Cannot analyze perf data, no analyze_perf_data function")
def get_decoder_io_efficiency(self, ctypes_model_config, lib, batch_size, input_len, output_len, generate_latency_device):
try:
from vllm_mlu.device_info import get_decoder_io_efficiency
self.io_efficiency = get_decoder_io_efficiency(ctypes_model_config, lib, batch_size, input_len, output_len, generate_latency_device)
except:
logger.info("Unsupport io_efficiency get_decoder_io_efficiency function")
def get_device_output_info(self,
model_cfg,
batch_size,
input_seq_len,
output_length,
tp_num,
weight_dtype,
kv_cache_dtype,
quantization):
self.initialize_ctypes_model_config(model_cfg, tp_num, weight_dtype, kv_cache_dtype, quantization)
if VLLM_DUMP_CPU_INFO_EN and self.so_file:
self.analyze_perf_data()
if VLLM_DUMP_MLU_INFO_EN and self.lib:
from vllm_mlu.device_info import get_flops_inner, HFUInfo
self.hfu_info = HFUInfo()
get_flops_inner(self.ctypes_model_config, batch_size, input_seq_len, output_length, tp_num, self.hfu_info, self.lib)
self.get_decoder_io_efficiency(self.ctypes_model_config,
self.lib,
self.batch_size,
self.input_len,
self.output_len,
self.generate_latency_device)
else:
self.flops_info = FlopsInfo()
self.get_flops(self.ctypes_model_config, batch_size, input_seq_len, output_length, self.flops_info)
def has_information_dump(self):
if self.dev_info and self.dev_info.so_file:
return True
return False
def dump(self):
self.get_device_output_info(self.model_config,
self.batch_size,
self.input_len,
self.output_len,
self.tensor_parallel_size,
self.dtype,
self.kv_cache_dtype,
self.quantization)
try:
from vllm_mlu.device_info import dump
dump(LLM.dump_info)
except:
logger.info("Unsupport dump device/cpu information")
def dump_performance_info(self):
try:
from vllm_mlu.device_info import dump_information
dump_information(LLM.dump_info)
except:
logger.info("Unsupport dump performance information")

View File

@@ -0,0 +1,2 @@
import vllm_mlu.engine.arg_utils
import vllm_mlu.engine.llm_engine

View File

@@ -0,0 +1,120 @@
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
from vllm_mlu._mlu_utils import (BlockSizeInfo, USE_PAGED, get_device_name)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
logger = init_logger(__name__)
vllm__engine__arg_utils__EngineArgs__create_model_config_org = EngineArgs.create_model_config
vllm__engine__arg_utils__EngineArgs__create_engine_config_org = EngineArgs.create_engine_config
vllm__engine__arg_utils__EngineArgs__add_cli_args_org = EngineArgs.add_cli_args
vllm_engine__arg_utils__EngineArgs____post_init__org = EngineArgs.__post_init__
def vllm_engine__arg_utils__EngineArgs____post_init__(self,) -> None:
'''
=============================
Add by vllm_mlu
=============================
@brief: 1. In MLU3XX device, when the tensor_parallel_size > 1, the enforce_eager is forced to set False.
2. For unpaged mode, set default block_size=2048.
'''
unsupport_graph_device = "3" in get_device_name()
if unsupport_graph_device and self.tensor_parallel_size > 1 and self.enforce_eager != True:
self.enforce_eager = True
logger.warning("The current device only support eager mode, when the tensor_parallel_size > 1. "
"The param enforce_eager is forced to set True")
if not USE_PAGED and self.block_size == 16:
self.block_size = 2048
'''
==================
End of MLU Hijack
==================
'''
vllm_engine__arg_utils__EngineArgs____post_init__org(self)
def vllm__engine__arg_utils__EngineArgs__create_model_config(self) -> ModelConfig:
model_config = vllm__engine__arg_utils__EngineArgs__create_model_config_org(self)
'''
=============================
Modify by vllm_mlu
=============================
@brief: set context mlugraph info for model config
'''
model_config.set_context_mlugraph_info(
getattr(self, "enable_context_mlugraph", False),
getattr(self, "context_batch_size_to_capture", None),
getattr(self, "context_seq_len_to_capture", None))
'''
==================
End of MLU Hijack
==================
'''
return model_config
def vllm__engine__arg_utils__EngineArgs__create_engine_config(self) -> VllmConfig:
'''
=============================
Modify by vllm_mlu
=============================
@brief: disable custom_all_reduce, re-set block_size to support paged and unpaged mode.
'''
# MLU not support custom all reduce
self.disable_custom_all_reduce = True
BlockSizeInfo.set_block_size(self.block_size)
if not USE_PAGED and self.enable_chunked_prefill:
raise ValueError("Not support chunked_prefill in unpaged mode.")
engine_config = vllm__engine__arg_utils__EngineArgs__create_engine_config_org(self)
engine_config.cache_config.block_size = BlockSizeInfo.BLOCK_SIZE
'''
==================
End of MLU Hijack
==================
'''
return engine_config
@staticmethod
def vllm__engine__arg_utils__EngineArgs__add_cli_args(
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser = vllm__engine__arg_utils__EngineArgs__add_cli_args_org(parser)
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1. remove block_size choices, set default value to -1
2. add kv_cache_dtype choices of 'int8'
'''
for action in parser._actions:
if action.dest == "block_size":
action.choices = None
action.default = -1
elif action.dest == "kv_cache_dtype":
action.choices += ['int8']
'''
==================
End of MLU Hijack
==================
'''
return parser
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs.__post_init__,
vllm_engine__arg_utils__EngineArgs____post_init__)
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs.create_model_config,
vllm__engine__arg_utils__EngineArgs__create_model_config)
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs.create_engine_config,
vllm__engine__arg_utils__EngineArgs__create_engine_config)
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs.add_cli_args,
vllm__engine__arg_utils__EngineArgs__add_cli_args)

View File

@@ -0,0 +1,35 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
logger = init_logger(__name__)
# for client init/reset server scheduler profile data
async def vllm__engine__async_llm_engine__AsyncLLMEngine__init_scheduler_view(self):
for scheduler in self.engine.scheduler:
if hasattr(scheduler, "init_scheduler_view"):
scheduler.init_scheduler_view()
else:
logger.warning("Can not find any scheduler view, " +
"please 'export VLLM_SCHEDULER_PROFILE=true' first.")
# for client pulling server scheduler profile data
async def vllm__engine__async_llm_engine__AsyncLLMEngine__save_scheduler_view(self):
for idx, scheduler in enumerate(self.engine.scheduler):
if hasattr(scheduler, "save_scheduler_view"):
scheduler.save_scheduler_view(idx)
else:
logger.warning("Can not find any scheduler view, " +
"please 'export VLLM_SCHEDULER_PROFILE=true' first.")
MluHijackObject.apply_hijack(AsyncLLMEngine,
"init_scheduler_view",
vllm__engine__async_llm_engine__AsyncLLMEngine__init_scheduler_view)
MluHijackObject.apply_hijack(AsyncLLMEngine,
"save_scheduler_view",
vllm__engine__async_llm_engine__AsyncLLMEngine__save_scheduler_view)

View File

@@ -0,0 +1,209 @@
import time
from typing import Optional, List, Union, Mapping
from vllm.engine.llm_engine import LLMEngine
from vllm_mlu._mlu_utils import USE_PAGED, BlockSizeInfo
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.sampling_params import SamplingParams
from vllm.lora.request import LoRARequest
from vllm.logger import init_logger
from vllm.inputs import PromptType
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.utils import deprecate_kwargs
logger = init_logger(__name__)
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def vllm_engine__llm_engine__LLMEngine__add_request(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
scheduler as `engine.step()` is called. The exact scheduling policy is
determined by the scheduler.
Args:
request_id: The unique ID of the request.
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
:class:`~vllm.PoolingParams` for pooling.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Details:
- Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None.
- Create `n` number of :class:`~vllm.Sequence` objects.
- Create a :class:`~vllm.SequenceGroup` object
from the list of :class:`~vllm.Sequence`.
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
Example:
>>> # initialize engine
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> # set request arguments
>>> example_prompt = "Who is the president of the United States?"
>>> sampling_params = SamplingParams(temperature=0.0)
>>> request_id = 0
>>>
>>> # add the request to the engine
>>> engine.add_request(
>>> str(request_id),
>>> example_prompt,
>>> SamplingParams(temperature=0.0))
>>> # continue the request processing
>>> ...
"""
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if priority != 0 and not self.scheduler_config.policy == "priority":
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
if isinstance(params, SamplingParams) \
and (params.guided_decoding or params.logits_processors) \
and self.scheduler_config.num_scheduler_steps > 1:
raise ValueError(
"Guided decoding and logits processors are not supported "
"in multi-step decoding")
if arrival_time is None:
arrival_time = time.time()
if self.tokenizer is not None:
self._validate_token_prompt(
prompt,
tokenizer=self.get_tokenizer(lora_request=lora_request))
preprocessed_inputs = self.input_preprocessor.preprocess(
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
processed_inputs = self.input_processor(preprocessed_inputs)
'''
=============================
Added by vllm_mlu
=============================
@brief: check input_len + output_len <= block_size
'''
def check_block_size_valid(input_len, output_len):
if BlockSizeInfo.BLOCK_SIZE < input_len + output_len:
raise ValueError(f"BLOCK_SIZE({BlockSizeInfo.BLOCK_SIZE}) can't smaller than " +
f"input_len({input_len}) + output_len({output_len}).")
if isinstance(params, SamplingParams):
output_len = params.max_tokens
# Check for 'prompt_token_ids' in different levels of processed_inputs
if not USE_PAGED:
for key in ['prompt_token_ids', 'encoder', 'decoder']:
if key in processed_inputs:
if key == 'prompt_token_ids':
input_len = len(processed_inputs[key])
elif isinstance(processed_inputs[key], dict) and 'prompt_token_ids' in processed_inputs[key]:
input_len = len(processed_inputs[key]['prompt_token_ids'])
else:
continue
check_block_size_valid(input_len, output_len)
'''
==================
End of modification
==================
'''
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
)
def vllm__engine__llm_engine__LLMEngine__get_latency(self):
latency = self.model_executor.get_latency()
return latency
def vllm__engine__llm_engine__LLMEngine__get_memory_usage(self):
return self.model_executor.get_memory_usage()
def vllm__engine__llm_engine__LLMEngine__get_block_usage(self):
assert len(self.scheduler) == 1, f"Only support pipeline_parallel_size=1."
num_free_gpu_blocks = self.scheduler[0].block_manager.get_num_free_gpu_blocks()
num_free_cpu_blocks = self.scheduler[0].block_manager.get_num_free_cpu_blocks()
return (num_free_gpu_blocks, num_free_cpu_blocks)
# for client init/reset server scheduler profile data
def vllm__engine__llm_engine__LLMEngine__init_scheduler_view(self):
for scheduler in self.scheduler:
if hasattr(scheduler, "init_scheduler_view"):
scheduler.init_scheduler_view()
else:
logger.warning("Can not find any scheduler view, " +
"please 'export VLLM_SCHEDULER_PROFILE=true' first.")
# for client pulling server scheduler profile data
def vllm__engine__llm_engine__LLMEngine__save_scheduler_view(self):
for idx, scheduler in enumerate(self.scheduler):
if hasattr(scheduler, "save_scheduler_view"):
scheduler.save_scheduler_view(idx)
else:
logger.warning("Can not find any scheduler view, " +
"please 'export VLLM_SCHEDULER_PROFILE=true' first.")
MluHijackObject.apply_hijack(LLMEngine,
"init_scheduler_view",
vllm__engine__llm_engine__LLMEngine__init_scheduler_view)
MluHijackObject.apply_hijack(LLMEngine,
"save_scheduler_view",
vllm__engine__llm_engine__LLMEngine__save_scheduler_view)
MluHijackObject.apply_hijack(LLMEngine,
LLMEngine.add_request,
vllm_engine__llm_engine__LLMEngine__add_request)
MluHijackObject.apply_hijack(LLMEngine,
"get_latency",
vllm__engine__llm_engine__LLMEngine__get_latency)
MluHijackObject.apply_hijack(LLMEngine,
"get_memory_usage",
vllm__engine__llm_engine__LLMEngine__get_memory_usage)
MluHijackObject.apply_hijack(LLMEngine,
"get_block_usage",
vllm__engine__llm_engine__LLMEngine__get_block_usage)

View File

@@ -0,0 +1,6 @@
from enum import Enum
class RPCSchedulerProfileRequest(Enum):
INIT_SCHEDULER_VIEW = 1
SAVE_SCHEDULER_VIEW = 2

View File

@@ -0,0 +1,32 @@
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.logger import init_logger
from vllm_mlu.engine.multiprocessing import RPCSchedulerProfileRequest
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
class MQLLMEngineClient_V2(MQLLMEngineClient):
async def init_scheduler_view(self):
"""Send INIT_SCHEDULER_VIEW request to RPC Server."""
await self._send_one_way_rpc_request(
request=RPCSchedulerProfileRequest.INIT_SCHEDULER_VIEW,
socket=self.input_socket)
async def save_scheduler_view(self):
"""Send SAVE_SCHEDULER_VIEW request to RPC Server."""
await self._send_one_way_rpc_request(
request=RPCSchedulerProfileRequest.SAVE_SCHEDULER_VIEW,
socket=self.input_socket)
MluHijackObject.apply_hijack(MQLLMEngineClient,
"init_scheduler_view",
MQLLMEngineClient_V2.init_scheduler_view)
MluHijackObject.apply_hijack(MQLLMEngineClient,
"save_scheduler_view",
MQLLMEngineClient_V2.save_scheduler_view)

View File

@@ -0,0 +1,183 @@
import pickle
from typing import Iterator, List, Optional, Union
import cloudpickle
import zmq
from vllm import SamplingParams
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (RPCAbortRequest, RPCProcessRequest,
RPCUProfileRequest)
from vllm.engine.llm_engine import LLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (IPC_DATA_EXT, IPC_HEALTH_EXT,
IPC_INPUT_EXT, IPC_OUTPUT_EXT,
RPCAbortRequest, RPCProcessRequest,
RPCUProfileRequest)
from vllm.engine.multiprocessing.engine import (MQLLMEngine,
POLLING_TIMEOUT_MS)
from vllm.logger import init_logger
from vllm_mlu.engine.multiprocessing import RPCSchedulerProfileRequest
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
vllm__engine__multiprocessing__engine__MQLLMEngine____init____org = MQLLMEngine.__init__
class MQLLMEngine_V2(MQLLMEngine):
def __init__(self,
ipc_path: str,
use_async_sockets: bool,
*args,
log_requests: bool = True,
**kwargs) -> None:
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
kwargs['use_cached_outputs'] = True
self.engine = LLMEngine(*args, **kwargs)
self.log_requests = log_requests
self.use_async_sockets = use_async_sockets
if self.use_async_sockets:
self.engine.process_request_outputs_callback = \
self._async_socket_engine_callback
self.ctx = zmq.Context() # type: ignore[attr-defined]
# Receive input from the client.
self.input_socket = self.ctx.socket(zmq.constants.PULL)
self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")
# Send output stream back to client.
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
# Send heartbeats back to client.
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
# Error state.
self._errored_with: Optional[BaseException] = None
self.collect_scheduler_view = False
def run_engine_loop(self):
"""Core busy loop of the LLMEngine."""
while True:
if not self.engine.has_unfinished_requests():
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
# When there's no work, check on engine health and send
# health status back to client
self._health_check()
self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.")
# Handle any input from the client.
self.handle_new_input()
'''
=============================
Add by vllm_mlu
=============================
@brief: support scheduler view
'''
if self.collect_scheduler_view:
self.collect_scheduler_view = False
continue
'''
==================
End of MLU Hijack
==================
'''
# Engine step.
request_outputs = self.engine_step()
# Send request outputs (if async, done in engine_step callback).
if not self.use_async_sockets:
self._send_outputs(request_outputs)
def handle_new_input(self):
"""Handle new input from the socket"""
try:
while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)
'''
=============================
Add by vllm_mlu
=============================
@brief: support scheduler view
'''
if isinstance(request, RPCProcessRequest):
if len(frames) > 1:
# Use cloudpickle for logits processors
assert isinstance(request.params, SamplingParams)
lprocs = cloudpickle.loads(frames[1].buffer)
request.params.logits_processors = lprocs
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
elif isinstance(request, RPCUProfileRequest):
if request == RPCUProfileRequest.START_PROFILE:
self.start_profile()
else:
self.stop_profile()
elif isinstance(request, RPCSchedulerProfileRequest):
self.collect_scheduler_view = True
if request == RPCSchedulerProfileRequest.INIT_SCHEDULER_VIEW:
self.init_scheduler_view()
elif request == RPCSchedulerProfileRequest.SAVE_SCHEDULER_VIEW:
self.save_scheduler_view()
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
'''
==================
End of MLU Hijack
==================
'''
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
raise e
def init_scheduler_view(self):
"""Init scheduler view."""
self.engine.init_scheduler_view()
def save_scheduler_view(self):
"""Save scheduler view."""
self.engine.save_scheduler_view()
MluHijackObject.apply_hijack(MQLLMEngine,
MQLLMEngine.__init__,
MQLLMEngine_V2.__init__)
MluHijackObject.apply_hijack(MQLLMEngine,
MQLLMEngine.run_engine_loop,
MQLLMEngine_V2.run_engine_loop)
MluHijackObject.apply_hijack(MQLLMEngine,
MQLLMEngine.handle_new_input,
MQLLMEngine_V2.handle_new_input)
MluHijackObject.apply_hijack(MQLLMEngine,
"init_scheduler_view",
MQLLMEngine_V2.init_scheduler_view)
MluHijackObject.apply_hijack(MQLLMEngine,
"save_scheduler_view",
MQLLMEngine_V2.save_scheduler_view)

View File

@@ -0,0 +1 @@
import vllm_mlu.entrypoints.llm

View File

@@ -0,0 +1,313 @@
import time
from tqdm import tqdm
from typing import Optional, List, Union, Dict, Any
from vllm.entrypoints.llm import LLM
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption)
from vllm.engine.llm_engine import LLMEngine
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_args
from vllm_mlu._mlu_utils import VLLM_LATENCY_DEBUG_EN, VLLM_LATENCY_DEBUG_WITH_DEVICE_EN
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.mlu_metric import LLMMetric
from vllm_mlu.dump_info import LLMDumpInfo
from vllm.logger import init_logger
logger = init_logger(__name__)
@deprecate_args(
start_index=2, # Ignore self and model
is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
additional_message=(
"All positional arguments other than `model` will be "
"replaced with keyword arguments in an upcoming version."),
)
def vllm__entrypoints__llm__LLM____init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
override_pooler_config: Optional[PoolerConfig] = None,
**kwargs,
) -> None:
'''
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False.
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1) Initialize LLMDumpInfo
2) Initialize context mlugraph params
'''
LLM.dump_info.init_param(
tensor_parallel_size=tensor_parallel_size, dtype=dtype,
kv_cache_dtype=kwargs.get('kv_cache_dtype', 'default_value'),
quantization=quantization,
model=model, trust_remote_code=kwargs.get('trust_remote_code', 'default_value'))
enable_context_mlugraph = kwargs.pop("enable_context_mlugraph", False)
context_batch_size_to_capture = kwargs.pop("context_batch_size_to_capture", None)
context_seq_len_to_capture = kwargs.pop("context_seq_len_to_capture", None)
'''
==================
End of MLU Hijack
==================
'''
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
engine_args = EngineArgs(
model=model,
task=task,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config,
**kwargs,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: set context mlugraph params for EngineArgs
'''
setattr(engine_args, "enable_context_mlugraph", enable_context_mlugraph)
setattr(engine_args, "context_batch_size_to_capture", context_batch_size_to_capture)
setattr(engine_args, "context_seq_len_to_capture", context_seq_len_to_capture)
'''
==================
End of MLU Hijack
==================
'''
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
self.engine_class = self.get_engine_class()
# TODO(rob): enable mp by default (issue with fork vs spawn)
self.llm_engine = self.engine_class.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()
'''
=============================
Modify by vllm_mlu
=============================
@brief: Get Cpuinfo member for vllm
'''
LLM.dump_info.memory_usage()
'''
==================
End of MLU Hijack
==================
'''
def vllm__entrypoints__llm__LLM__get_metrics(
self,
metrics_idx_start,
only_average,
input_len,
output_len,
tp_nums,
quantization,
dump_info=None,
show_per_iter=False,
) -> None:
'''
@brief:该函数用来打印vLLM调用generate接口过程中代码统计的各项性能指标数据
@params:
metrics_idx_start: 考虑存在调用generate接口为warmup过程的情况
因此设置该参数可忽略统计[0,metrics_idx_start)之间的数据,默认为0,即所有性能数据有效。
only_average: True 只打印N次调用generate接口的平均性能 False 打印每次调用generate接口的性能及其均值 若N次性能数据波动较大需自行排查测试环境是否稳定。
其余参数:均为模型配置参数
'''
if VLLM_LATENCY_DEBUG_EN:
self.metric.calc_metric(self.llm_engine.model_config.model,
self.llm_engine.model_config.dtype,
metrics_idx_start, only_average,
input_len, output_len, tp_nums,
quantization, dump_info, show_per_iter)
else:
print("Warnning:please set VLLM_LATENCY_DEBUG=true!")
def vllm__entrypoints__llm__LLM___run_engine(
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(
total=num_requests,
desc="Processed prompts",
dynamic_ncols=True,
postfix=(f"est. speed input: {0:.2f} toks/s, "
f"output: {0:.2f} toks/s"),
)
'''
=============================
Added by vllm_mlu
=============================
'''
is_latency_debug = VLLM_LATENCY_DEBUG_EN
# Record start
if is_latency_debug:
total_request_num = self.llm_engine.get_num_unfinished_requests()
self.dump_info.capture_cpu_info()
peak_memory, block_memory, num_total_gpu_blocks, num_total_cpu_blocks = \
self.llm_engine.get_memory_usage()
self.metric.update_memory_usage(peak_memory, block_memory, num_total_gpu_blocks, num_total_cpu_blocks)
e2e_start_time = self.metric.get_mlu_cost_time()
'''
==================
End of addition
==================
'''
# Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():
'''
=============================
Added by vllm_mlu
=============================
'''
if is_latency_debug :
self.dump_info.memory_usage()
start_time = self.metric.get_mlu_cost_time()
'''
==================
End of addition
==================
'''
step_outputs = self.llm_engine.step()
'''
=============================
Added by vllm_mlu
=============================
'''
if is_latency_debug:
end_time = self.metric.get_mlu_cost_time()
step_latency = end_time - start_time
if len(step_outputs) > 0:
batch_size = len(step_outputs)
assert batch_size == total_request_num, \
f"LLM has received {total_request_num} requests, but only processed {batch_size} requests in the current step.\n" + \
f"If you are running benchmark_latency test, please check if the input is correct.\n" + \
f"Otherwise, please set env VLLM_LATENCY_DEBUG=false, then run test again.\n"
num_free_gpu_blocks, num_free_cpu_blocks = self.llm_engine.get_block_usage()
self.metric.update_step_block_usage(num_free_gpu_blocks, num_free_cpu_blocks)
self.metric.update_step_latency(step_latency)
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
self.metric.update_step_latency_device(self.llm_engine.get_latency())
self.dump_info.memory_usage()
'''
==================
End of addition
==================
'''
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(
len(stp.token_ids) for stp in output.outputs)
out_spd = (total_out_toks /
pbar.format_dict["elapsed"])
pbar.postfix = (
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
pbar.update(1)
'''
=============================
Added by vllm_mlu
=============================
'''
if is_latency_debug:
e2e_end_time = self.metric.get_mlu_cost_time()
e2e_latency = e2e_end_time - e2e_start_time
self.metric.add_metrics(batch_size, e2e_latency)
'''
==================
End of addition
==================
'''
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))
LLM.metric = LLMMetric()
LLM.dump_info = LLMDumpInfo()
MluHijackObject.apply_hijack(LLM,
LLM.__init__,
vllm__entrypoints__llm__LLM____init__)
MluHijackObject.apply_hijack(LLM,
"get_metrics",
vllm__entrypoints__llm__LLM__get_metrics)
MluHijackObject.apply_hijack(LLM,
LLM._run_engine,
vllm__entrypoints__llm__LLM___run_engine)

View File

@@ -0,0 +1 @@
import vllm_mlu.entrypoints.openai.serving_engine

View File

@@ -0,0 +1,49 @@
from http import HTTPStatus
from typing import Optional
from vllm.entrypoints.openai.protocol import ErrorResponse
from vllm.entrypoints.openai.serving_engine import OpenAIServing, AnyRequest
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
logger = init_logger(__name__)
async def vllm__entrypoints__openai__serving_engine__OpenAIServing___check_model(
self,
request: AnyRequest,
) -> Optional[ErrorResponse]:
if self._is_model_supported(request.model):
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
return None
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.prompt_adapter_requests
]:
return None
'''
=============================
Modify by vllm_mlu
=============================
@brief: when client send a request with model=init/save_scheduler_view,
scheduler will dump profile data.
'''
if request.model == "init_scheduler_view":
await self.engine_client.init_scheduler_view()
if request.model == "save_scheduler_view":
await self.engine_client.save_scheduler_view()
'''
==================
End of MLU Hijack
==================
'''
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
MluHijackObject.apply_hijack(OpenAIServing,
OpenAIServing._check_model,
vllm__entrypoints__openai__serving_engine__OpenAIServing___check_model)

View File

@@ -0,0 +1,3 @@
import vllm_mlu.executor.mlu_executor
import vllm_mlu.executor.multiproc_mlu_executor
import vllm_mlu.executor.ray_mlu_executor

View File

@@ -0,0 +1,35 @@
from vllm.executor.mlu_executor import MLUExecutor
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__executor__mlu_executor__MLUExecutor__get_latency(self) -> float:
'''
requires that torch.mlu.synchronize() be executed before this function
for getting an accurate reading
'''
latency = self.driver_worker.get_latency()
return latency
def vllm__executor__mlu_executor__MLUExecutor__recapture_model(
self,
context_batch_size_to_capture,
context_seq_len_to_capture
) -> None:
return self.driver_worker.recapture_model(context_batch_size_to_capture,
context_seq_len_to_capture)
def vllm__executor__mlu_executor__MLUExecutor__get_memory_usage(self):
return self.driver_worker.get_memory_usage()
MluHijackObject.apply_hijack(MLUExecutor,
"get_latency",
vllm__executor__mlu_executor__MLUExecutor__get_latency)
MluHijackObject.apply_hijack(MLUExecutor,
"recapture_model",
vllm__executor__mlu_executor__MLUExecutor__recapture_model)
MluHijackObject.apply_hijack(MLUExecutor,
"get_memory_usage",
vllm__executor__mlu_executor__MLUExecutor__get_memory_usage)

View File

@@ -0,0 +1,16 @@
from vllm.executor.multiproc_mlu_executor import MultiprocessingMLUExecutor
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__executor__multiproc_mlu_executor__MultiprocessingMLUExecutor__recapture_model(
self,
context_batch_size_to_capture,
context_seq_len_to_capture
) -> None:
return self._run_workers("recapture_model",
context_batch_size_to_capture=context_batch_size_to_capture,
context_seq_len_to_capture=context_seq_len_to_capture)
MluHijackObject.apply_hijack(MultiprocessingMLUExecutor,
"recapture_model",
vllm__executor__multiproc_mlu_executor__MultiprocessingMLUExecutor__recapture_model)

View File

@@ -0,0 +1,267 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional
import vllm.envs as envs
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id)
from vllm_mlu._mlu_utils import VLLM_LATENCY_DEBUG, VLLM_LATENCY_DEBUG_NO_DEVICE
from vllm.executor.ray_mlu_executor import RayMLUExecutor
from vllm_mlu.mlu_hijack_utils import MluHijackObject
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
def vllm__executor__ray_mlu_executor__RayMLUExecutor___init_workers_ray(
self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if (self.parallel_config.tensor_parallel_size == 1
and self.parallel_config.pipeline_parallel_size == 1):
# For single GPU case, we use a ray worker with constrained memory.
num_gpus = self.cache_config.gpu_memory_utilization
else:
# Otherwise, the ray workers are allocated with a full GPU.
num_gpus = 1
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
# Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank.
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers.
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
if self.use_ray_spmd_worker:
self.workers.append(worker)
else:
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
**worker_wrapper_kwargs)
else:
# Else, added to the list of workers.
self.workers.append(worker)
logger.debug("workers: %s", self.workers)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")
worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = ray.get(worker.get_node_ip.remote())
return (ip != driver_ip, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)
node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids = [int(x) for x in gpu_ids]
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
all_ips = set(worker_ips + [driver_ip])
n_ips = len(all_ips)
n_nodes = len(node_workers)
if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP` or "
"`HOST_IP` environment variable, make sure it is unique for"
" each node.")
VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"MLU_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
**({
"VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND
} if envs.VLLM_ATTENTION_BACKEND is not None else {}),
"VLLM_LATENCY_DEBUG":
'1' if VLLM_LATENCY_DEBUG else '0',
"VLLM_LATENCY_DEBUG_NO_DEVICE":
'1' if VLLM_LATENCY_DEBUG_NO_DEVICE else '0',
}, ) for (node_id, _) in worker_node_and_gpu_ids]
self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)
self._run_workers("update_environment_variables",
all_args=self._get_env_vars_to_be_updated())
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(
self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[RayWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []
# Enforce rank order for correct rank to return final output.
for index, worker in enumerate(self.workers):
# The driver worker is rank 0 and not in self.workers.
rank = index + 1
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(worker)
def vllm__executor__ray_mlu_executor__RayMLUExecutor__get_latency(self):
'''
requires that torch.mlu.synchronize() be executed before this function
for getting an accurate reading
'''
return self.driver_worker.execute_method("get_latency")
def vllm__executor__ray_mlu_executor__RayMLUExecutor__recapture_model(
self,
context_batch_size_to_capture,
context_seq_len_to_capture
) -> None:
return self._run_workers("recapture_model",
context_batch_size_to_capture=context_batch_size_to_capture,
context_seq_len_to_capture=context_seq_len_to_capture)
def vllm__executor__ray_mlu_executor__RayMLUExecutor__get_memory_usage(self):
return self.driver_worker.execute_method("get_memory_usage")
MluHijackObject.apply_hijack(RayMLUExecutor,
RayMLUExecutor._init_workers_ray,
vllm__executor__ray_mlu_executor__RayMLUExecutor___init_workers_ray)
MluHijackObject.apply_hijack(RayMLUExecutor,
"get_latency",
vllm__executor__ray_mlu_executor__RayMLUExecutor__get_latency)
MluHijackObject.apply_hijack(RayMLUExecutor,
"recapture_model",
vllm__executor__ray_mlu_executor__RayMLUExecutor__recapture_model)
MluHijackObject.apply_hijack(RayMLUExecutor,
"get_memory_usage",
vllm__executor__ray_mlu_executor__RayMLUExecutor__get_memory_usage)

View File

@@ -0,0 +1,4 @@
import vllm_mlu.lora.ops
import vllm_mlu.lora.fully_sharded_layers
import vllm_mlu.lora.layers
import vllm_mlu.lora.punica

View File

@@ -0,0 +1,65 @@
from typing import Optional
import torch
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.lora.fully_sharded_layers import RowParallelLinearWithShardedLoRA
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__lora__fully_sharded_layers__RowParallelLinearWithShardedLoRA__apply(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
residual: Optional[torch.Tensor]
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual and bias in matmul
'''
output = self.base_layer.quant_method.apply(
self.base_layer, x, bias, residual)
'''
==================
End of MLU Hijack
==================
'''
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros(
(x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device,
)
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output
# tensor, which is a partial sum due to row parallel. All that
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
if self.bias_stacked is not None:
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
bias = bias[self.punica_wrapper.token_lora_indices]
bias[self.punica_wrapper.token_lora_indices == -1] = 0
output += bias
self.punica_wrapper.add_expand_slice(output, buffer,
self.lora_b_stacked, start_idx,
shard_size)
output = output.view(*out_orig_shape)
return output
MluHijackObject.apply_hijack(RowParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA.apply,
vllm__lora__fully_sharded_layers__RowParallelLinearWithShardedLoRA__apply)

View File

@@ -0,0 +1,219 @@
from typing import List, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
split_tensor_along_last_dim,
tensor_model_parallel_all_reduce)
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
RowParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
apply_bias)
from vllm_mlu.model_executor.layers.rotary_embedding import (
MLURotaryEmbedding, MLULinearScalingRotaryEmbedding)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
vllm__lora__layers__ColumnParallelLinearWithLoRA__forward_org = ColumnParallelLinearWithLoRA.forward
'''
=============================
Modify by vllm_mlu
=============================
@brief: add smooth_quant_scale parameter.
'''
def vllm__lora__layers__ColumnParallelLinearWithLoRA__forward(
self,
input_,
smooth_quant_scale: Optional[torch.Tensor] = None
):
assert smooth_quant_scale is None, "LoRA does not support smooth quant yet."
return vllm__lora__layers__ColumnParallelLinearWithLoRA__forward_org(self, input_)
'''
==================
End of MLU Hijack
==================
'''
def vllm__lora__layers__RowParallelLinearWithLoRA__apply(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
residual: Optional[torch.Tensor]
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual and bias in matmul
'''
output = self.base_layer.quant_method.apply(
self.base_layer, x, bias, residual)
'''
==================
End of MLU Hijack
==================
'''
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
return output
def vllm__lora__layers__RowParallelLinearWithLoRA__forward(
self,
input_: torch.Tensor,
residual: Optional[torch.Tensor] = None
):
# Set up backprop all-reduce.
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
# TODO: simplify code below
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1) apply residual fusion in matmul like RowParallelLinear
2) add bias in matmul, not after all reduce
'''
# Matrix multiply.
bias_ = (None if (self.base_layer.tp_rank > 0 or self.base_layer.skip_bias_add) else self.base_layer.bias)
residual_ = None if self.base_layer.tp_rank > 0 else residual
output_parallel = self.apply(input_parallel, bias_, residual_)
'''
==================
End of MLU Hijack
==================
'''
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
'''
=============================
Modify by vllm_mlu
=============================
@brief: do not add bias after all_reduce
'''
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
'''
==================
End of MLU Hijack
==================
'''
return output, output_bias
def vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
scaling_factors = (list(lora_config.long_lora_scaling_factors)
if lora_config.long_lora_scaling_factors else [])
'''
=============================
Modify by vllm_mlu
=============================
@brief: change LinearScalingRotaryEmbedding to MLULinearScalingRotaryEmbedding
'''
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
self.base_layer, MLULinearScalingRotaryEmbedding) else 1.0)
scaling_factors = sorted(
list(set([base_scaling_factor] + scaling_factors)))
self.base_layer = MLULinearScalingRotaryEmbedding(
self.base_layer.head_size,
self.base_layer.rotary_dim,
self.base_layer.max_position_embeddings,
self.base_layer.base,
self.base_layer.is_neox_style,
scaling_factors,
self.base_layer.dtype,
)
'''
==================
End of MLU Hijack
==================
'''
def vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__forward(
self,
positions: torch.Tensor,
qk: torch.Tensor
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: change function prototype to meet forward_mlu in rope
'''
return self.base_layer(
positions,
qk,
offsets=self.punica_wrapper.long_lora_indices,
)
'''
==================
End of MLU Hijack
==================
'''
@classmethod
def vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
'''
=============================
Modify by vllm_mlu
=============================
@brief: change origin rope type to mlu rope
'''
return (type(source_layer) is MLULinearScalingRotaryEmbedding
or type(source_layer) is MLURotaryEmbedding)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(RowParallelLinearWithLoRA,
RowParallelLinearWithLoRA.apply,
vllm__lora__layers__RowParallelLinearWithLoRA__apply)
MluHijackObject.apply_hijack(ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithLoRA.forward,
vllm__lora__layers__ColumnParallelLinearWithLoRA__forward)
MluHijackObject.apply_hijack(RowParallelLinearWithLoRA,
RowParallelLinearWithLoRA.forward,
vllm__lora__layers__RowParallelLinearWithLoRA__forward)
MluHijackObject.apply_hijack(LinearScalingRotaryEmbeddingWithLora,
LinearScalingRotaryEmbeddingWithLora.create_lora_weights,
vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__create_lora_weights)
MluHijackObject.apply_hijack(LinearScalingRotaryEmbeddingWithLora,
LinearScalingRotaryEmbeddingWithLora.forward,
vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__forward)
MluHijackObject.apply_hijack(LinearScalingRotaryEmbeddingWithLora,
LinearScalingRotaryEmbeddingWithLora.can_replace_layer,
vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__can_replace_layer)

View File

@@ -0,0 +1,3 @@
import vllm_mlu.lora.ops.sgmv_expand
import vllm_mlu.lora.ops.sgmv_expand_slice
import vllm_mlu.lora.ops.sgmv_shrink

View File

@@ -0,0 +1,233 @@
import torch
import triton
import triton.language as tl
from vllm_mlu.lora.ops.utils import adjust_kernel_block_size
@triton.jit
def _sgmv_expand_kernel_mlu(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
xm_stride,
xk_stride, # 1
l0_stride, # hidden_size*max_rank
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
The sgmv's expand triton kernel is based on GroupGEMM.
"""
pid = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = tl.arange(0, BLOCK_K)
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
a_ptr = (input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride +
offset_k[None, :] * xk_stride, )
b_ptr = (lora_ptr + l0_stride * lora_index +
offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride)
'''
==================
End of MLU Hijack
==================
'''
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K)):
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
if EVEN_K:
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
else:
tiled_a = tl.load(a_ptr,
mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)),
other=0)
tiled_b = tl.load(b_ptr,
mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)),
other=0)
'''
==================
End of MLU Hijack
==================
'''
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * xk_stride
b_ptr += BLOCK_K * lora_n_stride
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
offset_cn[None, :] * cn_stride)
M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] <
(cur_seq_start + M)) & (offset_cn[None, :] < N)
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@torch.inference_mode()
def sgmv_expand_mlu(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
'''
=============================
Modify by vllm_mlu
=============================
@brief: Workaround: Adjust block size to meet mlu restrictions.
The grid of mlu triton kernel must less than 65536, it will be out of bound when
the input seq is very long, and causes runtime error. So we need to adjust the block
size to avoid this.
'''
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32)
'''
==================
End of MLU Hijack
==================
'''
BLOCK_K = 16
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
batches,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: call _sgmv_expand_kernel_mlu
'''
_sgmv_expand_kernel_mlu[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
)
'''
==================
End of MLU Hijack
==================
'''
return

View File

@@ -0,0 +1,244 @@
import torch
import triton
import triton.language as tl
from vllm_mlu.lora.ops.utils import adjust_kernel_block_size
@triton.jit
def _sgmv_expand_slice_kernel_mlu(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
xm_stride,
xk_stride, # 1
l0_stride, # hidden_size*max_rank
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
slice_offset,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
Similar to the 'sgmv_expand' operator, but with an added parameter
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
might be that in the future, we could implement a fusion operator to
achieve the current functionality instead of having to call it multiple
times.
"""
pid = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = tl.arange(0, BLOCK_K)
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
a_ptr = (input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride +
offset_k[None, :] * xk_stride, )
b_ptr = (lora_ptr + l0_stride * lora_index +
offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride)
'''
==================
End of MLU Hijack
==================
'''
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K)):
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
if EVEN_K:
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
else:
tiled_a = tl.load(a_ptr,
mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)),
other=0)
tiled_b = tl.load(b_ptr,
mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)),
other=0)
'''
==================
End of MLU Hijack
==================
'''
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * xk_stride
b_ptr += BLOCK_K * lora_n_stride
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
offset_cn[None, :] * cn_stride)
M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
(slice_offset + N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@torch.inference_mode()
def sgmv_expand_slice_mlu(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
) -> None:
"""_summary_
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert slice_size == lora_b_weights.size(-2)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
'''
=============================
Modify by vllm_mlu
=============================
@brief: Workaround: Adjust block size to meet mlu restrictions.
The grid of mlu triton kernel must less than 65536, it will be out of bound when
the input seq is very long, and causes runtime error. So we need to adjust the block
size to avoid this.
'''
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32)
'''
==================
End of MLU Hijack
==================
'''
BLOCK_K = 16
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
batches,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: call _sgmv_expand_kernel_mlu
'''
_sgmv_expand_slice_kernel_mlu[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
slice_offset,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
)
'''
==================
End of MLU Hijack
==================
'''
return

View File

@@ -0,0 +1,226 @@
import torch
import triton
import triton.language as tl
from vllm_mlu.lora.ops.utils import adjust_kernel_block_size
@triton.jit
def _sgmv_shrink_kernel_mlu(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
scaling,
xm_stride, # hidden_size
xk_stride, # 1
l0_stride, # hidden_size*max_rank
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
):
"""
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
introducing SPLIT-K can improve performance
"""
pid = tl.program_id(axis=0)
pid_sk = tl.program_id(axis=1)
cur_batch = tl.program_id(axis=2)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
a_ptr = (input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride +
offset_k[None, :] * xk_stride)
b_ptr = (lora_ptr + l0_stride * lora_index + offset_n[None, :] * lora_k_stride +
offset_k[:, None] * lora_n_stride)
'''
==================
End of MLU Hijack
==================
'''
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
if EVEN_K:
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
tiled_a = tl.load(a_ptr,
mask=((offset_k[None, :] < k_remaining) & (offset_m[:, None] < M)),
other=0.0)
tiled_b = tl.load(b_ptr,
mask=((offset_k[:, None] < k_remaining) & (offset_n[None, :] < N)),
other=0.0)
'''
==================
End of MLU Hijack
==================
'''
accumulator += tl.dot(tiled_a, tiled_b)
a_ptr += BLOCK_K * SPLIT_K * xk_stride
b_ptr += BLOCK_K * SPLIT_K * lora_n_stride
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
offset_cn[None, :] * cn_stride)
c_mask = (offset_cm[:, None] <
(cur_seq_start + M)) & (offset_cn[None, :] < N)
accumulator *= scaling
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
@torch.inference_mode()
def sgmv_shrink_mlu(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_a_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor.
"""
assert inputs.dtype == lora_a_weights.dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
assert lora_a_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_a_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert inputs.is_contiguous()
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
assert lora_a_weights.size(1) == 1
lora_a_weights = lora_a_weights.squeeze(dim=1)
else:
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
assert lora_a_weights.is_contiguous()
assert output_tensor.is_contiguous()
# TODO tuning this config
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
'''
=============================
Modify by vllm_mlu
=============================
@brief: Workaround: adjust block size to meet mlu restrictions.
The grid of mlu triton kernel must less than 65536, it will be out of bound when
the input seq is very long, and causes runtime error. So we need to adjust the block
size to avoid this.
'''
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 16)
'''
==================
End of MLU Hijack
==================
'''
BLOCK_K = 32
SPLIT_K = 8
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
SPLIT_K,
batches,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: call _sgmv_shrink_kernel_mlu
'''
_sgmv_shrink_kernel_mlu[grid](
inputs,
lora_a_weights,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_a_weights.stride(0),
lora_a_weights.stride(1),
lora_a_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
)
'''
==================
End of MLU Hijack
==================
'''
return

View File

@@ -0,0 +1,38 @@
from typing import Tuple
from math import ceil
_MLU_MAX_GRID_SIZE = 65536
def adjust_kernel_block_size(
m: int,
block_m: int,
n: int,
block_n: int
) -> Tuple[int, int]:
"""Adjust block size to meet mlu triton grid restrictions.
Calculation of the max block size in candidates list:
LLama3.1-8b-tp1 max n is 14336
LLama3.1-70b-tp4 max n is 7168
LLama3.1-405b-tp8 max n is 6656
when n is 14336, the max sequence length of block size 256 can be
floor(65536 / ceil(14336 / 256)) * 256 = 299520.
"""
candidates_list = [16, 32, 64, 96, 128, 192, 256]
candidates_list_len = len(candidates_list)
m_idx = 1
n_idx = 0 if block_n == 16 else 1
while m_idx < candidates_list_len and n_idx < candidates_list_len:
block_m = candidates_list[m_idx]
block_n = candidates_list[n_idx]
if ceil(m / block_m) * ceil(n / block_n) < _MLU_MAX_GRID_SIZE:
break
if m_idx < candidates_list_len:
m_idx += 1
if n_idx < candidates_list_len:
n_idx += 1
if ceil(m / block_m) * ceil(n / block_n) >= _MLU_MAX_GRID_SIZE:
raise ValueError(f"the max seq len {m} is too long for lora triton kernel")
return block_m, block_n

View File

@@ -0,0 +1,115 @@
from typing import Optional
import torch
from vllm.lora.punica import PunicaWrapper
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.lora.ops.sgmv_expand import sgmv_expand_mlu
from vllm_mlu.lora.ops.sgmv_expand_slice import sgmv_expand_slice_mlu
from vllm_mlu.lora.ops.sgmv_shrink import sgmv_shrink_mlu
def vllm__lora__punica__PunicaWrapper__shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
#No LoRA request, so return directly
if self.no_lora:
return
'''
=============================
Modify by vllm_mlu
=============================
@brief: Change function from sgmv_shrink to sgmv_shrink_mlu.
'''
sgmv_shrink_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
'''
==================
End of MLU Hijack
==================
'''
def vllm__lora__punica__PunicaWrapper__expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
'''
=============================
Modify by vllm_mlu
=============================
@brief: Change function from sgmv_expand to sgmv_expand_mlu.
'''
sgmv_expand_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
add_input,
)
'''
==================
End of MLU Hijack
==================
'''
def vllm__lora__punica__PunicaWrapper__expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
'''
=============================
Modify by vllm_mlu
=============================
@brief: Change function from sgmv_expand_slice to sgmv_expand_slice_mlu.
'''
sgmv_expand_slice_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_input,
)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(PunicaWrapper,
PunicaWrapper.shrink_prefill,
vllm__lora__punica__PunicaWrapper__shrink_prefill)
MluHijackObject.apply_hijack(PunicaWrapper,
PunicaWrapper.expand_prefill,
vllm__lora__punica__PunicaWrapper__expand_prefill)
MluHijackObject.apply_hijack(PunicaWrapper,
PunicaWrapper.expand_slice_prefill,
vllm__lora__punica__PunicaWrapper__expand_slice_prefill)

View File

@@ -0,0 +1 @@
from . import model_executor

View File

@@ -0,0 +1,2 @@
from . import layers
from . import model_loader

View File

@@ -0,0 +1,2 @@
from . import feed_forward
from . import linear

View File

@@ -0,0 +1,98 @@
import torch
from typing import Optional
from vllm_mlu.mlu_hijack_utils import MluHijackObject, set_is_gated
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm.distributed.parallel_state import get_tp_group, get_tensor_model_parallel_group
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed import get_tensor_model_parallel_rank
from vllm import _mlu_ops as mlu_ops
from vllm.lora.layers import BaseLayerWithLoRA
from vllm_mlu._mlu_utils import *
def vllm_mlu__model_executor__layers__feed_forward__FeedForward__forward(
self,
hidden_states,
residual: Optional[torch.Tensor] = None
):
self.prepare_weight()
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
residual_ = None if self.tp_rank > 0 else residual
if (self.use_bt_ffn and not isinstance(up_proj, BaseLayerWithLoRA)
and not isinstance(down_proj, BaseLayerWithLoRA)):
# The matmul formula is the following:
# mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual
# output = active(mul_out)
# Notes: We cannot use the activation function in matmul because it does not support gated operation
# we might support its in tmo matmul in the future
fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj.weight, up_proj.bias,
None, 'none', self.alpha, self.beta)
act_out = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
beta = 1.0 if residual_ is not None else 0.0
'''
=======================================
Modify by custom vllm_mlu
=======================================
@brief: call parallel op and abandon original reduce if parallel_num is set
'''
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
if is_parallel_enable:
rank = get_tensor_model_parallel_rank()
pg = get_tensor_model_parallel_group().device_group
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
out_ = mlu_ops.matmul_allreduce(cncl_comm, act_out, down_proj.weight, None, residual_,
self.alpha, beta, self.parallel_num)
else:
out_ = mlu_ops.matmul(act_out, down_proj.weight, None, residual_, 'none', self.alpha, beta)
'''
=======================================
End of custom MLU Hijack
=======================================
'''
# bias if existed need to add after second matmul according to the original design of vllm
'''
=============================
Modify by custom vllm_mlu
=============================
@brief: when preload_size is set, call GroupCoordinator.all_reduce() directly and
use async_op to set all_reduce paralleled with preload
'''
if self.reduce_results and self.tp_size > 1 and not is_parallel_enable:
if hasattr(self, 'preload_size') and self.preload_size > 0 and not self.is_prompt:
handle = get_tp_group().all_reduce(out_, async_op=True)
_MB = 1 << 20
mlu_ops.preload(self.preloaded_weights[0].data, self.preload_size * _MB)
preloaded_weights_size = self.preloaded_weights[0].numel() * self.preloaded_weights[0].element_size()
if preloaded_weights_size < (self.preload_size * _MB) and len(self.preloaded_weights) > 1:
mlu_ops.preload(self.preloaded_weights[1].data, (self.preload_size * _MB) - preloaded_weights_size)
handle.wait()
out = out_
else:
out = tensor_model_parallel_all_reduce(out_)
else:
out = out_
'''
=========================
End of custom MLU Hijack
=========================
'''
# do the bias add if needed
if not self.skip_bias_add:
out = out + down_proj.bias if down_proj.bias is not None else out
else:
return out, down_proj.bias
else:
fc1, bias = up_proj(hidden_states)
if bias is not None:
fc1 += bias
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
out, bias = down_proj(fc1, residual=residual_)
if self.skip_bias_add:
return out, bias
return out
MluHijackObject.apply_hijack(FeedForward,
FeedForward.forward,
vllm_mlu__model_executor__layers__feed_forward__FeedForward__forward)

View File

@@ -0,0 +1,116 @@
from typing import Optional
import torch
from vllm.distributed.parallel_state import get_tp_group, get_tensor_model_parallel_group
from vllm.distributed import get_tensor_model_parallel_rank, split_tensor_along_last_dim
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor.layers.linear import UnquantizedLinearMethod, RowParallelLinear
from vllm import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
def vllm__model_executor__layers__linear__UnquantizedLinearMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
beta = 1.0 if residual is not None else 0.0
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
'''
=====================================================
Modify by custom vllm_mlu
=====================================================
@brief: call parallel op if parallel_num is set
'''
if hasattr(self, 'parallel_num') and get_is_prompt():
rank = get_tensor_model_parallel_rank()
pg = get_tensor_model_parallel_group().device_group
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
return mlu_ops.matmul_allreduce(cncl_comm, x.view(-1, x.shape[-1]), layer.weight,
bias, residual, 1.0, beta, self.parallel_num).view(res_shape)
return mlu_ops.matmul(x.view(-1, x.shape[-1]), layer.weight, bias, residual, 'none', 1.0, beta).view(res_shape)
'''
=====================================================
End of custom MLU Hijack
=====================================================
'''
def vllm__model_executor__layers__linear__RowParallelLinear__forward(
self,
input_,
residual: Optional[torch.Tensor] = None
):
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
residual_ = None if self.tp_rank > 0 else residual
'''
=====================================================
Modify by custom vllm_mlu
=====================================================
@brief: abandon original reduce if parallel_num is set
'''
is_parallel_enable = hasattr(self.quant_method, 'parallel_num') and get_is_prompt()
'''
=====================================================
End of custom MLU Hijack
=====================================================
'''
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_,
residual=residual_)
'''
=============================
Modify by custom vllm_mlu
=============================
@brief: when preload_size is set, call GroupCoordinator.all_reduce() directly and
use async_op to set all_reduce paralleled with preload
'''
if self.reduce_results and self.tp_size > 1 and not is_parallel_enable:
if hasattr(self, 'preload_size') and self.preload_size > 0 and not self.is_prompt:
handle = get_tp_group().all_reduce(output_parallel, async_op=True)
_MB = 1 << 20
mlu_ops.preload(self.preloaded_weights[0].data, self.preload_size * _MB)
preloaded_weights_size = self.preloaded_weights[0].numel() * self.preloaded_weights[0].element_size()
if preloaded_weights_size < (self.preload_size * _MB) and len(self.preloaded_weights) > 1:
mlu_ops.preload(self.preloaded_weights[1].data, (self.preload_size * _MB) - preloaded_weights_size)
handle.wait()
output = output_parallel
else:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
'''
=========================
End of custom MLU Hijack
=========================
'''
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
MluHijackObject.undo_hijack(UnquantizedLinearMethod,
UnquantizedLinearMethod.apply)
MluHijackObject.apply_hijack(UnquantizedLinearMethod,
UnquantizedLinearMethod.apply,
vllm__model_executor__layers__linear__UnquantizedLinearMethod__apply)
MluHijackObject.undo_hijack(RowParallelLinear,
RowParallelLinear.forward)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.forward,
vllm__model_executor__layers__linear__RowParallelLinear__forward)

View File

@@ -0,0 +1 @@
from . import loader

View File

@@ -0,0 +1,143 @@
import os
import torch
from torch import nn
from typing import Optional
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.config import VllmConfig, ModelConfig, ParallelConfig
from vllm_mlu._mlu_utils import *
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
logger = init_logger(__name__)
def get_parallel_num(
model_config: ModelConfig,
parallel_config: ParallelConfig
):
attention_parallel_num = os.environ.get(ATTN_PARALLEL_NUM)
ffn_parallel_num = os.environ.get(FFN_PARALLEL_NUM)
if attention_parallel_num and attention_parallel_num.isdecimal():
attention_parallel_num = int(attention_parallel_num)
else:
attention_parallel_num = 0
if ffn_parallel_num and ffn_parallel_num.isdecimal():
ffn_parallel_num = int(ffn_parallel_num)
else:
ffn_parallel_num = 0
if parallel_config.tensor_parallel_size == 1:
raise ValueError("Can not use context_comm_cmpt_parallel when tp num is 1.")
if (attention_parallel_num <= 0 and ffn_parallel_num <= 0):
raise ValueError("attention_parallel_num and ffn_parallel_num must be positive integers.")
hidden_size = model_config.get_hidden_size()
ffn_parallel_num = max(ffn_parallel_num, 1)
if hidden_size % ffn_parallel_num != 0:
raise ValueError(f"Hidden_size: {hidden_size} must be divisible by ffn_parallel_num: {ffn_parallel_num}")
return attention_parallel_num, ffn_parallel_num
def get_attr_by_path(obj, path):
# Split the path by dots to get individual attributes
attributes = path.split('.')
# Iterate through the attributes to access nested members
for attr in attributes:
if not hasattr(obj, attr):
return None
obj = getattr(obj, attr)
return obj
def set_custom_attributes(model, model_config, parallel_config):
attn_row_parallel_layers = []
attn_weights = []
ffn_row_parallel_layers = []
ffn_weights = []
sparse_moe_mlp_layers = []
for module in model.modules():
if module.__class__.__name__ == "FeedForward":
ffn_weight = []
if hasattr(module, "up_proj_name"):
up_proj_name = getattr(module, "up_proj_name")
up_proj = getattr(module, up_proj_name)
if hasattr(up_proj, "weight"):
ffn_weight.append(up_proj.weight)
if hasattr(module, "down_proj_name"):
down_proj_name = getattr(module, "down_proj_name")
down_proj = getattr(module, down_proj_name)
if hasattr(down_proj, "weight"):
ffn_weight.append(down_proj.weight)
if ffn_weight is not None:
ffn_weights.append(ffn_weight)
ffn_row_parallel_layers.append(module)
for child_module in module.children():
if child_module.__class__.__name__ == "Attention":
for sibling_module in module.children():
if sibling_module.__class__.__name__ == "QKVParallelLinear":
if hasattr(sibling_module, "weight"):
weight = getattr(sibling_module, "weight")
attn_weights.append([weight])
if sibling_module.__class__.__name__ == "RowParallelLinear":
attn_row_parallel_layers.append(sibling_module)
if module.__class__.__name__ == "SparseMoeMlp" or issubclass(module.__class__, SparseMoeMlp):
sparse_moe_mlp_layers.append(module)
if VLLM_PRELOAD_SIZE > 0:
if (len(attn_row_parallel_layers) \
== len(attn_weights) \
== len(ffn_row_parallel_layers) \
== len(ffn_weights)) and \
len(attn_row_parallel_layers) != 0:
for i in range(len(attn_row_parallel_layers)):
attn_row_parallel_layers[i].preloaded_weights = ffn_weights[i]
attn_row_parallel_layers[i].preload_size = VLLM_PRELOAD_SIZE
if i < len(attn_row_parallel_layers) - 1:
ffn_row_parallel_layers[i].preloaded_weights = attn_weights[i+1]
ffn_row_parallel_layers[i].preload_size = VLLM_PRELOAD_SIZE
else:
logger.warning("%s does not support preload weight!", model.__class__.__name__)
# context compute communication parallel
if check_context_comm_cmpt_parallel():
attention_parallel_num, ffn_parallel_num = get_parallel_num(model_config, parallel_config)
for o_proj in attn_row_parallel_layers:
setattr(o_proj.quant_method, 'parallel_num', attention_parallel_num)
if len(sparse_moe_mlp_layers) != 0:
for sparse_moe_mlp in sparse_moe_mlp_layers:
setattr(sparse_moe_mlp, 'parallel_num', ffn_parallel_num)
else:
for ffn in ffn_row_parallel_layers:
setattr(ffn, 'parallel_num', ffn_parallel_num)
vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model__org = DefaultModelLoader.load_model
def vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model(
self, vllm_config: VllmConfig) -> nn.Module:
model = vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model__org(
self, vllm_config=vllm_config)
'''
=============================
Modify by custom vllm_mlu
=============================
@brief: According to the layer name in models, set custom optimize attributes.
'''
set_custom_attributes(model, vllm_config.model_config, vllm_config.parallel_config)
'''
=========================
End of custom MLU Hijack
=========================
'''
return model
MluHijackObject.apply_hijack(DefaultModelLoader,
DefaultModelLoader.load_model,
vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model)

View File

@@ -0,0 +1,17 @@
### 简介
该劫持代码实现了vllm Context通算并行功能。开启后可在部分数据规模和切分数量上对Context Latency指标有优化效果。目前是可选功能默认不开启。
### 开启方法
- 设置环境变量ATTN_PARALLEL_NUM和FFN_PARALLEL_NUM为正整数分别控制attention和ffn部分的通算并行切分数量。两个环境变量相互独立可以同时开启。例如输入export ATTN_PARALLEL_NUM=2 FFN_PARALLEL_NUM=4则表示两部分均开启并行attention数据拆分为2份ffn数据拆分为4份。
- 需要保证tensor_parallel_size大于1。
- 开启ffn部分的通算并行时需要保证hidden_size能被FFN_PARALLEL_NUM整除。
### 注意事项
- 开启通算并行功能时由于算子限制Mixtral系列模型、Qwen2包含Qwen1.5和Qwen2.5系列模型在smoothquant量化下只支持batch_size = 1且算子默认切分数为4ATTN_PARALLEL_NUM不生效。
- smoothquant量化下vllm_mlu ffn部分不调用tmo matmul算子该部分通算融合不生效。

View File

@@ -0,0 +1 @@
from . import model_executor

View File

@@ -0,0 +1,3 @@
from . import custom_model
from . import layers
from . import models

View File

@@ -0,0 +1,62 @@
import torch
from typing import Optional
from vllm import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.distributed import tensor_model_parallel_all_reduce, get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
from vllm_mlu.model_executor.custom_model.custom import CustomMoeBlock
def vllm__module_executor__custom_model__CustomMoeBlock__forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
if self.shared_expert_gate is not None:
shared_output = F.sigmoid(
self.shared_expert_gate(hidden_states)) * shared_output
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
residual_ = None if self.rank > 0 else residual
'''
=====================================================
Modify by Context Communication Computation Parallel
=====================================================
@brief: call fused_moe
'''
params = [hidden_states, router_logits, self.w1, self.w2, None, None,
residual_, self.input_smooth, self.act_smooth, self.w1_scale, self.w2_scale,
self.top_k, self.config.norm_topk_prob, self.config.is_gated, self.config.hidden_act, 0]
if hasattr(self, 'parallel_num') and get_is_prompt():
rank = get_tensor_model_parallel_rank()
pg = get_tensor_model_parallel_group().device_group
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
params.extend([self.parallel_num, cncl_comm])
final_hidden_states = mlu_ops.fused_moe(*params)
'''
=====================================================
End of Context Communication Computation Parallel
=====================================================
'''
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
reduce_results = (self.config.use_parallel_residual == False)
if reduce_results:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
MluHijackObject.apply_hijack(CustomMoeBlock,
CustomMoeBlock.forward,
vllm__module_executor__custom_model__CustomMoeBlock__forward)

View File

@@ -0,0 +1,2 @@
from . import quantization
from . import sparse_moe_mlp

View File

@@ -0,0 +1,51 @@
import torch
from typing import Optional
from vllm import _mlu_ops as mlu_ops
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
from vllm_mlu._mlu_utils import get_is_prompt
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm_mlu__model_executor__layers__quantization__smoothquant__SmoothQuantLinearMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
quant_input = None
input_scale = None
if self.quant_config.input_quant_method == "per_token":
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer.smooth, None)
if self.quant_config.input_quant_method == "per_tensor":
quant_input = x if self.skip_quant_input else mlu_ops.quantize(x, layer.scale_to_int, None)
'''
=====================================================
Modify by Context Communication Computation Parallel
=====================================================
@brief: call parallel op
'''
if hasattr(self, 'parallel_num') and get_is_prompt():
rank = get_tensor_model_parallel_rank()
pg = get_tensor_model_parallel_group().device_group
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
params = [cncl_comm, quant_input, input_scale, layer.qweight, layer.per_channel_scale,
self.compute_dtype, bias, residual, 1.0, 1.0, self.parallel_num]
out = mlu_ops.smooth_quant_matmul_allreduce(*params)
else:
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer.qweight,
layer.per_channel_scale, self.compute_dtype, bias, residual)
'''
=====================================================
End of Context Communication Computation Parallel
=====================================================
'''
return out
MluHijackObject.apply_hijack(SmoothQuantLinearMethod,
SmoothQuantLinearMethod.apply,
vllm_mlu__model_executor__layers__quantization__smoothquant__SmoothQuantLinearMethod__apply)

View File

@@ -0,0 +1,89 @@
"""Inference-only MOE model."""
import torch
from torch import nn
from typing import Optional
from vllm.distributed import tensor_model_parallel_all_reduce, get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
from vllm import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
def vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
orig_hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# expert_logits: [num_tokens, self.num_experts_per_rank]
expert_logits, _ = self.gate(hidden_states)
final_hidden_states = self.forward_experts(hidden_states, expert_logits, residual)
'''
=====================================================
Modify by Context Communication Computation Parallel
=====================================================
@brief: disbale reduce if parallel op used
'''
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
if self.tp_size > 1 and not is_parallel_enable:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
'''
=====================================================
End of Context Communication Computation Parallel
=====================================================
'''
output = final_hidden_states.view(orig_hidden_states_shape)
return output
def vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward_experts(
self,
hidden_states,
expert_logits,
residual: Optional[torch.Tensor] = None
):
residual_ = None if self.tp_rank > 0 else residual
if self.is_use_fused_moe:
self.pack_params()
'''
=====================================================
Modify by Context Communication Computation Parallel
=====================================================
@brief: call fused_moe all_reduce
'''
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
if is_parallel_enable:
residual_ = residual
params = [hidden_states, expert_logits, self.w13, self.w2, self.b13, self.b2,
residual_, self.a13_scale, self.a2_scale, self.w13_scale, self.w2_scale,
self.top_k, self.renormalize, self.is_gated, self.hidden_act, self.start_expert_id]
if is_parallel_enable:
rank = get_tensor_model_parallel_rank()
pg = get_tensor_model_parallel_group().device_group
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
params.extend([self.parallel_num, cncl_comm])
final_hidden_states = mlu_ops.fused_moe(*params)
'''
=====================================================
End of Context Communication Computation Parallel
=====================================================
'''
else:
final_hidden_states = self.forward_experts_nofused(hidden_states, expert_logits)
if residual_ is not None:
final_hidden_states = final_hidden_states + residual_
return final_hidden_states
MluHijackObject.apply_hijack(SparseMoeMlp,
SparseMoeMlp.forward,
vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward)
MluHijackObject.apply_hijack(SparseMoeMlp,
SparseMoeMlp.forward_experts,
vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward_experts)

View File

@@ -0,0 +1,3 @@
from . import mixtral_quant
from . import qwen2
from . import qwen2_moe

View File

@@ -0,0 +1,299 @@
import torch
from typing import List, Optional
from vllm import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
from vllm.model_executor.models.mixtral_quant import MixtralAttention
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
from vllm.attention.backends.abstract import (AttentionMetadata,
AttentionType)
from vllm.attention.backends.utils import get_num_prefill_decode_query_kv_tokens
from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.attention.backends.mlu_attn import (MLUFlashAttentionMetadata,
_get_query_key_seq_metadata,
_get_causal_option)
def vllm__model_executor__models__mixtral__MixtralAttention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
'''
=====================================================
Modify by Context Communication Computation Parallel
=====================================================
@brief: call flash_attn_sq_mm_allreduce to finish forward
'''
if (attn_metadata.prefill_metadata) and \
(kv_cache[0].numel() > 0) and \
(hasattr(self.o_proj, 'quant_method')) and \
(isinstance(self.o_proj.quant_method, SmoothQuantLinearMethod)) and \
(self.o_proj.quant_method.quant_config.input_quant_method == "per_token"):
rank = get_tensor_model_parallel_rank()
pg = get_tensor_model_parallel_group().device_group
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
return torch.ops.vllm.context_attn_comm_cmpt_parallel_flash_attention_v2(
q, k, v,
self.num_heads, self.head_dim, self.num_kv_heads,
kv_cache, self.attn.impl.kv_cache_dtype,
1.0, 1.0, self.scaling,
cncl_comm,
self.o_proj.smooth, self.o_proj.qweight,
self.o_proj.per_channel_scale.to(torch.float),
self.o_proj.quant_method.parallel_num,
residual, self.attn.impl.sliding_window, self.attn.impl.alibi_slopes
)
'''
=====================================================
End of Context Communication Computation Parallel
=====================================================
'''
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual
'''
output, _ = self.o_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
MluHijackObject.apply_hijack(MixtralAttention,
MixtralAttention.forward,
vllm__model_executor__models__mixtral__MixtralAttention__forward)
def context_attn_comm_cmpt_parallel_flash_attention_v2(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: List[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
cncl_comm: int,
smooth: torch.Tensor,
qweight: torch.Tensor,
per_channel_scale: torch.Tensor,
parallel_num: int,
residual: Optional[torch.Tensor] = None,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, MLUFlashAttentionMetadata)
attn_metadata: MLUFlashAttentionMetadata = current_metadata
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
if (key is not None) and (key is not None):
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
kv_cache_, kv_cache_scale_ = kv_cache
key_cache = kv_cache_[0]
value_cache = kv_cache_[1]
key_cache_scale, value_cache_scale = None, None
if kv_cache_scale_.numel() > 0:
key_cache_scale = kv_cache_scale_[0]
value_cache_scale = kv_cache_scale_[1]
# if not specified in self.attn.forward params, use default DECODER
attn_type = AttentionType.DECODER
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the KV
# cache is already populated with the cross-attention tensor.
# Thus, we skip cache updates during this time.
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
if USE_PAGED:
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_paged_cache(key,
value,
key_cache,
value_cache,
key_cache_scale,
value_cache_scale,
attn_metadata.slot_mapping.flatten())
else:
mlu_ops.reshape_paged_cache(key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping.flatten())
else:
# FIXME: After TMO-1496 is completed, remove this code.
if key.stride() != value.stride():
key = key.contiguous()
value = value.contiguous()
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_linear_cache(key,
value,
key_cache,
value_cache,
key_cache_scale,
value_cache_scale,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, # packed
None, # context_seq_offset
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else:
mlu_ops.reshape_linear_cache(key,
value,
key_cache,
value_cache,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, # packed
None, # context_seq_offset
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
decode_query = query[num_prefill_query_tokens:]
# QKV for prefill.
query = query[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens
alibi_slopes = None if alibi_slopes is None else \
alibi_slopes.repeat(attn_metadata.num_prefills, 1)
prefill_meta = attn_metadata.prefill_metadata
# Prompt run.
if (prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]
output = mlu_ops.flash_attn_sq_mm_allreduce(cncl_comm,
query, key, value,
q_seq_start_loc, k_seq_start_loc,
alibi_slopes, None,
smooth, qweight,
per_channel_scale, None,
q_seq_len, k_seq_len,
softmax_scale, _get_causal_option(attn_type),
-1 if window_size is None \
else window_size[0],
-1 if window_size is None \
else window_size[1],
torch.float, parallel_num)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
output = mlu_ops.flash_attn_sq_mm_allreduce(cncl_comm,
query, key_cache, value_cache,
prefill_meta.query_start_loc, prefill_meta.seq_start_loc,
alibi_slopes, None,
smooth, qweight,
per_channel_scale, None,
prefill_meta.max_query_len, max_seq_len,
softmax_scale, True,
-1 if window_size is None \
else window_size[0],
-1 if window_size is None \
else window_size[1],
torch.float, parallel_num)
# Add residual.
if residual is not None:
output = output + residual
return output
def context_attn_comm_cmpt_parallel_flash_attention_v2_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: List[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
cncl_comm: int,
smooth: torch.Tensor,
qweight: torch.Tensor,
per_channel_scale: torch.Tensor,
parallel_num: int,
residual: Optional[torch.Tensor] = None,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
direct_register_custom_op(
op_name="context_attn_comm_cmpt_parallel_flash_attention_v2",
op_func=context_attn_comm_cmpt_parallel_flash_attention_v2,
mutates_args=["kv_cache"],
fake_impl=context_attn_comm_cmpt_parallel_flash_attention_v2_fake,
)

View File

@@ -0,0 +1,90 @@
import torch
from typing import Optional
from vllm.attention import AttentionMetadata
from vllm import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
from vllm.model_executor.models.qwen2 import Qwen2Attention
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
def vllm__model_executor__models__qwen2__Qwen2Attention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert smooth_quant_scale is None
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack q & k to fit tmo.apply_rotary
'''
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
'''
==================
End of MLU Hijack
==================
'''
'''
=====================================================
Modify by Context Communication Computation Parallel
=====================================================
@brief: call flash_attn_sq_mm_allreduce to finish forward
'''
if (attn_metadata.prefill_metadata) and \
(kv_cache[0].numel() > 0) and \
(hasattr(self.o_proj, 'quant_method')) and \
(isinstance(self.o_proj.quant_method, SmoothQuantLinearMethod)) and \
(self.o_proj.quant_method.quant_config.input_quant_method == "per_token"):
rank = get_tensor_model_parallel_rank()
pg = get_tensor_model_parallel_group().device_group
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
return torch.ops.vllm.context_attn_comm_cmpt_parallel_flash_attention_v2(
q, k, v,
self.num_heads, self.head_dim, self.num_kv_heads,
kv_cache, self.attn.impl.kv_cache_dtype,
1.0, 1.0, self.scaling,
cncl_comm,
self.o_proj.smooth, self.o_proj.qweight,
self.o_proj.per_channel_scale.to(torch.float),
self.o_proj.quant_method.parallel_num,
residual, self.attn.impl.sliding_window, self.attn.impl.alibi_slopes
)
'''
=====================================================
End of Context Communication Computation Parallel
=====================================================
'''
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual
'''
output, _ = self.o_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
MluHijackObject.undo_hijack(Qwen2Attention,
Qwen2Attention.forward)
MluHijackObject.apply_hijack(Qwen2Attention,
Qwen2Attention.forward,
vllm__model_executor__models__qwen2__Qwen2Attention__forward)

View File

@@ -0,0 +1,58 @@
import torch
import torch.nn.functional as F
from typing import Optional
from vllm import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.distributed import get_tensor_model_parallel_rank, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
from vllm.attention import AttentionMetadata
from vllm_mlu.model_executor.models.qwen2_moe import Qwen2MoeSparseMoeBlock
def vllm_mlu__model_executor__models__qwen2_moe__Qwen2MoeSparseMoeBlock__forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
if self.shared_expert_gate is not None:
gate_output = self.shared_expert_gate(hidden_states)
shared_output = F.sigmoid(gate_output[0]) * shared_output
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.forward_experts(hidden_states, router_logits, residual)
'''
=====================================================
Modify by Context Communication Computation Parallel
=====================================================
@brief: disbale reduce if parallel op used
'''
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
if self.tp_size > 1:
if is_parallel_enable:
shared_output = tensor_model_parallel_all_reduce(shared_output)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
else:
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
'''
=====================================================
End of Context Communication Computation Parallel
=====================================================
'''
return final_hidden_states.view(num_tokens, hidden_dim)
MluHijackObject.apply_hijack(Qwen2MoeSparseMoeBlock,
Qwen2MoeSparseMoeBlock.forward,
vllm_mlu__model_executor__models__qwen2_moe__Qwen2MoeSparseMoeBlock__forward)

View File

@@ -0,0 +1,32 @@
### 简介
该劫持代码实现在vLLM的解码通信过程中预加载下一层的权重从而减少解码的延迟。
### 支持模型
仅支持以下模型不支持量化后的模型以及MOE模型。
- Baichuan
- Bloom
- ChatGLM
- Falcon
- GPTNeoX
- Llama
- Qwen
- Qwen2
### 支持板卡
300系列不支持其他系列支持。
### 使用方法
- 设置环境变量export VLLM_PRELOAD_SIZE=<PRELOAD_SIZE><PRELOAD_SIZE>表示预加载权重的大小单位MB。
- 参数设置参考在低带宽资源环境下对于模型Llama-65B不同batch_sized和preload_size对应的性能优化收益如下。
| batch\preload | 8 | 16 | 24 | 32 | 48 | 64 |
|:--------------:|:----:|:----:|:----:|:----:|:----:|:----:|
| 1 | 4.9% | 10.0%| 9.5% | 6.7% |-2.4% | -7.1%|
| 8 | 3.2% | 6.3% | 8.9% | 11.2%| 6.0% | 1.8% |
| 16 | 2.3% | 5.1% | 7.5% | 9.2% | 8.3% | 4.3% |
| 24 | 2.3% | 4.8% | 7.4% | 9.1% | 9.5% | 6.0% |
| 32 | 2.1% | 4.3% | 7.0% | 8.7% | 10.1%| 8.1% |

View File

@@ -0,0 +1 @@
from . import distributed

Some files were not shown because too many files have changed in this diff Show More