Skip to content

Commit

Permalink
Merge branch 'pro_hot_fix' into 'master'
Browse files Browse the repository at this point in the history
fix sparse dense training

See merge request data/monolith!2114

GitOrigin-RevId: 90ec226aabb12f63eb1e0c8cfff5b6e490462008
  • Loading branch information
王才华 authored and monolith committed Aug 30, 2023
1 parent 77335d9 commit fc14ddb
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 58 deletions.
1 change: 1 addition & 0 deletions monolith/native_training/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1258,5 +1258,6 @@ py_library(
deps = [
":distribution_utils",
":yarn_runtime",
"//monolith/native_training/model_export:export_context",
],
)
6 changes: 4 additions & 2 deletions monolith/native_training/cpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
from monolith.native_training.model_export import saved_model_exporters
from monolith.native_training.model_export import export_context
from monolith.native_training.model_export.export_context import \
is_exporting, is_exporting_distributed, ExportMode
is_exporting, is_exporting_distributed, is_dry_run_or_exporting, ExportMode
from monolith.native_training.native_task import NativeTask
from monolith.native_training.prefetch_queue import \
enqueue_dicts_with_queue_return, EnqueueHook
Expand Down Expand Up @@ -807,7 +807,7 @@ def wrapped_input_fn():
if isinstance(ds, tf.data.Dataset):
if enable_reorder:
ds = ds.map(reorder_parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
if use_dataservice:
if use_dataservice and not is_dry_run_or_exporting():
# This is a temporary hack. Will revisit here once we decided to
# do the remanagement.
tmp_mlp_env = mlp_utils.MLPEnv()
Expand Down Expand Up @@ -1830,6 +1830,8 @@ def _do_worker_train(config: DistributedCpuTrainingConfig,
params.mode = tf.estimator.ModeKeys.PREDICT
native_task = params.instantiate()
training = CpuTraining(config, native_task)
if config.enable_partial_sync_training or config.use_dataservice:
training = sync_training_hooks.EofAwareTask(training, config.use_dataservice)
estimator = tf.estimator.Estimator(training.create_model_fn(), config=run_config)
estimator.train(training._task.create_item_input_fn(
items_path), max_steps=params.train.max_steps)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,7 @@ class ParquetDatasetOp : public DatasetOpKernel {
LOG(INFO) << "end_of_sequence of " << dataset()->file_name_;
} else {
counter_++;
if (counter_ % 1000 == 0) {
LOG(INFO) << "consume " << counter_ << "examples from "
<< dataset()->file_name_;
}
LOG_EVERY_N_SEC(INFO, 60) << "consume " << counter_ << " examples from " << dataset()->file_name_;
}
} else if (dataset()->output_pb_type_ == "examplebatch") {
ExampleBatch example_batch;
Expand Down
72 changes: 29 additions & 43 deletions monolith/native_training/mlp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,15 @@
from tensorflow_estimator.python.estimator.util import _DatasetInitializerHook
from monolith.native_training.distribution_utils import get_mpi_rank, \
get_mpi_size, get_mpi_local_size, enable_sync_training, get_device_str
from monolith.native_training.model_export.export_context import \
is_exporting, is_dry_run_or_exporting


FLAGS = flags.FLAGS

from monolith.native_training import yarn_runtime


def kill_by_port(port: int):
process = Popen(["lsof", "-i", ":{0}".format(port)], stdout=PIPE, stderr=PIPE)
stdout, stderr = process.communicate()
try:
pid = None
for process in str(stdout.decode("utf-8")).split("\n")[1:]:
data = [x for x in process.split(" ") if x]
if data and len(data) > 1:
pid = int(data[1])
break
print('pid is', pid)
except:
pass
if pid is not None:
os.kill(pid, signal.SIGKILL)


def check_port(host: str, port: int, timeout: float = 1) -> bool:
is_ipv6 = ':' in host.strip('[]')
skt = socket.socket(socket.AF_INET6 if is_ipv6 else socket.AF_INET,
Expand Down Expand Up @@ -420,36 +406,36 @@ def mlp_pass(dispatcher_role: str = 'dispatcher',
def begin(self):
self._initializer = self._iterator.initializer
self._broadcast_dataset_id = None
self._rank = -1
graph = tf.compat.v1.get_default_graph()
if enable_sync_training() and not hasattr(graph, 'dry_run'):
try:
enable_bps = int(os.getenv("MONOLITH_WITH_BYTEPS", "0"))
if enable_bps:
import byteps.tensorflow as hvd
else:
import horovod.tensorflow as hvd

dataset_ids = tf.compat.v1.get_collection(key='registed_dataset_id')
if dataset_ids is not None and len(dataset_ids) > 0:
dataset_id = dataset_ids[0]
if dataset_id is not None:
self._rank = hvd.rank()
#with tf.device(None), tf.device(get_device_str(True)):
self._broadcast_dataset_id = [
dataset_id,
hvd.broadcast(tensor=dataset_id,
root_rank=0,
name="broadcast_dataset_id")
]
graph.clear_collection(name='registed_dataset_id')
except Exception as e:
logging.info(f'import byteps/horovod error: {e}')
if not is_dry_run_or_exporting():
self._rank = -1
if enable_sync_training():
try:
enable_bps = int(os.getenv("MONOLITH_WITH_BYTEPS", "0"))
if enable_bps:
import byteps.tensorflow as hvd
else:
import horovod.tensorflow as hvd

dataset_ids = tf.compat.v1.get_collection(key='registed_dataset_id')
if dataset_ids is not None and len(dataset_ids) > 0:
dataset_id = dataset_ids[0]
if dataset_id is not None:
self._rank = hvd.rank()
#with tf.device(None), tf.device(get_device_str(True)):
self._broadcast_dataset_id = [
dataset_id,
hvd.broadcast(tensor=dataset_id,
root_rank=0,
name="broadcast_dataset_id")
]
graph.clear_collection(name='registed_dataset_id')
except Exception as e:
logging.info(f'import byteps/horovod error: {e}')


def after_create_session(self, session, coord):
del coord
if self._broadcast_dataset_id is not None:
if self._broadcast_dataset_id is not None and not is_dry_run_or_exporting():
dataset_id, bc_dataset_id = session.run(self._broadcast_dataset_id)
logging.info(
f'dataset_id is {dataset_id}, bc_dataset_id is {bc_dataset_id}, rank {self._rank}'
Expand Down
6 changes: 6 additions & 0 deletions monolith/native_training/model_export/export_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,9 @@ def enter_export_mode(mode: ExportMode, export_ctx=None):
finally:
EXPORT_MODE = ExportMode.NONE
EXPORT_CTX = None


@monolith_export
def is_dry_run_or_exporting():
graph = tf.compat.v1.get_default_graph()
return is_exporting() or hasattr(graph, 'dry_run')
9 changes: 5 additions & 4 deletions monolith/native_training/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from monolith.native_training.monolith_checkpoint_state_pb2 import MonolithCheckpointState
from monolith.native_training.net_utils import AddressFamily
from monolith.native_training import save_utils
from monolith.native_training.mlp_utils import mlp_pass, add_mpi_exception_hook, MLPEnv, kill_by_port
from monolith.native_training.mlp_utils import mlp_pass, add_mpi_exception_hook, MLPEnv

FLAGS = flags.FLAGS
old_isabs = os.path.isabs
Expand Down Expand Up @@ -239,6 +239,10 @@ def __post_init__(self):
except:
logging.info("update RunnerConfig failed")

if self.enable_gpu_training and self.enable_partial_sync_training:
if (self.index <= 0 or self.index is None) and self.server_type == 'worker':
self.index = int(os.environ.get('OMPI_COMM_WORLD_RANK') or '0')

if self.kafka_topics:
if isinstance(self.kafka_topics, str):
self.kafka_topics = self.kafka_topics.split(',')
Expand Down Expand Up @@ -383,9 +387,6 @@ def monolith_discovery(runner_conf: RunnerConfig):

logging.info('enter monolith_discovery!')
yield discovery
mlp_env = MLPEnv()
if mlp_env.avaiable:
kill_by_port(mlp_env.ssh_port)
except Exception as e:
raise e
finally:
Expand Down
13 changes: 8 additions & 5 deletions monolith/native_training/sync_training_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ class EofAwareTask:
def __init__(self, task: native_task.NativeTask, use_dataservice: bool = False):
self._ori_task = task
self.use_dataservice = use_dataservice
logging.info(f'init EofAwareTask')

def create_input_fn(self, mode):

Expand All @@ -278,6 +279,9 @@ def new_input_fn_factory(input_fn):

def new_input_fn():
ds = input_fn()
if export_context.is_dry_run_or_exporting():
return ds

ds = datasets.CacheOneDataset(ds)

# There are 2 reasons why we need a map here:
Expand All @@ -286,7 +290,7 @@ def new_input_fn():
# the original data after we wrap the input_fn output.
def map_fn(features, eof):
if isinstance(features, dict):
logging.info(EofAwareTask.EOF_KEY)
logging.info(f"in map_fn: {EofAwareTask.EOF_KEY}")
return {**features, EofAwareTask.EOF_KEY: eof}
logging.info('map_fn keys: 1, 2')
return {"1": features, "2": eof}
Expand All @@ -305,20 +309,20 @@ def create_model_fn(self):
model_fn = self._ori_task.create_model_fn()

def new_model_fn_factory(model_fn):
if export_context.is_exporting():
if export_context.is_dry_run_or_exporting():
return model_fn

def new_model_fn(features, mode, config):
if EofAwareTask.EOF_KEY in features:
logging.info(f"in model_fn: {EofAwareTask.EOF_KEY}")
eof = features[EofAwareTask.EOF_KEY]
features.pop(EofAwareTask.EOF_KEY)
real_features = features
else:
real_features, eof = features["1"], features["2"]
spec: tf.estimator.EstimatorSpec = model_fn(real_features, mode, config)
training_hooks = spec.training_hooks or ()
training_hooks = list(training_hooks)
training_hooks.append(self.EofHook(eof))
training_hooks = [self.EofHook(eof)] + list(training_hooks)
spec = spec._replace(training_hooks=training_hooks)
return spec

Expand Down Expand Up @@ -349,4 +353,3 @@ def after_run(self, run_context, run_values):
logging.info(f'rank {hvd_lib.rank()} request_stop, results is {run_values.results}, before')
run_context.request_stop()
logging.info(f'rank {hvd_lib.rank()} request_stop, results is {run_values.results}, after')

0 comments on commit fc14ddb

Please sign in to comment.