mirror of
https://github.com/WongKinYiu/yolov7.git
synced 2025-02-23 12:59:13 +08:00
main code
update default activation function
This commit is contained in:
parent
e44853eb4b
commit
09b8e34ae5
@ -102,7 +102,7 @@ class Conv(nn.Module):
|
||||
super(Conv, self).__init__()
|
||||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
|
||||
self.bn = nn.BatchNorm2d(c2)
|
||||
self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
|
||||
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.bn(self.conv(x)))
|
||||
@ -477,7 +477,7 @@ class RepConv(nn.Module):
|
||||
|
||||
padding_11 = autopad(k, p) - k // 2
|
||||
|
||||
self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
|
||||
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
|
||||
|
||||
if deploy:
|
||||
self.rbr_reparam = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=True)
|
||||
|
Loading…
Reference in New Issue
Block a user