from dataclasses import dataclass from typing import Any, List import torch import torch.nn.functional as F from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context @dataclass class AscendCommonAttentionMetadata: """ Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. For many of the tensors we keep both GPU and CPU versions. """ query_start_loc: torch.Tensor query_start_loc_cpu: torch.Tensor """(batch_size + 1,), the start location of each request in query Tensor""" seq_lens_cpu: torch.Tensor """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" seq_lens: torch.Tensor """same to seq_lens_cpu, for compatibility with some new attn metadata (such as GDN).""" num_computed_tokens_cpu: torch.Tensor """(batch_size,), the number of computed tokens for each request""" num_reqs: int """Number of requests""" num_actual_tokens: int """Total number of tokens in batch""" max_query_len: int """Max token number of request in batch""" decode_token_per_req: int """decode token number per request""" block_table_tensor: torch.Tensor slot_mapping: torch.Tensor actual_seq_lengths_q: list[int] positions: torch.Tensor = None attn_mask: torch.Tensor = None spec_attn_mask: torch.Tensor = None attn_state: Any = None enable_dbo_across_dp: bool = False is_only_prefill: bool = False graph_pad_size: int = -1 # num_input_tokens refers to total number of tokens including # padding tokens. It is used to handle some padding operations. num_input_tokens: int = 0 # NOTE: This is a temporary solution for rotary embedding in MLA cos: torch.Tensor = None sin: torch.Tensor = None def split_decodes_and_prefills( common_attn_metadata: AscendCommonAttentionMetadata, decode_threshold: int = 1, ) -> tuple[int, int, int, int]: """ Assuming a reordered batch, finds the boundary between prefill and decode requests. Args: common_attn_metadata: AscendCommonAttentionMetadata object containing the batch metadata. decode_threshold: The maximum query length to be considered a decode. Returns: num_decodes: The number of decode requests. num_prefills: The number of prefill requests. num_decode_tokens: The number of tokens in the decode requests. num_prefill_tokens: The number of tokens in the prefill requests. """ max_query_len = common_attn_metadata.max_query_len num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu if max_query_len <= decode_threshold: return num_reqs, 0, num_tokens, 0 query_lens = query_start_loc[1:] - query_start_loc[:-1] is_prefill = query_lens > decode_threshold if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 first_prefill = is_prefill.int().argmax(dim=-1).item() num_decodes = first_prefill num_prefills = num_reqs - num_decodes num_decode_tokens = query_start_loc[first_prefill].item() num_prefill_tokens = num_tokens - num_decode_tokens return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) def wait_for_kv_layer_from_connector(layer_name: str): if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): return connector = get_kv_transfer_group() forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if attn_metadata is None: return # TODO: assert ascendMetadata connector.wait_for_layer_load(layer_name) def maybe_save_kv_layer_to_connector( layer_name: str, kv_cache_layer: List[torch.Tensor], ): if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): return connector = get_kv_transfer_group() forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if attn_metadata is None: return # TODO: assert ascendMetadata connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) def round_up(val: int, align: int) -> int: if align == 0: return 0 return -(val // -align) * align def trans_rope_weight(weight, rope_dim): if rope_dim == 0: return weight.contiguous() nope_part = weight[..., :-rope_dim, :] rope_part = weight[..., -rope_dim:, :] reordered_rope_part = torch.cat( (rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2) return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous() def transdata(nd_mat, block_size: tuple = (16, 16)): r = round_up(nd_mat.shape[0], block_size[0]) c = round_up(nd_mat.shape[1], block_size[1]) r_pad = r - nd_mat.shape[0] c_pad = c - nd_mat.shape[1] nd_mat = F.pad(nd_mat, (0, r_pad, 0, c_pad)) nz_mat = torch.permute( torch.reshape( nd_mat, (r // block_size[0], block_size[0], c // block_size[1], block_size[1]), ), [2, 0, 1, 3], ) nz_mat = torch.reshape( nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])) return nz_mat