컴퓨터비전(CV)/논문 구현
[논문 구현]Fully Convolutional Networks for Semantic Segmentation(FCN)
jun0823
2022. 3. 16. 21:30
반응형
https://github.com/327aem/paper_review/tree/main/Segmentation
GitHub - 327aem/paper_review: 논문 리뷰/구현
논문 리뷰/구현. Contribute to 327aem/paper_review development by creating an account on GitHub.
github.com
import torch
import torch.nn as nn
import torchvision.models as models
class FCN16(nn.Module):
def __init__(self,model,hidden =64,kernel_size=3,padding = 1, num_classes = 21):
super(FCN16,self).__init__()
self.block1 = nn.Sequential(*list(model.features)[:5])
self.block2 = nn.Sequential(*list(model.features)[5:10])
self.block3 = nn.Sequential(*list(model.features)[10:17])
self.block4 = nn.Sequential(*list(model.features)[17:24])
self.block5 = nn.Sequential(*list(model.features)[24:31])
model.classifier[0] = nn.Conv2d(512,4096,7)
model.classifier[3] = nn.Conv2d(4096,4096,1)
model.classifier[6] = nn.Conv2d(4096,num_classes,1)
self.fc6 = nn.Sequential(*list(model.classifier)[0:3])
self.fc7 = nn.Sequential(*list(model.classifier)[3:6])
self.block_score = model.classifier[6]
self.score_pool4 = nn.Conv2d(8*hidden,num_classes, kernel_size=1)
self.upscore2 = nn.ConvTranspose2d(num_classes,num_classes,4,stride=2,bias=False)
self.upscore16 = nn.ConvTranspose2d(num_classes,num_classes,32,stride = 16, bias=False)
def forward(self,x):
pred1 = self.block1(x)
pred2 = self.block2(pred1)
pred3 = self.block3(pred2)
pred4 = self.block4(pred3)
pred5 = self.block5(pred4)
pred6 = self.fc6(pred5)
pred7 = self.fc7(pred6)
score = self.block_score(pred7)
upscore2 = self.upscore2(score)
pred4_1 = self.score_pool4(pred4)
upscore16 = self.upscore16(upscore2+pred4_1)
return upscore16
반응형