mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
97 lines
3.6 KiB
Python
97 lines
3.6 KiB
Python
# mypy: ignore-errors
|
|
import vllm.model_executor.models.config
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.models import ModelRegistry
|
|
from vllm.model_executor.models.config import MambaModelConfig
|
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
|
|
|
|
|
@classmethod
|
|
def verify_and_update_config(cls, vllm_config) -> None:
|
|
"""
|
|
Ensure that page size of attention layers is greater than or
|
|
equal to the mamba layers. If not, automatically set the attention
|
|
block size to ensure that it is. If the attention page size is
|
|
strictly greater than the mamba page size, we pad the mamba page size
|
|
to make them equal.
|
|
|
|
Args:
|
|
vllm_config: vLLM Config
|
|
"""
|
|
logger = init_logger(__name__)
|
|
# Enable FULL_AND_PIECEWISE by default
|
|
MambaModelConfig.verify_and_update_config(vllm_config)
|
|
|
|
cache_config = vllm_config.cache_config
|
|
model_config = vllm_config.model_config
|
|
parallel_config = vllm_config.parallel_config
|
|
|
|
if cache_config.cache_dtype == "auto":
|
|
kv_cache_dtype = model_config.dtype
|
|
else:
|
|
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
|
|
|
# get attention page size (for 1 token)
|
|
attn_page_size_1_token = FullAttentionSpec(
|
|
block_size=1,
|
|
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
|
head_size=model_config.get_head_size(),
|
|
dtype=kv_cache_dtype).page_size_bytes
|
|
|
|
model_cls, _ = ModelRegistry.resolve_model_cls(
|
|
model_config.architecture,
|
|
model_config=model_config,
|
|
)
|
|
|
|
# get mamba page size
|
|
mamba_page_size = MambaSpec(
|
|
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
|
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
|
block_size=model_config.max_model_len,
|
|
).page_size_bytes
|
|
|
|
block_alignment_bytes = 128
|
|
|
|
# some attention backends (e.g. FA) only support setting
|
|
# block size to multiple of 16, so let's suggest a value
|
|
# that would work (note: FA is currently not compatible
|
|
# with mamba layers, use FlashInfer instead).
|
|
attn_block_size = block_alignment_bytes * cdiv(
|
|
mamba_page_size, block_alignment_bytes * attn_page_size_1_token)
|
|
|
|
# override attention block size if either (a) the
|
|
# user has not set it or (b) the user has set it
|
|
# too small.
|
|
if (cache_config.block_size is None
|
|
or cache_config.block_size < attn_block_size):
|
|
cache_config.block_size = attn_block_size
|
|
logger.info(
|
|
"Setting attention block size to %d tokens "
|
|
"to ensure that attention page size is >= mamba page size.",
|
|
attn_block_size)
|
|
|
|
# compute new attention page size
|
|
attn_page_size = \
|
|
cache_config.block_size * attn_page_size_1_token
|
|
|
|
assert attn_page_size >= mamba_page_size
|
|
|
|
if attn_page_size == mamba_page_size:
|
|
# don't need to pad mamba page size
|
|
return
|
|
|
|
# pad mamba page size to exactly match attention
|
|
if (cache_config.mamba_page_size_padded is None
|
|
or cache_config.mamba_page_size_padded != attn_page_size):
|
|
cache_config.mamba_page_size_padded = (attn_page_size)
|
|
mamba_padding_pct = 100 * (attn_page_size -
|
|
mamba_page_size) / mamba_page_size
|
|
logger.info(
|
|
"Padding mamba page size by %.2f%% to ensure "
|
|
"that mamba page size and attention page size are "
|
|
"exactly equal.", mamba_padding_pct)
|
|
|
|
|
|
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config
|