Commit 03a1e288 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

turns out LayerNorm also has weight and bias and needs to be pre-multiplied...

turns out LayerNorm also has weight and bias and needs to be pre-multiplied and trained for hypernets
parent e4877722
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -52,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
            self.load_state_dict(state_dict)
        else:
            for layer in self.linear:
                if type(layer) == torch.nn.Linear:
                if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
                    layer.weight.data.normal_(mean=0.0, std=0.01)
                    layer.bias.data.zero_()

@@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
    def trainables(self):
        layer_structure = []
        for layer in self.linear:
            if type(layer) == torch.nn.Linear:
            if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
                layer_structure += [layer.weight, layer.bias]
        return layer_structure