User Tools

Site Tools


cs501r_f2018:lab5

Differences

This shows you the differences between two versions of the page.

Link to this comparison view

Both sides previous revision Previous revision
Next revision
Previous revision
cs501r_f2018:lab5 [2018/10/03 16:46]
cat [Grading standards:]
cs501r_f2018:lab5 [2021/06/30 23:42] (current)
Line 92: Line 92:
     self.vgg = models.vgg16(pretrained=True).features.eval()     self.vgg = models.vgg16(pretrained=True).features.eval()
     for i, m in enumerate(self.vgg.children()):​     for i, m in enumerate(self.vgg.children()):​
 +        if isinstance(m,​ nn.ReLU): ​  # we want to set the relu layers to NOT do the relu in place. ​
 +          m.inplace = False          # the model has a hard time going backwards on the in place functions. ​
 +          ​
         if i in requested:         if i in requested:
           def curry(i):           def curry(i):
Line 158: Line 161:
 You are welcome and encouraged to submit any other style transfer photographs you have, as long as you also submit the required image. Show us the awesome results you can generate! You are welcome and encouraged to submit any other style transfer photographs you have, as long as you also submit the required image. Show us the awesome results you can generate!
  
 +----
 +====Hints and usefulness:​====
  
 +A former student contributed the following:
  
 +Normalizing the image at each timestep is critical. ​ Here's what I did.
 +
 +Some extra things I did in the code snippet (in case they are useful):
 +
 +- I changed the VGG code to do use a dict, which in my opinion made things a lot easier. ​
 +- Swapped the max pool layers for avg pool layers (rather hackily...)
 +- I used a style scale of 500000 and a content scale of 1 (not in this code)
 +
 +PS if you try to use torchvision.transforms.Normalize() it won't work because it is missing a `forward()` and thus a `backward()` as well...
 +
 +<code python>
 +from collections import OrderedDict
 +class Normalization(nn.Module):​
 +    def __init__(
 +        self,
 +        mean=torch.tensor([0.485,​ 0.456, 0.406]).to(device),​
 +        std=torch.tensor([0.229,​ 0.224, 0.225]).to(device),​
 +    ):
 +        super(Normalization,​ self).__init__()
 +        self.mean = torch.tensor(mean).view(-1,​ 1, 1)
 +        self.std = torch.tensor(std).view(-1,​ 1, 1)
 +
 +    def forward(self,​ img):
 +        return (img - self.mean) / self.std
 +
 +
 +class VGGIntermediate(nn.Module):​
 +    def __init__(self,​ requested=[],​ transforms=[Normalization()]):​
 +        super(VGGIntermediate,​ self).__init__()
 +
 +        self.transforms = transforms
 +        self.vgg = models.vgg16(pretrained=True).features.eval()
 +
 +        layers_in_order = [
 +            "​conv1_1",​ "​relu1_1",​ "​conv1_2",​ "​relu1_2",​ "​maxpool1", ​
 +            "​conv2_1",​ "​relu2_1",​ "​conv2_2",​ "​relu2_2",​ "​maxpool2", ​
 +            "​conv3_1",​ "​relu3_1",​ "​conv3_2",​ "​relu3_2",​ "​conv3_3",​ "​relu3_3",​ "​maxpool3", ​
 +            "​conv4_1",​ "​relu4_1",​ "​conv4_2",​ "​relu4_2",​ "​conv4_3",​ "​relu4_3",​ "​maxpool4",​
 +            "​conv5_1",​ "​relu5_1",​ "​conv5_2",​ "​relu5_2",​ "​conv5_3",​ "​relu5_3",​ "​maxpool5"​
 +        ]
 +
 +        self.intermediates = OrderedDict()
 +        for layer_name, m in zip(layers_in_order,​ self.vgg.children()):​
 +            if isinstance(m,​ nn.ReLU):
 +                m.inplace = False
 +            elif isinstance(m,​ nn.MaxPool2d):​
 +                m.forward = lambda x: F.avg_pool2d(
 +                    x, m.kernel_size,​ m.stride, m.padding
 +                )
 +
 +            if layer_name in requested:
 +
 +                def curry(name):​
 +                    def hook(module,​ input, output):
 +                        self.intermediates[name] = output
 +
 +                    return hook
 +
 +                m.register_forward_hook(curry(layer_name))
 +
 +    def forward(self,​ x):
 +        for transform in self.transforms:​
 +            x = transform(x)
 +        self.vgg(x)
 +        return self.intermediates
 +
 +</​code>​
  
cs501r_f2018/lab5.1538585214.txt.gz ยท Last modified: 2021/06/30 23:40 (external edit)