Skip to content

Commit

Permalink
Fix benchmark_single_table with custom synthesizers and timeout (#337)
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho authored Aug 22, 2024
1 parent 2d8e00d commit 0aa0637
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 29 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
'appdirs>=1.3',
'boto3>=1.28,<2',
'botocore>=1.31,<2',
'cloudpickle>=2.1.0',
'compress-pickle>=1.2.0',
'humanfriendly>=8.2',
"numpy>=1.21.0,<2.0.0;python_version<'3.10'",
Expand Down
69 changes: 46 additions & 23 deletions sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import pickle
import tracemalloc
import warnings
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path

import boto3
import cloudpickle
import compress_pickle
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -318,6 +320,26 @@ def _score(
return output


@contextmanager
def multiprocessing_context():
"""Override multiprocessing ForkingPickler to use cloudpickle."""
original_dump = multiprocessing.reduction.ForkingPickler.dumps
original_load = multiprocessing.reduction.ForkingPickler.loads
original_method = multiprocessing.get_start_method()

multiprocessing.set_start_method('spawn', force=True)
multiprocessing.reduction.ForkingPickler.dumps = cloudpickle.dumps
multiprocessing.reduction.ForkingPickler.loads = cloudpickle.loads

try:
yield
finally:
# Restore original methods
multiprocessing.set_start_method(original_method, force=True)
multiprocessing.reduction.ForkingPickler.dumps = original_dump
multiprocessing.reduction.ForkingPickler.loads = original_load


def _score_with_timeout(
timeout,
synthesizer,
Expand All @@ -329,32 +351,33 @@ def _score_with_timeout(
modality=None,
dataset_name=None,
):
with multiprocessing.Manager() as manager:
output = manager.dict()
process = multiprocessing.Process(
target=_score,
args=(
synthesizer,
data,
metadata,
metrics,
output,
compute_quality_score,
compute_diagnostic_score,
modality,
dataset_name,
),
)
with multiprocessing_context():
with multiprocessing.Manager() as manager:
output = manager.dict()
process = multiprocessing.Process(
target=_score,
args=(
synthesizer,
data,
metadata,
metrics,
output,
compute_quality_score,
compute_diagnostic_score,
modality,
dataset_name,
),
)

process.start()
process.join(timeout)
process.terminate()
process.start()
process.join(timeout)
process.terminate()

output = dict(output)
if output.get('timeout'):
LOGGER.error('Timeout running %s on dataset %s;', synthesizer['name'], dataset_name)
output = dict(output)
if output.get('timeout'):
LOGGER.error('Timeout running %s on dataset %s;', synthesizer['name'], dataset_name)

return output
return output


def _format_output(
Expand Down
23 changes: 17 additions & 6 deletions sdgym/synthesizers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def get_trained_synthesizer(self, data, metadata):
obj:
The trained synthesizer.
"""
return get_trained_synthesizer_fn(data, metadata)
return self.synthesizer_fn['get_trained_synthesizer_fn'](data, metadata)

def sample_from_synthesizer(self, synthesizer, num_samples):
"""Sample the desired number of samples from the given synthesizer.
Expand All @@ -139,11 +139,22 @@ def sample_from_synthesizer(self, synthesizer, num_samples):
pandas.DataFrame:
The synthetic data.
"""
return sample_from_synthesizer_fn(synthesizer, num_samples)

NewSynthesizer.__name__ = f'Custom:{display_name}'

return NewSynthesizer
return self.synthesizer_fn['sample_from_synthesizer_fn'](synthesizer, num_samples)

CustomSynthesizer = type(
f'Custom:{display_name}',
(NewSynthesizer,),
{
'synthesizer_fn': {
'get_trained_synthesizer_fn': get_trained_synthesizer_fn,
'sample_from_synthesizer_fn': sample_from_synthesizer_fn,
},
},
)
CustomSynthesizer.__name__ = f'Custom:{display_name}'
CustomSynthesizer.__module__ = 'sdgym.synthesizers.generate'
globals()[f'Custom:{display_name}'] = CustomSynthesizer
return CustomSynthesizer


def create_multi_table_synthesizer(
Expand Down

0 comments on commit 0aa0637

Please sign in to comment.