add vxpu
This commit is contained in:
@@ -73,6 +73,11 @@ xvllm_environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_KUNLUN_ENABLE_INT8_BMM":
|
||||
lambda: (os.environ.get("VLLM_KUNLUN_ENABLE_INT8_BMM", "False").lower() in
|
||||
("true", "1")),
|
||||
|
||||
"VLLM_KUNLUN_ENABLE_VXPU":
|
||||
lambda: (os.environ.get("VLLM_KUNLUN_ENABLE_VXPU", "True").lower() in
|
||||
("true", "1")
|
||||
)
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@@ -8,6 +8,9 @@ import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
|
||||
# fix bfloat16 double size issue
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class KunlunPlatform(Platform):
|
||||
@@ -152,8 +155,10 @@ class KunlunPlatform(Platform):
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if vllm_config.speculative_config:
|
||||
if envs.VLLM_USE_V1:
|
||||
# parallel_config.worker_cls = \
|
||||
# "vllm.v1.worker.gpu_worker.Worker"
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
"vllm_kunlun.v1.worker.worker_v1.KunlunWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
@@ -162,8 +167,10 @@ class KunlunPlatform(Platform):
|
||||
else:
|
||||
print(f"envs.VLLM_USE_V1 = {envs.VLLM_USE_V1}")
|
||||
if envs.VLLM_USE_V1:
|
||||
# parallel_config.worker_cls = \
|
||||
# "vllm.v1.worker.gpu_worker.Worker"
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
"vllm_kunlun.v1.worker.worker_v1.KunlunWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user