Project - Introduction to Neural Style Transfer using Deep Learning & TensorFlow 2 (Art Generation Project)

17 / 18

Define the Training Step

Now it's the time to define the training step. We shall use tf.GradientTape to update the image.

Let us define train_step(image) function which performs the calculation of gradient and updation of image pixel values for each train step epoch.

In defining the train_step function, the following steps are implemented:

  • Calculate the outputs which are the style and content representations of the input image, using the extractor which is the object of StyleContentModel. Then, call the function style_content_loss function to get the weighted-loss of the input image. Record all these operations using tf.GradientTape().

  • Based on the thus obtained loss, calculate the gradients, using tape.gradient(loss, image).

  • Then, apply these gradients using opt.apply_gradients.

  • Finally, update the image as per the gradients and clip the pixel values to be in 0-1 range.

Note:

  • @tf.function converts a Python function to its graph representation for Faster execution, especially if the function consists of many small ops. The pattern to follow is to define the training step function, that's the most computationally intensive function, and decorate it with @tf.function.

  • tf.GradientTape() records the list of the operations, so that these could be used for automatic differentiation during optimization. It is very highly recommended to go through the official docs in order to gain a bigger picture of this.

  • optimizer.apply_gradients applies the gradients.

INSTRUCTIONS
  • Use the following code:

    @tf.function()
    def train_step(image):
        with tf.GradientTape() as tape:
            outputs = extractor(image)
            loss = style_content_loss(outputs)
    
        grad = tape.gradient(loss, image)
        opt.apply_gradients([(grad, image)])
        image.assign(clip_0_1(image))
    
  • Now run a few steps to test:

    train_step(image)
    train_step(image)
    train_step(image)
    tensor_to_image(image)
    
Get Hint See Answer


Note - Having trouble with the assessment engine? Follow the steps listed here

Loading comments...