sac_learner
SACLearner(actuator_ids, sensor_ids, agent_objective, *, replay_size=int(1000000.0), fc_dims=(256, 256), activation='torch.nn.ReLU', gamma=0.99, polyak=0.995, lr=0.001, batch_size=100, update_after=1000, update_every=50)
Bases: ActiveLearner
Learner class for the palaestrAI SAC agent.
Initialize the SAC learner.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
actuator_ids
|
list[str]
|
The IDs of actuators the learner should use to interact with the environment. |
required |
sensor_ids
|
list[str]
|
The IDs of sensors the learner should be able to see from the environment. |
required |
agent_objective
|
Objective
|
The objective function that takes environment rewards and converts them to an objective for the agent. |
required |
replay_size
|
int
|
Maximum length of replay buffer. |
int(1000000.0)
|
fc_dims
|
Sequence[int]
|
Dimensions of the hidden layers of the agent's actor and critic networks. "fc" stands for "fully connected". |
(256, 256)
|
activation
|
str
|
Activation function to use |
'torch.nn.ReLU'
|
gamma
|
float
|
Discount factor. (Always between 0 and 1.) |
0.99
|
polyak
|
float
|
Interpolation factor in polyak averaging for target networks. Target networks are updated towards main networks according to: \(\(\theta_{\text{targ}} \leftarrow \rho \theta_{ \text{targ}} + (1-\rho) \theta,\)\) where \(\rho\) is polyak. (Always between 0 and 1, usually close to 1.) |
0.995
|
lr
|
float
|
Learning rate (used for both policy and value learning). |
0.001
|
batch_size
|
int
|
Minibatch size for SGD. |
100
|
update_after
|
int
|
Number of env interactions to collect before starting to do gradient descent updates. Ensures replay buffer is full enough for useful updates. |
1000
|
update_every
|
int
|
Number of env interactions that should elapse between gradient descent updates. Note: Regardless of how long you wait between updates, the ratio of environment interactions to gradient steps is locked to 1. |
50
|
Source code in src/flowcean/palaestrai/sac_learner.py
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
|