Skip to content

Commit

Permalink
Merge branch 'dev' of https://github.com/datamllab/rlcard into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
hsywhu committed Sep 13, 2019
2 parents ad81c8e + d2edcb3 commit fe5e4aa
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ sonar-project.properties
docs/rst
docs/sphinx
experiments/
newtest/
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ before_script:
- pip install python-coveralls
- pip install pytest-cover
script:
- py.test test/test_utils.py --cov=rlcard/utils/
- py.test test/ --cov=rlcard
after_success:
- coveralls
2 changes: 1 addition & 1 deletion rlcard/agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import rlcard
from rlcard.utils.utils import *

Transition = namedtuple('Transition', '"state', 'action', 'reward', 'next_state', 'done'])
Transition = namedtuple("Transition", ["state", "action", "reward", "next_state", "done"])

class DQNAgent(object):

Expand Down
8 changes: 8 additions & 0 deletions rlcard/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,11 @@ def make_plot(self, save_path = ''):
os.makedirs(save_dir)

fig.savefig(save_path)

def close_file(self):
''' Close the created file objects
'''
if self.log_path != None:
self.log_file.close()
if self.csv_path != None:
self.csv_file.close()
34 changes: 34 additions & 0 deletions test/utils/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest
import random
import numpy as np
import os
from rlcard.utils.logger import Logger

class TestLoggerMethos(unittest.TestCase):

def test_log(self):
logger = Logger(xlabel="x", ylabel="y", legend="test", log_path="./newtest/test_log.txt")
logger.log("test text")
f = open("./newtest/test_log.txt", "r")
contents = f.read()
self.assertEqual(contents, "test text\n")
logger.close_file()

def test_add_point(self):
logger = Logger(xlabel="x", ylabel="y", legend="test", csv_path="./newtest/test_csv.csv")
logger.add_point(x=1, y=1)
self.assertEqual(logger.xs[0], 1)
self.assertEqual(logger.ys[0], 1)

def test_make_plot(self):
logger = Logger(xlabel="x", ylabel="y", legend="test")
for x in range(10):
logger.add_point(x=x, y=x*x)
logger.make_plot(save_path='./newtest/test.png')

def test_close_file(self):
logger = Logger(xlabel="x", ylabel="y", legend="test", log_path="./newtest/test_log.txt",csv_path="./newtest/test_csv.csv")
logger.close_file()

if __name__ == '__main__':
unittest.main()
File renamed without changes.

0 comments on commit fe5e4aa

Please sign in to comment.