spa/.claude/skills/thread-manager/node_modules/onnxruntime-web/lib/wasm/session-handler.js

99 lines
3.9 KiB
JavaScript

"use strict";
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
Object.defineProperty(exports, "__esModule", { value: true });
exports.OnnxruntimeWebAssemblySessionHandler = void 0;
const fs_1 = require("fs");
const onnxruntime_common_1 = require("onnxruntime-common");
const util_1 = require("util");
const proxy_wrapper_1 = require("./proxy-wrapper");
let ortInit;
const getLogLevel = (logLevel) => {
switch (logLevel) {
case 'verbose':
return 0;
case 'info':
return 1;
case 'warning':
return 2;
case 'error':
return 3;
case 'fatal':
return 4;
default:
throw new Error(`unsupported logging level: ${logLevel}`);
}
};
class OnnxruntimeWebAssemblySessionHandler {
async createSessionAllocate(path) {
// fetch model from url and move to wasm heap. The arraybufffer that held the http
// response is freed once we return
const response = await fetch(path);
const arrayBuffer = await response.arrayBuffer();
return (0, proxy_wrapper_1.createSessionAllocate)(new Uint8Array(arrayBuffer));
}
async loadModel(pathOrBuffer, options) {
if (!ortInit) {
await (0, proxy_wrapper_1.initOrt)(onnxruntime_common_1.env.wasm.numThreads, getLogLevel(onnxruntime_common_1.env.logLevel));
ortInit = true;
}
if (typeof pathOrBuffer === 'string') {
if (typeof fetch === 'undefined') {
// node
const model = await (0, util_1.promisify)(fs_1.readFile)(pathOrBuffer);
[this.sessionId, this.inputNames, this.outputNames] = await (0, proxy_wrapper_1.createSession)(model, options);
}
else {
// browser
// fetch model and move to wasm heap.
const modelData = await this.createSessionAllocate(pathOrBuffer);
// create the session
[this.sessionId, this.inputNames, this.outputNames] = await (0, proxy_wrapper_1.createSessionFinalize)(modelData, options);
}
}
else {
[this.sessionId, this.inputNames, this.outputNames] = await (0, proxy_wrapper_1.createSession)(pathOrBuffer, options);
}
}
async dispose() {
return (0, proxy_wrapper_1.releaseSession)(this.sessionId);
}
async run(feeds, fetches, options) {
const inputArray = [];
const inputIndices = [];
Object.entries(feeds).forEach(kvp => {
const name = kvp[0];
const tensor = kvp[1];
const index = this.inputNames.indexOf(name);
if (index === -1) {
throw new Error(`invalid input '${name}'`);
}
inputArray.push(tensor);
inputIndices.push(index);
});
const outputIndices = [];
Object.entries(fetches).forEach(kvp => {
const name = kvp[0];
// TODO: support pre-allocated output
const index = this.outputNames.indexOf(name);
if (index === -1) {
throw new Error(`invalid output '${name}'`);
}
outputIndices.push(index);
});
const outputs = await (0, proxy_wrapper_1.run)(this.sessionId, inputIndices, inputArray.map(t => [t.type, t.dims, t.data]), outputIndices, options);
const result = {};
for (let i = 0; i < outputs.length; i++) {
result[this.outputNames[outputIndices[i]]] = new onnxruntime_common_1.Tensor(outputs[i][0], outputs[i][2], outputs[i][1]);
}
return result;
}
startProfiling() {
// TODO: implement profiling
}
endProfiling() {
void (0, proxy_wrapper_1.endProfiling)(this.sessionId);
}
}
exports.OnnxruntimeWebAssemblySessionHandler = OnnxruntimeWebAssemblySessionHandler;
//# sourceMappingURL=session-handler.js.map