forked from AaronJny/captcha_detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassify_model.py
56 lines (49 loc) · 1.77 KB
/
classify_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# -*- coding: utf-8 -*-
# @Time : 2020/11/13
# @Author : AaronJny
# @File : classify_model.py
# @Desc :
import tensorflow as tf
from tensorflow.keras import layers
from config import ClassifyConfig
class SiameseNetwork(tf.keras.Model):
def __init__(self):
super(SiameseNetwork, self).__init__()
self.net_stage_1 = tf.keras.Sequential([
layers.Input(shape=(*ClassifyConfig.IMAGE_SIZE, ClassifyConfig.IMAGE_CHANNELS)),
layers.Conv2D(6, (3, 3), padding='same'),
layers.MaxPooling2D((2, 2), 2),
layers.Dropout(ClassifyConfig.DROPOUT_RATE),
layers.ReLU(),
layers.Conv2D(16, (5, 5)),
layers.MaxPooling2D((2, 2), 2),
layers.Dropout(ClassifyConfig.DROPOUT_RATE),
layers.ReLU()
])
self.net_stage_2 = tf.keras.Sequential([
layers.Conv2D(6, (3, 3)),
layers.MaxPooling2D((2, 2), 2),
layers.Dropout(ClassifyConfig.DROPOUT_RATE),
layers.ReLU(),
layers.Flatten(),
layers.Dense(84),
# layers.Dropout(ClassifyConfig.DROPOUT_RATE),
layers.ReLU(),
layers.Dense(1, activation='sigmoid')
])
@tf.function
def call(self, inputs, training=None, mask=None):
outs = []
for x in inputs:
out = self.net_stage_1(x)
outs.append(out)
out = tf.concat(outs, axis=-1)
out = self.net_stage_2(out)
return out
def load_classify_model():
model = SiameseNetwork()
model.build(
[(None, *ClassifyConfig.IMAGE_SIZE, ClassifyConfig.IMAGE_CHANNELS),
(None, *ClassifyConfig.IMAGE_SIZE, ClassifyConfig.IMAGE_CHANNELS)])
model.load_weights(ClassifyConfig.MODEL_PATH)
return model