Skip to content

Commit

Permalink
nvidia_deeprecommender: increase gpu utilization (pytorch#459)
Browse files Browse the repository at this point in the history
Grow parameter and batch sizes so utilization matches original netflix dataset.
  • Loading branch information
eircfb authored Sep 16, 2021
1 parent 8c8deab commit 35e2828
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
6 changes: 2 additions & 4 deletions torchbenchmark/models/nvidia_deeprecommender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@ def __init__(self, device="cpu", jit=False):
self.infer_obj = DeepRecommenderInferenceBenchmark(device = self.device, jit = jit)

def get_module(self):
# FIXME: This is not implemented.
raise NotImplementedError("deeprecommender is work in progress")
if self.eval_mode:
return lambda x: self.eval(), [0]
return self.infer_obj.rencoder, (self.infer_obj.toyinputs,)

return lambda x: self.train(), [0]
return self.train_obj.rencoder, (self.train_obj.toyinputs,)

def set_eval(self):
self.eval_mode = True
Expand Down
19 changes: 10 additions & 9 deletions torchbenchmark/models/nvidia_deeprecommender/nvinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,13 @@ def __init__(self, device = 'cpu', jit=False, usecommandlineargs = False) :

self.toytest = True

self.batch_size = 256

# number of movies in netflix training set.
self.node_count = 197951

if self.toytest:
self.toyinputs = torch.randn(1,15178).to(device)
self.toyinputs = torch.randn(self.batch_size,self.node_count).to(device)

if usecommandlineargs:
self.args = getCommandLineArgs()
Expand All @@ -135,10 +140,6 @@ def __init__(self, device = 'cpu', jit=False, usecommandlineargs = False) :

self.args = getBenchmarkArgs(forcecuda)

if jit == True:
# jit not supported, quit init
return

args = processArgState(self.args)

self.params = dict()
Expand Down Expand Up @@ -166,7 +167,7 @@ def __init__(self, device = 'cpu', jit=False, usecommandlineargs = False) :
self.eval_params['data_dir'] = self.args.path_to_eval_data

if self.toytest:
self.rencoder = model.AutoEncoder(layer_sizes=[15178] + [int(l) for l in self.args.hidden_layers.split(',')],
self.rencoder = model.AutoEncoder(layer_sizes=[self.node_count] + [int(l) for l in self.args.hidden_layers.split(',')],
nl_type=self.args.non_linearity_type,
is_constrained=self.args.constrained,
dp_drop_prob=self.args.drop_prob,
Expand Down Expand Up @@ -195,11 +196,11 @@ def __init__(self, device = 'cpu', jit=False, usecommandlineargs = False) :
print('######################################################')

self.rencoder.eval()

if self.args.jit:
self.rencoder = torch.jit.trace(self.rencoder, (self.toyinputs, ))

if self.args.use_cuda: self.rencoder = self.rencoder.cuda()

if self.args.jit:
self.rencoder = torch.jit.script(self.rencoder)

if self.toytest == False:
self.inv_userIdMap = {v: k for k, v in self.data_layer.userIdMap.items()}
Expand Down
24 changes: 14 additions & 10 deletions torchbenchmark/models/nvidia_deeprecommender/nvtrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def processTrainArgState(args) :
print(args)

if args.forcecpu and args.forcecuda:
print("Error, force cpu and cuda cannot bother be set")
print("Error, force cpu and cuda cannot both be set")
quit()

args.use_cuda = torch.cuda.is_available() # global flag
Expand Down Expand Up @@ -187,7 +187,12 @@ def TrainInit(self, device="cpu", jit=False, processCommandLine = False):

# Force test to run in toy mode. Single call of fake data to model.
self.toytest = True
self.toybatch = 15178
self.toybatch = 256

# number of movies in netflix training set.
self.toyvocab = 197951

self.toyinputs = torch.randn(self.toybatch, self.toyvocab)

if (processCommandLine) :
self.args = getTrainCommandLineArgs()
Expand All @@ -202,10 +207,6 @@ def TrainInit(self, device="cpu", jit=False, processCommandLine = False):
# unknown device string, quit init
return

# jit not supported, quit here if jit is requested
if jit == True:
return

self.args.forcecuda = forcecuda
self.args.forcecpu = not forcecuda

Expand Down Expand Up @@ -236,7 +237,7 @@ def TrainInit(self, device="cpu", jit=False, processCommandLine = False):

# must set eval batch size to 1 to make sure no examples are missed
if self.toytest:
self.rencoder = model.AutoEncoder(layer_sizes=[15178] + [int(l) for l in self.args.hidden_layers.split(',')],
self.rencoder = model.AutoEncoder(layer_sizes=[self.toyvocab] + [int(l) for l in self.args.hidden_layers.split(',')],
nl_type=self.args.non_linearity_type,
is_constrained=self.args.constrained,
dp_drop_prob=self.args.drop_prob,
Expand Down Expand Up @@ -267,6 +268,9 @@ def TrainInit(self, device="cpu", jit=False, processCommandLine = False):
print(self.rencoder)
print('######################################################')
print('######################################################')

if jit:
self.rencoder = torch.jit.trace(self.rencoder, (self.toyinputs,))

if self.args.use_cuda:
gpu_ids = [int(g) for g in self.args.gpu_ids.split(',')]
Expand All @@ -276,7 +280,10 @@ def TrainInit(self, device="cpu", jit=False, processCommandLine = False):
if len(gpu_ids)>1:
self.rencoder = nn.DataParallel(self.rencoder,
device_ids=gpu_ids)

self.rencoder = self.rencoder.cuda()
self.toyinputs = self.toyinputs.to(device)


if self.args.optimizer == "adam":
self.optimizer = optim.Adam(self.rencoder.parameters(),
Expand Down Expand Up @@ -307,9 +314,6 @@ def TrainInit(self, device="cpu", jit=False, processCommandLine = False):
if self.args.noise_prob > 0.0:
self.dp = nn.Dropout(p=self.args.noise_prob)

if self.toytest:
self.toyinputs = torch.randn(128,self.toybatch).to(device)

def DoTrain(self):

self.rencoder.train()
Expand Down

0 comments on commit 35e2828

Please sign in to comment.