Skip to content

Commit

Permalink
Use pydantic models for entities (mlcommons#366)
Browse files Browse the repository at this point in the history
* Implement benchmark pydantic model

* Implement CubeModel

* Make benchmark description optional

* Implement DatasetModel

* Remove unused import

* Implement ResultModel

* Progress on Benchmark Pydantic Model

* Revert "Progress on Benchmark Pydantic Model"

This reverts commit f8513af.

* Use model validation

* Allow passing entity models at instantiation

* Make description optional

* Use benchmark model for testing

* Implement Benchmark Entity as pydantic model

* Implement Cube entity as a pydantic model

* Fix Cube pydantic implementation

* Fix benchmark-specific tests

* Implement Dataset entity as pydantic model

* Add dataset status validator

* Remove DatasetModel

* Implement Result entity as pydantic model

* Fix MockCube

* Fix benchmark submission

* Change to schema. Filter private fields

* Fix Dataset commands

* Fix MLCube commands

* Fix result commands

* Fix tests after merge

* Remove unused mocks

* Fix tests

* Fix entity dicts returning None

* Fix list commands

* Update old entity parameters

* Update compatibility test mocks

* Update tests to use mocks

* Update tests to specific parameters

* Fix linter issues

* Fix typo

* Remove reduntant result schema

* Remove outdated init docstring args

* Add DeployableSchema

* empty commit for cloudbuild

* remove extra code introduced by merge main

---------

Co-authored-by: hasan7n <[email protected]>
  • Loading branch information
aristizabal95 and hasan7n authored Feb 8, 2023
1 parent f798481 commit c1f01c0
Show file tree
Hide file tree
Showing 39 changed files with 538 additions and 1,327 deletions.
6 changes: 3 additions & 3 deletions cli/medperf/commands/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def submit(
"name": name,
"description": description,
"docs_url": docs_url,
"demo_url": demo_url,
"demo_hash": demo_hash,
"demo_dataset_tarball_url": demo_url,
"demo_dataset_tarball_hash": demo_hash,
"data_preparation_mlcube": data_preparation_mlcube,
"reference_model_mlcube": reference_model_mlcube,
"evaluator_mlcube": evaluator_mlcube,
"data_evaluator_mlcube": evaluator_mlcube,
}
SubmitBenchmark.run(benchmark_info)
cleanup()
Expand Down
4 changes: 2 additions & 2 deletions cli/medperf/commands/benchmark/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ def run(local: bool = False, mine: bool = False):

data = [
[
bmark.uid if bmark.uid is not None else bmark.generated_uid,
bmark.id if bmark.id is not None else bmark.generated_uid,
bmark.name,
bmark.description,
bmark.state,
bmark.approval_status.value,
bmark.approval_status,
]
for bmark in benchmarks
]
Expand Down
127 changes: 19 additions & 108 deletions cli/medperf/commands/benchmark/submit.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import os
import shutil
import logging
from medperf.enums import Status
import validators

import medperf.config as config
from medperf.entities.benchmark import Benchmark
from medperf.utils import get_file_sha1, generate_tmp_uid, storage_path
from medperf.commands.compatibility_test import CompatibilityTestExecution
from medperf.exceptions import InvalidArgumentError, InvalidEntityError
from medperf.exceptions import InvalidEntityError


class SubmitBenchmark:
Expand All @@ -29,8 +27,6 @@ def run(cls, benchmark_info: dict, force_test: bool = True):
"""
ui = config.ui
submission = cls(benchmark_info, force_test)
if not submission.is_valid():
raise InvalidArgumentError("Invalid benchmark information")

with ui.interactive():
ui.text = "Getting additional information"
Expand All @@ -45,97 +41,38 @@ def run(cls, benchmark_info: dict, force_test: bool = True):
def __init__(self, benchmark_info: dict, force_test: bool = True):
self.comms = config.comms
self.ui = config.ui
self.name = benchmark_info["name"]
self.description = benchmark_info["description"]
self.docs_url = benchmark_info["docs_url"]
self.demo_url = benchmark_info["demo_url"]
self.demo_hash = benchmark_info["demo_hash"]
self.demo_uid = None
self.data_preparation_mlcube = benchmark_info["data_preparation_mlcube"]
self.reference_model_mlcube = benchmark_info["reference_model_mlcube"]
self.data_evaluator_mlcube = benchmark_info["evaluator_mlcube"]
self.results = None
self.bmk = Benchmark(**benchmark_info)
self.force_test = force_test

def is_valid(self) -> bool:
"""Validates that user-provided benchmark information is correct
Returns:
bool: Wether or not the benchmark information is valid
"""
name_valid_length = 0 < len(self.name) < 20
desc_valid_length = len(self.description) < 100
docs_url_valid = self.docs_url == "" or validators.url(self.docs_url)
demo_url_valid = self.demo_url == "" or validators.url(self.demo_url)
demo_hash_if_no_url = self.demo_url or self.demo_hash
prep_uid_valid = self.data_preparation_mlcube.isdigit()
model_uid_valid = self.reference_model_mlcube.isdigit()
eval_uid_valid = self.data_evaluator_mlcube.isdigit()

valid_tests = [
("name", name_valid_length, "Name is invalid"),
("description", desc_valid_length, "Description is too long"),
("docs_url", docs_url_valid, "Documentation URL is invalid"),
("demo_url", demo_url_valid, "Demo Dataset Tarball URL is invalid"),
(
"demo_hash",
demo_hash_if_no_url,
"Demo Datset Hash must be provided if no URL passed",
),
(
"data_preparation_mlcube",
prep_uid_valid,
"Data Preparation MLCube UID is invalid",
),
(
"reference_model_mlcube",
model_uid_valid,
"Reference Model MLCube is invalid",
),
(
"data_evaluator_mlcube",
eval_uid_valid,
"Data Evaluator MLCube is invalid",
),
]

valid = True
for attr, test, error_msg in valid_tests:
if not test:
valid = False
self.ui.print_error(error_msg)

return valid

def get_extra_information(self):
"""Retrieves information that must be populated automatically,
like hash, generated uid and test results
"""
tmp_uid = self.demo_hash if self.demo_hash else generate_tmp_uid()
demo_dset_path = self.comms.get_benchmark_demo_dataset(self.demo_url, tmp_uid)
bmk_demo_url = self.bmk.demo_dataset_tarball_url
bmk_demo_hash = self.bmk.demo_dataset_tarball_hash
tmp_uid = bmk_demo_hash if bmk_demo_hash else generate_tmp_uid()
demo_dset_path = self.comms.get_benchmark_demo_dataset(bmk_demo_url, tmp_uid)
demo_hash = get_file_sha1(demo_dset_path)
if self.demo_hash and demo_hash != self.demo_hash:
logging.error(
f"Demo dataset hash mismatch: {demo_hash} != {self.demo_hash}"
)
if bmk_demo_hash and demo_hash != bmk_demo_hash:
logging.error(f"Demo dataset hash mismatch: {demo_hash} != {bmk_demo_hash}")
raise InvalidEntityError(
"Demo dataset hash does not match the provided hash"
)
self.demo_hash = demo_hash
self.bmk.demo_dataset_tarball_hash = demo_hash
demo_uid, results = self.run_compatibility_test()
self.demo_uid = demo_uid
self.results = results
self.bmk.demo_dataset_generated_uid = demo_uid
self.bmk.metadata = {"results": results}

def run_compatibility_test(self):
"""Runs a compatibility test to ensure elements are compatible,
and to extract additional information required for submission
"""
self.ui.print("Running compatibility test")
data_prep = self.data_preparation_mlcube
model = self.reference_model_mlcube
evaluator = self.data_evaluator_mlcube
demo_url = self.demo_url
demo_hash = self.demo_hash
data_prep = self.bmk.data_preparation_mlcube
model = self.bmk.reference_model_mlcube
evaluator = self.bmk.data_evaluator_mlcube
demo_url = self.bmk.demo_dataset_tarball_url
demo_hash = self.bmk.demo_dataset_tarball_hash
benchmark = Benchmark.tmp(data_prep, model, evaluator, demo_url, demo_hash)
_, data_uid, _, results = CompatibilityTestExecution.run(
benchmark.generated_uid, force_test=self.force_test
Expand All @@ -147,34 +84,8 @@ def run_compatibility_test(self):

return data_uid, results

def todict(self):
return {
"name": self.name,
"description": self.description,
"docs_url": self.docs_url,
"demo_dataset_tarball_url": self.demo_url,
"demo_dataset_tarball_hash": self.demo_hash,
"demo_dataset_generated_uid": self.demo_uid,
"data_preparation_mlcube": int(self.data_preparation_mlcube),
"reference_model_mlcube": int(self.reference_model_mlcube),
"data_evaluator_mlcube": int(self.data_evaluator_mlcube),
"state": "OPERATION",
"is_valid": True,
"approval_status": Status.PENDING.value,
"metadata": {"results": self.results.results},
"id": None,
"models": [int(self.reference_model_mlcube)], # not in the server (OK)
"created_at": None,
"modified_at": None,
"approved_at": None,
"owner": None,
"is_active": True,
"user_metadata": {},
}

def submit(self):
body = self.todict()
updated_body = Benchmark(body).upload()
updated_body = self.bmk.upload()
return updated_body

def to_permanent_path(self, bmk_dict: dict):
Expand All @@ -183,7 +94,7 @@ def to_permanent_path(self, bmk_dict: dict):
Args:
bmk_dict (dict): dictionary containing updated information of the submitted benchmark
"""
bmk = Benchmark(bmk_dict)
bmk = Benchmark(**bmk_dict)
bmks_storage = storage_path(config.benchmarks_storage)
old_bmk_loc = os.path.join(bmks_storage, bmk.generated_uid)
new_bmk_loc = bmk.path
Expand All @@ -192,5 +103,5 @@ def to_permanent_path(self, bmk_dict: dict):
os.rename(old_bmk_loc, new_bmk_loc)

def write(self, updated_body):
bmk = Benchmark(updated_body)
bmk = Benchmark(**updated_body)
bmk.write()
24 changes: 15 additions & 9 deletions cli/medperf/commands/compatibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ def prepare_test(self):
"""
if self.benchmark_uid:
self.benchmark = Benchmark.get(self.benchmark_uid)
self.set_cube_uid("data_prep", self.benchmark.data_preparation)
self.set_cube_uid("model", self.benchmark.reference_model)
self.set_cube_uid("evaluator", self.benchmark.evaluator)
self.demo_dataset_url = self.benchmark.demo_dataset_url
self.demo_dataset_hash = self.benchmark.demo_dataset_hash
self.set_cube_uid("data_prep", self.benchmark.data_preparation_mlcube)
self.set_cube_uid("model", self.benchmark.reference_model_mlcube)
self.set_cube_uid("evaluator", self.benchmark.data_evaluator_mlcube)
self.demo_dataset_url = self.benchmark.demo_dataset_tarball_url
self.demo_dataset_hash = self.benchmark.demo_dataset_tarball_hash
else:
self.set_cube_uid("data_prep")
self.set_cube_uid("model")
Expand All @@ -115,8 +115,8 @@ def execute_benchmark(self):
"""
if (
not self.benchmark_uid
or self.benchmark.data_preparation != self.data_prep
or self.benchmark.evaluator != self.evaluator
or self.benchmark.data_preparation_mlcube != self.data_prep
or self.benchmark.data_evaluator_mlcube != self.evaluator
):
self.benchmark = Benchmark.tmp(self.data_prep, self.model, self.evaluator)
self.benchmark_uid = self.benchmark.generated_uid
Expand Down Expand Up @@ -190,12 +190,18 @@ def set_data_uid(self):
if self.data_uid is not None:
self.dataset = Dataset.get(self.data_uid)
# to avoid 'None' as a uid
self.data_prep = self.dataset.preparation_cube_uid
self.data_prep = self.dataset.data_preparation_mlcube
else:
logging.info("Using benchmark demo dataset")
data_path, labels_path = self.download_demo_data()
self.data_uid = DataPreparation.run(
None, self.data_prep, data_path, labels_path, run_test=True,
None,
self.data_prep,
data_path,
labels_path,
run_test=True,
name="demo_data",
location="local",
)
self.dataset = Dataset.get(self.data_uid)

Expand Down
6 changes: 3 additions & 3 deletions cli/medperf/commands/dataset/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ def run(data_uid: str, benchmark_uid: int, approved=False, force_test=False):
comms = config.comms
ui = config.ui
dset = Dataset.get(data_uid)
if dset.uid is None:
if dset.id is None:
msg = "The provided dataset is not registered."
raise InvalidArgumentError(msg)

benchmark = Benchmark.get(benchmark_uid)

if str(dset.preparation_cube_uid) != str(benchmark.data_preparation):
if str(dset.data_preparation_mlcube) != str(benchmark.data_preparation_mlcube):
raise InvalidArgumentError(
"The specified dataset wasn't prepared for this benchmark"
)
Expand All @@ -44,6 +44,6 @@ def run(data_uid: str, benchmark_uid: int, approved=False, force_test=False):
if approved:
ui.print("Generating dataset benchmark association")
metadata = {"test_result": result.results}
comms.associate_dset(dset.uid, benchmark_uid, metadata)
comms.associate_dset(dset.id, benchmark_uid, metadata)
else:
raise CleanExit("Dataset association operation cancelled.")
11 changes: 3 additions & 8 deletions cli/medperf/commands/dataset/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_prep_cube(self):
cube_uid = self.prep_cube_uid
if cube_uid is None:
benchmark = Benchmark.get(self.benchmark_uid)
cube_uid = benchmark.data_preparation
cube_uid = benchmark.data_preparation_mlcube
self.ui.print(f"Benchmark Data Preparation: {benchmark.name}")
self.ui.text = f"Retrieving data preparation cube: '{cube_uid}'"
self.cube = Cube.get(cube_uid)
Expand Down Expand Up @@ -215,19 +215,14 @@ def todict(self) -> dict:
"name": self.name,
"description": self.description,
"location": self.location,
"data_preparation_mlcube": self.cube.uid,
"data_preparation_mlcube": self.cube.id,
"input_data_hash": self.in_uid,
"generated_uid": self.generated_uid,
"split_seed": 0, # Currently this is not used
"generated_metadata": self.get_temp_stats(),
"status": Status.PENDING.value, # not in the server
"state": "OPERATION",
"separate_labels": self.labels_specified, # not in the server
"is_valid": True,
"user_metadata": {},
"created_at": None,
"modified_at": None,
"owner": None,
}

def get_temp_stats(self):
Expand All @@ -246,5 +241,5 @@ def write(self) -> str:
filename (str, optional): name of the file. Defaults to config.reg_file.
"""
dataset_dict = self.todict()
dataset = Dataset(dataset_dict)
dataset = Dataset(**dataset_dict)
dataset.write()
6 changes: 3 additions & 3 deletions cli/medperf/commands/dataset/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def run(local: bool = False, mine: bool = False):
# Get local dsets information
dsets_data = [
[
dset.uid if dset.uid is not None else dset.generated_uid,
dset.id if dset.id is not None else dset.generated_uid,
dset.name,
dset.preparation_cube_uid,
dset.uid is not None,
dset.data_preparation_mlcube,
dset.id is not None,
True,
]
for dset in dsets
Expand Down
6 changes: 3 additions & 3 deletions cli/medperf/commands/dataset/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def run(data_uid: str, approved=False):
ui = config.ui
dset = Dataset.get(data_uid)

if dset.uid is not None:
if dset.id is not None:
# TODO: should get_dataset and update locally. solves existing issue?
raise InvalidArgumentError("This dataset has already been registered")
remote_dsets = comms.get_user_datasets()
Expand All @@ -30,7 +30,7 @@ def run(data_uid: str, approved=False):
if remote_dset["generated_uid"] == dset.generated_uid
]
if len(remote_dset) == 1:
dset = Dataset(remote_dset[0])
dset = Dataset(**remote_dset[0])
dset.write()
ui.print(f"Remote dataset {dset.name} detected. Updating local dataset.")
return
Expand All @@ -42,7 +42,7 @@ def run(data_uid: str, approved=False):
if approved:
ui.print("Uploading...")
updated_dset_dict = dset.upload()
updated_dset = Dataset(updated_dset_dict)
updated_dset = Dataset(**updated_dset_dict)

old_dset_loc = dset.path
new_dset_loc = updated_dset.path
Expand Down
2 changes: 1 addition & 1 deletion cli/medperf/commands/mlcube/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def run(local: bool = False, mine: bool = False):
headers = ["MLCube UID", "Name", "State"]
cubes_data = [
[
cube.uid if cube.uid is not None else cube.generated_uid,
cube.id if cube.id is not None else cube.generated_uid,
cube.name,
cube.state,
]
Expand Down
Loading

0 comments on commit c1f01c0

Please sign in to comment.