diff --git a/resnet_policy.py b/resnet_policy.py index f8acfd9..257ce79 100644 --- a/resnet_policy.py +++ b/resnet_policy.py @@ -93,6 +93,15 @@ def __call__(self, x: chex.Array, batched: bool = False): return action_logits[0, 0, 0, :], value[0, 0, 0, 0] +class ResnetPolicyValueNet128(ResnetPolicyValueNet): + """Create a resnet of 128 channels, 5 blocks""" + + def __init__( + self, input_dims, num_actions: int, dim: int = 128, num_resblock: int = 5 + ): + super().__init__(input_dims, num_actions, dim, num_resblock) + + class ResnetPolicyValueNet256(ResnetPolicyValueNet): """Create a resnet of 256 channels, 6 blocks"""