Skip to content

Commit

Permalink
Update load_meta_cnn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
shillyshallysxy authored Dec 11, 2018
1 parent 77244cf commit acc05fc
Showing 1 changed file with 13 additions and 17 deletions.
30 changes: 13 additions & 17 deletions emotion_classifier_tensorflow_version/CNN/load_meta_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
channel = 1
default_height = 48
default_width = 48
confusion_matrix = True
confusion_matrix = False
use_advanced_method = True
emotion_labels = ['angry', 'disgust:', 'fear', 'happy', 'sad', 'surprise', 'neutral']
num_class = len(emotion_labels)
Expand All @@ -41,14 +41,7 @@
logits = graph.get_tensor_by_name('project/output/logits:0')


def produce_result(images_):
images_ = np.multiply(np.array(images_), 1./255)
pred_logits_ = sess.run(tf.nn.softmax(logits), {x_input: images_, dropout: 1.0})
return pred_logits_


def produce_advanced_result(images_):
images_ = np.multiply(np.array(images_), 1. / 255)
def advance_image(images_):
rsz_img = []
rsz_imgs = []
for image_ in images_:
Expand All @@ -64,6 +57,15 @@ def produce_advanced_result(images_):
rsz_img.append(np.reshape(cv2.resize(cv2.flip(image_, 1), (default_height, default_width)),
[default_height, default_width, channel]))
rsz_imgs.append(np.array(rsz_img))
return rsz_imgs


def produce_result(images_):
images_ = np.multiply(np.array(images_), 1. / 255)
if use_advanced_method:
rsz_imgs = advance_image(images_)
else:
rsz_imgs = [images_]
pred_logits_ = []
for rsz_img in rsz_imgs:
pred_logits_.append(sess.run(tf.nn.softmax(logits), {x_input: rsz_img, dropout: 1.0}))
Expand All @@ -72,10 +74,7 @@ def produce_advanced_result(images_):

def produce_results(images_):
results = []
if use_advanced_method:
pred_logits_ = produce_advanced_result(images_)
else:
pred_logits_ = produce_result(images_)
pred_logits_ = produce_result(images_)
pred_logits_list_ = np.array(np.reshape(np.argmax(pred_logits_, axis=1), [-1])).tolist()
for num in range(num_class):
results.append(pred_logits_list_.count(num))
Expand All @@ -102,10 +101,7 @@ def produce_confusion_matrix(images_list_, total_num_):
def predict_emotion(image_):
image_ = cv2.resize(image_, (default_height, default_width))
image_ = np.reshape(image_, [-1, default_height, default_width, channel])
if use_advanced_method:
return produce_advanced_result(image_)[0]
else:
return produce_result(image_)[0]
return produce_result(image_)[0]


def face_detect(image_path, casc_path_=casc_path):
Expand Down

0 comments on commit acc05fc

Please sign in to comment.