From 1cb99d1f79006ffe07b6656694dad40a8eb15a46 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Jun 2025 00:37:17 +0000 Subject: [PATCH 1/9] Initial plan for issue From 8d59084438e0d9aad2a910df0b6c255986bf0f03 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Jun 2025 00:40:49 +0000 Subject: [PATCH 2/9] Initial analysis and plan for disentangling merging from library transforms Co-authored-by: sordonia <1689148+sordonia@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e77b18947..9cb1d8d6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "mttl" version = "0.0.1" description = "Multi-Task Transfer Learning with Adapters" readme = "README.md" -requires-python = ">=3.9, <3.12" # ray don't support 3.12 +requires-python = ">=3.9, <3.13" # ray don't support 3.12 dynamic = ["dependencies"] classifiers = [ From d194abf30fa1974fc3b062d06861398a1e1311b9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Jun 2025 00:45:03 +0000 Subject: [PATCH 3/9] Disentangle merging routines from library transforms Co-authored-by: sordonia <1689148+sordonia@users.noreply.github.com> --- mttl/models/library/__init__.py | 1 + mttl/models/library/library_transforms.py | 165 +--------------- mttl/models/library/merging.py | 226 ++++++++++++++++++++++ 3 files changed, 236 insertions(+), 156 deletions(-) create mode 100644 mttl/models/library/merging.py diff --git a/mttl/models/library/__init__.py b/mttl/models/library/__init__.py index e69de29bb..5ce771bc1 100644 --- a/mttl/models/library/__init__.py +++ b/mttl/models/library/__init__.py @@ -0,0 +1 @@ +from .merging import ties_merge, weighted_linear_merge, wudi_merge diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index bb601f44c..cd2a06145 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -222,56 +222,15 @@ def __init__(self, config: WudiMergeConfig = None): @torch.no_grad() def transform(self, library) -> Expert: - device = "cuda" if torch.cuda.is_available() else "cpu" + from mttl.models.library.merging import wudi_merge + if type(library) == str: library = ExpertLibrary.get_expert_library(library) expert_names = list(library.keys()) experts = [library[name] for name in expert_names] - logger.info("Merging {} experts using WuDi merge".format(len(experts))) - - base_expert = copy.deepcopy(experts[0]) - base_expert.name = "wudi_merged_expert" - - # Get all parameter keys that we want to merge - keys = [key for key in base_expert.expert_weights.keys()] - - for key in keys: - # Stack all expert weights for this parameter - values = torch.stack([expert.expert_weights[key] for expert in experts]) - - values = values.to(device) - - # Initialize merged vector as sum of all vectors - merging_vector = torch.nn.Parameter(torch.sum(values, dim=0)) - optimizer = torch.optim.Adam( - [merging_vector], lr=self.config.lr, weight_decay=0 - ) - - # Compute L2 norms - l2_norms = torch.square( - torch.norm(values.reshape(values.shape[0], -1), p=2, dim=-1) - ) - - # Optimize merging vector - for _ in tqdm(range(self.config.iter), desc=f"Optimizing parameter {key}"): - disturbing_vectors = merging_vector.unsqueeze(0) - values - inner_product = torch.matmul(disturbing_vectors, values.transpose(1, 2)) - - loss = torch.sum( - torch.square(inner_product) / l2_norms.unsqueeze(-1).unsqueeze(-1) - ) - loss = loss.requires_grad_(True) # Ensure loss requires gradients - optimizer.zero_grad() - loss.backward() - optimizer.step() - - merging_vector = merging_vector / len(experts) - # Update base expert weights with optimized merging vector - base_expert.expert_weights[key].data.copy_(merging_vector.data.cpu()) - - return base_expert + return wudi_merge(experts, self.config) @dataclass @@ -290,55 +249,15 @@ def __init__(self, config: WeightedLinearMergeConfig = None): @torch.no_grad() def transform(self, library) -> Expert: + from mttl.models.library.merging import weighted_linear_merge + if type(library) == str: library = ExpertLibrary.get_expert_library(library) expert_names = list(library.keys()) experts = [library[name] for name in expert_names] - logger.info("Averaging {} experts".format(len(experts))) - - base_expert = copy.deepcopy(experts[0]) - base_expert.name = "weighted_expert" - - if self.config.weights is not None: - assert set(self.config.weights.keys()) == set( - expert_names - ), "Weights must have the same keys as the experts" - if not (1 - 1e-6) <= sum(self.config.weights.values()) <= (1 + 1e-6): - logger.warning( - "Weights do not sum to 1.0, please make sure this is intended" - ) - - # scale the base expert - for k, v in base_expert.expert_weights.items(): - base_expert.expert_weights[k] *= self.config.weights[expert_names[0]] - - for _, expert in zip(expert_names[1:], experts[1:]): - # Validate that the expert is compatible - assert type(expert.expert_info.expert_config) == type( - base_expert.expert_info.expert_config - ), "Expert configs must be the same type" - assert set(expert.expert_weights.keys()) == set( - base_expert.expert_weights.keys() - ), "Expert weights must have the same keys" - - weight = 1.0 - if self.config.weights is not None: - weight = self.config.weights[expert.expert_info.expert_name] - - for k, v in expert.expert_weights.items(): - base_expert.expert_weights[k] += v * weight - - # Normalize the final expert - if self.config.weights is None: - for k, v in base_expert.expert_weights.items(): - base_expert.expert_weights[k] /= len(experts) - - # manually change the config of the expert to remove the tie_params - base_expert.expert_config.tie_params = None - - return base_expert + return weighted_linear_merge(experts, self.config) @dataclass @@ -360,81 +279,15 @@ def __init__(self, config: TiesMergeConfig = None): @torch.no_grad() def transform(self, library) -> Expert: + from mttl.models.library.merging import ties_merge + if type(library) == str: library = ExpertLibrary.get_expert_library(library) expert_names = list(library.keys()) experts = [library[name] for name in expert_names] - logger.info("Averaging {} experts".format(len(experts))) - - base_expert = copy.deepcopy(experts[0]) - base_expert.name = "ties_weighted_expert" - - state_dict_keys = list(base_expert.expert_weights.keys()) - - # Build n_tasks x D experts - # TODO: No need to build this matrix, can be done 1 expert at a time - expert_vectors = [] - for expert in experts: - expert_vectors += [ - torch.nn.utils.parameters_to_vector( - list(expert.expert_weights[k] for k in state_dict_keys) - ) - ] - - expert_vectors = torch.stack(expert_vectors, dim=0) - per_exp_th = expert_vectors.abs().quantile(1.0 - self.config.top_k, dim=1) - keep_param = expert_vectors.abs() >= per_exp_th.view(-1, 1) - - mean_valid_per_task = keep_param.float().mean(1) - assert torch.all((mean_valid_per_task - self.config.top_k).abs() < 1e-4) - - used, kept, total = 0, 0, 0 - - for param_name in state_dict_keys: - # stack the expert weights - expert_weights = torch.stack( - [expert.expert_weights[param_name] for expert in experts], dim=0 - ) - - # keep weights over the threshold - TH = per_exp_th.view(-1, *((1,) * (expert_weights.ndim - 1))) - keep_mask = expert_weights.abs() >= TH - expert_weights = expert_weights * keep_mask - - if self.config.only_sparsify: - final_param = expert_weights.mean(0) - used += keep_mask.sum().item() - else: - # sign majority vote - sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign() - sign_per_dim = expert_weights.sum(0, keepdim=True).sign() - - # keep only weights whose sign agree with the majority - use_for_avg = expert_weights.sign() == sign_per_dim - - deno = use_for_avg.sum(0).clamp(min=1.0) - sum_param = (expert_weights * use_for_avg).sum(0) - final_param = sum_param / deno - used += (use_for_avg & (sign_per_dim != 0.0)).sum().item() - - kept += (expert_weights.abs() > TH).sum() - total += expert_weights.numel() - - base_expert.expert_weights[param_name].data.copy_(final_param) - - logger.info( - "Params not reset to 0 in TIES merge: {:.10f}%".format(100.0 * kept / total) - ) - logger.info( - "Params used to compute TIES mean: {:.10f}%".format(100.0 * used / total) - ) - - # manually change the config of the expert to remove the tie_params - base_expert.expert_config.tie_params = None - - return base_expert + return ties_merge(experts, self.config) @dataclass diff --git a/mttl/models/library/merging.py b/mttl/models/library/merging.py new file mode 100644 index 000000000..b41e218b3 --- /dev/null +++ b/mttl/models/library/merging.py @@ -0,0 +1,226 @@ +""" +Standalone merging functions for Expert objects. + +This module provides standalone merging routines that take a list of Expert objects +as input and return a merged Expert. These functions are decoupled from library +transforms and can be used independently. +""" + +import copy +from typing import List + +import torch +from tqdm.auto import tqdm + +from mttl.logging import logger +from mttl.models.library.expert import Expert +from mttl.models.library.library_transforms import ( + TiesMergeConfig, + WeightedLinearMergeConfig, + WudiMergeConfig, +) + + +def wudi_merge(experts: List[Expert], config: WudiMergeConfig) -> Expert: + """ + Merge experts using WuDi merge algorithm. + + Args: + experts: List of Expert objects to merge + config: WudiMergeConfig containing merge parameters + + Returns: + Expert: Merged expert + """ + if not experts: + raise ValueError("Cannot merge empty list of experts") + + device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info("Merging {} experts using WuDi merge".format(len(experts))) + + base_expert = copy.deepcopy(experts[0]) + base_expert.name = "wudi_merged_expert" + + # Get all parameter keys that we want to merge + keys = [key for key in base_expert.expert_weights.keys()] + + for key in keys: + # Stack all expert weights for this parameter + values = torch.stack([expert.expert_weights[key] for expert in experts]) + + values = values.to(device) + + # Initialize merged vector as sum of all vectors + merging_vector = torch.nn.Parameter(torch.sum(values, dim=0)) + optimizer = torch.optim.Adam( + [merging_vector], lr=config.lr, weight_decay=0 + ) + + # Compute L2 norms + l2_norms = torch.square( + torch.norm(values.reshape(values.shape[0], -1), p=2, dim=-1) + ) + + # Optimize merging vector + for _ in tqdm(range(config.iter), desc=f"Optimizing parameter {key}"): + disturbing_vectors = merging_vector.unsqueeze(0) - values + inner_product = torch.matmul(disturbing_vectors, values.transpose(1, 2)) + + loss = torch.sum( + torch.square(inner_product) / l2_norms.unsqueeze(-1).unsqueeze(-1) + ) + loss = loss.requires_grad_(True) # Ensure loss requires gradients + optimizer.zero_grad() + loss.backward() + optimizer.step() + + merging_vector = merging_vector / len(experts) + # Update base expert weights with optimized merging vector + base_expert.expert_weights[key].data.copy_(merging_vector.data.cpu()) + + return base_expert + + +def weighted_linear_merge(experts: List[Expert], config: WeightedLinearMergeConfig) -> Expert: + """ + Merge experts using weighted linear averaging. + + Args: + experts: List of Expert objects to merge + config: WeightedLinearMergeConfig containing merge parameters + + Returns: + Expert: Merged expert + """ + if not experts: + raise ValueError("Cannot merge empty list of experts") + + expert_names = [expert.name for expert in experts] + logger.info("Averaging {} experts".format(len(experts))) + + base_expert = copy.deepcopy(experts[0]) + base_expert.name = "weighted_expert" + + if config.weights is not None: + assert set(config.weights.keys()) == set( + expert_names + ), "Weights must have the same keys as the experts" + if not (1 - 1e-6) <= sum(config.weights.values()) <= (1 + 1e-6): + logger.warning( + "Weights do not sum to 1.0, please make sure this is intended" + ) + + # scale the base expert + for k, v in base_expert.expert_weights.items(): + base_expert.expert_weights[k] *= config.weights[expert_names[0]] + + for expert in experts[1:]: + # Validate that the expert is compatible + assert type(expert.expert_info.expert_config) == type( + base_expert.expert_info.expert_config + ), "Expert configs must be the same type" + assert set(expert.expert_weights.keys()) == set( + base_expert.expert_weights.keys() + ), "Expert weights must have the same keys" + + weight = 1.0 + if config.weights is not None: + weight = config.weights[expert.expert_info.expert_name] + + for k, v in expert.expert_weights.items(): + base_expert.expert_weights[k] += v * weight + + # Normalize the final expert + if config.weights is None: + for k, v in base_expert.expert_weights.items(): + base_expert.expert_weights[k] /= len(experts) + + # manually change the config of the expert to remove the tie_params + base_expert.expert_config.tie_params = None + + return base_expert + + +def ties_merge(experts: List[Expert], config: TiesMergeConfig) -> Expert: + """ + Merge experts using TIES merge algorithm. + + Args: + experts: List of Expert objects to merge + config: TiesMergeConfig containing merge parameters + + Returns: + Expert: Merged expert + """ + if not experts: + raise ValueError("Cannot merge empty list of experts") + + logger.info("Averaging {} experts".format(len(experts))) + + base_expert = copy.deepcopy(experts[0]) + base_expert.name = "ties_weighted_expert" + + state_dict_keys = list(base_expert.expert_weights.keys()) + + # Build n_tasks x D experts + # TODO: No need to build this matrix, can be done 1 expert at a time + expert_vectors = [] + for expert in experts: + expert_vectors += [ + torch.nn.utils.parameters_to_vector( + list(expert.expert_weights[k] for k in state_dict_keys) + ) + ] + + expert_vectors = torch.stack(expert_vectors, dim=0) + per_exp_th = expert_vectors.abs().quantile(1.0 - config.top_k, dim=1) + keep_param = expert_vectors.abs() >= per_exp_th.view(-1, 1) + + mean_valid_per_task = keep_param.float().mean(1) + assert torch.all((mean_valid_per_task - config.top_k).abs() < 1e-4) + + used, kept, total = 0, 0, 0 + + for param_name in state_dict_keys: + # stack the expert weights + expert_weights = torch.stack( + [expert.expert_weights[param_name] for expert in experts], dim=0 + ) + + # keep weights over the threshold + TH = per_exp_th.view(-1, *((1,) * (expert_weights.ndim - 1))) + keep_mask = expert_weights.abs() >= TH + expert_weights = expert_weights * keep_mask + + if config.only_sparsify: + final_param = expert_weights.mean(0) + used += keep_mask.sum().item() + else: + # sign majority vote + sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign() + sign_per_dim = expert_weights.sum(0, keepdim=True).sign() + + # keep only weights whose sign agree with the majority + use_for_avg = expert_weights.sign() == sign_per_dim + + deno = use_for_avg.sum(0).clamp(min=1.0) + sum_param = (expert_weights * use_for_avg).sum(0) + final_param = sum_param / deno + used += (use_for_avg & (sign_per_dim != 0.0)).sum().item() + + kept += (expert_weights.abs() > TH).sum() + total += expert_weights.numel() + + base_expert.expert_weights[param_name].data.copy_(final_param) + + logger.info( + "Params not reset to 0 in TIES merge: {:.10f}%".format(100.0 * kept / total) + ) + logger.info( + "Params used to compute TIES mean: {:.10f}%".format(100.0 * used / total) + ) + + # manually change the config of the expert to remove the tie_params + base_expert.expert_config.tie_params = None + + return base_expert \ No newline at end of file From 4499daaf59db9c5a6f2228db6723111a36bf88ad Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Jun 2025 13:34:52 +0000 Subject: [PATCH 4/9] Reorganize merging functions into separate files and move outside library module Co-authored-by: sordonia <1689148+sordonia@users.noreply.github.com> --- mttl/models/__init__.py | 1 + mttl/models/library/__init__.py | 2 +- mttl/models/library/library_transforms.py | 6 +- mttl/models/library/merging.py | 226 ---------------------- mttl/models/merging/__init__.py | 23 +++ mttl/models/merging/arrow.py | 183 ++++++++++++++++++ mttl/models/merging/phatgoose.py | 92 +++++++++ mttl/models/merging/ties.py | 99 ++++++++++ mttl/models/merging/weighted_linear.py | 72 +++++++ mttl/models/merging/wudi.py | 75 +++++++ 10 files changed, 549 insertions(+), 230 deletions(-) delete mode 100644 mttl/models/library/merging.py create mode 100644 mttl/models/merging/__init__.py create mode 100644 mttl/models/merging/arrow.py create mode 100644 mttl/models/merging/phatgoose.py create mode 100644 mttl/models/merging/ties.py create mode 100644 mttl/models/merging/weighted_linear.py create mode 100644 mttl/models/merging/wudi.py diff --git a/mttl/models/__init__.py b/mttl/models/__init__.py index e69de29bb..8b1378917 100644 --- a/mttl/models/__init__.py +++ b/mttl/models/__init__.py @@ -0,0 +1 @@ + diff --git a/mttl/models/library/__init__.py b/mttl/models/library/__init__.py index 5ce771bc1..9239607c0 100644 --- a/mttl/models/library/__init__.py +++ b/mttl/models/library/__init__.py @@ -1 +1 @@ -from .merging import ties_merge, weighted_linear_merge, wudi_merge +from mttl.models.merging import ties_merge, weighted_linear_merge, wudi_merge diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index cd2a06145..b88b47338 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -222,7 +222,7 @@ def __init__(self, config: WudiMergeConfig = None): @torch.no_grad() def transform(self, library) -> Expert: - from mttl.models.library.merging import wudi_merge + from mttl.models.merging import wudi_merge if type(library) == str: library = ExpertLibrary.get_expert_library(library) @@ -249,7 +249,7 @@ def __init__(self, config: WeightedLinearMergeConfig = None): @torch.no_grad() def transform(self, library) -> Expert: - from mttl.models.library.merging import weighted_linear_merge + from mttl.models.merging import weighted_linear_merge if type(library) == str: library = ExpertLibrary.get_expert_library(library) @@ -279,7 +279,7 @@ def __init__(self, config: TiesMergeConfig = None): @torch.no_grad() def transform(self, library) -> Expert: - from mttl.models.library.merging import ties_merge + from mttl.models.merging import ties_merge if type(library) == str: library = ExpertLibrary.get_expert_library(library) diff --git a/mttl/models/library/merging.py b/mttl/models/library/merging.py deleted file mode 100644 index b41e218b3..000000000 --- a/mttl/models/library/merging.py +++ /dev/null @@ -1,226 +0,0 @@ -""" -Standalone merging functions for Expert objects. - -This module provides standalone merging routines that take a list of Expert objects -as input and return a merged Expert. These functions are decoupled from library -transforms and can be used independently. -""" - -import copy -from typing import List - -import torch -from tqdm.auto import tqdm - -from mttl.logging import logger -from mttl.models.library.expert import Expert -from mttl.models.library.library_transforms import ( - TiesMergeConfig, - WeightedLinearMergeConfig, - WudiMergeConfig, -) - - -def wudi_merge(experts: List[Expert], config: WudiMergeConfig) -> Expert: - """ - Merge experts using WuDi merge algorithm. - - Args: - experts: List of Expert objects to merge - config: WudiMergeConfig containing merge parameters - - Returns: - Expert: Merged expert - """ - if not experts: - raise ValueError("Cannot merge empty list of experts") - - device = "cuda" if torch.cuda.is_available() else "cpu" - logger.info("Merging {} experts using WuDi merge".format(len(experts))) - - base_expert = copy.deepcopy(experts[0]) - base_expert.name = "wudi_merged_expert" - - # Get all parameter keys that we want to merge - keys = [key for key in base_expert.expert_weights.keys()] - - for key in keys: - # Stack all expert weights for this parameter - values = torch.stack([expert.expert_weights[key] for expert in experts]) - - values = values.to(device) - - # Initialize merged vector as sum of all vectors - merging_vector = torch.nn.Parameter(torch.sum(values, dim=0)) - optimizer = torch.optim.Adam( - [merging_vector], lr=config.lr, weight_decay=0 - ) - - # Compute L2 norms - l2_norms = torch.square( - torch.norm(values.reshape(values.shape[0], -1), p=2, dim=-1) - ) - - # Optimize merging vector - for _ in tqdm(range(config.iter), desc=f"Optimizing parameter {key}"): - disturbing_vectors = merging_vector.unsqueeze(0) - values - inner_product = torch.matmul(disturbing_vectors, values.transpose(1, 2)) - - loss = torch.sum( - torch.square(inner_product) / l2_norms.unsqueeze(-1).unsqueeze(-1) - ) - loss = loss.requires_grad_(True) # Ensure loss requires gradients - optimizer.zero_grad() - loss.backward() - optimizer.step() - - merging_vector = merging_vector / len(experts) - # Update base expert weights with optimized merging vector - base_expert.expert_weights[key].data.copy_(merging_vector.data.cpu()) - - return base_expert - - -def weighted_linear_merge(experts: List[Expert], config: WeightedLinearMergeConfig) -> Expert: - """ - Merge experts using weighted linear averaging. - - Args: - experts: List of Expert objects to merge - config: WeightedLinearMergeConfig containing merge parameters - - Returns: - Expert: Merged expert - """ - if not experts: - raise ValueError("Cannot merge empty list of experts") - - expert_names = [expert.name for expert in experts] - logger.info("Averaging {} experts".format(len(experts))) - - base_expert = copy.deepcopy(experts[0]) - base_expert.name = "weighted_expert" - - if config.weights is not None: - assert set(config.weights.keys()) == set( - expert_names - ), "Weights must have the same keys as the experts" - if not (1 - 1e-6) <= sum(config.weights.values()) <= (1 + 1e-6): - logger.warning( - "Weights do not sum to 1.0, please make sure this is intended" - ) - - # scale the base expert - for k, v in base_expert.expert_weights.items(): - base_expert.expert_weights[k] *= config.weights[expert_names[0]] - - for expert in experts[1:]: - # Validate that the expert is compatible - assert type(expert.expert_info.expert_config) == type( - base_expert.expert_info.expert_config - ), "Expert configs must be the same type" - assert set(expert.expert_weights.keys()) == set( - base_expert.expert_weights.keys() - ), "Expert weights must have the same keys" - - weight = 1.0 - if config.weights is not None: - weight = config.weights[expert.expert_info.expert_name] - - for k, v in expert.expert_weights.items(): - base_expert.expert_weights[k] += v * weight - - # Normalize the final expert - if config.weights is None: - for k, v in base_expert.expert_weights.items(): - base_expert.expert_weights[k] /= len(experts) - - # manually change the config of the expert to remove the tie_params - base_expert.expert_config.tie_params = None - - return base_expert - - -def ties_merge(experts: List[Expert], config: TiesMergeConfig) -> Expert: - """ - Merge experts using TIES merge algorithm. - - Args: - experts: List of Expert objects to merge - config: TiesMergeConfig containing merge parameters - - Returns: - Expert: Merged expert - """ - if not experts: - raise ValueError("Cannot merge empty list of experts") - - logger.info("Averaging {} experts".format(len(experts))) - - base_expert = copy.deepcopy(experts[0]) - base_expert.name = "ties_weighted_expert" - - state_dict_keys = list(base_expert.expert_weights.keys()) - - # Build n_tasks x D experts - # TODO: No need to build this matrix, can be done 1 expert at a time - expert_vectors = [] - for expert in experts: - expert_vectors += [ - torch.nn.utils.parameters_to_vector( - list(expert.expert_weights[k] for k in state_dict_keys) - ) - ] - - expert_vectors = torch.stack(expert_vectors, dim=0) - per_exp_th = expert_vectors.abs().quantile(1.0 - config.top_k, dim=1) - keep_param = expert_vectors.abs() >= per_exp_th.view(-1, 1) - - mean_valid_per_task = keep_param.float().mean(1) - assert torch.all((mean_valid_per_task - config.top_k).abs() < 1e-4) - - used, kept, total = 0, 0, 0 - - for param_name in state_dict_keys: - # stack the expert weights - expert_weights = torch.stack( - [expert.expert_weights[param_name] for expert in experts], dim=0 - ) - - # keep weights over the threshold - TH = per_exp_th.view(-1, *((1,) * (expert_weights.ndim - 1))) - keep_mask = expert_weights.abs() >= TH - expert_weights = expert_weights * keep_mask - - if config.only_sparsify: - final_param = expert_weights.mean(0) - used += keep_mask.sum().item() - else: - # sign majority vote - sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign() - sign_per_dim = expert_weights.sum(0, keepdim=True).sign() - - # keep only weights whose sign agree with the majority - use_for_avg = expert_weights.sign() == sign_per_dim - - deno = use_for_avg.sum(0).clamp(min=1.0) - sum_param = (expert_weights * use_for_avg).sum(0) - final_param = sum_param / deno - used += (use_for_avg & (sign_per_dim != 0.0)).sum().item() - - kept += (expert_weights.abs() > TH).sum() - total += expert_weights.numel() - - base_expert.expert_weights[param_name].data.copy_(final_param) - - logger.info( - "Params not reset to 0 in TIES merge: {:.10f}%".format(100.0 * kept / total) - ) - logger.info( - "Params used to compute TIES mean: {:.10f}%".format(100.0 * used / total) - ) - - # manually change the config of the expert to remove the tie_params - base_expert.expert_config.tie_params = None - - return base_expert \ No newline at end of file diff --git a/mttl/models/merging/__init__.py b/mttl/models/merging/__init__.py new file mode 100644 index 000000000..5286c9560 --- /dev/null +++ b/mttl/models/merging/__init__.py @@ -0,0 +1,23 @@ +""" +Standalone merging and transformation functions for Expert objects. + +This module provides standalone routines that take a list of Expert objects +as input and return merged or transformed experts. These functions are decoupled +from library transforms and can be used independently. +""" + +from .wudi import wudi_merge +from .weighted_linear import weighted_linear_merge +from .ties import ties_merge +from .arrow import arrow_transform +from .phatgoose import extract_phatgoose_prototypes, validate_phatgoose_training, initialize_phatgoose_gates + +__all__ = [ + "wudi_merge", + "weighted_linear_merge", + "ties_merge", + "arrow_transform", + "extract_phatgoose_prototypes", + "validate_phatgoose_training", + "initialize_phatgoose_gates" +] \ No newline at end of file diff --git a/mttl/models/merging/arrow.py b/mttl/models/merging/arrow.py new file mode 100644 index 000000000..a0a00342f --- /dev/null +++ b/mttl/models/merging/arrow.py @@ -0,0 +1,183 @@ +""" +Arrow transform algorithm for Expert objects. + +This module implements Arrow transform that extracts input directions most affected +by the linear transforms in expert weights. +""" + +import copy +from collections import defaultdict +from typing import Dict, List + +import torch + +from mttl.logging import logger +from mttl.models.library.expert import Expert +from mttl.models.library.library_transforms import ArrowTransformConfig +from mttl.models.modifiers.base import get_target_2_source_param_mapping + + +def arrow_transform(experts: List[Expert], config: ArrowTransformConfig) -> Dict[str, Dict[str, torch.Tensor]]: + """ + Apply Arrow transform to experts to extract input directions most affected by linear transforms. + + Args: + experts: List of Expert objects to transform + config: ArrowTransformConfig containing transform parameters + + Returns: + Dict mapping expert names to their Arrow prototypes (layer_name -> prototype vector) + """ + if not experts: + raise ValueError("Cannot transform empty list of experts") + + logger.info(f"Computing Arrow prototypes for {len(experts)} experts") + + vectors = {} + eigvals = {} + + for expert in experts: + expert_name = expert.name + logger.info(f"Computing SVD for expert {expert_name}") + vectors[expert_name] = {} + eigvals[expert_name] = {} + + # get parameters tied during training + param_map = get_target_2_source_param_mapping( + expert.expert_weights.items(), + expert.expert_info.expert_config.tie_params, + ) + if config.tie_params != "default": + # get parameters we wish to tie for Arrow + _tied_params = get_target_2_source_param_mapping( + expert.expert_weights.items(), config.tie_params + ) + # Make sure that params tied during training are also tied for Arrow + if any(key not in _tied_params for key in param_map): + logger.warning( + "Some parameters that are tied during training are not tied during Arrow computation." + ) + param_map = _tied_params + + tied_params = list(param_map.keys()) + list(param_map.values()) + assert all( + "lora_b" not in param_name for param_name in tied_params + ), "Support for tied B not available" + assert all( + "lora_a" in param_name for param_name in tied_params + ), "Only support tied As for now" + + # Now that we know only A's are tied, we can proceed using only the parent names + tied_parents = _get_unique_parent_names(tied_params) + untied_parents = [ + parent + for parent in _get_unique_parent_names(expert.expert_weights.keys()) + if parent not in tied_parents + ] + + # Build a mapping from source to target parameters + tied_param_bins = defaultdict(list) + for tgt_name, src_name in param_map.items(): + parent_src = ".".join(src_name.split(".")[:-1]) + parent_tgt = ".".join(tgt_name.split(".")[:-1]) + tied_param_bins[parent_src].append(parent_tgt) + for parent in untied_parents: + tied_param_bins[parent] = [] + + for parent_name, dependents in tied_param_bins.items(): + logger.info(f"\tComputing SVD for parameter {parent_name}") + + parent_names = [parent_name] + A_name, B_name = f"{parent_name}.lora_a", f"{parent_name}.lora_b" + As = [expert.expert_weights[A_name]] + Bs = [expert.expert_weights[B_name]] + + for tied_module in dependents: + logger.info(f"\t\t\tTying Arrow with {tied_module}") + As += [expert.expert_weights[f"{tied_module}.lora_a"]] + Bs += [expert.expert_weights[f"{tied_module}.lora_b"]] + parent_names += [tied_module] + + if len(As) > 1: + if config.tie_op == "concat": + # Mimicking phi-2 behavior + assert config.ab_only + assert all( + torch.allclose(A, As[0]) for A in As + ), "A should be the same for all tied parameters" + A = As[0] + B = torch.cat(Bs, dim=1) + elif config.tie_op == "sum": + # A1B1 + A2B2 == [A1 A2] [B1; B2]. + A = torch.cat(As, dim=1) + B = torch.cat(Bs, dim=0) + else: + raise NotImplementedError() + else: + A, B = As[0], Bs[0] + + # Reshape As and Bs (needed for Poly / MHR weights) + rank = expert.expert_config.lora_rank + A = A.reshape(-1, rank).float() + B = B.reshape(rank, -1).float() + + W = (A @ B).T # out_features, in_features + + if config.ab_only: + U_W, Sigma_W, _ = _low_rank_svd(A, B) + top_value = Sigma_W[0] ** 2 + top_vector = U_W[:, 0] + else: + raise NotImplementedError("Base model weights not supported in standalone function") + + # Save eigenvector and eigenvalue + for parent in parent_names: + assert parent not in vectors[expert_name] + vectors[expert_name][parent] = top_vector.real.cpu() + eigvals[expert_name][parent] = top_value.item() + + # Apply scaling if requested + if config.scale: + output = {} + for expert_name, expert_data in vectors.items(): + output[expert_name] = {} + for layer_name, vector in expert_data.items(): + vector = vector * eigvals[expert_name][layer_name] + output[expert_name][layer_name] = vector + return output + else: + return vectors + + +def _get_unique_parent_names(alist): + """ + if adict.keys() = ['model.layer1.lora_a', 'model.layer.lora_b', 'model.layer2.lora_a'] + output will be {'model.layer1', 'model.layer2'} + """ + dict_keys = sorted(list(set(".".join(k.split(".")[:-1]) for k in alist))) + return dict_keys + + +def _low_rank_svd(A, B): + """Faster SVD computation for low rank matrices""" + # Compute SVD of A + U_A, Sigma_A, V_A = torch.svd(A) + + # Compute SVD of B.T (transpose of B) + U_B, Sigma_B, V_B = torch.svd(B.T) + + # Compute product matrix C = Sigma_A * (V_A.T @ V_B) * Sigma_B + C = Sigma_A.diag_embed() @ V_A.t() @ V_B @ Sigma_B.diag_embed() + + # Compute SVD of the product matrix C + U_C, Sigma_C, V_C = torch.svd(C) + + # Construct the final SVD components of W + U_W = U_A @ U_C + V_W_T = V_C.t() @ U_B.t() + + diff_AB = (U_W.T @ U_A).abs().diag() + if diff_AB[0] < 0.9: + logger.debug("The first singular vector of U_A and U_AB are not aligned") + + return U_W, Sigma_C, V_W_T \ No newline at end of file diff --git a/mttl/models/merging/phatgoose.py b/mttl/models/merging/phatgoose.py new file mode 100644 index 000000000..01caecc9b --- /dev/null +++ b/mttl/models/merging/phatgoose.py @@ -0,0 +1,92 @@ +""" +Phatgoose transform algorithm for Expert objects. + +This module implements Phatgoose transform that computes prototype vectors +for expert selection through training selector gates. +""" + +import re +from typing import Dict, List + +import torch + +from mttl.logging import logger +from mttl.models.library.expert import Expert +from mttl.models.library.library_transforms import PhatgooseTransformConfig +from mttl.models.containers import ExpertContainer + + +def extract_phatgoose_prototypes(model) -> Dict[str, torch.Tensor]: + """ + Extract Phatgoose prototypes from a trained model with selector gates. + + Args: + model: Trained model with ExpertContainer modules containing selectors + + Returns: + Dict mapping layer names to prototype vectors + """ + prototypes = {} + for name, module in model.model.named_modules(): + if isinstance(module, ExpertContainer) and hasattr(module.selector, "get_prototypes"): + # expand dict + prototypes_module = {} + for k, v in module.selector.get_prototypes().items(): + prototypes_module[f"{name}.selector.{k}.v"] = v + prototypes = {**prototypes, **prototypes_module} + + return prototypes + + +def validate_phatgoose_training(model_state_before: Dict[str, torch.Tensor], + model_state_after: Dict[str, torch.Tensor]) -> bool: + """ + Validate that Phatgoose training only updated selector gates and not frozen parameters. + + Args: + model_state_before: Model state dict before training + model_state_after: Model state dict after training + + Returns: + True if training was valid, raises AssertionError otherwise + """ + frozen_sum_before, unfrozen_sum_before = 0, 0 + frozen_sum_after, unfrozen_sum_after = 0, 0 + + for key in model_state_before.keys(): + value_before = model_state_before[key] + value_after = model_state_after[key] + + if re.match(".*selector.gates.*.v", key): + unfrozen_sum_before += value_before.sum() + unfrozen_sum_after += value_after.sum() + else: + frozen_sum_before += value_before.sum() + frozen_sum_after += value_after.sum() + + assert frozen_sum_before == frozen_sum_after, "Frozen params changed during training" + assert unfrozen_sum_before != unfrozen_sum_after, "Unfrozen params did not change during training" + + return True + + +def initialize_phatgoose_gates(model) -> Dict[str, torch.Tensor]: + """ + Initialize Phatgoose selector gates to zero and return initial state. + + Args: + model: Model with selector gates to initialize + + Returns: + Dict of initial state for validation + """ + initial_state = {} + for key, value in model.state_dict().items(): + if re.match(".*selector.gates.*.v", key): + assert torch.allclose(value, torch.zeros_like(value)), "Gate should be 0 init" + value.requires_grad = True + else: + value.requires_grad = False + initial_state[key] = value.clone() + + return initial_state \ No newline at end of file diff --git a/mttl/models/merging/ties.py b/mttl/models/merging/ties.py new file mode 100644 index 000000000..fd9a43ddd --- /dev/null +++ b/mttl/models/merging/ties.py @@ -0,0 +1,99 @@ +""" +TIES merge algorithm for Expert objects. + +This module implements the TIES (Task Interference Elimination through Sparse merging) algorithm. +""" + +import copy +from typing import List + +import torch + +from mttl.logging import logger +from mttl.models.library.expert import Expert +from mttl.models.library.library_transforms import TiesMergeConfig + + +def ties_merge(experts: List[Expert], config: TiesMergeConfig) -> Expert: + """ + Merge experts using TIES merge algorithm. + + Args: + experts: List of Expert objects to merge + config: TiesMergeConfig containing merge parameters + + Returns: + Expert: Merged expert + """ + if not experts: + raise ValueError("Cannot merge empty list of experts") + + logger.info("Averaging {} experts".format(len(experts))) + + base_expert = copy.deepcopy(experts[0]) + base_expert.name = "ties_weighted_expert" + + state_dict_keys = list(base_expert.expert_weights.keys()) + + # Build n_tasks x D experts + # TODO: No need to build this matrix, can be done 1 expert at a time + expert_vectors = [] + for expert in experts: + expert_vectors += [ + torch.nn.utils.parameters_to_vector( + list(expert.expert_weights[k] for k in state_dict_keys) + ) + ] + + expert_vectors = torch.stack(expert_vectors, dim=0) + per_exp_th = expert_vectors.abs().quantile(1.0 - config.top_k, dim=1) + keep_param = expert_vectors.abs() >= per_exp_th.view(-1, 1) + + mean_valid_per_task = keep_param.float().mean(1) + assert torch.all((mean_valid_per_task - config.top_k).abs() < 1e-4) + + used, kept, total = 0, 0, 0 + + for param_name in state_dict_keys: + # stack the expert weights + expert_weights = torch.stack( + [expert.expert_weights[param_name] for expert in experts], dim=0 + ) + + # keep weights over the threshold + TH = per_exp_th.view(-1, *((1,) * (expert_weights.ndim - 1))) + keep_mask = expert_weights.abs() >= TH + expert_weights = expert_weights * keep_mask + + if config.only_sparsify: + final_param = expert_weights.mean(0) + used += keep_mask.sum().item() + else: + # sign majority vote + sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign() + sign_per_dim = expert_weights.sum(0, keepdim=True).sign() + + # keep only weights whose sign agree with the majority + use_for_avg = expert_weights.sign() == sign_per_dim + + deno = use_for_avg.sum(0).clamp(min=1.0) + sum_param = (expert_weights * use_for_avg).sum(0) + final_param = sum_param / deno + used += (use_for_avg & (sign_per_dim != 0.0)).sum().item() + + kept += (expert_weights.abs() > TH).sum() + total += expert_weights.numel() + + base_expert.expert_weights[param_name].data.copy_(final_param) + + logger.info( + "Params not reset to 0 in TIES merge: {:.10f}%".format(100.0 * kept / total) + ) + logger.info( + "Params used to compute TIES mean: {:.10f}%".format(100.0 * used / total) + ) + + # manually change the config of the expert to remove the tie_params + base_expert.expert_config.tie_params = None + + return base_expert \ No newline at end of file diff --git a/mttl/models/merging/weighted_linear.py b/mttl/models/merging/weighted_linear.py new file mode 100644 index 000000000..3aba76878 --- /dev/null +++ b/mttl/models/merging/weighted_linear.py @@ -0,0 +1,72 @@ +""" +Weighted linear merge algorithm for Expert objects. + +This module implements weighted linear averaging of expert parameters. +""" + +import copy +from typing import List + +from mttl.logging import logger +from mttl.models.library.expert import Expert +from mttl.models.library.library_transforms import WeightedLinearMergeConfig + + +def weighted_linear_merge(experts: List[Expert], config: WeightedLinearMergeConfig) -> Expert: + """ + Merge experts using weighted linear averaging. + + Args: + experts: List of Expert objects to merge + config: WeightedLinearMergeConfig containing merge parameters + + Returns: + Expert: Merged expert + """ + if not experts: + raise ValueError("Cannot merge empty list of experts") + + expert_names = [expert.name for expert in experts] + logger.info("Averaging {} experts".format(len(experts))) + + base_expert = copy.deepcopy(experts[0]) + base_expert.name = "weighted_expert" + + if config.weights is not None: + assert set(config.weights.keys()) == set( + expert_names + ), "Weights must have the same keys as the experts" + if not (1 - 1e-6) <= sum(config.weights.values()) <= (1 + 1e-6): + logger.warning( + "Weights do not sum to 1.0, please make sure this is intended" + ) + + # scale the base expert + for k, v in base_expert.expert_weights.items(): + base_expert.expert_weights[k] *= config.weights[expert_names[0]] + + for expert in experts[1:]: + # Validate that the expert is compatible + assert type(expert.expert_info.expert_config) == type( + base_expert.expert_info.expert_config + ), "Expert configs must be the same type" + assert set(expert.expert_weights.keys()) == set( + base_expert.expert_weights.keys() + ), "Expert weights must have the same keys" + + weight = 1.0 + if config.weights is not None: + weight = config.weights[expert.expert_info.expert_name] + + for k, v in expert.expert_weights.items(): + base_expert.expert_weights[k] += v * weight + + # Normalize the final expert + if config.weights is None: + for k, v in base_expert.expert_weights.items(): + base_expert.expert_weights[k] /= len(experts) + + # manually change the config of the expert to remove the tie_params + base_expert.expert_config.tie_params = None + + return base_expert \ No newline at end of file diff --git a/mttl/models/merging/wudi.py b/mttl/models/merging/wudi.py new file mode 100644 index 000000000..b09ba112d --- /dev/null +++ b/mttl/models/merging/wudi.py @@ -0,0 +1,75 @@ +""" +WuDi (Weight Disturbance) merge algorithm for Expert objects. + +This module implements the WuDi merge algorithm from https://arxiv.org/pdf/2503.08099v1 +""" + +import copy +from typing import List + +import torch +from tqdm.auto import tqdm + +from mttl.logging import logger +from mttl.models.library.expert import Expert +from mttl.models.library.library_transforms import WudiMergeConfig + + +def wudi_merge(experts: List[Expert], config: WudiMergeConfig) -> Expert: + """ + Merge experts using WuDi merge algorithm. + + Args: + experts: List of Expert objects to merge + config: WudiMergeConfig containing merge parameters + + Returns: + Expert: Merged expert + """ + if not experts: + raise ValueError("Cannot merge empty list of experts") + + device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info("Merging {} experts using WuDi merge".format(len(experts))) + + base_expert = copy.deepcopy(experts[0]) + base_expert.name = "wudi_merged_expert" + + # Get all parameter keys that we want to merge + keys = [key for key in base_expert.expert_weights.keys()] + + for key in keys: + # Stack all expert weights for this parameter + values = torch.stack([expert.expert_weights[key] for expert in experts]) + + values = values.to(device) + + # Initialize merged vector as sum of all vectors + merging_vector = torch.nn.Parameter(torch.sum(values, dim=0)) + optimizer = torch.optim.Adam( + [merging_vector], lr=config.lr, weight_decay=0 + ) + + # Compute L2 norms + l2_norms = torch.square( + torch.norm(values.reshape(values.shape[0], -1), p=2, dim=-1) + ) + + # Optimize merging vector + for _ in tqdm(range(config.iter), desc=f"Optimizing parameter {key}"): + disturbing_vectors = merging_vector.unsqueeze(0) - values + inner_product = torch.matmul(disturbing_vectors, values.transpose(1, 2)) + + loss = torch.sum( + torch.square(inner_product) / l2_norms.unsqueeze(-1).unsqueeze(-1) + ) + loss = loss.requires_grad_(True) # Ensure loss requires gradients + optimizer.zero_grad() + loss.backward() + optimizer.step() + + merging_vector = merging_vector / len(experts) + # Update base expert weights with optimized merging vector + base_expert.expert_weights[key].data.copy_(merging_vector.data.cpu()) + + return base_expert \ No newline at end of file From b2dea29c0d313317528fdcf7ed4199dd9cb60686 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Jun 2025 18:30:03 +0000 Subject: [PATCH 5/9] Fix circular import between library and merging modules --- mttl/models/library/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mttl/models/library/__init__.py b/mttl/models/library/__init__.py index 9239607c0..a553b468b 100644 --- a/mttl/models/library/__init__.py +++ b/mttl/models/library/__init__.py @@ -1 +1 @@ -from mttl.models.merging import ties_merge, weighted_linear_merge, wudi_merge +# Imports moved to avoid circular dependency with merging module From 67214fb105d51f6d8aa8a41fa4b49940c14ab3af Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Jun 2025 19:10:47 +0000 Subject: [PATCH 6/9] Format code with Black formatter Co-authored-by: sordonia <1689148+sordonia@users.noreply.github.com> --- mttl/models/library/library_transforms.py | 6 +-- mttl/models/merging/__init__.py | 20 ++++++---- mttl/models/merging/arrow.py | 22 ++++++----- mttl/models/merging/phatgoose.py | 48 ++++++++++++++--------- mttl/models/merging/ties.py | 8 ++-- mttl/models/merging/weighted_linear.py | 12 +++--- mttl/models/merging/wudi.py | 12 +++--- 7 files changed, 73 insertions(+), 55 deletions(-) diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index 4860910b2..3d1c00b85 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -348,7 +348,7 @@ def __init__(self, config: WudiMergeConfig = None): def transform(self, library) -> Expert: from mttl.models.merging import wudi_merge - + if type(library) == str: library = ExpertLibrary.get_expert_library(library) @@ -375,7 +375,7 @@ def __init__(self, config: WeightedLinearMergeConfig = None): @torch.no_grad() def transform(self, library) -> Expert: from mttl.models.merging import weighted_linear_merge - + if type(library) == str: library = ExpertLibrary.get_expert_library(library) @@ -405,7 +405,7 @@ def __init__(self, config: TiesMergeConfig = None): @torch.no_grad() def transform(self, library) -> Expert: from mttl.models.merging import ties_merge - + if type(library) == str: library = ExpertLibrary.get_expert_library(library) diff --git a/mttl/models/merging/__init__.py b/mttl/models/merging/__init__.py index 5286c9560..ca20604b1 100644 --- a/mttl/models/merging/__init__.py +++ b/mttl/models/merging/__init__.py @@ -2,22 +2,26 @@ Standalone merging and transformation functions for Expert objects. This module provides standalone routines that take a list of Expert objects -as input and return merged or transformed experts. These functions are decoupled +as input and return merged or transformed experts. These functions are decoupled from library transforms and can be used independently. """ from .wudi import wudi_merge -from .weighted_linear import weighted_linear_merge +from .weighted_linear import weighted_linear_merge from .ties import ties_merge from .arrow import arrow_transform -from .phatgoose import extract_phatgoose_prototypes, validate_phatgoose_training, initialize_phatgoose_gates +from .phatgoose import ( + extract_phatgoose_prototypes, + validate_phatgoose_training, + initialize_phatgoose_gates, +) __all__ = [ - "wudi_merge", - "weighted_linear_merge", + "wudi_merge", + "weighted_linear_merge", "ties_merge", "arrow_transform", "extract_phatgoose_prototypes", - "validate_phatgoose_training", - "initialize_phatgoose_gates" -] \ No newline at end of file + "validate_phatgoose_training", + "initialize_phatgoose_gates", +] diff --git a/mttl/models/merging/arrow.py b/mttl/models/merging/arrow.py index a0a00342f..4cae59d1b 100644 --- a/mttl/models/merging/arrow.py +++ b/mttl/models/merging/arrow.py @@ -1,7 +1,7 @@ """ Arrow transform algorithm for Expert objects. -This module implements Arrow transform that extracts input directions most affected +This module implements Arrow transform that extracts input directions most affected by the linear transforms in expert weights. """ @@ -17,25 +17,27 @@ from mttl.models.modifiers.base import get_target_2_source_param_mapping -def arrow_transform(experts: List[Expert], config: ArrowTransformConfig) -> Dict[str, Dict[str, torch.Tensor]]: +def arrow_transform( + experts: List[Expert], config: ArrowTransformConfig +) -> Dict[str, Dict[str, torch.Tensor]]: """ Apply Arrow transform to experts to extract input directions most affected by linear transforms. - + Args: experts: List of Expert objects to transform config: ArrowTransformConfig containing transform parameters - + Returns: Dict mapping expert names to their Arrow prototypes (layer_name -> prototype vector) """ if not experts: raise ValueError("Cannot transform empty list of experts") - + logger.info(f"Computing Arrow prototypes for {len(experts)} experts") - + vectors = {} eigvals = {} - + for expert in experts: expert_name = expert.name logger.info(f"Computing SVD for expert {expert_name}") @@ -128,7 +130,9 @@ def arrow_transform(experts: List[Expert], config: ArrowTransformConfig) -> Dict top_value = Sigma_W[0] ** 2 top_vector = U_W[:, 0] else: - raise NotImplementedError("Base model weights not supported in standalone function") + raise NotImplementedError( + "Base model weights not supported in standalone function" + ) # Save eigenvector and eigenvalue for parent in parent_names: @@ -180,4 +184,4 @@ def _low_rank_svd(A, B): if diff_AB[0] < 0.9: logger.debug("The first singular vector of U_A and U_AB are not aligned") - return U_W, Sigma_C, V_W_T \ No newline at end of file + return U_W, Sigma_C, V_W_T diff --git a/mttl/models/merging/phatgoose.py b/mttl/models/merging/phatgoose.py index 01caecc9b..18a999c28 100644 --- a/mttl/models/merging/phatgoose.py +++ b/mttl/models/merging/phatgoose.py @@ -19,74 +19,84 @@ def extract_phatgoose_prototypes(model) -> Dict[str, torch.Tensor]: """ Extract Phatgoose prototypes from a trained model with selector gates. - + Args: model: Trained model with ExpertContainer modules containing selectors - + Returns: Dict mapping layer names to prototype vectors """ prototypes = {} for name, module in model.model.named_modules(): - if isinstance(module, ExpertContainer) and hasattr(module.selector, "get_prototypes"): + if isinstance(module, ExpertContainer) and hasattr( + module.selector, "get_prototypes" + ): # expand dict prototypes_module = {} for k, v in module.selector.get_prototypes().items(): prototypes_module[f"{name}.selector.{k}.v"] = v prototypes = {**prototypes, **prototypes_module} - + return prototypes -def validate_phatgoose_training(model_state_before: Dict[str, torch.Tensor], - model_state_after: Dict[str, torch.Tensor]) -> bool: +def validate_phatgoose_training( + model_state_before: Dict[str, torch.Tensor], + model_state_after: Dict[str, torch.Tensor], +) -> bool: """ Validate that Phatgoose training only updated selector gates and not frozen parameters. - + Args: model_state_before: Model state dict before training model_state_after: Model state dict after training - + Returns: True if training was valid, raises AssertionError otherwise """ frozen_sum_before, unfrozen_sum_before = 0, 0 frozen_sum_after, unfrozen_sum_after = 0, 0 - + for key in model_state_before.keys(): value_before = model_state_before[key] value_after = model_state_after[key] - + if re.match(".*selector.gates.*.v", key): unfrozen_sum_before += value_before.sum() unfrozen_sum_after += value_after.sum() else: frozen_sum_before += value_before.sum() frozen_sum_after += value_after.sum() - - assert frozen_sum_before == frozen_sum_after, "Frozen params changed during training" - assert unfrozen_sum_before != unfrozen_sum_after, "Unfrozen params did not change during training" - + + assert ( + frozen_sum_before == frozen_sum_after + ), "Frozen params changed during training" + assert ( + unfrozen_sum_before != unfrozen_sum_after + ), "Unfrozen params did not change during training" + return True def initialize_phatgoose_gates(model) -> Dict[str, torch.Tensor]: """ Initialize Phatgoose selector gates to zero and return initial state. - + Args: model: Model with selector gates to initialize - + Returns: Dict of initial state for validation """ initial_state = {} for key, value in model.state_dict().items(): if re.match(".*selector.gates.*.v", key): - assert torch.allclose(value, torch.zeros_like(value)), "Gate should be 0 init" + assert torch.allclose( + value, torch.zeros_like(value) + ), "Gate should be 0 init" value.requires_grad = True else: value.requires_grad = False initial_state[key] = value.clone() - - return initial_state \ No newline at end of file + + return initial_state diff --git a/mttl/models/merging/ties.py b/mttl/models/merging/ties.py index fd9a43ddd..e001a746c 100644 --- a/mttl/models/merging/ties.py +++ b/mttl/models/merging/ties.py @@ -17,17 +17,17 @@ def ties_merge(experts: List[Expert], config: TiesMergeConfig) -> Expert: """ Merge experts using TIES merge algorithm. - + Args: experts: List of Expert objects to merge config: TiesMergeConfig containing merge parameters - + Returns: Expert: Merged expert """ if not experts: raise ValueError("Cannot merge empty list of experts") - + logger.info("Averaging {} experts".format(len(experts))) base_expert = copy.deepcopy(experts[0]) @@ -96,4 +96,4 @@ def ties_merge(experts: List[Expert], config: TiesMergeConfig) -> Expert: # manually change the config of the expert to remove the tie_params base_expert.expert_config.tie_params = None - return base_expert \ No newline at end of file + return base_expert diff --git a/mttl/models/merging/weighted_linear.py b/mttl/models/merging/weighted_linear.py index 3aba76878..8e1856d26 100644 --- a/mttl/models/merging/weighted_linear.py +++ b/mttl/models/merging/weighted_linear.py @@ -12,20 +12,22 @@ from mttl.models.library.library_transforms import WeightedLinearMergeConfig -def weighted_linear_merge(experts: List[Expert], config: WeightedLinearMergeConfig) -> Expert: +def weighted_linear_merge( + experts: List[Expert], config: WeightedLinearMergeConfig +) -> Expert: """ Merge experts using weighted linear averaging. - + Args: experts: List of Expert objects to merge config: WeightedLinearMergeConfig containing merge parameters - + Returns: Expert: Merged expert """ if not experts: raise ValueError("Cannot merge empty list of experts") - + expert_names = [expert.name for expert in experts] logger.info("Averaging {} experts".format(len(experts))) @@ -69,4 +71,4 @@ def weighted_linear_merge(experts: List[Expert], config: WeightedLinearMergeConf # manually change the config of the expert to remove the tie_params base_expert.expert_config.tie_params = None - return base_expert \ No newline at end of file + return base_expert diff --git a/mttl/models/merging/wudi.py b/mttl/models/merging/wudi.py index b09ba112d..a78e5830f 100644 --- a/mttl/models/merging/wudi.py +++ b/mttl/models/merging/wudi.py @@ -18,17 +18,17 @@ def wudi_merge(experts: List[Expert], config: WudiMergeConfig) -> Expert: """ Merge experts using WuDi merge algorithm. - + Args: experts: List of Expert objects to merge config: WudiMergeConfig containing merge parameters - + Returns: Expert: Merged expert """ if not experts: raise ValueError("Cannot merge empty list of experts") - + device = "cuda" if torch.cuda.is_available() else "cpu" logger.info("Merging {} experts using WuDi merge".format(len(experts))) @@ -46,9 +46,7 @@ def wudi_merge(experts: List[Expert], config: WudiMergeConfig) -> Expert: # Initialize merged vector as sum of all vectors merging_vector = torch.nn.Parameter(torch.sum(values, dim=0)) - optimizer = torch.optim.Adam( - [merging_vector], lr=config.lr, weight_decay=0 - ) + optimizer = torch.optim.Adam([merging_vector], lr=config.lr, weight_decay=0) # Compute L2 norms l2_norms = torch.square( @@ -72,4 +70,4 @@ def wudi_merge(experts: List[Expert], config: WudiMergeConfig) -> Expert: # Update base expert weights with optimized merging vector base_expert.expert_weights[key].data.copy_(merging_vector.data.cpu()) - return base_expert \ No newline at end of file + return base_expert From f2d6ee9c295426e7046be637099f7fe3367c1ec6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 11 Jun 2025 00:26:45 +0000 Subject: [PATCH 7/9] Update wudi_merge function with improvements from main branch Co-authored-by: sordonia <1689148+sordonia@users.noreply.github.com> --- mttl/models/merging/wudi.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/mttl/models/merging/wudi.py b/mttl/models/merging/wudi.py index a78e5830f..7b52f8314 100644 --- a/mttl/models/merging/wudi.py +++ b/mttl/models/merging/wudi.py @@ -45,7 +45,9 @@ def wudi_merge(experts: List[Expert], config: WudiMergeConfig) -> Expert: values = values.to(device) # Initialize merged vector as sum of all vectors - merging_vector = torch.nn.Parameter(torch.sum(values, dim=0)) + merging_vector = torch.nn.Parameter( + torch.sum(values, dim=0), requires_grad=True + ) optimizer = torch.optim.Adam([merging_vector], lr=config.lr, weight_decay=0) # Compute L2 norms @@ -54,18 +56,37 @@ def wudi_merge(experts: List[Expert], config: WudiMergeConfig) -> Expert: ) # Optimize merging vector - for _ in tqdm(range(config.iter), desc=f"Optimizing parameter {key}"): + pbar = tqdm(range(config.iter), desc=f"Optimizing parameter {key}") + prev_loss = float("inf") + patience = 5 # Number of steps to wait for improvement + no_improve_count = 0 + min_delta = 1e-4 # Minimum change in loss to be considered improvement + + for step in pbar: disturbing_vectors = merging_vector.unsqueeze(0) - values inner_product = torch.matmul(disturbing_vectors, values.transpose(1, 2)) loss = torch.sum( torch.square(inner_product) / l2_norms.unsqueeze(-1).unsqueeze(-1) ) - loss = loss.requires_grad_(True) # Ensure loss requires gradients optimizer.zero_grad() loss.backward() optimizer.step() + # Check if loss improvement is significant + if abs(prev_loss - loss.item()) < min_delta: + no_improve_count += 1 + else: + no_improve_count = 0 + + # Early stopping if no significant improvement for patience steps + if no_improve_count >= patience: + logger.info(f"Early stopping at step {step} due to minimal loss change") + break + + prev_loss = loss.item() + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + merging_vector = merging_vector / len(experts) # Update base expert weights with optimized merging vector base_expert.expert_weights[key].data.copy_(merging_vector.data.cpu()) From 5c543bf13d4b28c48344cc8edeeed3644ba27fe9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 11 Jun 2025 00:44:58 +0000 Subject: [PATCH 8/9] Add wudi_merge_after standalone function and refactor WudiMergeAfter class Co-authored-by: sordonia <1689148+sordonia@users.noreply.github.com> --- mttl/models/library/library_transforms.py | 105 +--------------- mttl/models/merging/__init__.py | 3 +- mttl/models/merging/wudi.py | 140 +++++++++++++++++++++- 3 files changed, 143 insertions(+), 105 deletions(-) diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index 3d1c00b85..0a6c7a2b4 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -222,119 +222,18 @@ class WudiMergeAfter(LibraryTransform): def __init__(self, config: WudiMergeConfig = None): super().__init__(config or WudiMergeConfig()) - def _get_task_vectors(self, expert): - """ - get the incremental weights for each layer, LoRA A outproduct LoRA B - """ - task_vectors = {} - for key in expert.expert_weights.keys(): - base_layer_name = key.split(".lora_")[ - 0 - ] # Get base layer name by removing .lora_a or .lora_b - if base_layer_name not in task_vectors: - task_vectors[base_layer_name] = None - - for layer in task_vectors.keys(): - lora_a = expert.expert_weights[f"{layer}.lora_a"] - lora_b = expert.expert_weights[f"{layer}.lora_b"] - task_vectors[layer] = lora_a.data @ lora_b.data - - return task_vectors - - def get_optimized_task_vector( - self, layer_name, task_vectors, iter, lr - ) -> torch.Tensor: - """ - min Σᵢ (1/||τᵢ,ₗ||²F) ||(τₘ,ₗ - τᵢ,ₗ)(τᵢ,ₗ)ᵀ||²F - - return the optimized merged task vector for each layer - """ - task_vectors = task_vectors.cuda() - merging_vector = torch.nn.Parameter((torch.sum(task_vectors, dim=0))) - optimizer = torch.optim.Adam([merging_vector], lr=lr, weight_decay=0) - - l2_norms = torch.square( - torch.norm(task_vectors.reshape(task_vectors.shape[0], -1), p=2, dim=-1) - ) - - pbar = tqdm(range(iter), desc=f"Optimizing parameter {layer_name}") - prev_loss = float("inf") - patience = 5 # Number of steps to wait for improvement - no_improve_count = 0 - min_delta = 1e-4 # Minimum change in loss to be considered improvement - - for step in pbar: - disturbing_vectors = merging_vector.unsqueeze(0) - task_vectors - inner_product = torch.matmul( - disturbing_vectors, task_vectors.transpose(1, 2) - ) - loss = torch.sum( - torch.square(inner_product) / l2_norms.unsqueeze(-1).unsqueeze(-1) - ) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # Check if loss improvement is significant - if abs(prev_loss - loss.item()) < min_delta: - no_improve_count += 1 - else: - no_improve_count = 0 - - # Early stopping if no significant improvement for patience steps - if no_improve_count >= patience: - logger.info(f"Early stopping at step {step} due to minimal loss change") - break - - prev_loss = loss.item() - pbar.set_postfix({"loss": f"{loss.item():.4f}"}) - return merging_vector - def transform(self, library, persist=True, recompute=False) -> dict: """ return the task merged vectors in each layer """ + from mttl.models.merging import wudi_merge_after if type(library) == str: library = ExpertLibrary.get_expert_library(library) expert_names = list(library.keys()) experts = [library[name] for name in expert_names] - logger.info("Merging {} experts using WuDi merge".format(len(experts))) - - one_expert = experts[0] - # get the layer names from the model - layer_names = [ - name.split(".lora_")[0] for name in one_expert.expert_weights.keys() - ] - layer_names = list(set(layer_names)) - - # get the task vectors for each expert - task_vectors_experts = {} - for expert in experts: - task_vectors = self._get_task_vectors(expert) - task_vectors_experts[expert.name] = task_vectors - task_merged_vectors = {} - # wudi merge the task vectors - for layer in layer_names: - - # get the experts for this layer - task_vectors = [ - task_vectors_experts[expert.name][layer] for expert in experts - ] - - task_vectors = torch.stack(task_vectors, dim=0) - # get the redundant task vector - merged_task_vector = self.get_optimized_task_vector( - layer_name=layer, - task_vectors=task_vectors, - iter=self.config.iter, - lr=self.config.lr, - ) - # save the merged task vector in each layer - task_merged_vectors[layer] = merged_task_vector / len(experts) - return task_merged_vectors + return wudi_merge_after(experts, self.config) @LibraryTransform.register("wudi_merge", WudiMergeConfig) diff --git a/mttl/models/merging/__init__.py b/mttl/models/merging/__init__.py index ca20604b1..b0f0b87b3 100644 --- a/mttl/models/merging/__init__.py +++ b/mttl/models/merging/__init__.py @@ -6,7 +6,7 @@ from library transforms and can be used independently. """ -from .wudi import wudi_merge +from .wudi import wudi_merge, wudi_merge_after from .weighted_linear import weighted_linear_merge from .ties import ties_merge from .arrow import arrow_transform @@ -18,6 +18,7 @@ __all__ = [ "wudi_merge", + "wudi_merge_after", "weighted_linear_merge", "ties_merge", "arrow_transform", diff --git a/mttl/models/merging/wudi.py b/mttl/models/merging/wudi.py index 7b52f8314..e7260d31c 100644 --- a/mttl/models/merging/wudi.py +++ b/mttl/models/merging/wudi.py @@ -5,7 +5,7 @@ """ import copy -from typing import List +from typing import Dict, List import torch from tqdm.auto import tqdm @@ -92,3 +92,141 @@ def wudi_merge(experts: List[Expert], config: WudiMergeConfig) -> Expert: base_expert.expert_weights[key].data.copy_(merging_vector.data.cpu()) return base_expert + + +def wudi_merge_after( + experts: List[Expert], config: WudiMergeConfig +) -> Dict[str, torch.Tensor]: + """ + Merge experts using WuDi merge algorithm after computing task vectors (LoRA A @ LoRA B). + + This variant computes the outer product of LoRA A and B matrices first, then applies + WuDi merge to the resulting task vectors for each layer. + + Args: + experts: List of Expert objects to merge + config: WudiMergeConfig containing merge parameters + + Returns: + Dict[str, torch.Tensor]: Dictionary mapping layer names to merged task vectors + """ + if not experts: + raise ValueError("Cannot merge empty list of experts") + + logger.info("Merging {} experts using WuDi merge after".format(len(experts))) + + one_expert = experts[0] + # Get the layer names from the model + layer_names = [name.split(".lora_")[0] for name in one_expert.expert_weights.keys()] + layer_names = list(set(layer_names)) + + # Get the task vectors for each expert + task_vectors_experts = {} + for expert in experts: + task_vectors = _get_task_vectors(expert) + task_vectors_experts[expert.name] = task_vectors + + task_merged_vectors = {} + # WuDi merge the task vectors + for layer in layer_names: + # Get the experts for this layer + task_vectors = [task_vectors_experts[expert.name][layer] for expert in experts] + + task_vectors = torch.stack(task_vectors, dim=0) + # Get the redundant task vector + merged_task_vector = _get_optimized_task_vector( + layer_name=layer, + task_vectors=task_vectors, + iter=config.iter, + lr=config.lr, + ) + + # Save the merged task vector in each layer + task_merged_vectors[layer] = merged_task_vector / len(experts) + + return task_merged_vectors + + +def _get_task_vectors(expert: Expert) -> Dict[str, torch.Tensor]: + """ + Get the incremental weights for each layer, LoRA A outer product LoRA B. + + Args: + expert: Expert object containing LoRA weights + + Returns: + Dict[str, torch.Tensor]: Dictionary mapping layer names to task vectors + """ + task_vectors = {} + for key in expert.expert_weights.keys(): + base_layer_name = key.split(".lora_")[ + 0 + ] # Get base layer name by removing .lora_a or .lora_b + if base_layer_name not in task_vectors: + task_vectors[base_layer_name] = None + + for layer in task_vectors.keys(): + lora_a = expert.expert_weights[f"{layer}.lora_a"] + lora_b = expert.expert_weights[f"{layer}.lora_b"] + task_vectors[layer] = lora_a.data @ lora_b.data + + return task_vectors + + +def _get_optimized_task_vector( + layer_name: str, task_vectors: torch.Tensor, iter: int, lr: float +) -> torch.Tensor: + """ + Minimize Σᵢ (1/||τᵢ,ₗ||²F) ||(τₘ,ₗ - τᵢ,ₗ)(τᵢ,ₗ)ᵀ||²F + + Return the optimized merged task vector for each layer. + + Args: + layer_name: Name of the layer being optimized + task_vectors: Stacked task vectors for the layer + iter: Number of optimization iterations + lr: Learning rate for optimization + + Returns: + torch.Tensor: Optimized merged task vector + """ + task_vectors = task_vectors.cuda() + merging_vector = torch.nn.Parameter((torch.sum(task_vectors, dim=0))) + optimizer = torch.optim.Adam([merging_vector], lr=lr, weight_decay=0) + + l2_norms = torch.square( + torch.norm(task_vectors.reshape(task_vectors.shape[0], -1), p=2, dim=-1) + ) + + pbar = tqdm(range(iter), desc=f"Optimizing parameter {layer_name}") + prev_loss = float("inf") + patience = 5 # Number of steps to wait for improvement + no_improve_count = 0 + min_delta = 1e-4 # Minimum change in loss to be considered improvement + + for step in pbar: + disturbing_vectors = merging_vector.unsqueeze(0) - task_vectors + inner_product = torch.matmul(disturbing_vectors, task_vectors.transpose(1, 2)) + loss = torch.sum( + torch.square(inner_product) / l2_norms.unsqueeze(-1).unsqueeze(-1) + ) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Check if loss improvement is significant + if abs(prev_loss - loss.item()) < min_delta: + no_improve_count += 1 + else: + no_improve_count = 0 + + # Early stopping if no significant improvement for patience steps + if no_improve_count >= patience: + logger.info(f"Early stopping at step {step} due to minimal loss change") + break + + prev_loss = loss.item() + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + + return merging_vector From c47a536ee169aaab4711c5d29336dade97afa2a4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Jun 2025 01:41:37 +0000 Subject: [PATCH 9/9] Refactor ArrowTransform to use standalone arrow_transform function Co-authored-by: sordonia <1689148+sordonia@users.noreply.github.com> --- mttl/models/library/library_transforms.py | 255 ++++++++++++---------- 1 file changed, 139 insertions(+), 116 deletions(-) diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index 0a6c7a2b4..b95ea45f5 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -725,93 +725,19 @@ def _maybe_scale(self, vectors, eigvals): output[expert_name][layer_name] = torch.from_numpy(vector) return output - def _low_rank_svd(self, A, B): - """Faster SVD computation for low rank matrices""" - - # Compute SVD of A - U_A, Sigma_A, V_A = torch.svd(A) - - # Compute SVD of B.T (transpose of B) - U_B, Sigma_B, V_B = torch.svd(B.T) - - # Compute product matrix C = Sigma_A * (V_A.T @ V_B) * Sigma_B - # Since V_A and V_B are orthogonal, their product is also an orthogonal matrix - C = Sigma_A.diag_embed() @ V_A.t() @ V_B @ Sigma_B.diag_embed() - - # Compute SVD of the product matrix C - U_C, Sigma_C, V_C = torch.svd(C) - - # Construct the final SVD components of W - U_W = U_A @ U_C - V_W_T = V_C.t() @ U_B.t() - - diff_AB = (U_W.T @ U_A).abs().diag() - if diff_AB[0] < 0.9: - logger.debug("The first singular vector of U_A and U_AB are not aligned") - - return U_W, Sigma_C, V_W_T - - def _get_unique_parent_names(self, alist): - """ - if adict.keys() = ['model.layer1.lora_a', 'model.layer.lora_b', 'model.layer2.lora_a'] - output will be {'model.layer1', 'model.layer2'} - """ - dict_keys = sorted(list(set(".".join(k.split(".")[:-1]) for k in alist))) - return dict_keys - - @classmethod - @torch.no_grad() - def fetch(cls, library: Union[str, ExpertLibrary], config_hash: str): - """Fetch arrow prototypes from the library, raises ValueError if they are not computed. - - Args: - library (Union[str, ExpertLibrary]): ExpertLibrary object or its name - scale (bool): If True, scale the output by the eigenvalue - """ - if not isinstance(library, ExpertLibrary): - library = ExpertLibrary.get_expert_library(library) - - config_hash = config_hash or ArrowTransformConfig().save_name - - # try to fetch auxiliary data - protos = library.get_auxiliary_data(data_type=config_hash + "_protos") - return protos - - @torch.no_grad() - def transform( - self, - library, - persist=True, - recompute=False, - ) -> Expert: - logger.info("Arrow save name : {}".format(self.config.save_name)) - - if isinstance(library, str): - library = ExpertLibrary.get_expert_library(library) - + def _compute_with_base_model(self, experts_to_compute, library): + """Handle arrow computation when ab_only=False (includes base model weights)""" base_model = None - - # Try to fetch the precomputed Arrow prototypes - protos = self.fetch(library, self.config.save_name) - already_computed = [] - vectors = {} eigvals = {} - for expert_name, expert in library.items(): - if expert_name in protos and not recompute: - logger.info( - "Found precomputed Arrow prototypes for expert {}".format( - expert_name - ) - ) - already_computed.append(expert_name) - continue + for expert in experts_to_compute: + expert_name = expert.name logger.info(f"Computing SVD for expert {expert_name}") vectors[expert_name] = {} eigvals[expert_name] = {} - if base_model is None and not self.config.ab_only: + if base_model is None: training_config = expert.training_config training_config.model_modifier = None from mttl.models.lightning.expert_module import MultiExpertModule @@ -844,7 +770,6 @@ def transform( ), "Only support tied As for now" # Now that we know only A's are tied, we can proceed using only the parent names - # e.g. 'model.layers.30.self_attn.q_proj' instead of 'model.layers.30.self_attn.q_proj.lora_a' tied_parents = self._get_unique_parent_names(tied_params) untied_parents = [ @@ -856,8 +781,6 @@ def transform( ] # Build a mapping from source to target parameters - # e.g. : [] - # NOTE: list will be empty if the param is not tied to anything tied_param_bins = defaultdict(list) for tgt_name, src_name in param_map.items(): @@ -882,23 +805,16 @@ def transform( Bs += [expert.expert_weights[f"{tied_module}.lora_b"]] parent_names += [tied_module] - if not self.config.ab_only: - base_W += [ - base_model.model.state_dict()[f"{tied_module}.weight"] - ] + base_W += [base_model.model.state_dict()[f"{tied_module}.weight"]] if len(As) > 1: if self.config.tie_op == "concat": - # Mimicking phi-2 behavior - assert self.config.ab_only - assert all( - torch.allclose(A, As[0]) for A in As - ), "A should be the same for all tied parameters" - A = As[0] - B = torch.cat(Bs, dim=1) + # This shouldn't be used with base model + raise NotImplementedError( + "concat not supported with base model" + ) elif self.config.tie_op == "sum": # A1B1 + A2B2 == [A1 A2] [B1; B2]. - # We do it this way to leverage the low-rank SVD A = torch.cat(As, dim=1) B = torch.cat(Bs, dim=0) else: @@ -913,21 +829,16 @@ def transform( W = (A @ B).T # out_features, in_features - if self.config.ab_only: - U_W, Sigma_W, _ = self._low_rank_svd(A, B) - top_value = Sigma_W[0] ** 2 - bottom_vector = U_W[:, -1] - top_vector = U_W[:, 0] - else: - base_W += [ - base_model.model.state_dict()[f"{parent_name}.weight"] - ].float() - base_W = torch.stack(base_W).sum(0) - W += base_W - U, E, Vt = torch.linalg.svd(W) - top_vector = Vt[0] - bottom_vector = Vt[-1] - top_value = E[0] + # Add base model weights + base_W += [ + base_model.model.state_dict()[f"{parent_name}.weight"] + ].float() + base_W = torch.stack(base_W).sum(0) + W += base_W + U, E, Vt = torch.linalg.svd(W) + top_vector = Vt[0] + bottom_vector = Vt[-1] + top_value = E[0] # Check that top vector is indeed an eigenvector WTW = W.T @ W @@ -939,17 +850,129 @@ def transform( 2 ).sum() - # Save eigenvector and eigvenvalue + # Save eigenvector and eigenvalue for parent in parent_names: assert parent not in vectors[expert_name] - vectors[expert_name][parent] = top_vector.real.cpu().numpy() + vectors[expert_name][parent] = top_vector.real.cpu() eigvals[expert_name][parent] = top_value.item() - to_upload = [x for x in library.keys() if x not in already_computed] - new_protos = self._maybe_scale(vectors, eigvals) + return self._maybe_scale(vectors, eigvals) + + def _low_rank_svd(self, A, B): + """Faster SVD computation for low rank matrices""" + + # Compute SVD of A + U_A, Sigma_A, V_A = torch.svd(A) + + # Compute SVD of B.T (transpose of B) + U_B, Sigma_B, V_B = torch.svd(B.T) + + # Compute product matrix C = Sigma_A * (V_A.T @ V_B) * Sigma_B + # Since V_A and V_B are orthogonal, their product is also an orthogonal matrix + C = Sigma_A.diag_embed() @ V_A.t() @ V_B @ Sigma_B.diag_embed() + + # Compute SVD of the product matrix C + U_C, Sigma_C, V_C = torch.svd(C) + + # Construct the final SVD components of W + U_W = U_A @ U_C + V_W_T = V_C.t() @ U_B.t() + + diff_AB = (U_W.T @ U_A).abs().diag() + if diff_AB[0] < 0.9: + logger.debug("The first singular vector of U_A and U_AB are not aligned") + + return U_W, Sigma_C, V_W_T + + def _get_unique_parent_names(self, alist): + """ + if adict.keys() = ['model.layer1.lora_a', 'model.layer.lora_b', 'model.layer2.lora_a'] + output will be {'model.layer1', 'model.layer2'} + """ + dict_keys = sorted(list(set(".".join(k.split(".")[:-1]) for k in alist))) + return dict_keys + + @classmethod + @torch.no_grad() + def fetch(cls, library: Union[str, ExpertLibrary], config_hash: str): + """Fetch arrow prototypes from the library, raises ValueError if they are not computed. + + Args: + library (Union[str, ExpertLibrary]): ExpertLibrary object or its name + scale (bool): If True, scale the output by the eigenvalue + """ + if not isinstance(library, ExpertLibrary): + library = ExpertLibrary.get_expert_library(library) + + config_hash = config_hash or ArrowTransformConfig().save_name + + # try to fetch auxiliary data + protos = library.get_auxiliary_data(data_type=config_hash + "_protos") + return protos + + @torch.no_grad() + def transform( + self, + library, + persist=True, + recompute=False, + ) -> Expert: + from mttl.models.merging import arrow_transform + + logger.info("Arrow save name : {}".format(self.config.save_name)) + + if isinstance(library, str): + library = ExpertLibrary.get_expert_library(library) + + # Try to fetch the precomputed Arrow prototypes + protos = self.fetch(library, self.config.save_name) + already_computed = [] + + # Find experts that need to be computed + experts_to_compute = [] + for expert_name, expert in library.items(): + if expert_name in protos and not recompute: + logger.info( + "Found precomputed Arrow prototypes for expert {}".format( + expert_name + ) + ) + already_computed.append(expert_name) + else: + experts_to_compute.append(expert) + + # Use standalone function for experts that need computation + if experts_to_compute: + if self.config.ab_only: + # Use the standalone function for ab_only case + new_protos = arrow_transform(experts_to_compute, self.config) + else: + # Keep existing logic for base model case (ab_only=False) + new_protos = self._compute_with_base_model(experts_to_compute, library) + else: + new_protos = {} + + # Handle persistence + if persist and new_protos: + # Convert to the format expected by persistence + vectors = {} + eigvals = {} + for expert_name, expert_protos in new_protos.items(): + vectors[expert_name] = {} + eigvals[expert_name] = {} + for layer_name, vector in expert_protos.items(): + if self.config.scale: + # If scaling was applied, we need to extract the unscaled vector + # This is complex, so for now we'll store the scaled vector as is + vectors[expert_name][layer_name] = vector.cpu().numpy() + eigvals[expert_name][layer_name] = 1.0 # Placeholder + else: + vectors[expert_name][layer_name] = vector.cpu().numpy() + eigvals[expert_name][layer_name] = 1.0 # Placeholder + + to_upload = list(new_protos.keys()) + formatted_protos = self._maybe_scale(vectors, eigvals) - if persist and len(to_upload) > 0: - # add embeddings to the library with library.batched_commit(): for expert_name in to_upload: logger.info( @@ -958,7 +981,7 @@ def transform( for data_name, data in [ ("vectors", vectors), ("eigvals", eigvals), - ("protos", new_protos), + ("protos", formatted_protos), ]: library.add_auxiliary_data( data_type=self.config.save_name + "_" + data_name,