2020-07-21 17:41发布
在pytorch中,网络有train和eval两种模式,在train模式下,dropout和batch normalization会生效,而val模式下,dropout不生效,bn固定参数。
想要在测试的时候使用dropout,可以把dropout单独设为train模式,这里可以使用apply函数:
def
apply_dropout(m):
if
type
(m)
=
nn.Dropout:
m.train()
# coding: utf-8
import
torch
torch.nn as nn
numpy as np
class
SimpleNet(nn.Module):
__init__(
self
):
super
(SimpleNet,
).__init__()
.fc
nn.Linear(
8
,
)
.dropout
nn.Dropout(
0.5
forward(
, x):
x
.fc(x)
.dropout(x)
return
net
SimpleNet()
torch.FloatTensor([
1
]
*
net.train()
y
net(x)
print
(
'train mode result: '
, y)
net.
eval
()
'eval mode result: '
'eval2 mode result: '
apply
(apply_dropout)
'apply eval result:'
运行结果:
可以看到,在eval模式下,由于dropout未生效,每次跑的结果不同,利用apply函数,将Dropout单独设为train模式,dropout就生效了。
最多设置5个标签!
在pytorch中,网络有train和eval两种模式,在train模式下,dropout和batch normalization会生效,而val模式下,dropout不生效,bn固定参数。
想要在测试的时候使用dropout,可以把dropout单独设为train模式,这里可以使用apply函数:
def
apply_dropout(m):
if
type
(m)
=
=
nn.Dropout:
m.train()
下面是完整demo代码:
# coding: utf-8
import
torch
import
torch.nn as nn
import
numpy as np
class
SimpleNet(nn.Module):
def
__init__(
self
):
super
(SimpleNet,
self
).__init__()
self
.fc
=
nn.Linear(
8
,
8
)
self
.dropout
=
nn.Dropout(
0.5
)
def
forward(
self
, x):
x
=
self
.fc(x)
x
=
self
.dropout(x)
return
x
net
=
SimpleNet()
x
=
torch.FloatTensor([
1
]
*
8
)
net.train()
y
=
net(x)
print
(
'train mode result: '
, y)
net.
eval
()
y
=
net(x)
print
(
'eval mode result: '
, y)
net.
eval
()
y
=
net(x)
print
(
'eval2 mode result: '
, y)
def
apply_dropout(m):
if
type
(m)
=
=
nn.Dropout:
m.train()
net.
eval
()
net.
apply
(apply_dropout)
y
=
net(x)
print
(
'apply eval result:'
, y)
运行结果:
可以看到,在eval模式下,由于dropout未生效,每次跑的结果不同,利用apply函数,将Dropout单独设为train模式,dropout就生效了。
一周热门 更多>