Skip to content

Commit

Permalink
Merge branch 'master' into sde
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Jul 16, 2020
2 parents 8f5279e + 208890d commit a96970c
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Others:
- Split the ``collect_rollout()`` method for off-policy algorithms
- Added ``_on_step()`` for off-policy base class
- Optimized replay buffer size by removing the need of ``next_observations`` numpy array
- Ignored errors from newer pytype version

Documentation:
^^^^^^^^^^^^^^
Expand Down
5 changes: 3 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAl
env = data["env"]

# noinspection PyArgumentList
model = cls(policy=data["policy_class"], env=env, device='auto', _init_setup_model=False)
model = cls(policy=data["policy_class"], env=env,
device='auto', _init_setup_model=False) # pytype: disable=not-instantiable,wrong-keyword-args

# load parameters
model.__dict__.update(data)
Expand All @@ -350,7 +351,7 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAl
# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
if model.use_sde:
model.policy.reset_noise()
model.policy.reset_noise() # pytype: disable=attribute-error
return model

def set_random_seed(self, seed: Optional[int] = None) -> None:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BaseModel':
device = get_device(device)
saved_variables = th.load(path, map_location=device)
# Create policy object
model = cls(**saved_variables['data'])
model = cls(**saved_variables['data']) # pytype: disable=not-instantiable
# Load weights
model.load_state_dict(saved_variables['state_dict'])
model.to(device)
Expand Down

0 comments on commit a96970c

Please sign in to comment.