diff --git a/insert_bn.py b/insert_bn.py index 6663e95..1562a50 100644 --- a/insert_bn.py +++ b/insert_bn.py @@ -124,6 +124,9 @@ def directly_insert_bn_without_init(model): block.rbr_reparam.dilation, block.rbr_reparam.groups, bias=False)) # Note bias=False convbn.add_module('bn', nn.BatchNorm2d(block.rbr_reparam.out_channels)) + convbn.add_module('relu', nn.ReLU()) + print('conv bn relu') + block.nonlinearity = nn.Identity() #TODO note this block.__delattr__('rbr_reparam') block.rbr_reparam = convbn diff --git a/quant_train.py b/quant_train.py new file mode 100644 index 0000000..cc673bd --- /dev/null +++ b/quant_train.py @@ -0,0 +1,462 @@ +import argparse +import os +import random +import shutil +import time +import warnings + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +from utils import accuracy, AverageMeter, ProgressMeter, log_msg, WarmupCosineAnnealingLR +from noris_dataset import ImageNetNoriDataset +import math +import copy + +best_acc1 = 0 + +IMAGENET_TRAINSET_SIZE = 1281167 + +parser = argparse.ArgumentParser(description='PyTorch Quant') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('-a', '--arch', metavar='ARCH', default='RepVGG-A0') +parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--epochs', default=8, type=int, metavar='N', + help='number of epochs for each run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--val-batch-size', default=100, type=int, metavar='V', + help='validation batch size') +parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, + metavar='LR', help='learning rate for finetuning', dest='lr') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') +parser.add_argument('--dist-url', default='tcp://127.0.0.1:23333', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') +parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') +parser.add_argument('--base-weights', default=None, type=str, + help='weights of the base model. Ignore it if this is not the first quant iteration') +parser.add_argument('--last-weights', default=None, type=str, + help='the weights of the last iteration. Ignore it if this is the first quant iteration') +parser.add_argument('--quant', default='1_2', type=str, + help='the quant set. For example, "1_2_3-0-1-2-3" means you want to quantize stage1, stage2, and the first 4 sections of stage3, and "2_3" means stage2 and the whole stage3') + + + + +def sgd_optimizer(model, lr, momentum, weight_decay): + params = [] + for key, value in model.named_parameters(): + if not value.requires_grad: + continue + apply_weight_decay = weight_decay + apply_lr = lr + if value.ndimension() < 2: #TODO note this + apply_weight_decay = 0 + print('set weight decay=0 for {}'.format(key)) + if 'bias' in key: + apply_lr = 2 * lr # Just a Caffe-style common practice. Made no difference. + params += [{'params': [value], 'lr': apply_lr, 'weight_decay': apply_weight_decay}] + optimizer = torch.optim.SGD(params, lr, momentum=momentum) + return optimizer + +def main(): + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + ngpus_per_node = torch.cuda.device_count() + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + global best_acc1 + args.gpu = gpu + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + + # 1. Build and load base model + from repvgg import get_RepVGG_func_by_name + repvgg_build_func = get_RepVGG_func_by_name(args.arch) + base_model = repvgg_build_func(deploy=True) + from insert_bn import directly_insert_bn_without_init + directly_insert_bn_without_init(base_model) + if args.base_weights is not None: + assert args.last_weights is None + base_weights = {} + for k, v in torch.load(args.base_weights).items(): + base_weights[k.replace('restore', 'rbr_reparam')] = v #TODO + base_model.load_state_dict(base_weights) + # 2. + from repvgg_quantized import RepVGGQuant + if 'A' in args.arch: + # RepVGG-A has 2, 4, 14 layers in the middle 3 stages. We only split the 14-layer stage + stage_sections = {3: 4} # split stage3 into 4 sections + # stage_sections = {} #TODO + elif 'B' in args.arch: + # RepVGG-B has 4, 6, 16 layers in the middle 3 stages. We split stage2 and stage3 + stage_sections = {2:3, 3:4} # split stage2 into 3 sections and stage3 into 4 sections + else: + raise ValueError('TODO') + + # "1_2_3-0-1-2-3" + # Parse the quant set. For example, "1_2_3-0-1-2-3" + quant_stagesections = [] + ss = args.quant.split('_') + for s in ss: + if len(s) == 1: + stage_idx = int(s) + assert stage_idx < 5 + quant_stagesections.append(stage_idx) + else: + sections = s.split('-') + stage_idx = int(sections[0]) + for i in range(1, len(sections)): + section_idx = int(sections[i]) + quant_stagesections.append((stage_idx, section_idx)) + + qat_model = RepVGGQuant(repvgg_model=base_model, stage_sections=stage_sections, quant_stagesections=quant_stagesections) + + qat_model.prepare_quant() + + if args.last_weights is not None: + assert args.base_weights is None + base_model.load_state_dict(torch.load(args.last_weights)) + + #=================================================== + # From now on, the code will be very similar to ordinary training + # =================================================== + + if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): + for n, p in qat_model.named_parameters(): + print(n, p.size()) + for n, p in qat_model.named_buffers(): + print(n, p.size()) + # You will see it now has quantization-related parameters (zero-points and scales) + + if not torch.cuda.is_available(): + print('using CPU, this will be slow') + elif args.distributed: + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + qat_model.cuda(args.gpu) + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + qat_model = torch.nn.parallel.DistributedDataParallel(qat_model, device_ids=[args.gpu]) + else: + qat_model.cuda() + qat_model = torch.nn.parallel.DistributedDataParallel(qat_model) + elif args.gpu is not None: + torch.cuda.set_device(args.gpu) + qat_model = qat_model.cuda(args.gpu) + else: + # DataParallel will divide and allocate batch_size to all available GPUs + qat_model = torch.nn.DataParallel(qat_model).cuda() + + + criterion = nn.CrossEntropyLoss().cuda(args.gpu) + optimizer = sgd_optimizer(qat_model, args.lr, args.momentum, args.weight_decay) + + warmup_epochs = 1 + lr_scheduler = WarmupCosineAnnealingLR(optimizer=optimizer, T_cosine_max=args.epochs * IMAGENET_TRAINSET_SIZE // args.batch_size // ngpus_per_node, + eta_min=0, warmup=warmup_epochs * IMAGENET_TRAINSET_SIZE // args.batch_size // ngpus_per_node) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + loc = 'cuda:{}'.format(args.gpu) + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint['epoch'] + best_acc1 = checkpoint['best_acc1'] + if args.gpu is not None: + # best_acc1 may be from a checkpoint from a different GPU + best_acc1 = best_acc1.to(args.gpu) + qat_model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['scheduler']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_trans = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ]) + + # train_dataset = datasets.ImageFolder(traindir, transform=train_trans) + train_dataset = ImageNetNoriDataset('/home/dingxiaohan/ndp/imagenet.train.nori.list', train_trans) #TODO + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler) + + + val_trans = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ]) + + # val_dataset = datasets.ImageFolder(valdir, val_trans) + val_dataset = ImageNetNoriDataset('/home/dingxiaohan/ndp/imagenet.val.nori.list', val_trans) #TODO + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=args.val_batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + if args.evaluate: + validate(val_loader, qat_model, criterion, args) + return + + validate(val_loader, qat_model, criterion, args) #TODO note this + + # if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): + # acc1 = validate(val_loader, qat_model, criterion, args) + # msg = '{}, quant {}, init, QAT acc {}'.format(args.arch, args.quant, acc1) + # log_msg(msg, 'quant_exp.txt') + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + # adjust_learning_rate(optimizer, epoch, args) + + # train for one epoch + train(train_loader, qat_model, criterion, optimizer, epoch, args, lr_scheduler) + + if epoch > (3 * args.epochs // 8): + # Freeze quantizer parameters + qat_model.apply(torch.quantization.disable_observer) + if epoch > (2 * args.epochs // 8): + # Freeze batch norm mean and variance estimates + qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) + + + + # evaluate on validation set + if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): + acc1 = validate(val_loader, qat_model, criterion, args) + msg = '{}, quant {}, epoch {}, QAT acc {}'.format(args.arch, args.quant, epoch, acc1) + log_msg(msg, 'quant_exp.txt') + + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': qat_model.state_dict(), + 'best_acc1': best_acc1, + 'optimizer' : optimizer.state_dict(), + 'scheduler': lr_scheduler.state_dict(), + }, is_best, best_filename='{}_{}.pth.tar'.format(args.arch, args.quant)) + + +def train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5, ], + prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for i, (images, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + if torch.cuda.is_available(): + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if lr_scheduler is not None: + lr_scheduler.step() + + if i % args.print_freq == 0: + progress.display(i) + if i % 1000 == 0 and lr_scheduler is not None: + print('cur lr: ', lr_scheduler.get_lr()[0]) + + + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(val_loader), + [batch_time, losses, top1, top5], + prefix='Test: ') + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, (images, target) in enumerate(val_loader): + images = images.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + # TODO: this should also be done with the ProgressMeter + print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + + return top1.avg + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, best_filename) + + + + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/repvgg_quantized.py b/repvgg_quantized.py index 76fae4f..c064aa7 100644 --- a/repvgg_quantized.py +++ b/repvgg_quantized.py @@ -1,105 +1,82 @@ import torch import torch.nn as nn -from repvgg import get_RepVGG_func_by_name -from utils import load_checkpoint - -# class RepVGGRestoredBlock(nn.Module): -# -# def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups=1): -# super().__init__() -# sq = nn.Sequential() -# sq.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, -# kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False)) -# sq.add_module('bn', nn.BatchNorm2d(out_channels)) -# sq.add_module('relu', nn.ReLU()) -# self.restore = sq -# -# def forward(self, input): -# return self.restore(input) - +import math +from collections import OrderedDict +from torch.quantization import QuantStub, DeQuantStub class RepVGGQuant(nn.Module): + # {0: def __init__(self, repvgg_model, - stage3_splits, - width_multiplier=None, - quant_stages=None): + stage_sections, + quant_stagesections): super(RepVGGQuant, self).__init__() - # repvgg = get_RepVGG_func_by_name(repvgg_name)(deploy=is_base_deploy) - # load_checkpoint(repvgg, base_weights) - - self.stage0, self.stage1, self.stage2 = repvgg_model.stage0, repvgg_model.stage1, repvgg_model.stage2 - if use_aux and split_stage3: - stage3_blocks = list(repvgg_model.stage3.children()) - num_blocks = len(stage3_blocks) - assert num_blocks % 2 == 0 - self.stage3_first = nn.Sequential(*stage3_blocks[:num_blocks // 2]) - self.stage3_second = nn.Sequential(*stage3_blocks[num_blocks // 2:]) - else: - self.stage3 = repvgg.stage3 - - self.stage4, self.gap, self.linear = repvgg.stage4, repvgg.gap, repvgg.linear + self.body = nn.Sequential() + assert 0 not in stage_sections + self.body.add_module('stage0', repvgg_model.stage0) - self.in_planes = min(64, int(64 * width_multiplier[0])) + for stage_idx in [1, 2, 3, 4]: + origin_stage = repvgg_model.__getattr__('stage{}'.format(stage_idx)) + if stage_idx in stage_sections: + sections = stage_sections[stage_idx] + origin_blocks = list(origin_stage.children()) + blocks_per_sections = math.ceil(len(origin_blocks) / sections) + for section_idx in range(sections): + cur_section_blocks = origin_blocks[section_idx * blocks_per_sections : min(len(origin_blocks), (section_idx + 1) * blocks_per_sections)] + od = OrderedDict() # We don't use a list to construct nn.Sequential because we don't want the existence of QuantStub and DeQuantStub to change the param names + do_quant = (stage_idx, section_idx) in quant_stagesections + if do_quant: # Quant this section. Insert the quant and dequant stubs + od['quant'] = QuantStub() + for i, b in enumerate(cur_section_blocks): + od[str(i)] = b + if do_quant: + od['dequant'] = DeQuantStub() + cur_section = nn.Sequential(od) + self.body.add_module('stage{}_{}'.format(stage_idx, section_idx), cur_section) + else: + if stage_idx in quant_stagesections: + od = OrderedDict() + od['quant'] = QuantStub() + for i, b in enumerate(origin_stage.children()): + od[str(i)] = b + od['dequant'] = DeQuantStub() + self.body.add_module('stage{}'.format(stage_idx), nn.Sequential(od)) + else: + self.body.add_module('stage{}'.format(stage_idx), origin_stage) - self.stage0 = RestoredBlock(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1) - self.cur_layer_idx = 1 - self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=2, do_quant=True) - self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2, do_quant=True) - self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2) - self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=2) - self.gap = nn.AdaptiveAvgPool2d(output_size=1) - self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes) + self.quant_stagesections = quant_stagesections + print('quant setting: ', self.quant_stagesections) - - # def _make_stage(self, planes, num_blocks, stride, do_quant=False): - # strides = [stride] + [1]*(num_blocks-1) - # blocks = [] - # if do_quant: - # blocks.append(QuantStub()) - # for stride in strides: - # cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1) - # blocks.append(RestoredBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3, - # stride=stride, padding=1, groups=cur_groups)) - # self.in_planes = planes - # self.cur_layer_idx += 1 - # if do_quant: - # blocks.append(DeQuantStub()) - # return nn.Sequential(*blocks) - - def _make_stage(self, planes, num_blocks, stride, do_quant=False): - strides = [stride] + [1]*(num_blocks-1) - blocks = OrderedDict() - if do_quant: - blocks['quant'] = QuantStub() - for i, stride in enumerate(strides): - cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1) - blocks[str(i)] = RestoredBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3, - stride=stride, padding=1, groups=cur_groups) - self.in_planes = planes - self.cur_layer_idx += 1 - if do_quant: - blocks['deq'] = DeQuantStub() - return nn.Sequential(blocks) + self.gap = repvgg_model.gap + self.linear = repvgg_model.linear def forward(self, x): - out = self.stage0(x) - out = self.stage1(out) - # print(out[0,:1,:1,:]) - out = self.stage2(out) - out = self.stage3(out) - out = self.stage4(out) + out = self.body(x) out = self.gap(out) out = out.view(out.size(0), -1) - # out = self.linear_quant(out) out = self.linear(out) return out + # From https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html def fuse_model(self): - # pass for m in self.modules(): if type(m) == nn.Sequential and hasattr(m, 'conv'): - torch.quantization.fuse_modules(m, ['conv', 'bn', 'relu'], inplace=True) \ No newline at end of file + torch.quantization.fuse_modules(m, ['conv', 'bn', 'relu'], inplace=True) #TODO note this + # torch.quantization.fuse_modules(m, ['conv', 'bn'], inplace=True) + + def prepare_quant(self): + # From https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html + self.fuse_model() + qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') + for q in self.quant_stagesections: + if type(q) is int: + quant_stage_or_section = self.body.__getattr__('stage{}'.format(q)) + print('prepared quant for stage', q) + else: + quant_stage_or_section = self.body.__getattr__('stage{}_{}'.format(q[0], q[1])) + print('prepared quant for stage', q[0], 'section', q[1]) + quant_stage_or_section.qconfig = qconfig + torch.quantization.prepare_qat(quant_stage_or_section, inplace=True) \ No newline at end of file diff --git a/utils.py b/utils.py index b6a8853..43dbce0 100644 --- a/utils.py +++ b/utils.py @@ -1,4 +1,5 @@ import torch +import math class AverageMeter(object): """Computes and stores the average and current value""" @@ -100,4 +101,28 @@ def model_load_hdf5(model:torch.nn.Module, hdf5_path, ignore_keys='stage0.'): np_value = weights_dict[name.replace(ignore_keys, '')] value = torch.from_numpy(np_value).float() assert tuple(value.size()) == tuple(param.size()) - param.data = value \ No newline at end of file + param.data = value + + + +class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler): + + def __init__(self, optimizer, T_cosine_max, eta_min=0, last_epoch=-1, warmup=0): + self.eta_min = eta_min + self.T_cosine_max = T_cosine_max + self.warmup = warmup + super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch < self.warmup: + return [self.last_epoch / self.warmup * base_lr for base_lr in self.base_lrs] + else: + return [self.eta_min + (base_lr - self.eta_min) * + (1 + math.cos(math.pi * (self.last_epoch - self.warmup) / (self.T_cosine_max - self.warmup))) / 2 + for base_lr in self.base_lrs] + + +def log_msg(message, log_file): + print(message) + with open(log_file, 'a') as f: + print(message, file=f)