134 lines
6.0 KiB
JavaScript
134 lines
6.0 KiB
JavaScript
"use strict";
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
exports.setSessionOptions = void 0;
|
|
const options_utils_1 = require("./options-utils");
|
|
const string_utils_1 = require("./string-utils");
|
|
const wasm_factory_1 = require("./wasm-factory");
|
|
const getGraphOptimzationLevel = (graphOptimizationLevel) => {
|
|
switch (graphOptimizationLevel) {
|
|
case 'disabled':
|
|
return 0;
|
|
case 'basic':
|
|
return 1;
|
|
case 'extended':
|
|
return 2;
|
|
case 'all':
|
|
return 99;
|
|
default:
|
|
throw new Error(`unsupported graph optimization level: ${graphOptimizationLevel}`);
|
|
}
|
|
};
|
|
const getExecutionMode = (executionMode) => {
|
|
switch (executionMode) {
|
|
case 'sequential':
|
|
return 0;
|
|
case 'parallel':
|
|
return 1;
|
|
default:
|
|
throw new Error(`unsupported execution mode: ${executionMode}`);
|
|
}
|
|
};
|
|
const appendDefaultOptions = (options) => {
|
|
if (!options.extra) {
|
|
options.extra = {};
|
|
}
|
|
if (!options.extra.session) {
|
|
options.extra.session = {};
|
|
}
|
|
const session = options.extra.session;
|
|
if (!session.use_ort_model_bytes_directly) {
|
|
// eslint-disable-next-line camelcase
|
|
session.use_ort_model_bytes_directly = '1';
|
|
}
|
|
};
|
|
const setExecutionProviders = (sessionOptionsHandle, executionProviders, allocs) => {
|
|
for (const ep of executionProviders) {
|
|
let epName = typeof ep === 'string' ? ep : ep.name;
|
|
// check EP name
|
|
switch (epName) {
|
|
case 'xnnpack':
|
|
epName = 'XNNPACK';
|
|
break;
|
|
case 'wasm':
|
|
case 'cpu':
|
|
continue;
|
|
default:
|
|
throw new Error(`not supported EP: ${epName}`);
|
|
}
|
|
const epNameDataOffset = (0, string_utils_1.allocWasmString)(epName, allocs);
|
|
if ((0, wasm_factory_1.getInstance)()._OrtAppendExecutionProvider(sessionOptionsHandle, epNameDataOffset) !== 0) {
|
|
throw new Error(`Can't append execution provider: ${epName}`);
|
|
}
|
|
}
|
|
};
|
|
const setSessionOptions = (options) => {
|
|
const wasm = (0, wasm_factory_1.getInstance)();
|
|
let sessionOptionsHandle = 0;
|
|
const allocs = [];
|
|
const sessionOptions = options || {};
|
|
appendDefaultOptions(sessionOptions);
|
|
try {
|
|
if ((options === null || options === void 0 ? void 0 : options.graphOptimizationLevel) === undefined) {
|
|
sessionOptions.graphOptimizationLevel = 'all';
|
|
}
|
|
const graphOptimizationLevel = getGraphOptimzationLevel(sessionOptions.graphOptimizationLevel);
|
|
if ((options === null || options === void 0 ? void 0 : options.enableCpuMemArena) === undefined) {
|
|
sessionOptions.enableCpuMemArena = true;
|
|
}
|
|
if ((options === null || options === void 0 ? void 0 : options.enableMemPattern) === undefined) {
|
|
sessionOptions.enableMemPattern = true;
|
|
}
|
|
if ((options === null || options === void 0 ? void 0 : options.executionMode) === undefined) {
|
|
sessionOptions.executionMode = 'sequential';
|
|
}
|
|
const executionMode = getExecutionMode(sessionOptions.executionMode);
|
|
let logIdDataOffset = 0;
|
|
if ((options === null || options === void 0 ? void 0 : options.logId) !== undefined) {
|
|
logIdDataOffset = (0, string_utils_1.allocWasmString)(options.logId, allocs);
|
|
}
|
|
if ((options === null || options === void 0 ? void 0 : options.logSeverityLevel) === undefined) {
|
|
sessionOptions.logSeverityLevel = 2; // Default to warning
|
|
}
|
|
else if (typeof options.logSeverityLevel !== 'number' || !Number.isInteger(options.logSeverityLevel) ||
|
|
options.logSeverityLevel < 0 || options.logSeverityLevel > 4) {
|
|
throw new Error(`log serverity level is not valid: ${options.logSeverityLevel}`);
|
|
}
|
|
if ((options === null || options === void 0 ? void 0 : options.logVerbosityLevel) === undefined) {
|
|
sessionOptions.logVerbosityLevel = 0; // Default to 0
|
|
}
|
|
else if (typeof options.logVerbosityLevel !== 'number' || !Number.isInteger(options.logVerbosityLevel)) {
|
|
throw new Error(`log verbosity level is not valid: ${options.logVerbosityLevel}`);
|
|
}
|
|
if ((options === null || options === void 0 ? void 0 : options.enableProfiling) === undefined) {
|
|
sessionOptions.enableProfiling = false;
|
|
}
|
|
sessionOptionsHandle = wasm._OrtCreateSessionOptions(graphOptimizationLevel, !!sessionOptions.enableCpuMemArena, !!sessionOptions.enableMemPattern, executionMode, !!sessionOptions.enableProfiling, 0, logIdDataOffset, sessionOptions.logSeverityLevel, sessionOptions.logVerbosityLevel);
|
|
if (sessionOptionsHandle === 0) {
|
|
throw new Error('Can\'t create session options');
|
|
}
|
|
if (options === null || options === void 0 ? void 0 : options.executionProviders) {
|
|
setExecutionProviders(sessionOptionsHandle, options.executionProviders, allocs);
|
|
}
|
|
if ((options === null || options === void 0 ? void 0 : options.extra) !== undefined) {
|
|
(0, options_utils_1.iterateExtraOptions)(options.extra, '', new WeakSet(), (key, value) => {
|
|
const keyDataOffset = (0, string_utils_1.allocWasmString)(key, allocs);
|
|
const valueDataOffset = (0, string_utils_1.allocWasmString)(value, allocs);
|
|
if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {
|
|
throw new Error(`Can't set a session config entry: ${key} - ${value}`);
|
|
}
|
|
});
|
|
}
|
|
return [sessionOptionsHandle, allocs];
|
|
}
|
|
catch (e) {
|
|
if (sessionOptionsHandle !== 0) {
|
|
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
|
|
}
|
|
allocs.forEach(wasm._free);
|
|
throw e;
|
|
}
|
|
};
|
|
exports.setSessionOptions = setSessionOptions;
|
|
//# sourceMappingURL=session-options.js.map
|