Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/mikro-orm.providers.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { EntityManager, MetadataStorage, MikroORM, type AnyEntity, type ForkOptions } from '@mikro-orm/core';
import { EntityManager, EntitySchema, MetadataStorage, MikroORM, type AnyEntity, type ForkOptions } from '@mikro-orm/core';
import { Scope, type InjectionToken, type Provider, type Type } from '@nestjs/common';

import {
Expand Down Expand Up @@ -117,7 +117,7 @@ export function createMikroOrmRepositoryProviders(entities: EntityName<AnyEntity
const inject = contextName ? getEntityManagerToken(contextName) : EntityManager;

(entities || []).forEach(entity => {
const meta = metadata.find(meta => meta.class === entity);
const meta = entity instanceof EntitySchema ? entity.meta : metadata.find(meta => meta.class === entity);
const repository = meta?.repository as unknown as (() => InjectionToken) | undefined;

if (repository) {
Expand Down
47 changes: 47 additions & 0 deletions tests/entities/baz.entity.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import { EntityRepository, EntitySchema, defineEntity, p } from '@mikro-orm/core';

// EntitySchema test
export interface IBaz {
id: number;
name: string;
}

export class BazRepository extends EntityRepository<IBaz> {

customMethod(): string {
return 'custom';
}

}

export const Baz = new EntitySchema<IBaz>({
name: 'Baz',
repository: () => BazRepository,
properties: {
id: { type: 'number', primary: true },
name: { type: 'string' },
},
});

// defineEntity test
export interface IQux {
id: number;
title: string;
}

export class QuxRepository extends EntityRepository<IQux> {

anotherCustomMethod(): string {
return 'another-custom';
}

}

export const Qux = defineEntity({
name: 'Qux',
repository: () => QuxRepository,
properties: {
id: p.integer().primary(),
title: p.string(),
},
});
57 changes: 51 additions & 6 deletions tests/mikro-orm.module.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
} from '@mikro-orm/nestjs';
import { Bar } from './entities/bar.entity.js';
import { Foo } from './entities/foo.entity.js';
import { Baz, BazRepository, Qux, QuxRepository } from './entities/baz.entity.js';

const testOptions = defineConfig({
dbName: ':memory:',
Expand Down Expand Up @@ -91,7 +92,7 @@ describe('MikroORM Module', () => {
imports: [MikroOrmModule.forRoot(testOptions)],
}).compile();

const orm = module.get<MikroORM>(MikroORM);
const orm = module.get(MikroORM);
expect(orm).toBeDefined();
expect(orm.config.get('contextName')).toBe('default');
expect(module.get<EntityManager>(EntityManager)).toBeDefined();
Expand All @@ -109,7 +110,7 @@ describe('MikroORM Module', () => {
],
}).compile();

const orm = module.get<MikroORM>(MikroORM);
const orm = module.get(MikroORM);
expect(orm).toBeDefined();
expect(orm.config.get('contextName')).toBe('default');
expect(module.get<EntityManager>(EntityManager)).toBeDefined();
Expand All @@ -127,7 +128,7 @@ describe('MikroORM Module', () => {
],
}).compile();

const orm = module.get<MikroORM>(MikroORM);
const orm = module.get(MikroORM);
expect(orm).toBeDefined();
expect(orm.config.get('contextName')).toBe('default');
expect(module.get<EntityManager>(EntityManager)).toBeDefined();
Expand All @@ -149,7 +150,7 @@ describe('MikroORM Module', () => {
],
}).compile();

const orm = module.get<MikroORM>(MikroORM);
const orm = module.get(MikroORM);
expect(orm).toBeDefined();
expect(orm.config.get('contextName')).toBe('default');
expect(module.get<EntityManager>(EntityManager)).toBeDefined();
Expand Down Expand Up @@ -261,7 +262,7 @@ describe('MikroORM Module', () => {
imports: [...MikroOrmModule.forRoot([testOptions]), MikroOrmModule.forFeature([Foo])],
}).compile();

const orm = module.get<MikroORM>(MikroORM);
const orm = module.get(MikroORM);
const entityManager = module.get<EntityManager>(EntityManager);
const repository = module.get<EntityRepository<Foo>>(getRepositoryToken(Foo));

Expand All @@ -285,7 +286,7 @@ describe('MikroORM Module', () => {
],
}).compile();

const orm = module.get<MikroORM>(MikroORM);
const orm = module.get(MikroORM);
const entityManager = module.get<EntityManager>(EntityManager);
const repository = module.get<EntityRepository<Foo>>(getRepositoryToken(Foo));

Expand All @@ -295,6 +296,50 @@ describe('MikroORM Module', () => {

await orm.close();
});

it('forFeature should return custom repository for EntitySchema (GH6701)', async () => {
const module = await Test.createTestingModule({
imports: [
MikroOrmModule.forRoot({
...testOptions,
entities: [Baz],
}),
MikroOrmModule.forFeature([Baz]),
],
}).compile();

const orm = module.get(MikroORM);
const repository = module.get(BazRepository);

expect(orm).toBeDefined();
expect(repository).toBeDefined();
expect(repository).toBeInstanceOf(BazRepository);
expect(repository.customMethod()).toBe('custom');

await orm.close();
});

it('forFeature should return custom repository for defineEntity (GH6701)', async () => {
const module = await Test.createTestingModule({
imports: [
MikroOrmModule.forRoot({
...testOptions,
entities: [Qux],
}),
MikroOrmModule.forFeature([Qux]),
],
}).compile();

const orm = module.get(MikroORM);
const repository = module.get(QuxRepository);

expect(orm).toBeDefined();
expect(repository).toBeDefined();
expect(repository).toBeInstanceOf(QuxRepository);
expect(repository.anotherCustomMethod()).toBe('another-custom');

await orm.close();
});
});

describe('Multiple Databases', () => {
Expand Down