from dataclasses import dataclass from typing import Callable, Optional import torch import torch.distributed as dist from vllm.distributed.parallel_state import GroupCoordinator from vllm.model_executor.layers.linear import LinearBase def dispose_tensor(x: torch.Tensor): x.set_(torch.empty([], device=x.device, dtype=x.dtype)) @dataclass class LayerMetadata: """Metadata for a layer. """ layer: Optional[LinearBase] # The layer object. post_method: Callable[[ torch.nn.Module ], None] # The `process_weights_after_loading` method from the quant method. weight: torch.Tensor # The weight tensor. window_idx: int # The index of the window. @dataclass class SharedWindowMetadata: """Metadata for a shared window. """ weight: torch.Tensor # The weight tensor to be shared by layers. data_layer_idx: int # The index of the layer this window's weight is equal to. work: Optional[torch.distributed.Work] # The asynchronous broadcast work. @dataclass class SeriesMetadata: """Metadata for a weight shared series. """ group: GroupCoordinator start_layer: int end_layer: int num_layers: int prefetch_step: int dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor. layers: list[LayerMetadata] shared_windows: list[ SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored. window_offset: int # The index of the window for the next coming layer. def is_source(self, layer_idx) -> bool: return layer_idx % self.group.world_size == self.group.rank_in_group def post_process_after_loading(self): # This method only needs to be called once per series. if self.shared_windows: return for layer_idx in range(self.start_layer, self.end_layer): layer = self.layers[layer_idx - self.start_layer] is_source = self.is_source(layer_idx) # If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight. if not is_source: layer.weight.set_(torch.empty_like(self.dummy_weight)) # Broadcast to get the true weight. dist.broadcast(layer.weight, src=self.group.ranks[layer_idx % self.group.world_size], group=self.group.device_group) assert layer.layer is not None # Call `process_weights_after_loading` from the quant method. layer.post_method(layer.layer) step = layer_idx - self.start_layer if step < self.prefetch_step: # Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights. self.shared_windows.append( SharedWindowMetadata( weight=layer.weight.clone().detach(), data_layer_idx=layer_idx, work=None, )) layer.window_idx = step # When the layer not intended to be stored in this device, link to the corresponding window's tensor. if not is_source: layer.weight.set_(self.shared_windows[-1].weight) else: # Build one more window for prefetch. The weight is useless, so just keep the shape. if step == self.prefetch_step: self.shared_windows.append( SharedWindowMetadata( weight=torch.empty_like(layer.weight), data_layer_idx=-1, work=None, )) # When the layer not intended to be stored in this device, dispose the tensor. if not is_source: dispose_tensor(layer.weight) dispose_tensor(self.dummy_weight) def reach_layer(self, layer_idx: int): # The index of the layer to be prefetched. next_layer_idx = (layer_idx + self.prefetch_step ) % self.num_layers + self.start_layer next_layer = self.layers[next_layer_idx - self.start_layer] # The index of the window to store the weight for the coming layer. next_layer.window_idx = self.window_offset window = self.shared_windows[next_layer.window_idx] # When the layer not intended to be stored in this device, link to the corresponding window's tensor. if not self.is_source(next_layer_idx): next_layer.weight.set_(window.weight) # Update `window_offset` by rolling one step. self.window_offset = (self.window_offset + 1) % (self.prefetch_step + 1) assert window.data_layer_idx != next_layer_idx window.data_layer_idx = next_layer_idx # Start asynchronous broadcast work. window.work = dist.broadcast( next_layer.weight, src=self.group.ranks[next_layer_idx % self.group.world_size], group=self.group.device_group, async_op=True) def wait_weight(self, layer_idx: int): # Find the asynchronous broadcast work and wait for it. assert self.shared_windows window = self.shared_windows[self.layers[layer_idx - self.start_layer].window_idx] # Make sure the data in the corresponding shared window is for the current layer. assert window.data_layer_idx == layer_idx if window.work is not None: window.work.wait() window.work = None @dataclass class LayerExternalMetadata: """External metadata for a layer. """ series: SeriesMetadata layer_idx: int _series_dict: dict[str, SeriesMetadata] = {} _layer_external_dict: dict[int, LayerExternalMetadata] = {} def _create_forward_wrapper(forward: Callable, series: SeriesMetadata, layer_idx: int) -> Callable: def wrapped_forward(*args, **kwargs): # Wait for the weight. series.wait_weight(layer_idx) return forward(*args, **kwargs) return wrapped_forward """ Register linear layers into a shared storage series. In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices. After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization. During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series. Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula: - total_layers = end_layer - start_layer - prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series. Arguments: series_name: This name identifies which series this layer belongs to. group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series. start_layer: The index of the first layer in the series (inclusive). end_layer: The index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer). layer_idx: The index of the current layer. layer: The linear layer object to register. prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases. """ def register_layer_to_shared_weight_series( series_name: str, group: GroupCoordinator, start_layer: int, end_layer: int, layer_idx: int, layer: LinearBase, prefetch_step: int = 1, ): global _series_dict if series_name not in _series_dict: num_layers = end_layer - start_layer assert num_layers > 0 assert prefetch_step >= 0 and prefetch_step <= num_layers - 2 _series_dict[series_name] = SeriesMetadata( group=group, start_layer=start_layer, end_layer=end_layer, num_layers=num_layers, prefetch_step=prefetch_step, dummy_weight=torch.empty_like(layer.weight), layers=[ LayerMetadata( layer=None, post_method=lambda layer: None, weight=torch.empty([]), window_idx=-1, ) for _ in range(num_layers) ], shared_windows=[], window_offset=prefetch_step, ) series = _series_dict[series_name] assert layer.quant_method is not None series.layers[layer_idx - start_layer] = LayerMetadata( layer=layer, post_method=layer.quant_method.process_weights_after_loading, weight=layer.weight, window_idx=-1, ) # Discard the original `process_weights_after_loading` method such that it won't be called by others. layer.quant_method.process_weights_after_loading = lambda layer: None # When the layer not intended to be stored in this device, dispose the tensor and skip weight loading. if not series.is_source(layer_idx): dispose_tensor(layer.weight) layer.weight.weight_loader = lambda *args, **kwargs: None layer.forward = _create_forward_wrapper(layer.forward, series, layer_idx) global _layer_external_dict _layer_external_dict[id(layer)] = LayerExternalMetadata( series=series, layer_idx=layer_idx, ) def post_process_after_loading_for_shared_weight_series(layer: LinearBase): ext = _layer_external_dict[id(layer)] ext.series.post_process_after_loading() def reach_layer_for_shared_weight_series(layer: LinearBase): ext = _layer_external_dict[id(layer)] ext.series.reach_layer(ext.layer_idx)