Commit 6a4fa73a authored by discus0434's avatar discus0434
Browse files

small fix

parent 97749b7c
Loading
Loading
Loading
Loading
+3 −4
Original line number Diff line number Diff line
@@ -51,10 +51,9 @@ class HypernetworkModule(torch.nn.Module):
            if add_layer_norm:
                linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))

            # Add dropout
            if use_dropout:
                p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2
                linears.append(torch.nn.Dropout(p=p))
            # Add dropout expect last layer
            if use_dropout and i < len(layer_structure) - 3:
                linears.append(torch.nn.Dropout(p=0.3))

        self.linear = torch.nn.Sequential(*linears)