import "reflect-metadata"; import { CircularDependencyError, ProviderNotFoundError } from "../errors"; import { ContextualModuleRef, ModuleRef } from "../module/ModuleRef"; import type { Constructor } from "../types"; import { type BaseProvider, INJECT_METADATA_KEY, INJECTABLE_METADATA_KEY, MODULE_METADATA_KEY, type ModuleOptions, type Provider, type ProviderToken, type ResolvedProvider, } from "./types"; interface ParamType { name?: string; } type LifecycleHook = "onModuleInit" | "onModuleDestroy"; type LifecycleHookFn = () => Promise | void; type LifecycleAware = Partial>; export class Container { private readonly providers = new Map< string, Map >(); private readonly instances = new Map>(); private readonly moduleOptions = new Map(); private readonly resolving = new Set(); private readonly proxies = new Map>(); private readonly pendingResolutions = new Map< string, Map> >(); private rootModuleName?: string; constructor(private readonly globalContext = "global") { this.ensureScope(this.globalContext); } public registerRootModule(moduleClass: Constructor): void { this.rootModuleName = moduleClass.name; this.registerModule(moduleClass); } public getRootModuleName(): string | undefined { return this.rootModuleName; } public registerModule(moduleClass: Constructor): void { const options: ModuleOptions | undefined = Reflect.getMetadata( MODULE_METADATA_KEY, moduleClass, ); if (!options) return; const moduleName = moduleClass.name; if (this.moduleOptions.has(moduleName)) return; this.moduleOptions.set(moduleName, options); this.ensureScope(moduleName); if (options.imports) { for (const imported of options.imports) { this.registerModule( this.unwrapToken(imported as ProviderToken), ); } } if (options.providers) { for (const provider of options.providers) { const normalized = this.normalizeProvider(provider, moduleName); this.providers .get(moduleName)! .set(normalized.token, normalized); } } const moduleProvider = this.normalizeProvider(moduleClass, moduleName); this.providers.get(moduleName)!.set(moduleClass, moduleProvider); } public registerProvider(provider: Provider): void { const normalized = this.normalizeProvider(provider, this.globalContext); this.providers .get(this.globalContext)! .set(normalized.token, normalized); } public async get( token: ProviderToken, moduleName?: string, ): Promise { return this.resolve(token, moduleName); } public async resolveAll(): Promise { for (const [scope] of this.providers) { for (const token of this.providers.get(scope)!.keys()) { await this.resolve( token, scope === this.globalContext ? undefined : scope, ); } } } public async callLifecycleHook( hook: "onModuleInit" | "onModuleDestroy", ): Promise { for (const instances of this.instances.values()) { for (const instance of instances.values()) { if (!instance) continue; const lifecycleInstance = instance as LifecycleAware; const hookFn = lifecycleInstance[hook]; if (typeof hookFn === "function") { await hookFn.call(lifecycleInstance); } } } } private async resolve( token: ProviderToken, moduleName?: string, ): Promise { const unwrappedToken = this.unwrapToken(token) as ProviderToken; const context = moduleName ?? this.globalContext; if (unwrappedToken === ModuleRef) { return new ContextualModuleRef(this, context) as T; } let provider = this.findProvider(unwrappedToken, context); if ( !provider && context === this.globalContext && this.rootModuleName ) { provider = this.findProvider(unwrappedToken, this.rootModuleName); } if (!provider) { throw new ProviderNotFoundError( typeof unwrappedToken === "function" ? unwrappedToken.name : String(unwrappedToken), context, ); } if (provider.scope === "singleton") { const scopeName = provider.moduleName ?? context; const proxy = this.proxies.get(scopeName)?.get(unwrappedToken); if (proxy !== undefined) { return proxy as T; } const existing = this.instances.get(scopeName)?.get(unwrappedToken); if (existing !== undefined) { return existing as T; } const pending = this.pendingResolutions .get(scopeName) ?.get(unwrappedToken); if (pending) { return pending as Promise; } } if (this.resolving.has(unwrappedToken)) { if ( typeof token === "object" && token !== null && "forwardRef" in token ) { const scopeName = provider.moduleName ?? context; const proxy = this.createProxy(unwrappedToken, scopeName); this.ensureScope(scopeName); this.proxies.get(scopeName)!.set(unwrappedToken, proxy); return proxy as T; } const stack = [...this.resolving].map((item) => typeof item === "function" ? item.name : String(item), ); stack.push( typeof unwrappedToken === "function" ? unwrappedToken.name : String(unwrappedToken), ); throw new CircularDependencyError(stack); } const performResolution = async (): Promise => { this.resolving.add(unwrappedToken); try { let instance: T; if (provider.useValue !== undefined) { instance = provider.useValue as T; } else if (provider.useFactory) { const dependencies = await Promise.all( (provider.inject ?? []).map((dep) => this.resolve(dep, provider.moduleName ?? context), ), ); instance = (await provider.useFactory( ...dependencies, )) as T; } else if (provider.useExisting) { instance = await this.resolve( provider.useExisting as ProviderToken, provider.moduleName ?? context, ); } else if (provider.useClass) { const targetClass = provider.useClass; const paramTypes: ParamType[] = Reflect.getMetadata("design:paramtypes", targetClass) || []; const injectTokens: ProviderToken[] = Reflect.getMetadata(INJECT_METADATA_KEY, targetClass) || []; const dependencies = await Promise.all( paramTypes.map(async (paramType, index) => { const depToken = injectTokens[index] || (paramType as ProviderToken); return this.resolve( depToken as ProviderToken, provider.moduleName ?? context, ); }), ); instance = new targetClass(...dependencies) as T; } else { throw new Error( `Invalid provider configuration for token ${String(token)}`, ); } if (provider.scope === "singleton") { const scopeName = provider.moduleName ?? context; this.ensureScope(scopeName); this.instances .get(scopeName)! .set(unwrappedToken, instance); } return instance; } finally { this.resolving.delete(unwrappedToken); if (provider.scope === "singleton") { const scopeName = provider.moduleName ?? context; this.pendingResolutions .get(scopeName) ?.delete(unwrappedToken); } } }; if (provider.scope === "singleton") { const scopeName = provider.moduleName ?? context; this.ensureScope(scopeName); const promise = performResolution(); this.pendingResolutions .get(scopeName)! .set(unwrappedToken, promise); return promise; } return performResolution(); } private findProvider( token: ProviderToken, context: string, ): ResolvedProvider | undefined { const unwrappedToken = this.unwrapToken(token) as ProviderToken; const moduleProviders = this.providers.get(context); if (moduleProviders?.has(unwrappedToken)) { return moduleProviders.get(unwrappedToken); } const options = this.moduleOptions.get(context); if (options?.imports) { for (const imported of options.imports) { const unwrappedImport = this.unwrapToken( imported as ProviderToken, ); const importedProviders = this.providers.get( unwrappedImport.name, ); const importedOptions = this.moduleOptions.get( unwrappedImport.name, ); if ( importedProviders?.has(unwrappedToken) && importedOptions?.exports?.some( (exportToken) => this.unwrapToken(exportToken) === unwrappedToken, ) ) { return importedProviders.get(unwrappedToken); } if (unwrappedToken === unwrappedImport) { return importedProviders?.get(unwrappedToken); } } } if (context !== this.globalContext) { const globalProviders = this.providers.get(this.globalContext); if (globalProviders?.has(unwrappedToken)) { return globalProviders.get(unwrappedToken); } } if (context === this.globalContext && this.rootModuleName) { const rootProviders = this.providers.get(this.rootModuleName); if (rootProviders?.has(unwrappedToken)) { return rootProviders.get(unwrappedToken); } } return undefined; } private normalizeProvider( provider: Provider, moduleName?: string, ): ResolvedProvider { if (typeof provider === "function") { const metadata = Reflect.getMetadata( INJECTABLE_METADATA_KEY, provider, ); return { token: provider, scope: metadata?.scope ?? "singleton", useClass: provider, moduleName, }; } const base = { token: this.unwrapToken(provider.provide) as ProviderToken, scope: provider.scope ?? "singleton", moduleName, }; if ("useClass" in provider) { return { ...base, useClass: provider.useClass }; } if ("useValue" in provider) { return { ...base, useValue: provider.useValue }; } if ("useFactory" in provider) { return { ...base, useFactory: provider.useFactory, inject: provider.inject, }; } if ("useExisting" in provider) { return { ...base, useExisting: provider.useExisting }; } const exhaustiveCheck: never = provider; throw new Error( `Invalid provider definition for token ${String((exhaustiveCheck as BaseProvider).provide)}`, ); } private ensureScope(name: string): void { if (!this.providers.has(name)) { this.providers.set(name, new Map()); } if (!this.instances.has(name)) { this.instances.set(name, new Map()); } if (!this.proxies.has(name)) { this.proxies.set(name, new Map()); } if (!this.pendingResolutions.has(name)) { this.pendingResolutions.set(name, new Map()); } } private unwrapToken(token: ProviderToken): T { if ( typeof token === "object" && token !== null && "forwardRef" in token ) { return (token as { forwardRef: () => T }).forwardRef(); } return token as T; } private createProxy(token: ProviderToken, context: string): T { const self = this; return new Proxy({} as object, { get(_target, prop, receiver) { if (prop === "then") return undefined; const instance = self.instances.get(context)?.get(token); if (!instance) { throw new Error( `Circular dependency instance for ${ typeof token === "function" ? token.name : String(token) } is not yet available. It might be accessed too early (e.g. in a constructor or during static initialization).`, ); } const value = Reflect.get(instance as object, prop, receiver); return typeof value === "function" ? value.bind(instance) : value; }, has(_target, prop) { const instance = self.instances.get(context)?.get(token); return instance ? Reflect.has(instance as object, prop) : false; }, }) as T; } }