# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. from collections.abc import Iterable from typing import Any, List, Optional, Union import torch import torch.nn.functional as F import vllm import vllm.envs as envs from torch import nn from transformers import Qwen2Config from vllm.attention import AttentionMetadata, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, tensor_model_parallel_all_gather, tensor_model_parallel_reduce_scatter) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.qwen2 import Qwen2Attention # noqa: F401 from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM # noqa: F401 from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Model from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix) from vllm.sequence import IntermediateTensors from vllm_npu.ascend_config import get_ascend_config from vllm_npu.attention.attention_v1 import AscendAttentionState def all_gather_and_maybe_unpad( hidden_states: torch.Tensor, pad_size: int, ) -> torch.Tensor: hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) if pad_size > 0: return hidden_states[:-pad_size, :] return hidden_states def maybe_pad_and_reduce_scatter( hidden_states: torch.Tensor, pad_size: int, ) -> torch.Tensor: if pad_size > 0: hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size)) hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0) return hidden_states class CustomQwen2Attention(Qwen2Attention): def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, rope_scaling: Optional[tuple] = None, prefix: str = "", attn_type: str = AttentionType.DECODER, dual_chunk_attention_config: Optional[dict[str, Any]] = None, ) -> None: super().__init__( hidden_size=hidden_size, num_heads=num_heads, num_kv_heads=num_kv_heads, max_position=max_position, rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config, rope_scaling=rope_scaling, prefix=prefix, attn_type=attn_type, dual_chunk_attention_config=dual_chunk_attention_config) ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.torchair_graph_enabled and attn_metadata is not None and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: q, k = self.rotary_emb(positions, q, k, is_prefill=False, is_qwen_torchair=True) forward_kwargs = {} if envs.VLLM_USE_V1: output_shape = q.shape output = torch.empty(output_shape, dtype=q.dtype, device=q.device) forward_kwargs['output'] = output attn_output = self.attn.impl.forward(self.attn, q, k, v, kv_cache=kv_cache, attn_metadata=attn_metadata, trace_flag=False, **forward_kwargs) output, _ = self.o_proj(attn_output) return output else: if type(self.rotary_emb) is RotaryEmbedding: q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True) else: q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output class CustomQwen2DecoderLayer(nn.Module): def __init__( self, config: Qwen2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) dual_chunk_attention_config = getattr(config, "dual_chunk_attention_config", None) # By default, Qwen2 uses causal attention as it is a decoder-only model. # You can override the HF config with `is_causal=False` to enable # bidirectional attention, which is used in some embedding models # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) if getattr(config, "is_causal", True): attn_type = AttentionType.DECODER else: attn_type = AttentionType.ENCODER_ONLY self.self_attn = CustomQwen2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config, rope_scaling=rope_scaling, prefix=f"{prefix}.self_attn", attn_type=attn_type, dual_chunk_attention_config=dual_chunk_attention_config, ) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata) # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile( dynamic_arg_dims={ "input_ids": 0, # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, # otherwise (seq_len, ). "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, }) class CustomQwen2Model(Qwen2Model): def __init__( self, *, vllm_config: VllmConfig, prefix: str = "", decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer): super().__init__(vllm_config=vllm_config, prefix=prefix, decoder_layer_type=decoder_layer_type) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: Optional[List[torch.Tensor]] = None, attn_metadata: Optional[AttentionMetadata] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): layer = self.layers[i] kv_cache = kv_caches[i - self.start_layer] \ if kv_caches is not None else None hidden_states, residual = layer(positions, hidden_states, residual, kv_cache=kv_cache, attn_metadata=attn_metadata) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # add `CustomQwen2Model` to init self.model packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config self.quant_config = quant_config self.model = CustomQwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=maybe_prefix( prefix, "lm_head")) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: Optional[List[torch.Tensor]] = None, attn_metadata: Optional[AttentionMetadata] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata=None, # type: ignore ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) vllm.model_executor.models.qwen2.Qwen2ForCausalLM = CustomQwen2ForCausalLM