mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-21 04:00:15 +00:00
大改
This commit is contained in:
30
vllm_npu/patch/platform/__init__.py
Normal file
30
vllm_npu/patch/platform/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import vllm_npu.patch.platform.patch_config # noqa
|
||||
import vllm_npu.patch.platform.patch_distributed # noqa
|
||||
import vllm_npu.patch.platform.patch_mamba_config # noqa
|
||||
import vllm_npu.patch.platform.patch_sched_yield # noqa
|
||||
|
||||
if os.getenv("DYNAMIC_EPLB", "false") == "true" or os.getenv(
|
||||
"EXPERT_MAP_RECORD", "false") == "true":
|
||||
import vllm_npu.patch.platform.patch_multiproc_executor # noqa
|
||||
|
||||
if os.getenv("SHM_BARRIER", "true") == "true":
|
||||
import vllm_npu.patch.platform.patch_core # noqa
|
||||
import vllm_npu.patch.platform.patch_message_queue # noqa
|
||||
234
vllm_npu/patch/platform/patch_config.py
Normal file
234
vllm_npu/patch/platform/patch_config.py
Normal file
@@ -0,0 +1,234 @@
|
||||
import ast
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
# Note: "method" is a new parameter that helps to extend the
|
||||
# configuration of non-model-based proposers, and the "model" parameter
|
||||
# will be used to set the draft model, eagle head, or additional weight
|
||||
# when needed. If users do not specify "method", the speculative method
|
||||
# will be detected automatically if possible. If the speculative method
|
||||
# can not be detected, it will be considered as the "draft_model" by
|
||||
# default.
|
||||
|
||||
if self.model is None and self.num_speculative_tokens is not None:
|
||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||
if (self.target_model_config
|
||||
and self.target_model_config.hf_text_config.model_type
|
||||
in ("deepseek_v3", "deepseek_v32", "mimo", "ernie4_5_moe",
|
||||
"qwen3_next")):
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
# Align the quantization of draft model for cases such as
|
||||
# --quantization fp8 with a bf16 checkpoint.
|
||||
if not self.quantization:
|
||||
self.quantization = self.target_model_config.quantization
|
||||
elif self.method in ("ngram", "[ngram]"):
|
||||
self.model = "ngram"
|
||||
else:
|
||||
raise ValueError("num_speculative_tokens was provided but without "
|
||||
"speculative model.")
|
||||
|
||||
# Automatically configure the method for ngram when "model" is used
|
||||
# instead of "method"
|
||||
if self.method is None and (self.model is not None
|
||||
and self.model in ("ngram", "[ngram]")):
|
||||
self.method = "ngram"
|
||||
|
||||
if self.method in ("ngram", "[ngram]"):
|
||||
# Unified to "ngram" internally
|
||||
self.method = "ngram"
|
||||
# Set default values if not provided
|
||||
if (self.prompt_lookup_min is None and self.prompt_lookup_max is None):
|
||||
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
|
||||
self.prompt_lookup_min = 5
|
||||
self.prompt_lookup_max = 5
|
||||
elif self.prompt_lookup_min is None:
|
||||
assert self.prompt_lookup_max is not None
|
||||
self.prompt_lookup_min = self.prompt_lookup_max
|
||||
elif self.prompt_lookup_max is None:
|
||||
assert self.prompt_lookup_min is not None
|
||||
self.prompt_lookup_max = self.prompt_lookup_min
|
||||
|
||||
# Validate values
|
||||
if self.prompt_lookup_min < 1:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0")
|
||||
if self.prompt_lookup_max < 1:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0")
|
||||
if self.prompt_lookup_min > self.prompt_lookup_max:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_min={self.prompt_lookup_min} must "
|
||||
f"be <= prompt_lookup_max={self.prompt_lookup_max}")
|
||||
|
||||
# TODO: current we still need extract vocab_size from target model
|
||||
# config, in future, we may try refactor it out, and set
|
||||
# draft related config as None here.
|
||||
self.draft_model_config = self.target_model_config
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
else:
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
|
||||
if self.model is not None:
|
||||
# TODO: Move this import to the top once `ModelConfig`
|
||||
# lives in `vllm.config.model`.
|
||||
from vllm.config import ModelConfig
|
||||
self.draft_model_config = ModelConfig(
|
||||
model=self.model,
|
||||
runner="draft",
|
||||
tokenizer=self.target_model_config.tokenizer,
|
||||
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
||||
trust_remote_code=self.target_model_config.trust_remote_code,
|
||||
allowed_local_media_path=self.target_model_config.
|
||||
allowed_local_media_path,
|
||||
allowed_media_domains=self.target_model_config.
|
||||
allowed_media_domains,
|
||||
dtype=self.target_model_config.dtype,
|
||||
seed=self.target_model_config.seed,
|
||||
revision=self.revision,
|
||||
code_revision=self.code_revision,
|
||||
tokenizer_revision=self.target_model_config.tokenizer_revision,
|
||||
spec_target_max_model_len=self.target_model_config.
|
||||
max_model_len,
|
||||
quantization=self.quantization,
|
||||
enforce_eager=self.target_model_config.enforce_eager,
|
||||
max_logprobs=self.target_model_config.max_logprobs,
|
||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||
)
|
||||
|
||||
# Automatically detect the method
|
||||
if self.method in ('eagle', 'eagle3'):
|
||||
pass
|
||||
# examples:
|
||||
# yuhuili/EAGLE-LLaMA3-Instruct-8B
|
||||
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
|
||||
# AngelSlim/Qwen3-8B_eagle3
|
||||
elif "eagle-" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle"
|
||||
elif "eagle3" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle3"
|
||||
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||
self.method = "medusa"
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"mlp_speculator"):
|
||||
self.method = "mlp_speculator"
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
|
||||
self.method = "deepseek_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Deepseek MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type == "ernie_mtp"):
|
||||
self.method = "ernie_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Ernie MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"qwen3_next_mtp"):
|
||||
self.method = "qwen3_next_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Qwen3Next MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("longcat_flash_mtp")):
|
||||
self.method = "longcat_flash_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"LongCat MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
else:
|
||||
self.method = "draft_model"
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding with draft model is not "
|
||||
"supported yet. Please consider using other "
|
||||
"speculative decoding methods such as ngram, medusa, "
|
||||
"eagle, or deepseek_mtp.")
|
||||
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Chunked prefill and EAGLE are not compatible "
|
||||
"when using V0.")
|
||||
|
||||
from vllm.transformers_utils.configs import SpeculatorsConfig
|
||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
|
||||
if isinstance(self.draft_model_config.hf_config,
|
||||
(EAGLEConfig, SpeculatorsConfig)):
|
||||
pass
|
||||
else:
|
||||
eagle_config = EAGLEConfig(
|
||||
self.draft_model_config.hf_config,
|
||||
method=self.method,
|
||||
model_type="eagle")
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
|
||||
if (self.num_speculative_tokens is not None
|
||||
and hasattr(self.draft_model_config.hf_config,
|
||||
"num_lookahead_tokens")):
|
||||
self.draft_model_config.hf_config.num_lookahead_tokens = \
|
||||
self.num_speculative_tokens
|
||||
|
||||
n_predict = getattr(self.draft_model_config.hf_config, "n_predict",
|
||||
None)
|
||||
if n_predict is not None:
|
||||
if self.num_speculative_tokens is None:
|
||||
# Default to max value defined in draft model config.
|
||||
self.num_speculative_tokens = n_predict
|
||||
elif self.num_speculative_tokens > n_predict and \
|
||||
self.num_speculative_tokens % n_predict != 0:
|
||||
# Ensure divisibility for MTP module reuse.
|
||||
raise ValueError(
|
||||
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
||||
f" must be divisible by {n_predict=}")
|
||||
|
||||
if self.speculative_token_tree is None:
|
||||
# Generate chain of tokens.
|
||||
self.speculative_token_tree = str([
|
||||
(i + 1) * (0, ) for i in range(self.num_speculative_tokens)
|
||||
])
|
||||
else:
|
||||
# Sort the token tree breadth-first.
|
||||
tree_choices = ast.literal_eval(self.speculative_token_tree)
|
||||
self.speculative_token_tree = str(
|
||||
sorted(tree_choices, key=lambda t: (len(t), t)))
|
||||
|
||||
self.draft_tensor_parallel_size = \
|
||||
SpeculativeConfig._verify_and_get_draft_tp(
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size,
|
||||
self.draft_model_config.hf_config
|
||||
)
|
||||
|
||||
self.draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
self.max_model_len,
|
||||
self.draft_model_config.max_model_len,
|
||||
self.target_model_config.max_model_len,
|
||||
))
|
||||
|
||||
self.draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size))
|
||||
|
||||
|
||||
SpeculativeConfig.__post_init__ = __post_init__
|
||||
68
vllm_npu/patch/platform/patch_core.py
Normal file
68
vllm_npu/patch/platform/patch_core.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import signal
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.transformers_utils.config import \
|
||||
maybe_register_config_serialize_by_value
|
||||
from vllm.utils import decorate_logs, set_process_title
|
||||
from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc
|
||||
|
||||
|
||||
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
|
||||
"""Launch EngineCore busy loop in background process."""
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import \
|
||||
MessageQueue # noqa
|
||||
|
||||
# Signal handler used for graceful termination.
|
||||
# SystemExit exception is only raised once to allow this and worker
|
||||
# processes to terminate without error
|
||||
shutdown_requested = False
|
||||
|
||||
# Ensure we can serialize transformer config after spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the engine_core
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
engine_core: Optional[EngineCoreProc] = None
|
||||
try:
|
||||
parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
|
||||
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
|
||||
set_process_title("EngineCore", f"DP{dp_rank}")
|
||||
decorate_logs()
|
||||
# Set data parallel rank for this engine process.
|
||||
parallel_config.data_parallel_rank = dp_rank
|
||||
parallel_config.data_parallel_rank_local = local_dp_rank
|
||||
engine_core = DPEngineCoreProc(*args, **kwargs)
|
||||
else:
|
||||
set_process_title("EngineCore")
|
||||
decorate_logs()
|
||||
engine_core = EngineCoreProc(*args, **kwargs)
|
||||
|
||||
engine_core.run_busy_loop()
|
||||
|
||||
except SystemExit:
|
||||
logger.debug("EngineCore exiting.")
|
||||
raise
|
||||
except Exception as e:
|
||||
if engine_core is None:
|
||||
logger.exception("EngineCore failed to start.")
|
||||
else:
|
||||
logger.exception("EngineCore encountered a fatal error.")
|
||||
engine_core._send_engine_dead()
|
||||
raise e
|
||||
finally:
|
||||
if engine_core is not None:
|
||||
engine_core.shutdown()
|
||||
|
||||
|
||||
EngineCoreProc.run_engine_core = run_engine_core
|
||||
115
vllm_npu/patch/platform/patch_distributed.py
Normal file
115
vllm_npu/patch/platform/patch_distributed.py
Normal file
@@ -0,0 +1,115 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapted from vllm/model_executor/models/qwen2_vl.py
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs_vllm
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
from vllm_npu.utils import is_310p
|
||||
|
||||
|
||||
def parallel_config_get_dp_port(self) -> int:
|
||||
"""
|
||||
We might need to initialize process groups in multiple
|
||||
processes that is related to data parallelism,
|
||||
e.g. both in the worker and in the engine, which
|
||||
can live in different processes. To avoid port conflicts, we
|
||||
increment the port number each time we need to initialize a
|
||||
new process group related to data parallelism.
|
||||
"""
|
||||
answer = self.data_parallel_master_port
|
||||
self.data_parallel_master_port += 1
|
||||
|
||||
# NOTE: Get port from envs directly when using torchrun
|
||||
port = envs_vllm.VLLM_DP_MASTER_PORT if envs_vllm.VLLM_DP_MASTER_PORT else answer
|
||||
return port
|
||||
|
||||
|
||||
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
|
||||
|
||||
|
||||
class NullHandle:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def wait(self):
|
||||
pass
|
||||
|
||||
|
||||
def communication_adaptation_310p():
|
||||
|
||||
def broadcast310p_wrapper(fn):
|
||||
|
||||
def broadcast310p(tensor, src, group=None, async_op=False):
|
||||
if tensor.device == torch.device('cpu'):
|
||||
return fn(tensor, src, group, async_op)
|
||||
rank = torch.distributed.get_rank(group)
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
tensor_list[rank] = tensor
|
||||
torch.distributed.all_gather(tensor_list, tensor, group=group)
|
||||
tensor[...] = tensor_list[src]
|
||||
if async_op:
|
||||
return NullHandle()
|
||||
else:
|
||||
return None
|
||||
|
||||
return broadcast310p
|
||||
|
||||
torch.distributed.broadcast = broadcast310p_wrapper(
|
||||
torch.distributed.broadcast)
|
||||
torch.distributed.distributed_c10d.broadcast = broadcast310p_wrapper(
|
||||
torch.distributed.distributed_c10d.broadcast)
|
||||
|
||||
def all_reduce_wrapper_310p(fn):
|
||||
|
||||
def all_reduce(
|
||||
tensor,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op=False,
|
||||
):
|
||||
if tensor.dtype != torch.int64:
|
||||
return fn(tensor, op, group, async_op)
|
||||
rank = torch.distributed.get_rank(group)
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
tensor_list[rank] = tensor
|
||||
torch.distributed.all_gather(tensor_list, tensor, group=group)
|
||||
if op == torch.distributed.ReduceOp.SUM:
|
||||
return torch.stack(tensor_list).sum(0)
|
||||
elif op == torch.distributed.ReduceOp.MAX:
|
||||
return torch.tensor(
|
||||
torch.stack(tensor_list).cpu().numpy().max(0),
|
||||
device=tensor.device,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"not implement op {op}")
|
||||
|
||||
return all_reduce
|
||||
|
||||
torch.distributed.all_reduce = all_reduce_wrapper_310p(
|
||||
torch.distributed.all_reduce)
|
||||
torch.distributed.distributed_c10d.all_reduce = all_reduce_wrapper_310p(
|
||||
torch.distributed.distributed_c10d.all_reduce)
|
||||
|
||||
|
||||
if is_310p():
|
||||
communication_adaptation_310p()
|
||||
96
vllm_npu/patch/platform/patch_mamba_config.py
Normal file
96
vllm_npu/patch/platform/patch_mamba_config.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# mypy: ignore-errors
|
||||
import vllm.model_executor.models.config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models.config import MambaModelConfig
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
||||
|
||||
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config) -> None:
|
||||
"""
|
||||
Ensure that page size of attention layers is greater than or
|
||||
equal to the mamba layers. If not, automatically set the attention
|
||||
block size to ensure that it is. If the attention page size is
|
||||
strictly greater than the mamba page size, we pad the mamba page size
|
||||
to make them equal.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM Config
|
||||
"""
|
||||
logger = init_logger(__name__)
|
||||
# Enable FULL_AND_PIECEWISE by default
|
||||
MambaModelConfig.verify_and_update_config(vllm_config)
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
kv_cache_dtype = model_config.dtype
|
||||
else:
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# get attention page size (for 1 token)
|
||||
attn_page_size_1_token = FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype).page_size_bytes
|
||||
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||
model_config.architecture,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# get mamba page size
|
||||
mamba_page_size = MambaSpec(
|
||||
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
||||
block_size=model_config.max_model_len,
|
||||
).page_size_bytes
|
||||
|
||||
block_alignment_bytes = 128
|
||||
|
||||
# some attention backends (e.g. FA) only support setting
|
||||
# block size to multiple of 16, so let's suggest a value
|
||||
# that would work (note: FA is currently not compatible
|
||||
# with mamba layers, use FlashInfer instead).
|
||||
attn_block_size = block_alignment_bytes * cdiv(
|
||||
mamba_page_size, block_alignment_bytes * attn_page_size_1_token)
|
||||
|
||||
# override attention block size if either (a) the
|
||||
# user has not set it or (b) the user has set it
|
||||
# too small.
|
||||
if (cache_config.block_size is None
|
||||
or cache_config.block_size < attn_block_size):
|
||||
cache_config.block_size = attn_block_size
|
||||
logger.info(
|
||||
"Setting attention block size to %d tokens "
|
||||
"to ensure that attention page size is >= mamba page size.",
|
||||
attn_block_size)
|
||||
|
||||
# compute new attention page size
|
||||
attn_page_size = \
|
||||
cache_config.block_size * attn_page_size_1_token
|
||||
|
||||
assert attn_page_size >= mamba_page_size
|
||||
|
||||
if attn_page_size == mamba_page_size:
|
||||
# don't need to pad mamba page size
|
||||
return
|
||||
|
||||
# pad mamba page size to exactly match attention
|
||||
if (cache_config.mamba_page_size_padded is None
|
||||
or cache_config.mamba_page_size_padded != attn_page_size):
|
||||
cache_config.mamba_page_size_padded = (attn_page_size)
|
||||
mamba_padding_pct = 100 * (attn_page_size -
|
||||
mamba_page_size) / mamba_page_size
|
||||
logger.info(
|
||||
"Padding mamba page size by %.2f%% to ensure "
|
||||
"that mamba page size and attention page size are "
|
||||
"exactly equal.", mamba_padding_pct)
|
||||
|
||||
|
||||
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config
|
||||
164
vllm_npu/patch/platform/patch_message_queue.py
Normal file
164
vllm_npu/patch/platform/patch_message_queue.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
||||
MessageQueue,
|
||||
ShmRingBuffer,
|
||||
SpinTimer)
|
||||
from vllm.distributed.utils import sched_yield
|
||||
from vllm.logger import logger
|
||||
from vllm.utils import (get_ip, get_mp_context, get_open_port,
|
||||
get_open_zmq_ipc_path, is_valid_ipv6_address)
|
||||
from zmq import IPV6, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_reader, # number of all readers
|
||||
n_local_reader, # number of local readers through shared memory
|
||||
local_reader_ranks: Optional[list[int]] = None,
|
||||
max_chunk_bytes: int = 1024 * 1024 * 10,
|
||||
max_chunks: int = 10,
|
||||
connect_ip: Optional[str] = None,
|
||||
):
|
||||
if local_reader_ranks is None:
|
||||
local_reader_ranks = list(range(n_local_reader))
|
||||
else:
|
||||
assert len(local_reader_ranks) == n_local_reader
|
||||
self.n_local_reader = n_local_reader
|
||||
n_remote_reader = n_reader - n_local_reader
|
||||
self.n_remote_reader = n_remote_reader
|
||||
|
||||
context = Context()
|
||||
|
||||
if n_local_reader > 0:
|
||||
# for local readers, we will:
|
||||
# 1. create a shared memory ring buffer to communicate small data
|
||||
# 2. create a publish-subscribe socket to communicate large data
|
||||
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
|
||||
max_chunks)
|
||||
|
||||
# XPUB is very similar to PUB,
|
||||
# except that it can receive subscription messages
|
||||
# to confirm the number of subscribers
|
||||
self.local_socket = context.socket(XPUB)
|
||||
# set the verbose option so that we can receive every subscription
|
||||
# message. otherwise, we will only receive the first subscription
|
||||
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
|
||||
self.local_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
local_subscribe_addr = get_open_zmq_ipc_path()
|
||||
logger.debug("Binding to %s", local_subscribe_addr)
|
||||
self.local_socket.bind(local_subscribe_addr)
|
||||
|
||||
self.current_idx = 0
|
||||
self.writer_lock = get_mp_context().Lock()
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
local_subscribe_addr = None
|
||||
self.local_socket = None
|
||||
self.current_idx = -1
|
||||
|
||||
remote_addr_ipv6 = False
|
||||
if n_remote_reader > 0:
|
||||
# for remote readers, we will:
|
||||
# create a publish-subscribe socket to communicate large data
|
||||
if not connect_ip:
|
||||
connect_ip = get_ip()
|
||||
self.remote_socket = context.socket(XPUB)
|
||||
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
remote_subscribe_port = get_open_port()
|
||||
if is_valid_ipv6_address(connect_ip):
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
remote_addr_ipv6 = True
|
||||
connect_ip = f"[{connect_ip}]"
|
||||
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||
self.remote_socket.bind(socket_addr)
|
||||
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||
else:
|
||||
remote_subscribe_addr = None
|
||||
self.remote_socket = None
|
||||
|
||||
self._is_writer = True
|
||||
self._is_local_reader = False
|
||||
self.local_reader_rank = -1
|
||||
# rank does not matter for remote readers
|
||||
self._is_remote_reader = False
|
||||
self._read_spin_timer = SpinTimer()
|
||||
|
||||
self.handle = Handle(
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
buffer_handle=self.buffer.handle()
|
||||
if self.buffer is not None else None,
|
||||
local_subscribe_addr=local_subscribe_addr,
|
||||
remote_subscribe_addr=remote_subscribe_addr,
|
||||
remote_addr_ipv6=remote_addr_ipv6,
|
||||
)
|
||||
|
||||
logger.info("vLLM message queue communication handle: %s", self.handle)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def acquire_write(self, timeout: Optional[float] = None):
|
||||
assert self._is_writer, "Only writers can acquire write"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
read_count = sum(metadata_buffer[1:])
|
||||
written_flag = metadata_buffer[0]
|
||||
if written_flag and read_count != self.buffer.n_reader:
|
||||
# this block is written and not read by all readers
|
||||
# for writers, `self.current_idx` is the next block to write
|
||||
# if this block is not ready to write,
|
||||
# we need to wait until it is read by all readers
|
||||
|
||||
# Release the processor to other threads
|
||||
sched_yield()
|
||||
|
||||
# if we time out, raise an exception
|
||||
elapsed = time.monotonic() - start_time
|
||||
if timeout is not None and elapsed > timeout:
|
||||
raise TimeoutError
|
||||
|
||||
# if we wait for a long time, log a message
|
||||
if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
|
||||
logger.info(
|
||||
"No available shared memory broadcast block found"
|
||||
" in %s seconds. This typically happens when some"
|
||||
" processes are hanging or doing some"
|
||||
" time-consuming work (e.g. compilation)",
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||
n_warning += 1
|
||||
|
||||
continue
|
||||
# found a block that is either
|
||||
# (1) not written
|
||||
# (2) read by all readers
|
||||
|
||||
with self.writer_lock:
|
||||
# mark the block as not written
|
||||
metadata_buffer[0] = 0
|
||||
# let caller write to the buffer
|
||||
with self.buffer.get_data(self.current_idx) as buf:
|
||||
yield buf
|
||||
|
||||
# caller has written to the buffer
|
||||
# NOTE: order is important here
|
||||
# first set the read flags to 0
|
||||
# then set the written flag to 1
|
||||
# otherwise, the readers may think they already read the block
|
||||
for i in range(1, self.buffer.n_reader + 1):
|
||||
# set read flag to 0, meaning it is not read yet
|
||||
metadata_buffer[i] = 0
|
||||
# mark the block as written
|
||||
metadata_buffer[0] = 1
|
||||
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
||||
break
|
||||
|
||||
|
||||
MessageQueue.__init__ = __init__
|
||||
MessageQueue.acquire_write = acquire_write
|
||||
151
vllm_npu/patch/platform/patch_multiproc_executor.py
Normal file
151
vllm_npu/patch/platform/patch_multiproc_executor.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import threading
|
||||
import weakref
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from typing import Optional
|
||||
|
||||
import vllm.v1.executor.multiproc_executor
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||
from vllm.utils import (get_distributed_init_method, get_loopback_ip,
|
||||
get_mp_context, get_open_port)
|
||||
from vllm.v1.executor.abstract import FailureCallback
|
||||
from vllm.v1.executor.multiproc_executor import (
|
||||
MultiprocExecutor, UnreadyWorkerProcHandle, WorkerProc,
|
||||
set_multiprocessing_worker_envs)
|
||||
|
||||
|
||||
class AscendMultiprocExecutor(MultiprocExecutor):
|
||||
supports_pp: bool = True
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
# Call self.shutdown at exit to clean up
|
||||
# and ensure workers will be terminated.
|
||||
self._finalizer = weakref.finalize(self, self.shutdown)
|
||||
self.is_failed = False
|
||||
self.shutdown_event = threading.Event()
|
||||
self.failure_callback: Optional[FailureCallback] = None
|
||||
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
|
||||
self.world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
|
||||
f"world_size ({self.world_size}) must be equal to the "
|
||||
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
|
||||
f"_parallel_size ({pp_parallel_size}). ")
|
||||
|
||||
# Set multiprocessing envs
|
||||
set_multiprocessing_worker_envs()
|
||||
|
||||
# Multiprocessing-based executor does not support multi-node setting.
|
||||
# Since it only works for single node, we can use the loopback address
|
||||
# get_loopback_ip() for communication.
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_loopback_ip(), get_open_port())
|
||||
|
||||
# Initialize worker and set up message queues for SchedulerOutputs
|
||||
# and ModelRunnerOutputs
|
||||
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
||||
self.rpc_broadcast_mq = MessageQueue(self.world_size,
|
||||
self.world_size,
|
||||
max_chunk_bytes=max_chunk_bytes)
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
|
||||
# Create workers
|
||||
context = get_mp_context()
|
||||
shared_worker_lock = context.Lock()
|
||||
unready_workers: list[UnreadyWorkerProcHandle] = []
|
||||
success = False
|
||||
try:
|
||||
for rank in range(self.world_size):
|
||||
unready_workers.append(
|
||||
AscendWorkerProc.make_worker_process(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
input_shm_handle=scheduler_output_handle,
|
||||
shared_worker_lock=shared_worker_lock,
|
||||
))
|
||||
|
||||
# Workers must be created before wait_for_ready to avoid
|
||||
# deadlock, since worker.init_device() does a device sync.
|
||||
self.workers = WorkerProc.wait_for_ready(unready_workers)
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered
|
||||
# Must be kept consistent with the WorkerProc.
|
||||
self.rpc_broadcast_mq.wait_until_ready()
|
||||
for w in self.workers:
|
||||
w.worker_response_mq.wait_until_ready()
|
||||
|
||||
self.start_worker_monitor()
|
||||
success = True
|
||||
finally:
|
||||
if not success:
|
||||
# Clean up the worker procs if there was a failure.
|
||||
# Close death_writers first to signal workers to exit
|
||||
for uw in unready_workers:
|
||||
if uw.death_writer is not None:
|
||||
uw.death_writer.close()
|
||||
self._ensure_worker_termination(
|
||||
[uw.proc for uw in unready_workers])
|
||||
|
||||
# For pipeline parallel, we use a thread pool for asynchronous
|
||||
# execute_model.
|
||||
if self.max_concurrent_batches > 1:
|
||||
# Note: must use only 1 IO thread to keep dequeue sequence
|
||||
# from the response queue
|
||||
# _async_aggregate_workers_output also assumes a single IO thread
|
||||
self.io_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="mp_exec_io")
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
|
||||
|
||||
class AscendWorkerProc(WorkerProc):
|
||||
|
||||
@staticmethod
|
||||
def make_worker_process(
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
input_shm_handle, # Receive SchedulerOutput
|
||||
shared_worker_lock: LockType,
|
||||
) -> UnreadyWorkerProcHandle:
|
||||
context = get_mp_context()
|
||||
# (reader, writer)
|
||||
reader, writer = context.Pipe(duplex=False)
|
||||
|
||||
# Create death pipe to detect parent process exit
|
||||
death_reader, death_writer = context.Pipe(duplex=False)
|
||||
|
||||
process_kwargs = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_rank": local_rank,
|
||||
"rank": rank,
|
||||
"distributed_init_method": distributed_init_method,
|
||||
"input_shm_handle": input_shm_handle,
|
||||
"ready_pipe": (reader, writer),
|
||||
"death_pipe": death_reader,
|
||||
"shared_worker_lock": shared_worker_lock,
|
||||
}
|
||||
# Run EngineCore busy loop in background process.
|
||||
proc = context.Process(
|
||||
target=WorkerProc.worker_main,
|
||||
kwargs=process_kwargs,
|
||||
name=f"VllmWorker-{rank}",
|
||||
daemon=False,
|
||||
)
|
||||
|
||||
proc.start()
|
||||
writer.close()
|
||||
# Keep death_writer open in parent - when parent exits,
|
||||
# death_reader in child will get EOFError
|
||||
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer)
|
||||
|
||||
|
||||
vllm.v1.executor.multiproc_executor.MultiprocExecutor = AscendMultiprocExecutor
|
||||
13
vllm_npu/patch/platform/patch_sched_yield.py
Normal file
13
vllm_npu/patch/platform/patch_sched_yield.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import sys
|
||||
|
||||
import vllm.distributed.utils
|
||||
from vllm.platforms import CpuArchEnum, Platform
|
||||
|
||||
is_arm = (Platform.get_cpu_architecture() == CpuArchEnum.ARM)
|
||||
|
||||
USE_SCHED_YIELD = (
|
||||
((sys.version_info[:3] >= (3, 11, 1)) or
|
||||
(sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8))
|
||||
and not is_arm)
|
||||
|
||||
vllm.distributed.utils.USE_SCHED_YIELD = USE_SCHED_YIELD
|
||||
Reference in New Issue
Block a user