Skip to content

Commit

Permalink
add resnet128
Browse files Browse the repository at this point in the history
  • Loading branch information
NTT123 authored Jul 2, 2022
1 parent fbdcf76 commit 7fd5adf
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions resnet_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ def __call__(self, x: chex.Array, batched: bool = False):
return action_logits[0, 0, 0, :], value[0, 0, 0, 0]


class ResnetPolicyValueNet128(ResnetPolicyValueNet):
"""Create a resnet of 128 channels, 5 blocks"""

def __init__(
self, input_dims, num_actions: int, dim: int = 128, num_resblock: int = 5
):
super().__init__(input_dims, num_actions, dim, num_resblock)


class ResnetPolicyValueNet256(ResnetPolicyValueNet):
"""Create a resnet of 256 channels, 6 blocks"""

Expand Down

0 comments on commit 7fd5adf

Please sign in to comment.