diff --git a/src/container/Container.ts b/src/container/Container.ts index 9783a65..3e221df 100644 --- a/src/container/Container.ts +++ b/src/container/Container.ts @@ -29,6 +29,7 @@ export class Container { private readonly instances = new Map>(); private readonly moduleOptions = new Map(); private readonly resolving = new Set(); + private readonly proxies = new Map>(); private rootModuleName?: string; constructor(private readonly globalContext = "global") { @@ -59,7 +60,9 @@ export class Container { if (options.imports) { for (const imported of options.imports) { - this.registerModule(imported); + this.registerModule( + this.unwrapToken(imported as ProviderToken), + ); } } @@ -120,49 +123,69 @@ export class Container { token: ProviderToken, moduleName?: string, ): Promise { + const unwrappedToken = this.unwrapToken(token) as ProviderToken; const context = moduleName ?? this.globalContext; - if (token === ModuleRef) { + if (unwrappedToken === ModuleRef) { return new ContextualModuleRef(this, context) as T; } - let provider = this.findProvider(token, context); + let provider = this.findProvider(unwrappedToken, context); if ( !provider && context === this.globalContext && this.rootModuleName ) { - provider = this.findProvider(token, this.rootModuleName); + provider = this.findProvider(unwrappedToken, this.rootModuleName); } if (!provider) { throw new ProviderNotFoundError( - typeof token === "function" ? token.name : String(token), + typeof unwrappedToken === "function" + ? unwrappedToken.name + : String(unwrappedToken), context, ); } if (provider.scope === "singleton") { - const existing = this.instances - .get(provider.moduleName ?? context) - ?.get(provider.token); + const scopeName = provider.moduleName ?? context; + const proxy = this.proxies.get(scopeName)?.get(unwrappedToken); + if (proxy !== undefined) { + return proxy as T; + } + + const existing = this.instances.get(scopeName)?.get(unwrappedToken); if (existing !== undefined) { return existing as T; } } - if (this.resolving.has(token)) { + if (this.resolving.has(unwrappedToken)) { + if ( + typeof token === "object" && + token !== null && + "forwardRef" in token + ) { + const scopeName = provider.moduleName ?? context; + const proxy = this.createProxy(unwrappedToken, scopeName); + this.ensureScope(scopeName); + this.proxies.get(scopeName)!.set(unwrappedToken, proxy); + return proxy as T; + } const stack = [...this.resolving].map((item) => typeof item === "function" ? item.name : String(item), ); stack.push( - typeof token === "function" ? token.name : String(token), + typeof unwrappedToken === "function" + ? unwrappedToken.name + : String(unwrappedToken), ); throw new CircularDependencyError(stack); } - this.resolving.add(token); + this.resolving.add(unwrappedToken); try { let instance: T; @@ -188,25 +211,10 @@ export class Container { const injectTokens: ProviderToken[] = Reflect.getMetadata(INJECT_METADATA_KEY, targetClass) || []; - this.debug(`[Container] Instantiating ${targetClass.name}`); - this.debug( - "[Container] ParamTypes:", - paramTypes.map((p) => p?.name || String(p)), - ); - this.debug( - "[Container] InjectTokens:", - injectTokens.map((token) => - token ? String(token) : "undefined", - ), - ); - const dependencies = await Promise.all( paramTypes.map(async (paramType, index) => { const depToken = injectTokens[index] || (paramType as ProviderToken); - this.debug( - `[Container] Resolving dependency ${index}: ${String(depToken)}`, - ); return this.resolve( depToken as ProviderToken, provider.moduleName ?? context, @@ -229,7 +237,7 @@ export class Container { return instance; } finally { - this.resolving.delete(token); + this.resolving.delete(unwrappedToken); } } @@ -237,39 +245,50 @@ export class Container { token: ProviderToken, context: string, ): ResolvedProvider | undefined { + const unwrappedToken = this.unwrapToken(token) as ProviderToken; const moduleProviders = this.providers.get(context); - if (moduleProviders?.has(token)) { - return moduleProviders.get(token); + if (moduleProviders?.has(unwrappedToken)) { + return moduleProviders.get(unwrappedToken); } const options = this.moduleOptions.get(context); if (options?.imports) { for (const imported of options.imports) { - const importedProviders = this.providers.get(imported.name); - const importedOptions = this.moduleOptions.get(imported.name); + const unwrappedImport = this.unwrapToken( + imported as ProviderToken, + ); + const importedProviders = this.providers.get( + unwrappedImport.name, + ); + const importedOptions = this.moduleOptions.get( + unwrappedImport.name, + ); if ( - importedProviders?.has(token) && - importedOptions?.exports?.includes(token) + importedProviders?.has(unwrappedToken) && + importedOptions?.exports?.some( + (exportToken) => + this.unwrapToken(exportToken) === unwrappedToken, + ) ) { - return importedProviders.get(token); + return importedProviders.get(unwrappedToken); } - if (token === imported) { - return importedProviders?.get(token); + if (unwrappedToken === unwrappedImport) { + return importedProviders?.get(unwrappedToken); } } } if (context !== this.globalContext) { const globalProviders = this.providers.get(this.globalContext); - if (globalProviders?.has(token)) { - return globalProviders.get(token); + if (globalProviders?.has(unwrappedToken)) { + return globalProviders.get(unwrappedToken); } } if (context === this.globalContext && this.rootModuleName) { const rootProviders = this.providers.get(this.rootModuleName); - if (rootProviders?.has(token)) { - return rootProviders.get(token); + if (rootProviders?.has(unwrappedToken)) { + return rootProviders.get(unwrappedToken); } } @@ -294,7 +313,7 @@ export class Container { } const base = { - token: provider.provide, + token: this.unwrapToken(provider.provide) as ProviderToken, scope: provider.scope ?? "singleton", moduleName, }; @@ -329,11 +348,46 @@ export class Container { if (!this.instances.has(name)) { this.instances.set(name, new Map()); } - } - - private debug(...args: unknown[]): void { - if (process.env.ALVEO_CONTAINER_DEBUG === "true") { - console.debug(...args); + if (!this.proxies.has(name)) { + this.proxies.set(name, new Map()); } } + + private unwrapToken(token: ProviderToken): T { + if ( + typeof token === "object" && + token !== null && + "forwardRef" in token + ) { + return (token as { forwardRef: () => T }).forwardRef(); + } + return token as T; + } + + private createProxy(token: ProviderToken, context: string): T { + const self = this; + return new Proxy({} as object, { + get(_target, prop, receiver) { + if (prop === "then") return undefined; + const instance = self.instances.get(context)?.get(token); + if (!instance) { + throw new Error( + `Circular dependency instance for ${ + typeof token === "function" + ? token.name + : String(token) + } is not yet available. It might be accessed too early (e.g. in a constructor or during static initialization).`, + ); + } + const value = Reflect.get(instance as object, prop, receiver); + return typeof value === "function" + ? value.bind(instance) + : value; + }, + has(_target, prop) { + const instance = self.instances.get(context)?.get(token); + return instance ? Reflect.has(instance as object, prop) : false; + }, + }) as T; + } } diff --git a/src/container/forwardRef.ts b/src/container/forwardRef.ts new file mode 100644 index 0000000..039848a --- /dev/null +++ b/src/container/forwardRef.ts @@ -0,0 +1,21 @@ +/** + * Interface for a function that returns a type. + */ +export type ForwardReference = () => T; + +/** + * Interface for the object returned by forwardRef. + */ +export interface ForwardRefFn { + forwardRef: ForwardReference; +} + +/** + * Allows to refer to a reference which is not yet defined. + * Useful for circular dependencies between classes or modules. + * + * @param fn A function that returns the reference + */ +export function forwardRef(fn: ForwardReference): ForwardRefFn { + return { forwardRef: fn }; +} diff --git a/src/container/types.ts b/src/container/types.ts index 3e927f9..53213f6 100644 --- a/src/container/types.ts +++ b/src/container/types.ts @@ -1,8 +1,13 @@ import type { Constructor, Type } from "../types"; +import type { ForwardRefFn } from "./forwardRef"; export type ProviderScope = "singleton" | "transient"; -export type ProviderToken = Type | string | symbol; +export type ProviderToken = + | Type + | string + | symbol + | ForwardRefFn>; export interface BaseProvider { provide: ProviderToken; @@ -49,7 +54,7 @@ export interface InjectableOptions { } export interface ModuleOptions { - imports?: Constructor[]; + imports?: (Constructor | ForwardRefFn)[]; providers?: Provider[]; exports?: ProviderToken[]; } diff --git a/src/index.ts b/src/index.ts index c5f6cc5..cbdea2f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,4 +1,6 @@ export { Container } from "./container/Container"; +export type { ForwardReference, ForwardRefFn } from "./container/forwardRef"; +export { forwardRef } from "./container/forwardRef"; export type { ClassProvider, ExistingProvider,