Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
Enable passing custom fetch function to BrowserHTTPRequest constructor (
Browse files Browse the repository at this point in the history
  • Loading branch information
caisq authored Dec 6, 2018
1 parent e32084d commit 2ff431b
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 19 deletions.
47 changes: 31 additions & 16 deletions src/io/browser_http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,29 @@ export class BrowserHTTPRequest implements IOHandler {
protected readonly path: string|string[];
protected readonly requestInit: RequestInit;

private readonly fetchFunc: Function;

readonly DEFAULT_METHOD = 'POST';

static readonly URL_SCHEME_REGEX = /^https?:\/\//;

constructor(
path: string|string[], requestInit?: RequestInit,
private readonly weightPathPrefix?: string) {
if (typeof fetch === 'undefined') {
throw new Error(
// tslint:disable-next-line:max-line-length
'browserHTTPRequest is not supported outside the web browser without a fetch polyfill.');
private readonly weightPathPrefix?: string, fetchFunc?: Function) {
if (fetchFunc == null) {
if (typeof fetch === 'undefined') {
throw new Error(
'browserHTTPRequest is not supported outside the web browser ' +
'without a fetch polyfill.');
}
this.fetchFunc = fetch;
} else {
assert(
typeof fetchFunc === 'function',
'Must pass a function that matches the signature of ' +
'`fetch` (see ' +
'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)');
this.fetchFunc = fetchFunc;
}

assert(
Expand Down Expand Up @@ -98,7 +110,7 @@ export class BrowserHTTPRequest implements IOHandler {
'model.weights.bin');
}

const response = await fetch(this.path as string, init);
const response = await this.fetchFunc(this.path as string, init);

if (response.ok) {
return {
Expand Down Expand Up @@ -130,7 +142,7 @@ export class BrowserHTTPRequest implements IOHandler {
*/
private async loadBinaryTopology(): Promise<ArrayBuffer> {
try {
const response = await fetch(this.path[0], this.requestInit);
const response = await this.fetchFunc(this.path[0], this.requestInit);
if (!response.ok) {
throw new Error(
`BrowserHTTPRequest.load() failed due to HTTP response: ${
Expand All @@ -144,7 +156,8 @@ export class BrowserHTTPRequest implements IOHandler {

protected async loadBinaryModel(): Promise<ModelArtifacts> {
const graphPromise = this.loadBinaryTopology();
const manifestPromise = await fetch(this.path[1], this.requestInit);
const manifestPromise =
await this.fetchFunc(this.path[1], this.requestInit);
if (!manifestPromise.ok) {
throw new Error(`BrowserHTTPRequest.load() failed due to HTTP response: ${
manifestPromise.statusText}`);
Expand All @@ -168,7 +181,7 @@ export class BrowserHTTPRequest implements IOHandler {

protected async loadJSONModel(): Promise<ModelArtifacts> {
const modelConfigRequest =
await fetch(this.path as string, this.requestInit);
await this.fetchFunc(this.path as string, this.requestInit);
if (!modelConfigRequest.ok) {
throw new Error(`BrowserHTTPRequest.load() failed due to HTTP response: ${
modelConfigRequest.statusText}`);
Expand Down Expand Up @@ -216,8 +229,8 @@ export class BrowserHTTPRequest implements IOHandler {

return [
weightSpecs,
concatenateArrayBuffers(
await loadWeightsAsArrayBuffer(fetchURLs, this.requestInit))
concatenateArrayBuffers(await loadWeightsAsArrayBuffer(
fetchURLs, this.requestInit, this.fetchFunc))
];
}
}
Expand All @@ -242,7 +255,7 @@ export function parseUrl(url: string): [string, string] {
return [prefix + '/', suffix];
}

function isHTTPScheme(url: string): boolean {
export function isHTTPScheme(url: string): boolean {
return url.match(BrowserHTTPRequest.URL_SCHEME_REGEX) != null;
}

Expand Down Expand Up @@ -404,11 +417,13 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
* 'model.weights.bin') will be appended to the body. If `requestInit` has a
* `body`, an Error will be thrown.
* @param weightPathPrefix Optional, this specifies the path prefix for weight
* files, by default this is calculated from the path param.
* files, by default this is calculated from the path param.
* @param fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
* the `fetch` from node-fetch can be used here.
* @returns An instance of `IOHandler`.
*/
export function browserHTTPRequest(
path: string|string[], requestInit?: RequestInit,
weightPathPrefix?: string): IOHandler {
return new BrowserHTTPRequest(path, requestInit, weightPathPrefix);
path: string|string[], requestInit?: RequestInit, weightPathPrefix?: string,
fetchFunc?: Function): IOHandler {
return new BrowserHTTPRequest(path, requestInit, weightPathPrefix, fetchFunc);
}
52 changes: 52 additions & 0 deletions src/io/browser_http_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1032,4 +1032,56 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
expect(() => tf.io.browserHTTPRequest(['path1/model.pb'])).toThrow();
});
});

it('Overriding BrowserHTTPRequest fetchFunc', async () => {
const weightManifest1: tf.io.WeightsManifestConfig = [{
paths: ['weightfile0'],
weights: [
{
name: 'dense/kernel',
shape: [3, 1],
dtype: 'float32',
},
{
name: 'dense/bias',
shape: [2],
dtype: 'float32',
}
]
}];
const floatData = new Float32Array([1, 3, 3, 7, 4]);

const fetchInputs: RequestInfo[] = [];
const fetchInits: RequestInit[] = [];
async function customFetch(
input: RequestInfo, init?: RequestInit): Promise<Response> {
fetchInputs.push(input);
fetchInits.push(init);

if (input === './model.json') {
return new Response(
JSON.stringify({
modelTopology: modelTopology1,
weightsManifest: weightManifest1
}),
{status: 200});
} else if (input === './weightfile0') {
return new Response(floatData, {status: 200});
} else {
return new Response(null, {status: 404});
}
}

const handler = tf.io.browserHTTPRequest(
'./model.json', {credentials: 'include'}, null, customFetch);
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);

expect(fetchInputs).toEqual(['./model.json', './weightfile0']);
expect(fetchInits).toEqual([
{credentials: 'include'}, {credentials: 'include'}
]);
});
});
3 changes: 2 additions & 1 deletion src/io/io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import './indexed_db';
import './local_storage';

import {browserFiles} from './browser_files';
import {browserHTTPRequest} from './browser_http';
import {browserHTTPRequest, isHTTPScheme} from './browser_http';
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsInfoForJSON} from './io_utils';
import {fromMemory, withSaveHandler} from './passthrough';
import {IORouterRegistry} from './router_registry';
Expand All @@ -46,6 +46,7 @@ export {
getModelArtifactsInfoForJSON,
getSaveHandlers,
IOHandler,
isHTTPScheme,
LoadHandler,
loadWeights,
ModelArtifacts,
Expand Down
11 changes: 9 additions & 2 deletions src/io/weights_loader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@ import {DTYPE_VALUE_SIZE_MAP, WeightsManifestConfig, WeightsManifestEntry} from
*
* @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls.
* @param requestOptions RequestInit (options) for the HTTP requests.
* @param fetchFunc Optional overriding value for the `window.fetch` function.
* @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same
* length as `fetchURLs`.
*/
export async function loadWeightsAsArrayBuffer(
fetchURLs: string[], requestOptions?: RequestInit): Promise<ArrayBuffer[]> {
fetchURLs: string[], requestOptions?: RequestInit, fetchFunc?: Function):
Promise<ArrayBuffer[]> {
if (fetchFunc == null) {
fetchFunc = fetch;
}

// Create the requests for all of the weights in parallel.
const requests = fetchURLs.map(fetchURL => fetch(fetchURL, requestOptions));
const requests = fetchURLs.map(
fetchURL => fetchFunc(fetchURL, requestOptions));
const responses = await Promise.all(requests);
const buffers =
await Promise.all(responses.map(response => response.arrayBuffer()));
Expand Down

0 comments on commit 2ff431b

Please sign in to comment.