Using TensorFlow.js with web workers in Angular
Tensorflow.js is a JavaScript library for training and deploying machine learning models. In combination with Angular it can be used to create beautiful web apps e.g. for inference. The only problem is that tensor operations are computationally expensive and JavaScript runs single threaded in browsers. That means every time the main thread is occupied by expensive and long-running tensor operations, the UI is blocked and users cannot interact with the web app. To avoid that behavior, we need to run our expensive operations in separate background threads with web workers.
Angular supports creating web workers with the CLI: ng generate web-worker <location>
. This will create a worker.ts
scaffold and configures the project if it isn’t already.
Serializing and deserializing tensors for web workers
In a service the web worker is invoked with the tensor you want to transform as a parameter. All data that is sent to the worker needs to be serialized and later deserialized in the worker. This also applies to tf.Tensor:
const worker = new Worker(new URL('./app.worker', import.meta.url));
worker.onmessage = ({ data }) => {
const transformedTensor = tensor3d(data.serializedTensor.data, data.serializedTensor.shape);
console.log(`worker result: ${transformedTensor}`);
};
worker.postMessage({serializedTensor: {data: tensor.dataSync(), shape: tensor.shape}});
Within the web worker expensive tensor operations can be performed that would otherwise block the main thread. To avoid memory leaks, the web worker should be terminated either from the service with worker.terminate()
or within the worker with close()
:
/// <reference lib="webworker" />
import {tensor3d, Tensor3D} from "@tensorflow/tfjs";
addEventListener('message', ({data}) => {
const tensor: Tensor3D = tensor3d(data.serializedTensor.data, data.serializedTensor.shape); //deserialize
const transformedTensor = someExpensiveTensorComputation(tensor);
postMessage({serializedTensor: {data: transformedTensor.dataSync(), shape: transformedTensor.shape}}); //pass message back to service
close(); //close web worker to prevent memory leaks
});
Serializing models
The inference with tf.GraphModel.predict() is also performed on the main thread and will block the UI. To execute it in the background with a web worker you can either save the model in indexedDB or localStorage and load it in the web worker or you can serialize the model and deserialize it in the worker. The latter approach can be achieved by serializing the model’s weightData
ArrayBuffer to a Base64 string:
import {GraphModel} from "@tensorflow/tfjs";
import * as base64 from "base64-arraybuffer";
serializeModel(model: GraphModel): string {
(<any>model).artifacts.weightData = base64.encode(data.weightData);
const jsonModel = JSON.stringify(data);
}
deserializeModel(json: string): GraphModel {
const artifacts = JSON.parse(serializedModel);
artifacts.weightData = base64.decode(artifacts.weightData);
const model = new GraphModel({});
model.loadSync(artifacts);
return model;
}
Some improvements and syntactic sugar
Sometimes it is useful to wrap the call to the web worker in a convenient Promise that can be passed around:
expensiveTensorOperation(tensor: Tensor3D): Promise<Tensor3D> {
const promise: Promise<Tensor3D> = new Promise((resolve, reject) => {
worker.onmessage = ({data}) => {
resolve(data);
}
worker.onerror = (e: ErrorEvent) => {
e.preventDefault(); //prevent exception propagating out of the web worker
reject(new Error(e.message));
};
});
worker.postMessage(tensor);
return promise;
}
If you have lots of operations you want to outsource to web workers, you can create a dedicated class to pass the serialized tensors around in value objects:
import {Rank} from "@tensorflow/tfjs-core/dist/types";
import {Tensor} from "@tensorflow/tfjs-core/dist/tensor";
import {ShapeMap, tensor} from "@tensorflow/tfjs";
export class STensor<R extends Rank = Rank> {
private data: Uint8Array | Int32Array | Float32Array;
private shape: ShapeMap[R];
constructor(tensor: Tensor<R>) {
this.data = tensor.dataSync();
this.shape = tensor.shape;
}
public static deserialize<R extends Rank = Rank>(serializedTensor: STensor<R>): Tensor<R> {
return tensor(serializedTensor.data, serializedTensor.shape as ShapeMap[R]);
}
}