Skip to content

Commit

Permalink
update data export mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
shex1627 committed Jan 19, 2022
1 parent 4560b7f commit e08963f
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 55 deletions.
36 changes: 17 additions & 19 deletions posture_monitor/posture_monitor_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,43 +35,28 @@
alert_on = True
track_data_on = True
program_on = True
# KEY_ALERT_ON = Key.f7
# KEY_ALERT_OFF = Key.f8
# KEY_EXIT = Key.f6
KEY_ALERT_TOGGLE = Key.f6
KEY_TRACK_DATA_TOGGLE = Key.f7
KEY_CAMERA_TOGGLE = Key.f8
KEY_EXIT = Key.f4


# def on_press_start(key, key_alert_on=KEY_ALERT_ON, exit=KEY_EXIT):
# global alert_toggle
# global camera_on
# #print("Key pressed: {0}".format(key))
# if key == key_alert_on:
# logging.info("alert on")
# alert_toggle = True
# return False

# if key == exit:
# print('exiting...')
# sys.exit()


def on_press_loop(key):
global alert_on
global track_data_on
global camera_on
global program_on

if key == KEY_TRACK_DATA_TOGGLE:
"""turn off alert if data tracking is off."""
track_data_on = not track_data_on
if not track_data_on:
alert_on = False
logging.info(f"track data toggle to {track_data_on}")
return True

if key == KEY_ALERT_TOGGLE:
"""turn on data tracking if alert is on as well."""
alert_on = not alert_on
if alert_on:
track_data_on = True
Expand All @@ -84,6 +69,7 @@ def on_press_loop(key):
return True

if key == KEY_EXIT:
"""turn off camera before program exits"""
camera_on = not camera_on
program_on = False
logging.info(f"exiting program")
Expand All @@ -100,13 +86,15 @@ def main():
cap = cv2.VideoCapture(0)
while program_on:
with Listener(on_press=on_press_loop) as listener:
#listener.join()
# For webcam input:
print("loop listener")

if not camera_on:
cv2.destroyAllWindows()
else:
cv2.namedWindow("Posture Monitor", cv2.WINDOW_AUTOSIZE)
cv2.setWindowProperty('Posture Monitor', cv2.WND_PROP_TOPMOST, 1)

with mp_pose.Pose(
min_detection_confidence=0.5,
min_tracking_confidence=0.5) as pose:
Expand Down Expand Up @@ -159,6 +147,16 @@ def main():
for i in range(len(alerts_trigger)):
cv2.putText(img=image_flip, text=alerts_trigger[i], org=(100,y_init+i*y_increment),
fontFace=font, fontScale=font_scale, color=(0, 255, 0), thickness=6, lineType=cv2.LINE_AA)

# show on windows if tracking and alert are on
y_increment = 60
y_init = 400
font_scale = 1
cv2.putText(img=image_flip, text=f"alert_on: {alert_on}", org=(0, y_init),
fontFace=font, fontScale=font_scale, color=(0, 255, 0), thickness=4, lineType=cv2.LINE_AA)
cv2.putText(img=image_flip, text=f"track_data_on: {track_data_on}", org=(0,y_init+y_increment),
fontFace=font, fontScale=font_scale, color=(0, 255, 0), thickness=4, lineType=cv2.LINE_AA)

# Flip the image horizontally for a selfie-view display.
scale_percent = 50 # percent of original size
width = int(image_flip.shape[1] * scale_percent / 100)
Expand All @@ -180,7 +178,6 @@ def main():
if __name__ == '__main__':
parser = argparse.ArgumentParser()
#parser.add_argument("--input_file", help="data file with filename field", default=False)
args = parser.parse_args()
parser.add_argument('--window_size',
default='small',
#const='all',
Expand All @@ -192,6 +189,7 @@ def main():
# this should default true
parser.add_argument('--windows_on_top',
help="flag to determine if window is always on top of all other windows", action="store_true")
args = parser.parse_args()
main()

# run main
30 changes: 30 additions & 0 deletions posture_monitor/src/PostureMetricTs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
import os
import json
from typing import Callable, OrderedDict, List, Dict
from posture_monitor.src.util import OrderedDefaultDict
from collections import defaultdict
Expand Down Expand Up @@ -55,6 +56,35 @@ def get_past_data(self, start_time: int=None, seconds: int=1,
#print(f"data: {data}")
return past_data


def reset_data(self):
""" Remove all historical data."""
self.second_to_frame_scores = OrderedDefaultDict()
self.second_to_avg_frame_scores = defaultdict(lambda : self.fillna)

def export_data(self, data_dir: str):
"""
export data to data dir, then reset data
"""
# do nothing if no data to export
if len(self.second_to_avg_frame_scores) <= 0:
return
metric_filepath = os.path.join(data_dir, self.name+".json")
# load historical data if exist
if os.path.exists(metric_filepath):
with open(metric_filepath, 'r') as infile:
historical_data = json.load(infile)
else:
historical_data = defaultdict(lambda : self.fillna)
# update with latest data
current_data = {str(ts): value for ts, value in self.second_to_avg_frame_scores.items()}
historical_data.update(current_data)
# dump updated data
with open(metric_filepath, 'w') as outfile:
json.dump(historical_data, outfile, indent=4)
# reset data
self.reset_data()


class PostureSubMetricTs(PostureMetricTs):
"""
Expand Down
27 changes: 21 additions & 6 deletions posture_monitor/src/PostureSession.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def __init__(self, metricTsDict: Dict[str, PostureMetricTs], alertRules: List[Po
if data_dir:
self.init_data_dir()

self.alertTS = PostureMetricTs("alert_ts",lambda landmarks: 0)

def update_metrics(self, landmarks, export_data=True) -> None:
"""Update all the metrics. If a metric is a submetric, then feed in the metricDict"""
for metric in self.metrics.values():
Expand Down Expand Up @@ -70,11 +72,16 @@ def check_posture_alert(self, trigger_sound=True) -> List[str]:
for alert in self.alertRules:
if alert.alert_trigger(self.metrics):
alerts_triggered.append(alert.name)
if trigger_sound and len(alerts_triggered):
self.trigger_alert_sound()
if len(alerts_triggered):
self.alertTS.second_to_avg_frame_scores[get_time()] = 1
if trigger_sound:
self.trigger_alert_sound()
else:
self.alertTS.second_to_avg_frame_scores[get_time()] = 0
return alerts_triggered

def init_data_dir(self) -> None:
""" create session directory, copies all the session config to directory."""
if not os.path.exists(self.data_dir):
# Create a new directory because it does not exist
os.makedirs(self.data_dir)
Expand All @@ -85,7 +92,12 @@ def init_data_dir(self) -> None:
copy(self.config_datapath, self.session_data_dir)

def export_data(self):
"""export all metrics data if last export time is old enough."""
"""export all metrics data if last export time is old enough.
Also resets all the data.
metric data are json files, each key is integer time, value is the
avg value of the metric during the second.
"""
logger.debug("renaming session data dir based on end_time")
now = get_time()
new_session_data_dir = os.path.join(self.data_dir, f"session_{self.start_time}_{now}")
Expand All @@ -95,8 +107,11 @@ def export_data(self):
logger.debug("writing to file")
for metric_name in self.metrics:
metric_filepath = os.path.join(self.session_data_dir, metric_name+".json")
with open(metric_filepath, 'w') as outfile:
data = self.metrics[metric_name].second_to_avg_frame_scores
json.dump(data, outfile, indent=4)
# with open(metric_filepath, 'w') as outfile:
# data = self.metrics[metric_name].second_to_avg_frame_scores
# json.dump(data, outfile, indent=4)
self.metrics[metric_name].export_data(self.session_data_dir)

self.alertTS.export_data(self.session_data_dir)


6 changes: 2 additions & 4 deletions test/unit/test_PostureKDeltaAlert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
def test_alert_trigger():
test_metricTs = \
PostureMetricTs("test_metric",
metric_func=lambda landmarks: landmarks[0],
data_dir="test_data")
metric_func=lambda landmarks: landmarks[0])
now = get_time()
test_metricTs.second_to_avg_frame_scores.update({
now -1: 1
Expand All @@ -26,8 +25,7 @@ def test_alert_trigger():
def test_alert_not_trigger():
test_metricTs = \
PostureMetricTs("test_metric",
metric_func=lambda landmarks: landmarks[0],
data_dir="test_data")
metric_func=lambda landmarks: landmarks[0])
now = get_time()
test_metricTs.second_to_avg_frame_scores.update({
now -1: 0.1
Expand Down
82 changes: 66 additions & 16 deletions test/unit/test_PostureMetricTS.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,43 @@
from collections import defaultdict
import time
import os
from turtle import update
import pytest
from unittest.mock import patch
import json
import shutil

dir_path = os.path.dirname(os.path.realpath(__file__))
from posture_monitor.src.PostureMetricTs import PostureMetricTs, PostureSubMetricTs
from posture_monitor.src.PostureMetricTs import if_metric_fail_avg_and_last_second
from posture_monitor.src.util import get_time
from unittest.mock import patch


TEMP_DIR = 'test/temp_test_dir'
def create_temp_dir():
if not os.path.exists(TEMP_DIR):
# Create a new directory because it does not exist
print("creating temp dir")
os.makedirs(TEMP_DIR)
else:
print('temp dir exists')

def cleanup_temp_dir():
if os.path.exists(TEMP_DIR):
shutil.rmtree(TEMP_DIR)


@pytest.fixture()
def test_metricTs():
test_metricTs = \
PostureMetricTs("test_metric",
metric_func=lambda landmarks: landmarks[0],
data_dir="test_data")
metric_func=lambda landmarks: landmarks[0])
return test_metricTs

def test_update():
test_metricTs = \
PostureMetricTs("test_metric",
metric_func=lambda landmarks: landmarks[0],
data_dir="test_data")
metric_func=lambda landmarks: landmarks[0])

test_landmarks_1 = [1]
test_landmarks_2 = [4]
Expand All @@ -46,8 +63,7 @@ def test_update():
def test_get_past_data_full_data():
test_metricTs = \
PostureMetricTs("test_metric",
metric_func=lambda landmarks: landmarks[0],
data_dir="test_data")
metric_func=lambda landmarks: landmarks[0])


now = get_time()
Expand All @@ -68,8 +84,7 @@ def test_get_past_data_full_data():
def test_get_past_data_skip_data():
test_metricTs = \
PostureMetricTs("test_metric",
metric_func=lambda landmarks: landmarks[0],
data_dir="test_data")
metric_func=lambda landmarks: landmarks[0])


now = get_time()
Expand All @@ -91,8 +106,7 @@ def test_get_past_data_skip_data():
def test_posture_checking_function():
test_metricTs = \
PostureMetricTs("test_metric",
metric_func=lambda landmarks: landmarks[0],
data_dir="test_data")
metric_func=lambda landmarks: landmarks[0])
test_metricTs_dict = {'test_metric': test_metricTs}
sub_metric_func = if_metric_fail_avg_and_last_second('test_metric',
threshold_rule = lambda metric: int(metric > 0.5),
Expand All @@ -115,8 +129,7 @@ def test_posture_checking_function():
def test_posture_submetrics_update():
test_metricTs = \
PostureMetricTs("test_metric",
metric_func=lambda landmarks: landmarks[0],
data_dir="test_data")
metric_func=lambda landmarks: landmarks[0])
test_metricTs_dict = {'test_metric': test_metricTs}
sub_metric_func = if_metric_fail_avg_and_last_second('test_metric',
threshold_rule = lambda metric: int(metric > 0.5),
Expand Down Expand Up @@ -144,8 +157,7 @@ def test_posture_submetrics_update():
def test_posture_submetrics_get_past_data():
test_metricTs = \
PostureMetricTs("test_metric",
metric_func=lambda landmarks: landmarks[0],
data_dir="test_data")
metric_func=lambda landmarks: landmarks[0])
test_metricTs_dict = {'test_metric': test_metricTs}
sub_metric_func = if_metric_fail_avg_and_last_second('test_metric',
threshold_rule = lambda metric: int(metric > 0.5),
Expand All @@ -165,4 +177,42 @@ def test_posture_submetrics_get_past_data():
})
#test_SubMetricTs.update(test_metricTs_dict)
result = test_SubMetricTs.get_past_data(seconds=4)
assert result==[0.1, 0.1, 1, 1]
assert result==[0.1, 0.1, 1, 1]


def test_export_data():
"""
create fake json, dump it as historical data
create new data
call expoert method
load updated file
assert if data is updated
assert if data is reset
clean up temp_dir
"""
create_temp_dir()
test_metricTs = \
PostureMetricTs("test_metric",
metric_func=lambda landmarks: landmarks[0])

historical_data = {1:1, 2:2}
data_file = os.path.join(TEMP_DIR, test_metricTs.name+".json")

with open(data_file, 'w') as outfile:
json.dump(historical_data, outfile, indent=4)

new_data = {3:3, 4:4}
test_metricTs.second_to_avg_frame_scores.update(new_data)
test_metricTs.export_data(TEMP_DIR)

with open(data_file, 'r') as infile:
update_data = json.load(infile)

assert update_data == {
str(i):i for i in range(1, 5)
}
assert len(test_metricTs.second_to_avg_frame_scores) == 0
cleanup_temp_dir()
Loading

0 comments on commit e08963f

Please sign in to comment.