Source code for lib.agents.sequentialCritic


import numpy as np
import torch
import torch.nn            as nn
import torch.nn.functional as F

[docs]class SequentialCritic(nn.Module): def __init__(self, stateSize, actionSize, layers=[10, 5], activations=[F.tanh, F.tanh], mergeLayer = 0, batchNormalization = True ): '''[summary] [description] Parameters ---------- stateSize : {[type]} [description] actionSize : {[type]} [description] layers : {list}, optional [description] (the default is [10, 5], which [default_description]) activations : {list}, optional [description] (the default is [F.tanh, F.tanh], which [default_description]) batchNormalization : {bool}, optional [description] (the default is True, which [default_description]) ''' super(SequentialCritic, self).__init__() self.stateSize = stateSize self.actionSize = actionSize self.layers = layers self.activations = activations self.mergeLayer = mergeLayer self.batchNormalization = batchNormalization # Generate the fullly connected layer functions self.fcLayers = [] self.bns = [] oldN = stateSize for i, layer in enumerate(layers): if i == mergeLayer: oldN += actionSize self.fcLayers.append( nn.Linear(oldN, layer) ) self.bns.append( nn.BatchNorm1d( num_features = layer ) ) oldN = layer # ------------------------------------------------------ # The final layer will only need to supply a quality # function. This is a single value for an action # provided. Ideally, you would want to provide a # OHE action sequence for most purposes ... # ------------------------------------------------------ self.fcFinal = nn.Linear( oldN, 1 ) return
[docs] def forward(self, x, action): for i, (bn, fc, a) in enumerate(zip(self.bns, self.fcLayers, self.activations)): if i == self.mergeLayer: x = torch.cat((x, action), dim=1) x = a(bn(fc(x))) x = self.fcFinal( x ) return x