Skip to content

Commit

Permalink
Minor cleanup and bug fixes in train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-jeff committed Mar 6, 2024
1 parent 16dc746 commit 88447a5
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 39 deletions.
10 changes: 4 additions & 6 deletions train/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch
torch
torchvision
tensorboardX
boto3
Expand All @@ -13,11 +13,12 @@ jmespath
joblib
networkx
numpy
opencv-python-headless
opencv-python-headless==4.6.0.66
packaging
Pillow
protobuf==3.20.1
pyparsing
pyrr
python-dateutil
PyWavelets
PyYAML
Expand All @@ -27,16 +28,13 @@ s3transfer
scikit-image
scikit-learn
scipy
simplejson
six
threadpoolctl
tifffile
typing_extensions
urllib3

pyrr
simplejson

visii

## If running into dependency issues, install the version-specific requirements below in a virtual environment
# albumentations==1.2.1
Expand Down
13 changes: 10 additions & 3 deletions train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@
help="folder to output images and model checkpoints",
)
parser.add_argument("--sigma", default=4, help="keypoint creation sigma")
parser.add_argument("--local_rank", type=int)
parser.add_argument("--local-rank", type=int, default=0)


parser.add_argument("--save", action="store_true", help="save a batch and quit")
parser.add_argument(
Expand Down Expand Up @@ -242,7 +243,7 @@

if opt.pretrained:
if opt.net_path is not None:
net.load_state_dict(torch.load(opt.net))
net.load_state_dict(torch.load(opt.net_path))
else:
print("Error: Did not specify path to pretrained weights.")
quit()
Expand Down Expand Up @@ -374,8 +375,14 @@ def _runnetwork(epoch, train_loader, syn=False):
"loss/train_bel", np.mean(loss_avg_to_log["loss_belief"]), epoch
)

start_epoch = 1
if opt.pretrained and opt.net_path is not None:
# we started with a saved checkpoint, we start numbering
# checkpoints after the loaded one
start_epoch = int(os.path.splitext(os.path.basename(opt.net_path).split('_')[2])[0]) + 1
print(f"Starting at epoch {start_epoch}")

for epoch in range(1, opt.epochs + 1):
for epoch in range(start_epoch, opt.epochs + 1):

_runnetwork(epoch, trainingdata)

Expand Down
32 changes: 2 additions & 30 deletions train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import torch.utils.data as data
import glob
import os
# import boto3
import io

from PIL import Image
Expand Down Expand Up @@ -143,7 +142,7 @@ def loadweights(root):
]

weights.sort()
return weights
return weights


def loadimages_inference(root, extensions):
Expand Down Expand Up @@ -937,35 +936,8 @@ def draw_cube(self, points, color=(0, 255, 0)):

for l in line_order:
self.draw_line(points[l[0]], points[l[1]], color, line_width=2)
# Draw center
self.draw_dot(points[0], point_color=color, point_radius=6)
# # draw front
# self.draw_line(points[0], points[1], color)
# self.draw_line(points[1], points[2], color)
# self.draw_line(points[3], points[2], color)
# self.draw_line(points[3], points[0], color)

# # draw back
# self.draw_line(points[4], points[5], color)
# self.draw_line(points[6], points[5], color)
# self.draw_line(points[6], points[7], color)
# self.draw_line(points[4], points[7], color)

# # draw sides
# self.draw_line(points[0], points[4], color)
# self.draw_line(points[7], points[3], color)
# self.draw_line(points[5], points[1], color)
# self.draw_line(points[2], points[6], color)

# # draw dots
# self.draw_dot(points[0], point_color=color, point_radius=4)
# self.draw_dot(points[1], point_color=color, point_radius=4)

# # draw x on the top
# self.draw_line(points[0], points[5], color)
# self.draw_line(points[1], points[4], color)

# # Draw center
# self.draw_dot(points[8], point_color=color, point_radius=6)

for i in range(9):
self.draw_text(points[i], str(i), (255, 0, 0))
Expand Down

0 comments on commit 88447a5

Please sign in to comment.