Skip to content

Commit

Permalink
555
Browse files Browse the repository at this point in the history
  • Loading branch information
dingxiaohan committed Jun 7, 2021
1 parent 02d8dbd commit c8a830f
Show file tree
Hide file tree
Showing 5 changed files with 549 additions and 37 deletions.
81 changes: 56 additions & 25 deletions insert_bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from repvgg import get_RepVGG_func_by_name, RepVGGBlock
import PIL
from utils import load_checkpoint
from noris_dataset import ImageNetNoriDataset


# Get the mean and std on every conv3x3 (before the bias-adding) on the train set. Then use such data to initialize BN layers and insert them after conv3x3.
Expand All @@ -37,11 +38,15 @@
help='resolution (default: 224) for test')


def update_running_mean_var(x, running_mean, running_var, momentum=0.9):
def update_running_mean_var(x, running_mean, running_var, momentum=0.9, is_first_batch=False):
mean = x.mean(dim=(0, 2, 3), keepdim=True)
var = ((x - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
running_mean = momentum * running_mean + (1.0 - momentum) * mean
running_var = momentum * running_var + (1.0 - momentum) * var
if is_first_batch:
running_mean = mean
running_var = var
else:
running_mean = momentum * running_mean + (1.0 - momentum) * mean
running_var = momentum * running_var + (1.0 - momentum) * var
return running_mean, running_var

# Record the mean and std like a BN layer but do no normalization
Expand All @@ -50,12 +55,15 @@ def __init__(self, num_features):
super(BNStatistics, self).__init__()
shape = (1, num_features, 1, 1)
self.register_buffer('running_mean', torch.zeros(shape))
self.register_buffer('running_var', torch.ones(shape))
self.register_buffer('running_var', torch.zeros(shape))
self.is_first_batch = True

def forward(self, x):
if self.running_mean.device != x.device:
self.running_mean = self.running_mean.to(x.device)
self.running_var = self.running_var.to(x.device)
self.running_mean, self.running_var = update_running_mean_var(x, self.running_mean, self.running_var, momentum=0.9)
self.running_mean, self.running_var = update_running_mean_var(x, self.running_mean, self.running_var, momentum=0.9, is_first_batch=self.is_first_batch)
self.is_first_batch = False
return x

# This is designed to insert BNStat layer between Conv2d(without bias) and its bias
Expand Down Expand Up @@ -136,14 +144,7 @@ def insert_bn():

repvgg_build_func = get_RepVGG_func_by_name(args.arch)

model = repvgg_build_func(deploy=True)

if not torch.cuda.is_available():
print('using CPU, this will be slow')
use_gpu = False
else:
model = model.cuda()
use_gpu = True
model = repvgg_build_func(deploy=True).cuda()

load_checkpoint(model, args.weights)

Expand All @@ -154,9 +155,14 @@ def insert_bn():
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if args.resolution == 224:
# trans = transforms.Compose([
# transforms.Resize(256),
# transforms.CenterCrop(224),
# transforms.ToTensor(),
# normalize])
trans = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize])
else:
Expand All @@ -165,34 +171,59 @@ def insert_bn():
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
normalize])
print('data aug: ', trans)
if os.path.exists('/home/dingxiaohan/ndp/imagenet.train.nori.list'):
train_dataset = ImageNetNoriDataset('/home/dingxiaohan/ndp/imagenet.train.nori.list', trans)
else:
traindir = os.path.join(args.data, 'train')
train_dataset = datasets.ImageFolder(traindir, trans)

traindir = os.path.join(args.data, 'train')
train_dataset = datasets.ImageFolder(traindir, trans)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)

batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')

progress = ProgressMeter(
min(len(train_loader), args.num_batches), # If num_batches > the total num of batches in the train set, we just use the whole set
[batch_time],
prefix='BN Stats: ')
min(len(train_loader), args.num_batches),
[batch_time, losses, top1, top5],
prefix='BN stat: ')

criterion = nn.CrossEntropyLoss().cuda()

with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(train_loader):
if i == args.num_batches:
if i >= args.num_batches:
break
if use_gpu:
images = images.cuda(non_blocking=True)
# Just forward
model(images)
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)

# compute output
output = model(images)
loss = criterion(output, target)

# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

if i % 10 == 0:
progress.display(i)


print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))

switch_bnstat_to_convbn(model)

torch.save(model.state_dict(), args.save)
Expand Down
Loading

0 comments on commit c8a830f

Please sign in to comment.