Home Article Practice python code

python code

2024-06-17 19:34  views:476  source:Libra    

# This is the code from the p1ch2/3_cyclegan notebook
import torch
import torch.nn as nn
class ResNetBlock(nn.Module):
def __init__(self, dim):
super(ResNetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim)
def build_conv_block(self, dim):
conv_block = []
conv_block += [nn.ReflectionPad2d(1)]
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim),
nn.ReLU(True),
]
conv_block += [nn.ReflectionPad2d(1)]
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim),
]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class ResNetGenerator(nn.Module):
def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9):
assert n_blocks >= 0
super(ResNetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
nn.InstanceNorm2d(ngf),
nn.ReLU(True),
]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [
nn.Conv2d(
ngf * mult,
ngf * mult * 2,
kernel_size=3,
stride=2,
padding=1,
bias=True,
),
nn.InstanceNorm2d(ngf * mult * 2),
nn.ReLU(True),
]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResNetBlock(ngf * mult)]
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [
nn.ConvTranspose2d(
ngf * mult,
int(ngf * mult / 2),
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
bias=True,
),
nn.InstanceNorm2d(int(ngf * mult / 2)),
nn.ReLU(True),
]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
# here we move to 0-1 input and 0-1 output
# usually one would think about writing this differently
# for efficiency (e.g. absorbing the 255 into the first conv
return self.model(input * 255) / 2 + 0.5
def get_pretrained_model(model_path, map_location=None):
netG = ResNetGenerator()
model_data = torch.load(model_path, map_location=map_location)
netG.load_state_dict(model_data)
netG.eval()
for p in netG.parameters():
netG.requires_grad_(False)
return netG
if __name__ == "__main__":
import sys
if len(sys.argv) < 3:
print("Call as {} zebra_weights.pt traced_zebra_model.pt".format(sys.argv[0]))
sys.exit(1)
model = get_pretrained_model(sys.argv[1], map_location="cpu")
traced_model = torch.jit.trace(model, torch.randn(1, 3, 227, 227))
traced_model.save(sys.argv[2])
# img = Image.open("../data/p1ch2/horse.jpg")
# out_img.save('../data/p1ch2/zebra.jpg')



Disclaimer: The above articles are added by users themselves and are only for typing and communication purposes. They do not represent the views of this website, and this website does not assume any legal responsibility. This statement is hereby made! If there is any infringement of your rights, please contact us promptly to delete it.

字符:    改为:
去打字就可以设置个性皮肤啦!(O ^ ~ ^ O)