Skip to content

Commit

Permalink
Creats the bucket from PKB rather than from within the VM that PKB la…
Browse files Browse the repository at this point in the history
…unches.

Sets TPU zone MLPerf benchmark.

-------------
Created by MOE: https://github.com/google/moe
MOE_MIGRATED_REVID=254907013
  • Loading branch information
tohaowu authored and cwilkes committed Jun 25, 2019
1 parent 1c47a38 commit eb47798
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 20 deletions.
4 changes: 4 additions & 0 deletions CHANGES.next.md
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,7 @@
- RobustRemoteCommand now retries on ssh connections refused.
- Fixed a bug that was causing AWS capacity reservation creation to fail.
- GceNetworkSpec's zone is equal to the VM's zone
- Creates the bucket from PKB rather than from within the VM that PKB launches
in TPU test.
- Sets TPU zone MLPerf benchmark.

4 changes: 4 additions & 0 deletions perfkitbenchmarker/linux_benchmarks/mlperf_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def Run(benchmark_spec):
'export PYTHONPATH=$PYTHONPATH:$PWD/tpu/models && '
'cd results/v0.5.0/google/{code_path} && '
'sed -i "s/python /python3 /g" run_helper*.sh && '
'sed -i "s/tpu_zone=[^[:space:]]*/tpu_zone={tpu_zone}/g" '
'run_helper*.sh && '
'mkdir -p $MLP_HOST_OUTPUT_DIR && '
'{cmd}'.format(
model_dir=benchmark_spec.model_dir,
Expand All @@ -230,6 +232,8 @@ def Run(benchmark_spec):
tpu_eval=(benchmark_spec.tpu_groups['eval'].GetName()
if benchmark_spec.tpus else ''),
code_path=code_path,
tpu_zone=(benchmark_spec.tpu_groups['train'].GetZone()
if benchmark_spec.tpus else ''),
cmd=cmd))
if cuda_toolkit.CheckNvidiaGpuExists(vm):
mlperf_benchmark_cmd = '{env} {cmd}'.format(
Expand Down
28 changes: 8 additions & 20 deletions perfkitbenchmarker/linux_benchmarks/mnist_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,14 @@ def Prepare(benchmark_spec):
vm.Install('tensorflow_models')
if benchmark_spec.tpus:
storage_service = gcs.GoogleCloudStorageService()
storage_service.PrepareVM(vm)
benchmark_spec.storage_service = storage_service
model_dir = 'gs://{}'.format(FLAGS.run_uri)
benchmark_spec.model_dir = model_dir
vm.RemoteCommand(
'{gsutil} mb -c regional -l {location} {model_dir}'.format(
gsutil=vm.gsutil_path,
location=util.GetRegionFromZone(
benchmark_spec.tpu_groups['train'].GetZone()),
model_dir=benchmark_spec.model_dir), should_log=True)
vm.RemoteCommand(
'{gsutil} acl ch -u {service_account}:W {model_dir}'.format(
gsutil=vm.gsutil_path,
service_account=benchmark_spec.gcp_service_account,
model_dir=benchmark_spec.model_dir), should_log=True)
bucket = 'pkb{}'.format(FLAGS.run_uri)
benchmark_spec.bucket = bucket
benchmark_spec.model_dir = 'gs://{}'.format(bucket)
location = benchmark_spec.tpu_groups['train'].GetZone()
storage_service.PrepareService(util.GetRegionFromZone(location))
storage_service.MakeBucket(bucket)
storage_service.ChmodBucket(benchmark_spec.gcp_service_account, 'W', bucket)
else:
benchmark_spec.model_dir = '/tmp'

Expand Down Expand Up @@ -364,9 +357,4 @@ def Cleanup(benchmark_spec):
required to run the benchmark.
"""
if benchmark_spec.tpus:
vm = benchmark_spec.vms[0]
vm.RemoteCommand(
'{gsutil} rm -r {model_dir}'.format(
gsutil=vm.gsutil_path,
model_dir=benchmark_spec.model_dir), should_log=True)
benchmark_spec.storage_service.CleanupVM(vm)
benchmark_spec.storage_service.DeleteBucket(benchmark_spec.bucket)
13 changes: 13 additions & 0 deletions perfkitbenchmarker/providers/gcp/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ def EmptyBucket(self, bucket):
['gsutil', '-m', 'rm', '-r',
'gs://%s/*' % bucket])

def ChmodBucket(self, account, access, bucket):
"""Updates access control lists.
Args:
account: string, the user to be granted.
access: string, the permission to be granted.
bucket: string, the name of the bucket to change
"""
vm_util.IssueCommand([
'gsutil', 'acl', 'ch', '-u',
'{account}:{access}'.format(account=account, access=access),
'gs://{}'.format(bucket)])

def PrepareVM(self, vm):
vm.Install('wget')
# Unfortunately there isn't one URL scheme that works for both
Expand Down

0 comments on commit eb47798

Please sign in to comment.