Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
MickyasTA committed Apr 26, 2024
1 parent 5f85986 commit a336fe6
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 7 deletions.
14 changes: 7 additions & 7 deletions Agent_Actor_Critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def learn(self,state,reward,state_,done,):
state_value_=torch.squeeze(state_value_)

action_probs=torch.distributions.categorical(probabilites)
log_prob=action_probs.Log_prob(self.action)
log_prob=action_probs.log_prob(self.action) # we calculate the log probability of the action we took

delta = reward + self.gamma*state_value_*(1-int(done)) - state_value
actor_loss = -log_prob*delta
critic_loss=delta**2
total_loss= actor_loss+critic_loss
delta = reward + self.gamma*state_value_*(1-int(done)) - state_value # we calculate the TD error
actor_loss = -log_prob*delta # we calculate the actor loss
critic_loss=delta**2 # we calculate the critic loss
total_loss= actor_loss+critic_loss

gradient=total_loss.backward()
self.actor_critic.optimizer.step()
gradient=total_loss.backward() # we backpropagate the gradients
self.actor_critic.optimizer.step() # we update the weights of the network


48 changes: 48 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import gym
import numpy as np
from Agent_Actor_Critic import Agent
from utils import plot_learning_curve

if __name__=='__main__':
env=gym.make('CartPole-v1')
agent=Agent(learning_rate=1e-5,n_actions=env.action_space)
n_games=1800
filename='cartpole.png'
figure_file='plots/'+ filename

best_score=env.reward_range[0]
score_history=[]
#load_checkpoint=False
check_path_dir="tmp/actor_critic"
if check_path_dir and os.path.exists(check_path_dir):

agent.load_model()
""" if load_checkpoint:
agent.load_model()"""
for i in range(n_games):
observation=env.reset() # Resets the environment to an initial state and returns the initial observation.
done=False
score=0
while not done:
action=agent.choose_action(observation)
observation_,reward,done,info =env.step()
score+=reward

if check_path_dir and os.path.exists(check_path_dir):
agent.learn(observation,reward,observation_,done)
observation=observation_
score_history.append(score)
avg_score=np.sum(score_history[-100:])

if avg_score>best_score:
best_score=avg_score
if not check_path_dir and os.path.exists(check_path_dir):
agent.save_model()
x=[i+1 for i in range(n_games)]
plot_learning_curve(x,score_history,figure_file)





0 comments on commit a336fe6

Please sign in to comment.