Deep Reinforcement Learning Hands-On
上QQ阅读APP看书,第一时间看更新

Custom layers

In the previous section, we briefly mentioned the nn.Module class as a base parent for all NN building blocks exposed by PyTorch. It's not only a unifying parent for the existing layers—it's much more than that. By subclassing the nn.Module class, you can create your own building blocks which can be stacked together, reused later, and integrated into the PyTorch framework flawlessly.

At its core, nn.Module provides quite rich functionality to its children:

  • It tracks all submodules that the current module includes. For example, your building block can have two feed-forward layers used somehow to perform the block's transformation.
  • It provides functions to deal with all parameters of the registered submodules. You can obtain a full list of the module's parameters (parameters() method), zero its gradients (zero_grads() method), move to CPU or GPU (to(device) method), serialize and deserialize the module (state_dict() and load_state_dict()), and even perform generic transformations using your own callable (apply() method).
  • It establishes the convention of module application to data. Every module needs to perform its data transformation in the forward() method by overriding it.
  • There are some more functions, such as the ability to register a hook function to tweak module transformation or gradients flow, but it's more for advanced use cases.

These functionalities allow us to nest our submodels into higher-level models in a unified way, which is extremely useful when dealing with complexity. It could be a simple one-layer linear transformation or a 1001-layer ResNet monster, but if they follow the conventions of nn.Module, then both of them could be handled in the same way. This is very handy for code simplicity and reusability.

To make our life simpler, when following the preceding convention, PyTorch authors simplified the creation of modules by careful design and a good dose of Python magic. So, to create a custom module, we usually have to do only two things: register submodules and implement the forward() method. Let's look at how this can be done for our Sequential example from the previous section, but in a more generic and reusable way (full sample is Chapter03/01_modules.py):

class OurModule(nn.Module):
    def __init__(self, num_inputs, num_classes, dropout_prob=0.3):
        super(OurModule, self).__init__()
        self.pipe = nn.Sequential(
            nn.Linear(num_inputs, 5),
            nn.ReLU(),
            nn.Linear(5, 20),
            nn.ReLU(),
            nn.Linear(20, num_classes),
            nn.Dropout(p=dropout_prob),
            nn.Softmax()
        )

This is our module class that inherits nn.Module. In the constructor, we pass three parameters: the size of input, size of output, and optional dropout probability. The first thing we need to do is to call the parent's constructor to let it initialize itself. In the second step, we create an already familiar nn.Sequential with a bunch of layers and assign it to our class field named pipe. By assigning a Sequential instance to our field, we automatically register this module (nn.Sequential inherits from nn.Module as does everything in the nn package). To register, we don't need to call anything, we just assign our submodules to fields. After the constructor finishes, all those fields will be registered automatically (if you really want to, there is a function in nn.Module to register submodules):

    def forward(self, x):
        return self.pipe(x)

Here, we override the forward function with our implementation of data transformation. As our module is a very simple wrapper around other layers, we just need to ask them to transform the data. Note that to apply a module to the data, you need to call the module as callable (that is, pretend that the module instance is a function and call it with the arguments) and not use the forward() function of the nn.Module class. This is because nn.Module overrides the __call__() method, which is being used when we treat an instance as callable. This method does some nn.Module magic stuff and calls your forward() method. If you call forward() directly, you'll intervene with the nn.Module duty, which can give you wrong results.

So, that's what we need to do to define our own module. Now, let's use it:

if __name__ == "__main__":
    net = OurModule(num_inputs=2, num_classes=3)
    v = torch.FloatTensor([[2, 3]])
    out = net(v)
    print(net)
    print(out)

We create our module, providing it with the desired number of inputs and outputs, then we create a tensor, wrapped into the Variable and ask our module to transform it, following the same convention of using it as callable. Then we print our network's structure (nn.Module overrides __str__() and __repr__()) to represent the inner structure in a nice way. The last thing we show is the result of the network's transformation.

The output of our code should look like this:

rl_book_samples/Chapter03$ python 01_modules.pyOurModule(
  (pipe): Sequential(
    (0): Linear(in_features=2, out_features=5, bias=True)
    (1): ReLU()
    (2): Linear(in_features=5, out_features=20, bias=True)
    (3): ReLU()
    (4): Linear(in_features=20, out_features=3, bias=True)
    (5): Dropout(p=0.3)
    (6): Softmax()
  )
)
tensor([[ 0.3672,  0.3469,  0.2859]])

Of course, everything that was said about the dynamic nature of PyTorch is still true. Your forward() method will get control for every batch of data, so if you want to do some complex transformations based on the data you need to process, like hierarchical softmax or a random choice of net to apply, then nothing can stop you from doing so. The count of arguments to your module is also not limited by one parameter. So, if you want, you can write a module with multiple required parameters and dozens of optional arguments, and it will be fine.

Now we need to get familiar with two important pieces of the PyTorch library, which will simplify our lives: loss functions and optimizers.