spa/.claude/skills/thread-manager/node_modules/onnxruntime-web/lib/wasm/session-options.ts

157 lines
5.2 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {InferenceSession} from 'onnxruntime-common';
import {iterateExtraOptions} from './options-utils';
import {allocWasmString} from './string-utils';
import {getInstance} from './wasm-factory';
const getGraphOptimzationLevel = (graphOptimizationLevel: string|unknown): number => {
switch (graphOptimizationLevel) {
case 'disabled':
return 0;
case 'basic':
return 1;
case 'extended':
return 2;
case 'all':
return 99;
default:
throw new Error(`unsupported graph optimization level: ${graphOptimizationLevel}`);
}
};
const getExecutionMode = (executionMode: 'sequential'|'parallel'): number => {
switch (executionMode) {
case 'sequential':
return 0;
case 'parallel':
return 1;
default:
throw new Error(`unsupported execution mode: ${executionMode}`);
}
};
const appendDefaultOptions = (options: InferenceSession.SessionOptions): void => {
if (!options.extra) {
options.extra = {};
}
if (!options.extra.session) {
options.extra.session = {};
}
const session = options.extra.session as Record<string, string>;
if (!session.use_ort_model_bytes_directly) {
// eslint-disable-next-line camelcase
session.use_ort_model_bytes_directly = '1';
}
};
const setExecutionProviders =
(sessionOptionsHandle: number, executionProviders: readonly InferenceSession.ExecutionProviderConfig[],
allocs: number[]): void => {
for (const ep of executionProviders) {
let epName = typeof ep === 'string' ? ep : ep.name;
// check EP name
switch (epName) {
case 'xnnpack':
epName = 'XNNPACK';
break;
case 'wasm':
case 'cpu':
continue;
default:
throw new Error(`not supported EP: ${epName}`);
}
const epNameDataOffset = allocWasmString(epName, allocs);
if (getInstance()._OrtAppendExecutionProvider(sessionOptionsHandle, epNameDataOffset) !== 0) {
throw new Error(`Can't append execution provider: ${epName}`);
}
}
};
export const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => {
const wasm = getInstance();
let sessionOptionsHandle = 0;
const allocs: number[] = [];
const sessionOptions: InferenceSession.SessionOptions = options || {};
appendDefaultOptions(sessionOptions);
try {
if (options?.graphOptimizationLevel === undefined) {
sessionOptions.graphOptimizationLevel = 'all';
}
const graphOptimizationLevel = getGraphOptimzationLevel(sessionOptions.graphOptimizationLevel!);
if (options?.enableCpuMemArena === undefined) {
sessionOptions.enableCpuMemArena = true;
}
if (options?.enableMemPattern === undefined) {
sessionOptions.enableMemPattern = true;
}
if (options?.executionMode === undefined) {
sessionOptions.executionMode = 'sequential';
}
const executionMode = getExecutionMode(sessionOptions.executionMode!);
let logIdDataOffset = 0;
if (options?.logId !== undefined) {
logIdDataOffset = allocWasmString(options.logId, allocs);
}
if (options?.logSeverityLevel === undefined) {
sessionOptions.logSeverityLevel = 2; // Default to warning
} else if (
typeof options.logSeverityLevel !== 'number' || !Number.isInteger(options.logSeverityLevel) ||
options.logSeverityLevel < 0 || options.logSeverityLevel > 4) {
throw new Error(`log serverity level is not valid: ${options.logSeverityLevel}`);
}
if (options?.logVerbosityLevel === undefined) {
sessionOptions.logVerbosityLevel = 0; // Default to 0
} else if (typeof options.logVerbosityLevel !== 'number' || !Number.isInteger(options.logVerbosityLevel)) {
throw new Error(`log verbosity level is not valid: ${options.logVerbosityLevel}`);
}
if (options?.enableProfiling === undefined) {
sessionOptions.enableProfiling = false;
}
sessionOptionsHandle = wasm._OrtCreateSessionOptions(
graphOptimizationLevel, !!sessionOptions.enableCpuMemArena!, !!sessionOptions.enableMemPattern!, executionMode,
!!sessionOptions.enableProfiling!, 0, logIdDataOffset, sessionOptions.logSeverityLevel!,
sessionOptions.logVerbosityLevel!);
if (sessionOptionsHandle === 0) {
throw new Error('Can\'t create session options');
}
if (options?.executionProviders) {
setExecutionProviders(sessionOptionsHandle, options.executionProviders, allocs);
}
if (options?.extra !== undefined) {
iterateExtraOptions(options.extra, '', new WeakSet<Record<string, unknown>>(), (key, value) => {
const keyDataOffset = allocWasmString(key, allocs);
const valueDataOffset = allocWasmString(value, allocs);
if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {
throw new Error(`Can't set a session config entry: ${key} - ${value}`);
}
});
}
return [sessionOptionsHandle, allocs];
} catch (e) {
if (sessionOptionsHandle !== 0) {
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
}
allocs.forEach(wasm._free);
throw e;
}
};