[Bugfix] Follow vLLM Qwen-Moe/VL and KV Connector change to fix broken CI (#2181)
### What this PR does / why we need it? This pr fix broken CI: 1. Fix theee2eb6ecd8changes, in this commit, they fused the gate and up projections in the vision MLP, This can improve performance by reducing one matrix multiplication. so, this pr do the following things: - Specify that the two linear layers are fused as `mlp.gate_up_proj` when loading the weights. - Use a SiluAndMul activation function. 2. Fixaefeea0fde, Update ModelRunnerOutput parameters to adapt to its changes 3. Fix [vllm-commit](https://github.com/vllm-project/vllm/pull/20815/files#diff-3ffb829a39ab2b3e4706aa28f5e476815f36c3a87b98d6a66514ebedc8f3ffb4R354-R356), fix qwen moe ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.10.0 - vLLM main:fed5849d3f--------- Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -94,6 +94,8 @@ from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
if not vllm_version_is("0.10.0"):
|
||||
from vllm.tasks import GenerationTask, SupportedTask
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||
KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr # type: ignore[import-untyped]
|
||||
@@ -1472,8 +1474,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
finished_sending: Optional[set[str]],
|
||||
finished_receiving: Optional[set[str]],
|
||||
finished_sending: Optional[set[str]] = None,
|
||||
finished_recving: Optional[set[str]] = None,
|
||||
kv_connector_output: Optional["KVConnectorOutput"] = None,
|
||||
) -> ModelRunnerOutput:
|
||||
assert self.input_batch.num_reqs ==\
|
||||
len(self.input_batch.pooling_params), \
|
||||
@@ -1499,6 +1502,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
pooler_output.append(raw_output.data.cpu())
|
||||
else:
|
||||
pooler_output.append(None)
|
||||
extra_args = ({
|
||||
"finished_sending": finished_sending,
|
||||
"finished_recving": finished_recving
|
||||
} if vllm_version_is("0.10.0") else {
|
||||
"kv_connector_output": kv_connector_output
|
||||
})
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
@@ -1508,8 +1517,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_receiving)
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
@@ -1533,7 +1542,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_scheduled_tokens_np, finished_sending,
|
||||
finished_recving) = (self._process_reqs(scheduler_output,
|
||||
intermediate_tensors))
|
||||
|
||||
kv_connector_output = None
|
||||
if not vllm_version_is("0.10.0"):
|
||||
kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving)
|
||||
finished_sending = None
|
||||
finished_recving = None
|
||||
with ProfileExecuteDuration().capture_async("post process"):
|
||||
# Broadcast PP output for external_launcher (torchrun)
|
||||
# to make sure we are synced across pp ranks
|
||||
@@ -1545,7 +1560,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
if not broadcast_pp_output:
|
||||
if finished_sending or finished_recving:
|
||||
if kv_connector_output is not None:
|
||||
hidden_states.kv_connector_output = kv_connector_output
|
||||
else:
|
||||
#TODO: Remove this after we drop vllm v0.10.0
|
||||
hidden_states.finished_sending = finished_sending
|
||||
hidden_states.finished_recving = finished_recving
|
||||
return hidden_states
|
||||
@@ -1557,7 +1575,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np,
|
||||
finished_sending, finished_recving)
|
||||
finished_sending, finished_recving,
|
||||
kv_connector_output)
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if broadcast_pp_output:
|
||||
@@ -1691,17 +1710,23 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().clear_connector_metadata()
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
extra_args = ({
|
||||
"finished_sending": finished_sending,
|
||||
"finished_recving": finished_recving
|
||||
} if vllm_version_is("0.10.0") else {
|
||||
"kv_connector_output": kv_connector_output
|
||||
})
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
durations = ProfileExecuteDuration().pop_captured_sync()
|
||||
if durations:
|
||||
|
||||
@@ -209,12 +209,27 @@ class NPUWorker(WorkerBase):
|
||||
if not has_kv_transfer_group():
|
||||
return None
|
||||
|
||||
new_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
if output.finished_sending or output.finished_recving:
|
||||
new_output = copy.copy(new_output)
|
||||
new_output.finished_sending = output.finished_sending
|
||||
new_output.finished_recving = output.finished_recving
|
||||
output = new_output
|
||||
is_legacy = vllm_version_is("0.10.0")
|
||||
|
||||
if is_legacy:
|
||||
finished_sending = output.finished_sending
|
||||
finished_recving = output.finished_recving
|
||||
else:
|
||||
kv_connector_output = output.kv_connector_output
|
||||
finished_sending = kv_connector_output.finished_sending
|
||||
finished_recving = kv_connector_output.finished_recving
|
||||
|
||||
if not finished_sending and not finished_recving:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
|
||||
if is_legacy:
|
||||
new_output.finished_sending = finished_sending
|
||||
new_output.finished_recving = finished_recving
|
||||
else:
|
||||
new_output.kv_connector_output = kv_connector_output
|
||||
return new_output
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user