import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import MultivariateNormal
[docs]
def weight_init(m):
"""
Initializes the weights of a neural network layer using the Xavier normal distribution and sets biases to zero.
Parameters
----------
m : torch.nn.Module
The neural network layer (usually `nn.Linear`) to initialize.
Notes
-----
This function checks if the input module `m` is an instance of `nn.Linear`. If so, it initializes
the weights (`m.weight`) using the Xavier normal distribution (also known as Glorot normal initialization)
and sets the biases (`m.bias`) to zero. This initialization technique is designed to keep
the weights of the neural network layers within a reasonable range during training, helping
with convergence and preventing vanishing or exploding gradients.
"""
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
[docs]
class Policynet(nn.Module):
"""
Policy network for generating actions and corresponding log-probabilities.
Attributes
----------
nn : torch.nn.Sequential
A neural network that processes the input state and produces a hidden representation.
fc_mu : torch.nn.Linear
A fully connected layer that maps the hidden representation to the mean of the action distribution.
fc_logstd : torch.nn.Linear
A fully connected layer that maps the hidden representation to the log standard deviation of the action distribution.
Methods
-------
forward(x)
Generates actions and corresponding log-probabilities given an input state.
Parameters
----------
state_dim : int
Dimensionality of the input state.
hidden_dim : int
Dimensionality of the hidden representation.
action_dim : int
Dimensionality of the action space.
action_space : tuple
Tuple indicating the minimum and maximum action values (min_action, max_action).
Notes
-----
The `forward` method generates actions by sampling from a multivariate normal distribution
parameterized by the mean (`mu`) and standard deviation (`std`). The covariance matrix is
constructed as a diagonal matrix with the elements of `std` on the diagonal. Actions are
sampled using the `rsample` method to allow for gradient propagation through the sampling process.
The log-probability of the generated actions is also computed and returned.
"""
[docs]
def __init__(self, state_dim ,hidden_dim, action_dim, action_space):
super(Policynet, self).__init__()
self.nn = nn.Sequential(nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.fc_mu = nn.Linear(hidden_dim, action_dim)
self.fc_logstd = nn.Linear(hidden_dim, action_dim)
self.min_action, self.max_action = action_space
self.apply(weight_init)
[docs]
def forward(self, x):
"""
Generates actions and corresponding log-probabilities given an input state.
Parameters
----------
x : torch.Tensor
Input state tensor of shape `(batch_size, state_dim)`.
Returns
-------
action : torch.Tensor
Generated actions of shape `(batch_size, action_dim)`.
logprob : torch.Tensor
Log-probabilities of the generated actions of shape `(batch_size, 1)`.
"""
x = self.nn(x)
mu = torch.tanh(self.fc_mu(x))
mu = self.min_action + 0.5 * (mu + 1.0) * (self.max_action - self.min_action)
logstd = self.fc_logstd(x)
std = F.softplus(logstd) + 1e-6
cov = torch.stack([torch.diag(s) for s in std])
mn = MultivariateNormal(mu, cov)
action = mn.rsample()
logprob = mn.log_prob(action)
return action, logprob.view(-1,1)
[docs]
class Qnet(nn.Module):
"""
Q-network for estimating the state-action value in reinforcement learning.
Attributes
----------
nn : torch.nn.Sequential
A neural network that processes the concatenated state and action and outputs the Q-value.
Methods
-------
forward(x, a)
Computes the Q-value given a state and an action.
Parameters
----------
state_dim : int
Dimensionality of the state space.
hidden_dim : int
Dimensionality of the hidden layers in the neural network.
action_dim : int
Dimensionality of the action space.
Notes
-----
The `forward` method concatenates the state `x` and action `a` along the second dimension,
then passes the concatenated vector through the neural network `nn` to obtain the Q-value.
"""
[docs]
def __init__(self, state_dim, hidden_dim, action_dim):
super(Qnet, self).__init__()
self.nn = nn.Sequential(nn.Linear(state_dim+action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
self.apply(weight_init)
[docs]
def forward(self, x, a):
"""
Computes the Q-value given a state and an action.
Parameters
----------
x : torch.Tensor
Input state tensor of shape `(batch_size, state_dim)`.
a : torch.Tensor
Input action tensor of shape `(batch_size, action_dim)`.
Returns
-------
q_value : torch.Tensor
The computed Q-values of shape `(batch_size, 1)`.
"""
cat = torch.cat([x, a], dim=1)
return self.nn(cat)
[docs]
class SAC:
"""
Implementation of the Soft Actor-Critic (SAC) algorithm for reinforcement learning.
Attributes
----------
actor : Policynet
The policy network that outputs actions given states.
critic_1, critic_2 : Qnet
Two Q-networks (also known as critics) that estimate the state-action value.
target_critic_1, target_critic_2 : Qnet
Target Q-networks used for stabilizing learning via soft updates.
actor_optimizer : torch.optim.Optimizer
Optimizer for updating the actor network.
critic_1_optimizer, critic_2_optimizer : torch.optim.Optimizer
Optimizers for updating the two critic networks.
log_alpha : torch.Tensor
Learnable temperature parameter for entropy regularization.
log_alpha_optimizer : torch.optim.Optimizer
Optimizer for updating the temperature parameter.
target_entropy : float
Target entropy used for entropy regularization.
gamma : float
Discount factor for future rewards.
tau : float
Soft update coefficient for target networks.
device : torch.device
Device (CPU or GPU) on which the networks and tensors should be stored.
Methods
-------
take_action(state)
Given a state, returns an action sampled from the actor network.
calc_target(rewards, next_states, dones)
Computes the target Q-values for a batch of transitions.
soft_update(net, target_net)
Updates the target network towards the main network using a soft update rule.
update(transition_dict)
Performs a training update using a batch of transitions.
"""
[docs]
def __init__(self, state_dim, hidden_dim, action_dim, action_space, actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma, device):
"""
Initialize the SAC agent.
Parameters
----------
state_dim : int
Dimensionality of the state space.
hidden_dim : int
Dimensionality of the hidden layers in the neural networks.
action_dim : int
Dimensionality of the action space.
actor_lr : float
Learning rate for the actor.
critic_lr : float
Learning rate for the critics.
alpha_lr : float
Learning rate for the temperature parameter alpha.
target_entropy : float
Target entropy for the policy.
tau : float
Soft update factor for the target networks.
gamma : float
Discount factor.
device : str or torch.device
Device on which to run the computations (e.g., 'cuda' or 'cpu').
"""
self.actor = Policynet(state_dim, hidden_dim, action_dim, action_space).to(device)
self.critic_1 = Qnet(state_dim, hidden_dim, action_dim).to(device)
self.critic_2 = Qnet(state_dim, hidden_dim, action_dim).to(device)
self.target_critic_1 = Qnet(state_dim, hidden_dim, action_dim).to(device)
self.target_critic_2 = Qnet(state_dim, hidden_dim, action_dim).to(device)
self.target_critic_1.load_state_dict(self.critic_1.state_dict())
self.target_critic_2.load_state_dict(self.critic_2.state_dict())
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(), lr=critic_lr)
self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(), lr=critic_lr)
self.log_alpha = torch.tensor(np.log(.01), dtype=torch.float)
self.log_alpha.requires_grad = True
self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr)
self.target_entropy = target_entropy
self.gamma = gamma
self.tau = tau
self.device = device
[docs]
def take_action(self, state):
"""
Take an action given the current state.
Parameters
----------
state : array_like
Current state of the environment.
Returns
-------
action : array_like
Action taken by the agent.
"""
state = torch.tensor(state, dtype=torch.float).to(self.device)
return self.actor(state)[0].detach().cpu().numpy()
[docs]
def calc_target(self, rewards, next_states, dones):
"""
Calculate the TD targets for the critics.
Parameters
----------
rewards : array_like
Rewards received from the environment.
next_states : array_like
Next states observed from the environment.
dones : array_like
Boolean array indicating whether each episode has terminated.
Returns
-------
td_target : torch.Tensor
Temporal difference targets for the critics.
"""
next_actions, log_prob = self.actor(next_states)
entropy = -log_prob
q1_value = self.target_critic_1(next_states, next_actions)
q2_value = self.target_critic_2(next_states, next_actions)
next_value = torch.min(q1_value, q2_value) + self.log_alpha.exp() * entropy
td_target = rewards + self.gamma * next_value * (1 - dones)
return td_target
[docs]
def soft_update(self, net, target_net):
"""
Perform a soft update of the target network parameters.
Parameters
----------
net : nn.Module
The current network.
target_net : nn.Module
The target network to be updated.
"""
for param_target, param in zip(target_net.parameters(), net.parameters()):
param_target.data.copy_(param_target.data * (1 - self.tau) + param.data * self.tau)
[docs]
def update(self, transition_dict):
"""
Update the agent's networks using a batch of transitions.
Parameters
----------
transition_dict : dict
Dictionary containing the transitions. Should have keys:
'states', 'actions', 'rewards', 'next_states', 'dones'.
"""
states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
actions = torch.tensor(transition_dict['actions'], dtype=torch.float).to(self.device)
rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1,1).to(self.device)
next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1,1).to(self.device)
td_target = self.calc_target(rewards, next_states, dones)
critic_1_loss = torch.mean(F.mse_loss(self.critic_1(states, actions), td_target.detach()))
critic_2_loss = torch.mean(F.mse_loss(self.critic_2(states, actions), td_target.detach()))
self.critic_1_optimizer.zero_grad()
critic_1_loss.backward()
self.critic_1_optimizer.step()
self.critic_2_optimizer.zero_grad()
critic_2_loss.backward()
self.critic_2_optimizer.step()
new_actions, log_prob = self.actor(states)
entropy = -log_prob
q1_value = self.critic_1(states, new_actions)
q2_value = self.critic_2(states, new_actions)
actor_loss = torch.mean(-self.log_alpha.exp() * entropy - torch.min(q1_value, q2_value))
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
alpha_loss = torch.mean((entropy - self.target_entropy).detach() * self.log_alpha.exp())
self.log_alpha_optimizer.zero_grad()
alpha_loss.backward()
self.log_alpha_optimizer.step()
self.soft_update(self.critic_1, self.target_critic_1)
self.soft_update(self.critic_2, self.target_critic_2)