import "reflect-metadata"; import { CircularDependencyError, ProviderNotFoundError } from "../errors"; import { ContextualModuleRef, ModuleRef } from "../module/ModuleRef"; import type { Constructor } from "../types"; import type { ForwardRefFn } from "./forwardRef"; import { type BaseProvider, type DynamicModule, INJECT_METADATA_KEY, INJECTABLE_METADATA_KEY, MODULE_METADATA_KEY, type ModuleOptions, type Provider, type ProviderToken, type ResolvedProvider, } from "./types"; interface ParamType { name?: string; } type LifecycleHook = "onModuleInit" | "onModuleDestroy"; type LifecycleHookFn = () => Promise | void; type LifecycleAware = Partial>; export class Container { private readonly providers = new Map< string, Map >(); private readonly instances = new Map>(); private readonly moduleOptions = new Map(); private readonly resolving = new Set(); private readonly proxies = new Map>(); private readonly pendingResolutions = new Map< string, Map> >(); private rootModuleName?: string; constructor(private readonly globalContext = "global") { this.ensureScope(this.globalContext); } public async registerRootModule( module: Constructor | DynamicModule, ): Promise { const moduleClass = (module as DynamicModule).module || module; this.rootModuleName = (moduleClass as Constructor).name; await this.registerModule(module); } public getRootModuleName(): string | undefined { return this.rootModuleName; } public async registerModule( module: | Constructor | DynamicModule | Promise | ForwardRefFn, ): Promise { const unwrappedModule = (await (typeof module === "object" && module !== null && "forwardRef" in (module as object) ? ( module as { forwardRef: () => Constructor | DynamicModule } ).forwardRef() : module)) as Constructor | DynamicModule; const isDynamic = typeof unwrappedModule === "object" && "module" in unwrappedModule; const moduleClass = isDynamic ? (unwrappedModule as DynamicModule).module : (unwrappedModule as Constructor); const dynamicOptions: ModuleOptions = isDynamic ? (unwrappedModule as DynamicModule) : {}; const metadataOptions: ModuleOptions | undefined = Reflect.getMetadata( MODULE_METADATA_KEY, moduleClass, ); if (!metadataOptions && !isDynamic) return; const moduleName = moduleClass.name; if (this.moduleOptions.has(moduleName)) return; const rawImports = [ ...(metadataOptions?.imports || []), ...(dynamicOptions?.imports || []), ]; const resolvedImports = await Promise.all( rawImports.map(async (imp) => { return await (typeof imp === "object" && imp !== null && "forwardRef" in (imp as object) ? ( imp as { forwardRef: () => Constructor | DynamicModule; } ).forwardRef() : imp); }), ); const options: ModuleOptions = { ...metadataOptions, ...dynamicOptions, imports: resolvedImports, providers: [ ...(metadataOptions?.providers || []), ...(dynamicOptions?.providers || []), ], exports: [ ...(metadataOptions?.exports || []), ...(dynamicOptions?.exports || []), ], }; this.moduleOptions.set(moduleName, options); this.ensureScope(moduleName); for (const imported of resolvedImports) { await this.registerModule( imported as | Constructor | DynamicModule | Promise, ); } if (options.providers) { for (const provider of options.providers) { const normalized = this.normalizeProvider(provider, moduleName); this.providers .get(moduleName)! .set(normalized.token, normalized); } } const moduleProvider = this.normalizeProvider(moduleClass, moduleName); this.providers.get(moduleName)!.set(moduleClass, moduleProvider); } public registerProvider(provider: Provider): void { const normalized = this.normalizeProvider(provider, this.globalContext); this.providers .get(this.globalContext)! .set(normalized.token, normalized); } public async get( token: ProviderToken, moduleName?: string, ): Promise { return this.resolve(token, moduleName); } public async resolveAll(): Promise { for (const [scope] of this.providers) { for (const token of this.providers.get(scope)!.keys()) { await this.resolve( token, scope === this.globalContext ? undefined : scope, ); } } } public async callLifecycleHook( hook: "onModuleInit" | "onModuleDestroy", ): Promise { for (const instances of this.instances.values()) { for (const instance of instances.values()) { if (!instance) continue; const lifecycleInstance = instance as LifecycleAware; const hookFn = lifecycleInstance[hook]; if (typeof hookFn === "function") { await hookFn.call(lifecycleInstance); } } } } private async resolve( token: ProviderToken, moduleName?: string, ): Promise { const unwrappedToken = this.unwrapToken(token) as ProviderToken; const context = moduleName ?? this.globalContext; if (unwrappedToken === ModuleRef) { return new ContextualModuleRef(this, context) as T; } let provider = this.findProvider(unwrappedToken, context); if ( !provider && context === this.globalContext && this.rootModuleName ) { provider = this.findProvider(unwrappedToken, this.rootModuleName); } if (!provider) { throw new ProviderNotFoundError( typeof unwrappedToken === "function" ? unwrappedToken.name : String(unwrappedToken), context, ); } if (provider.scope === "singleton") { const scopeName = provider.moduleName ?? context; const proxy = this.proxies.get(scopeName)?.get(unwrappedToken); if (proxy !== undefined) { return proxy as T; } const existing = this.instances.get(scopeName)?.get(unwrappedToken); if (existing !== undefined) { return existing as T; } const pending = this.pendingResolutions .get(scopeName) ?.get(unwrappedToken); if (pending) { return pending as Promise; } } if (this.resolving.has(unwrappedToken)) { if ( typeof token === "object" && token !== null && "forwardRef" in token ) { const scopeName = provider.moduleName ?? context; const proxy = this.createProxy(unwrappedToken, scopeName); this.ensureScope(scopeName); this.proxies.get(scopeName)!.set(unwrappedToken, proxy); return proxy as T; } const stack = [...this.resolving].map((item) => typeof item === "function" ? item.name : String(item), ); stack.push( typeof unwrappedToken === "function" ? unwrappedToken.name : String(unwrappedToken), ); throw new CircularDependencyError(stack); } const performResolution = async (): Promise => { this.resolving.add(unwrappedToken); try { let instance: T; if (provider.useValue !== undefined) { instance = provider.useValue as T; } else if (provider.useFactory) { const dependencies = await Promise.all( (provider.inject ?? []).map((dep) => this.resolve(dep, provider.moduleName ?? context), ), ); instance = (await provider.useFactory( ...dependencies, )) as T; } else if (provider.useExisting) { instance = await this.resolve( provider.useExisting as ProviderToken, provider.moduleName ?? context, ); } else if (provider.useClass) { const targetClass = provider.useClass; const paramTypes: ParamType[] = Reflect.getMetadata("design:paramtypes", targetClass) || []; const injectTokens: ProviderToken[] = Reflect.getMetadata(INJECT_METADATA_KEY, targetClass) || []; const dependencies = await Promise.all( paramTypes.map(async (paramType, index) => { const depToken = injectTokens[index] || (paramType as ProviderToken); return this.resolve( depToken as ProviderToken, provider.moduleName ?? context, ); }), ); instance = new targetClass(...dependencies) as T; } else { throw new Error( `Invalid provider configuration for token ${String(token)}`, ); } if (provider.scope === "singleton") { const scopeName = provider.moduleName ?? context; this.ensureScope(scopeName); this.instances .get(scopeName)! .set(unwrappedToken, instance); } return instance; } finally { this.resolving.delete(unwrappedToken); if (provider.scope === "singleton") { const scopeName = provider.moduleName ?? context; this.pendingResolutions .get(scopeName) ?.delete(unwrappedToken); } } }; if (provider.scope === "singleton") { const scopeName = provider.moduleName ?? context; this.ensureScope(scopeName); const promise = performResolution(); this.pendingResolutions .get(scopeName)! .set(unwrappedToken, promise); return promise; } return performResolution(); } private findProvider( token: ProviderToken, context: string, ): ResolvedProvider | undefined { const unwrappedToken = this.unwrapToken(token) as ProviderToken; const moduleProviders = this.providers.get(context); if (moduleProviders?.has(unwrappedToken)) { return moduleProviders.get(unwrappedToken); } const options = this.moduleOptions.get(context); if (options?.imports) { for (const imported of options.imports) { const importedModule = (imported as DynamicModule).module || (imported as Constructor); if (typeof importedModule !== "function") continue; const moduleName = importedModule.name; const importedProviders = this.providers.get(moduleName); const importedOptions = this.moduleOptions.get(moduleName); if ( importedProviders?.has(unwrappedToken) && importedOptions?.exports?.some( (exportToken) => this.unwrapToken(exportToken) === unwrappedToken, ) ) { return importedProviders.get(unwrappedToken); } if (unwrappedToken === importedModule) { return importedProviders?.get(unwrappedToken); } } } if (context !== this.globalContext) { const globalProviders = this.providers.get(this.globalContext); if (globalProviders?.has(unwrappedToken)) { return globalProviders.get(unwrappedToken); } } if (context === this.globalContext && this.rootModuleName) { const rootProviders = this.providers.get(this.rootModuleName); if (rootProviders?.has(unwrappedToken)) { return rootProviders.get(unwrappedToken); } } return undefined; } private normalizeProvider( provider: Provider, moduleName?: string, ): ResolvedProvider { if (typeof provider === "function") { const metadata = Reflect.getMetadata( INJECTABLE_METADATA_KEY, provider, ); return { token: provider, scope: metadata?.scope ?? "singleton", useClass: provider, moduleName, }; } const base = { token: this.unwrapToken(provider.provide) as ProviderToken, scope: provider.scope ?? "singleton", moduleName, }; if ("useClass" in provider) { return { ...base, useClass: provider.useClass }; } if ("useValue" in provider) { return { ...base, useValue: provider.useValue }; } if ("useFactory" in provider) { return { ...base, useFactory: provider.useFactory, inject: provider.inject, }; } if ("useExisting" in provider) { return { ...base, useExisting: provider.useExisting }; } const exhaustiveCheck: never = provider; throw new Error( `Invalid provider definition for token ${String((exhaustiveCheck as BaseProvider).provide)}`, ); } private ensureScope(name: string): void { if (!this.providers.has(name)) { this.providers.set(name, new Map()); } if (!this.instances.has(name)) { this.instances.set(name, new Map()); } if (!this.proxies.has(name)) { this.proxies.set(name, new Map()); } if (!this.pendingResolutions.has(name)) { this.pendingResolutions.set(name, new Map()); } } private unwrapToken(token: ProviderToken): T { if ( typeof token === "object" && token !== null && "forwardRef" in token ) { return (token as { forwardRef: () => T }).forwardRef(); } return token as T; } private createProxy(token: ProviderToken, context: string): T { const self = this; return new Proxy({} as object, { get(_target, prop, receiver) { if (prop === "then") return undefined; const instance = self.instances.get(context)?.get(token); if (!instance) { throw new Error( `Circular dependency instance for ${ typeof token === "function" ? token.name : String(token) } is not yet available. It might be accessed too early (e.g. in a constructor or during static initialization).`, ); } const value = Reflect.get(instance as object, prop, receiver); return typeof value === "function" ? value.bind(instance) : value; }, has(_target, prop) { const instance = self.instances.get(context)?.get(token); return instance ? Reflect.has(instance as object, prop) : false; }, }) as T; } }