Login using Social Account
     Continue with GoogleLogin using your credentials
Now it's the time to define the training step. We shall use tf.GradientTape
to update the image.
Let us define train_step(image)
function which performs the calculation of gradient and updation of image pixel values for each train step epoch.
In defining the train_step
function, the following steps are implemented:
Calculate the outputs
which are the style and content representations of the input image, using the extractor
which is the object of StyleContentModel
.
Then, call the function style_content_loss
function to get the weighted-loss of the input image.
Record all these operations using tf.GradientTape()
.
Based on the thus obtained loss, calculate the gradients, using tape.gradient(loss, image)
.
Then, apply these gradients using opt.apply_gradients
.
Finally, update the image as per the gradients and clip the pixel values to be in 0-1 range.
Note:
@tf.function
converts a Python function to its graph representation for Faster execution, especially if the function consists of many small ops.
The pattern to follow is to define the training step function, that's the most computationally intensive function, and decorate it with @tf.function.
tf.GradientTape()
records the list of the operations, so that these could be used for automatic differentiation during optimization. It is very highly recommended to go through the official docs in order to gain a bigger picture of this.
optimizer.apply_gradients
applies the gradients.
Use the following code:
@tf.function()
def train_step(image):
with tf.GradientTape() as tape:
outputs = extractor(image)
loss = style_content_loss(outputs)
grad = tape.gradient(loss, image)
opt.apply_gradients([(grad, image)])
image.assign(clip_0_1(image))
Now run a few steps to test:
train_step(image)
train_step(image)
train_step(image)
tensor_to_image(image)
Taking you to the next exercise in seconds...
Want to create exercises like this yourself? Click here.
Loading comments...