import torch
from torch import nn
def comp_conv2d(conv2d, X):
X = X.reshape((1, 1) + X.shape)
Y = conv2d(X)
return Y.reshape(Y.shape[2:])
conv2d = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1)
X = torch.rand(size=(8, 8))
print(comp_conv2d(conv2d, X).shape) # 8+2*2-5+1, 8-3+1+1*2
# 将高度和步副设置为2
conv2d = nn.Conv2d(1, 1, kernel_size=3, stride=2, padding=1)
print(comp_conv2d(conv2d, X).shape)
#稍微复杂的例子
conv2d = nn.Conv2d(1, 1, kernel_size=(3,5), padding=(0, 1), stride=(3,4))
print(comp_conv2d(conv2d, X).shape)
将高度和步副设置为2
稍微复杂的例子