diff --git a/pytorch_classification/Test5_resnet/train.py b/pytorch_classification/Test5_resnet/train.py index 3efe897af..d3e9923df 100644 --- a/pytorch_classification/Test5_resnet/train.py +++ b/pytorch_classification/Test5_resnet/train.py @@ -21,7 +21,7 @@ def main(): transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} - + data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path image_path = data_root + "/data_set/flower_data/" # flower data set path