|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import torchvision.transforms as transforms |
| 4 | +from PIL import Image |
| 5 | + |
| 6 | +model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet50', pretrained=True) |
| 7 | +model.eval() |
| 8 | + |
| 9 | + |
| 10 | +def load_and_preprocess_image(image_path): |
| 11 | + image = Image.open(image_path).convert('RGB') |
| 12 | + preprocess = transforms.Compose([ |
| 13 | + transforms.Resize((512, 512)), |
| 14 | + transforms.ToTensor(), |
| 15 | + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| 16 | + ]) |
| 17 | + input_tensor = preprocess(image).unsqueeze(0) |
| 18 | + return input_tensor |
| 19 | + |
| 20 | + |
| 21 | +def remove_background(image_path, save_path, alpha_foreground=255, alpha_background=0): |
| 22 | + input_tensor = load_and_preprocess_image(image_path) |
| 23 | + |
| 24 | + with torch.no_grad(): |
| 25 | + output = model(input_tensor)['out'][0] |
| 26 | + output_predictions = output.argmax(0) |
| 27 | + |
| 28 | + # Convert the prediction to a binary mask (0 for background, 1 for foreground) |
| 29 | + mask = output_predictions.byte() |
| 30 | + |
| 31 | + # Convert the mask tensor to a PIL Image with mode 'L' (8-bit pixels, black and white) |
| 32 | + mask_pil = transforms.ToPILImage()(mask) |
| 33 | + |
| 34 | + # Load the image outside the if block |
| 35 | + image = Image.open(image_path).convert('RGBA') |
| 36 | + |
| 37 | + # Resize the mask to match the dimensions of the image |
| 38 | + mask_pil = mask_pil.resize((image.size[0], image.size[1])) |
| 39 | + |
| 40 | + # Apply the mask to the input image |
| 41 | + image_with_alpha = Image.alpha_composite(Image.new('RGBA', image.size, (255, 255, 255, alpha_background)), image) |
| 42 | + image_with_alpha.putalpha(mask_pil.point(lambda p: alpha_foreground if p else 0)) |
| 43 | + |
| 44 | + # Save the resulting image with transparent background |
| 45 | + image_with_alpha.save(save_path, format='PNG') |
| 46 | + |
| 47 | +if __name__ == "__main__": |
| 48 | + # Replace 'input_folder' with the path to the folder containing your input images |
| 49 | + input_folder = "/config/workspace/project/img" |
| 50 | + |
| 51 | + # Replace 'output_folder' with the path where you want to save the output images |
| 52 | + output_folder = "/config/workspace/project/out_img" |
| 53 | + |
| 54 | + for filename in os.listdir(input_folder): |
| 55 | + if filename.endswith(".jpg") or filename.endswith(".png") or filename.endswith(".avif"): |
| 56 | + image_path = os.path.join(input_folder, filename) |
| 57 | + save_path = os.path.join(output_folder, filename.replace(".jpg", ".png").replace(".png", ".png")) |
| 58 | + remove_background(image_path, save_path) |
0 commit comments