99 lines
3.9 KiB
JavaScript
99 lines
3.9 KiB
JavaScript
"use strict";
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
exports.OnnxruntimeWebAssemblySessionHandler = void 0;
|
|
const fs_1 = require("fs");
|
|
const onnxruntime_common_1 = require("onnxruntime-common");
|
|
const util_1 = require("util");
|
|
const proxy_wrapper_1 = require("./proxy-wrapper");
|
|
let ortInit;
|
|
const getLogLevel = (logLevel) => {
|
|
switch (logLevel) {
|
|
case 'verbose':
|
|
return 0;
|
|
case 'info':
|
|
return 1;
|
|
case 'warning':
|
|
return 2;
|
|
case 'error':
|
|
return 3;
|
|
case 'fatal':
|
|
return 4;
|
|
default:
|
|
throw new Error(`unsupported logging level: ${logLevel}`);
|
|
}
|
|
};
|
|
class OnnxruntimeWebAssemblySessionHandler {
|
|
async createSessionAllocate(path) {
|
|
// fetch model from url and move to wasm heap. The arraybufffer that held the http
|
|
// response is freed once we return
|
|
const response = await fetch(path);
|
|
const arrayBuffer = await response.arrayBuffer();
|
|
return (0, proxy_wrapper_1.createSessionAllocate)(new Uint8Array(arrayBuffer));
|
|
}
|
|
async loadModel(pathOrBuffer, options) {
|
|
if (!ortInit) {
|
|
await (0, proxy_wrapper_1.initOrt)(onnxruntime_common_1.env.wasm.numThreads, getLogLevel(onnxruntime_common_1.env.logLevel));
|
|
ortInit = true;
|
|
}
|
|
if (typeof pathOrBuffer === 'string') {
|
|
if (typeof fetch === 'undefined') {
|
|
// node
|
|
const model = await (0, util_1.promisify)(fs_1.readFile)(pathOrBuffer);
|
|
[this.sessionId, this.inputNames, this.outputNames] = await (0, proxy_wrapper_1.createSession)(model, options);
|
|
}
|
|
else {
|
|
// browser
|
|
// fetch model and move to wasm heap.
|
|
const modelData = await this.createSessionAllocate(pathOrBuffer);
|
|
// create the session
|
|
[this.sessionId, this.inputNames, this.outputNames] = await (0, proxy_wrapper_1.createSessionFinalize)(modelData, options);
|
|
}
|
|
}
|
|
else {
|
|
[this.sessionId, this.inputNames, this.outputNames] = await (0, proxy_wrapper_1.createSession)(pathOrBuffer, options);
|
|
}
|
|
}
|
|
async dispose() {
|
|
return (0, proxy_wrapper_1.releaseSession)(this.sessionId);
|
|
}
|
|
async run(feeds, fetches, options) {
|
|
const inputArray = [];
|
|
const inputIndices = [];
|
|
Object.entries(feeds).forEach(kvp => {
|
|
const name = kvp[0];
|
|
const tensor = kvp[1];
|
|
const index = this.inputNames.indexOf(name);
|
|
if (index === -1) {
|
|
throw new Error(`invalid input '${name}'`);
|
|
}
|
|
inputArray.push(tensor);
|
|
inputIndices.push(index);
|
|
});
|
|
const outputIndices = [];
|
|
Object.entries(fetches).forEach(kvp => {
|
|
const name = kvp[0];
|
|
// TODO: support pre-allocated output
|
|
const index = this.outputNames.indexOf(name);
|
|
if (index === -1) {
|
|
throw new Error(`invalid output '${name}'`);
|
|
}
|
|
outputIndices.push(index);
|
|
});
|
|
const outputs = await (0, proxy_wrapper_1.run)(this.sessionId, inputIndices, inputArray.map(t => [t.type, t.dims, t.data]), outputIndices, options);
|
|
const result = {};
|
|
for (let i = 0; i < outputs.length; i++) {
|
|
result[this.outputNames[outputIndices[i]]] = new onnxruntime_common_1.Tensor(outputs[i][0], outputs[i][2], outputs[i][1]);
|
|
}
|
|
return result;
|
|
}
|
|
startProfiling() {
|
|
// TODO: implement profiling
|
|
}
|
|
endProfiling() {
|
|
void (0, proxy_wrapper_1.endProfiling)(this.sessionId);
|
|
}
|
|
}
|
|
exports.OnnxruntimeWebAssemblySessionHandler = OnnxruntimeWebAssemblySessionHandler;
|
|
//# sourceMappingURL=session-handler.js.map
|