Last active
June 19, 2024 02:13
-
-
Save Algomancer/8f0b8d7cc26657659af663d9ab1721a0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class SubspaceLinear(nn.Linear): | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Forward pass for the BaseSubspaceLinear layer. Calls `subspace_weights` to sample from the subspace | |
and uses the corresponding weight and bias. | |
Parameters | |
---------- | |
x : torch.Tensor | |
The input tensor. | |
Returns | |
------- | |
torch.Tensor | |
The output tensor after applying the linear transformation. | |
""" | |
w, b = self.subspace_weights() | |
return F.linear(x, w, b) | |
class LineLinear(SubspaceLinear): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.weight_alt = nn.Parameter(torch.zeros_like(self.weight)) | |
self.bias_alt = nn.Parameter(torch.zeros_like(self.bias)) | |
def initialize_parameters(self, init_fn) -> None: | |
""" | |
Initialize the additional weights using a provided function. | |
Parameters | |
---------- | |
init_fn : Callable | |
The function to initialize the weights. | |
""" | |
init_fn(self.weight_alt) | |
init_fn(self.bias_alt) | |
def subspace_weights(self) -> tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Compute the weight and bias as a linear combination of two sets of parameters. | |
Returns | |
------- | |
tuple[torch.Tensor, torch.Tensor] | |
The combined weight and bias tensors. | |
""" | |
w = (1 - self.alpha) * self.weight + self.alpha * self.weight_alt | |
b = (1 - self.alpha) * self.bias + self.alpha * self.bias_alt | |
return w, b | |
# Then during training loop, randomly set alpha uniform[0, 1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment