diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index d647ea55e..54346b64a 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -41,8 +41,8 @@ class HypernetworkModule(torch.nn.Module): # Add a fully-connected layer linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) - # Add an activation func - if activation_func == "linear" or activation_func is None: + # Add an activation func except last layer + if activation_func == "linear" or activation_func is None or i >= len(layer_structure) - 3: pass elif activation_func in self.activation_dict: linears.append(self.activation_dict[activation_func]()) @@ -53,7 +53,7 @@ class HypernetworkModule(torch.nn.Module): if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) - # Add dropout expect last layer + # Add dropout except last layer if use_dropout and i < len(layer_structure) - 3: linears.append(torch.nn.Dropout(p=0.3))