add qwen3

This commit is contained in:
Chranos
2026-02-04 17:22:39 +08:00
parent d1c0f68ab4
commit 8511fe8530
1932 changed files with 300426 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
import vllm_mlu.worker.mlu_worker
import vllm_mlu.worker.mlu_model_runner
import vllm_mlu.worker.mlu_multi_step_model_runner
import vllm_mlu.worker.cache_engine
import vllm_mlu.worker.mlu_enc_dec_model_runner

View File

@@ -0,0 +1,152 @@
"""CacheEngine class for managing the KV cache."""
from typing import List, Tuple, Optional
import torch
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import is_pin_memory_available, get_dtype_size
from vllm.worker.cache_engine import CacheEngine
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__worker__cache_engine__CacheEngine___allocate_kv_cache(
self,
num_blocks: int,
device: str,
) -> List[List[torch.Tensor]]:
"""Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add kv_cache_scale for int8 support
'''
kv_cache_scales_shape = self.attn_backend.get_kv_cache_scale_shape(
num_blocks, self.block_size, self.num_kv_heads)
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[List[torch.Tensor]] = []
for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
kv_cache_ = torch.zeros(kv_cache_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device)
if self.dtype == torch.int8:
kv_cache_scale_ = torch.zeros(kv_cache_scales_shape,
dtype=torch.float32,
pin_memory=pin_memory,
device=device)
else:
kv_cache_scale_ = torch.tensor([],
dtype=torch.float32,
device=device)
kv_cache.append([kv_cache_, kv_cache_scale_])
'''
==================
End of MLU Hijack
==================
'''
return kv_cache
def vllm__worker__cache_engine__CacheEngine__swap_in(self, src_to_dst: torch.Tensor) -> None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: swap kv_cache_scale for int8 support
'''
for i in range(self.num_attention_layers):
# swap kv_cache
self.attn_backend.swap_blocks(self.cpu_cache[i][0], self.gpu_cache[i][0],
src_to_dst)
if self.dtype == torch.int8:
# swap kv_cache_scale
self.attn_backend.swap_blocks(self.cpu_cache[i][1], self.gpu_cache[i][1],
src_to_dst)
'''
==================
End of MLU Hijack
==================
'''
def vllm__worker__cache_engine__CacheEngine__swap_out(self, src_to_dst: torch.Tensor) -> None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: swap kv_cache_scale for int8 support
'''
for i in range(self.num_attention_layers):
# swap kv_cache
self.attn_backend.swap_blocks(self.gpu_cache[i][0], self.cpu_cache[i][0],
src_to_dst)
if self.dtype == torch.int8:
# swap kv_cache_scale
self.attn_backend.swap_blocks(self.gpu_cache[i][1], self.cpu_cache[i][1],
src_to_dst)
'''
==================
End of MLU Hijack
==================
'''
vllm__worker__cache_engine__CacheEngine__get_cache_block_size__org = CacheEngine.get_cache_block_size
@staticmethod
def vllm__worker__cache_engine__CacheEngine__get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
kv_cache_total_size = vllm__worker__cache_engine__CacheEngine__get_cache_block_size__org(
cache_config=cache_config,
model_config=model_config,
parallel_config=parallel_config
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: compute kv_cache_scale total size
'''
num_heads = model_config.get_num_kv_heads(parallel_config)
num_attention_layers = model_config.get_num_attention_layers(parallel_config)
kv_cache_scale_total_size = 0
if cache_config.cache_dtype == 'int8':
key_cache_scale_block = cache_config.block_size * num_heads
value_cache_scale_block = key_cache_scale_block
scale_total = num_attention_layers * (key_cache_scale_block + value_cache_scale_block)
dtype_size = get_dtype_size(torch.float32)
kv_cache_scale_total_size = dtype_size * scale_total
return kv_cache_total_size + kv_cache_scale_total_size
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(CacheEngine,
CacheEngine._allocate_kv_cache,
vllm__worker__cache_engine__CacheEngine___allocate_kv_cache)
MluHijackObject.apply_hijack(CacheEngine,
CacheEngine.swap_in,
vllm__worker__cache_engine__CacheEngine__swap_in)
MluHijackObject.apply_hijack(CacheEngine,
CacheEngine.swap_out,
vllm__worker__cache_engine__CacheEngine__swap_out)
MluHijackObject.apply_hijack(CacheEngine,
CacheEngine.get_cache_block_size,
vllm__worker__cache_engine__CacheEngine__get_cache_block_size)

View File

@@ -0,0 +1,329 @@
import itertools
from typing import List, Optional, Tuple
import torch
import torch.distributed
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelInput
from vllm.worker.mlu_enc_dec_model_runner import MLUEncoderDecoderModelRunner
from vllm.worker.model_runner import _get_graph_batch_size
from vllm.utils import make_tensor_with_pad
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
class MLUEncoderDecoderModelRunner_V2(MLUEncoderDecoderModelRunner):
@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
self.model_config)
if max_mm_tokens > 0:
logger.info("Starting profile run for multi-modal models.")
batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
decoder_dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry,
is_encoder_data=False)
encoder_dummy_data \
= self.input_registry.dummy_data_for_profiling(
self.model_config,
seq_len,
self.mm_registry,
is_encoder_data=True)
# Having more tokens is over-conservative but otherwise fine
assert len(
decoder_dummy_data.seq_data.prompt_token_ids
) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}"
)
assert decoder_dummy_data.multi_modal_data is None or \
encoder_dummy_data.multi_modal_data is None, (
"Multi-modal data can't be provided in both encoder and decoder"
)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: decoder_dummy_data.seq_data},
sampling_params=sampling_params,
block_tables=None,
encoder_seq_data=encoder_dummy_data.seq_data,
cross_block_table=None,
multi_modal_data=decoder_dummy_data.multi_modal_data
or encoder_dummy_data.multi_modal_data,
multi_modal_placeholders=decoder_dummy_data.
multi_modal_placeholders
or encoder_dummy_data.multi_modal_placeholders)
seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
'''
=============================
Modify by vllm_mlu
=============================
@brief: support kv cache int8
'''
kv_caches = []
for _ in range(num_layers):
kv_cache_ = torch.tensor([], dtype=torch.float32, device=self.device)
kv_cache_scale_ = torch.tensor([], dtype=torch.float32, device=self.device)
kv_caches.append([kv_cache_, kv_cache_scale_])
'''
==================
End of MLU Hijack
==================
'''
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
intermediate_tensors = None
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize()
return
def _prepare_encoder_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
model_input: EncoderDecoderModelInput,
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""Helper method to prepare the encoder- and cross-attn-related
model inputs based on a given sequence group. These additional inputs
are used to augment an already-computed `EncoderDecoderModelInput`
data structure which already has decoder-related model inputs
populated.
Sets the following attn_metadata fields:
* `num_encoder_tokens`
* `encoder_seq_lens`
* `encoder_seq_lens_tensor`
* `max_encoder_seq_len`
* `cross_slot_mapping`
* `cross_block_tables`
Constructs a new model inputs data structure, based on
(1) the existing fields in the `model_inputs` argument,
and (2) the following additional fields which are
computed (or in the case of `attn_metadata`, updated)
by this function:
* attn_metadata
* encoder_input_tokens
* encoder_input_positions
Arguments:
* seq_group_metadata_list: list of sequence groups for which to
compute inputs
* model_inputs: model inputs data structure with decoder-oriented
fields already computed.
Return:
* Updated model inputs data structure
"""
if len(seq_group_metadata_list) == 0:
return (model_input.attn_metadata, None, None)
# Since we are not supporting chunked prefill either the entire
# batch is prefill or it is decode
is_prompt = seq_group_metadata_list[0].is_prompt
# Build encoder inputs
encoder_seq_lens: List[int] = []
if is_prompt:
# Prefill phase.
cross_block_tables = self._empty_int32_tensor().view(
len(seq_group_metadata_list), -1)
# Extract input tokens/positions, cross-attention slot-mapping,
# & seq len from each sequence group metadata
(
encoder_input_tokens,
encoder_input_positions,
cross_slot_mapping,
) = (
[],
[],
[],
)
for seq_group_metadata in seq_group_metadata_list:
# Build seq lens
seq_len = seq_group_metadata.encoder_seq_data.get_len()
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
encoder_seq_lens.append(seq_len)
# Build slot mapping
is_profile_run = (seq_group_metadata.block_tables is None)
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len)
else:
for i in range(0, seq_len):
block_number = seq_group_metadata.cross_block_table[
i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
cross_slot_mapping.append(slot)
# Build encoder input tokens
encoder_input_tokens.extend(token_ids)
encoder_input_positions.extend(list(range(0, seq_len)))
# Convert tokens/positions & cross-attention
# slot-mapping to encoder input tensors
encoder_input_tokens_tensor = self._list_to_long_tensor(
encoder_input_tokens)
encoder_input_positions_tensor = self._list_to_long_tensor(
encoder_input_positions)
'''
=============================
Modify by vllm_mlu
=============================
@brief: since `slot_mapping` parameter in tmo.reshape_paged_cache only
support int32, change dtype to int32.
'''
cross_slot_mapping_tensor = self._list_to_int32_tensor(
cross_slot_mapping)
'''
==================
End of MLU Hijack
==================
'''
else:
# Decode phase.
encoder_input_tokens_tensor = self._empty_long_tensor()
encoder_input_positions_tensor = self._empty_long_tensor()
'''
=============================
Modify by vllm_mlu
=============================
@brief: since `slot_mapping` parameter in tmo.reshape_paged_cache only
support int32, change dtype to int32.
'''
cross_slot_mapping_tensor = self._empty_int32_tensor()
'''
==================
End of MLU Hijack
==================
'''
# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables = []
for seq_group_metadata in seq_group_metadata_list:
for _ in range(len(seq_group_metadata.seq_data)):
encoder_seq_lens.append(
seq_group_metadata.encoder_seq_data.get_len())
cross_block_table = seq_group_metadata.cross_block_table
cross_block_tables.append([] if (
cross_block_table is None) else cross_block_table)
if (model_input.attn_metadata is not None
and model_input.attn_metadata.use_cuda_graph):
# We will be using CUDA graph replay for this decode.
max_len_of_block_table = self.get_max_block_per_batch()
batch_size = len(encoder_seq_lens)
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
cuda_graph_pad_size = graph_batch_size - batch_size
# extend the cross_block_tables and encoder_seq_lens to match
# the graph_batch_size.
cross_block_tables.extend([[]
for _ in range(cuda_graph_pad_size)
])
encoder_seq_lens.extend(
itertools.repeat(1, cuda_graph_pad_size))
else:
max_len_of_block_table = max(
len(block_table) for block_table in cross_block_tables)
cross_block_tables = make_tensor_with_pad(
cross_block_tables,
max_len=max_len_of_block_table,
pad=0,
dtype=torch.int32,
device=self.device,
)
# Compute encoder sequence lengths & encoder
# sequence starting offset tensors
max_encoder_seq_len = max(encoder_seq_lens, default=0)
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
1,
dtype=torch.int32,
device=self.device)
torch.cumsum(encoder_seq_lens_tensor,
dim=0,
dtype=encoder_seq_start_loc.dtype,
out=encoder_seq_start_loc[1:])
# Update attention metadata with encoder-oriented attributes
attn_metadata = model_input.attn_metadata
assert attn_metadata is not None
(
attn_metadata.num_encoder_tokens,
attn_metadata.encoder_seq_lens,
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.encoder_seq_start_loc,
attn_metadata.cross_slot_mapping,
attn_metadata.cross_block_tables,
) = (
sum(encoder_seq_lens),
encoder_seq_lens,
encoder_seq_lens_tensor,
max_encoder_seq_len,
encoder_seq_start_loc,
cross_slot_mapping_tensor,
cross_block_tables,
)
return (attn_metadata, encoder_input_tokens_tensor,
encoder_input_positions_tensor)
MluHijackObject.apply_hijack(MLUEncoderDecoderModelRunner,
MLUEncoderDecoderModelRunner.profile_run,
MLUEncoderDecoderModelRunner_V2.profile_run)
MluHijackObject.apply_hijack(MLUEncoderDecoderModelRunner,
MLUEncoderDecoderModelRunner._prepare_encoder_model_input_tensors,
MLUEncoderDecoderModelRunner_V2._prepare_encoder_model_input_tensors)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,28 @@
import torch
from typing import List
from vllm.worker.mlu_multi_step_model_runner import MLUMultiStepModelRunner
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
logger = init_logger(__name__)
def vllm__worker__mlu_multi_step_model_runner__MLUMultiStepModelRunner__capture_model_with_context(
self,
kv_caches: List[List[torch.Tensor]],
num_gpu_blocks: int
) -> None:
return self._base_model_runner.capture_model_with_context(kv_caches, num_gpu_blocks)
def vllm__worker__mlu_multi_step_model_runner__MLUMultiStepModelRunner__reset_capture_context(self):
return self._base_model_runner.reset_capture_context()
MluHijackObject.apply_hijack(MLUMultiStepModelRunner,
"capture_model_with_context",
vllm__worker__mlu_multi_step_model_runner__MLUMultiStepModelRunner__capture_model_with_context)
MluHijackObject.apply_hijack(MLUMultiStepModelRunner,
"reset_capture_context",
vllm__worker__mlu_multi_step_model_runner__MLUMultiStepModelRunner__reset_capture_context)

View File

@@ -0,0 +1,303 @@
import torch
import gc
import functools
from collections import defaultdict
from typing import Tuple
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm.model_executor.layers.vocab_parallel_embedding import (VocabParallelEmbedding,
ParallelLMHead)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor import set_random_seed
from vllm.worker.mlu_worker import MLUWorker
from vllm.logger import init_logger
logger = init_logger(__name__)
def default_act_range_value():
return {
"x": None,
"split": None,
"is_linear": False,
"is_qkv": False,
"q_proj_size": 0,
"num_kv_head_replicas": 1,
"is_merge": False,
"input_id": []
}
class MLUWorker_V2(MLUWorker):
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.mlu.empty_cache()
torch.mlu.reset_peak_memory_stats()
free_memory_pre_profile, total_gpu_memory = torch.mlu.mem_get_info()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
torch.mlu.synchronize()
self._assert_memory_footprint_increased_during_profiling()
# Get the peak memory allocation recorded by torch
peak_memory = torch.mlu.memory_stats()["allocated_bytes.all.peak"]
# Check for any memory left around that may have been allocated on the
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass
torch.mlu.empty_cache()
torch_allocated_bytes = torch.mlu.memory_stats(
)["allocated_bytes.all.current"]
total_allocated_bytes = torch.mlu.mem_get_info(
)[1] - torch.mlu.mem_get_info()[0]
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
cache_block_size = self.get_cache_block_size_bytes()
if cache_block_size == 0:
num_gpu_blocks = 0
num_cpu_blocks = 0
else:
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
logger.info(
"Memory profiling results: total_gpu_memory=%.2fGiB"
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
" memory_usage_post_profile=%.2fGiB"
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
(peak_memory - non_torch_allocations) / (1024**3),
total_allocated_bytes / (1024**3),
non_torch_allocations / (1024**3),
available_kv_cache_memory / (1024**3),
self.cache_config.gpu_memory_utilization)
# Final cleanup
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
'''
=============================
Modify by vllm_mlu
=============================
@brief: record init memory usage
'''
# Record memory usage
self.peak_memory = peak_memory
self.block_memory = available_kv_cache_memory
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
'''
==================
End of MLU Hijack
==================
'''
return num_gpu_blocks, num_cpu_blocks
def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
'''
=============================
Modify by vllm_mlu
=============================
@brief: support context mlugraph
'''
if self.model_config.use_context_mlugraph():
# Capture MLUGraph both prefill and decode
self.model_runner.capture_model_with_context(self.gpu_cache,
self.cache_config.num_gpu_blocks)
else:
# Capture MLUGraph only decode
self.model_runner.capture_model(self.gpu_cache,
self.cache_config.num_gpu_blocks)
'''
==================
End of MLU Hijack
==================
'''
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def get_latency(self):
start, end = self.model_runner.time_markers
return start.elapsed_time(end)
def get_memory_usage(self):
return (self.peak_memory, self.block_memory,
self.num_gpu_blocks, self.num_cpu_blocks)
def recapture_model(
self,
context_batch_size_to_capture,
context_seq_len_to_capture
) -> None:
# Reset history capture context
self.model_runner.reset_capture_context()
# Re-capture context and decoder mlugraph
self.model_runner.model_config.context_batch_size_to_capture = context_batch_size_to_capture
self.model_runner.model_config.context_seq_len_to_capture = context_seq_len_to_capture
self._warm_up_model()
def stat_tensor(self, name, tensor, act_range, key, dim):
logger.debug(f"name:{name}, key:{key}, dim:{dim}, tensor.shape:{tensor.shape}")
hidden_dim = tensor.shape[-1]
# TODO
# For torch.max has bug which generates nan/inf, so load the tensor to cpu to do torch.max.
# And need to convert to mlu after the bug is fixed.
# The pytorch jira: http://jira.cambricon.com/browse/PYTORCH-12199
tensor = tensor.view(-1, hidden_dim).abs()
comming_max = torch.max(tensor, dim=dim)[0].float()
if act_range[name][key] is None:
act_range[name][key] = comming_max
else:
act_range[name][key] = torch.max(act_range[name][key], comming_max)
def stat_input_hook(self, m, x, y, name, act_range, is_linear, is_save_input_id):
if isinstance(x, tuple):
x = x[0]
if isinstance(y, tuple):
y = y[0]
logger.debug(f"name:{name}, x.shape:{x.shape}, y.shape:{y.shape}, m.weight.shape:{m.weight.shape}")
if is_linear:
self.stat_tensor(name, x, act_range, "x", 0)
if act_range[name]["is_qkv"] and is_save_input_id and ".0." in name:
x_cpu = x.clone().to("cpu")
act_range[name]["input_id"].append(x_cpu)
def setup_smooth_hook(self, is_save_input_id: bool = False):
model = self.model_runner.model
self.act_range = defaultdict(default_act_range_value)
self.hooks = []
linear_class_list = (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
other_class_list = (VocabParallelEmbedding, ParallelLMHead)
class_list = linear_class_list + other_class_list
row_class_list = (RowParallelLinear)
for name, m in model.named_modules():
if isinstance(m, FeedForward):
m.use_bt_ffn = False
if isinstance(m, SparseMoeMlp):
m.is_use_fused_moe = False
if isinstance(m, class_list):
is_linear = True if isinstance(m, linear_class_list) else False
split_type = "row" if isinstance(m, row_class_list) else "col"
self.act_range[name]["split"] = split_type
self.act_range[name]["is_linear"] = is_linear
if isinstance(m, QKVParallelLinear):
self.act_range[name]["is_qkv"] = True
self.act_range[name]["q_proj_size"] = m.num_heads * m.head_size
self.act_range[name]["num_kv_head_replicas"] = m.num_kv_head_replicas
self.act_range[name]["is_merge"] = isinstance(m, MergedColumnParallelLinear)
logger.info(f"rank:{self.rank}, add hook to {name}, is_linear:{is_linear}, split_type:{split_type}")
self.hooks.append(m.register_forward_hook(functools.partial(self.stat_input_hook,
name=name, act_range=self.act_range,
is_linear=is_linear,
is_save_input_id=is_save_input_id)))
def remove_hooks(self):
for h in self.hooks:
h.remove()
def get_act_range(self):
act_range = defaultdict(default_act_range_value)
for layer_name, layer_range in self.act_range.items():
for tensor_key, tensor_value in layer_range.items():
if isinstance(tensor_value, torch.Tensor):
act_range[layer_name][tensor_key] = tensor_value.to("cpu")
elif tensor_key == "input_id" and isinstance(tensor_value, list):
input_id_len = len(tensor_value)
for i in range(input_id_len):
if isinstance(tensor_value[i], torch.Tensor):
act_range[layer_name][tensor_key].append(tensor_value[i].to("cpu"))
else:
act_range[layer_name][tensor_key].append(tensor_value[i])
else:
act_range[layer_name][tensor_key] = tensor_value
return act_range
@torch.no_grad()
def get_named_parameters(self):
name_parameters = {}
for name, param in self.model_runner.model.named_parameters():
name_parameters[name] = param.to("cpu")
return name_parameters
MluHijackObject.apply_hijack(MLUWorker,
MLUWorker.determine_num_available_blocks,
MLUWorker_V2.determine_num_available_blocks)
MluHijackObject.apply_hijack(MLUWorker,
MLUWorker._warm_up_model,
MLUWorker_V2._warm_up_model)
MluHijackObject.apply_hijack(MLUWorker,
"get_latency",
MLUWorker_V2.get_latency)
MluHijackObject.apply_hijack(MLUWorker,
"get_memory_usage",
MLUWorker_V2.get_memory_usage)
MluHijackObject.apply_hijack(MLUWorker,
"recapture_model",
MLUWorker_V2.recapture_model)
MluHijackObject.apply_hijack(MLUWorker,
"stat_tensor",
MLUWorker_V2.stat_tensor)
MluHijackObject.apply_hijack(MLUWorker,
"stat_input_hook",
MLUWorker_V2.stat_input_hook)
MluHijackObject.apply_hijack(MLUWorker,
"setup_smooth_hook",
MLUWorker_V2.setup_smooth_hook)
MluHijackObject.apply_hijack(MLUWorker,
"remove_hooks",
MLUWorker_V2.remove_hooks)
MluHijackObject.apply_hijack(MLUWorker,
"get_act_range",
MLUWorker_V2.get_act_range)
MluHijackObject.apply_hijack(MLUWorker,
"get_named_parameters",
MLUWorker_V2.get_named_parameters)