import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch.nn as nn
[docs]
class LayerNameResolver:
"""
Provides intuitive ways to select layers for tuning in large models.
This resolver allows users to specify layers using patterns, types, or indices
rather than requiring exact layer names.
"""
[docs]
def __init__(
self, model: nn.Module, layer_map: Optional[Dict[str, nn.Module]] = None
):
"""
Initialize the resolver with a model.
Args:
model: The neural network model to analyze
"""
self.model = model
self.layer_map = layer_map if layer_map is not None else {}
def _check_selectors(
self, selectors: Dict[str, Union[str, List[str], int, List[int]]]
) -> None:
"""
Validate selector configurations before processing.
Args:
selectors: Dictionary with selection criteria
Raises:
ValueError: If conflicting or invalid selectors are detected
"""
# Check if both 'indices' and 'range' are used together
if "indices" in selectors and "range" in selectors:
raise ValueError("Cannot use both 'indices' and 'range' selectors together")
# Validate range format if present
if "range" in selectors:
range_val = selectors["range"]
if not isinstance(range_val, list):
raise ValueError(
"'range' must be a list of [start, end] or [start, end, step]"
)
if len(range_val) < 2 or len(range_val) > 3:
raise ValueError(
"'range' must contain 2 or 3 elements: [start, end] or [start, end, step]"
)
if not all(isinstance(v, int) for v in range_val):
raise ValueError("All 'range' values must be integers")
# Type narrowing: we know all elements are int now
start: int = range_val[0] # type: ignore
end: int = range_val[1] # type: ignore
if start < 0:
raise ValueError("'range' start value must be non-negative")
if end <= start:
raise ValueError("'range' end value must be greater than start value")
# Validate indices format if present
if "indices" in selectors:
indices = selectors["indices"]
if isinstance(indices, int):
if indices < 0:
raise ValueError("Index values must be non-negative")
elif isinstance(indices, list):
if not all(isinstance(idx, int) for idx in indices):
raise ValueError("All 'indices' must be integers")
# Type narrowing: all elements are int now
int_indices: List[int] = [
idx for idx in indices if isinstance(idx, int)
]
if any(idx < 0 for idx in int_indices):
raise ValueError("All 'indices' must be non-negative")
else:
raise ValueError("'indices' must be an integer or list of integers")
[docs]
def resolve(
self,
selectors: Union[
str, List[str], Dict[str, Union[str, List[str], int, List[int]]]
],
) -> List[str]:
"""
Resolve layer name selectors to actual layer names in the model.
Args:
selectors: One or more selectors to match layers. Can be:
- A single string pattern (e.g., "encoder.*attention")
- A list of string patterns
- A dictionary with keys:
- 'pattern': String or list of regex patterns
- 'type': Layer type or list of types (e.g., 'Linear', 'Conv2d')
- 'contains': String that layer name must contain
- 'indices': Indices to select from matched layers (e.g., [0, 2, 4] for first, third, fifth)
- 'range': Range of indices as [start, end, step]
Returns:
List of resolved layer names that match the selectors
"""
# Handle string case
if isinstance(selectors, str):
return self._resolve_pattern(selectors)
# Handle list of strings case
elif isinstance(selectors, list) and all(isinstance(s, str) for s in selectors):
matched_names = []
for selector in selectors:
matched_names.extend(self._resolve_pattern(selector))
return list(set(matched_names)) # Remove duplicates
# Handle dictionary case with advanced options
elif isinstance(selectors, dict):
# Validate selectors before processing
self._check_selectors(selectors)
matched_layers = set(self.layer_map.keys())
# Filter by pattern
if "pattern" in selectors:
patterns_val = selectors["pattern"]
patterns: List[str] = []
if isinstance(patterns_val, str):
patterns = [patterns_val]
elif isinstance(patterns_val, list):
# Filter to only string elements
patterns = [p for p in patterns_val if isinstance(p, str)]
pattern_matches = set()
for pattern in patterns:
pattern_matches.update(self._resolve_pattern(pattern))
matched_layers = matched_layers.intersection(pattern_matches)
# Filter by containing string
if "contains" in selectors:
contains_val = selectors["contains"]
contains: List[str] = []
if isinstance(contains_val, str):
contains = [contains_val]
elif isinstance(contains_val, list):
contains = [c for c in contains_val if isinstance(c, str)]
matched_layers = {
name for name in matched_layers if any(c in name for c in contains)
}
# Filter by type
if "type" in selectors:
types_val = selectors["type"]
types: List[str] = []
if isinstance(types_val, str):
types = [types_val]
elif isinstance(types_val, list):
types = [t for t in types_val if isinstance(t, str)]
type_matches = set()
for layer_name in matched_layers:
layer = self.layer_map[layer_name]
layer_type = type(layer).__name__
if layer_type in types:
type_matches.add(layer_name)
matched_layers = type_matches
# Convert to list for indexing operations
matched_list = sorted(list(matched_layers))
if not matched_list:
return []
# Apply indices filter
if "indices" in selectors:
indices_val = selectors["indices"]
if isinstance(indices_val, int):
indices = [indices_val]
elif isinstance(indices_val, list):
indices = [i for i in indices_val if isinstance(i, int)]
else:
indices = []
# Validate indices are within bounds
if indices and max(indices) >= len(matched_list):
raise ValueError(
f"Index {max(indices)} is out of bounds for {len(matched_list)} matched layers"
)
selected_layers = [matched_list[i] for i in indices]
return selected_layers
# Apply range filter
if "range" in selectors:
range_val = selectors["range"]
if isinstance(range_val, list) and len(range_val) >= 2:
start_val = range_val[0]
end_val = range_val[1]
step_val = range_val[2] if len(range_val) > 2 else 1
if not isinstance(start_val, int) or not isinstance(end_val, int):
raise ValueError("Range start and end must be integers")
if not isinstance(step_val, int):
raise ValueError("Range step must be an integer")
# Type narrowed now
start: int = start_val
end: int = end_val
step: int = step_val
# Validate end is within bounds
if end > len(matched_list):
raise ValueError(
f"Range end {end} exceeds the number of matched layers ({len(matched_list)})"
)
selected_layers = matched_list[start:end:step]
return selected_layers
if len(matched_list) == 0:
raise ValueError("No layers matched the provided selectors")
return matched_list
else:
raise ValueError(
"Selectors must be a string, list of strings, or a dictionary with selection criteria"
)
def _resolve_pattern(self, pattern: str) -> List[str]:
"""
Resolve a regex pattern to matching layer names.
Args:
pattern: Regex pattern to match layer names
Returns:
List of layer names that match the pattern
"""
try:
regex = re.compile(pattern)
return [name for name in self.layer_map.keys() if regex.search(name)]
except re.error:
# If not a valid regex, try simple substring matching
return [name for name in self.layer_map.keys() if pattern in name]
def _get_layers_by_depth(self, depth: int) -> List[str]:
"""
Get layer names at a specific depth in the model hierarchy.
Depth is determined by the number of '.' in the layer name.
For example, 'model.encoder.layer0' has depth 2.
Args:
depth: The depth level to retrieve layers from
Returns:
List of layer names at the specified depth
"""
return [name for name in self.layer_map.keys() if name.count(".") == depth]
def _get_layer_types(self) -> Dict[str, List[str]]:
"""
Get a mapping of layer types to layer names.
Returns:
Dictionary where keys are layer types (e.g., 'Linear', 'Conv2d')
and values are lists of layer names of that type
"""
type_map: Dict[str, List[str]] = {}
for name, layer in self.layer_map.items():
layer_type = type(layer).__name__
if layer_type not in type_map:
type_map[layer_type] = []
type_map[layer_type].append(name)
return type_map
def _get_layers_by_parent(self, parent_name: str) -> List[str]:
"""
Get all direct child layers of a parent layer.
Args:
parent_name: Name of the parent layer
Returns:
List of layer names that are direct children of the parent
"""
if not parent_name.endswith("."):
parent_name = parent_name + "."
return [
name
for name in self.layer_map.keys()
if name.startswith(parent_name) and name[len(parent_name) :].count(".") == 0
]
def _filter_layers_by_attribute(
self, attribute_name: str, attribute_value: Any
) -> List[str]:
"""
Filter layers based on a specific attribute value.
Args:
attribute_name: Name of the attribute to check
attribute_value: Expected value of the attribute
Returns:
List of layer names with matching attribute value
"""
matched_layers = []
for name, layer in self.layer_map.items():
if hasattr(layer, attribute_name):
if getattr(layer, attribute_name) == attribute_value:
matched_layers.append(name)
return matched_layers
def _filter_layers_by_function(
self, filter_fn: Callable[[nn.Module], bool]
) -> List[str]:
"""
Filter layers using a custom function.
Args:
filter_fn: Function that takes a layer and returns True if it should be included
Returns:
List of layer names that pass the filter function
"""
return [name for name, layer in self.layer_map.items() if filter_fn(layer)]
def _get_layer_parameter_count(self) -> Dict[str, int]:
"""
Count parameters for each layer in the model.
Returns:
Dictionary mapping layer names to their parameter counts
"""
param_counts = {}
for name, layer in self.layer_map.items():
param_count = sum(p.numel() for p in layer.parameters())
param_counts[name] = param_count
return param_counts
def _get_top_n_layers_by_parameters(self, n: int = 10) -> List[Tuple[str, int]]:
"""
Get the top N layers with the most parameters.
Args:
n: Number of top layers to return
Returns:
List of tuples (layer_name, parameter_count) sorted by parameter count
"""
param_counts = self._get_layer_parameter_count()
sorted_layers = sorted(param_counts.items(), key=lambda x: x[1], reverse=True)
return sorted_layers[:n]
def _find_common_prefix(self) -> str:
"""
Find the common prefix shared by all layer names.
This is useful for identifying the root module name.
Returns:
The common prefix string
"""
if not self.layer_map:
return ""
names = list(self.layer_map.keys())
prefix = names[0]
for name in names[1:]:
while not name.startswith(prefix):
prefix = prefix[:-1]
if not prefix:
return ""
return prefix