init
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
235
vllm_vacc/vllm/model_executor/models/memory/allocator.py
Normal file
235
vllm_vacc/vllm/model_executor/models/memory/allocator.py
Normal file
@@ -0,0 +1,235 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List
|
||||
|
||||
class VaccHugeMemoryAllocator(nn.Module):
|
||||
# self._active_bytes means the real tensor buffers used bytes.
|
||||
# you can use this value to slice the src buffer
|
||||
# you can not free the self._src_buffer_array, because the src buffer is the max buffer
|
||||
# self._block_bytes means the part max buffer size
|
||||
def __init__(self, blocks, dtype = torch.bfloat16, use_contiguous = False):
|
||||
self._total_blocks = blocks
|
||||
self._dtype = dtype
|
||||
self._enable = False
|
||||
self._src_buffer_array = None
|
||||
self._block_bytes = 0
|
||||
self._max_tokens = 0
|
||||
self._hiddens = 0
|
||||
self._active_bytes = 0
|
||||
self._use_contiguous_buffer = use_contiguous
|
||||
# max tokens for dynamic buffer size
|
||||
# dynamic buffer size is bigger than normal buffer usually
|
||||
self._dynamic_max_tokens = 0
|
||||
self._dynamic_block_bytes = 0
|
||||
|
||||
# malloc the max buffer, and not free
|
||||
def init_buffers(self, max_tokens, hiddens):
|
||||
self._max_tokens = max_tokens
|
||||
self._hiddens = hiddens
|
||||
|
||||
try:
|
||||
import torch_vacc
|
||||
self._block_bytes = self._max_tokens * self._hiddens \
|
||||
* self.get_dtype_bytes(self._dtype)
|
||||
self._all_bytes = self._block_bytes * self._total_blocks
|
||||
|
||||
if self._use_contiguous_buffer:
|
||||
# 一次性申请[N*3,]大小的BytesBuffer
|
||||
self._src_buffer = torch.zeros(self._all_bytes,
|
||||
dtype = torch.uint8,
|
||||
device = "vacc")
|
||||
tmp_buffer_array = self._src_buffer.view(self._total_blocks, -1)
|
||||
self._src_buffer_array = [tmp_buffer_array[i]
|
||||
for i in range(self._total_blocks)]
|
||||
else:
|
||||
# 一次性申请3块[N,]大小的BytesBuffer
|
||||
self._src_buffer_array = [torch.zeros(self.block_bytes,
|
||||
dtype = torch.uint8,
|
||||
device = "vacc")
|
||||
for i in range(self._total_blocks)]
|
||||
self._enable = True
|
||||
except Exception as e:
|
||||
print(f"vacc huge buffer alloc fail: {e}")
|
||||
|
||||
# 为有dynamic buffer需求的网络设计, dynamic buffer 可能会要比普通的max_tokens buffers大一些
|
||||
def init_buffers_with_dynamic(self, max_tokens, dynamic_tokens, hiddens, dynamic_buffers_mask: List):
|
||||
self._max_tokens = max_tokens
|
||||
self._dynamic_max_tokens = dynamic_tokens
|
||||
self._hiddens = hiddens
|
||||
|
||||
dynamic_buffers_count = sum(dynamic_buffers_mask)
|
||||
normal_buffers_count = self._total_blocks - dynamic_buffers_count
|
||||
|
||||
self._block_bytes = self._max_tokens * self._hiddens \
|
||||
* self.get_dtype_bytes(self._dtype)
|
||||
self._dynamic_block_bytes = self._dynamic_max_tokens * self._hiddens \
|
||||
* self.get_dtype_bytes(self._dtype)
|
||||
|
||||
self._all_bytes = self._block_bytes * normal_buffers_count + \
|
||||
self._dynamic_block_bytes * dynamic_buffers_count
|
||||
|
||||
# print("创建重复利用buffer: dynamic的数量->", dynamic_buffers_count,
|
||||
# " 正常的数量->", normal_buffers_count,
|
||||
# " dynamic buffer大小->", self._dynamic_block_bytes,
|
||||
# " 正常大小->", self._block_bytes)
|
||||
|
||||
try:
|
||||
assert self._use_contiguous_buffer is False, "malloc dynamic recycle memory buffers only support separation buffer now"
|
||||
|
||||
self._src_buffer_array = []
|
||||
for i in range(self._total_blocks):
|
||||
if dynamic_buffers_mask[i]:
|
||||
self._src_buffer_array.append(torch.zeros(self._dynamic_block_bytes,
|
||||
dtype = torch.uint8,
|
||||
device = "vacc"))
|
||||
else:
|
||||
# 一次性申请3块[N,]大小的BytesBuffer
|
||||
self._src_buffer_array.append(torch.zeros(self._block_bytes,
|
||||
dtype = torch.uint8,
|
||||
device = "vacc"))
|
||||
self._enable = True
|
||||
except Exception as e:
|
||||
print("vacc huge buffer alloc fail.", e)
|
||||
|
||||
# slice buffers from the src buffer (48K * blocks)
|
||||
# use for target tensor
|
||||
# you should analyse such tensor position by model, such as deepseek have 4 buffers
|
||||
# you need alloc the buffer by real input tokens when new request is in
|
||||
# notice: you should warn the dtype, because the buffer created by uint8
|
||||
def alloc_memory_buffers(self, tokens, dtype=torch.bfloat16):
|
||||
if tokens > self._max_tokens:
|
||||
print("alloc memory buffer fail, tokens is large than max_tokens.", self._max_tokens)
|
||||
return None
|
||||
|
||||
self._active_bytes = tokens * self._hiddens * self.get_dtype_bytes(dtype)
|
||||
return [sub_array[:self._active_bytes]
|
||||
for sub_array in self._src_buffer_array]
|
||||
|
||||
@property
|
||||
def memory_buffers(self):
|
||||
return self._src_buffer_array
|
||||
|
||||
# allock 1_2 buffers
|
||||
# @params tokens 待缓存的prefill tokens buffer大小
|
||||
# @params part 总共划分的区域
|
||||
# @params return_buffer_list 需要返回的区域列表,如果为空的话,返回所有
|
||||
# @params dtype 数据类型
|
||||
# 创建1/N的buffers
|
||||
# 返回[第N部分]
|
||||
def alloc_1_div_N_buffers(self, part = 2,
|
||||
return_buffer_list = [], ):
|
||||
|
||||
if not hasattr(self, "_src_buffer"):
|
||||
print("1 div N alloctor need a contiguous buffer")
|
||||
return None
|
||||
|
||||
assert isinstance(return_buffer_list, list), "return_buffer_list need List object"
|
||||
|
||||
# 数据以int8的方式,划分为part
|
||||
tmp_buffer_array = self._src_buffer.view(part, -1)
|
||||
# 如果未指定return_buffer_list, 返回所有的part buffer
|
||||
if len(return_buffer_list) == 0:
|
||||
return [tmp_buffer_array[i] for i in range(part)]
|
||||
|
||||
return [tmp_buffer_array[i] for i in return_buffer_list]
|
||||
|
||||
|
||||
def get_dtype_bytes(self, dtype):
|
||||
if isinstance(dtype, torch.dtype):
|
||||
if dtype in [torch.float16, torch.bfloat16, torch.half]:
|
||||
return 2
|
||||
elif dtype in [torch.float32, torch.float, torch.int32]:
|
||||
return 4
|
||||
elif dtype in [torch.float64, torch.double, torch.int64]:
|
||||
return 8
|
||||
elif dtype in [torch.int8, torch.uint8, torch.bool]:
|
||||
return 1
|
||||
else:
|
||||
return 1
|
||||
elif dtype == int:
|
||||
return 8
|
||||
elif dtype == float:
|
||||
return 8
|
||||
elif dtype == bool:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@property
|
||||
def enable(self):
|
||||
return self._enable
|
||||
|
||||
@property
|
||||
def blocks(self):
|
||||
return self._total_blocks
|
||||
|
||||
@property
|
||||
def max_tokens(self):
|
||||
return self._max_tokens
|
||||
|
||||
@property
|
||||
def hiddens(self):
|
||||
return self._hiddens
|
||||
|
||||
@property
|
||||
def active_bytes(self):
|
||||
return self._active_bytes
|
||||
|
||||
@property
|
||||
def block_bytes(self):
|
||||
return self._block_bytes
|
||||
|
||||
@property
|
||||
def dynamic_block_bytes(self):
|
||||
return self._dynamic_block_bytes
|
||||
|
||||
|
||||
class LLMMemoryRecycler:
|
||||
def __init__(self):
|
||||
self.count = 3
|
||||
self.embedding_output = None
|
||||
self.moe_shared_mlp_output = None
|
||||
self.mla_oproj_output = None
|
||||
#self.moe_expert_output = None
|
||||
|
||||
def clear(self):
|
||||
self.embedding_output = None
|
||||
self.moe_shared_mlp_output = None
|
||||
self.mla_oproj_output = None
|
||||
#self.moe_expert_output = None
|
||||
|
||||
@property
|
||||
def EMBEDDING_OUT_BUFFER(self):
|
||||
return self.embedding_output
|
||||
@property
|
||||
def MOE_SHARED_MLP_OUT_BUFFER(self):
|
||||
return self.moe_shared_mlp_output
|
||||
@property
|
||||
def MLA_OPROJ_OUT_BUFFER(self):
|
||||
return self.mla_oproj_output
|
||||
|
||||
def alloc_memory_recycler_llm(tokens,
|
||||
alloctor:VaccHugeMemoryAllocator,
|
||||
recycler:LLMMemoryRecycler,
|
||||
dtype:torch.dtype = torch.bfloat16):
|
||||
|
||||
if not alloctor.enable:
|
||||
print("memory alloctor is not Init.")
|
||||
return False
|
||||
|
||||
recycler.clear()
|
||||
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
|
||||
|
||||
if out_buffers is None:
|
||||
print("llm memory recycler buffers alloc fail. now disable it")
|
||||
return False
|
||||
|
||||
if len(out_buffers) != recycler.count:
|
||||
print("memory recycler buffers not equal llm.")
|
||||
return False
|
||||
|
||||
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
|
||||
#recycler.moe_expert_output = out_buffers[3].view(dtype).view(tokens, alloctor.hiddens)
|
||||
return True
|
||||
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
from .allocator import VaccHugeMemoryAllocator,LLMMemoryRecycler
|
||||
'''
|
||||
DeepSeek support 48K input tokens, there are 3 buffers can recycle:
|
||||
1. parallel_embedding output buffer
|
||||
2. mla_oproject output buffer
|
||||
3. moe_shared_mlp output buffer
|
||||
# 4. moe_expert output buffer
|
||||
each buffer size is 48K * 7168 * 2 bytes
|
||||
'''
|
||||
class DeepseekV3MemoryRecycler(LLMMemoryRecycler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
#self.moe_expert_output = None
|
||||
|
||||
# @property
|
||||
# def MOE_EXPERT_OUT_BUFFER(self):
|
||||
# return self.moe_expert_output
|
||||
|
||||
|
||||
class DeepseekMTPMemoryRecycler(LLMMemoryRecycler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dynamic_output = None
|
||||
self.deepseek_mtp_layer_input = None
|
||||
|
||||
@property
|
||||
def DYNAMIC_OUTPUT_BUFFER(self):
|
||||
return self.dynamic_output
|
||||
|
||||
@property
|
||||
def DEEPSEEK_MTP_LAYER_INPUT(self):
|
||||
return self.deepseek_mtp_layer_input
|
||||
|
||||
|
||||
def alloc_memory_recycler_deepseek_v3(tokens,
|
||||
alloctor:VaccHugeMemoryAllocator,
|
||||
recycler:DeepseekV3MemoryRecycler,
|
||||
dtype:torch.dtype = torch.bfloat16):
|
||||
|
||||
if not alloctor.enable:
|
||||
print("memory alloctor is not Init.")
|
||||
return False
|
||||
|
||||
recycler.clear()
|
||||
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
|
||||
|
||||
if out_buffers is None:
|
||||
print("deepseek_v3 memory recycler buffers alloc fail. now disable it")
|
||||
return False
|
||||
|
||||
if len(out_buffers) != recycler.count:
|
||||
print("memory recycler buffers not equal deepseek_v3.")
|
||||
return False
|
||||
|
||||
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
|
||||
#recycler.moe_expert_output = out_buffers[3].view(dtype).view(tokens, alloctor.hiddens)
|
||||
return True
|
||||
|
||||
|
||||
def alloc_memory_recycler_deepseek_mtp(tokens,
|
||||
alloctor:VaccHugeMemoryAllocator,
|
||||
world_size:int,
|
||||
recycler:DeepseekMTPMemoryRecycler,
|
||||
dtype:torch.dtype = torch.bfloat16):
|
||||
|
||||
if not alloctor.enable:
|
||||
print("memory alloctor is not Init.")
|
||||
return False
|
||||
|
||||
recycler.clear()
|
||||
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
|
||||
|
||||
if out_buffers is None:
|
||||
print("deepseek_mtp memory recycler buffers alloc fail. now disable it")
|
||||
return False
|
||||
|
||||
if len(out_buffers) != recycler.count:
|
||||
print("memory recycler buffers not equal deepseek_mtp.")
|
||||
return False
|
||||
#MTP的内存布局为
|
||||
# 1. deepseek 主模型和 草稿模型中的decoder_layer
|
||||
# a. embedding_output
|
||||
# b. mla_oproj_output
|
||||
# c. moe_shared_mlp_output
|
||||
# 因为: moe会作为previous_hidden_states,被重新组织一次,组织好的buffer会置于buffer0[a.]中
|
||||
# 如果embedding_output放在第一位也可以,相关地址需要整体往后偏移,
|
||||
# 为了理解方便把embedding置于最后,不会参与buffer的重划分复用
|
||||
# 2. deepseek_mtp 草稿模型(未启用该策略, )
|
||||
# a. dynamic_output 占用1/2
|
||||
# b. mtp_input 占用1/6,并且位于dynamic_output buffer后面
|
||||
# dynamic_buffer = alloctor.alloc_1_div_N_buffers(2, [0,])
|
||||
# mtp_input_buffer = alloctor.alloc_1_div_N_buffers(6, [3,])
|
||||
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
|
||||
|
||||
#1/2用broadcast的自由分配,mtp涉及到previous_hidden_states的广播
|
||||
memory_buffers = alloctor.memory_buffers
|
||||
recycler.dynamic_output = memory_buffers[1] #公共用mla_oproject
|
||||
#1/6用于mtp decoder layer的输入, 该算子特殊处理,仅用了1/tp的输出buffer即可
|
||||
# recycler.deepseek_mtp_layer_input = mtp_input_buffer[0]
|
||||
|
||||
mtp_layer_input_dims = alloctor.hiddens * 2 // world_size
|
||||
mtp_layer_input_numels = tokens * mtp_layer_input_dims
|
||||
|
||||
from vllm import envs
|
||||
if envs.VLLM_USE_V1:
|
||||
#v1 无需重新缓存previous_hidden_states,因此moe还在被占用状态,因此需要先用attention的o-buffer去暂存mtp 预处理的空间
|
||||
recycler.deepseek_mtp_layer_input = memory_buffers[1].view(dtype)[:mtp_layer_input_numels].view(tokens, mtp_layer_input_dims)
|
||||
else:
|
||||
# 公共用moe output
|
||||
recycler.deepseek_mtp_layer_input = memory_buffers[2].view(dtype)[:mtp_layer_input_numels].view(tokens, mtp_layer_input_dims)
|
||||
return True
|
||||
132
vllm_vacc/vllm/model_executor/models/memory/memory_recycling.py
Normal file
132
vllm_vacc/vllm/model_executor/models/memory/memory_recycling.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import os
|
||||
import torch
|
||||
from .allocator import VaccHugeMemoryAllocator
|
||||
from .deepseek_v3_memory_recycler import DeepseekMTPMemoryRecycler
|
||||
|
||||
VLLM_MODEL_MODE = os.environ.get("VLLM_MODEL_MODE", "deepseek")
|
||||
|
||||
global huge_memory_alloctor
|
||||
global memory_recycler
|
||||
huge_memory_alloctor = None
|
||||
memory_recycler = None
|
||||
|
||||
# you should call this function when new request is in
|
||||
def alloc_memory_recycler(tokens, dtype=torch.bfloat16, **argv):
|
||||
global huge_memory_alloctor
|
||||
global memory_recycler
|
||||
|
||||
vllm_model = argv.get('vllm_model')
|
||||
if vllm_model is None:
|
||||
print("model infos is empty, now using VLLM_MODEL_MODE")
|
||||
vllm_model = VLLM_MODEL_MODE
|
||||
|
||||
# TODO: use default memory-recycle schedule
|
||||
# if vllm_model in ['xxx']:
|
||||
# vllm_model = "llm_default"
|
||||
|
||||
memory_recycler = None
|
||||
if vllm_model == "deepseek":
|
||||
from .deepseek_v3_memory_recycler import DeepseekV3MemoryRecycler, alloc_memory_recycler_deepseek_v3
|
||||
memory_recycler = DeepseekV3MemoryRecycler()
|
||||
state = alloc_memory_recycler_deepseek_v3(tokens, huge_memory_alloctor, memory_recycler, dtype)
|
||||
if not state:
|
||||
del memory_recycler
|
||||
memory_recycler = None
|
||||
return state
|
||||
|
||||
if vllm_model == "deepseek_mtp":
|
||||
from .deepseek_v3_memory_recycler import DeepseekMTPMemoryRecycler, alloc_memory_recycler_deepseek_mtp
|
||||
memory_recycler = DeepseekMTPMemoryRecycler()
|
||||
if argv.get('world_size') is None:
|
||||
print("mtp should have TP world size, memory recycler allock fail")
|
||||
return False
|
||||
|
||||
state = alloc_memory_recycler_deepseek_mtp(tokens, huge_memory_alloctor, argv['world_size'], memory_recycler, dtype)
|
||||
if not state:
|
||||
del memory_recycler
|
||||
memory_recycler = None
|
||||
return state
|
||||
|
||||
if vllm_model == "qwen3_moe":
|
||||
from .qwen3_moe_memory_recycler import QWen3MoeMemoryRecycler, alloc_memory_recycler_qwen3_moe
|
||||
memory_recycler = QWen3MoeMemoryRecycler()
|
||||
state = alloc_memory_recycler_qwen3_moe(tokens, huge_memory_alloctor, memory_recycler, dtype)
|
||||
if not state:
|
||||
del memory_recycler
|
||||
memory_recycler = None
|
||||
return state
|
||||
|
||||
if vllm_model == "llm_default":
|
||||
from .allocator import LLMMemoryRecycler, alloc_memory_recycler_llm
|
||||
memory_recycler = LLMMemoryRecycler()
|
||||
state = alloc_memory_recycler_llm(tokens, huge_memory_alloctor, memory_recycler, dtype)
|
||||
if not state:
|
||||
del memory_recycler
|
||||
memory_recycler = None
|
||||
return False
|
||||
|
||||
# LLM pipeline parallel 方案下, 对于非stage0的PART,
|
||||
# 在接收来自BEFORE PART的hiddens, residual的时候
|
||||
# 符合内存复用规则
|
||||
# 该过程与llm forward过程相独立, 需要单独维护
|
||||
# hiddens 对应 llm forward的时候,moe mlp buffer
|
||||
# residual 对应 llm forward的时候,embedding buffer
|
||||
def alloc_pipeline_parallel_recycler_buffer(size:torch.Size, dtype:torch.dtype, key:str):
|
||||
global huge_memory_alloctor
|
||||
if huge_memory_alloctor is None:
|
||||
return None
|
||||
|
||||
intermize_tensor_dict = {
|
||||
"hidden_states": 2,
|
||||
"attention": 1,
|
||||
"residual": 0
|
||||
}
|
||||
|
||||
if not key in intermize_tensor_dict:
|
||||
return None
|
||||
|
||||
src_tensors = huge_memory_alloctor.memory_buffers[intermize_tensor_dict[key]]
|
||||
|
||||
all_bytes = size.numel() * huge_memory_alloctor.get_dtype_bytes(dtype)
|
||||
return src_tensors[:all_bytes].view(dtype).view(size)
|
||||
|
||||
# you should call this function when new server and every workers is start
|
||||
def init_huge_memory_allocator(max_tokens, hidden_size, vllm_model = None):
|
||||
global huge_memory_alloctor
|
||||
if vllm_model is None:
|
||||
print("model infos is empty, now using VLLM_MODEL_MODE")
|
||||
vllm_model = VLLM_MODEL_MODE
|
||||
|
||||
if huge_memory_alloctor is not None:
|
||||
del huge_memory_alloctor
|
||||
torch.vacc.empty_cache()
|
||||
huge_memory_alloctor = None
|
||||
|
||||
if vllm_model == "deepseek":
|
||||
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
|
||||
huge_memory_alloctor.init_buffers(max_tokens, hidden_size)
|
||||
return True
|
||||
|
||||
# deepseek_mtp set use_congituous = True
|
||||
# deepseek_mtp buffer recycler:
|
||||
# buffer[0]: normal_buffer -> embedding_output
|
||||
# buffer[1]: dynamic_buffer -> mla_oproj_output, dynamic_output
|
||||
# buffer[2]: normal_buffer -> moe_shared_mlp_output, deepseek_mtp_layer_input
|
||||
if vllm_model == "deepseek_mtp":
|
||||
# deepseek dynamic tokens last block use for mtp-weights
|
||||
# dynamic_buffer_max_tokens = max_tokens + 128, we will let mtp only support 48K now
|
||||
dynamic_buffer_max_tokens = max_tokens
|
||||
# dynamic bufffer use more 128tokens for broadcast
|
||||
# positions, input tokens
|
||||
deepseek_mtp_max_tokens = max_tokens
|
||||
|
||||
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
|
||||
# huge_memory_alloctor.init_buffers(max_tokens, 7168)
|
||||
huge_memory_alloctor.init_buffers_with_dynamic(deepseek_mtp_max_tokens, dynamic_buffer_max_tokens, hidden_size, [False, True, False])
|
||||
return True
|
||||
|
||||
if vllm_model == "qwen3_moe":
|
||||
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
|
||||
huge_memory_alloctor.init_buffers(max_tokens, hidden_size)
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from .allocator import VaccHugeMemoryAllocator, LLMMemoryRecycler
|
||||
'''
|
||||
QWen3-Moe support 56K input tokens, there are 3 buffers can recycle:
|
||||
1. parallel_embedding output buffer
|
||||
2. mla_oproject output buffer
|
||||
3. moe_shared_mlp output buffer
|
||||
'''
|
||||
class QWen3MoeMemoryRecycler(LLMMemoryRecycler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
def alloc_memory_recycler_qwen3_moe(tokens,
|
||||
alloctor:VaccHugeMemoryAllocator,
|
||||
recycler:QWen3MoeMemoryRecycler,
|
||||
dtype:torch.dtype = torch.bfloat16):
|
||||
|
||||
if not alloctor.enable:
|
||||
print("memory alloctor is not Init.")
|
||||
return False
|
||||
|
||||
recycler.clear()
|
||||
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
|
||||
|
||||
if out_buffers is None:
|
||||
print("qwen3_moe memory recycler buffers alloc fail. now disable it")
|
||||
return False
|
||||
|
||||
if len(out_buffers) != recycler.count:
|
||||
print("memory recycler buffers not equal qwen3_moe.")
|
||||
return False
|
||||
|
||||
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user