scfocus.model.SAC

class scfocus.model.SAC(state_dim, hidden_dim, action_dim, action_space, actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma, device)[source]

Implementation of the Soft Actor-Critic (SAC) algorithm for reinforcement learning.

actor

The policy network that outputs actions given states.

Type:

Policynet

critic_1, critic_2

Two Q-networks (also known as critics) that estimate the state-action value.

Type:

Qnet

target_critic_1, target_critic_2

Target Q-networks used for stabilizing learning via soft updates.

Type:

Qnet

actor_optimizer

Optimizer for updating the actor network.

Type:

torch.optim.Optimizer

critic_1_optimizer, critic_2_optimizer

Optimizers for updating the two critic networks.

Type:

torch.optim.Optimizer

log_alpha

Learnable temperature parameter for entropy regularization.

Type:

torch.Tensor

log_alpha_optimizer

Optimizer for updating the temperature parameter.

Type:

torch.optim.Optimizer

target_entropy

Target entropy used for entropy regularization.

Type:

float

gamma

Discount factor for future rewards.

Type:

float

tau

Soft update coefficient for target networks.

Type:

float

device

Device (CPU or GPU) on which the networks and tensors should be stored.

Type:

torch.device

take_action(state)[source]

Given a state, returns an action sampled from the actor network.

calc_target(rewards, next_states, dones)[source]

Computes the target Q-values for a batch of transitions.

soft_update(net, target_net)[source]

Updates the target network towards the main network using a soft update rule.

update(transition_dict)[source]

Performs a training update using a batch of transitions.

__init__(state_dim, hidden_dim, action_dim, action_space, actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma, device)[source]

Initialize the SAC agent.

Parameters:
  • state_dim (int) – Dimensionality of the state space.

  • hidden_dim (int) – Dimensionality of the hidden layers in the neural networks.

  • action_dim (int) – Dimensionality of the action space.

  • actor_lr (float) – Learning rate for the actor.

  • critic_lr (float) – Learning rate for the critics.

  • alpha_lr (float) – Learning rate for the temperature parameter alpha.

  • target_entropy (float) – Target entropy for the policy.

  • tau (float) – Soft update factor for the target networks.

  • gamma (float) – Discount factor.

  • device (str or torch.device) – Device on which to run the computations (e.g., ‘cuda’ or ‘cpu’).

Methods

__init__(state_dim, hidden_dim, action_dim, ...)

Initialize the SAC agent.

calc_target(rewards, next_states, dones)

Calculate the TD targets for the critics.

soft_update(net, target_net)

Perform a soft update of the target network parameters.

take_action(state)

Take an action given the current state.

update(transition_dict)

Update the agent's networks using a batch of transitions.