Created
January 3, 2019 20:07
-
-
Save cipher982/25e84ae8b87dc6a16df23c5682ff9af5 to your computer and use it in GitHub Desktop.
QNetwork
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class QNetwork(nn.Module): | |
"""Actor (Policy) Model""" | |
def __init__(self, state_size, action_size, seed, fc1_units=64, fc2_units=64): | |
""" | |
Initialize parameters and build model | |
Params | |
====== | |
state_size (int): Dimension of each state | |
action_size (int): Dimension of each action | |
seed (int): Random seed | |
fc1_units (int): Number of nodes in first hidden layer | |
fc2_units (int): Number of nodes in second hidden layer | |
""" | |
super(QNetwork, self).__init__() | |
self.seed = torch.manual_seed(seed) | |
self.fc1 = nn.Linear(state_size, fc1_units) | |
self.fc2 = nn.Linear(fc1_units, fc2_units) | |
self.fc3 = nn.Linear(fc2_units, action_size) | |
def forward(self, state): | |
"""Build a network that maps state -> action values""" | |
x = F.relu(self.fc1(state)) | |
x = F.relu(self.fc2(x)) | |
x = self.fc3(x) # Action likelihoods | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment