Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zzazzz committed Dec 21, 2024
1 parent 0e7b0e0 commit bc68490
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 51 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/ci-cd-pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,34 @@ on:

jobs:

train_model:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11.7'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Log in to WandB
run: |
echo "WANDB_API_KEY=${{ secrets.WANDB_API_KEY }}" >> $GITHUB_ENV
wandb login ${{ secrets.WANDB_API_KEY }}
- name: Run model training
run: |
python model_training.py
validate_model:
needs: train_model
runs-on: ubuntu-latest

steps:
Expand Down
1 change: 1 addition & 0 deletions model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,5 @@ def compute_metrics(pred):

# Save the model
trainer.save_model("model")
image_processor.save_pretrained('model')
wandb.finish()
98 changes: 47 additions & 51 deletions validate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,66 +12,62 @@
model = Swinv2ForImageClassification.from_pretrained(model)
image_processor = AutoImageProcessor.from_pretrained(model)

# Load the dataset
data_dir = "data"
ds = load_dataset("imagefolder", data_dir=data_dir)

# Preprocessing
_transforms = Compose([
Resize((200, 200)),
GaussianBlur(kernel_size=(1, 5)),
RandomAdjustSharpness(sharpness_factor=2),
RandomEqualize(),
ToTensor()
])

def preprocess_test(example):
example["pixel_values"] = _transforms(example["image"].convert("RGB"))
return example

test_ds = ds["test"].map(preprocess_test, remove_columns=["image"])

# Function for predictions
def predict(model, image):
inputs = image_processor(images=image, return_tensors="pt")
# Fungsi untuk memprediksi satu gambar
def predict_single_image(image):
inputs = image_processor(images=image, return_tensors="pt").to(device)
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return predicted_class_idx

# Evaluate the model
def evaluate_model(test_ds):
true_labels = []
pred_labels = []
# Direktori dataset test
test_dir = 'data/test'

for sample in test_ds:
pixel_values = sample['pixel_values']
true_label = sample['label']

# Get prediction for the image
predicted_label = predict(model, pixel_values)
true_labels = []
pred_labels = []

# Iterasi semua subfolder dalam direktori test
for class_idx, class_name in id2label.items(): # id2label harus didefinisikan sebagai {0: "cardboard", 1: "glass", ...}
class_folder = os.path.join(test_dir, class_name)
if not os.path.isdir(class_folder):
continue # Lewati jika bukan folder

for image_name in os.listdir(class_folder):
image_path = os.path.join(class_folder, image_name)
try:
# Load image
image = Image.open(image_path).convert("RGB")

# Predict label
predicted_label = predict_single_image(image)

# Simpan true dan predicted label
true_labels.append(class_idx)
pred_labels.append(predicted_label)

true_labels.append(true_label)
pred_labels.append(predicted_label)

true_labels = np.array(true_labels)
pred_labels = np.array(pred_labels)
except Exception as e:
print(f"Error processing {image_path}: {e}")

cm = confusion_matrix(true_labels, pred_labels)
report = classification_report(true_labels, pred_labels)
# Konversi ke numpy array
true_labels = np.array(true_labels)
pred_labels = np.array(pred_labels)

return cm, report
# Confusion Matrix
cm = confusion_matrix(true_labels, pred_labels)

# Evaluate and plot the confusion matrix
cm, report = evaluate_model(test_ds)
# Classification Report
report = classification_report(true_labels, pred_labels, target_names=id2label.values())
print("Classification Report:")
print(report)

# Plot confusion matrix
plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
# Plot Confusion Matrix
def plot_confusion_matrix(cm, classes):
plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

# Print classification report
print("Classification Report:\n", report)
# Plot confusion matrix
plot_confusion_matrix(cm, list(id2label.values()))

0 comments on commit bc68490

Please sign in to comment.