Skip to content

Commit

Permalink
Merge pull request #31 from beiyuouo/dev
Browse files Browse the repository at this point in the history
v0.1.1
  • Loading branch information
beiyuouo authored Sep 12, 2022
2 parents d823bea + bb5a765 commit d9d1cd5
Show file tree
Hide file tree
Showing 16 changed files with 286 additions and 247 deletions.
154 changes: 90 additions & 64 deletions src/apis/scaffold.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,86 @@
import os
import subprocess
import sys
import typing

from loguru import logger

from components.auto_label import ClusterLabeler
from components.config import Config
from components.utils import ToolBox
from factories.resnet import ResNet
from components.auto_label import ClusterLabeler

BADCODE = {
"а": "a",
"е": "e",
"e": "e",
"i": "i",
"і": "i",
"ο": "o",
"с": "c",
"ԁ": "d",
"ѕ": "s",
"һ": "h",
"у": "y",
"р": "p",
}


def diagnose_task(task_name: typing.Optional[str]) -> typing.Optional[str]:
"""Input detection and normalization"""
if not task_name or not isinstance(task_name, str) or len(task_name) < 2:
raise TypeError(f"({task_name})TASK should be string type data")

# Filename contains illegal characters
inv = {"\\", "/", ":", "*", "?", "<", ">", "|"}
if s := set(task_name) & inv:
raise TypeError(f"({task_name})TASK contains invalid characters({s})")

# Normalized separator
rnv = {" ", ",", "-"}
for s in rnv:
task_name = task_name.replace(s, "_")

# Convert bad code
for code in BADCODE:
task_name.replace(code, BADCODE[code])

task_name = task_name.strip()
logger.debug(f"Diagnose task | task_name={task_name}")

return task_name


class Scaffold:
_model = None

@staticmethod
@logger.catch()
def new():
"""
[dev for challenger] Initialize the project directory
Usage: python main.py new
prompt[en] --> Please click each image containing a dog-shaped cookie
task=`dog_shaped_cookie`
---
>>> input("prompt[en] --> ")
prompt[en] --> Please click each image containing a dog-shaped cookie
=> `dog_shaped_cookie`
prompt[en] --> horse with white legs
=> `horse_with_white_legs`
prompt[en] --> ""
=> raise TypeError
>>> input(f"Use AI to automatically label datasets? {choices} --> ")
1. Copy the unbinary image files to the automatically opened folder.
2. OkAction,Waiting for AI to automatically label.
>>> input(f"Start automatic training? {choices} --> ")
3. Check the results of automatic classification(Manual correction).
4. If the error rate is too high, it is recommended to cancel the training,
otherwise the training workflow can be continued.
:return:
"""
boolean_yes = "y"
boolean_no = "n"
choices = {boolean_yes, boolean_no}

# Prepend the detector to avoid invalid interactions
task = ToolBox.split_prompt(input("prompt[en] --> "), lang="en")
auto_label = input("auto_label? [y/n] --> ")
if auto_label in ["y", "Y"]:
task = diagnose_task(task)

# IF AUTO-LABEL
prompts = f"Use AI to automatically label datasets? {choices} --> "
while (auto_label := input(prompts)) not in choices:
continue
if auto_label == "y":
data_dir = os.path.join(Config.DIR_DATABASE, task)
unlabel_dir = os.path.join(data_dir, "unlabel")
if not os.path.exists(unlabel_dir):
os.makedirs(unlabel_dir)

os.system(f"start {unlabel_dir}")
# Create and open un-labeled dir
unlabel_dir = os.path.join(data_dir, "unlabel")
os.makedirs(unlabel_dir, exist_ok=True)
if sys.platform == "win32":
os.startfile(unlabel_dir)
else:
opener = "open" if sys.platform == "darwin" else "xdg-open"
subprocess.call([opener, unlabel_dir])

# Block the main process, waiting for manual operation
input(
"please put all the images in the `unlabel` folder and press any key to continue..."
)

labeler = ClusterLabeler(data_dir=data_dir)
labeler.run()
logger.info("Auto labeling completed")

cmd_train = input("start to train now? [y/n] --> ")
if cmd_train in ["y", "Y"]:
ClusterLabeler(data_dir=data_dir).run()
logger.success("Auto labeling completed")

# IF AUTO-TRAIN
prompts = f"Start automatic training? {choices} --> "
while (cmd_train := input(prompts)) not in choices:
continue
if cmd_train == "y":
Scaffold.train(task=task)

@staticmethod
@logger.catch()
def train(
task: str,
epochs: typing.Optional[int] = None,
batch_size: typing.Optional[int] = None,
task: str, epochs: typing.Optional[int] = None, batch_size: typing.Optional[int] = None
):
"""
Train the specified model and output an ONNX object
Expand Down Expand Up @@ -137,9 +128,7 @@ def val(task: str):
@staticmethod
@logger.catch()
def trainval(
task: str,
epochs: typing.Optional[int] = None,
batch_size: typing.Optional[int] = None,
task: str, epochs: typing.Optional[int] = None, batch_size: typing.Optional[int] = None
):
"""
Connect train and val
Expand All @@ -151,7 +140,44 @@ def trainval(
:param batch_size:
:return:
"""
# Scaffold.train.__func__(task, epochs, batch_size)
# Scaffold.val.__func__(task)
Scaffold.train(task, epochs, batch_size)
Scaffold.val(task)


def diagnose_task(task_name: typing.Optional[str]) -> typing.Optional[str]:
"""Input detection and normalization"""
if not task_name or not isinstance(task_name, str) or len(task_name) < 2:
raise TypeError(f"({task_name})TASK should be string type data")

# Filename contains illegal characters
inv = {"\\", "/", ":", "*", "?", "<", ">", "|"}
if s := set(task_name) & inv:
raise TypeError(f"({task_name})TASK contains invalid characters({s})")

# Normalized separator
rnv = {" ", ",", "-"}
for s in rnv:
task_name = task_name.replace(s, "_")

# Convert bad code
badcode = {
"а": "a",
"е": "e",
"e": "e",
"i": "i",
"і": "i",
"ο": "o",
"с": "c",
"ԁ": "d",
"ѕ": "s",
"һ": "h",
"у": "y",
"р": "p",
}
for code, right_code in badcode.items():
task_name.replace(code, right_code)

task_name = task_name.strip()
logger.debug(f"Diagnose task | task_name={task_name}")

return task_name
4 changes: 3 additions & 1 deletion src/components/auto_label/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .cluster import ClusterLabeler
from .cluster import ClusterLabeler

__all__ = ["ClusterLabeler"]
8 changes: 4 additions & 4 deletions src/components/auto_label/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import os
from typing import List
import typing


class BaseLabeler:
def __init__(self, data_dir, num_class: int = 2, labels: List[str] = None) -> None:
def __init__(self, data_dir, num_class: int = 2, labels: typing.List[str] = None):
self.data_dir = data_dir
self.num_class = num_class
if labels:
self.labels = labels
elif num_class == 2:
self.labels = ["yes", "bad"]
# elif num_class == 2:
# self.labels = ["yes", "bad"]
else:
self.labels = [f"class_{i}" for i in range(num_class)]

Expand Down
37 changes: 15 additions & 22 deletions src/components/auto_label/cluster.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from glob import glob
import os
import shutil
import sys
from glob import glob
from typing import List
from loguru import logger

import numpy as np
from PIL import Image
from loguru import logger
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

from components.utils import ToolBox
from .base import BaseLabeler
from .emb import *
from .img2emb import Img2Emb
from components.utils import ToolBox


class ClusterLabeler(BaseLabeler):
Expand All @@ -29,10 +28,7 @@ def __init__(
) -> None:
super().__init__(data_dir, num_class, labels)
self.img2emb = Img2Emb(
model=model,
layer=layer,
layer_output_size=layer_output_size,
save=save,
model=model, layer=layer, layer_output_size=layer_output_size, save=save
)
self.num_feat = num_feat
self._dir_unlabel = os.path.join(self.data_dir, "unlabel")
Expand All @@ -48,34 +44,31 @@ def __init__(
def run(self):
self.images = []
for ext in ToolBox.IMAGE_EXT:
self.images.extend(
glob(os.path.join(self._dir_unlabel, f"**/*.{ext}"), recursive=True)
)
self.images.extend(glob(os.path.join(self._dir_unlabel, f"**/*.{ext}"), recursive=True))
self.images = sorted(self.images)

if len(self.images) == 0:
raise ValueError(f"No images found in {self._dir_unlabel}")

logger.info(f"Found {len(self.images)} images in {self._dir_unlabel}")
logger.info("Extracting embeddings...")
logger.debug("Extracting embeddings...")
for i, img in enumerate(self.images):
img = Image.open(img)
emb = self.img2emb.get_emb(img)
self.embs.append(emb)
if (i + 1) % 100 == 0:
logger.info(f"Extracted {(i+1)} embeddings")
logger.info(f"Extracted {(i + 1)} embeddings")
logger.info("Embeddings extracted")

self.embs = np.array(self.embs)
logger.info("PCA..., shape of embs: {}".format(self.embs.shape))
logger.info(f"PCA..., shape of embs: {self.embs.shape}")
self.embs = PCA(n_components=self.num_feat).fit_transform(self.embs)
logger.info("PCA done, shape of embs: {}".format(self.embs.shape))

logger.info("Clustering...")
logger.info(f"PCA done, shape of embs: {self.embs.shape}")
logger.debug("Clustering...")
kmeans = KMeans(n_clusters=self.num_class).fit(self.embs)
logger.info("Clustering done")

labels_ = np.array(kmeans.labels_)
logger.debug("Clustering done")
labels_ = np.array(kmeans.labels_) # noqa
logger.info("Saving labels...")
for i, label in enumerate(labels_):
label = self.labels[label]
Expand All @@ -85,4 +78,4 @@ def run(self):
if not os.path.exists(label_path):
os.makedirs(label_path)
shutil.copy(img, os.path.join(label_path, img_name))
logger.info("Labels saved")
logger.debug("Labels saved")
6 changes: 2 additions & 4 deletions src/components/auto_label/emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_distance_function(distance):
elif distance == "l1":
return l1_distance
else:
raise ValueError("Unknown distance function: {}".format(distance))
raise ValueError(f"Unknown distance function: {distance}")


def get_distance_matrix(embs, distance="cosine"):
Expand Down Expand Up @@ -88,7 +88,5 @@ def get_sorted_distance_matrix(embs, distance="cosine"):
distance_matrix = get_distance_matrix(embs, distance)
sorted_distance_matrix = {}
for i in range(len(embs)):
sorted_distance_matrix[i] = sorted(
enumerate(distance_matrix[i]), key=lambda x: x[1]
)
sorted_distance_matrix[i] = sorted(enumerate(distance_matrix[i]), key=lambda x: x[1])
return sorted_distance_matrix
Loading

0 comments on commit d9d1cd5

Please sign in to comment.