Skip to content

Commit

Permalink
Add tf.io.browserHTTPRequest (tensorflow#1030)
Browse files Browse the repository at this point in the history
* Add tf.io.browserHTTPRequest

* Allows model artifacts to be sent via a multipart/form-data
  HTTP request

Towards: tensorflow#13
  • Loading branch information
caisq authored May 12, 2018
1 parent 323afa2 commit 5980684
Show file tree
Hide file tree
Showing 6 changed files with 527 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/io/browser_files.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ const DEFAULT_FILE_NAME_PREFIX = 'model';
const DEFAULT_JSON_EXTENSION_NAME = '.json';
const DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin';

export class BrowserDownloads implements IOHandler {
class BrowserDownloads implements IOHandler {
private readonly modelTopologyFileName: string;
private readonly weightDataFileName: string;
private readonly jsonAnchor: HTMLAnchorElement;
Expand Down Expand Up @@ -102,7 +102,7 @@ export class BrowserDownloads implements IOHandler {
}
}

export class BrowserFiles implements IOHandler {
class BrowserFiles implements IOHandler {
private readonly files: File[];

constructor(files: File[]) {
Expand Down
236 changes: 236 additions & 0 deletions src/io/browser_http.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

/**
* IOHandler implementations based on HTTP requests in the web browser.
*
* Uses [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).
*/

import {assert} from '../util';

import {getModelArtifactsInfoForKerasJSON} from './io_utils';
// tslint:disable-next-line:max-line-length
import {IOHandler, ModelArtifacts, SaveResult, WeightsManifestConfig} from './types';

class BrowserHTTPRequest implements IOHandler {
protected readonly path: string;
protected readonly requestInit: RequestInit;

readonly DEFAULT_METHOD = 'POST';

constructor(path: string, requestInit?: RequestInit) {
assert(
path != null && path.length > 0,
'URL path for browserHTTPRequest must not be null, undefined or ' +
'empty.');
this.path = path;

if (requestInit != null && requestInit.body != null) {
throw new Error(
'requestInit is expected to have no pre-existing body, but has one.');
}
this.requestInit = requestInit || {};
}

async save(modelArtifacts: ModelArtifacts): Promise<SaveResult> {
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error(
'BrowserHTTPRequest.save() does not support saving model topology ' +
'in binary formats yet.');
}

const init = Object.assign({method: this.DEFAULT_METHOD}, this.requestInit);
init.body = new FormData();

const weightsManifest: WeightsManifestConfig = [{
paths: ['./model.weights.bin'],
weights: modelArtifacts.weightSpecs,
}];
const modelTopologyAndWeightManifest = {
modelTopology: modelArtifacts.modelTopology,
weightsManifest
};

init.body.append(
'model.json',
new Blob(
[JSON.stringify(modelTopologyAndWeightManifest)],
{type: 'application/json'}),
'model.json');

if (modelArtifacts.weightData != null) {
init.body.append(
'model.weights.bin',
new Blob(
[modelArtifacts.weightData], {type: 'application/octet-stream'}),
'model.weights.bin');
}

const response = await fetch(this.path, init);

if (response.status === 200) {
return {
modelArtifactsInfo: getModelArtifactsInfoForKerasJSON(modelArtifacts),
responses: [response],
};
} else {
throw new Error(
`BrowserHTTPRequest.save() failed due to HTTP response status ` +
`${response.status}.`);
}
}

// TODO(cais): Add load to unify this IOHandler type and the mechanism
// that currently underlies `tf.loadModel('path')` in tfjs-layers.
// See: https://github.com/tensorflow/tfjs/issues/290
}

// tslint:disable:max-line-length
/**
* Creates an IOHandler subtype that sends model artifacts to HTTP server.
*
* An HTTP request of the `multipart/form-data` mime type will be sent to the
* `path` URL. The form data includes artifacts that represent the topology
* and/or weights of the model. In the case of Keras-style `tf.Model`, two
* blobs (files) exist in form-data:
* - A JSON file consisting of `modelTopology` and `weightsManifest`.
* - A binary weights file consisting of the concatenated weight values.
* These files are in the same format as the one generated by
* [tensorflowjs_converter](https://js.tensorflow.org/tutorials/import-keras.html).
*
* The following code snippet exemplifies the client-side code that uses this
* function:
*
* ```js
* const model = tf.sequential();
* model.add(
* tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));
*
* const saveResult = await model.save(tf.io.browserHTTPRequest(
* 'http://model-server:5000/upload', {method: 'PUT'}));
* console.log(saveResult);
* ```
*
* The following Python code snippet based on the
* [flask](https://github.com/pallets/flask) server framework implements a
* server that can receive the request. Upon receiving the model artifacts
* via the requst, this particular server reconsistutes instances of
* [Keras Models](https://keras.io/models/model/) in memory.
*
* ```python
* # pip install -U flask flask-cors keras tensorflow tensorflowjs
*
* from __future__ import absolute_import
* from __future__ import division
* from __future__ import print_function
*
* import io
*
* from flask import Flask, Response, request
* from flask_cors import CORS, cross_origin
* import tensorflow as tf
* import tensorflowjs as tfjs
* import werkzeug.formparser
*
*
* class ModelReceiver(object):
*
* def __init__(self):
* self._model = None
* self._model_json_bytes = None
* self._model_json_writer = None
* self._weight_bytes = None
* self._weight_writer = None
*
* @property
* def model(self):
* self._model_json_writer.flush()
* self._weight_writer.flush()
* self._model_json_writer.seek(0)
* self._weight_writer.seek(0)
*
* json_content = self._model_json_bytes.read()
* weights_content = self._weight_bytes.read()
* return tfjs.converters.deserialize_keras_model(
* json_content,
* weight_data=[weights_content],
* use_unique_name_scope=True)
*
* def stream_factory(self,
* total_content_length,
* content_type,
* filename,
* content_length=None):
* # Note: this example code is *not* thread-safe.
* if filename == 'model.json':
* self._model_json_bytes = io.BytesIO()
* self._model_json_writer = io.BufferedWriter(self._model_json_bytes)
* return self._model_json_writer
* elif filename == 'model.weights.bin':
* self._weight_bytes = io.BytesIO()
* self._weight_writer = io.BufferedWriter(self._weight_bytes)
* return self._weight_writer
*
*
* def main():
* app = Flask('model-server')
* CORS(app)
* app.config['CORS_HEADER'] = 'Content-Type'
*
* model_receiver = ModelReceiver()
*
* @app.route('/upload', methods=['POST'])
* @cross_origin()
* def upload():
* print('Handling request...')
* werkzeug.formparser.parse_form_data(
* request.environ, stream_factory=model_receiver.stream_factory)
* print('Received model:')
* with tf.Graph().as_default(), tf.Session():
* model = model_receiver.model
* model.summary()
* # You can perform `model.predict()`, `model.fit()`,
* # `model.evaluate()` etc. here.
* return Response(status=200)
*
* app.run('localhost', 5000)
*
*
* if __name__ == '__main__':
* main()
* ```
*
* @param path URL path. Can be an absolute HTTP path (e.g.,
* 'http://localhost:8000/model-upload)') or a relative path (e.g.,
* './model-upload').
* @param requestInit Request configurations to be used when sending
* HTTP request to server using `fetch`. It can contain fields such as
* `method`, `credentials`, `headers`, `mode`, etc. See
* https://developer.mozilla.org/en-US/docs/Web/API/Request/Request
* for more information. `requestInit` must not have a body, because the body
* will be set by TensorFlow.js. File blobs representing
* the model topology (filename: 'model.json') and the weights of the
* model (filename: 'model.weights.bin') will be appended to the body.
* If `requestInit` has a `body`, an Error will be thrown.
* @returns An instance of `IOHandler`.
*/
// tslint:enable:max-line-length
export function browserHTTPRequest(
path: string, requestInit?: RequestInit): IOHandler {
return new BrowserHTTPRequest(path, requestInit);
}
Loading

0 comments on commit 5980684

Please sign in to comment.