Skip to content

Commit

Permalink
fix: Conversion of FieldState is in correct place
Browse files Browse the repository at this point in the history
This allows autobahn to work correctly.
  • Loading branch information
alex-ong committed Jul 12, 2020
1 parent ae0ac85 commit ee05354
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 12 deletions.
24 changes: 20 additions & 4 deletions nestris_ocr/calibration/field_view.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from nestris_ocr.colors import Colors
from nestris_ocr.calibration.image_canvas import ImageCanvas
import tkinter as tk
import numpy as np
from PIL import Image

from nestris_ocr.colors import Colors
from nestris_ocr.calibration.image_canvas import ImageCanvas
from nestris_ocr.utils.lib import tryGetInt
from nestris_ocr.ocr_state.field_state import FieldState


class FieldView(ImageCanvas):
Expand All @@ -18,8 +20,7 @@ def updateField(self, field, level):
if success and level != self.current_level:
self.color_table.setLevel(level)

if isinstance(field, str): # convert back to numpy array...
field = convertStringField(field)
field = convert_field(field)

lut = np.array(self.color_table.colors)
image = lut[field]
Expand All @@ -29,6 +30,21 @@ def updateField(self, field, level):
self.updateImage(image)


# convert from string or fieldstate to numpy array.
def convert_field(field):
if isinstance(field, str): # convert back to numpy array...
field = convertStringField(field)

if isinstance(field, FieldState):
field = field.data

if not isinstance(field, np.ndarray):
raise TypeError("Cannot convert field to correct type: " + str(type(field)))

return field


# convert string field to numpy array.
def convertStringField(strfield):
field = np.zeros((200,), dtype=np.uint8)
for index, item in enumerate(strfield):
Expand Down
11 changes: 11 additions & 0 deletions nestris_ocr/network/cached_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,25 @@ def sendResult(self, message, timeStamp):
# print(self.lastMessage,"\n",message)
self.lastMessage = message.copy()
message["time"] = timeStamp

message = prePackMessage(message, self.protocol)

if self.printPacket:
print(message)

packed, binary = packMessage(message, self.protocol)

self.client.sendMessage(packed, binary)

self.lastSend = time.time()


def prePackMessage(dictionary, protocol):
if dictionary["field"]:
dictionary["field"] = dictionary["field"].serialize()
return dictionary


def packMessage(dictionary, protocol):
if protocol in ["LEGACY", "AUTOBAHN", "FILE"]:
return json.dumps(dictionary), False
Expand All @@ -45,6 +55,7 @@ def packMessage(dictionary, protocol):
def sameMessage(dict1, dict2):
if dict1 is None:
return False

for key in dict1.keys():
if key in dict2:
if dict1[key] != dict2[key]:
Expand Down
25 changes: 18 additions & 7 deletions nestris_ocr/ocr_state/field_state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from nestris_ocr.config import config # TODO: remove this dependency.
from nestris_ocr.network.byte_stuffer import prePackField

Expand All @@ -14,8 +15,14 @@ def blockCountAdjusted(self):

def __eq__(self, other):
if isinstance(other, self.__class__):
return False
# return self.__dict__ == other.__dict__
result = np.array_equal(self.data, other.data)
return result

return False

# In Python3, don't implement __ne__
# def __ne__(self, other):
# ...

def piece_spawned(self, other):
return False
Expand All @@ -31,9 +38,13 @@ def serialize(self):
result = prePackField(result)
result = result.tobytes()
else:
result2 = []
for y in range(20):
temp = "".join(str(result[y, x]) for x in range(10))
result2.append(temp)
result = "".join(str(r) for r in result2)
return self.simple_string()
return result

def simple_string(self):
one_d = self.data.flatten()
result = "".join(str(r) for r in one_d)
return result

def __str__(self):
return self.simple_string()
2 changes: 1 addition & 1 deletion nestris_ocr/scan_strat/base_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def to_dict(self):
result["lines"] = dict_zfill(self.lines, 3)
result["score"] = dict_zfill(self.score, 6)
result["level"] = dict_zfill(self.level, 2)
result["field"] = self.field.serialize() if self.field else None
result["field"] = self.field
result["preview"] = self.preview
result["gameid"] = self.gameid
result.update(self.piece_stats.toDict())
Expand Down

0 comments on commit ee05354

Please sign in to comment.