[Refactor] MLP weight prefetch to consistency with MoE Model's prefetching in terms of code and usage (#6442)
### What this PR does / why we need it?
Refactor MLP weight prefetch to consistency with MoE Model's prefetching
in terms of code and usage.
Environments VLLM_ASCEND_ENABLE_PREFETCH_MLP,
VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE and
VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE is removed, usage as following:
--additional-config '{"weight_prefetch_config": { "enabled": true,
"prefetch_ratio": {"mlp": { "gate_up": 1.0, "down": 1.0} }}}'
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: leo-pony <nengjunma@outlook.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import logger
|
||||
@@ -48,9 +49,7 @@ class AscendConfig:
|
||||
|
||||
# Dump / PrecisionDebugger configuration
|
||||
self.dump_config_path = additional_config.get("dump_config_path", None)
|
||||
|
||||
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
|
||||
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)
|
||||
self._construct_weight_prefetch_config(additional_config)
|
||||
self.layer_sharding = additional_config.get("layer_sharding", None)
|
||||
logger.info_once(
|
||||
f"Linear layer sharding enabled with config: {self.layer_sharding}. "
|
||||
@@ -138,6 +137,29 @@ class AscendConfig:
|
||||
"enable_kv_nz is only supported in pd scenario and can only be used in D node."
|
||||
)
|
||||
|
||||
def _construct_weight_prefetch_config(self, additional_config):
|
||||
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
|
||||
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)
|
||||
# Deprecated env var handling for backward compatibility
|
||||
if os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0") == "1":
|
||||
MAX_PREFETCH_WEIGHT_SIZE: int = 18 * 1024 * 1024
|
||||
gate_up_prefetch_size = int(os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE))
|
||||
down_prefetch_szie = int(os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE))
|
||||
self.weight_prefetch_config.set_mlp_pre_version_compatibale_config(
|
||||
gate_up_prefetch_size, down_prefetch_szie
|
||||
)
|
||||
logger.info_once(
|
||||
f"MLP weight prefetch enabled from env variable VLLM_ASCEND_ENABLE_PREFETCH_MLP."
|
||||
f"gate_up_prefetch_size={gate_up_prefetch_size}, "
|
||||
f"down_prefetch_szie={down_prefetch_szie}."
|
||||
)
|
||||
warnings.warn(
|
||||
"VLLM_ASCEND_ENABLE_PREFETCH_MLP is deprecated and will be removed in a v0.16.0 version. "
|
||||
"Please use weight_prefetch_config in additional-config for now instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
class FinegrainedTPConfig:
|
||||
"""
|
||||
@@ -305,18 +327,28 @@ class WeightPrefetchConfig:
|
||||
Configuration Object for weight_prefetch_config from additional_config
|
||||
"""
|
||||
|
||||
mlp_pre_version_compatibale_config: dict = {}
|
||||
|
||||
prefetch_ratio: dict = {
|
||||
"attn": {
|
||||
"qkv": 1.0,
|
||||
"o": 1.0,
|
||||
},
|
||||
"moe": {"gate_up": 0.8},
|
||||
"mlp": {"gate_up": 1, "down": 1.0},
|
||||
}
|
||||
|
||||
def __init__(self, weight_prefetch_config: dict):
|
||||
self.enabled = weight_prefetch_config.get("enabled", False)
|
||||
self.prefetch_ratio = weight_prefetch_config.get("prefetch_ratio", self.prefetch_ratio)
|
||||
|
||||
def set_mlp_pre_version_compatibale_config(self, gate_up_prefetch_size: int, down_prefetch_size: int):
|
||||
config = {
|
||||
"gate_up": gate_up_prefetch_size,
|
||||
"down": down_prefetch_size,
|
||||
}
|
||||
self.mlp_pre_version_compatibale_config = config
|
||||
|
||||
|
||||
class EplbConfig:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user