forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tf.io.browserHTTPRequest (tensorflow#1030)
* Add tf.io.browserHTTPRequest * Allows model artifacts to be sent via a multipart/form-data HTTP request Towards: tensorflow#13
- Loading branch information
Showing
6 changed files
with
527 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
/** | ||
* @license | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ============================================================================= | ||
*/ | ||
|
||
/** | ||
* IOHandler implementations based on HTTP requests in the web browser. | ||
* | ||
* Uses [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). | ||
*/ | ||
|
||
import {assert} from '../util'; | ||
|
||
import {getModelArtifactsInfoForKerasJSON} from './io_utils'; | ||
// tslint:disable-next-line:max-line-length | ||
import {IOHandler, ModelArtifacts, SaveResult, WeightsManifestConfig} from './types'; | ||
|
||
class BrowserHTTPRequest implements IOHandler { | ||
protected readonly path: string; | ||
protected readonly requestInit: RequestInit; | ||
|
||
readonly DEFAULT_METHOD = 'POST'; | ||
|
||
constructor(path: string, requestInit?: RequestInit) { | ||
assert( | ||
path != null && path.length > 0, | ||
'URL path for browserHTTPRequest must not be null, undefined or ' + | ||
'empty.'); | ||
this.path = path; | ||
|
||
if (requestInit != null && requestInit.body != null) { | ||
throw new Error( | ||
'requestInit is expected to have no pre-existing body, but has one.'); | ||
} | ||
this.requestInit = requestInit || {}; | ||
} | ||
|
||
async save(modelArtifacts: ModelArtifacts): Promise<SaveResult> { | ||
if (modelArtifacts.modelTopology instanceof ArrayBuffer) { | ||
throw new Error( | ||
'BrowserHTTPRequest.save() does not support saving model topology ' + | ||
'in binary formats yet.'); | ||
} | ||
|
||
const init = Object.assign({method: this.DEFAULT_METHOD}, this.requestInit); | ||
init.body = new FormData(); | ||
|
||
const weightsManifest: WeightsManifestConfig = [{ | ||
paths: ['./model.weights.bin'], | ||
weights: modelArtifacts.weightSpecs, | ||
}]; | ||
const modelTopologyAndWeightManifest = { | ||
modelTopology: modelArtifacts.modelTopology, | ||
weightsManifest | ||
}; | ||
|
||
init.body.append( | ||
'model.json', | ||
new Blob( | ||
[JSON.stringify(modelTopologyAndWeightManifest)], | ||
{type: 'application/json'}), | ||
'model.json'); | ||
|
||
if (modelArtifacts.weightData != null) { | ||
init.body.append( | ||
'model.weights.bin', | ||
new Blob( | ||
[modelArtifacts.weightData], {type: 'application/octet-stream'}), | ||
'model.weights.bin'); | ||
} | ||
|
||
const response = await fetch(this.path, init); | ||
|
||
if (response.status === 200) { | ||
return { | ||
modelArtifactsInfo: getModelArtifactsInfoForKerasJSON(modelArtifacts), | ||
responses: [response], | ||
}; | ||
} else { | ||
throw new Error( | ||
`BrowserHTTPRequest.save() failed due to HTTP response status ` + | ||
`${response.status}.`); | ||
} | ||
} | ||
|
||
// TODO(cais): Add load to unify this IOHandler type and the mechanism | ||
// that currently underlies `tf.loadModel('path')` in tfjs-layers. | ||
// See: https://github.com/tensorflow/tfjs/issues/290 | ||
} | ||
|
||
// tslint:disable:max-line-length | ||
/** | ||
* Creates an IOHandler subtype that sends model artifacts to HTTP server. | ||
* | ||
* An HTTP request of the `multipart/form-data` mime type will be sent to the | ||
* `path` URL. The form data includes artifacts that represent the topology | ||
* and/or weights of the model. In the case of Keras-style `tf.Model`, two | ||
* blobs (files) exist in form-data: | ||
* - A JSON file consisting of `modelTopology` and `weightsManifest`. | ||
* - A binary weights file consisting of the concatenated weight values. | ||
* These files are in the same format as the one generated by | ||
* [tensorflowjs_converter](https://js.tensorflow.org/tutorials/import-keras.html). | ||
* | ||
* The following code snippet exemplifies the client-side code that uses this | ||
* function: | ||
* | ||
* ```js | ||
* const model = tf.sequential(); | ||
* model.add( | ||
* tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'})); | ||
* | ||
* const saveResult = await model.save(tf.io.browserHTTPRequest( | ||
* 'http://model-server:5000/upload', {method: 'PUT'})); | ||
* console.log(saveResult); | ||
* ``` | ||
* | ||
* The following Python code snippet based on the | ||
* [flask](https://github.com/pallets/flask) server framework implements a | ||
* server that can receive the request. Upon receiving the model artifacts | ||
* via the requst, this particular server reconsistutes instances of | ||
* [Keras Models](https://keras.io/models/model/) in memory. | ||
* | ||
* ```python | ||
* # pip install -U flask flask-cors keras tensorflow tensorflowjs | ||
* | ||
* from __future__ import absolute_import | ||
* from __future__ import division | ||
* from __future__ import print_function | ||
* | ||
* import io | ||
* | ||
* from flask import Flask, Response, request | ||
* from flask_cors import CORS, cross_origin | ||
* import tensorflow as tf | ||
* import tensorflowjs as tfjs | ||
* import werkzeug.formparser | ||
* | ||
* | ||
* class ModelReceiver(object): | ||
* | ||
* def __init__(self): | ||
* self._model = None | ||
* self._model_json_bytes = None | ||
* self._model_json_writer = None | ||
* self._weight_bytes = None | ||
* self._weight_writer = None | ||
* | ||
* @property | ||
* def model(self): | ||
* self._model_json_writer.flush() | ||
* self._weight_writer.flush() | ||
* self._model_json_writer.seek(0) | ||
* self._weight_writer.seek(0) | ||
* | ||
* json_content = self._model_json_bytes.read() | ||
* weights_content = self._weight_bytes.read() | ||
* return tfjs.converters.deserialize_keras_model( | ||
* json_content, | ||
* weight_data=[weights_content], | ||
* use_unique_name_scope=True) | ||
* | ||
* def stream_factory(self, | ||
* total_content_length, | ||
* content_type, | ||
* filename, | ||
* content_length=None): | ||
* # Note: this example code is *not* thread-safe. | ||
* if filename == 'model.json': | ||
* self._model_json_bytes = io.BytesIO() | ||
* self._model_json_writer = io.BufferedWriter(self._model_json_bytes) | ||
* return self._model_json_writer | ||
* elif filename == 'model.weights.bin': | ||
* self._weight_bytes = io.BytesIO() | ||
* self._weight_writer = io.BufferedWriter(self._weight_bytes) | ||
* return self._weight_writer | ||
* | ||
* | ||
* def main(): | ||
* app = Flask('model-server') | ||
* CORS(app) | ||
* app.config['CORS_HEADER'] = 'Content-Type' | ||
* | ||
* model_receiver = ModelReceiver() | ||
* | ||
* @app.route('/upload', methods=['POST']) | ||
* @cross_origin() | ||
* def upload(): | ||
* print('Handling request...') | ||
* werkzeug.formparser.parse_form_data( | ||
* request.environ, stream_factory=model_receiver.stream_factory) | ||
* print('Received model:') | ||
* with tf.Graph().as_default(), tf.Session(): | ||
* model = model_receiver.model | ||
* model.summary() | ||
* # You can perform `model.predict()`, `model.fit()`, | ||
* # `model.evaluate()` etc. here. | ||
* return Response(status=200) | ||
* | ||
* app.run('localhost', 5000) | ||
* | ||
* | ||
* if __name__ == '__main__': | ||
* main() | ||
* ``` | ||
* | ||
* @param path URL path. Can be an absolute HTTP path (e.g., | ||
* 'http://localhost:8000/model-upload)') or a relative path (e.g., | ||
* './model-upload'). | ||
* @param requestInit Request configurations to be used when sending | ||
* HTTP request to server using `fetch`. It can contain fields such as | ||
* `method`, `credentials`, `headers`, `mode`, etc. See | ||
* https://developer.mozilla.org/en-US/docs/Web/API/Request/Request | ||
* for more information. `requestInit` must not have a body, because the body | ||
* will be set by TensorFlow.js. File blobs representing | ||
* the model topology (filename: 'model.json') and the weights of the | ||
* model (filename: 'model.weights.bin') will be appended to the body. | ||
* If `requestInit` has a `body`, an Error will be thrown. | ||
* @returns An instance of `IOHandler`. | ||
*/ | ||
// tslint:enable:max-line-length | ||
export function browserHTTPRequest( | ||
path: string, requestInit?: RequestInit): IOHandler { | ||
return new BrowserHTTPRequest(path, requestInit); | ||
} |
Oops, something went wrong.