add qwen3
This commit is contained in:
1
vllm-v0.6.2/vllm_mlu/vllm_mlu/__init__.py
Normal file
1
vllm-v0.6.2/vllm_mlu/vllm_mlu/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from . import mlu_hijack
|
||||
Binary file not shown.
Binary file not shown.
BIN
vllm-v0.6.2/vllm_mlu/vllm_mlu/__pycache__/config.cpython-310.pyc
Normal file
BIN
vllm-v0.6.2/vllm_mlu/vllm_mlu/__pycache__/config.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm-v0.6.2/vllm_mlu/vllm_mlu/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm-v0.6.2/vllm_mlu/vllm_mlu/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
122
vllm-v0.6.2/vllm_mlu/vllm_mlu/_mlu_utils.py
Normal file
122
vllm-v0.6.2/vllm_mlu/vllm_mlu/_mlu_utils.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from torch.utils import collect_env as torch_collect_env
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def _check_env(env, default=False):
|
||||
if env in os.environ:
|
||||
return os.environ[env].lower() in ["true", "1"]
|
||||
return default
|
||||
|
||||
def _check_env_value(env, default=0):
|
||||
if env in os.environ:
|
||||
if not os.environ[env].isdigit():
|
||||
raise ValueError(f"'{env}' should be set with integer")
|
||||
value = int(os.environ[env])
|
||||
return value
|
||||
return default
|
||||
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
r"""Gets the name of a device.
|
||||
|
||||
Args:
|
||||
device_id (int): device id for which to return the device name.
|
||||
|
||||
Returns:
|
||||
str: the name of the device. eg. MLU370.
|
||||
"""
|
||||
run_lambda = torch_collect_env.run
|
||||
try:
|
||||
out = torch_collect_env.run_and_read_all(run_lambda, "cnmon -l")
|
||||
matches = re.findall(r'MLU\d+(?:-\w+)?', out)
|
||||
return matches[device_id]
|
||||
except Exception as e:
|
||||
raise Exception(f"No device found with ID {device_id}.")
|
||||
|
||||
def get_device_major_capability(device_id: int = 0) -> int:
|
||||
r"""Gets the cuda major capability of a device.
|
||||
|
||||
Args:
|
||||
device_id (int): device id for which to return the device capability.
|
||||
|
||||
Returns:
|
||||
int: the major cuda capability of the device.
|
||||
"""
|
||||
try:
|
||||
device_name = get_device_name(device_id)
|
||||
return int(device_name[3])
|
||||
except Exception as e:
|
||||
raise Exception(f"Fail to parse device capability with ID: {device_id}.")
|
||||
|
||||
# USE_PAGED: Select the vLLM running mode, default value depends on current platform.
|
||||
USE_PAGED = _check_env("USE_PAGED", default=(get_device_major_capability() > 3))
|
||||
|
||||
# VLLM_LATENCY_DEBUG: Get more kernel info for benchmark latency.
|
||||
VLLM_LATENCY_DEBUG = _check_env("VLLM_LATENCY_DEBUG", default=False)
|
||||
|
||||
# VLLM_LATENCY_DEBUG_NO_DEVICE: Get more kernel info(without device) for benchmark latency.
|
||||
VLLM_LATENCY_DEBUG_NO_DEVICE = _check_env("VLLM_LATENCY_DEBUG_NO_DEVICE", default=False)
|
||||
|
||||
# VLLM_DUMP_TENSORS: Dump each layer outputs when running vLLM inference.
|
||||
VLLM_DUMP_OUTPUTS = _check_env("VLLM_DUMP_OUTPUTS", default=False)
|
||||
|
||||
# VLLM_DUMP_CPU_INFO: Get cpu info when running vLLM inference.
|
||||
VLLM_DUMP_CPU_INFO = _check_env("VLLM_DUMP_CPU_INFO", default=False)
|
||||
|
||||
# VLLM_DUMP_MLU_INFO: Get device info when running vLLM inference.
|
||||
VLLM_DUMP_MLU_INFO = _check_env("VLLM_DUMP_MLU_INFO", default=False)
|
||||
|
||||
# VLLM_SCHEDULER_PROFILE: Profiling vLLM scheduler.
|
||||
VLLM_SCHEDULER_PROFILE = _check_env("VLLM_SCHEDULER_PROFILE", default=False)
|
||||
|
||||
# VLLM_GRAPH_DEBUG: Debug the graph status when running decoder, default value is True.
|
||||
# Set to False to disable warning messages.
|
||||
VLLM_GRAPH_DEBUG = _check_env("VLLM_GRAPH_DEBUG", default=True)
|
||||
|
||||
# CHUNKED_PIPELINE_PARALLEL_EN: use chunked pipeline parallel, default value is False.
|
||||
CHUNKED_PIPELINE_PARALLEL_EN = _check_env("CHUNKED_PIPELINE_PARALLEL_EN", default=False)
|
||||
|
||||
# CONTEXT_PARALLEL_EN: use context parallel, default value is False.
|
||||
CONTEXT_PARALLEL_EN = _check_env("CONTEXT_PARALLEL_EN", default=False)
|
||||
|
||||
# EXPERT_PARALLEL_EN: use expert parallel, default value is False.
|
||||
EXPERT_PARALLEL_EN = _check_env("EXPERT_PARALLEL_EN", default=False)
|
||||
|
||||
VLLM_LATENCY_DEBUG_EN = (VLLM_LATENCY_DEBUG or VLLM_LATENCY_DEBUG_NO_DEVICE)
|
||||
VLLM_LATENCY_DEBUG_WITH_DEVICE_EN = (VLLM_LATENCY_DEBUG and not VLLM_LATENCY_DEBUG_NO_DEVICE)
|
||||
VLLM_DUMP_CPU_INFO_EN = (VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and VLLM_DUMP_CPU_INFO)
|
||||
VLLM_DUMP_MLU_INFO_EN = (VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and VLLM_DUMP_MLU_INFO)
|
||||
CUSTOM_VLLM_HIJACK_EN = (CHUNKED_PIPELINE_PARALLEL_EN or CONTEXT_PARALLEL_EN or EXPERT_PARALLEL_EN)
|
||||
|
||||
VLLM_PRELOAD_SIZE = _check_env_value("VLLM_PRELOAD_SIZE", default=0)
|
||||
|
||||
# ATTN_PARALLEL_NUM & FFN_PARALLEL_NUM: use context comm cmpt parallel.
|
||||
ATTN_PARALLEL_NUM = 'ATTN_PARALLEL_NUM'
|
||||
FFN_PARALLEL_NUM = 'FFN_PARALLEL_NUM'
|
||||
|
||||
# this class is used by layers, add BlockSizeInfo to get BLOCKSIZE in model/layer
|
||||
class BlockSizeInfo :
|
||||
BLOCK_SIZE = -1
|
||||
|
||||
@classmethod
|
||||
def set_block_size(cls, a : int) :
|
||||
if USE_PAGED :
|
||||
if a != -1 and a != 16 :
|
||||
raise ValueError("BLOCKSIZE other than 16 are not supported in paged mode, please check '--block-size' value.")
|
||||
cls.BLOCK_SIZE = 16
|
||||
else :
|
||||
cls.BLOCK_SIZE = 2048 if a == -1 else a
|
||||
|
||||
|
||||
def check_context_comm_cmpt_parallel():
|
||||
return (ATTN_PARALLEL_NUM in os.environ) or (FFN_PARALLEL_NUM in os.environ)
|
||||
|
||||
|
||||
def set_is_prompt(flag):
|
||||
global IS_PROMPT
|
||||
IS_PROMPT=flag
|
||||
|
||||
|
||||
def get_is_prompt():
|
||||
return IS_PROMPT
|
||||
4
vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/__init__.py
Normal file
4
vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import vllm_mlu.attention.backends
|
||||
import vllm_mlu.attention.ops
|
||||
import vllm_mlu.attention.layer
|
||||
import vllm_mlu.attention.selector
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
|
||||
import vllm_mlu.attention.backends.mlu_attn
|
||||
Binary file not shown.
Binary file not shown.
802
vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py
Normal file
802
vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py
Normal file
@@ -0,0 +1,802 @@
|
||||
import torch
|
||||
|
||||
from contextlib import contextmanager
|
||||
from itertools import accumulate
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm.attention.backends.abstract import (AttentionMetadata,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import (
|
||||
PAD_SLOT_ID, get_num_prefill_decode_query_kv_tokens,
|
||||
get_seq_len_block_table_args)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
|
||||
make_tensor_with_pad)
|
||||
from vllm.attention.backends.mlu_attn import (
|
||||
MLUFlashAttentionBackend, MLUFlashAttentionMetadataBuilder,
|
||||
MLUFlashAttentionMetadata, MLUFlashAttentionImpl,
|
||||
MLUFlashAttentionState, _get_query_key_seq_metadata,
|
||||
_get_causal_option)
|
||||
from vllm_mlu._mlu_utils import USE_PAGED
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
class MLUFlashAttentionBackend_V2(MLUFlashAttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 80, 96, 128, 160, 192, 224, 256, 512, 576]
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["MLUFlashAttentionImpl_V2"]:
|
||||
return MLUFlashAttentionImpl_V2
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["MLUFlashAttentionMetadataBuilder_V2"]:
|
||||
return MLUFlashAttentionMetadataBuilder_V2
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["MLUFlashAttentionState_V2"]:
|
||||
return MLUFlashAttentionState_V2
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_scale_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, num_kv_heads, block_size)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[List[torch.Tensor]],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0][0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[0][1] for kv_cache in kv_caches]
|
||||
mlu_ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
kv_cache_scales = [kv_cache[1] for kv_cache in kv_caches]
|
||||
if len(kv_cache_scales) > 0 and kv_cache_scales[0].numel() > 0:
|
||||
key_cache_scales = [kv_cache_scale[0] for kv_cache_scale in kv_cache_scales]
|
||||
value_cache_scales = [kv_cache_scale[1] for kv_cache_scale in kv_cache_scales]
|
||||
mlu_ops.copy_blocks(key_cache_scales, value_cache_scales, src_to_dists)
|
||||
|
||||
|
||||
class MLUMLAFlashAttentionBackend(MLUFlashAttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (1, num_blocks, num_kv_heads, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
mlu_ops.swap_blocks(dst_key_cache, src_key_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_scale_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (1, num_blocks, num_kv_heads, block_size)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[List[torch.Tensor]],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0][0] for kv_cache in kv_caches]
|
||||
mlu_ops.copy_blocks(key_caches, None, src_to_dists)
|
||||
|
||||
kv_cache_scales = [kv_cache[1] for kv_cache in kv_caches]
|
||||
if len(kv_cache_scales) > 0 and kv_cache_scales[0].numel() > 0:
|
||||
key_cache_scales = [kv_cache_scale[0] for kv_cache_scale in kv_cache_scales]
|
||||
mlu_ops.copy_blocks(key_cache_scales, None, src_to_dists)
|
||||
|
||||
|
||||
class MLUFlashAttentionState_V2(MLUFlashAttentionState):
|
||||
|
||||
def __init__(self, runner: "ModelRunnerBase"):
|
||||
MLUFlashAttentionState.__init__(self, runner)
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
self._is_graph_capturing = True
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID if USE_PAGED else 0,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
yield
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
assert self._is_graph_capturing
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
max_query_len=1,
|
||||
max_decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=(self.runner.max_seq_len_to_capture if USE_PAGED
|
||||
else min(self.runner.block_size, self.runner.max_seq_len_to_capture)),
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self._graph_block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in\
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or " \
|
||||
f"'FLASH_ATTN', but "\
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._update_captured_metadata_for_enc_dec_model(
|
||||
batch_size=batch_size, attn_metadata=attn_metadata)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
|
||||
input_buffers = {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": None,
|
||||
"block_tables": None,
|
||||
}
|
||||
if attn_metadata.num_prefills > 0:
|
||||
input_buffers["seq_lens_tensor"] = attn_metadata.prefill_metadata.seq_lens_tensor
|
||||
input_buffers["block_tables"] = attn_metadata.prefill_metadata.block_tables
|
||||
else:
|
||||
input_buffers["seq_lens_tensor"] = attn_metadata.decode_metadata.seq_lens_tensor
|
||||
input_buffers["block_tables"] = attn_metadata.decode_metadata.block_tables
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in\
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
||||
f"'FLASH_ATTN', but "\
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._add_additonal_input_buffers_for_enc_dec_model(
|
||||
attn_metadata=attn_metadata, input_buffers=input_buffers)
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(
|
||||
self,
|
||||
input_buffers: Dict[str, Any],
|
||||
attn_metadata: AttentionMetadata,
|
||||
is_encoder_decoder_model: bool = False) -> None:
|
||||
metadata = attn_metadata.prefill_metadata if \
|
||||
attn_metadata.num_prefills > 0 else attn_metadata.decode_metadata
|
||||
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
metadata.block_tables, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in\
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
||||
f"'FLASH_ATTN', but "\
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._prepare_input_buffers_for_enc_dec_model(
|
||||
attn_metadata, input_buffers)
|
||||
|
||||
@contextmanager
|
||||
def graph_capture_with_context(
|
||||
self,
|
||||
ctx_graph_batch_size: int,
|
||||
max_batch_size: int,
|
||||
max_num_tokens: int
|
||||
):
|
||||
self._is_graph_capturing = True
|
||||
self._graph_slot_mapping = torch.full((max_num_tokens, ),
|
||||
PAD_SLOT_ID if USE_PAGED else 0,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
# block tables used for decode mlugraph input buffer
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
# block tables used for context mlugraph input buffer
|
||||
self._ctx_graph_block_tables = torch.zeros((ctx_graph_batch_size, 0),
|
||||
dtype=self._graph_block_tables.dtype,
|
||||
device=self.runner.device)
|
||||
yield
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
del self._ctx_graph_block_tables
|
||||
|
||||
def fill_seq_lens_tensor(
|
||||
self,
|
||||
seq_len: int
|
||||
) -> None:
|
||||
self._graph_seq_lens.fill_(seq_len)
|
||||
|
||||
def graph_capture_get_metadata_for_context(
|
||||
self,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
is_encoder_decoder_model: bool = False
|
||||
) -> MLUFlashAttentionMetadata:
|
||||
assert self._is_graph_capturing
|
||||
|
||||
query_start_loc = torch.zeros(batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
seq_start_loc = torch.zeros(batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
context_lens_tensor = torch.zeros(batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
torch.cumsum(self._graph_seq_lens[:batch_size],
|
||||
dim=0,
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
torch.cumsum(self._graph_seq_lens[:batch_size],
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
|
||||
num_tokens = batch_size * seq_len
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=batch_size,
|
||||
num_prefill_tokens=num_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=self._graph_slot_mapping[:num_tokens],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
seq_lens=[seq_len] * batch_size,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
max_query_len=seq_len,
|
||||
max_decode_query_len=0,
|
||||
max_prefill_seq_len=seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=self._ctx_graph_block_tables,
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class MLUFlashAttentionMetadataBuilder_V2(MLUFlashAttentionMetadataBuilder):
|
||||
|
||||
def build(
|
||||
self,
|
||||
seq_lens: List[int],
|
||||
query_lens: List[int],
|
||||
cuda_graph_pad_size: int,
|
||||
batch_size: int
|
||||
) -> MLUFlashAttentionMetadata:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Use origin func if do not use context mlugraph.
|
||||
'''
|
||||
if not self.runner.model_config.use_context_mlugraph():
|
||||
return super().build(seq_lens,
|
||||
query_lens,
|
||||
cuda_graph_pad_size,
|
||||
batch_size)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
prefix_cache_hit = any([
|
||||
inter_data.prefix_cache_hit
|
||||
for inter_data in self.input_builder.inter_data_list
|
||||
])
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled,
|
||||
prefix_cache_hit)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
decode_query_lens = query_lens[self.num_prefills:]
|
||||
if len(decode_query_lens) > 0:
|
||||
max_decode_query_len = max(decode_query_lens)
|
||||
else:
|
||||
max_decode_query_len = 1
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
num_seqs = len(seq_lens)
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([
|
||||
(PAD_SLOT_ID if USE_PAGED else 0)] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables)
|
||||
else:
|
||||
if USE_PAGED:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
block_tables = make_tensor_without_pad(
|
||||
self.block_tables,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Check if we can use context mlugraph for the given input.
|
||||
'''
|
||||
if num_decode_tokens == 0 and self.num_prefills > 0:
|
||||
ctx_graph_bs, ctx_graph_seq_len = (
|
||||
self.runner.model_config.get_context_mlugraph_bs_and_seq())
|
||||
use_captured_graph = len(seq_lens) == ctx_graph_bs and all(
|
||||
seq_len == ctx_graph_seq_len for seq_len in seq_lens)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
return MLUFlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
|
||||
|
||||
class MLUFlashAttentionImpl_V2(MLUFlashAttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: List[torch.Tensor],
|
||||
attn_metadata: MLUFlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
use_mla: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
output = torch.ops.vllm.unified_flash_attention_v2(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.num_kv_heads,
|
||||
kv_cache,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
self.scale,
|
||||
attn_type.value,
|
||||
self.sliding_window,
|
||||
self.alibi_slopes,
|
||||
self.logits_soft_cap,
|
||||
use_mla,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def unified_flash_attention_v2(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: List[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
attn_type_int_val: int,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
use_mla: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# Convert integer attn_type to enum
|
||||
try:
|
||||
attn_type = AttentionType(attn_type_int_val)
|
||||
except ValueError as err:
|
||||
raise AttributeError(
|
||||
f"Invalid attention type {str(attn_type_int_val)}") from err
|
||||
|
||||
current_metadata = get_forward_context()
|
||||
assert current_metadata is not None
|
||||
assert isinstance(current_metadata, MLUFlashAttentionMetadata)
|
||||
attn_metadata: MLUFlashAttentionMetadata = current_metadata
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
v_head_size = value.size(1) // num_kv_heads
|
||||
if (key is not None) and (value is not None):
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
if use_mla and attn_metadata.prefill_metadata:
|
||||
value = value.view(-1, num_kv_heads, v_head_size)
|
||||
else:
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
if kv_cache[0].numel() > 0:
|
||||
kv_cache_, kv_cache_scale_ = kv_cache
|
||||
key_cache = kv_cache_[0]
|
||||
value_cache = None if use_mla else kv_cache_[1]
|
||||
key_cache_scale, value_cache_scale = None, None
|
||||
if kv_cache_scale_.numel() > 0:
|
||||
key_cache_scale = kv_cache_scale_[0]
|
||||
value_cache_scale = None if use_mla else kv_cache_scale_[1]
|
||||
|
||||
# We skip updating the KV cache under two conditions:
|
||||
# a. When the Attention Type is ENCODER. In this phase, we compute
|
||||
# only the encoder attention without updating the cache.
|
||||
# b. When both Key and Value are None. This occurs during
|
||||
# cross-attention computation in the decoding phase, where the KV
|
||||
# cache is already populated with the cross-attention tensor.
|
||||
# Thus, we skip cache updates during this time.
|
||||
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
|
||||
value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
if USE_PAGED:
|
||||
value_to_cache = None if use_mla else value
|
||||
if use_mla and attn_metadata.prefill_metadata:
|
||||
# MLA save cache info in models before flashattn
|
||||
pass
|
||||
else:
|
||||
if kv_cache_dtype == 'int8':
|
||||
mlu_ops.quant_to_paged_cache(key,
|
||||
value_to_cache,
|
||||
key_cache,
|
||||
value_cache,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
updated_slot_mapping.flatten())
|
||||
else:
|
||||
mlu_ops.reshape_paged_cache(key,
|
||||
value_to_cache,
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping.flatten())
|
||||
else:
|
||||
# FIXME: After TMO-1496 is completed, remove this code.
|
||||
if key.stride() != value.stride():
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
if kv_cache_dtype == 'int8':
|
||||
mlu_ops.quant_to_linear_cache(key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
attn_metadata.cu_seq_lens,
|
||||
attn_metadata.max_seq_len,
|
||||
True, # packed
|
||||
None, # context_seq_offset
|
||||
attn_metadata.batch_ids,
|
||||
attn_metadata.slot_mapping_unpaged)
|
||||
else:
|
||||
mlu_ops.reshape_linear_cache(key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.cu_seq_lens,
|
||||
attn_metadata.max_seq_len,
|
||||
True, # packed
|
||||
None, # context_seq_offset
|
||||
attn_metadata.batch_ids,
|
||||
attn_metadata.slot_mapping_unpaged)
|
||||
if use_mla and attn_metadata.prefill_metadata:
|
||||
output = torch.empty(query.shape[0], query.shape[1], v_head_size, dtype=query.dtype, device="mlu")
|
||||
else:
|
||||
output = torch.empty_like(query)
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
decode_query = query[num_prefill_query_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_query_tokens]
|
||||
assert query.shape[0] == num_prefill_query_tokens
|
||||
assert decode_query.shape[0] == num_decode_query_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
alibi_slopes = None if alibi_slopes is None else \
|
||||
alibi_slopes.repeat(attn_metadata.num_prefills, 1)
|
||||
# Prompt run.
|
||||
if (kv_cache[0].numel() == 0 or prefill_meta.block_tables is None
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
|
||||
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
|
||||
|
||||
key = key[:num_prefill_kv_tokens]
|
||||
value = value[:num_prefill_kv_tokens]
|
||||
mlu_ops.flash_attention(query,
|
||||
key,
|
||||
value,
|
||||
output[:num_prefill_query_tokens],
|
||||
q_seq_start_loc,
|
||||
k_seq_start_loc,
|
||||
alibi_slopes,
|
||||
None,
|
||||
q_seq_len,
|
||||
k_seq_len,
|
||||
softmax_scale,
|
||||
_get_causal_option(attn_type),
|
||||
-1 if window_size is None \
|
||||
else window_size[0],
|
||||
-1 if window_size is None \
|
||||
else window_size[1],
|
||||
torch.float,
|
||||
False)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support prefix caching")
|
||||
assert prefill_meta.seq_lens is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
mlu_ops.flash_attention(query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
output[:num_prefill_kv_tokens],
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.seq_start_loc,
|
||||
alibi_slopes,
|
||||
None,
|
||||
prefill_meta.max_query_len,
|
||||
max_seq_len,
|
||||
softmax_scale,
|
||||
True,
|
||||
-1 if window_size is None \
|
||||
else window_size[0],
|
||||
-1 if window_size is None \
|
||||
else window_size[1],
|
||||
torch.float,
|
||||
False,
|
||||
prefill_meta.block_tables)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
alibi_slopes = None if alibi_slopes is None \
|
||||
else alibi_slopes.repeat(attn_metadata.num_decode_tokens, 1)
|
||||
decode_query = decode_query.view(-1, 1, num_heads, head_size)
|
||||
decode_out = output[num_prefill_query_tokens:].view(-1, 1, num_heads, head_size)
|
||||
# Use flash_attn_varlen_func kernel for speculative decoding
|
||||
# because different queries might have different lengths.
|
||||
assert decode_meta.max_decode_query_len is not None
|
||||
if decode_meta.max_decode_query_len > 1:
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support max_decode_query_len > 1")
|
||||
mlu_ops.flash_attention(decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_out,
|
||||
decode_meta.query_start_loc,
|
||||
decode_meta.seq_start_loc,
|
||||
alibi_slopes,
|
||||
None,
|
||||
decode_meta.max_decode_query_len,
|
||||
decode_meta.max_decode_seq_len,
|
||||
softmax_scale,
|
||||
True,
|
||||
-1 if window_size is None \
|
||||
else window_size[0],
|
||||
-1 if window_size is None \
|
||||
else window_size[1],
|
||||
torch.float,
|
||||
False,
|
||||
decode_meta.block_tables)
|
||||
else:
|
||||
# Use flash_attn_with_kvcache for normal decoding.
|
||||
(
|
||||
seq_lens_arg,
|
||||
max_context_len,
|
||||
block_tables_arg,
|
||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
if use_mla:
|
||||
value_cache = key_cache
|
||||
value_cache_scale = key_cache_scale
|
||||
mlu_ops.single_query_cached_kv_attn(decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_out,
|
||||
block_tables_arg,
|
||||
seq_lens_arg,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
alibi_slopes,
|
||||
max_context_len,
|
||||
-1 if window_size is None \
|
||||
else window_size[0],
|
||||
-1 if window_size is None \
|
||||
else window_size[1],
|
||||
softmax_scale)
|
||||
|
||||
# Reshape the output tensor.
|
||||
if use_mla and attn_metadata.prefill_metadata:
|
||||
return output.view(num_tokens, num_kv_heads * v_head_size)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def unified_flash_attention_v2_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: List[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
attn_type_int_val: int,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_flash_attention_v2",
|
||||
op_func=unified_flash_attention_v2,
|
||||
mutates_args=["kv_cache"],
|
||||
fake_impl=unified_flash_attention_v2_fake,
|
||||
)
|
||||
|
||||
|
||||
def make_tensor_without_pad(
|
||||
x: List[List[int]],
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device] = "mlu",
|
||||
pin_memory: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.tensor(x,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory and str(device) == "cpu")
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
||||
MLUFlashAttentionBackend.get_supported_head_sizes,
|
||||
MLUFlashAttentionBackend_V2.get_supported_head_sizes)
|
||||
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
||||
MLUFlashAttentionBackend.get_impl_cls,
|
||||
MLUFlashAttentionBackend_V2.get_impl_cls)
|
||||
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
||||
MLUFlashAttentionBackend.get_builder_cls,
|
||||
MLUFlashAttentionBackend_V2.get_builder_cls)
|
||||
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
||||
MLUFlashAttentionBackend.get_state_cls,
|
||||
MLUFlashAttentionBackend_V2.get_state_cls)
|
||||
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
||||
"get_kv_cache_scale_shape",
|
||||
MLUFlashAttentionBackend_V2.get_kv_cache_scale_shape)
|
||||
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
||||
MLUFlashAttentionBackend.copy_blocks,
|
||||
MLUFlashAttentionBackend_V2.copy_blocks)
|
||||
118
vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/layer.py
Normal file
118
vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/layer.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Attention layer."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import AttentionMetadata, AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm_mlu.attention.selector import vllm__attention__selector__get_attn_backend as get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
|
||||
from vllm_mlu._mlu_utils import USE_PAGED
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add a arg use_mla for function get_attn_backend, _cached_get_attn_backend,
|
||||
which_attn_to_use
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
def vllm__attention__layer__Attention__init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
use_mla: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super(Attention, self).__init__()
|
||||
self.use_mla = use_mla
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
sliding_window = cache_config.sliding_window
|
||||
is_attention_free = cache_config.is_attention_free
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
sliding_window = None
|
||||
is_attention_free = False
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
|
||||
# The default k/v_scale is set to 1.0. This is ignored
|
||||
# when kv-cache is not fp8, and should be used with
|
||||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||
# expect the pre-quantized k/v_scale to be loaded along
|
||||
# with the model weights.
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self._k_scale = 1.0
|
||||
self._v_scale = 1.0
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
if quant_method is not None:
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if self.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||
"fp8 checkpoints.")
|
||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||
# parameters so that it can be loaded from the model checkpoint.
|
||||
# The k/v_scale will then be converted back to native float32
|
||||
# values after weight loading.
|
||||
self.quant_method = quant_method
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
|
||||
block_size, is_attention_free,
|
||||
blocksparse_params is not None,
|
||||
use_mla=use_mla)
|
||||
impl_cls = attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap)
|
||||
|
||||
def vllm__attention__layer__Attention__forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
return self.impl.forward(query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
self._k_scale,
|
||||
self._v_scale,
|
||||
attn_type=attn_type,
|
||||
use_mla=self.use_mla)
|
||||
|
||||
MluHijackObject.apply_hijack(Attention,
|
||||
Attention.__init__,
|
||||
vllm__attention__layer__Attention__init__)
|
||||
MluHijackObject.apply_hijack(Attention,
|
||||
Attention.forward,
|
||||
vllm__attention__layer__Attention__forward)
|
||||
1
vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/ops/__init__.py
Normal file
1
vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/ops/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
import vllm_mlu.attention.ops.prefix_prefill
|
||||
Binary file not shown.
Binary file not shown.
157
vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/ops/prefix_prefill.py
Normal file
157
vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/ops/prefix_prefill.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
|
||||
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
import vllm.attention.ops.prefix_prefill
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
if triton.__version__ >= "2.1.0":
|
||||
|
||||
@torch.inference_mode()
|
||||
def vllm__attention__ops__prefix_prefill__context_attention_fwd(q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None):
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use to many memory when block is 64
|
||||
'''
|
||||
BLOCK = 16
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
# round up Lk to a power of 2 - this is required for Triton block size
|
||||
Lk_padded = triton.next_power_of_2(Lk)
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
num_queries_per_kv = q.shape[1] // k.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
|
||||
|
||||
num_warps = 8 if Lk <= 64 else 8
|
||||
if alibi_slopes is not None:
|
||||
assert Lk == Lk_padded
|
||||
vllm.attention.ops.prefix_prefill._fwd_kernel_alibi[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
alibi_slopes,
|
||||
v_cache.shape[3],
|
||||
8,
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(
|
||||
4
|
||||
), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(
|
||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
vllm.attention.ops.prefix_prefill._fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
v_cache.shape[3],
|
||||
8,
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(
|
||||
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(
|
||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
BLOCK_N=BLOCK,
|
||||
SLIDING_WINDOW=sliding_window if sliding_window is not None else 0,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
MluHijackObject.apply_hijack(vllm.attention.ops.prefix_prefill,
|
||||
vllm.attention.ops.prefix_prefill.context_attention_fwd,
|
||||
vllm__attention__ops__prefix_prefill__context_attention_fwd)
|
||||
@@ -0,0 +1,802 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
|
||||
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
||||
(https://tridao.me/publications/flash2/flash2.pdf)
|
||||
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
|
||||
|
||||
Features supported:
|
||||
|
||||
1) Fwd with causal masking
|
||||
2) Any sequence lengths without padding (currently fwd kernel only)
|
||||
3) Support for different sequence lengths for q and k
|
||||
4) Nested tensor API currently does not support dropout or bias.
|
||||
|
||||
Not currently supported:
|
||||
|
||||
1) Non power of two head dims
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
torch_dtype: tl.constexpr = torch.float16
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def max_fn(x, y):
|
||||
return tl.math.max(x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
ms = tl.arange(0, m)
|
||||
ns = tl.arange(0, n)
|
||||
return philox_offset + ms[:, None] * stride + ns[None, :]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
|
||||
stride).to(tl.uint32)
|
||||
# TODO: use tl.randint for better performance
|
||||
return tl.rand(philox_seed, rng_offsets)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
|
||||
stride)
|
||||
rng_keep = rng_output > dropout_p
|
||||
return rng_keep
|
||||
|
||||
|
||||
@triton.jit
|
||||
def load_fn(block_ptr, first, second, pad):
|
||||
if first and second:
|
||||
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
||||
elif first:
|
||||
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
|
||||
elif second:
|
||||
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
|
||||
else:
|
||||
tensor = tl.load(block_ptr)
|
||||
return tensor
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
actual_seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
block_min,
|
||||
block_max,
|
||||
offs_n_causal,
|
||||
masked_blocks,
|
||||
n_extra_tokens,
|
||||
bias_ptr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
OFFS_M: tl.constexpr,
|
||||
OFFS_N: tl.constexpr,
|
||||
PRE_LOAD_V: tl.constexpr,
|
||||
MASK_STEPS: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr,
|
||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||
PADDED_HEAD: tl.constexpr,
|
||||
):
|
||||
# loop over k, v, and update accumulator
|
||||
for start_n in range(block_min, block_max, BLOCK_N):
|
||||
# For padded blocks, we will overrun the tensor size if
|
||||
# we load all BLOCK_N. For others, the blocks are all within range.
|
||||
k = load_fn(
|
||||
K_block_ptr,
|
||||
PADDED_HEAD,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
"zero",
|
||||
)
|
||||
if PRE_LOAD_V:
|
||||
v = load_fn(
|
||||
V_block_ptr,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
PADDED_HEAD,
|
||||
"zero",
|
||||
)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
# We start from end of seqlen_k so only the first iteration would need
|
||||
# to be checked for padding if it is not a multiple of block_n
|
||||
# TODO: This can be optimized to only be true for the padded block.
|
||||
if MASK_STEPS: # noqa: SIM102
|
||||
# If this is the last block / iteration, we want to
|
||||
# mask if the sequence length is not a multiple of block size
|
||||
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
|
||||
# if not is_modulo_mn. last step might get wasted but that is okay.
|
||||
# check if this masking works for that case.
|
||||
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
||||
boundary_m = tl.full([BLOCK_M],
|
||||
actual_seqlen_k,
|
||||
dtype=tl.int32)
|
||||
size_n = start_n + OFFS_N[None, :]
|
||||
mask = size_n < boundary_m[:, None]
|
||||
qk = tl.where(mask, qk, float("-inf"))
|
||||
if IS_CAUSAL:
|
||||
causal_boundary = start_n + offs_n_causal
|
||||
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
|
||||
qk = tl.where(causal_mask, qk, float("-inf"))
|
||||
# -- compute qk ----
|
||||
qk += tl.dot(q, k)
|
||||
if bias_ptr is not None:
|
||||
bias = load_fn(bias_ptr, False, MASK_STEPS
|
||||
and (n_extra_tokens != 0), "zero")
|
||||
# While bias is added after multiplying qk with sm_scale, our
|
||||
# optimization to use 2^x instead of e^x results in an additional
|
||||
# scale factor of log2(e) which we must also multiply the bias with.
|
||||
qk += bias * 1.44269504089
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk = qk - m_ij[:, None]
|
||||
p = tl.math.exp2(qk)
|
||||
|
||||
# CAVEAT: Must update l_ij before applying dropout
|
||||
l_ij = tl.sum(p, 1)
|
||||
if ENABLE_DROPOUT:
|
||||
philox_offset = (batch_philox_offset +
|
||||
start_m * BLOCK_M * actual_seqlen_k + start_n -
|
||||
BLOCK_N)
|
||||
keep = dropout_mask(
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
dropout_p,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
actual_seqlen_k,
|
||||
)
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
tl.where(keep, p,
|
||||
-p).to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
p = tl.where(keep, p, 0.0)
|
||||
elif RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
p.to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
# -- update output accumulator --
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
if not PRE_LOAD_V:
|
||||
v = load_fn(
|
||||
V_block_ptr,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
PADDED_HEAD,
|
||||
"zero",
|
||||
)
|
||||
# -- update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
# update m_i and l_i
|
||||
m_i = m_ij
|
||||
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||
(0, BLOCK_N))
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 256,
|
||||
"BLOCK_N": 64,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 128,
|
||||
"BLOCK_N": 128,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 256,
|
||||
"BLOCK_N": 128,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 128,
|
||||
"BLOCK_N": 64,
|
||||
"PRE_LOAD_V": True,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 128,
|
||||
"BLOCK_N": 64,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 64,
|
||||
"BLOCK_N": 64,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 32,
|
||||
"BLOCK_N": 32,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8,
|
||||
),
|
||||
# TODO: This config fails with head_size not pow2 with data mismatches.
|
||||
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16,
|
||||
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 16,
|
||||
"BLOCK_N": 16,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
],
|
||||
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
|
||||
)
|
||||
@triton.jit
|
||||
def attn_fwd(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
bias,
|
||||
sm_scale,
|
||||
L,
|
||||
Out,
|
||||
stride_qz,
|
||||
stride_qh,
|
||||
stride_qm,
|
||||
stride_qk,
|
||||
stride_kz,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_kk,
|
||||
stride_vz,
|
||||
stride_vh,
|
||||
stride_vk,
|
||||
stride_vn,
|
||||
stride_oz,
|
||||
stride_oh,
|
||||
stride_om,
|
||||
stride_on,
|
||||
stride_bz,
|
||||
stride_bh,
|
||||
stride_bm,
|
||||
stride_bn,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
philox_offset_base,
|
||||
encoded_softmax,
|
||||
HQ: tl.constexpr,
|
||||
HK: tl.constexpr,
|
||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||
MAX_SEQLENS_Q: tl.constexpr,
|
||||
MAX_SEQLENS_K: tl.constexpr,
|
||||
VARLEN: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
PRE_LOAD_V: tl.constexpr,
|
||||
BIAS_TYPE: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr,
|
||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_h_q = tl.program_id(1)
|
||||
off_z = tl.program_id(2)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
if VARLEN:
|
||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
|
||||
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
|
||||
# small for all start_m so for those we return early.
|
||||
if start_m * BLOCK_M > seqlen_q:
|
||||
return
|
||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
|
||||
else:
|
||||
cu_seqlens_q_start = 0
|
||||
cu_seqlens_k_start = 0
|
||||
seqlen_q = MAX_SEQLENS_Q
|
||||
seqlen_k = MAX_SEQLENS_K
|
||||
|
||||
# Now we compute whether we need to exit early due to causal masking.
|
||||
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
|
||||
# are completely masked, resulting in 0s written to the output, and
|
||||
# inf written to LSE. We don't need to do any GEMMs in this case.
|
||||
# This block of code determines what N is, and if this WG is operating
|
||||
# on those M rows.
|
||||
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
|
||||
if IS_CAUSAL:
|
||||
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
|
||||
# If seqlen_q != seqlen_k, attn scores are rectangular which means
|
||||
# the causal mask boundary is bottom right aligned, and ends at either
|
||||
# the top edge (seqlen_q < seqlen_k) or left edge.
|
||||
# This captures the decrease in n_blocks if we have a rectangular attn
|
||||
# matrix
|
||||
n_blocks_seqlen = cdiv_fn(
|
||||
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
|
||||
# This is what adjusts the block_max for the current WG, only
|
||||
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
||||
n_blocks = min(n_blocks, n_blocks_seqlen)
|
||||
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
||||
# part of the blocks that are all 0. We exit early.
|
||||
if n_blocks <= 0:
|
||||
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||
off_h_q * stride_oh)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
|
||||
# We still need to write 0s to the result
|
||||
# tl.store(O_block_ptr,
|
||||
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
||||
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
||||
# + offs_m
|
||||
# We store inf to LSE, not -inf because in the bwd pass,
|
||||
# we subtract this
|
||||
# from qk which makes it -inf, such that exp(qk - inf) = 0
|
||||
# for these masked blocks.
|
||||
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
|
||||
# tl.store(l_ptrs, l)
|
||||
# TODO: Should dropout and return encoded softmax be handled here?
|
||||
return
|
||||
|
||||
# If MQA / GQA, set the K and V head offsets appropriately.
|
||||
GROUP_SIZE: tl.constexpr = HQ // HK
|
||||
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
|
||||
|
||||
n_extra_tokens = 0
|
||||
if seqlen_k < BLOCK_N:
|
||||
n_extra_tokens = BLOCK_N - seqlen_k
|
||||
elif seqlen_k % BLOCK_N:
|
||||
n_extra_tokens = seqlen_k % BLOCK_N
|
||||
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
||||
|
||||
# Compute pointers for all the tensors used in this kernel.
|
||||
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
|
||||
cu_seqlens_q_start * stride_qm)
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + q_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
|
||||
cu_seqlens_k_start * stride_kn)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
|
||||
cu_seqlens_k_start * stride_vk)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
if BIAS_TYPE != 0:
|
||||
bias_ptr = tl.make_block_ptr(
|
||||
base=bias + off_h_q * stride_bh,
|
||||
shape=(seqlen_q, seqlen_k),
|
||||
strides=(stride_bm, stride_bn),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
bias_ptr = None
|
||||
if ENABLE_DROPOUT:
|
||||
batch_philox_offset = philox_offset_base \
|
||||
+ (off_z * HQ + off_h_q) \
|
||||
* seqlen_q * seqlen_k
|
||||
else:
|
||||
batch_philox_offset = 0
|
||||
# We can ask to return the dropout mask without actually doing any dropout.
|
||||
# In this case, we return an invalid pointer so indicate the mask is not i
|
||||
# valid.
|
||||
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.make_block_ptr(
|
||||
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
|
||||
shape=(seqlen_q, seqlen_k),
|
||||
strides=(seqlen_k, 1),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
encoded_softmax_block_ptr = 0
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
|
||||
# have native e^x support in HW.
|
||||
qk_scale = sm_scale * 1.44269504089
|
||||
# Q is loaded once at the beginning and shared by all N blocks.
|
||||
q = load_fn(Q_block_ptr, True, padded_head, "zero")
|
||||
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
||||
|
||||
# Here we compute how many full and masked blocks we have.
|
||||
padded_block_k = n_extra_tokens != 0
|
||||
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
|
||||
if IS_CAUSAL:
|
||||
# There are always at least BLOCK_M // BLOCK_N masked blocks.
|
||||
# Additionally there might be one more due to dissimilar seqlens.
|
||||
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
|
||||
else:
|
||||
# Padding on Q does not need to be masked in the FA loop.
|
||||
masked_blocks = padded_block_k
|
||||
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
|
||||
# block. In this case we might exceed n_blocks so pick the min.
|
||||
masked_blocks = min(masked_blocks, n_blocks)
|
||||
n_full_blocks = n_blocks - masked_blocks
|
||||
block_min = 0
|
||||
block_max = n_blocks * BLOCK_N
|
||||
# Compute for full blocks. Here we set causal to false regardless of its
|
||||
# value because there is no masking. Similarly we do not need padding.
|
||||
if n_full_blocks > 0:
|
||||
block_max = (n_blocks - masked_blocks) * BLOCK_N
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
|
||||
block_min,
|
||||
block_max,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
bias_ptr,
|
||||
# IS_CAUSAL, ....
|
||||
False,
|
||||
BLOCK_M,
|
||||
BLOCK_DMODEL,
|
||||
BLOCK_N,
|
||||
offs_m,
|
||||
offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V,
|
||||
False,
|
||||
ENABLE_DROPOUT,
|
||||
RETURN_ENCODED_SOFTMAX,
|
||||
padded_head,
|
||||
)
|
||||
block_min = block_max
|
||||
block_max = n_blocks * BLOCK_N
|
||||
|
||||
tl.debug_barrier()
|
||||
# Remaining blocks, if any, are full / not masked.
|
||||
if masked_blocks > 0:
|
||||
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||
(0, n_full_blocks))
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
block_min,
|
||||
block_max,
|
||||
offs_n_causal,
|
||||
masked_blocks,
|
||||
n_extra_tokens,
|
||||
bias_ptr,
|
||||
IS_CAUSAL,
|
||||
BLOCK_M,
|
||||
BLOCK_DMODEL,
|
||||
BLOCK_N,
|
||||
offs_m,
|
||||
offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V,
|
||||
True,
|
||||
ENABLE_DROPOUT,
|
||||
RETURN_ENCODED_SOFTMAX,
|
||||
padded_head,
|
||||
)
|
||||
# epilogue
|
||||
acc = acc / l_i[:, None]
|
||||
if ENABLE_DROPOUT:
|
||||
acc = acc / (1 - dropout_p)
|
||||
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
|
||||
# then we have one block with a row of all NaNs which come from computing
|
||||
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
|
||||
# and store 0s where there are NaNs as these rows should've been zeroed out.
|
||||
end_m_idx = (start_m + 1) * BLOCK_M
|
||||
start_m_idx = start_m * BLOCK_M
|
||||
causal_start_idx = seqlen_q - seqlen_k
|
||||
acc = acc.to(Out.type.element_ty)
|
||||
if IS_CAUSAL: # noqa: SIM102
|
||||
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
||||
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
|
||||
causal_start_idx,
|
||||
dtype=tl.int32)
|
||||
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
||||
out_ptrs_mask = (mask_m_offsets[:, None] >=
|
||||
out_mask_boundary[None, :])
|
||||
z = 0.0
|
||||
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||
# write back LSE
|
||||
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
||||
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
||||
# few rows. This is only true for the last M block. For others,
|
||||
# overflow_size will be -ve
|
||||
# overflow_size = end_m_idx - seqlen_q
|
||||
# if overflow_size > 0:
|
||||
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
|
||||
# # This is a > check because mask being 0 blocks the store.
|
||||
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
|
||||
# else:
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||
|
||||
# write back O
|
||||
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||
off_h_q * stride_oh)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
# Need boundary check on this to make sure the padding from the
|
||||
# Q and KV tensors in both dims are not part of what we store back.
|
||||
# TODO: Do the boundary check optionally.
|
||||
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
|
||||
|
||||
|
||||
def check_args(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
varlen=True,
|
||||
max_seqlens=None,
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_k=None,
|
||||
):
|
||||
assert q.dim() == k.dim() and q.dim() == v.dim()
|
||||
if varlen:
|
||||
assert q.dim() == 3
|
||||
total_q, nheads_q, head_size = q.shape
|
||||
total_k, nheads_k, _ = k.shape
|
||||
assert cu_seqlens_q is not None
|
||||
assert cu_seqlens_k is not None
|
||||
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
||||
else:
|
||||
assert q.dim() == 4
|
||||
batch, nheads_q, seqlen_q, head_size = q.shape
|
||||
_, nheads_k, seqlen_k, _ = k.shape
|
||||
assert max_seqlens > 0
|
||||
assert k.shape == v.shape
|
||||
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
||||
# TODO: Change assert if we support qkl f8 and v f16
|
||||
assert q.dtype == k.dtype and q.dtype == v.dtype
|
||||
assert head_size <= 256
|
||||
assert o.shape == q.shape
|
||||
assert (nheads_q % nheads_k) == 0
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlens_q,
|
||||
max_seqlens_k,
|
||||
causal=False,
|
||||
sm_scale=1.0,
|
||||
bias=None,
|
||||
):
|
||||
if o is None:
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
|
||||
check_args(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
varlen=True,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
)
|
||||
if True: # varlen
|
||||
total_q, nheads_q, head_size = q.shape
|
||||
total_k, nheads_k, _ = k.shape
|
||||
batch = len(cu_seqlens_q) - 1
|
||||
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
|
||||
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
|
||||
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
|
||||
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
|
||||
else:
|
||||
batch, seqlen_q, nheads_q, head_size = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
|
||||
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
|
||||
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
|
||||
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
||||
|
||||
# Get closest power of 2 over or equal to 32.
|
||||
unpadded_head_dims = {32, 64, 128, 256}
|
||||
if head_size not in unpadded_head_dims:
|
||||
padded_d_model = None
|
||||
for i in unpadded_head_dims:
|
||||
if i > head_size:
|
||||
padded_d_model = i
|
||||
break
|
||||
assert padded_d_model is not None
|
||||
else:
|
||||
padded_d_model = head_size
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
|
||||
nheads_q,
|
||||
batch,
|
||||
)
|
||||
|
||||
encoded_softmax = None
|
||||
|
||||
# Seed the RNG so we get reproducible results for testing.
|
||||
philox_seed = 0x1BF52
|
||||
philox_offset = 0x1D4B42
|
||||
|
||||
if bias is not None:
|
||||
bias_strides = (
|
||||
bias.stride(0),
|
||||
bias.stride(1),
|
||||
bias.stride(2),
|
||||
bias.stride(3),
|
||||
)
|
||||
else:
|
||||
bias_strides = (0, 0, 0, 0)
|
||||
|
||||
attn_fwd[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
bias,
|
||||
sm_scale,
|
||||
None,
|
||||
o,
|
||||
*q_strides,
|
||||
*k_strides,
|
||||
*v_strides,
|
||||
*o_strides,
|
||||
*bias_strides,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
dropout_p=0.0,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset_base=philox_offset,
|
||||
encoded_softmax=encoded_softmax,
|
||||
HQ=nheads_q,
|
||||
HK=nheads_k,
|
||||
ACTUAL_BLOCK_DMODEL=head_size,
|
||||
MAX_SEQLENS_Q=max_seqlens_q,
|
||||
MAX_SEQLENS_K=max_seqlens_k,
|
||||
IS_CAUSAL=causal,
|
||||
VARLEN=True,
|
||||
BLOCK_DMODEL=padded_d_model,
|
||||
BIAS_TYPE=0 if bias is None else 1,
|
||||
ENABLE_DROPOUT=False,
|
||||
RETURN_ENCODED_SOFTMAX=False,
|
||||
)
|
||||
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = head_size
|
||||
ctx.causal = causal
|
||||
ctx.dropout_p = 0.0
|
||||
ctx.philox_seed = philox_seed
|
||||
ctx.philox_offset = philox_offset
|
||||
ctx.encoded_softmax = encoded_softmax
|
||||
ctx.return_encoded_softmax = False
|
||||
return o, encoded_softmax
|
||||
|
||||
|
||||
triton_attention = _attention.apply
|
||||
303
vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/selector.py
Normal file
303
vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/selector.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import enum
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
from typing import Generator, Optional, Type
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import get_global_forced_attn_backend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
|
||||
from vllm_mlu._mlu_utils import USE_PAGED
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
from vllm.attention.selector import _Backend, backend_name_to_enum
|
||||
from vllm.attention import selector
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add MLU_MLA_FLASH_ATTN for deepseekv2 MLA.
|
||||
'''
|
||||
_Backend.MLU_MLA_FLASH_ATTN = enum.auto()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add a arg use_mla for function get_attn_backend, _cached_get_attn_backend,
|
||||
which_attn_to_use
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
def vllm__attention__selector__get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
# value to be returned from the cache if the value changes between calls.
|
||||
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
|
||||
# private function.
|
||||
return vllm__attention__selector___cached_get_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
is_attention_free=is_attention_free,
|
||||
is_blocksparse=is_blocksparse,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def vllm__attention__selector___cached_get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
if is_blocksparse:
|
||||
logger.info("Using BlocksparseFlashAttention backend.")
|
||||
from vllm.attention.backends.blocksparse_attn import (
|
||||
BlocksparseFlashAttentionBackend)
|
||||
return BlocksparseFlashAttentionBackend
|
||||
|
||||
backend = vllm__attention__selector__which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
|
||||
is_attention_free, use_v1, use_mla)
|
||||
if backend == _Backend.FLASH_ATTN:
|
||||
logger.info("Using Flash Attention backend.")
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
return FlashAttentionBackend
|
||||
if backend == _Backend.FLASH_ATTN_VLLM_V1:
|
||||
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend as FlashAttentionBackendV1)
|
||||
return FlashAttentionBackendV1
|
||||
if backend == _Backend.XFORMERS:
|
||||
logger.info("Using XFormers backend.")
|
||||
from vllm.attention.backends.xformers import ( # noqa: F401
|
||||
XFormersBackend)
|
||||
return XFormersBackend
|
||||
elif backend == _Backend.ROCM_FLASH:
|
||||
logger.info("Using ROCmFlashAttention backend.")
|
||||
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
|
||||
ROCmFlashAttentionBackend)
|
||||
return ROCmFlashAttentionBackend
|
||||
elif backend == _Backend.TORCH_SDPA:
|
||||
assert current_platform.is_cpu(), RuntimeError(
|
||||
"Torch SDPA backend is only used for the CPU device.")
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
||||
return TorchSDPABackend
|
||||
elif backend == _Backend.OPENVINO:
|
||||
logger.info("Using OpenVINO Attention backend.")
|
||||
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
|
||||
return OpenVINOAttentionBackend
|
||||
elif backend == _Backend.IPEX:
|
||||
assert current_platform.is_xpu(), RuntimeError(
|
||||
"IPEX attention backend is only used for the XPU device.")
|
||||
logger.info("Using IPEX attention backend.")
|
||||
from vllm.attention.backends.ipex_attn import IpexAttnBackend
|
||||
return IpexAttnBackend
|
||||
elif backend == _Backend.FLASHINFER:
|
||||
logger.info("Using Flashinfer backend.")
|
||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
||||
return FlashInferBackend
|
||||
elif backend == _Backend.HPU_ATTN:
|
||||
logger.info("Using HPUAttention backend.")
|
||||
from vllm.attention.backends.hpu_attn import HPUAttentionBackend
|
||||
return HPUAttentionBackend
|
||||
elif backend == _Backend.PALLAS:
|
||||
logger.info("Using Pallas backend.")
|
||||
from vllm.attention.backends.pallas import PallasAttentionBackend
|
||||
return PallasAttentionBackend
|
||||
elif backend == _Backend.MLU_MLA_FLASH_ATTN:
|
||||
logger.info("Using MLUFlashAttention backend.")
|
||||
from vllm_mlu.attention.backends.mlu_attn import MLUMLAFlashAttentionBackend
|
||||
return MLUMLAFlashAttentionBackend
|
||||
elif backend == _Backend.MLU_FLASH_ATTN:
|
||||
logger.info("Using MLUFlashAttention backend.")
|
||||
from vllm.attention.backends.mlu_attn import MLUFlashAttentionBackend
|
||||
return MLUFlashAttentionBackend
|
||||
elif backend == _Backend.NO_ATTENTION:
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionBackend)
|
||||
return PlaceholderAttentionBackend
|
||||
else:
|
||||
raise ValueError("Invalid attention backend.")
|
||||
|
||||
|
||||
def vllm__attention__selector__which_attn_to_use(head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False) -> _Backend:
|
||||
"""Returns which flash attention backend to use."""
|
||||
# Default case.
|
||||
selected_backend = _Backend.FLASH_ATTN
|
||||
|
||||
# If there are no attention layers (e.g. we are running Mamba),
|
||||
# use the placeholder NO_ATTENTION
|
||||
if is_attention_free:
|
||||
return _Backend.NO_ATTENTION
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
|
||||
if current_platform.is_cpu():
|
||||
if selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
if current_platform.is_openvino():
|
||||
if selected_backend != _Backend.OPENVINO:
|
||||
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
|
||||
return _Backend.OPENVINO
|
||||
|
||||
if current_platform.is_xpu():
|
||||
if selected_backend != _Backend.IPEX:
|
||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||
return _Backend.IPEX
|
||||
|
||||
if current_platform.is_tpu():
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
return _Backend.PALLAS
|
||||
|
||||
if current_platform.is_mlu():
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add MLU_MLA_FLASH_ATTN for deepseekv2 MLA.
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if use_mla:
|
||||
return _Backend.MLU_MLA_FLASH_ATTN
|
||||
if selected_backend != _Backend.MLU_FLASH_ATTN:
|
||||
logger.debug("Cannot use %s backend on MLU.", selected_backend)
|
||||
return _Backend.MLU_FLASH_ATTN
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# AMD GPUs.
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
== _Backend.FLASH_ATTN else selected_backend)
|
||||
if selected_backend == _Backend.ROCM_FLASH:
|
||||
if not current_platform.has_device_capability(90):
|
||||
# not Instinct series GPUs.
|
||||
logger.info("flash_attn is not supported on NAVI GPUs.")
|
||||
else:
|
||||
logger.info("%s is not supported in AMD GPUs.", selected_backend)
|
||||
return _Backend.ROCM_FLASH
|
||||
|
||||
if current_platform.is_hpu():
|
||||
return _Backend.HPU_ATTN
|
||||
|
||||
if use_v1:
|
||||
return _Backend.FLASH_ATTN_VLLM_V1
|
||||
# FlashAttn in NVIDIA GPUs.
|
||||
if selected_backend == _Backend.FLASH_ATTN:
|
||||
if not current_platform.has_device_capability(80):
|
||||
# Volta and Turing NVIDIA GPUs.
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for Volta and Turing "
|
||||
"GPUs.")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
elif dtype not in (torch.float16, torch.bfloat16):
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for dtype other than "
|
||||
"torch.float16 or torch.bfloat16.")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
||||
logger.warning(
|
||||
"Please use FlashInfer backend with FP8 KV Cache for "
|
||||
"better performance by setting environment variable "
|
||||
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
elif block_size % 16 != 0:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for block size not "
|
||||
"divisible by 16.")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
|
||||
# FlashAttn is valid for the model, checking if the package is installed.
|
||||
if selected_backend == _Backend.FLASH_ATTN:
|
||||
try:
|
||||
import vllm.vllm_flash_attn # noqa: F401
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
|
||||
supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in supported_sizes:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for head size %d.",
|
||||
head_size)
|
||||
selected_backend = _Backend.XFORMERS
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend because the "
|
||||
"vllm.vllm_flash_attn package is not found. "
|
||||
"Make sure that vllm_flash_attn was built and installed "
|
||||
"(on by default).")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
|
||||
return selected_backend
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(selector,
|
||||
selector.get_attn_backend,
|
||||
vllm__attention__selector__get_attn_backend)
|
||||
MluHijackObject.apply_hijack(selector,
|
||||
selector._cached_get_attn_backend,
|
||||
vllm__attention__selector___cached_get_attn_backend)
|
||||
MluHijackObject.apply_hijack(selector,
|
||||
selector.which_attn_to_use,
|
||||
vllm__attention__selector__which_attn_to_use)
|
||||
138
vllm-v0.6.2/vllm_mlu/vllm_mlu/config.py
Normal file
138
vllm-v0.6.2/vllm_mlu/vllm_mlu/config.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from typing import Tuple
|
||||
from vllm.logger import init_logger
|
||||
from vllm.config import ModelConfig, CacheConfig, LoRAConfig
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
vllm__config__LoRAConfig__verify_with_model_config_org = LoRAConfig.verify_with_model_config
|
||||
|
||||
|
||||
def vllm__config__CacheConfig___verify_cache_dtype(self) -> None:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add kv_cache_dtype int8 support
|
||||
'''
|
||||
if self.cache_dtype == "auto":
|
||||
pass
|
||||
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
||||
logger.info(
|
||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
"Meanwhile, it may cause accuracy drop without a proper "
|
||||
"scaling factor")
|
||||
elif self.cache_dtype == 'int8':
|
||||
logger.info(
|
||||
"Using int8 data type to store kv cache. It reduces the MLU "
|
||||
"memory footprint and boosts the performance. ")
|
||||
else:
|
||||
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type == 'deepseek_v2':
|
||||
# feature flag MLA
|
||||
return 1
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
# If tensor parallelism is used, we divide the number of KV heads by
|
||||
# the tensor parallel size. We will replicate the KV heads in the
|
||||
# case where the number of KV heads is smaller than the tensor
|
||||
# parallel size so each GPU has at least one KV head.
|
||||
return max(1,
|
||||
total_num_kv_heads // parallel_config.tensor_parallel_size)
|
||||
|
||||
def vllm__config__ModelConfig__get_head_size(self) -> int:
|
||||
# TODO remove hard code
|
||||
if hasattr(self.hf_text_config, "model_type"
|
||||
) and self.hf_text_config.model_type == 'deepseek_v2':
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace 256 to 192.
|
||||
'''
|
||||
return 576
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
|
||||
if hasattr(self.hf_text_config, "head_dim"):
|
||||
return self.hf_text_config.head_dim
|
||||
# FIXME(woosuk): This may not be true for all models.
|
||||
return (self.hf_text_config.hidden_size //
|
||||
self.hf_text_config.num_attention_heads)
|
||||
|
||||
def vllm__config__ModelConfig__set_context_mlugraph_info(
|
||||
self, enable_context_mlugraph: bool, batch_size: int, seq_len: int) -> None:
|
||||
self.enable_context_mlugraph = enable_context_mlugraph
|
||||
self.context_batch_size_to_capture = batch_size
|
||||
self.context_seq_len_to_capture = seq_len
|
||||
|
||||
|
||||
def vllm__config__ModelConfig__use_context_mlugraph(self) -> bool:
|
||||
return hasattr(self, "enable_context_mlugraph") and self.enable_context_mlugraph
|
||||
|
||||
|
||||
def vllm__config__ModelConfig__get_context_mlugraph_bs_and_seq(self) -> Tuple[int, int]:
|
||||
return self.context_batch_size_to_capture, self.context_seq_len_to_capture
|
||||
|
||||
|
||||
def vllm__config__LoRAConfig__verify_with_model_config(self, model_config: ModelConfig):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: do not support quantization with lora for now
|
||||
'''
|
||||
if model_config.quantization:
|
||||
raise ValueError("vllm mlu does not support quantization with lora for now")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
vllm__config__LoRAConfig__verify_with_model_config_org(self, model_config)
|
||||
|
||||
@property
|
||||
def vllm__config__ModelConfig__is_deepseek_v2(self) -> bool:
|
||||
result = hasattr(
|
||||
self.hf_text_config,
|
||||
"model_type") and self.hf_text_config.model_type == 'deepseek_v2'
|
||||
return result
|
||||
|
||||
MluHijackObject.apply_hijack(ModelConfig,
|
||||
"is_deepseek_v2",
|
||||
vllm__config__ModelConfig__is_deepseek_v2)
|
||||
MluHijackObject.apply_hijack(ModelConfig,
|
||||
"set_context_mlugraph_info",
|
||||
vllm__config__ModelConfig__set_context_mlugraph_info)
|
||||
MluHijackObject.apply_hijack(ModelConfig,
|
||||
"use_context_mlugraph",
|
||||
vllm__config__ModelConfig__use_context_mlugraph)
|
||||
MluHijackObject.apply_hijack(ModelConfig,
|
||||
"get_context_mlugraph_bs_and_seq",
|
||||
vllm__config__ModelConfig__get_context_mlugraph_bs_and_seq)
|
||||
MluHijackObject.apply_hijack(CacheConfig,
|
||||
CacheConfig._verify_cache_dtype,
|
||||
vllm__config__CacheConfig___verify_cache_dtype)
|
||||
MluHijackObject.apply_hijack(ModelConfig,
|
||||
ModelConfig.get_head_size,
|
||||
vllm__config__ModelConfig__get_head_size)
|
||||
MluHijackObject.apply_hijack(ModelConfig,
|
||||
ModelConfig.get_num_kv_heads,
|
||||
vllm__config__ModelConfig__get_num_kv_heads)
|
||||
MluHijackObject.apply_hijack(LoRAConfig,
|
||||
LoRAConfig.verify_with_model_config,
|
||||
vllm__config__LoRAConfig__verify_with_model_config)
|
||||
1
vllm-v0.6.2/vllm_mlu/vllm_mlu/core/__init__.py
Normal file
1
vllm-v0.6.2/vllm_mlu/vllm_mlu/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
import vllm_mlu.core.block_manager
|
||||
Binary file not shown.
Binary file not shown.
56
vllm-v0.6.2/vllm_mlu/vllm_mlu/core/block_manager.py
Normal file
56
vllm-v0.6.2/vllm_mlu/vllm_mlu/core/block_manager.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from vllm.sequence import SequenceGroup, SequenceStatus
|
||||
from vllm_mlu._mlu_utils import USE_PAGED
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.core.block_manager import SelfAttnBlockSpaceManager
|
||||
from vllm.utils import Device
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__core__block_manager__SelfAttnBlockSpaceManager__can_append_slots(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
"""Determine if there is enough space in the GPU KV cache to continue
|
||||
generation of the specified sequence group.
|
||||
|
||||
We use a worst-case heuristic: assume each touched block will require a
|
||||
new allocation (either via CoW or new block). We can append slots if the
|
||||
number of touched blocks is less than the number of free blocks.
|
||||
|
||||
"Lookahead slots" are slots that are allocated in addition to the slots
|
||||
for known tokens. The contents of the lookahead slots are not defined.
|
||||
This is used by speculative decoding when speculating future tokens.
|
||||
"""
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: optimize the allocation strategy for unpagged mode
|
||||
'''
|
||||
if not USE_PAGED:
|
||||
return True
|
||||
else:
|
||||
num_touched_blocks = 0
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
num_touched_blocks += (
|
||||
block_table.get_num_blocks_touched_by_append_slots(
|
||||
token_ids=block_table.get_unseen_token_ids(
|
||||
seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
))
|
||||
|
||||
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
||||
Device.GPU)
|
||||
return num_touched_blocks <= num_free_gpu_blocks
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(SelfAttnBlockSpaceManager,
|
||||
SelfAttnBlockSpaceManager.can_append_slots,
|
||||
vllm__core__block_manager__SelfAttnBlockSpaceManager__can_append_slots)
|
||||
328
vllm-v0.6.2/vllm_mlu/vllm_mlu/core/scheduler.py
Normal file
328
vllm-v0.6.2/vllm_mlu/vllm_mlu/core/scheduler.py
Normal file
@@ -0,0 +1,328 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import deque
|
||||
from typing import Deque, List, Optional, Set, Callable
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.scheduler import (SchedulingBudget, SchedulerPrefillOutputs,
|
||||
SchedulerRunningOutputs, SchedulerOutputs, Scheduler)
|
||||
from vllm.sequence import SequenceGroup
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
vllm__core__scheduler__Scheduler____init____org = Scheduler.__init__
|
||||
vllm__core__scheduler__Scheduler___schedule_prefills__org = Scheduler._schedule_prefills
|
||||
vllm__core__scheduler__Scheduler___schedule_running__org = Scheduler._schedule_running
|
||||
vllm__core__scheduler__Scheduler___schedule__org = Scheduler._schedule
|
||||
|
||||
|
||||
def vllm__core__scheduler__Scheduler__init_scheduler_view(self):
|
||||
logger.info(f"vLLM scheduler profiling start...")
|
||||
|
||||
self.df = pd.DataFrame(
|
||||
data={
|
||||
'waiting': [], 'running': [], 'swapped': [], 'finished': [],
|
||||
'wait_to_run_reqs': [], 'run_to_wait_reqs': [], 'wait_to_run_tokens': [],
|
||||
'batch_utils': [], 'block_utils': [], 'preempt_ratio': []
|
||||
},
|
||||
dtype=np.float32
|
||||
)
|
||||
self.sched_step = 0
|
||||
self.running_seqs = 0
|
||||
self.waiting_seqs = 0
|
||||
self.swapped_seqs = 0
|
||||
self.finished_seqs = 0
|
||||
self.total_seqs = 0
|
||||
self.running_to_waiting_seqs = 0
|
||||
self.waiting_to_running_seqs = 0
|
||||
self.wait_to_run_tokens = 0
|
||||
self.batch_utils = 0
|
||||
self.block_utils = 0
|
||||
self.preempt_ratio = 0
|
||||
|
||||
self.finished_seq_groups = []
|
||||
|
||||
|
||||
def summary_finished_seq_groups(seq_groups: List[SequenceGroup]):
|
||||
df = pd.DataFrame(
|
||||
data={
|
||||
'ttft/s': [], 'time_in_queue/s': [], 'context_latency/s': [], 'decoder_latency/s': []
|
||||
},
|
||||
dtype=np.float32
|
||||
)
|
||||
for seq_group in seq_groups:
|
||||
ttft = seq_group.metrics.first_token_time - seq_group.metrics.arrival_time
|
||||
time_in_queue = seq_group.metrics.time_in_queue
|
||||
context_latency = seq_group.metrics.first_token_time - seq_group.metrics.first_scheduled_time
|
||||
decoder_latency = seq_group.metrics.finished_time - seq_group.metrics.first_token_time
|
||||
decoder_token_num = seq_group.get_seqs()[0].get_output_len() - 1
|
||||
per_token_latency = decoder_latency if decoder_token_num == 0 \
|
||||
else decoder_latency / decoder_token_num
|
||||
df_ = pd.DataFrame(
|
||||
[[ttft, time_in_queue, context_latency, decoder_latency, per_token_latency, decoder_token_num]],
|
||||
columns=['ttft/s', 'time_in_queue/s', 'context_latency/s', 'decoder_latency/s', 'per_token_latency/s', 'decoder_tokens'],
|
||||
index=[str(seq_group.request_id)]
|
||||
)
|
||||
df = pd.concat([df, df_])
|
||||
sum_, max_, mean_, min_, p99_ = df.sum(), df.max(), df.mean(), df.min(), df.quantile(0.99)
|
||||
df.loc['Sum'] = sum_
|
||||
df.loc['Max'] = max_
|
||||
df.loc['Mean'] = mean_
|
||||
df.loc['Min'] = min_
|
||||
df.loc['P99'] = p99_
|
||||
return df
|
||||
|
||||
|
||||
def vllm__core__scheduler__Scheduler__save_scheduler_view(self, scheduler_idx=0):
|
||||
logger.info(f"vLLM scheduler profiling save...")
|
||||
plt.rcParams.update({'font.size': 8})
|
||||
figure = plt.figure(figsize=(6.4, 5.6))
|
||||
gs = figure.add_gridspec(3, hspace=0)
|
||||
axes = gs.subplots(sharex=True, sharey=False)
|
||||
figure.suptitle("Cambricon vLLM Scheduler View")
|
||||
# scheduler queue view
|
||||
self.df.plot(ax=axes[0], y=['waiting', 'running', 'swapped', 'finished'])
|
||||
axes[0].set_xlabel('X-LLMEngineStep', loc='left')
|
||||
axes[0].set_ylabel('Y-ReqNum', loc='top')
|
||||
# utilization
|
||||
self.df.plot(ax=axes[1], y=['batch_utils', 'block_utils', 'preempt_ratio'])
|
||||
axes[1].set_xlabel('X-LLMEngineStep', loc='left')
|
||||
axes[1].set_ylabel('Y-Utilization(%)', loc='top')
|
||||
# token view
|
||||
self.df.plot(ax=axes[2], y=['wait_to_run_tokens'])
|
||||
axes[2].set_xlabel('X-LLMEngineStep', loc='left')
|
||||
axes[2].set_ylabel('Y-TokenNum', loc='top')
|
||||
for ax in axes:
|
||||
ax.label_outer()
|
||||
ax.legend(loc='upper right')
|
||||
figure.tight_layout()
|
||||
figure.savefig(f"vllm_scheduler{scheduler_idx}_view.svg", dpi=300, format='svg')
|
||||
plt.close(figure)
|
||||
|
||||
time_df = summary_finished_seq_groups(self.finished_seq_groups)
|
||||
|
||||
sched_df = self.df.copy(deep=True)
|
||||
max_, mean_, min_ = sched_df.max(), sched_df.mean(), sched_df.min()
|
||||
sched_df.loc["Max"] = max_
|
||||
sched_df.loc["Mean"] = mean_
|
||||
sched_df.loc["Min"] = min_
|
||||
with pd.option_context('display.max_rows', None,
|
||||
'display.max_columns', None,
|
||||
'display.max_colwidth', None,
|
||||
'display.float_format', '{:^6,.2f}'.format,
|
||||
'expand_frame_repr', False):
|
||||
logger.info(sched_df.loc[["Max", "Mean", "Min"]])
|
||||
logger.info(time_df.loc[["Sum", "Max", "Mean", "Min", "P99"]])
|
||||
sched_df.astype(str).to_csv(f"vllm_scheduler{scheduler_idx}_step_view.csv", mode="w")
|
||||
time_df.astype(str).to_csv(f"vllm_scheduler{scheduler_idx}_reqs_view.csv", mode="w")
|
||||
|
||||
|
||||
def vllm__core__scheduler__Scheduler____init__(
|
||||
self,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
pipeline_parallel_size: int = 1,
|
||||
output_proc_callback: Optional[Callable] = None,
|
||||
) -> None:
|
||||
vllm__core__scheduler__Scheduler____init____org(
|
||||
self=self,
|
||||
scheduler_config=scheduler_config,
|
||||
cache_config=cache_config,
|
||||
lora_config=lora_config,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
output_proc_callback=output_proc_callback
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add for scheduler profiling
|
||||
'''
|
||||
self.init_scheduler_view()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__core__scheduler__Scheduler___schedule_prefills(
|
||||
self,
|
||||
budget: SchedulingBudget,
|
||||
curr_loras: Optional[Set[int]],
|
||||
enable_chunking: bool = False,
|
||||
) -> SchedulerPrefillOutputs:
|
||||
prefills = vllm__core__scheduler__Scheduler___schedule_prefills__org(
|
||||
self=self,
|
||||
budget=budget,
|
||||
curr_loras=curr_loras,
|
||||
enable_chunking=enable_chunking
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add for scheduler profiling
|
||||
'''
|
||||
self.waiting_to_running_seqs = len(prefills.seq_groups)
|
||||
self.wait_to_run_tokens = budget.num_batched_tokens
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return prefills
|
||||
|
||||
|
||||
def vllm__core__scheduler__Scheduler___schedule_running(
|
||||
self,
|
||||
budget: SchedulingBudget,
|
||||
curr_loras: Optional[Set[int]],
|
||||
enable_chunking: bool = False,
|
||||
) -> SchedulerRunningOutputs:
|
||||
running_scheduled = vllm__core__scheduler__Scheduler___schedule_running__org(
|
||||
self=self,
|
||||
budget=budget,
|
||||
curr_loras=curr_loras,
|
||||
enable_chunking=enable_chunking
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add for scheduler profiling
|
||||
'''
|
||||
self.running_to_waiting_seqs += len(running_scheduled.preempted)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return running_scheduled
|
||||
|
||||
|
||||
def vllm__core__scheduler__Scheduler___schedule(self) -> SchedulerOutputs:
|
||||
scheduler_outputs = vllm__core__scheduler__Scheduler___schedule__org(self)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add for scheduler profiling
|
||||
'''
|
||||
self.sched_step += 1
|
||||
self.running_seqs = len(self.running)
|
||||
self.waiting_seqs = len(self.waiting)
|
||||
self.swapped_seqs = len(self.swapped)
|
||||
|
||||
total_seqs_ = self.running_seqs + self.waiting_seqs + self.swapped_seqs + self.finished_seqs
|
||||
if total_seqs_ == 0:
|
||||
return
|
||||
|
||||
if total_seqs_ > self.total_seqs:
|
||||
self.total_seqs = total_seqs_
|
||||
|
||||
self.batch_utils = self.running_seqs / self.scheduler_config.max_num_seqs
|
||||
self.block_utils = (self.block_manager.num_total_gpu_blocks -
|
||||
self.block_manager.get_num_free_gpu_blocks()) / self.block_manager.num_total_gpu_blocks
|
||||
self.preempt_ratio = self.running_to_waiting_seqs / self.total_seqs
|
||||
|
||||
df_ = pd.DataFrame(
|
||||
[[self.waiting_seqs, self.running_seqs, self.swapped_seqs,
|
||||
self.waiting_to_running_seqs, self.running_to_waiting_seqs, self.wait_to_run_tokens,
|
||||
self.batch_utils, self.block_utils, self.preempt_ratio]],
|
||||
columns=['waiting', 'running', 'swapped',
|
||||
'wait_to_run_reqs', 'run_to_wait_reqs', 'wait_to_run_tokens',
|
||||
'batch_utils', 'block_utils', 'preempt_ratio'],
|
||||
index=[str(self.sched_step)])
|
||||
self.df = pd.concat([self.df, df_])
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return scheduler_outputs
|
||||
|
||||
|
||||
def vllm__core__scheduler__Scheduler__free_finished_seq_groups(self) -> None:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add for scheduler profiling
|
||||
'''
|
||||
finished_seq_groups_ = []
|
||||
remaining: Deque[SequenceGroup] = deque()
|
||||
for seq_group in self.running:
|
||||
self._free_finished_seq_group(seq_group)
|
||||
if not seq_group.is_finished():
|
||||
remaining.append(seq_group)
|
||||
else:
|
||||
finished_seq_groups_.append(seq_group)
|
||||
|
||||
self.finished_seqs += len(finished_seq_groups_)
|
||||
self.finished_seq_groups += finished_seq_groups_
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
self.running = remaining
|
||||
|
||||
# Handle async stopped sequence groups
|
||||
# (ones that reached max model len)
|
||||
if self._async_stopped:
|
||||
for seq_group in self._async_stopped:
|
||||
self._free_seq_group_cross_attn_blocks(seq_group)
|
||||
self._finished_requests_ids.append(seq_group.request_id)
|
||||
|
||||
# Free finished seqs
|
||||
self._free_finished_seqs(seq_group)
|
||||
|
||||
self._async_stopped.clear()
|
||||
|
||||
|
||||
def vllm__core__scheduler__Scheduler____del__(self):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add for scheduler profiling
|
||||
'''
|
||||
self.save_scheduler_view()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(Scheduler,
|
||||
Scheduler.__init__,
|
||||
vllm__core__scheduler__Scheduler____init__)
|
||||
MluHijackObject.apply_hijack(Scheduler,
|
||||
Scheduler._schedule_prefills,
|
||||
vllm__core__scheduler__Scheduler___schedule_prefills)
|
||||
MluHijackObject.apply_hijack(Scheduler,
|
||||
Scheduler._schedule_running,
|
||||
vllm__core__scheduler__Scheduler___schedule_running)
|
||||
MluHijackObject.apply_hijack(Scheduler,
|
||||
Scheduler._schedule,
|
||||
vllm__core__scheduler__Scheduler___schedule)
|
||||
MluHijackObject.apply_hijack(Scheduler,
|
||||
Scheduler.free_finished_seq_groups,
|
||||
vllm__core__scheduler__Scheduler__free_finished_seq_groups)
|
||||
MluHijackObject.apply_hijack(Scheduler,
|
||||
"__del__",
|
||||
vllm__core__scheduler__Scheduler____del__)
|
||||
MluHijackObject.apply_hijack(Scheduler,
|
||||
"init_scheduler_view",
|
||||
vllm__core__scheduler__Scheduler__init_scheduler_view)
|
||||
MluHijackObject.apply_hijack(Scheduler,
|
||||
"save_scheduler_view",
|
||||
vllm__core__scheduler__Scheduler__save_scheduler_view)
|
||||
1
vllm-v0.6.2/vllm_mlu/vllm_mlu/distributed/__init__.py
Normal file
1
vllm-v0.6.2/vllm_mlu/vllm_mlu/distributed/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
import vllm_mlu.distributed.parallel_state
|
||||
Binary file not shown.
Binary file not shown.
134
vllm-v0.6.2/vllm_mlu/vllm_mlu/distributed/parallel_state.py
Normal file
134
vllm-v0.6.2/vllm_mlu/vllm_mlu/distributed/parallel_state.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import torch
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from torch.distributed import Backend
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.distributed.parallel_state import (GroupCoordinator,
|
||||
GraphCaptureContext)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
vllm__distributed__parallel_state__GroupCoordinator____init____org = GroupCoordinator.__init__
|
||||
|
||||
|
||||
def vllm__distributed__parallel_state__GroupCoordinator____init__(
|
||||
self,
|
||||
group_ranks: List[List[int]],
|
||||
local_rank: int,
|
||||
torch_distributed_backend: Union[str, Backend],
|
||||
use_pynccl: bool,
|
||||
use_custom_allreduce: bool,
|
||||
use_tpu_communicator: bool,
|
||||
use_hpu_communicator: bool,
|
||||
use_xpu_communicator: bool,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: disable pynccl and custom_allreduce by default
|
||||
'''
|
||||
if use_pynccl or use_custom_allreduce:
|
||||
logger.debug(f"Disable pynccl and custom_allreduce when using MLU backend.")
|
||||
|
||||
vllm__distributed__parallel_state__GroupCoordinator____init____org(
|
||||
self=self,
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=torch_distributed_backend,
|
||||
use_pynccl=False,
|
||||
use_custom_allreduce=False,
|
||||
use_tpu_communicator=use_tpu_communicator,
|
||||
use_hpu_communicator=use_hpu_communicator,
|
||||
use_xpu_communicator=use_xpu_communicator,
|
||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||
group_name=group_name
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__distributed__parallel_state__GroupCoordinator__gather(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
NOTE: `dst` is the local rank of the destination rank.
|
||||
"""
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
if self.xpu_communicator is not None and \
|
||||
not self.xpu_communicator.disabled:
|
||||
return self.xpu_communicator.gather(input_, self.rank_in_group,
|
||||
dst, dim)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use p2p communication to reduce gather host time for non driver worker.
|
||||
NOTE: this hijack function should be REMOVED when torch upgrade to 2.4.0
|
||||
'''
|
||||
rank = self.rank_in_group
|
||||
|
||||
gather_list = None
|
||||
if rank == dst:
|
||||
gather_list = [
|
||||
torch.empty_like(input_) for _ in range(dst)
|
||||
] + [input_] + [
|
||||
torch.empty_like(input_) for _ in range(dst + 1, world_size)
|
||||
]
|
||||
|
||||
send_recv_op_list = []
|
||||
if rank != dst:
|
||||
op = torch.distributed.P2POp(torch.distributed.isend,
|
||||
input_,
|
||||
self.ranks[dst],
|
||||
group=self.device_group)
|
||||
send_recv_op_list.append(op)
|
||||
else:
|
||||
for r in range(0, world_size):
|
||||
if r == dst:
|
||||
continue
|
||||
op = torch.distributed.P2POp(torch.distributed.irecv,
|
||||
gather_list[r],
|
||||
self.ranks[r],
|
||||
group=self.device_group)
|
||||
send_recv_op_list.append(op)
|
||||
reqs = torch.distributed.batch_isend_irecv(send_recv_op_list)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(GroupCoordinator,
|
||||
GroupCoordinator.__init__,
|
||||
vllm__distributed__parallel_state__GroupCoordinator____init__)
|
||||
MluHijackObject.apply_hijack(GroupCoordinator,
|
||||
GroupCoordinator.gather,
|
||||
vllm__distributed__parallel_state__GroupCoordinator__gather)
|
||||
409
vllm-v0.6.2/vllm_mlu/vllm_mlu/dump_info.py
Normal file
409
vllm-v0.6.2/vllm_mlu/vllm_mlu/dump_info.py
Normal file
@@ -0,0 +1,409 @@
|
||||
from vllm.logger import init_logger
|
||||
from vllm_mlu.mlu_hijack_utils import get_is_gated, ModelConfig
|
||||
import ctypes
|
||||
import json
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm_mlu._mlu_utils import VLLM_DUMP_CPU_INFO_EN, VLLM_DUMP_MLU_INFO_EN
|
||||
logger = init_logger(__name__)
|
||||
|
||||
def get_deepseek_v2_flops(bcfg, batch, seq_len, hidden_size):
|
||||
ATTN_PAD_SIZE = 192
|
||||
qk_nope_head_dim = bcfg.qk_nope_head_dim
|
||||
qk_rope_head_dim = bcfg.qk_rope_head_dim
|
||||
v_head_dim = bcfg.v_head_dim
|
||||
q_lora_rank = bcfg.q_lora_rank
|
||||
kv_lora_rank = bcfg.kv_lora_rank
|
||||
context_atn_pre = 2 * batch * seq_len * \
|
||||
(hidden_size * q_lora_rank + \
|
||||
hidden_size * (kv_lora_rank + qk_rope_head_dim) + \
|
||||
q_lora_rank * bcfg.head_num * (qk_nope_head_dim + qk_rope_head_dim) + \
|
||||
kv_lora_rank * bcfg.head_num * (qk_nope_head_dim + v_head_dim))
|
||||
context_atn_qk = 2 * batch * seq_len * seq_len * bcfg.head_num * ATTN_PAD_SIZE
|
||||
context_atn_qkv = 2 * batch * seq_len * seq_len * bcfg.head_num * ATTN_PAD_SIZE
|
||||
context_atn_post = 2 * batch * seq_len * bcfg.head_num * v_head_dim * hidden_size
|
||||
return context_atn_pre, context_atn_qk, context_atn_qkv, context_atn_post
|
||||
|
||||
class FlopsInfo(ctypes.Structure):
|
||||
_fields_ = [("context_flops", ctypes.c_double),
|
||||
("decoder_flops", ctypes.c_double)]
|
||||
|
||||
class LLMDumpInfo:
|
||||
def __init__(self,
|
||||
tensor_parallel_size=None,
|
||||
dtype=None, kv_cache_dtype=None,
|
||||
quantization=None,
|
||||
model=None, batch_size=None,
|
||||
input_len=None,
|
||||
output_len=None,
|
||||
trust_remote_code=None)->None:
|
||||
self.so_file = None
|
||||
self.dev_info = None
|
||||
self.cpu_info = None
|
||||
self.lib = None
|
||||
self.hfu_info = None
|
||||
self.flops_info = None
|
||||
self.ctypes_model_config = ModelConfig()
|
||||
self.io_efficiency = 0
|
||||
self.context_latency_device = 0
|
||||
self.generate_latency_device = 0
|
||||
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.dtype = dtype
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.quantization = quantization
|
||||
self.batch_size = batch_size
|
||||
self.input_len = input_len
|
||||
self.output_len = output_len
|
||||
self.model = model
|
||||
self.model_config = None
|
||||
|
||||
try:
|
||||
from vllm_mlu.device_info import get_info_inner
|
||||
self.so_file,self.dev_info,self.cpu_info,self.lib = get_info_inner(self.so_file, self.dev_info, self.cpu_info, self.lib)
|
||||
except:
|
||||
logger.info("Cannot get device info")
|
||||
|
||||
def init_param(self,
|
||||
tensor_parallel_size=None,
|
||||
dtype=None,
|
||||
kv_cache_dtype=None,
|
||||
quantization=None,
|
||||
model=None,
|
||||
batch_size=None,
|
||||
input_len=None,
|
||||
output_len=None,
|
||||
trust_remote_code=None,
|
||||
context_latency_device=None,
|
||||
generate_latency_device=None):
|
||||
if tensor_parallel_size != None:
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
if dtype != None:
|
||||
self.dtype = dtype
|
||||
if kv_cache_dtype != None:
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if quantization != None:
|
||||
self.quantization = quantization
|
||||
if model != None:
|
||||
self.model = model
|
||||
if batch_size != None:
|
||||
self.batch_size = batch_size
|
||||
if input_len != None:
|
||||
self.input_len = input_len
|
||||
if output_len != None:
|
||||
self.output_len = output_len
|
||||
if trust_remote_code != None:
|
||||
self.trust_remote_code = trust_remote_code
|
||||
if context_latency_device != None:
|
||||
self.context_latency_device = context_latency_device
|
||||
if generate_latency_device != None:
|
||||
self.generate_latency_device = generate_latency_device
|
||||
|
||||
# paser the model config
|
||||
if self.model_config == None and self.model != None and self.trust_remote_code != None:
|
||||
self.model_config = get_config(self.model, self.trust_remote_code)
|
||||
|
||||
def initialize_ctypes_model_config(self, model_cfg, tp_num, weight_dtype, kv_cache_dtype, quantization):
|
||||
# prepare input
|
||||
self.ctypes_model_config.hidden_size = model_cfg.hidden_size
|
||||
self.ctypes_model_config.vocab_size = model_cfg.vocab_size
|
||||
self.ctypes_model_config.cla_coeffient = 1.0
|
||||
|
||||
possible_keys_ffn_size = [
|
||||
# chatglm3-6b-32k
|
||||
"ffn_hidden_size",
|
||||
# llama3-8b-hf
|
||||
"intermediate_size",
|
||||
]
|
||||
possible_kv_heads = [
|
||||
# chatglm3-6b-32k
|
||||
"multi_query_group_num",
|
||||
# llama3-8b-hf
|
||||
"num_key_value_heads",
|
||||
# falcon-180B-chat
|
||||
"num_kv_heads",
|
||||
]
|
||||
possible_num_attention_heads = [
|
||||
"num_attention_heads",
|
||||
"n_heads",
|
||||
]
|
||||
moe_size=None
|
||||
ffn_size=None
|
||||
if getattr(model_cfg, "moe_intermediate_size", None):
|
||||
moe_size = getattr(model_cfg, "moe_intermediate_size", None)
|
||||
for key in possible_keys_ffn_size:
|
||||
ffn_size = getattr(model_cfg, key, None)
|
||||
if ffn_size is not None:
|
||||
break
|
||||
if model_cfg.model_type in ['bloom'] and ffn_size is None:
|
||||
ffn_size = model_cfg.hidden_size * 4
|
||||
if model_cfg.model_type in ['qwen']:
|
||||
ffn_size = model_cfg.intermediate_size // 2
|
||||
if ffn_size is None and moe_size is None:
|
||||
logger.warning("The model's config.json does not contain any of the following"
|
||||
"keys to determine the ffn_size or moe_size: "
|
||||
f"{possible_keys_ffn_size}. ")
|
||||
|
||||
for key in possible_num_attention_heads:
|
||||
num_attention_heads = getattr(model_cfg, key, None)
|
||||
if num_attention_heads is not None:
|
||||
break
|
||||
if num_attention_heads is None:
|
||||
logger.error("The model's config.json does not contain any of the following"
|
||||
"keys to determine the num_attention_heads: "
|
||||
f"{possible_num_attention_heads}. ")
|
||||
|
||||
for key in possible_kv_heads:
|
||||
kv_heads = getattr(model_cfg, key, None)
|
||||
if kv_heads is not None:
|
||||
break
|
||||
|
||||
if kv_heads is None:
|
||||
logger.warning("The model's config.json does not contain any of the following"
|
||||
"keys to determine the kv_heads: "
|
||||
f"{possible_kv_heads}, use num_attention_heads to replace")
|
||||
kv_heads = model_cfg.num_attention_heads
|
||||
|
||||
self.ctypes_model_config.ffn_inner_size = 0 if ffn_size is None else ffn_size
|
||||
self.ctypes_model_config.moe_inner_size = 0 if moe_size is None else moe_size
|
||||
self.ctypes_model_config.moe_layer_num = 0 if moe_size is None else model_cfg.num_hidden_layers
|
||||
self.ctypes_model_config.layer_num = model_cfg.num_hidden_layers
|
||||
self.ctypes_model_config.head_num = num_attention_heads
|
||||
self.ctypes_model_config.head_size = self.ctypes_model_config.hidden_size / self.ctypes_model_config.head_num
|
||||
self.ctypes_model_config.head_num_kv = kv_heads
|
||||
self.ctypes_model_config.tp_num = tp_num
|
||||
if hasattr(model_cfg, "shared_expert_intermediate_size") and model_cfg.shared_expert_intermediate_size is not None:
|
||||
self.ctypes_model_config.shared_expert_intermediate_size = model_cfg.shared_expert_intermediate_size
|
||||
else:
|
||||
self.ctypes_model_config.shared_expert_intermediate_size = 0
|
||||
self.ctypes_model_config.use_gated_ffn = get_is_gated()
|
||||
if hasattr(model_cfg, "n_shared_experts") and model_cfg.n_shared_experts is not None:
|
||||
self.ctypes_model_config.shared_expert_intermediate_size = model_cfg.n_shared_experts * moe_size
|
||||
else:
|
||||
self.ctypes_model_config.shared_experts = 0
|
||||
if hasattr(model_cfg, "num_experts") and model_cfg.num_experts is not None:
|
||||
self.ctypes_model_config.experts_num = model_cfg.num_experts
|
||||
if model_cfg.model_type == 'hunyuan':
|
||||
self.ctypes_model_config.topk_num = model_cfg.moe_topk
|
||||
else:
|
||||
self.ctypes_model_config.topk_num = model_cfg.num_experts_per_tok
|
||||
elif hasattr(model_cfg, "num_local_experts"):
|
||||
self.ctypes_model_config.experts_num = model_cfg.num_local_experts
|
||||
if model_cfg.model_type == 'hunyuan':
|
||||
self.ctypes_model_config.topk_num = model_cfg.moe_topk
|
||||
else:
|
||||
self.ctypes_model_config.topk_num = model_cfg.num_experts_per_tok
|
||||
elif hasattr(model_cfg, "n_routed_experts"):
|
||||
self.ctypes_model_config.experts_num = model_cfg.n_routed_experts
|
||||
if model_cfg.model_type == 'hunyuan':
|
||||
self.ctypes_model_config.topk_num = model_cfg.moe_topk
|
||||
else:
|
||||
self.ctypes_model_config.topk_num = model_cfg.num_experts_per_tok
|
||||
else:
|
||||
self.ctypes_model_config.experts_num = 0
|
||||
if hasattr(model_cfg, "model_type") and model_cfg.model_type is not None:
|
||||
self.ctypes_model_config.model_type = model_cfg.model_type.encode('utf-8')
|
||||
# when adding a moe model, need fix moe/ffn info, like
|
||||
# moe_inner_size, ffn_inner_size, moe_layer_num, shared_expert_intermediate_size.
|
||||
# add for mixtral
|
||||
if model_cfg.model_type == "mixtral":
|
||||
self.ctypes_model_config.moe_inner_size = ffn_size
|
||||
self.ctypes_model_config.ffn_inner_size = 0
|
||||
self.ctypes_model_config.moe_layer_num = model_cfg.num_hidden_layers
|
||||
# add for deepseek-v2
|
||||
if model_cfg.model_type == "deepseek_v2":
|
||||
if hasattr(model_cfg, "first_k_dense_replace") and model_cfg.first_k_dense_replace is not None:
|
||||
self.ctypes_model_config.moe_layer_num = model_cfg.num_hidden_layers - model_cfg.first_k_dense_replace
|
||||
if hasattr(model_cfg, "qk_nope_head_dim") and model_cfg.qk_nope_head_dim is not None:
|
||||
self.ctypes_model_config.qk_nope_head_dim = model_cfg.qk_nope_head_dim
|
||||
if hasattr(model_cfg, "qk_rope_head_dim") and model_cfg.qk_rope_head_dim is not None:
|
||||
self.ctypes_model_config.qk_rope_head_dim = model_cfg.qk_rope_head_dim
|
||||
if hasattr(model_cfg, "v_head_dim") and model_cfg.v_head_dim is not None:
|
||||
self.ctypes_model_config.v_head_dim = model_cfg.v_head_dim
|
||||
if hasattr(model_cfg, "q_lora_rank") and model_cfg.q_lora_rank is not None:
|
||||
self.ctypes_model_config.q_lora_rank = model_cfg.q_lora_rank
|
||||
else:
|
||||
self.ctypes_model_config.q_lora_rank = 0
|
||||
if hasattr(model_cfg, "kv_lora_rank") and model_cfg.kv_lora_rank is not None:
|
||||
self.ctypes_model_config.kv_lora_rank = model_cfg.kv_lora_rank
|
||||
# add for Hunyuan
|
||||
if model_cfg.model_type == "hunyuan":
|
||||
self.ctypes_model_config.cla_coeffient = 0.5 # huanyuan model use CLA2
|
||||
if hasattr(model_cfg, "num_shared_expert") and model_cfg.num_shared_expert is not None:
|
||||
self.ctypes_model_config.shared_expert_intermediate_size = model_cfg.num_shared_expert * model_cfg.intermediate_size
|
||||
if not self.ctypes_model_config.moe_inner_size and model_cfg.intermediate_size is not None:
|
||||
self.ctypes_model_config.moe_inner_size = model_cfg.intermediate_size
|
||||
if not self.ctypes_model_config.moe_layer_num and hasattr(model_cfg, "num_experts"):
|
||||
self.ctypes_model_config.moe_layer_num = model_cfg.num_hidden_layers
|
||||
|
||||
self.ctypes_model_config.use_causal_mask = True # the flash attention is only use causal_mask in vllm
|
||||
|
||||
if weight_dtype == "auto":
|
||||
self.ctypes_model_config.data_type = b'float16'
|
||||
else:
|
||||
self.ctypes_model_config.data_type = weight_dtype.encode('utf-8')
|
||||
|
||||
if quantization != None:
|
||||
with open(self.model + "/quantize_config.json", 'r') as file:
|
||||
config = json.load(file)
|
||||
if config["quant_mode"] == "SmoothQuant":
|
||||
self.ctypes_model_config.smooth_quant_type = b"SmoothQuant"
|
||||
else:
|
||||
self.ctypes_model_config.smooth_quant_type = b'invalid'
|
||||
self.ctypes_model_config.filter_data_type = ("int" + str(config['bits'])).encode('utf-8')
|
||||
else:
|
||||
self.ctypes_model_config.smooth_quant_type = b'invalid'
|
||||
self.ctypes_model_config.filter_data_type = self.ctypes_model_config.data_type
|
||||
|
||||
if kv_cache_dtype == "auto":
|
||||
self.ctypes_model_config.kv_cache_dtype = self.ctypes_model_config.data_type
|
||||
else:
|
||||
self.ctypes_model_config.kv_cache_dtype = kv_cache_dtype.encode('utf-8')
|
||||
|
||||
|
||||
def get_flops(self, bcfg, once_batch, input_seq_len, output_length, flops_info):
|
||||
self.batch_size = once_batch
|
||||
seq_len = input_seq_len
|
||||
hidden_size = bcfg.hidden_size
|
||||
voc_size = bcfg.vocab_size
|
||||
ffn_size = bcfg.ffn_inner_size
|
||||
moe_size = bcfg.moe_inner_size
|
||||
shared_expert_intermediate_size = bcfg.shared_expert_intermediate_size
|
||||
layer_num = bcfg.layer_num
|
||||
out_seq = output_length
|
||||
seq_len_decode = seq_len + out_seq / 2
|
||||
r = bcfg.head_num / bcfg.head_num_kv
|
||||
bsh2 = self.batch_size * seq_len * hidden_size * hidden_size
|
||||
cla_coeffient = bcfg.cla_coeffient
|
||||
|
||||
if bcfg.model_type == b'deepseek_v2':
|
||||
context_atn_pre, context_atn_qk, context_atn_qkv, context_atn_post = (
|
||||
get_deepseek_v2_flops(bcfg, self.batch_size, seq_len, hidden_size)
|
||||
)
|
||||
else:
|
||||
context_atn_pre = 2 * bsh2 + 4 * bsh2 / r * cla_coeffient
|
||||
context_atn_qk = 2 * self.batch_size * seq_len * seq_len * hidden_size
|
||||
context_atn_qkv = 2 * self.batch_size * seq_len * seq_len * hidden_size
|
||||
context_atn_post = 2 * self.batch_size * seq_len * hidden_size * hidden_size
|
||||
context_lm_head = 2 * self.batch_size * seq_len * hidden_size * voc_size
|
||||
context_ffn = 0
|
||||
bh2 = self.batch_size * hidden_size * hidden_size
|
||||
decode_atn_pre = 2 * bh2 + 4 * bh2 / r * cla_coeffient
|
||||
decode_atn_qk = 2 * self.batch_size * seq_len_decode * hidden_size
|
||||
decode_atn_qkv = 2 * self.batch_size * seq_len_decode * hidden_size
|
||||
decode_atn_post = 2 * self.batch_size * hidden_size * hidden_size
|
||||
decode_lm_head = 2 * self.batch_size * hidden_size * voc_size
|
||||
decode_ffn = 0
|
||||
coeffient = 6 if bcfg.use_gated_ffn else 4
|
||||
if bcfg.experts_num == 0:
|
||||
context_ffn = coeffient * self.batch_size * seq_len * hidden_size * ffn_size
|
||||
decode_ffn = coeffient * self.batch_size * hidden_size * ffn_size
|
||||
else:
|
||||
context_ffn = self.batch_size * seq_len * hidden_size * (coeffient * (moe_size * bcfg.topk_num + shared_expert_intermediate_size) + 2 * bcfg.experts_num)
|
||||
decode_ffn = self.batch_size * hidden_size * (coeffient * (moe_size * bcfg.topk_num + shared_expert_intermediate_size) + 2 * bcfg.experts_num)
|
||||
|
||||
if bcfg.use_causal_mask:
|
||||
c = 0.5
|
||||
context_atn_qk *= c
|
||||
context_atn_qkv *= c
|
||||
|
||||
flops_info.context_flops = context_lm_head
|
||||
flops_info.decoder_flops = decode_lm_head
|
||||
if bcfg.kv_cache_dtype != b"int8":
|
||||
flops_info.context_flops += (layer_num * (context_atn_qk + context_atn_qkv))
|
||||
flops_info.decoder_flops += (layer_num * (decode_atn_qk + decode_atn_qkv))
|
||||
else:
|
||||
flops_info.context_flops += (layer_num * (context_atn_qk + context_atn_qkv))
|
||||
flops_info.decoder_flops += (layer_num * (decode_atn_qk + decode_atn_qkv))
|
||||
|
||||
if bcfg.smooth_quant_type == b"invalid":
|
||||
flops_info.context_flops += (layer_num * (context_atn_pre + context_atn_post + context_ffn))
|
||||
flops_info.decoder_flops += (layer_num * (decode_atn_pre + decode_atn_post + decode_ffn))
|
||||
else:
|
||||
flops_info.context_flops += (layer_num * (context_atn_pre + context_atn_post + context_ffn))
|
||||
flops_info.decoder_flops += (layer_num * (decode_atn_pre + decode_atn_post + decode_ffn))
|
||||
|
||||
def capture_cpu_info(self):
|
||||
if VLLM_DUMP_CPU_INFO_EN and self.cpu_info:
|
||||
try:
|
||||
from vllm_mlu.device_info import capture_cpu_info
|
||||
self.cpu_info = capture_cpu_info(self.cpu_info, my_rank=0)
|
||||
except:
|
||||
logger.info("Unsupport capture_cpu_info function")
|
||||
|
||||
def memory_usage(self):
|
||||
if VLLM_DUMP_CPU_INFO_EN and self.cpu_info:
|
||||
try:
|
||||
from vllm_mlu.device_info import memory_usage
|
||||
self.cpu_info = memory_usage(self.cpu_info)
|
||||
except:
|
||||
logger.info("Unsupport memory_usage function")
|
||||
|
||||
def analyze_perf_data(self, rank=0):
|
||||
try:
|
||||
from vllm_mlu.device_info import analyze_perf_data
|
||||
analyze_perf_data(self.cpu_info, self.lib)
|
||||
except:
|
||||
logger.info("Cannot analyze perf data, no analyze_perf_data function")
|
||||
|
||||
def get_decoder_io_efficiency(self, ctypes_model_config, lib, batch_size, input_len, output_len, generate_latency_device):
|
||||
try:
|
||||
from vllm_mlu.device_info import get_decoder_io_efficiency
|
||||
self.io_efficiency = get_decoder_io_efficiency(ctypes_model_config, lib, batch_size, input_len, output_len, generate_latency_device)
|
||||
except:
|
||||
logger.info("Unsupport io_efficiency get_decoder_io_efficiency function")
|
||||
|
||||
def get_device_output_info(self,
|
||||
model_cfg,
|
||||
batch_size,
|
||||
input_seq_len,
|
||||
output_length,
|
||||
tp_num,
|
||||
weight_dtype,
|
||||
kv_cache_dtype,
|
||||
quantization):
|
||||
self.initialize_ctypes_model_config(model_cfg, tp_num, weight_dtype, kv_cache_dtype, quantization)
|
||||
if VLLM_DUMP_CPU_INFO_EN and self.so_file:
|
||||
self.analyze_perf_data()
|
||||
if VLLM_DUMP_MLU_INFO_EN and self.lib:
|
||||
from vllm_mlu.device_info import get_flops_inner, HFUInfo
|
||||
self.hfu_info = HFUInfo()
|
||||
get_flops_inner(self.ctypes_model_config, batch_size, input_seq_len, output_length, tp_num, self.hfu_info, self.lib)
|
||||
self.get_decoder_io_efficiency(self.ctypes_model_config,
|
||||
self.lib,
|
||||
self.batch_size,
|
||||
self.input_len,
|
||||
self.output_len,
|
||||
self.generate_latency_device)
|
||||
else:
|
||||
self.flops_info = FlopsInfo()
|
||||
self.get_flops(self.ctypes_model_config, batch_size, input_seq_len, output_length, self.flops_info)
|
||||
|
||||
def has_information_dump(self):
|
||||
if self.dev_info and self.dev_info.so_file:
|
||||
return True
|
||||
return False
|
||||
|
||||
def dump(self):
|
||||
self.get_device_output_info(self.model_config,
|
||||
self.batch_size,
|
||||
self.input_len,
|
||||
self.output_len,
|
||||
self.tensor_parallel_size,
|
||||
self.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.quantization)
|
||||
try:
|
||||
from vllm_mlu.device_info import dump
|
||||
dump(LLM.dump_info)
|
||||
except:
|
||||
logger.info("Unsupport dump device/cpu information")
|
||||
|
||||
def dump_performance_info(self):
|
||||
try:
|
||||
from vllm_mlu.device_info import dump_information
|
||||
dump_information(LLM.dump_info)
|
||||
except:
|
||||
logger.info("Unsupport dump performance information")
|
||||
|
||||
|
||||
2
vllm-v0.6.2/vllm_mlu/vllm_mlu/engine/__init__.py
Normal file
2
vllm-v0.6.2/vllm_mlu/vllm_mlu/engine/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
import vllm_mlu.engine.arg_utils
|
||||
import vllm_mlu.engine.llm_engine
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
120
vllm-v0.6.2/vllm_mlu/vllm_mlu/engine/arg_utils.py
Normal file
120
vllm-v0.6.2/vllm_mlu/vllm_mlu/engine/arg_utils.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
|
||||
from vllm_mlu._mlu_utils import (BlockSizeInfo, USE_PAGED, get_device_name)
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
vllm__engine__arg_utils__EngineArgs__create_model_config_org = EngineArgs.create_model_config
|
||||
vllm__engine__arg_utils__EngineArgs__create_engine_config_org = EngineArgs.create_engine_config
|
||||
vllm__engine__arg_utils__EngineArgs__add_cli_args_org = EngineArgs.add_cli_args
|
||||
vllm_engine__arg_utils__EngineArgs____post_init__org = EngineArgs.__post_init__
|
||||
|
||||
|
||||
def vllm_engine__arg_utils__EngineArgs____post_init__(self,) -> None:
|
||||
'''
|
||||
=============================
|
||||
Add by vllm_mlu
|
||||
=============================
|
||||
@brief: 1. In MLU3XX device, when the tensor_parallel_size > 1, the enforce_eager is forced to set False.
|
||||
2. For unpaged mode, set default block_size=2048.
|
||||
'''
|
||||
unsupport_graph_device = "3" in get_device_name()
|
||||
if unsupport_graph_device and self.tensor_parallel_size > 1 and self.enforce_eager != True:
|
||||
self.enforce_eager = True
|
||||
logger.warning("The current device only support eager mode, when the tensor_parallel_size > 1. "
|
||||
"The param enforce_eager is forced to set True")
|
||||
|
||||
if not USE_PAGED and self.block_size == 16:
|
||||
self.block_size = 2048
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
vllm_engine__arg_utils__EngineArgs____post_init__org(self)
|
||||
|
||||
|
||||
def vllm__engine__arg_utils__EngineArgs__create_model_config(self) -> ModelConfig:
|
||||
model_config = vllm__engine__arg_utils__EngineArgs__create_model_config_org(self)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: set context mlugraph info for model config
|
||||
'''
|
||||
model_config.set_context_mlugraph_info(
|
||||
getattr(self, "enable_context_mlugraph", False),
|
||||
getattr(self, "context_batch_size_to_capture", None),
|
||||
getattr(self, "context_seq_len_to_capture", None))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return model_config
|
||||
|
||||
|
||||
def vllm__engine__arg_utils__EngineArgs__create_engine_config(self) -> VllmConfig:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: disable custom_all_reduce, re-set block_size to support paged and unpaged mode.
|
||||
'''
|
||||
# MLU not support custom all reduce
|
||||
self.disable_custom_all_reduce = True
|
||||
BlockSizeInfo.set_block_size(self.block_size)
|
||||
if not USE_PAGED and self.enable_chunked_prefill:
|
||||
raise ValueError("Not support chunked_prefill in unpaged mode.")
|
||||
|
||||
engine_config = vllm__engine__arg_utils__EngineArgs__create_engine_config_org(self)
|
||||
engine_config.cache_config.block_size = BlockSizeInfo.BLOCK_SIZE
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return engine_config
|
||||
|
||||
|
||||
@staticmethod
|
||||
def vllm__engine__arg_utils__EngineArgs__add_cli_args(
|
||||
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser = vllm__engine__arg_utils__EngineArgs__add_cli_args_org(parser)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: 1. remove block_size choices, set default value to -1
|
||||
2. add kv_cache_dtype choices of 'int8'
|
||||
'''
|
||||
for action in parser._actions:
|
||||
if action.dest == "block_size":
|
||||
action.choices = None
|
||||
action.default = -1
|
||||
elif action.dest == "kv_cache_dtype":
|
||||
action.choices += ['int8']
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return parser
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs.__post_init__,
|
||||
vllm_engine__arg_utils__EngineArgs____post_init__)
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs.create_model_config,
|
||||
vllm__engine__arg_utils__EngineArgs__create_model_config)
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs.create_engine_config,
|
||||
vllm__engine__arg_utils__EngineArgs__create_engine_config)
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs.add_cli_args,
|
||||
vllm__engine__arg_utils__EngineArgs__add_cli_args)
|
||||
35
vllm-v0.6.2/vllm_mlu/vllm_mlu/engine/async_llm_engine.py
Normal file
35
vllm-v0.6.2/vllm_mlu/vllm_mlu/engine/async_llm_engine.py
Normal file
@@ -0,0 +1,35 @@
|
||||
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# for client init/reset server scheduler profile data
|
||||
async def vllm__engine__async_llm_engine__AsyncLLMEngine__init_scheduler_view(self):
|
||||
for scheduler in self.engine.scheduler:
|
||||
if hasattr(scheduler, "init_scheduler_view"):
|
||||
scheduler.init_scheduler_view()
|
||||
else:
|
||||
logger.warning("Can not find any scheduler view, " +
|
||||
"please 'export VLLM_SCHEDULER_PROFILE=true' first.")
|
||||
|
||||
|
||||
# for client pulling server scheduler profile data
|
||||
async def vllm__engine__async_llm_engine__AsyncLLMEngine__save_scheduler_view(self):
|
||||
for idx, scheduler in enumerate(self.engine.scheduler):
|
||||
if hasattr(scheduler, "save_scheduler_view"):
|
||||
scheduler.save_scheduler_view(idx)
|
||||
else:
|
||||
logger.warning("Can not find any scheduler view, " +
|
||||
"please 'export VLLM_SCHEDULER_PROFILE=true' first.")
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(AsyncLLMEngine,
|
||||
"init_scheduler_view",
|
||||
vllm__engine__async_llm_engine__AsyncLLMEngine__init_scheduler_view)
|
||||
MluHijackObject.apply_hijack(AsyncLLMEngine,
|
||||
"save_scheduler_view",
|
||||
vllm__engine__async_llm_engine__AsyncLLMEngine__save_scheduler_view)
|
||||
209
vllm-v0.6.2/vllm_mlu/vllm_mlu/engine/llm_engine.py
Normal file
209
vllm-v0.6.2/vllm_mlu/vllm_mlu/engine/llm_engine.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import time
|
||||
from typing import Optional, List, Union, Mapping
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm_mlu._mlu_utils import USE_PAGED, BlockSizeInfo
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.logger import init_logger
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.utils import deprecate_kwargs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@deprecate_kwargs(
|
||||
"inputs",
|
||||
additional_message="Please use the 'prompt' parameter instead.",
|
||||
)
|
||||
def vllm_engine__llm_engine__LLMEngine__add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[PromptType] = None,
|
||||
params: Optional[Union[SamplingParams, PoolingParams]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> None:
|
||||
"""Add a request to the engine's request pool.
|
||||
|
||||
The request is added to the request pool and will be processed by the
|
||||
scheduler as `engine.step()` is called. The exact scheduling policy is
|
||||
determined by the scheduler.
|
||||
|
||||
Args:
|
||||
request_id: The unique ID of the request.
|
||||
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each input.
|
||||
params: Parameters for sampling or pooling.
|
||||
:class:`~vllm.SamplingParams` for text generation.
|
||||
:class:`~vllm.PoolingParams` for pooling.
|
||||
arrival_time: The arrival time of the request. If None, we use
|
||||
the current monotonic time.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
priority: The priority of the request.
|
||||
Only applicable with priority scheduling.
|
||||
|
||||
Details:
|
||||
- Set arrival_time to the current time if it is None.
|
||||
- Set prompt_token_ids to the encoded prompt if it is None.
|
||||
- Create `n` number of :class:`~vllm.Sequence` objects.
|
||||
- Create a :class:`~vllm.SequenceGroup` object
|
||||
from the list of :class:`~vllm.Sequence`.
|
||||
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
|
||||
|
||||
Example:
|
||||
>>> # initialize engine
|
||||
>>> engine = LLMEngine.from_engine_args(engine_args)
|
||||
>>> # set request arguments
|
||||
>>> example_prompt = "Who is the president of the United States?"
|
||||
>>> sampling_params = SamplingParams(temperature=0.0)
|
||||
>>> request_id = 0
|
||||
>>>
|
||||
>>> # add the request to the engine
|
||||
>>> engine.add_request(
|
||||
>>> str(request_id),
|
||||
>>> example_prompt,
|
||||
>>> SamplingParams(temperature=0.0))
|
||||
>>> # continue the request processing
|
||||
>>> ...
|
||||
"""
|
||||
if inputs is not None:
|
||||
prompt = inputs
|
||||
assert prompt is not None and params is not None
|
||||
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
|
||||
if priority != 0 and not self.scheduler_config.policy == "priority":
|
||||
raise ValueError(f"Got priority {priority} but "
|
||||
"Priority scheduling is not enabled.")
|
||||
|
||||
if isinstance(params, SamplingParams) \
|
||||
and (params.guided_decoding or params.logits_processors) \
|
||||
and self.scheduler_config.num_scheduler_steps > 1:
|
||||
raise ValueError(
|
||||
"Guided decoding and logits processors are not supported "
|
||||
"in multi-step decoding")
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self._validate_token_prompt(
|
||||
prompt,
|
||||
tokenizer=self.get_tokenizer(lora_request=lora_request))
|
||||
|
||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Added by vllm_mlu
|
||||
=============================
|
||||
@brief: check input_len + output_len <= block_size
|
||||
'''
|
||||
def check_block_size_valid(input_len, output_len):
|
||||
if BlockSizeInfo.BLOCK_SIZE < input_len + output_len:
|
||||
raise ValueError(f"BLOCK_SIZE({BlockSizeInfo.BLOCK_SIZE}) can't smaller than " +
|
||||
f"input_len({input_len}) + output_len({output_len}).")
|
||||
|
||||
if isinstance(params, SamplingParams):
|
||||
output_len = params.max_tokens
|
||||
|
||||
# Check for 'prompt_token_ids' in different levels of processed_inputs
|
||||
if not USE_PAGED:
|
||||
for key in ['prompt_token_ids', 'encoder', 'decoder']:
|
||||
if key in processed_inputs:
|
||||
if key == 'prompt_token_ids':
|
||||
input_len = len(processed_inputs[key])
|
||||
elif isinstance(processed_inputs[key], dict) and 'prompt_token_ids' in processed_inputs[key]:
|
||||
input_len = len(processed_inputs[key]['prompt_token_ids'])
|
||||
else:
|
||||
continue
|
||||
|
||||
check_block_size_valid(input_len, output_len)
|
||||
'''
|
||||
==================
|
||||
End of modification
|
||||
==================
|
||||
'''
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
|
||||
def vllm__engine__llm_engine__LLMEngine__get_latency(self):
|
||||
latency = self.model_executor.get_latency()
|
||||
return latency
|
||||
|
||||
|
||||
def vllm__engine__llm_engine__LLMEngine__get_memory_usage(self):
|
||||
return self.model_executor.get_memory_usage()
|
||||
|
||||
|
||||
def vllm__engine__llm_engine__LLMEngine__get_block_usage(self):
|
||||
assert len(self.scheduler) == 1, f"Only support pipeline_parallel_size=1."
|
||||
num_free_gpu_blocks = self.scheduler[0].block_manager.get_num_free_gpu_blocks()
|
||||
num_free_cpu_blocks = self.scheduler[0].block_manager.get_num_free_cpu_blocks()
|
||||
return (num_free_gpu_blocks, num_free_cpu_blocks)
|
||||
|
||||
|
||||
# for client init/reset server scheduler profile data
|
||||
def vllm__engine__llm_engine__LLMEngine__init_scheduler_view(self):
|
||||
for scheduler in self.scheduler:
|
||||
if hasattr(scheduler, "init_scheduler_view"):
|
||||
scheduler.init_scheduler_view()
|
||||
else:
|
||||
logger.warning("Can not find any scheduler view, " +
|
||||
"please 'export VLLM_SCHEDULER_PROFILE=true' first.")
|
||||
|
||||
|
||||
# for client pulling server scheduler profile data
|
||||
def vllm__engine__llm_engine__LLMEngine__save_scheduler_view(self):
|
||||
for idx, scheduler in enumerate(self.scheduler):
|
||||
if hasattr(scheduler, "save_scheduler_view"):
|
||||
scheduler.save_scheduler_view(idx)
|
||||
else:
|
||||
logger.warning("Can not find any scheduler view, " +
|
||||
"please 'export VLLM_SCHEDULER_PROFILE=true' first.")
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(LLMEngine,
|
||||
"init_scheduler_view",
|
||||
vllm__engine__llm_engine__LLMEngine__init_scheduler_view)
|
||||
MluHijackObject.apply_hijack(LLMEngine,
|
||||
"save_scheduler_view",
|
||||
vllm__engine__llm_engine__LLMEngine__save_scheduler_view)
|
||||
MluHijackObject.apply_hijack(LLMEngine,
|
||||
LLMEngine.add_request,
|
||||
vllm_engine__llm_engine__LLMEngine__add_request)
|
||||
MluHijackObject.apply_hijack(LLMEngine,
|
||||
"get_latency",
|
||||
vllm__engine__llm_engine__LLMEngine__get_latency)
|
||||
MluHijackObject.apply_hijack(LLMEngine,
|
||||
"get_memory_usage",
|
||||
vllm__engine__llm_engine__LLMEngine__get_memory_usage)
|
||||
MluHijackObject.apply_hijack(LLMEngine,
|
||||
"get_block_usage",
|
||||
vllm__engine__llm_engine__LLMEngine__get_block_usage)
|
||||
@@ -0,0 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RPCSchedulerProfileRequest(Enum):
|
||||
INIT_SCHEDULER_VIEW = 1
|
||||
SAVE_SCHEDULER_VIEW = 2
|
||||
@@ -0,0 +1,32 @@
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm_mlu.engine.multiprocessing import RPCSchedulerProfileRequest
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class MQLLMEngineClient_V2(MQLLMEngineClient):
|
||||
|
||||
async def init_scheduler_view(self):
|
||||
"""Send INIT_SCHEDULER_VIEW request to RPC Server."""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCSchedulerProfileRequest.INIT_SCHEDULER_VIEW,
|
||||
socket=self.input_socket)
|
||||
|
||||
|
||||
async def save_scheduler_view(self):
|
||||
"""Send SAVE_SCHEDULER_VIEW request to RPC Server."""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCSchedulerProfileRequest.SAVE_SCHEDULER_VIEW,
|
||||
socket=self.input_socket)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MQLLMEngineClient,
|
||||
"init_scheduler_view",
|
||||
MQLLMEngineClient_V2.init_scheduler_view)
|
||||
MluHijackObject.apply_hijack(MQLLMEngineClient,
|
||||
"save_scheduler_view",
|
||||
MQLLMEngineClient_V2.save_scheduler_view)
|
||||
183
vllm-v0.6.2/vllm_mlu/vllm_mlu/engine/multiprocessing/engine.py
Normal file
183
vllm-v0.6.2/vllm_mlu/vllm_mlu/engine/multiprocessing/engine.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import pickle
|
||||
from typing import Iterator, List, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
import zmq
|
||||
|
||||
from vllm import SamplingParams
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.engine.multiprocessing import (RPCAbortRequest, RPCProcessRequest,
|
||||
RPCUProfileRequest)
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.engine.multiprocessing import (IPC_DATA_EXT, IPC_HEALTH_EXT,
|
||||
IPC_INPUT_EXT, IPC_OUTPUT_EXT,
|
||||
RPCAbortRequest, RPCProcessRequest,
|
||||
RPCUProfileRequest)
|
||||
from vllm.engine.multiprocessing.engine import (MQLLMEngine,
|
||||
POLLING_TIMEOUT_MS)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm_mlu.engine.multiprocessing import RPCSchedulerProfileRequest
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
vllm__engine__multiprocessing__engine__MQLLMEngine____init____org = MQLLMEngine.__init__
|
||||
|
||||
class MQLLMEngine_V2(MQLLMEngine):
|
||||
|
||||
def __init__(self,
|
||||
ipc_path: str,
|
||||
use_async_sockets: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
**kwargs) -> None:
|
||||
# For MQLLMEngine, we can use cached outputs, since each new request
|
||||
# output is immediately pickled and send over the socket, which frees
|
||||
# the python object to be reused again.
|
||||
kwargs['use_cached_outputs'] = True
|
||||
|
||||
self.engine = LLMEngine(*args, **kwargs)
|
||||
self.log_requests = log_requests
|
||||
|
||||
self.use_async_sockets = use_async_sockets
|
||||
if self.use_async_sockets:
|
||||
self.engine.process_request_outputs_callback = \
|
||||
self._async_socket_engine_callback
|
||||
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
|
||||
# Receive input from the client.
|
||||
self.input_socket = self.ctx.socket(zmq.constants.PULL)
|
||||
self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")
|
||||
|
||||
# Send output stream back to client.
|
||||
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
|
||||
|
||||
# Send heartbeats back to client.
|
||||
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||
|
||||
# IPC path for the data socket.
|
||||
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
||||
|
||||
# Error state.
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
self.collect_scheduler_view = False
|
||||
|
||||
def run_engine_loop(self):
|
||||
"""Core busy loop of the LLMEngine."""
|
||||
|
||||
while True:
|
||||
if not self.engine.has_unfinished_requests():
|
||||
# Poll until there is work to do.
|
||||
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||
# When there's no work, check on engine health and send
|
||||
# health status back to client
|
||||
self._health_check()
|
||||
self.engine.do_log_stats()
|
||||
logger.debug("Waiting for new requests in engine loop.")
|
||||
|
||||
# Handle any input from the client.
|
||||
self.handle_new_input()
|
||||
|
||||
'''
|
||||
=============================
|
||||
Add by vllm_mlu
|
||||
=============================
|
||||
@brief: support scheduler view
|
||||
'''
|
||||
if self.collect_scheduler_view:
|
||||
self.collect_scheduler_view = False
|
||||
continue
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# Engine step.
|
||||
request_outputs = self.engine_step()
|
||||
|
||||
# Send request outputs (if async, done in engine_step callback).
|
||||
if not self.use_async_sockets:
|
||||
self._send_outputs(request_outputs)
|
||||
|
||||
def handle_new_input(self):
|
||||
"""Handle new input from the socket"""
|
||||
try:
|
||||
while self.input_socket.poll(timeout=0) != 0:
|
||||
frames = self.input_socket.recv_multipart(copy=False)
|
||||
request = pickle.loads(frames[0].buffer)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Add by vllm_mlu
|
||||
=============================
|
||||
@brief: support scheduler view
|
||||
'''
|
||||
|
||||
if isinstance(request, RPCProcessRequest):
|
||||
if len(frames) > 1:
|
||||
# Use cloudpickle for logits processors
|
||||
assert isinstance(request.params, SamplingParams)
|
||||
lprocs = cloudpickle.loads(frames[1].buffer)
|
||||
request.params.logits_processors = lprocs
|
||||
self._handle_process_request(request)
|
||||
elif isinstance(request, RPCAbortRequest):
|
||||
self._handle_abort_request(request)
|
||||
elif isinstance(request, RPCUProfileRequest):
|
||||
if request == RPCUProfileRequest.START_PROFILE:
|
||||
self.start_profile()
|
||||
else:
|
||||
self.stop_profile()
|
||||
elif isinstance(request, RPCSchedulerProfileRequest):
|
||||
self.collect_scheduler_view = True
|
||||
if request == RPCSchedulerProfileRequest.INIT_SCHEDULER_VIEW:
|
||||
self.init_scheduler_view()
|
||||
elif request == RPCSchedulerProfileRequest.SAVE_SCHEDULER_VIEW:
|
||||
self.save_scheduler_view()
|
||||
else:
|
||||
raise ValueError("Unknown RPCRequest Type: "
|
||||
f"{type(request)}")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
self._send_unhealthy(e)
|
||||
raise e
|
||||
|
||||
def init_scheduler_view(self):
|
||||
"""Init scheduler view."""
|
||||
self.engine.init_scheduler_view()
|
||||
|
||||
def save_scheduler_view(self):
|
||||
"""Save scheduler view."""
|
||||
self.engine.save_scheduler_view()
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MQLLMEngine,
|
||||
MQLLMEngine.__init__,
|
||||
MQLLMEngine_V2.__init__)
|
||||
MluHijackObject.apply_hijack(MQLLMEngine,
|
||||
MQLLMEngine.run_engine_loop,
|
||||
MQLLMEngine_V2.run_engine_loop)
|
||||
MluHijackObject.apply_hijack(MQLLMEngine,
|
||||
MQLLMEngine.handle_new_input,
|
||||
MQLLMEngine_V2.handle_new_input)
|
||||
MluHijackObject.apply_hijack(MQLLMEngine,
|
||||
"init_scheduler_view",
|
||||
MQLLMEngine_V2.init_scheduler_view)
|
||||
MluHijackObject.apply_hijack(MQLLMEngine,
|
||||
"save_scheduler_view",
|
||||
MQLLMEngine_V2.save_scheduler_view)
|
||||
1
vllm-v0.6.2/vllm_mlu/vllm_mlu/entrypoints/__init__.py
Normal file
1
vllm-v0.6.2/vllm_mlu/vllm_mlu/entrypoints/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
import vllm_mlu.entrypoints.llm
|
||||
Binary file not shown.
Binary file not shown.
313
vllm-v0.6.2/vllm_mlu/vllm_mlu/entrypoints/llm.py
Normal file
313
vllm-v0.6.2/vllm_mlu/vllm_mlu/entrypoints/llm.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from typing import Optional, List, Union, Dict, Any
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
|
||||
TaskOption)
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, deprecate_args
|
||||
from vllm_mlu._mlu_utils import VLLM_LATENCY_DEBUG_EN, VLLM_LATENCY_DEBUG_WITH_DEVICE_EN
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.mlu_metric import LLMMetric
|
||||
from vllm_mlu.dump_info import LLMDumpInfo
|
||||
from vllm.logger import init_logger
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@deprecate_args(
|
||||
start_index=2, # Ignore self and model
|
||||
is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
|
||||
additional_message=(
|
||||
"All positional arguments other than `model` will be "
|
||||
"replaced with keyword arguments in an upcoming version."),
|
||||
)
|
||||
def vllm__entrypoints__llm__LLM____init__(
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: Optional[str] = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
allowed_local_media_path: str = "",
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: float = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
disable_async_output_proc: bool = False,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
# After positional args are removed, move this right below `model`
|
||||
task: TaskOption = "auto",
|
||||
override_pooler_config: Optional[PoolerConfig] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
'''
|
||||
LLM constructor.
|
||||
|
||||
Note: if enforce_eager is unset (enforce_eager is None)
|
||||
it defaults to False.
|
||||
'''
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: 1) Initialize LLMDumpInfo
|
||||
2) Initialize context mlugraph params
|
||||
'''
|
||||
LLM.dump_info.init_param(
|
||||
tensor_parallel_size=tensor_parallel_size, dtype=dtype,
|
||||
kv_cache_dtype=kwargs.get('kv_cache_dtype', 'default_value'),
|
||||
quantization=quantization,
|
||||
model=model, trust_remote_code=kwargs.get('trust_remote_code', 'default_value'))
|
||||
|
||||
enable_context_mlugraph = kwargs.pop("enable_context_mlugraph", False)
|
||||
context_batch_size_to_capture = kwargs.pop("context_batch_size_to_capture", None)
|
||||
context_seq_len_to_capture = kwargs.pop("context_seq_len_to_capture", None)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
task=task,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
trust_remote_code=trust_remote_code,
|
||||
allowed_local_media_path=allowed_local_media_path,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
revision=revision,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
seed=seed,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
cpu_offload_gb=cpu_offload_gb,
|
||||
enforce_eager=enforce_eager,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
hf_overrides=hf_overrides,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
override_pooler_config=override_pooler_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: set context mlugraph params for EngineArgs
|
||||
'''
|
||||
setattr(engine_args, "enable_context_mlugraph", enable_context_mlugraph)
|
||||
setattr(engine_args, "context_batch_size_to_capture", context_batch_size_to_capture)
|
||||
setattr(engine_args, "context_seq_len_to_capture", context_seq_len_to_capture)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# Logic to switch between engines is done at runtime instead of import
|
||||
# to avoid import order issues
|
||||
self.engine_class = self.get_engine_class()
|
||||
|
||||
# TODO(rob): enable mp by default (issue with fork vs spawn)
|
||||
self.llm_engine = self.engine_class.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.LLM_CLASS)
|
||||
|
||||
self.request_counter = Counter()
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Get Cpuinfo member for vllm
|
||||
'''
|
||||
LLM.dump_info.memory_usage()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__entrypoints__llm__LLM__get_metrics(
|
||||
self,
|
||||
metrics_idx_start,
|
||||
only_average,
|
||||
input_len,
|
||||
output_len,
|
||||
tp_nums,
|
||||
quantization,
|
||||
dump_info=None,
|
||||
show_per_iter=False,
|
||||
) -> None:
|
||||
'''
|
||||
@brief:该函数用来打印vLLM调用generate接口过程中代码统计的各项性能指标数据
|
||||
@params:
|
||||
metrics_idx_start: 考虑存在调用generate接口为warmup过程的情况,
|
||||
因此设置该参数可忽略统计[0,metrics_idx_start)之间的数据,默认为0,即所有性能数据有效。
|
||||
only_average: True 只打印N次调用generate接口的平均性能 False 打印每次调用generate接口的性能及其均值 若N次性能数据波动较大,需自行排查测试环境是否稳定。
|
||||
其余参数:均为模型配置参数
|
||||
'''
|
||||
if VLLM_LATENCY_DEBUG_EN:
|
||||
self.metric.calc_metric(self.llm_engine.model_config.model,
|
||||
self.llm_engine.model_config.dtype,
|
||||
metrics_idx_start, only_average,
|
||||
input_len, output_len, tp_nums,
|
||||
quantization, dump_info, show_per_iter)
|
||||
else:
|
||||
print("Warnning:please set VLLM_LATENCY_DEBUG=true!")
|
||||
|
||||
|
||||
def vllm__entrypoints__llm__LLM___run_engine(
|
||||
self, *, use_tqdm: bool
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||
pbar = tqdm(
|
||||
total=num_requests,
|
||||
desc="Processed prompts",
|
||||
dynamic_ncols=True,
|
||||
postfix=(f"est. speed input: {0:.2f} toks/s, "
|
||||
f"output: {0:.2f} toks/s"),
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Added by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
is_latency_debug = VLLM_LATENCY_DEBUG_EN
|
||||
# Record start
|
||||
if is_latency_debug:
|
||||
total_request_num = self.llm_engine.get_num_unfinished_requests()
|
||||
self.dump_info.capture_cpu_info()
|
||||
peak_memory, block_memory, num_total_gpu_blocks, num_total_cpu_blocks = \
|
||||
self.llm_engine.get_memory_usage()
|
||||
self.metric.update_memory_usage(peak_memory, block_memory, num_total_gpu_blocks, num_total_cpu_blocks)
|
||||
e2e_start_time = self.metric.get_mlu_cost_time()
|
||||
'''
|
||||
==================
|
||||
End of addition
|
||||
==================
|
||||
'''
|
||||
|
||||
# Run the engine.
|
||||
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
|
||||
total_in_toks = 0
|
||||
total_out_toks = 0
|
||||
while self.llm_engine.has_unfinished_requests():
|
||||
'''
|
||||
=============================
|
||||
Added by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
if is_latency_debug :
|
||||
self.dump_info.memory_usage()
|
||||
start_time = self.metric.get_mlu_cost_time()
|
||||
'''
|
||||
==================
|
||||
End of addition
|
||||
==================
|
||||
'''
|
||||
step_outputs = self.llm_engine.step()
|
||||
'''
|
||||
=============================
|
||||
Added by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
if is_latency_debug:
|
||||
end_time = self.metric.get_mlu_cost_time()
|
||||
step_latency = end_time - start_time
|
||||
if len(step_outputs) > 0:
|
||||
batch_size = len(step_outputs)
|
||||
assert batch_size == total_request_num, \
|
||||
f"LLM has received {total_request_num} requests, but only processed {batch_size} requests in the current step.\n" + \
|
||||
f"If you are running benchmark_latency test, please check if the input is correct.\n" + \
|
||||
f"Otherwise, please set env VLLM_LATENCY_DEBUG=false, then run test again.\n"
|
||||
num_free_gpu_blocks, num_free_cpu_blocks = self.llm_engine.get_block_usage()
|
||||
self.metric.update_step_block_usage(num_free_gpu_blocks, num_free_cpu_blocks)
|
||||
self.metric.update_step_latency(step_latency)
|
||||
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
||||
self.metric.update_step_latency_device(self.llm_engine.get_latency())
|
||||
self.dump_info.memory_usage()
|
||||
'''
|
||||
==================
|
||||
End of addition
|
||||
==================
|
||||
'''
|
||||
for output in step_outputs:
|
||||
if output.finished:
|
||||
outputs.append(output)
|
||||
if use_tqdm:
|
||||
if isinstance(output, RequestOutput):
|
||||
# Calculate tokens only for RequestOutput
|
||||
assert output.prompt_token_ids is not None
|
||||
total_in_toks += len(output.prompt_token_ids)
|
||||
in_spd = total_in_toks / pbar.format_dict["elapsed"]
|
||||
total_out_toks += sum(
|
||||
len(stp.token_ids) for stp in output.outputs)
|
||||
out_spd = (total_out_toks /
|
||||
pbar.format_dict["elapsed"])
|
||||
pbar.postfix = (
|
||||
f"est. speed input: {in_spd:.2f} toks/s, "
|
||||
f"output: {out_spd:.2f} toks/s")
|
||||
pbar.update(1)
|
||||
'''
|
||||
=============================
|
||||
Added by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
if is_latency_debug:
|
||||
e2e_end_time = self.metric.get_mlu_cost_time()
|
||||
e2e_latency = e2e_end_time - e2e_start_time
|
||||
self.metric.add_metrics(batch_size, e2e_latency)
|
||||
'''
|
||||
==================
|
||||
End of addition
|
||||
==================
|
||||
'''
|
||||
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
# Sort the outputs by request ID.
|
||||
# This is necessary because some requests may be finished earlier than
|
||||
# its previous requests.
|
||||
return sorted(outputs, key=lambda x: int(x.request_id))
|
||||
|
||||
|
||||
LLM.metric = LLMMetric()
|
||||
|
||||
LLM.dump_info = LLMDumpInfo()
|
||||
|
||||
MluHijackObject.apply_hijack(LLM,
|
||||
LLM.__init__,
|
||||
vllm__entrypoints__llm__LLM____init__)
|
||||
MluHijackObject.apply_hijack(LLM,
|
||||
"get_metrics",
|
||||
vllm__entrypoints__llm__LLM__get_metrics)
|
||||
MluHijackObject.apply_hijack(LLM,
|
||||
LLM._run_engine,
|
||||
vllm__entrypoints__llm__LLM___run_engine)
|
||||
@@ -0,0 +1 @@
|
||||
import vllm_mlu.entrypoints.openai.serving_engine
|
||||
@@ -0,0 +1,49 @@
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, AnyRequest
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
async def vllm__entrypoints__openai__serving_engine__OpenAIServing___check_model(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
) -> Optional[ErrorResponse]:
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||
return None
|
||||
if request.model in [
|
||||
prompt_adapter.prompt_adapter_name
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
]:
|
||||
return None
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: when client send a request with model=init/save_scheduler_view,
|
||||
scheduler will dump profile data.
|
||||
'''
|
||||
if request.model == "init_scheduler_view":
|
||||
await self.engine_client.init_scheduler_view()
|
||||
if request.model == "save_scheduler_view":
|
||||
await self.engine_client.save_scheduler_view()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(OpenAIServing,
|
||||
OpenAIServing._check_model,
|
||||
vllm__entrypoints__openai__serving_engine__OpenAIServing___check_model)
|
||||
3
vllm-v0.6.2/vllm_mlu/vllm_mlu/executor/__init__.py
Normal file
3
vllm-v0.6.2/vllm_mlu/vllm_mlu/executor/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
import vllm_mlu.executor.mlu_executor
|
||||
import vllm_mlu.executor.multiproc_mlu_executor
|
||||
import vllm_mlu.executor.ray_mlu_executor
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
35
vllm-v0.6.2/vllm_mlu/vllm_mlu/executor/mlu_executor.py
Normal file
35
vllm-v0.6.2/vllm_mlu/vllm_mlu/executor/mlu_executor.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from vllm.executor.mlu_executor import MLUExecutor
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def vllm__executor__mlu_executor__MLUExecutor__get_latency(self) -> float:
|
||||
'''
|
||||
requires that torch.mlu.synchronize() be executed before this function
|
||||
for getting an accurate reading
|
||||
'''
|
||||
latency = self.driver_worker.get_latency()
|
||||
return latency
|
||||
|
||||
|
||||
def vllm__executor__mlu_executor__MLUExecutor__recapture_model(
|
||||
self,
|
||||
context_batch_size_to_capture,
|
||||
context_seq_len_to_capture
|
||||
) -> None:
|
||||
return self.driver_worker.recapture_model(context_batch_size_to_capture,
|
||||
context_seq_len_to_capture)
|
||||
|
||||
|
||||
def vllm__executor__mlu_executor__MLUExecutor__get_memory_usage(self):
|
||||
return self.driver_worker.get_memory_usage()
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MLUExecutor,
|
||||
"get_latency",
|
||||
vllm__executor__mlu_executor__MLUExecutor__get_latency)
|
||||
MluHijackObject.apply_hijack(MLUExecutor,
|
||||
"recapture_model",
|
||||
vllm__executor__mlu_executor__MLUExecutor__recapture_model)
|
||||
MluHijackObject.apply_hijack(MLUExecutor,
|
||||
"get_memory_usage",
|
||||
vllm__executor__mlu_executor__MLUExecutor__get_memory_usage)
|
||||
@@ -0,0 +1,16 @@
|
||||
from vllm.executor.multiproc_mlu_executor import MultiprocessingMLUExecutor
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
def vllm__executor__multiproc_mlu_executor__MultiprocessingMLUExecutor__recapture_model(
|
||||
self,
|
||||
context_batch_size_to_capture,
|
||||
context_seq_len_to_capture
|
||||
) -> None:
|
||||
return self._run_workers("recapture_model",
|
||||
context_batch_size_to_capture=context_batch_size_to_capture,
|
||||
context_seq_len_to_capture=context_seq_len_to_capture)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MultiprocessingMLUExecutor,
|
||||
"recapture_model",
|
||||
vllm__executor__multiproc_mlu_executor__MultiprocessingMLUExecutor__recapture_model)
|
||||
267
vllm-v0.6.2/vllm_mlu/vllm_mlu/executor/ray_mlu_executor.py
Normal file
267
vllm-v0.6.2/vllm_mlu/vllm_mlu/executor/ray_mlu_executor.py
Normal file
@@ -0,0 +1,267 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
get_vllm_instance_id)
|
||||
from vllm_mlu._mlu_utils import VLLM_LATENCY_DEBUG, VLLM_LATENCY_DEBUG_NO_DEVICE
|
||||
from vllm.executor.ray_mlu_executor import RayMLUExecutor
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
if ray is not None:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__executor__ray_mlu_executor__RayMLUExecutor___init_workers_ray(
|
||||
self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
if (self.parallel_config.tensor_parallel_size == 1
|
||||
and self.parallel_config.pipeline_parallel_size == 1):
|
||||
# For single GPU case, we use a ray worker with constrained memory.
|
||||
num_gpus = self.cache_config.gpu_memory_utilization
|
||||
else:
|
||||
# Otherwise, the ray workers are allocated with a full GPU.
|
||||
num_gpus = 1
|
||||
|
||||
# The driver dummy worker does not actually use any resources.
|
||||
# It holds the resource for the driver worker.
|
||||
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
|
||||
# The remaining workers are the actual ray actors.
|
||||
self.workers: List[RayWorkerWrapper] = []
|
||||
|
||||
# Used in ray compiled DAG: indexed first by PP rank,
|
||||
# and then TP rank. In other words, the inner list is
|
||||
# the TP group of workers for a PP rank.
|
||||
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
|
||||
|
||||
if self.parallel_config.ray_workers_use_nsight:
|
||||
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
|
||||
ray_remote_kwargs)
|
||||
|
||||
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
self.workers.append(worker)
|
||||
else:
|
||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
**worker_wrapper_kwargs)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
|
||||
logger.debug("workers: %s", self.workers)
|
||||
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
|
||||
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
|
||||
raise ValueError(
|
||||
"Ray does not allocate any GPUs on the driver node. Consider "
|
||||
"adjusting the Ray placement group or running the driver on a "
|
||||
"GPU node.")
|
||||
|
||||
worker_ips = [
|
||||
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
|
||||
for worker in self.workers
|
||||
]
|
||||
ip_counts: Dict[str, int] = {}
|
||||
for ip in worker_ips:
|
||||
ip_counts[ip] = ip_counts.get(ip, 0) + 1
|
||||
|
||||
def sort_by_driver_then_worker_ip(worker):
|
||||
"""
|
||||
Sort the workers based on 3 properties:
|
||||
1. If the worker is on the same node as the driver (vllm engine),
|
||||
it should be placed first.
|
||||
2. Then, if the worker is on a node with fewer workers, it should
|
||||
be placed first.
|
||||
3. Finally, if the work is on a node with smaller IP address, it
|
||||
should be placed first.
|
||||
"""
|
||||
ip = ray.get(worker.get_node_ip.remote())
|
||||
return (ip != driver_ip, ip_counts[ip], ip)
|
||||
|
||||
# After sorting, the workers on the same node will be
|
||||
# close to each other, and the workers on the driver
|
||||
# node will be placed first.
|
||||
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
|
||||
|
||||
# Get the set of GPU IDs used on each node.
|
||||
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
|
||||
use_dummy_driver=True)
|
||||
|
||||
node_workers = defaultdict(list) # node id -> list of worker ranks
|
||||
node_gpus = defaultdict(list) # node id -> list of gpu ids
|
||||
|
||||
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
|
||||
node_workers[node_id].append(i)
|
||||
# `gpu_ids` can be a list of strings or integers.
|
||||
# convert them to integers for consistency.
|
||||
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
|
||||
# string sorting is not sufficient.
|
||||
# see https://github.com/vllm-project/vllm/issues/5590
|
||||
gpu_ids = [int(x) for x in gpu_ids]
|
||||
node_gpus[node_id].extend(gpu_ids)
|
||||
for node_id, gpu_ids in node_gpus.items():
|
||||
node_gpus[node_id] = sorted(gpu_ids)
|
||||
|
||||
all_ips = set(worker_ips + [driver_ip])
|
||||
n_ips = len(all_ips)
|
||||
n_nodes = len(node_workers)
|
||||
|
||||
if n_nodes != n_ips:
|
||||
raise RuntimeError(
|
||||
f"Every node should have a unique IP address. Got {n_nodes}"
|
||||
f" nodes with node ids {list(node_workers.keys())} and "
|
||||
f"{n_ips} unique IP addresses {all_ips}. Please check your"
|
||||
" network configuration. If you set `VLLM_HOST_IP` or "
|
||||
"`HOST_IP` environment variable, make sure it is unique for"
|
||||
" each node.")
|
||||
|
||||
VLLM_INSTANCE_ID = get_vllm_instance_id()
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [({
|
||||
"MLU_VISIBLE_DEVICES":
|
||||
",".join(map(str, node_gpus[node_id])),
|
||||
"VLLM_INSTANCE_ID":
|
||||
VLLM_INSTANCE_ID,
|
||||
"VLLM_TRACE_FUNCTION":
|
||||
str(envs.VLLM_TRACE_FUNCTION),
|
||||
**({
|
||||
"VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND
|
||||
} if envs.VLLM_ATTENTION_BACKEND is not None else {}),
|
||||
"VLLM_LATENCY_DEBUG":
|
||||
'1' if VLLM_LATENCY_DEBUG else '0',
|
||||
"VLLM_LATENCY_DEBUG_NO_DEVICE":
|
||||
'1' if VLLM_LATENCY_DEBUG_NO_DEVICE else '0',
|
||||
}, ) for (node_id, _) in worker_node_and_gpu_ids]
|
||||
|
||||
self._env_vars_for_all_workers = (
|
||||
all_args_to_update_environment_variables)
|
||||
|
||||
self._run_workers("update_environment_variables",
|
||||
all_args=self._get_env_vars_to_be_updated())
|
||||
|
||||
if len(node_gpus) == 1:
|
||||
# in single node case, we don't need to get the IP address.
|
||||
# the loopback address is sufficient
|
||||
# NOTE: a node may have several IP addresses, one for each
|
||||
# network interface. `get_ip()` might return any of them,
|
||||
# while they might not work for communication inside the node
|
||||
# if the network setup is complicated. Using the loopback address
|
||||
# solves this issue, as it always works for communication inside
|
||||
# the node.
|
||||
driver_ip = "127.0.0.1"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port())
|
||||
|
||||
# Initialize the actual workers inside worker wrapper.
|
||||
init_worker_all_kwargs = [
|
||||
self._get_worker_kwargs(
|
||||
local_rank=node_workers[node_id].index(rank),
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
|
||||
]
|
||||
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
|
||||
|
||||
self._run_workers("init_device")
|
||||
self._run_workers("load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
||||
self.pp_tp_workers.append([])
|
||||
for tp_rank in range(
|
||||
self.parallel_config.tensor_parallel_size):
|
||||
# PP=2, TP=4
|
||||
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
|
||||
rank = (pp_rank * self.parallel_config.tensor_parallel_size
|
||||
) + tp_rank
|
||||
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
||||
assert pp_rank < len(self.pp_tp_workers)
|
||||
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|
||||
|
||||
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
||||
# global rank 0. These are the workers that will broadcast to the
|
||||
# rest of the workers.
|
||||
self.tp_driver_workers: List[RayWorkerWrapper] = []
|
||||
# This is the list of workers that are not drivers and not the first
|
||||
# worker in a TP group. These are the workers that will be
|
||||
# broadcasted to.
|
||||
self.non_driver_workers: List[RayWorkerWrapper] = []
|
||||
|
||||
# Enforce rank order for correct rank to return final output.
|
||||
for index, worker in enumerate(self.workers):
|
||||
# The driver worker is rank 0 and not in self.workers.
|
||||
rank = index + 1
|
||||
if rank % self.parallel_config.tensor_parallel_size == 0:
|
||||
self.tp_driver_workers.append(worker)
|
||||
else:
|
||||
self.non_driver_workers.append(worker)
|
||||
|
||||
|
||||
def vllm__executor__ray_mlu_executor__RayMLUExecutor__get_latency(self):
|
||||
'''
|
||||
requires that torch.mlu.synchronize() be executed before this function
|
||||
for getting an accurate reading
|
||||
'''
|
||||
return self.driver_worker.execute_method("get_latency")
|
||||
|
||||
|
||||
def vllm__executor__ray_mlu_executor__RayMLUExecutor__recapture_model(
|
||||
self,
|
||||
context_batch_size_to_capture,
|
||||
context_seq_len_to_capture
|
||||
) -> None:
|
||||
return self._run_workers("recapture_model",
|
||||
context_batch_size_to_capture=context_batch_size_to_capture,
|
||||
context_seq_len_to_capture=context_seq_len_to_capture)
|
||||
|
||||
|
||||
def vllm__executor__ray_mlu_executor__RayMLUExecutor__get_memory_usage(self):
|
||||
return self.driver_worker.execute_method("get_memory_usage")
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(RayMLUExecutor,
|
||||
RayMLUExecutor._init_workers_ray,
|
||||
vllm__executor__ray_mlu_executor__RayMLUExecutor___init_workers_ray)
|
||||
MluHijackObject.apply_hijack(RayMLUExecutor,
|
||||
"get_latency",
|
||||
vllm__executor__ray_mlu_executor__RayMLUExecutor__get_latency)
|
||||
MluHijackObject.apply_hijack(RayMLUExecutor,
|
||||
"recapture_model",
|
||||
vllm__executor__ray_mlu_executor__RayMLUExecutor__recapture_model)
|
||||
MluHijackObject.apply_hijack(RayMLUExecutor,
|
||||
"get_memory_usage",
|
||||
vllm__executor__ray_mlu_executor__RayMLUExecutor__get_memory_usage)
|
||||
4
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/__init__.py
Normal file
4
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import vllm_mlu.lora.ops
|
||||
import vllm_mlu.lora.fully_sharded_layers
|
||||
import vllm_mlu.lora.layers
|
||||
import vllm_mlu.lora.punica
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
65
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/fully_sharded_layers.py
Normal file
65
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/fully_sharded_layers.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.lora.fully_sharded_layers import RowParallelLinearWithShardedLoRA
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def vllm__lora__fully_sharded_layers__RowParallelLinearWithShardedLoRA__apply(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
residual: Optional[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual and bias in matmul
|
||||
'''
|
||||
output = self.base_layer.quant_method.apply(
|
||||
self.base_layer, x, bias, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1,
|
||||
output.shape[-1]), output.shape
|
||||
buffer = torch.zeros(
|
||||
(x.shape[0], self.lora_a_stacked.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
|
||||
buffer = tensor_model_parallel_all_reduce(buffer)
|
||||
|
||||
# following S-LoRA, allows the fusing of all_gather and all_reduce
|
||||
# by adding the column partitioned lora output to a slice of output
|
||||
# tensor, which is a partial sum due to row parallel. All that
|
||||
# remains is a standard all_reduce. User should be aware though that
|
||||
# the output is not the same as a normal row_parallel, it should be
|
||||
# reduced before being used
|
||||
shard_size = self.lora_b_stacked.shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
|
||||
if self.bias_stacked is not None:
|
||||
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
|
||||
bias = bias[self.punica_wrapper.token_lora_indices]
|
||||
bias[self.punica_wrapper.token_lora_indices == -1] = 0
|
||||
output += bias
|
||||
|
||||
self.punica_wrapper.add_expand_slice(output, buffer,
|
||||
self.lora_b_stacked, start_idx,
|
||||
shard_size)
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(RowParallelLinearWithShardedLoRA,
|
||||
RowParallelLinearWithShardedLoRA.apply,
|
||||
vllm__lora__fully_sharded_layers__RowParallelLinearWithShardedLoRA__apply)
|
||||
219
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/layers.py
Normal file
219
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/layers.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
LinearScalingRotaryEmbeddingWithLora,
|
||||
apply_bias)
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding import (
|
||||
MLURotaryEmbedding, MLULinearScalingRotaryEmbedding)
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
|
||||
vllm__lora__layers__ColumnParallelLinearWithLoRA__forward_org = ColumnParallelLinearWithLoRA.forward
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add smooth_quant_scale parameter.
|
||||
'''
|
||||
def vllm__lora__layers__ColumnParallelLinearWithLoRA__forward(
|
||||
self,
|
||||
input_,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None
|
||||
):
|
||||
assert smooth_quant_scale is None, "LoRA does not support smooth quant yet."
|
||||
return vllm__lora__layers__ColumnParallelLinearWithLoRA__forward_org(self, input_)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def vllm__lora__layers__RowParallelLinearWithLoRA__apply(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
residual: Optional[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual and bias in matmul
|
||||
'''
|
||||
output = self.base_layer.quant_method.apply(
|
||||
self.base_layer, x, bias, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.bias_stacked is not None:
|
||||
self.indices = self.punica_wrapper.token_lora_indices
|
||||
output = apply_bias(
|
||||
self.indices,
|
||||
output,
|
||||
self.bias_stacked,
|
||||
)
|
||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0)
|
||||
return output
|
||||
|
||||
|
||||
def vllm__lora__layers__RowParallelLinearWithLoRA__forward(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
):
|
||||
# Set up backprop all-reduce.
|
||||
if self.base_layer.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
# TODO: simplify code below
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.base_layer.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: 1) apply residual fusion in matmul like RowParallelLinear
|
||||
2) add bias in matmul, not after all reduce
|
||||
'''
|
||||
# Matrix multiply.
|
||||
bias_ = (None if (self.base_layer.tp_rank > 0 or self.base_layer.skip_bias_add) else self.base_layer.bias)
|
||||
residual_ = None if self.base_layer.tp_rank > 0 else residual
|
||||
output_parallel = self.apply(input_parallel, bias_, residual_)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: do not add bias after all_reduce
|
||||
'''
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output, output_bias
|
||||
|
||||
|
||||
def vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> None:
|
||||
scaling_factors = (list(lora_config.long_lora_scaling_factors)
|
||||
if lora_config.long_lora_scaling_factors else [])
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: change LinearScalingRotaryEmbedding to MLULinearScalingRotaryEmbedding
|
||||
'''
|
||||
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
|
||||
self.base_layer, MLULinearScalingRotaryEmbedding) else 1.0)
|
||||
scaling_factors = sorted(
|
||||
list(set([base_scaling_factor] + scaling_factors)))
|
||||
self.base_layer = MLULinearScalingRotaryEmbedding(
|
||||
self.base_layer.head_size,
|
||||
self.base_layer.rotary_dim,
|
||||
self.base_layer.max_position_embeddings,
|
||||
self.base_layer.base,
|
||||
self.base_layer.is_neox_style,
|
||||
scaling_factors,
|
||||
self.base_layer.dtype,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
qk: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: change function prototype to meet forward_mlu in rope
|
||||
'''
|
||||
return self.base_layer(
|
||||
positions,
|
||||
qk,
|
||||
offsets=self.punica_wrapper.long_lora_indices,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
@classmethod
|
||||
def vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: change origin rope type to mlu rope
|
||||
'''
|
||||
return (type(source_layer) is MLULinearScalingRotaryEmbedding
|
||||
or type(source_layer) is MLURotaryEmbedding)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(RowParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA.apply,
|
||||
vllm__lora__layers__RowParallelLinearWithLoRA__apply)
|
||||
MluHijackObject.apply_hijack(ColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinearWithLoRA.forward,
|
||||
vllm__lora__layers__ColumnParallelLinearWithLoRA__forward)
|
||||
MluHijackObject.apply_hijack(RowParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA.forward,
|
||||
vllm__lora__layers__RowParallelLinearWithLoRA__forward)
|
||||
MluHijackObject.apply_hijack(LinearScalingRotaryEmbeddingWithLora,
|
||||
LinearScalingRotaryEmbeddingWithLora.create_lora_weights,
|
||||
vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__create_lora_weights)
|
||||
MluHijackObject.apply_hijack(LinearScalingRotaryEmbeddingWithLora,
|
||||
LinearScalingRotaryEmbeddingWithLora.forward,
|
||||
vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__forward)
|
||||
MluHijackObject.apply_hijack(LinearScalingRotaryEmbeddingWithLora,
|
||||
LinearScalingRotaryEmbeddingWithLora.can_replace_layer,
|
||||
vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__can_replace_layer)
|
||||
3
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/ops/__init__.py
Normal file
3
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/ops/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
import vllm_mlu.lora.ops.sgmv_expand
|
||||
import vllm_mlu.lora.ops.sgmv_expand_slice
|
||||
import vllm_mlu.lora.ops.sgmv_shrink
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
233
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/ops/sgmv_expand.py
Normal file
233
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/ops/sgmv_expand.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm_mlu.lora.ops.utils import adjust_kernel_block_size
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_expand_kernel_mlu(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
The sgmv's expand triton kernel is based on GroupGEMM.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
a_ptr = (input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride +
|
||||
offset_k[None, :] * xk_stride, )
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
||||
offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K)):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
|
||||
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
|
||||
else:
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)),
|
||||
other=0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)),
|
||||
other=0)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * xk_stride
|
||||
b_ptr += BLOCK_K * lora_n_stride
|
||||
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
c_mask = (offset_cm[:, None] <
|
||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sgmv_expand_mlu(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4, 10].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch.
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||
batch.
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Workaround: Adjust block size to meet mlu restrictions.
|
||||
|
||||
The grid of mlu triton kernel must less than 65536, it will be out of bound when
|
||||
the input seq is very long, and causes runtime error. So we need to adjust the block
|
||||
size to avoid this.
|
||||
'''
|
||||
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
BLOCK_K = 16
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
batches,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: call _sgmv_expand_kernel_mlu
|
||||
'''
|
||||
_sgmv_expand_kernel_mlu[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return
|
||||
244
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/ops/sgmv_expand_slice.py
Normal file
244
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/ops/sgmv_expand_slice.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm_mlu.lora.ops.utils import adjust_kernel_block_size
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_expand_slice_kernel_mlu(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
slice_offset,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
|
||||
Similar to the 'sgmv_expand' operator, but with an added parameter
|
||||
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
|
||||
might be that in the future, we could implement a fusion operator to
|
||||
achieve the current functionality instead of having to call it multiple
|
||||
times.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
a_ptr = (input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride +
|
||||
offset_k[None, :] * xk_stride, )
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
||||
offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K)):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
|
||||
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
|
||||
else:
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)),
|
||||
other=0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)),
|
||||
other=0)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * xk_stride
|
||||
b_ptr += BLOCK_K * lora_n_stride
|
||||
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
|
||||
(slice_offset + N))
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sgmv_expand_slice_mlu(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4, 10].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences
|
||||
in the batch
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
slice_offset (int): output_tensor's offset
|
||||
slice_size (int): current output_tensor's size
|
||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert slice_size == lora_b_weights.size(-2)
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Workaround: Adjust block size to meet mlu restrictions.
|
||||
|
||||
The grid of mlu triton kernel must less than 65536, it will be out of bound when
|
||||
the input seq is very long, and causes runtime error. So we need to adjust the block
|
||||
size to avoid this.
|
||||
'''
|
||||
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
BLOCK_K = 16
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
batches,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: call _sgmv_expand_kernel_mlu
|
||||
'''
|
||||
_sgmv_expand_slice_kernel_mlu[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
slice_offset,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return
|
||||
226
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/ops/sgmv_shrink.py
Normal file
226
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/ops/sgmv_shrink.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm_mlu.lora.ops.utils import adjust_kernel_block_size
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_shrink_kernel_mlu(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
scaling,
|
||||
xm_stride, # hidden_size
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
|
||||
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
|
||||
introducing SPLIT-K can improve performance
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
pid_sk = tl.program_id(axis=1)
|
||||
cur_batch = tl.program_id(axis=2)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
a_ptr = (input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride +
|
||||
offset_k[None, :] * xk_stride)
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index + offset_n[None, :] * lora_k_stride +
|
||||
offset_k[:, None] * lora_n_stride)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
|
||||
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=((offset_k[None, :] < k_remaining) & (offset_m[:, None] < M)),
|
||||
other=0.0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=((offset_k[:, None] < k_remaining) & (offset_n[None, :] < N)),
|
||||
other=0.0)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
accumulator += tl.dot(tiled_a, tiled_b)
|
||||
|
||||
a_ptr += BLOCK_K * SPLIT_K * xk_stride
|
||||
b_ptr += BLOCK_K * SPLIT_K * lora_n_stride
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
c_mask = (offset_cm[:, None] <
|
||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
||||
accumulator *= scaling
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sgmv_shrink_mlu(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_a_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch.
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||
batch.
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
assert inputs.dtype == lora_a_weights.dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
assert lora_a_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_a_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert inputs.is_contiguous()
|
||||
|
||||
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
|
||||
assert lora_a_weights.size(1) == 1
|
||||
lora_a_weights = lora_a_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
|
||||
assert lora_a_weights.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
# TODO tuning this config
|
||||
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Workaround: adjust block size to meet mlu restrictions.
|
||||
|
||||
The grid of mlu triton kernel must less than 65536, it will be out of bound when
|
||||
the input seq is very long, and causes runtime error. So we need to adjust the block
|
||||
size to avoid this.
|
||||
'''
|
||||
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 16)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
BLOCK_K = 32
|
||||
SPLIT_K = 8
|
||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
SPLIT_K,
|
||||
batches,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: call _sgmv_shrink_kernel_mlu
|
||||
'''
|
||||
_sgmv_shrink_kernel_mlu[grid](
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_a_weights.stride(0),
|
||||
lora_a_weights.stride(1),
|
||||
lora_a_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return
|
||||
38
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/ops/utils.py
Normal file
38
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/ops/utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Tuple
|
||||
from math import ceil
|
||||
|
||||
_MLU_MAX_GRID_SIZE = 65536
|
||||
|
||||
def adjust_kernel_block_size(
|
||||
m: int,
|
||||
block_m: int,
|
||||
n: int,
|
||||
block_n: int
|
||||
) -> Tuple[int, int]:
|
||||
"""Adjust block size to meet mlu triton grid restrictions.
|
||||
|
||||
Calculation of the max block size in candidates list:
|
||||
|
||||
LLama3.1-8b-tp1 max n is 14336
|
||||
LLama3.1-70b-tp4 max n is 7168
|
||||
LLama3.1-405b-tp8 max n is 6656
|
||||
|
||||
when n is 14336, the max sequence length of block size 256 can be
|
||||
floor(65536 / ceil(14336 / 256)) * 256 = 299520.
|
||||
"""
|
||||
candidates_list = [16, 32, 64, 96, 128, 192, 256]
|
||||
candidates_list_len = len(candidates_list)
|
||||
m_idx = 1
|
||||
n_idx = 0 if block_n == 16 else 1
|
||||
while m_idx < candidates_list_len and n_idx < candidates_list_len:
|
||||
block_m = candidates_list[m_idx]
|
||||
block_n = candidates_list[n_idx]
|
||||
if ceil(m / block_m) * ceil(n / block_n) < _MLU_MAX_GRID_SIZE:
|
||||
break
|
||||
if m_idx < candidates_list_len:
|
||||
m_idx += 1
|
||||
if n_idx < candidates_list_len:
|
||||
n_idx += 1
|
||||
if ceil(m / block_m) * ceil(n / block_n) >= _MLU_MAX_GRID_SIZE:
|
||||
raise ValueError(f"the max seq len {m} is too long for lora triton kernel")
|
||||
return block_m, block_n
|
||||
115
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/punica.py
Normal file
115
vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/punica.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.lora.punica import PunicaWrapper
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
from vllm_mlu.lora.ops.sgmv_expand import sgmv_expand_mlu
|
||||
from vllm_mlu.lora.ops.sgmv_expand_slice import sgmv_expand_slice_mlu
|
||||
from vllm_mlu.lora.ops.sgmv_shrink import sgmv_shrink_mlu
|
||||
|
||||
|
||||
def vllm__lora__punica__PunicaWrapper__shrink_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Change function from sgmv_shrink to sgmv_shrink_mlu.
|
||||
'''
|
||||
sgmv_shrink_mlu(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
scale,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__lora__punica__PunicaWrapper__expand_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_input: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Change function from sgmv_expand to sgmv_expand_mlu.
|
||||
'''
|
||||
sgmv_expand_mlu(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
add_input,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__lora__punica__PunicaWrapper__expand_slice_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: Optional[int],
|
||||
y_slice_size: Optional[int],
|
||||
add_input: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Change function from sgmv_expand_slice to sgmv_expand_slice_mlu.
|
||||
'''
|
||||
sgmv_expand_slice_mlu(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_input,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(PunicaWrapper,
|
||||
PunicaWrapper.shrink_prefill,
|
||||
vllm__lora__punica__PunicaWrapper__shrink_prefill)
|
||||
MluHijackObject.apply_hijack(PunicaWrapper,
|
||||
PunicaWrapper.expand_prefill,
|
||||
vllm__lora__punica__PunicaWrapper__expand_prefill)
|
||||
MluHijackObject.apply_hijack(PunicaWrapper,
|
||||
PunicaWrapper.expand_slice_prefill,
|
||||
vllm__lora__punica__PunicaWrapper__expand_slice_prefill)
|
||||
@@ -0,0 +1 @@
|
||||
from . import model_executor
|
||||
@@ -0,0 +1,2 @@
|
||||
from . import layers
|
||||
from . import model_loader
|
||||
@@ -0,0 +1,2 @@
|
||||
from . import feed_forward
|
||||
from . import linear
|
||||
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject, set_is_gated
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm.distributed.parallel_state import get_tp_group, get_tensor_model_parallel_group
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm_mlu._mlu_utils import *
|
||||
|
||||
|
||||
def vllm_mlu__model_executor__layers__feed_forward__FeedForward__forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
):
|
||||
self.prepare_weight()
|
||||
up_proj = getattr(self, self.up_proj_name)
|
||||
down_proj = getattr(self, self.down_proj_name)
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
if (self.use_bt_ffn and not isinstance(up_proj, BaseLayerWithLoRA)
|
||||
and not isinstance(down_proj, BaseLayerWithLoRA)):
|
||||
# The matmul formula is the following:
|
||||
# mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual
|
||||
# output = active(mul_out)
|
||||
# Notes: We cannot use the activation function in matmul because it does not support gated operation
|
||||
# we might support its in tmo matmul in the future
|
||||
fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj.weight, up_proj.bias,
|
||||
None, 'none', self.alpha, self.beta)
|
||||
act_out = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
|
||||
beta = 1.0 if residual_ is not None else 0.0
|
||||
'''
|
||||
=======================================
|
||||
Modify by custom vllm_mlu
|
||||
=======================================
|
||||
@brief: call parallel op and abandon original reduce if parallel_num is set
|
||||
'''
|
||||
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
|
||||
if is_parallel_enable:
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pg = get_tensor_model_parallel_group().device_group
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
out_ = mlu_ops.matmul_allreduce(cncl_comm, act_out, down_proj.weight, None, residual_,
|
||||
self.alpha, beta, self.parallel_num)
|
||||
else:
|
||||
out_ = mlu_ops.matmul(act_out, down_proj.weight, None, residual_, 'none', self.alpha, beta)
|
||||
'''
|
||||
=======================================
|
||||
End of custom MLU Hijack
|
||||
=======================================
|
||||
'''
|
||||
# bias if existed need to add after second matmul according to the original design of vllm
|
||||
'''
|
||||
=============================
|
||||
Modify by custom vllm_mlu
|
||||
=============================
|
||||
@brief: when preload_size is set, call GroupCoordinator.all_reduce() directly and
|
||||
use async_op to set all_reduce paralleled with preload
|
||||
'''
|
||||
if self.reduce_results and self.tp_size > 1 and not is_parallel_enable:
|
||||
if hasattr(self, 'preload_size') and self.preload_size > 0 and not self.is_prompt:
|
||||
handle = get_tp_group().all_reduce(out_, async_op=True)
|
||||
_MB = 1 << 20
|
||||
mlu_ops.preload(self.preloaded_weights[0].data, self.preload_size * _MB)
|
||||
preloaded_weights_size = self.preloaded_weights[0].numel() * self.preloaded_weights[0].element_size()
|
||||
if preloaded_weights_size < (self.preload_size * _MB) and len(self.preloaded_weights) > 1:
|
||||
mlu_ops.preload(self.preloaded_weights[1].data, (self.preload_size * _MB) - preloaded_weights_size)
|
||||
handle.wait()
|
||||
out = out_
|
||||
else:
|
||||
out = tensor_model_parallel_all_reduce(out_)
|
||||
else:
|
||||
out = out_
|
||||
'''
|
||||
=========================
|
||||
End of custom MLU Hijack
|
||||
=========================
|
||||
'''
|
||||
# do the bias add if needed
|
||||
if not self.skip_bias_add:
|
||||
out = out + down_proj.bias if down_proj.bias is not None else out
|
||||
else:
|
||||
return out, down_proj.bias
|
||||
else:
|
||||
fc1, bias = up_proj(hidden_states)
|
||||
if bias is not None:
|
||||
fc1 += bias
|
||||
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
|
||||
out, bias = down_proj(fc1, residual=residual_)
|
||||
if self.skip_bias_add:
|
||||
return out, bias
|
||||
return out
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(FeedForward,
|
||||
FeedForward.forward,
|
||||
vllm_mlu__model_executor__layers__feed_forward__FeedForward__forward)
|
||||
@@ -0,0 +1,116 @@
|
||||
from typing import Optional
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import get_tp_group, get_tensor_model_parallel_group
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, split_tensor_along_last_dim
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod, RowParallelLinear
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm_mlu._mlu_utils import *
|
||||
|
||||
|
||||
def vllm__model_executor__layers__linear__UnquantizedLinearMethod__apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
beta = 1.0 if residual is not None else 0.0
|
||||
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
|
||||
'''
|
||||
=====================================================
|
||||
Modify by custom vllm_mlu
|
||||
=====================================================
|
||||
@brief: call parallel op if parallel_num is set
|
||||
'''
|
||||
if hasattr(self, 'parallel_num') and get_is_prompt():
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pg = get_tensor_model_parallel_group().device_group
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
return mlu_ops.matmul_allreduce(cncl_comm, x.view(-1, x.shape[-1]), layer.weight,
|
||||
bias, residual, 1.0, beta, self.parallel_num).view(res_shape)
|
||||
return mlu_ops.matmul(x.view(-1, x.shape[-1]), layer.weight, bias, residual, 'none', 1.0, beta).view(res_shape)
|
||||
'''
|
||||
=====================================================
|
||||
End of custom MLU Hijack
|
||||
=====================================================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__model_executor__layers__linear__RowParallelLinear__forward(
|
||||
self,
|
||||
input_,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
):
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
'''
|
||||
=====================================================
|
||||
Modify by custom vllm_mlu
|
||||
=====================================================
|
||||
@brief: abandon original reduce if parallel_num is set
|
||||
'''
|
||||
is_parallel_enable = hasattr(self.quant_method, 'parallel_num') and get_is_prompt()
|
||||
'''
|
||||
=====================================================
|
||||
End of custom MLU Hijack
|
||||
=====================================================
|
||||
'''
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_,
|
||||
residual=residual_)
|
||||
'''
|
||||
=============================
|
||||
Modify by custom vllm_mlu
|
||||
=============================
|
||||
@brief: when preload_size is set, call GroupCoordinator.all_reduce() directly and
|
||||
use async_op to set all_reduce paralleled with preload
|
||||
'''
|
||||
if self.reduce_results and self.tp_size > 1 and not is_parallel_enable:
|
||||
if hasattr(self, 'preload_size') and self.preload_size > 0 and not self.is_prompt:
|
||||
handle = get_tp_group().all_reduce(output_parallel, async_op=True)
|
||||
_MB = 1 << 20
|
||||
mlu_ops.preload(self.preloaded_weights[0].data, self.preload_size * _MB)
|
||||
preloaded_weights_size = self.preloaded_weights[0].numel() * self.preloaded_weights[0].element_size()
|
||||
if preloaded_weights_size < (self.preload_size * _MB) and len(self.preloaded_weights) > 1:
|
||||
mlu_ops.preload(self.preloaded_weights[1].data, (self.preload_size * _MB) - preloaded_weights_size)
|
||||
handle.wait()
|
||||
output = output_parallel
|
||||
else:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
'''
|
||||
=========================
|
||||
End of custom MLU Hijack
|
||||
=========================
|
||||
'''
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
return output, output_bias
|
||||
|
||||
|
||||
MluHijackObject.undo_hijack(UnquantizedLinearMethod,
|
||||
UnquantizedLinearMethod.apply)
|
||||
MluHijackObject.apply_hijack(UnquantizedLinearMethod,
|
||||
UnquantizedLinearMethod.apply,
|
||||
vllm__model_executor__layers__linear__UnquantizedLinearMethod__apply)
|
||||
MluHijackObject.undo_hijack(RowParallelLinear,
|
||||
RowParallelLinear.forward)
|
||||
MluHijackObject.apply_hijack(RowParallelLinear,
|
||||
RowParallelLinear.forward,
|
||||
vllm__model_executor__layers__linear__RowParallelLinear__forward)
|
||||
@@ -0,0 +1 @@
|
||||
from . import loader
|
||||
@@ -0,0 +1,143 @@
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Optional
|
||||
from vllm.logger import init_logger
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
from vllm.config import VllmConfig, ModelConfig, ParallelConfig
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_parallel_num(
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig
|
||||
):
|
||||
attention_parallel_num = os.environ.get(ATTN_PARALLEL_NUM)
|
||||
ffn_parallel_num = os.environ.get(FFN_PARALLEL_NUM)
|
||||
if attention_parallel_num and attention_parallel_num.isdecimal():
|
||||
attention_parallel_num = int(attention_parallel_num)
|
||||
else:
|
||||
attention_parallel_num = 0
|
||||
if ffn_parallel_num and ffn_parallel_num.isdecimal():
|
||||
ffn_parallel_num = int(ffn_parallel_num)
|
||||
else:
|
||||
ffn_parallel_num = 0
|
||||
|
||||
if parallel_config.tensor_parallel_size == 1:
|
||||
raise ValueError("Can not use context_comm_cmpt_parallel when tp num is 1.")
|
||||
if (attention_parallel_num <= 0 and ffn_parallel_num <= 0):
|
||||
raise ValueError("attention_parallel_num and ffn_parallel_num must be positive integers.")
|
||||
|
||||
hidden_size = model_config.get_hidden_size()
|
||||
ffn_parallel_num = max(ffn_parallel_num, 1)
|
||||
if hidden_size % ffn_parallel_num != 0:
|
||||
raise ValueError(f"Hidden_size: {hidden_size} must be divisible by ffn_parallel_num: {ffn_parallel_num}")
|
||||
|
||||
return attention_parallel_num, ffn_parallel_num
|
||||
|
||||
|
||||
def get_attr_by_path(obj, path):
|
||||
# Split the path by dots to get individual attributes
|
||||
attributes = path.split('.')
|
||||
# Iterate through the attributes to access nested members
|
||||
for attr in attributes:
|
||||
if not hasattr(obj, attr):
|
||||
return None
|
||||
obj = getattr(obj, attr)
|
||||
return obj
|
||||
|
||||
|
||||
def set_custom_attributes(model, model_config, parallel_config):
|
||||
attn_row_parallel_layers = []
|
||||
attn_weights = []
|
||||
ffn_row_parallel_layers = []
|
||||
ffn_weights = []
|
||||
sparse_moe_mlp_layers = []
|
||||
for module in model.modules():
|
||||
if module.__class__.__name__ == "FeedForward":
|
||||
ffn_weight = []
|
||||
if hasattr(module, "up_proj_name"):
|
||||
up_proj_name = getattr(module, "up_proj_name")
|
||||
up_proj = getattr(module, up_proj_name)
|
||||
if hasattr(up_proj, "weight"):
|
||||
ffn_weight.append(up_proj.weight)
|
||||
if hasattr(module, "down_proj_name"):
|
||||
down_proj_name = getattr(module, "down_proj_name")
|
||||
down_proj = getattr(module, down_proj_name)
|
||||
if hasattr(down_proj, "weight"):
|
||||
ffn_weight.append(down_proj.weight)
|
||||
if ffn_weight is not None:
|
||||
ffn_weights.append(ffn_weight)
|
||||
ffn_row_parallel_layers.append(module)
|
||||
for child_module in module.children():
|
||||
if child_module.__class__.__name__ == "Attention":
|
||||
for sibling_module in module.children():
|
||||
if sibling_module.__class__.__name__ == "QKVParallelLinear":
|
||||
if hasattr(sibling_module, "weight"):
|
||||
weight = getattr(sibling_module, "weight")
|
||||
attn_weights.append([weight])
|
||||
if sibling_module.__class__.__name__ == "RowParallelLinear":
|
||||
attn_row_parallel_layers.append(sibling_module)
|
||||
if module.__class__.__name__ == "SparseMoeMlp" or issubclass(module.__class__, SparseMoeMlp):
|
||||
sparse_moe_mlp_layers.append(module)
|
||||
|
||||
if VLLM_PRELOAD_SIZE > 0:
|
||||
if (len(attn_row_parallel_layers) \
|
||||
== len(attn_weights) \
|
||||
== len(ffn_row_parallel_layers) \
|
||||
== len(ffn_weights)) and \
|
||||
len(attn_row_parallel_layers) != 0:
|
||||
|
||||
for i in range(len(attn_row_parallel_layers)):
|
||||
attn_row_parallel_layers[i].preloaded_weights = ffn_weights[i]
|
||||
attn_row_parallel_layers[i].preload_size = VLLM_PRELOAD_SIZE
|
||||
if i < len(attn_row_parallel_layers) - 1:
|
||||
ffn_row_parallel_layers[i].preloaded_weights = attn_weights[i+1]
|
||||
ffn_row_parallel_layers[i].preload_size = VLLM_PRELOAD_SIZE
|
||||
else:
|
||||
logger.warning("%s does not support preload weight!", model.__class__.__name__)
|
||||
|
||||
# context compute communication parallel
|
||||
if check_context_comm_cmpt_parallel():
|
||||
attention_parallel_num, ffn_parallel_num = get_parallel_num(model_config, parallel_config)
|
||||
for o_proj in attn_row_parallel_layers:
|
||||
setattr(o_proj.quant_method, 'parallel_num', attention_parallel_num)
|
||||
|
||||
if len(sparse_moe_mlp_layers) != 0:
|
||||
for sparse_moe_mlp in sparse_moe_mlp_layers:
|
||||
setattr(sparse_moe_mlp, 'parallel_num', ffn_parallel_num)
|
||||
else:
|
||||
for ffn in ffn_row_parallel_layers:
|
||||
setattr(ffn, 'parallel_num', ffn_parallel_num)
|
||||
|
||||
|
||||
vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model__org = DefaultModelLoader.load_model
|
||||
|
||||
|
||||
def vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model(
|
||||
self, vllm_config: VllmConfig) -> nn.Module:
|
||||
model = vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model__org(
|
||||
self, vllm_config=vllm_config)
|
||||
'''
|
||||
=============================
|
||||
Modify by custom vllm_mlu
|
||||
=============================
|
||||
@brief: According to the layer name in models, set custom optimize attributes.
|
||||
'''
|
||||
set_custom_attributes(model, vllm_config.model_config, vllm_config.parallel_config)
|
||||
'''
|
||||
=========================
|
||||
End of custom MLU Hijack
|
||||
=========================
|
||||
'''
|
||||
return model
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(DefaultModelLoader,
|
||||
DefaultModelLoader.load_model,
|
||||
vllm__model_executor__model_loader__loader__DefaultModelLoader__load_model)
|
||||
@@ -0,0 +1,17 @@
|
||||
### 简介
|
||||
|
||||
该劫持代码实现了vllm Context通算并行功能。开启后可在部分数据规模和切分数量上对Context Latency指标有优化效果。目前是可选功能,默认不开启。
|
||||
|
||||
### 开启方法
|
||||
|
||||
- 设置环境变量ATTN_PARALLEL_NUM和FFN_PARALLEL_NUM为正整数,分别控制attention和ffn部分的通算并行切分数量。两个环境变量相互独立,可以同时开启。例如输入export ATTN_PARALLEL_NUM=2 FFN_PARALLEL_NUM=4,则表示两部分均开启并行,attention数据拆分为2份,ffn数据拆分为4份。
|
||||
|
||||
- 需要保证tensor_parallel_size大于1。
|
||||
|
||||
- 开启ffn部分的通算并行时,需要保证hidden_size能被FFN_PARALLEL_NUM整除。
|
||||
|
||||
### 注意事项
|
||||
|
||||
- 开启通算并行功能时,由于算子限制,Mixtral系列模型、Qwen2(包含Qwen1.5和Qwen2.5)系列模型在smoothquant量化下只支持batch_size = 1,且算子默认切分数为4,ATTN_PARALLEL_NUM不生效。
|
||||
|
||||
- smoothquant量化下,vllm_mlu ffn部分不调用tmo matmul算子,该部分通算融合不生效。
|
||||
@@ -0,0 +1 @@
|
||||
from . import model_executor
|
||||
@@ -0,0 +1,3 @@
|
||||
from . import custom_model
|
||||
from . import layers
|
||||
from . import models
|
||||
@@ -0,0 +1 @@
|
||||
from . import custom
|
||||
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce, get_tensor_model_parallel_rank
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
|
||||
from vllm_mlu.model_executor.custom_model.custom import CustomMoeBlock
|
||||
|
||||
|
||||
def vllm__module_executor__custom_model__CustomMoeBlock__forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
shared_output = None
|
||||
if self.shared_expert is not None:
|
||||
shared_output = self.shared_expert(hidden_states)
|
||||
if self.shared_expert_gate is not None:
|
||||
shared_output = F.sigmoid(
|
||||
self.shared_expert_gate(hidden_states)) * shared_output
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
residual_ = None if self.rank > 0 else residual
|
||||
'''
|
||||
=====================================================
|
||||
Modify by Context Communication Computation Parallel
|
||||
=====================================================
|
||||
@brief: call fused_moe
|
||||
'''
|
||||
params = [hidden_states, router_logits, self.w1, self.w2, None, None,
|
||||
residual_, self.input_smooth, self.act_smooth, self.w1_scale, self.w2_scale,
|
||||
self.top_k, self.config.norm_topk_prob, self.config.is_gated, self.config.hidden_act, 0]
|
||||
if hasattr(self, 'parallel_num') and get_is_prompt():
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pg = get_tensor_model_parallel_group().device_group
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
params.extend([self.parallel_num, cncl_comm])
|
||||
final_hidden_states = mlu_ops.fused_moe(*params)
|
||||
'''
|
||||
=====================================================
|
||||
End of Context Communication Computation Parallel
|
||||
=====================================================
|
||||
'''
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
reduce_results = (self.config.use_parallel_residual == False)
|
||||
if reduce_results:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(CustomMoeBlock,
|
||||
CustomMoeBlock.forward,
|
||||
vllm__module_executor__custom_model__CustomMoeBlock__forward)
|
||||
@@ -0,0 +1,2 @@
|
||||
from . import quantization
|
||||
from . import sparse_moe_mlp
|
||||
@@ -0,0 +1 @@
|
||||
from . import smoothquant
|
||||
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
|
||||
from vllm_mlu._mlu_utils import get_is_prompt
|
||||
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def vllm_mlu__model_executor__layers__quantization__smoothquant__SmoothQuantLinearMethod__apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
quant_input = None
|
||||
input_scale = None
|
||||
if self.quant_config.input_quant_method == "per_token":
|
||||
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer.smooth, None)
|
||||
if self.quant_config.input_quant_method == "per_tensor":
|
||||
quant_input = x if self.skip_quant_input else mlu_ops.quantize(x, layer.scale_to_int, None)
|
||||
|
||||
'''
|
||||
=====================================================
|
||||
Modify by Context Communication Computation Parallel
|
||||
=====================================================
|
||||
@brief: call parallel op
|
||||
'''
|
||||
if hasattr(self, 'parallel_num') and get_is_prompt():
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pg = get_tensor_model_parallel_group().device_group
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
params = [cncl_comm, quant_input, input_scale, layer.qweight, layer.per_channel_scale,
|
||||
self.compute_dtype, bias, residual, 1.0, 1.0, self.parallel_num]
|
||||
out = mlu_ops.smooth_quant_matmul_allreduce(*params)
|
||||
else:
|
||||
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer.qweight,
|
||||
layer.per_channel_scale, self.compute_dtype, bias, residual)
|
||||
'''
|
||||
=====================================================
|
||||
End of Context Communication Computation Parallel
|
||||
=====================================================
|
||||
'''
|
||||
return out
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(SmoothQuantLinearMethod,
|
||||
SmoothQuantLinearMethod.apply,
|
||||
vllm_mlu__model_executor__layers__quantization__smoothquant__SmoothQuantLinearMethod__apply)
|
||||
@@ -0,0 +1,89 @@
|
||||
"""Inference-only MOE model."""
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Optional
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce, get_tensor_model_parallel_rank
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
|
||||
|
||||
def vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
orig_hidden_states_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# expert_logits: [num_tokens, self.num_experts_per_rank]
|
||||
expert_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.forward_experts(hidden_states, expert_logits, residual)
|
||||
|
||||
'''
|
||||
=====================================================
|
||||
Modify by Context Communication Computation Parallel
|
||||
=====================================================
|
||||
@brief: disbale reduce if parallel op used
|
||||
'''
|
||||
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
|
||||
if self.tp_size > 1 and not is_parallel_enable:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
'''
|
||||
=====================================================
|
||||
End of Context Communication Computation Parallel
|
||||
=====================================================
|
||||
'''
|
||||
|
||||
output = final_hidden_states.view(orig_hidden_states_shape)
|
||||
return output
|
||||
|
||||
|
||||
def vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward_experts(
|
||||
self,
|
||||
hidden_states,
|
||||
expert_logits,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
):
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
if self.is_use_fused_moe:
|
||||
self.pack_params()
|
||||
'''
|
||||
=====================================================
|
||||
Modify by Context Communication Computation Parallel
|
||||
=====================================================
|
||||
@brief: call fused_moe all_reduce
|
||||
'''
|
||||
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
|
||||
if is_parallel_enable:
|
||||
residual_ = residual
|
||||
params = [hidden_states, expert_logits, self.w13, self.w2, self.b13, self.b2,
|
||||
residual_, self.a13_scale, self.a2_scale, self.w13_scale, self.w2_scale,
|
||||
self.top_k, self.renormalize, self.is_gated, self.hidden_act, self.start_expert_id]
|
||||
if is_parallel_enable:
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pg = get_tensor_model_parallel_group().device_group
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
params.extend([self.parallel_num, cncl_comm])
|
||||
final_hidden_states = mlu_ops.fused_moe(*params)
|
||||
'''
|
||||
=====================================================
|
||||
End of Context Communication Computation Parallel
|
||||
=====================================================
|
||||
'''
|
||||
else:
|
||||
final_hidden_states = self.forward_experts_nofused(hidden_states, expert_logits)
|
||||
if residual_ is not None:
|
||||
final_hidden_states = final_hidden_states + residual_
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(SparseMoeMlp,
|
||||
SparseMoeMlp.forward,
|
||||
vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward)
|
||||
MluHijackObject.apply_hijack(SparseMoeMlp,
|
||||
SparseMoeMlp.forward_experts,
|
||||
vllm_mlu__model_executor__layers__sparse_moe_mlp__SparseMoeMlp__forward_experts)
|
||||
@@ -0,0 +1,3 @@
|
||||
from . import mixtral_quant
|
||||
from . import qwen2
|
||||
from . import qwen2_moe
|
||||
@@ -0,0 +1,299 @@
|
||||
import torch
|
||||
from typing import List, Optional
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
|
||||
from vllm.model_executor.models.mixtral_quant import MixtralAttention
|
||||
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
|
||||
from vllm.attention.backends.abstract import (AttentionMetadata,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import get_num_prefill_decode_query_kv_tokens
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.attention.backends.mlu_attn import (MLUFlashAttentionMetadata,
|
||||
_get_query_key_seq_metadata,
|
||||
_get_causal_option)
|
||||
|
||||
|
||||
def vllm__model_executor__models__mixtral__MixtralAttention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
'''
|
||||
=====================================================
|
||||
Modify by Context Communication Computation Parallel
|
||||
=====================================================
|
||||
@brief: call flash_attn_sq_mm_allreduce to finish forward
|
||||
'''
|
||||
if (attn_metadata.prefill_metadata) and \
|
||||
(kv_cache[0].numel() > 0) and \
|
||||
(hasattr(self.o_proj, 'quant_method')) and \
|
||||
(isinstance(self.o_proj.quant_method, SmoothQuantLinearMethod)) and \
|
||||
(self.o_proj.quant_method.quant_config.input_quant_method == "per_token"):
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pg = get_tensor_model_parallel_group().device_group
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
|
||||
return torch.ops.vllm.context_attn_comm_cmpt_parallel_flash_attention_v2(
|
||||
q, k, v,
|
||||
self.num_heads, self.head_dim, self.num_kv_heads,
|
||||
kv_cache, self.attn.impl.kv_cache_dtype,
|
||||
1.0, 1.0, self.scaling,
|
||||
cncl_comm,
|
||||
self.o_proj.smooth, self.o_proj.qweight,
|
||||
self.o_proj.per_channel_scale.to(torch.float),
|
||||
self.o_proj.quant_method.parallel_num,
|
||||
residual, self.attn.impl.sliding_window, self.attn.impl.alibi_slopes
|
||||
)
|
||||
'''
|
||||
=====================================================
|
||||
End of Context Communication Computation Parallel
|
||||
=====================================================
|
||||
'''
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual
|
||||
'''
|
||||
output, _ = self.o_proj(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MixtralAttention,
|
||||
MixtralAttention.forward,
|
||||
vllm__model_executor__models__mixtral__MixtralAttention__forward)
|
||||
|
||||
|
||||
def context_attn_comm_cmpt_parallel_flash_attention_v2(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: List[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
cncl_comm: int,
|
||||
smooth: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
per_channel_scale: torch.Tensor,
|
||||
parallel_num: int,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
current_metadata = get_forward_context()
|
||||
assert current_metadata is not None
|
||||
assert isinstance(current_metadata, MLUFlashAttentionMetadata)
|
||||
attn_metadata: MLUFlashAttentionMetadata = current_metadata
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
if (key is not None) and (key is not None):
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
kv_cache_, kv_cache_scale_ = kv_cache
|
||||
key_cache = kv_cache_[0]
|
||||
value_cache = kv_cache_[1]
|
||||
key_cache_scale, value_cache_scale = None, None
|
||||
if kv_cache_scale_.numel() > 0:
|
||||
key_cache_scale = kv_cache_scale_[0]
|
||||
value_cache_scale = kv_cache_scale_[1]
|
||||
|
||||
# if not specified in self.attn.forward params, use default DECODER
|
||||
attn_type = AttentionType.DECODER
|
||||
|
||||
# We skip updating the KV cache under two conditions:
|
||||
# a. When the Attention Type is ENCODER. In this phase, we compute
|
||||
# only the encoder attention without updating the cache.
|
||||
# b. When both Key and Value are None. This occurs during
|
||||
# cross-attention computation in the decoding phase, where the KV
|
||||
# cache is already populated with the cross-attention tensor.
|
||||
# Thus, we skip cache updates during this time.
|
||||
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
|
||||
value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
if USE_PAGED:
|
||||
if kv_cache_dtype == 'int8':
|
||||
mlu_ops.quant_to_paged_cache(key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
attn_metadata.slot_mapping.flatten())
|
||||
else:
|
||||
mlu_ops.reshape_paged_cache(key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping.flatten())
|
||||
else:
|
||||
# FIXME: After TMO-1496 is completed, remove this code.
|
||||
if key.stride() != value.stride():
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
if kv_cache_dtype == 'int8':
|
||||
mlu_ops.quant_to_linear_cache(key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
attn_metadata.cu_seq_lens,
|
||||
attn_metadata.max_seq_len,
|
||||
True, # packed
|
||||
None, # context_seq_offset
|
||||
attn_metadata.batch_ids,
|
||||
attn_metadata.slot_mapping_unpaged)
|
||||
else:
|
||||
mlu_ops.reshape_linear_cache(key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.cu_seq_lens,
|
||||
attn_metadata.max_seq_len,
|
||||
True, # packed
|
||||
None, # context_seq_offset
|
||||
attn_metadata.batch_ids,
|
||||
attn_metadata.slot_mapping_unpaged)
|
||||
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
decode_query = query[num_prefill_query_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_query_tokens]
|
||||
assert query.shape[0] == num_prefill_query_tokens
|
||||
assert decode_query.shape[0] == num_decode_query_tokens
|
||||
|
||||
alibi_slopes = None if alibi_slopes is None else \
|
||||
alibi_slopes.repeat(attn_metadata.num_prefills, 1)
|
||||
prefill_meta = attn_metadata.prefill_metadata
|
||||
# Prompt run.
|
||||
if (prefill_meta.block_tables is None
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
|
||||
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
|
||||
|
||||
key = key[:num_prefill_kv_tokens]
|
||||
value = value[:num_prefill_kv_tokens]
|
||||
|
||||
output = mlu_ops.flash_attn_sq_mm_allreduce(cncl_comm,
|
||||
query, key, value,
|
||||
q_seq_start_loc, k_seq_start_loc,
|
||||
alibi_slopes, None,
|
||||
smooth, qweight,
|
||||
per_channel_scale, None,
|
||||
q_seq_len, k_seq_len,
|
||||
softmax_scale, _get_causal_option(attn_type),
|
||||
-1 if window_size is None \
|
||||
else window_size[0],
|
||||
-1 if window_size is None \
|
||||
else window_size[1],
|
||||
torch.float, parallel_num)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support prefix caching")
|
||||
assert prefill_meta.seq_lens is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
output = mlu_ops.flash_attn_sq_mm_allreduce(cncl_comm,
|
||||
query, key_cache, value_cache,
|
||||
prefill_meta.query_start_loc, prefill_meta.seq_start_loc,
|
||||
alibi_slopes, None,
|
||||
smooth, qweight,
|
||||
per_channel_scale, None,
|
||||
prefill_meta.max_query_len, max_seq_len,
|
||||
softmax_scale, True,
|
||||
-1 if window_size is None \
|
||||
else window_size[0],
|
||||
-1 if window_size is None \
|
||||
else window_size[1],
|
||||
torch.float, parallel_num)
|
||||
|
||||
# Add residual.
|
||||
if residual is not None:
|
||||
output = output + residual
|
||||
return output
|
||||
|
||||
|
||||
def context_attn_comm_cmpt_parallel_flash_attention_v2_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: List[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
cncl_comm: int,
|
||||
smooth: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
per_channel_scale: torch.Tensor,
|
||||
parallel_num: int,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="context_attn_comm_cmpt_parallel_flash_attention_v2",
|
||||
op_func=context_attn_comm_cmpt_parallel_flash_attention_v2,
|
||||
mutates_args=["kv_cache"],
|
||||
fake_impl=context_attn_comm_cmpt_parallel_flash_attention_v2_fake,
|
||||
)
|
||||
@@ -0,0 +1,90 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Attention
|
||||
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
|
||||
|
||||
|
||||
def vllm__model_executor__models__qwen2__Qwen2Attention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert smooth_quant_scale is None
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack q & k to fit tmo.apply_rotary
|
||||
'''
|
||||
qk, _ = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
self.rotary_emb(positions, qk.view(-1, self.num_heads + self.num_kv_heads, self.head_dim))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
'''
|
||||
=====================================================
|
||||
Modify by Context Communication Computation Parallel
|
||||
=====================================================
|
||||
@brief: call flash_attn_sq_mm_allreduce to finish forward
|
||||
'''
|
||||
if (attn_metadata.prefill_metadata) and \
|
||||
(kv_cache[0].numel() > 0) and \
|
||||
(hasattr(self.o_proj, 'quant_method')) and \
|
||||
(isinstance(self.o_proj.quant_method, SmoothQuantLinearMethod)) and \
|
||||
(self.o_proj.quant_method.quant_config.input_quant_method == "per_token"):
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pg = get_tensor_model_parallel_group().device_group
|
||||
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
|
||||
|
||||
return torch.ops.vllm.context_attn_comm_cmpt_parallel_flash_attention_v2(
|
||||
q, k, v,
|
||||
self.num_heads, self.head_dim, self.num_kv_heads,
|
||||
kv_cache, self.attn.impl.kv_cache_dtype,
|
||||
1.0, 1.0, self.scaling,
|
||||
cncl_comm,
|
||||
self.o_proj.smooth, self.o_proj.qweight,
|
||||
self.o_proj.per_channel_scale.to(torch.float),
|
||||
self.o_proj.quant_method.parallel_num,
|
||||
residual, self.attn.impl.sliding_window, self.attn.impl.alibi_slopes
|
||||
)
|
||||
'''
|
||||
=====================================================
|
||||
End of Context Communication Computation Parallel
|
||||
=====================================================
|
||||
'''
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual
|
||||
'''
|
||||
output, _ = self.o_proj(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
|
||||
MluHijackObject.undo_hijack(Qwen2Attention,
|
||||
Qwen2Attention.forward)
|
||||
MluHijackObject.apply_hijack(Qwen2Attention,
|
||||
Qwen2Attention.forward,
|
||||
vllm__model_executor__models__qwen2__Qwen2Attention__forward)
|
||||
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_group
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm_mlu.model_executor.models.qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
|
||||
|
||||
def vllm_mlu__model_executor__models__qwen2_moe__Qwen2MoeSparseMoeBlock__forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
shared_output = None
|
||||
if self.shared_expert is not None:
|
||||
shared_output = self.shared_expert(hidden_states)
|
||||
if self.shared_expert_gate is not None:
|
||||
gate_output = self.shared_expert_gate(hidden_states)
|
||||
shared_output = F.sigmoid(gate_output[0]) * shared_output
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.forward_experts(hidden_states, router_logits, residual)
|
||||
|
||||
'''
|
||||
=====================================================
|
||||
Modify by Context Communication Computation Parallel
|
||||
=====================================================
|
||||
@brief: disbale reduce if parallel op used
|
||||
'''
|
||||
is_parallel_enable = hasattr(self, 'parallel_num') and get_is_prompt()
|
||||
if self.tp_size > 1:
|
||||
if is_parallel_enable:
|
||||
shared_output = tensor_model_parallel_all_reduce(shared_output)
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
else:
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
'''
|
||||
=====================================================
|
||||
End of Context Communication Computation Parallel
|
||||
=====================================================
|
||||
'''
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(Qwen2MoeSparseMoeBlock,
|
||||
Qwen2MoeSparseMoeBlock.forward,
|
||||
vllm_mlu__model_executor__models__qwen2_moe__Qwen2MoeSparseMoeBlock__forward)
|
||||
32
vllm-v0.6.2/vllm_mlu/vllm_mlu/mlu_custom/preload/README.md
Normal file
32
vllm-v0.6.2/vllm_mlu/vllm_mlu/mlu_custom/preload/README.md
Normal file
@@ -0,0 +1,32 @@
|
||||
### 简介
|
||||
|
||||
该劫持代码实现在vLLM的解码通信过程中预加载下一层的权重,从而减少解码的延迟。
|
||||
|
||||
### 支持模型
|
||||
|
||||
仅支持以下模型,不支持量化后的模型以及MOE模型。
|
||||
- Baichuan
|
||||
- Bloom
|
||||
- ChatGLM
|
||||
- Falcon
|
||||
- GPTNeoX
|
||||
- Llama
|
||||
- Qwen
|
||||
- Qwen2
|
||||
|
||||
### 支持板卡
|
||||
|
||||
300系列不支持,其他系列支持。
|
||||
|
||||
### 使用方法
|
||||
|
||||
- 设置环境变量export VLLM_PRELOAD_SIZE=<PRELOAD_SIZE>,<PRELOAD_SIZE>表示预加载权重的大小,单位:MB。
|
||||
- 参数设置参考:在低带宽资源环境下,对于模型Llama-65B,不同batch_sized和preload_size对应的性能优化收益如下。
|
||||
|
||||
| batch\preload | 8 | 16 | 24 | 32 | 48 | 64 |
|
||||
|:--------------:|:----:|:----:|:----:|:----:|:----:|:----:|
|
||||
| 1 | 4.9% | 10.0%| 9.5% | 6.7% |-2.4% | -7.1%|
|
||||
| 8 | 3.2% | 6.3% | 8.9% | 11.2%| 6.0% | 1.8% |
|
||||
| 16 | 2.3% | 5.1% | 7.5% | 9.2% | 8.3% | 4.3% |
|
||||
| 24 | 2.3% | 4.8% | 7.4% | 9.1% | 9.5% | 6.0% |
|
||||
| 32 | 2.1% | 4.3% | 7.0% | 8.7% | 10.1%| 8.1% |
|
||||
@@ -0,0 +1 @@
|
||||
from . import distributed
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user