Skip to content

Commit

Permalink
model for detecting driver status
Browse files Browse the repository at this point in the history
  • Loading branch information
xurror committed May 10, 2020
1 parent 0e41822 commit bb3040f
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions src/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
import numpy as np
import random, shutil

from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Dropout, Conv2D, Flatten, Dense, MaxPool2D, BatchNormalization

def generator(dir, gen=image.ImageDataGenerator(rescale=1./255), shuffle=True, batch_size=1, target_size=(24, 24), class_mode='categorical'):
return gen.flow_from_directory(dir, batch_size=batch_size, shuffle=shuffle, color_mode='grayscale', class_mode=class_mode, target_size=target_size)

BS = 32
TS = (24, 24)
train_batch = generator('src/data/train', shuffle=True, batch_size=BS, target_size=TS)
valid_batch = generator('src/data/valid', shuffle=True, batch_size=BS, target_size=TS)
SPE = len(train_batch.classes)//BS
VS = len(valid_batch.classes)//BS
print(SPE,VS)

model = Sequential([
Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(24,24,1)),
MaxPool2D(pool_size=(1,1)),
Conv2D(32,(3,3),activation='relu'),
MaxPool2D(pool_size=(1,1)),
#32 convolution filters used each of size 3x3
#again
Conv2D(64, (3, 3), activation='relu'),
MaxPool2D(pool_size=(1,1)),

#64 convolution filters used each of size 3x3
#choose the best features via pooling

#randomly turn neurons on and off to improve convergence
Dropout(0.25),
#flatten since too many dimensions, we only want a classification output
Flatten(),
#fully connected to get all relevant data
Dense(128, activation='relu'),
#one more dropout for convergence' sake :)
Dropout(0.5),
#output a softmax to squash the matrix into output probabilities
Dense(2, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

model.fit_generator(train_batch, validation_data=valid_batch,epochs=15, steps_per_epoch=SPE, validation_steps=VS)

model.save('src/models/cnnCat2.h5', overwrite=True)

Binary file added src/models/cnnCat2.h5
Binary file not shown.

0 comments on commit bb3040f

Please sign in to comment.