本篇文章为大家展示了如何在pytorch获取vgg16-feature层的输出,内容简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。
import numpy as np import torch from torchvision import models from torch.autograd import Variable import torchvision.transforms as transforms class CNNShow(): def __init__(self, model): self.model = model self.model.eval() self.created_image = self.image_for_pytorch(np.uint8(np.random.uniform(150, 180, (224, 224, 3)))) def show(self): x = self.created_image for index, layer in enumerate(self.model): print(index,layer) x = layer(x) def image_for_pytorch(self,Data): transform = transforms.Compose([ transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ] ) imData = transform(Data) imData = Variable(torch.unsqueeze(imData, dim=0), requires_grad=True) return imData if __name__ == '__main__': pretrained_model = models.vgg16(pretrained=True).features CNN = CNNShow(pretrained_model) CNN.show()
上述内容就是如何在pytorch获取vgg16-feature层的输出,你们学到知识或技能了吗?如果还想学到更多技能或者丰富自己的知识储备,欢迎关注创新互联成都网站设计公司行业资讯频道。
另外有需要云服务器可以了解下创新互联scvps.cn,海内外云服务器15元起步,三天无理由+7*72小时售后在线,公司持有idc许可证,提供“云服务器、裸金属服务器、高防服务器、香港服务器、美国服务器、虚拟主机、免备案服务器”等云主机租用服务以及企业上云的综合解决方案,具有“安全稳定、简单易用、服务可用性高、性价比高”等特点与优势,专为企业上云打造定制,能够满足用户丰富、多元化的应用场景需求。