Skip to content

Commit

Permalink
Fix ortho init when bias=False with custom policy (DLR-RM#126)
Browse files Browse the repository at this point in the history
* Update policies.py

fix AttributeError occurred when use "bias=False" linear layer in custom FeaturesExtractor DLR-RM#124

* Update changelog.rst

 update the changelog accordingly

* Update changelog.rst

Co-authored-by: Kong Lingchao <[email protected]>
Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
3 people authored Jul 25, 2020
1 parent 8353056 commit bd2aae0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Bug Fixes:
- Fixed a bug in the ``close()`` method of ``SubprocVecEnv``, causing wrappers further down in the wrapper stack to not be closed. (@NeoExtended)
- Fix target for updating q values in SAC: the entropy term was not conditioned by terminals states
- Use ``cloudpickle.load`` instead of ``pickle.load`` in ``CloudpickleWrapper``. (@shwang)
- Fixed a bug with orthogonal initialization when `bias=False` in custom policy (@rk37)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -356,4 +357,4 @@ And all the contributors:
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake @koulakis @joeljosephjin @shwang
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37
3 changes: 2 additions & 1 deletion stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def init_weights(module: nn.Module, gain: float = 1) -> None:
"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
nn.init.orthogonal_(module.weight, gain=gain)
module.bias.data.fill_(0.0)
if module.bias is not None:
module.bias.data.fill_(0.0)

@abstractmethod
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
Expand Down

0 comments on commit bd2aae0

Please sign in to comment.