Files
core/src/container/Container.ts
M1000fr 536fcfc336 feat: enhance asynchronous provider support and fix concurrent resolution bug
Allows factory providers to be asynchronous and ensures concurrent requests for the same singleton await the same promise instead of triggering circular dependency errors.
2026-01-11 17:11:41 +01:00

435 lines
12 KiB
TypeScript

import "reflect-metadata";
import { CircularDependencyError, ProviderNotFoundError } from "../errors";
import { ContextualModuleRef, ModuleRef } from "../module/ModuleRef";
import type { Constructor } from "../types";
import {
type BaseProvider,
INJECT_METADATA_KEY,
INJECTABLE_METADATA_KEY,
MODULE_METADATA_KEY,
type ModuleOptions,
type Provider,
type ProviderToken,
type ResolvedProvider,
} from "./types";
interface ParamType {
name?: string;
}
type LifecycleHook = "onModuleInit" | "onModuleDestroy";
type LifecycleHookFn = () => Promise<void> | void;
type LifecycleAware = Partial<Record<LifecycleHook, LifecycleHookFn>>;
export class Container {
private readonly providers = new Map<
string,
Map<ProviderToken, ResolvedProvider>
>();
private readonly instances = new Map<string, Map<ProviderToken, unknown>>();
private readonly moduleOptions = new Map<string, ModuleOptions>();
private readonly resolving = new Set<ProviderToken>();
private readonly proxies = new Map<string, Map<ProviderToken, unknown>>();
private readonly pendingResolutions = new Map<
string,
Map<ProviderToken, Promise<unknown>>
>();
private rootModuleName?: string;
constructor(private readonly globalContext = "global") {
this.ensureScope(this.globalContext);
}
public registerRootModule(moduleClass: Constructor): void {
this.rootModuleName = moduleClass.name;
this.registerModule(moduleClass);
}
public getRootModuleName(): string | undefined {
return this.rootModuleName;
}
public registerModule(moduleClass: Constructor): void {
const options: ModuleOptions | undefined = Reflect.getMetadata(
MODULE_METADATA_KEY,
moduleClass,
);
if (!options) return;
const moduleName = moduleClass.name;
if (this.moduleOptions.has(moduleName)) return;
this.moduleOptions.set(moduleName, options);
this.ensureScope(moduleName);
if (options.imports) {
for (const imported of options.imports) {
this.registerModule(
this.unwrapToken(imported as ProviderToken<Constructor>),
);
}
}
if (options.providers) {
for (const provider of options.providers) {
const normalized = this.normalizeProvider(provider, moduleName);
this.providers
.get(moduleName)!
.set(normalized.token, normalized);
}
}
const moduleProvider = this.normalizeProvider(moduleClass, moduleName);
this.providers.get(moduleName)!.set(moduleClass, moduleProvider);
}
public registerProvider(provider: Provider): void {
const normalized = this.normalizeProvider(provider, this.globalContext);
this.providers
.get(this.globalContext)!
.set(normalized.token, normalized);
}
public async get<T>(
token: ProviderToken<T>,
moduleName?: string,
): Promise<T> {
return this.resolve(token, moduleName);
}
public async resolveAll(): Promise<void> {
for (const [scope] of this.providers) {
for (const token of this.providers.get(scope)!.keys()) {
await this.resolve(
token,
scope === this.globalContext ? undefined : scope,
);
}
}
}
public async callLifecycleHook(
hook: "onModuleInit" | "onModuleDestroy",
): Promise<void> {
for (const instances of this.instances.values()) {
for (const instance of instances.values()) {
if (!instance) continue;
const lifecycleInstance = instance as LifecycleAware;
const hookFn = lifecycleInstance[hook];
if (typeof hookFn === "function") {
await hookFn.call(lifecycleInstance);
}
}
}
}
private async resolve<T>(
token: ProviderToken<T>,
moduleName?: string,
): Promise<T> {
const unwrappedToken = this.unwrapToken(token) as ProviderToken<T>;
const context = moduleName ?? this.globalContext;
if (unwrappedToken === ModuleRef) {
return new ContextualModuleRef(this, context) as T;
}
let provider = this.findProvider(unwrappedToken, context);
if (
!provider &&
context === this.globalContext &&
this.rootModuleName
) {
provider = this.findProvider(unwrappedToken, this.rootModuleName);
}
if (!provider) {
throw new ProviderNotFoundError(
typeof unwrappedToken === "function"
? unwrappedToken.name
: String(unwrappedToken),
context,
);
}
if (provider.scope === "singleton") {
const scopeName = provider.moduleName ?? context;
const proxy = this.proxies.get(scopeName)?.get(unwrappedToken);
if (proxy !== undefined) {
return proxy as T;
}
const existing = this.instances.get(scopeName)?.get(unwrappedToken);
if (existing !== undefined) {
return existing as T;
}
const pending = this.pendingResolutions
.get(scopeName)
?.get(unwrappedToken);
if (pending) {
return pending as Promise<T>;
}
}
if (this.resolving.has(unwrappedToken)) {
if (
typeof token === "object" &&
token !== null &&
"forwardRef" in token
) {
const scopeName = provider.moduleName ?? context;
const proxy = this.createProxy(unwrappedToken, scopeName);
this.ensureScope(scopeName);
this.proxies.get(scopeName)!.set(unwrappedToken, proxy);
return proxy as T;
}
const stack = [...this.resolving].map((item) =>
typeof item === "function" ? item.name : String(item),
);
stack.push(
typeof unwrappedToken === "function"
? unwrappedToken.name
: String(unwrappedToken),
);
throw new CircularDependencyError(stack);
}
const performResolution = async (): Promise<T> => {
this.resolving.add(unwrappedToken);
try {
let instance: T;
if (provider.useValue !== undefined) {
instance = provider.useValue as T;
} else if (provider.useFactory) {
const dependencies = await Promise.all(
(provider.inject ?? []).map((dep) =>
this.resolve(dep, provider.moduleName ?? context),
),
);
instance = (await provider.useFactory(
...dependencies,
)) as T;
} else if (provider.useExisting) {
instance = await this.resolve<T>(
provider.useExisting as ProviderToken<T>,
provider.moduleName ?? context,
);
} else if (provider.useClass) {
const targetClass = provider.useClass;
const paramTypes: ParamType[] =
Reflect.getMetadata("design:paramtypes", targetClass) ||
[];
const injectTokens: ProviderToken[] =
Reflect.getMetadata(INJECT_METADATA_KEY, targetClass) ||
[];
const dependencies = await Promise.all(
paramTypes.map(async (paramType, index) => {
const depToken =
injectTokens[index] ||
(paramType as ProviderToken);
return this.resolve(
depToken as ProviderToken<unknown>,
provider.moduleName ?? context,
);
}),
);
instance = new targetClass(...dependencies) as T;
} else {
throw new Error(
`Invalid provider configuration for token ${String(token)}`,
);
}
if (provider.scope === "singleton") {
const scopeName = provider.moduleName ?? context;
this.ensureScope(scopeName);
this.instances
.get(scopeName)!
.set(unwrappedToken, instance);
}
return instance;
} finally {
this.resolving.delete(unwrappedToken);
if (provider.scope === "singleton") {
const scopeName = provider.moduleName ?? context;
this.pendingResolutions
.get(scopeName)
?.delete(unwrappedToken);
}
}
};
if (provider.scope === "singleton") {
const scopeName = provider.moduleName ?? context;
this.ensureScope(scopeName);
const promise = performResolution();
this.pendingResolutions
.get(scopeName)!
.set(unwrappedToken, promise);
return promise;
}
return performResolution();
}
private findProvider(
token: ProviderToken,
context: string,
): ResolvedProvider | undefined {
const unwrappedToken = this.unwrapToken(token) as ProviderToken;
const moduleProviders = this.providers.get(context);
if (moduleProviders?.has(unwrappedToken)) {
return moduleProviders.get(unwrappedToken);
}
const options = this.moduleOptions.get(context);
if (options?.imports) {
for (const imported of options.imports) {
const unwrappedImport = this.unwrapToken(
imported as ProviderToken<Constructor>,
);
const importedProviders = this.providers.get(
unwrappedImport.name,
);
const importedOptions = this.moduleOptions.get(
unwrappedImport.name,
);
if (
importedProviders?.has(unwrappedToken) &&
importedOptions?.exports?.some(
(exportToken) =>
this.unwrapToken(exportToken) === unwrappedToken,
)
) {
return importedProviders.get(unwrappedToken);
}
if (unwrappedToken === unwrappedImport) {
return importedProviders?.get(unwrappedToken);
}
}
}
if (context !== this.globalContext) {
const globalProviders = this.providers.get(this.globalContext);
if (globalProviders?.has(unwrappedToken)) {
return globalProviders.get(unwrappedToken);
}
}
if (context === this.globalContext && this.rootModuleName) {
const rootProviders = this.providers.get(this.rootModuleName);
if (rootProviders?.has(unwrappedToken)) {
return rootProviders.get(unwrappedToken);
}
}
return undefined;
}
private normalizeProvider(
provider: Provider,
moduleName?: string,
): ResolvedProvider {
if (typeof provider === "function") {
const metadata = Reflect.getMetadata(
INJECTABLE_METADATA_KEY,
provider,
);
return {
token: provider,
scope: metadata?.scope ?? "singleton",
useClass: provider,
moduleName,
};
}
const base = {
token: this.unwrapToken(provider.provide) as ProviderToken,
scope: provider.scope ?? "singleton",
moduleName,
};
if ("useClass" in provider) {
return { ...base, useClass: provider.useClass };
}
if ("useValue" in provider) {
return { ...base, useValue: provider.useValue };
}
if ("useFactory" in provider) {
return {
...base,
useFactory: provider.useFactory,
inject: provider.inject,
};
}
if ("useExisting" in provider) {
return { ...base, useExisting: provider.useExisting };
}
const exhaustiveCheck: never = provider;
throw new Error(
`Invalid provider definition for token ${String((exhaustiveCheck as BaseProvider).provide)}`,
);
}
private ensureScope(name: string): void {
if (!this.providers.has(name)) {
this.providers.set(name, new Map());
}
if (!this.instances.has(name)) {
this.instances.set(name, new Map());
}
if (!this.proxies.has(name)) {
this.proxies.set(name, new Map());
}
if (!this.pendingResolutions.has(name)) {
this.pendingResolutions.set(name, new Map());
}
}
private unwrapToken<T>(token: ProviderToken<T>): T {
if (
typeof token === "object" &&
token !== null &&
"forwardRef" in token
) {
return (token as { forwardRef: () => T }).forwardRef();
}
return token as T;
}
private createProxy<T>(token: ProviderToken<T>, context: string): T {
const self = this;
return new Proxy({} as object, {
get(_target, prop, receiver) {
if (prop === "then") return undefined;
const instance = self.instances.get(context)?.get(token);
if (!instance) {
throw new Error(
`Circular dependency instance for ${
typeof token === "function"
? token.name
: String(token)
} is not yet available. It might be accessed too early (e.g. in a constructor or during static initialization).`,
);
}
const value = Reflect.get(instance as object, prop, receiver);
return typeof value === "function"
? value.bind(instance)
: value;
},
has(_target, prop) {
const instance = self.instances.get(context)?.get(token);
return instance ? Reflect.has(instance as object, prop) : false;
},
}) as T;
}
}