scfocus.model.SAC
- class scfocus.model.SAC(state_dim, hidden_dim, action_dim, action_space, actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma, device)[source]
Implementation of the Soft Actor-Critic (SAC) algorithm for reinforcement learning.
- critic_1, critic_2
Two Q-networks (also known as critics) that estimate the state-action value.
- Type:
- target_critic_1, target_critic_2
Target Q-networks used for stabilizing learning via soft updates.
- Type:
- actor_optimizer
Optimizer for updating the actor network.
- Type:
torch.optim.Optimizer
- critic_1_optimizer, critic_2_optimizer
Optimizers for updating the two critic networks.
- Type:
torch.optim.Optimizer
- log_alpha
Learnable temperature parameter for entropy regularization.
- Type:
torch.Tensor
- log_alpha_optimizer
Optimizer for updating the temperature parameter.
- Type:
torch.optim.Optimizer
- target_entropy
Target entropy used for entropy regularization.
- Type:
float
- gamma
Discount factor for future rewards.
- Type:
float
- tau
Soft update coefficient for target networks.
- Type:
float
- device
Device (CPU or GPU) on which the networks and tensors should be stored.
- Type:
torch.device
- calc_target(rewards, next_states, dones)[source]
Computes the target Q-values for a batch of transitions.
- soft_update(net, target_net)[source]
Updates the target network towards the main network using a soft update rule.
- __init__(state_dim, hidden_dim, action_dim, action_space, actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma, device)[source]
Initialize the SAC agent.
- Parameters:
state_dim (int) – Dimensionality of the state space.
hidden_dim (int) – Dimensionality of the hidden layers in the neural networks.
action_dim (int) – Dimensionality of the action space.
actor_lr (float) – Learning rate for the actor.
critic_lr (float) – Learning rate for the critics.
alpha_lr (float) – Learning rate for the temperature parameter alpha.
target_entropy (float) – Target entropy for the policy.
tau (float) – Soft update factor for the target networks.
gamma (float) – Discount factor.
device (str or torch.device) – Device on which to run the computations (e.g., ‘cuda’ or ‘cpu’).
Methods
__init__(state_dim, hidden_dim, action_dim, ...)Initialize the SAC agent.
calc_target(rewards, next_states, dones)Calculate the TD targets for the critics.
soft_update(net, target_net)Perform a soft update of the target network parameters.
take_action(state)Take an action given the current state.
update(transition_dict)Update the agent's networks using a batch of transitions.