Skip to content

Commit

Permalink
Merge branch 'hanzhi-demo' into 'master'
Browse files Browse the repository at this point in the history
Monolith open source demo

See merge request data/monolith!1874

GitOrigin-RevId: 58a70fbb367190fb325a21d149ba0b5f4e4aca52
  • Loading branch information
hanzhi713 authored and zlqiszlqbd committed Jan 4, 2023
1 parent 3123816 commit 26e2dee
Show file tree
Hide file tree
Showing 16 changed files with 703 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,6 @@ target/
/bazel-*
/.vscode

# Jupyter noteobok
.ipynb_checkpoints

82 changes: 82 additions & 0 deletions markdown/demo/AWS-EKS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Distributed async training on EKS

To scale to multiple machines and handle failure recovery, we can utilize container orchestration frameworks such as yarn and kubernetes. Regradless what tool you use, as long as the `TF_CONFIG` environment variable is correctly set for each worker and ps, it will work just fine.

In this tutorial, we will show how to setup distributed training using kubernetes, kubeflow, and AWS's elastic kubernetes service (EKS). Kubeflow is used as the middleware that injects `TF_CONFIG` environment variable for each worker container.

## Prerequisite

Setup kubeflow on AWS by following the official guide. It will also help you to setup other tools such as aws cli and eksctl. Make sure to complete

- Prerequisites
- Create an EKS Cluster
- Vanilla Installation

https://awslabs.github.io/kubeflow-manifests/docs/deployment/


## Prepare monolith docker

TODO

## Write Spec and launch training

If you have completed all the prerequisites, `kubectl` should be able to connect to your cluster on AWS.

Now, create a spec file called `aws-tfjob.yaml`.

```yaml
apiVersion: "kubeflow.org/v1"
kind: "TFJob"
metadata:
name: "monolith-train"
namespace: kubeflow
spec:
runPolicy:
cleanPodPolicy: None
tfReplicaSpecs:
Worker:
replicas: 4
restartPolicy: Never
template:
metadata:
annotations:
# solve RBAC permission problem
sidecar.istio.io/inject: "false"
spec:
containers:
- name: tensorflow
image: YOUR_IMAGE
args:
- --model_dir=/tmp/model
PS:
replicas: 4
restartPolicy: Never
template:
metadata:
annotations:
sidecar.istio.io/inject: "false"
spec:
containers:
- name: tensorflow
image: YOUR_IMAGE
args:
- --model_dir=/tmp/model
```
Then, launch training:
```bash
kubectl apply -f aws-tfjob.yaml
```

To view the status of workers, you can use

```bash
# use this to list pods
kubectl --namespace kubeflow get pods
# use this get a log of a worker
kubectl --namespace kubeflow logs monolith-train-worker-0
```

Of course, there are other middlewares built on top of kubeflow to better help you to keep track of the training progress. Monolith's compatibility with tensorflow means that tools that are built for tensorflow will likely work with Monolith too.
45 changes: 45 additions & 0 deletions markdown/demo/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
load("@rules_python//python:defs.bzl", "py_binary", "py_library")

package(default_visibility = ["//visibility:public"])

py_binary(
name = "kafka_producer",
srcs = ["kafka_producer.py"],
deps = [
"@org_tensorflow//tensorflow:tensorflow_py"
]
)

py_binary(
name = "kafka_receiver",
srcs = ["kafka_receiver.py"],
deps = [
"//monolith/native_training:native_model"
]
)

py_binary(
name = "demo_model",
srcs = ["demo_model.py"],
deps = [
"//monolith/native_training:native_model",
":kafka_producer",
":kafka_receiver"
]
)

py_binary(
name = "demo_local_runner",
srcs = ["demo_local_runner.py"],
deps = [
":demo_model"
]
)

py_binary(
name = "preprocess_data",
srcs = ["preprocess_data.py"],
deps = [
":demo_model"
]
)
133 changes: 133 additions & 0 deletions markdown/demo/Batch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Movie Ranking Batch Training

This tutorial demonstrates how to use Monolith to perform a movie ranking task. This tutorial is essentially the same as [Tensorflow's tutorial on movie ranking](https://www.tensorflow.org/recommenders/examples/basic_ranking), but with Monolith's API. Through this tutorial, you'll learn the similarity and differences between Monolith and native Tensorflow. Additionally, we'll showcase how batching training and stream training is done with Monolith.

## Building the Model

Source code: [kafka_producer.py](./kafka_producer.py)

### Monolith Model API

```python
class MovieRankingModel(MonolithModel):
def __init__(self, params):
super().__init__(params)
self.p = params
self.p.serving.export_when_saving = True

def input_fn(self, mode):
return dataset

def model_fn(self, features, mode):
# features =
return EstimatorSpec(...)

def serving_input_receiver_fn(self):
return tf.estimator.export.ServingInputReceiver({...})
```

A monolith model follows the above template. `input_fn` returns an instance of tf.data.Dataset. `model_fn` builds the graph for the forward pass and returns an EstimatorSpec. The `features` argument is an item from the dataset returned by the `input_fn`. Finally, if you want to serve the model, you need to implement the `serving_input_receiver_fn`.

### Prepare the dataset

We can use tfds to load dataset. Then, we select the features that we're going to use from the dataset, and do some preprocessing. In our case, we need to convert user ids and movie titles from strings to unique integer ids.

```python
def get_preprocessed_dataset(size='100k') -> tf.data.Dataset:
ratings = tfds.load(f"movielens/{size}-ratings", split="train")
# For simplicity, we map each movie_title and user_id to numbers
# by hashing. You can use other ways to number them to avoid
# collision and better leverage Monolith's collision-free hash tables.
max_b = (1 << 63) - 1
return ratings.map(lambda x: {
'mov': tf.strings.to_hash_bucket_fast([x['movie_title']], max_b),
'uid': tf.strings.to_hash_bucket_fast([x['user_id']], max_b),
'label': tf.expand_dims(x['user_rating'], axis=0)
})
```

### Write input_fn for batch training

To enable distributed training, our `input_fn` first shard the dataset according to total number of workers, then batch. Note that Monolith requires sparse features to be ragged tensors, so a .map(to_ragged) is required if this isn't the case.

```python
def to_ragged(x):
return {
'mov': tf.RaggedTensor.from_tensor(x['mov']),
'uid': tf.RaggedTensor.from_tensor(x['uid']),
'label': x['label']
}

def input_fn(self, mode):
env = json.loads(os.environ['TF_CONFIG'])
cluster = env['cluster']
worker_count = len(cluster.get('worker', [])) + len(cluster.get('chief', []))
dataset = get_preprocessed_dataset('25m')
dataset = dataset.shard(worker_count, env['task']['index'])
return dataset.batch(512, drop_remainder=True)\
.map(to_ragged).prefetch(tf.data.AUTOTUNE)
```

### Build the model

```python
def model_fn(self, features, mode):
# for sparse features, we declare an embedding table for each of them
for s_name in ["mov", "uid"]:
self.create_embedding_feature_column(s_name)

mov_embedding, user_embedding = self.lookup_embedding_slice(
features=['mov', 'uid'], slice_name='vec', slice_dim=32)
ratings = tf.keras.Sequential([
# Learn multiple dense layers.
tf.keras.layers.Dense(256, activation="relu"),
tf.keras.layers.Dense(64, activation="relu"),
# Make rating predictions in the final layer.
tf.keras.layers.Dense(1)
])
rank = ratings(tf.concat((user_embedding, mov_embedding), axis=1))
label = features['label']
loss = tf.reduce_mean(tf.losses.mean_squared_error(rank, label))

optimizer = tf.compat.v1.train.AdagradOptimizer(0.05)

return EstimatorSpec(
label=label,
pred=rank,
head_name="rank",
loss=loss,
optimizer=optimizer,
classification=False
)
```

In `model_fn`, we use `self.create_embedding_feature_column(feature_name)` to declare a embedding table for each of the feature name that requires an embedding. In our case, they are `mov` and `uid`. Note that the these feature names must match what the `input_fn` provides.

Then, we use `self.lookup_embedding_slice` to lookup the embeddings at once. If your features require different embedding length, then you can use multiple calls to `self.lookup_embedding_slice`. The rest is straightforward and is identical to how you do it in native tensorflow in graph mode.

Finally, we return an `EstimatorSpec`. This `EstimatorSpec` is a wrapped version of `tf.estimator.EstimatorSpec` and thus has more fields.

## Run distributed batch training locally

There're multiple ways to setup a distributed training. In this tutorial, we'll use the parameter server (PS) training strategy. In this strategy, model weights are partitioned across PS, and workers read data and pull weights from PS and do training.

While we usually run distributed training on top of a job scheduler such as YARN and Kubernetes, it can be done locally too.

To launch a training, we start multiple processes, some of which are workers and some of which are PS. Tensorflow uses a `TF_CONFIG` variable to define a cluster and the role of the current process in the cluster. This environment variable also enables service discovery between worker and PS. Example of a `TF_CONFIG`:

```python
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": ["host1:port", "host2:port", "host3:port"],
"ps": ["host4:port", "host5:port"]
},
"task": {"type": "worker", "index": 1}
})
```

We provide a script for this: [demo_local_runner.py](./demo_local_runner.py). To run batch training, simply do

```bash
bazel run //markdown/demo:demo_local_runner -- --training_type=batch
```

9 changes: 9 additions & 0 deletions markdown/demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Monolith demo model and tutorials

This is a 3-part tutorial for building monolith models and launch training.

### [Part 1: building a model and launch distributed async batch training](./Batch.md)

### [Part 2: training with streaming input data](./Stream.md)

### [Part 3: launching distributed async training on the cloud](./AWS-EKS.md)
64 changes: 64 additions & 0 deletions markdown/demo/Stream.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Stream training tutorial
> This tutorial depends on the batching training tutorial. Please read it first if you haven't.
Monolith supports reading input data from Kafka stream. To add stream training support to your model, simply change the `input_fn` and read data from a KafkaDataset.

## Kafka producer

Source code: [kafka_producer.py](./kafka_producer.py)

Let's create a kafka producer for our movie-lens dataset. Kafka requires serializing everything to bytes, so we convert each data item in the dataset to String by putting them into the standard Tensorflow Example protobuf.

```python
def serialize_one(data):
# serialize an training instance to string
return tf.train.Example(features=tf.train.Features(
feature={
'mov': tf.train.Feature(int64_list=tf.train.Int64List(value=data['mov'])),
'uid': tf.train.Feature(int64_list=tf.train.Int64List(value=data['uid'])),
'label': tf.train.Feature(float_list=tf.train.FloatList(value=data['label']))
}
)).SerializeToString()
```

Then, we create a KafkaProducer, iterate over the dataset, serializing each item and write it to the desired kafka topic.

```python
if __name__ == "__main__":
ds = get_preprocessed_dataset()
producer = KafkaProducer(bootstrap_servers=['127.0.0.1:9092'])
for count, val in tqdm(enumerate(ds), total=len(ds)):
# note: we omit error callback here for performance
producer.send(
"movie-train", key=str(count).encode('utf-8'), value=serialize_one(val), headers=[])
producer.flush()
```

## Kafka consumer in the input_fn

Source code: [kafka_receiver.py](./kafka_receiver.py) and [demo_model.py](./demo_model.py)

Since the kafka stream contains serialized `tf.train.Example`, we can use `tf.io.parse_example` to parse multiple of them at once.

```python
def decode_example(v):
x = tf.io.parse_example(v, raw_feature_desc)
return to_ragged(x)
```

In the `input_fn`, we use the Monolith's utility function to create a kafka dataset, and use the function above the decode. The parameter `poll_batch_size` determines the how many serialized `Example` we should batch before sending them to `decode_example`. It effectively means the training batch size.

```python
def input_fn(self, mode):
dataset = create_plain_kafka_dataset(topics=["movie-train"],
group_id="cgonline",
servers="127.0.0.1:9092",
stream_timeout=10000,
poll_batch_size=16,
configuration=[
"session.timeout.ms=7000",
"max.poll.interval.ms=8000"
],
)
return dataset.map(lambda x: decode_example(x.message))
```
Loading

0 comments on commit 26e2dee

Please sign in to comment.