diff --git a/package.json b/package.json index 715dfb26e..e85f7cdcb 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "zenstack-monorepo", - "version": "2.10.2", + "version": "2.11.0", "description": "", "scripts": { "build": "pnpm -r --filter=\"!./packages/ide/*\" build", diff --git a/packages/ide/jetbrains/build.gradle.kts b/packages/ide/jetbrains/build.gradle.kts index a2fc573a5..8412fc899 100644 --- a/packages/ide/jetbrains/build.gradle.kts +++ b/packages/ide/jetbrains/build.gradle.kts @@ -9,7 +9,7 @@ plugins { } group = "dev.zenstack" -version = "2.10.2" +version = "2.11.0" repositories { mavenCentral() diff --git a/packages/ide/jetbrains/package.json b/packages/ide/jetbrains/package.json index 331700e89..694b784e9 100644 --- a/packages/ide/jetbrains/package.json +++ b/packages/ide/jetbrains/package.json @@ -1,6 +1,6 @@ { "name": "jetbrains", - "version": "2.10.2", + "version": "2.11.0", "displayName": "ZenStack JetBrains IDE Plugin", "description": "ZenStack JetBrains IDE plugin", "homepage": "https://zenstack.dev", diff --git a/packages/language/package.json b/packages/language/package.json index aa9fdb382..1f4d72647 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/language", - "version": "2.10.2", + "version": "2.11.0", "displayName": "ZenStack modeling language compiler", "description": "ZenStack modeling language compiler", "homepage": "https://zenstack.dev", diff --git a/packages/misc/redwood/package.json b/packages/misc/redwood/package.json index f49e46f0b..159ab6d29 100644 --- a/packages/misc/redwood/package.json +++ b/packages/misc/redwood/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/redwood", "displayName": "ZenStack RedwoodJS Integration", - "version": "2.10.2", + "version": "2.11.0", "description": "CLI and runtime for integrating ZenStack with RedwoodJS projects.", "repository": { "type": "git", diff --git a/packages/plugins/openapi/package.json b/packages/plugins/openapi/package.json index 508dce2b5..e29c8f0e7 100644 --- a/packages/plugins/openapi/package.json +++ b/packages/plugins/openapi/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/openapi", "displayName": "ZenStack Plugin and Runtime for OpenAPI", - "version": "2.10.2", + "version": "2.11.0", "description": "ZenStack plugin and runtime supporting OpenAPI", "main": "index.js", "repository": { diff --git a/packages/plugins/openapi/src/rest-generator.ts b/packages/plugins/openapi/src/rest-generator.ts index e6da0268b..98c6abcb6 100644 --- a/packages/plugins/openapi/src/rest-generator.ts +++ b/packages/plugins/openapi/src/rest-generator.ts @@ -906,16 +906,24 @@ export class RESTfulOpenAPIGenerator extends OpenAPIGeneratorBase { }, }; + let idFieldSchema: OAPI.SchemaObject = { type: 'string' }; + if (idFields.length === 1) { + // FIXME: JSON:API actually requires id field to be a string, + // but currently the RESTAPIHandler returns the original data + // type as declared in the ZModel schema. + idFieldSchema = this.fieldTypeToOpenAPISchema(idFields[0].type); + } + if (mode === 'create') { // 'id' is required if there's no default value const idFields = model.fields.filter((f) => isIdField(f)); if (idFields.length === 1 && !hasAttribute(idFields[0], '@default')) { - properties = { id: { type: 'string' }, ...properties }; + properties = { id: idFieldSchema, ...properties }; toplevelRequired.unshift('id'); } } else { // 'id' always required for read and update - properties = { id: { type: 'string' }, ...properties }; + properties = { id: idFieldSchema, ...properties }; toplevelRequired.unshift('id'); } diff --git a/packages/plugins/openapi/tests/openapi-restful.test.ts b/packages/plugins/openapi/tests/openapi-restful.test.ts index 8fd0880ff..51d16e888 100644 --- a/packages/plugins/openapi/tests/openapi-restful.test.ts +++ b/packages/plugins/openapi/tests/openapi-restful.test.ts @@ -84,7 +84,7 @@ model Bar { const { name: output } = tmp.fileSync({ postfix: '.yaml' }); - const options = buildOptions(model, modelFile, output, '3.1.0'); + const options = buildOptions(model, modelFile, output, specVersion); await generate(model, options, dmmf); console.log(`OpenAPI specification generated for ${specVersion}: ${output}`); @@ -324,7 +324,7 @@ model Foo { const { name: output } = tmp.fileSync({ postfix: '.yaml' }); - const options = buildOptions(model, modelFile, output, '3.1.0'); + const options = buildOptions(model, modelFile, output, specVersion); await generate(model, options, dmmf); console.log(`OpenAPI specification generated for ${specVersion}: ${output}`); @@ -340,6 +340,28 @@ model Foo { } }); + it('int field as id', async () => { + const { model, dmmf, modelFile } = await loadZModelAndDmmf(` +plugin openapi { + provider = '${normalizePath(path.resolve(__dirname, '../dist'))}' +} + +model Foo { + id Int @id @default(autoincrement()) +} + `); + + const { name: output } = tmp.fileSync({ postfix: '.yaml' }); + + const options = buildOptions(model, modelFile, output, '3.0.0'); + await generate(model, options, dmmf); + console.log(`OpenAPI specification generated: ${output}`); + await OpenAPIParser.validate(output); + + const parsed = YAML.parse(fs.readFileSync(output, 'utf-8')); + expect(parsed.components.schemas.Foo.properties.id.type).toBe('integer'); + }); + it('exposes individual fields from a compound id as attributes', async () => { const { model, dmmf, modelFile } = await loadZModelAndDmmf(` plugin openapi { diff --git a/packages/plugins/swr/package.json b/packages/plugins/swr/package.json index a3b0c1018..ba61604c7 100644 --- a/packages/plugins/swr/package.json +++ b/packages/plugins/swr/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/swr", "displayName": "ZenStack plugin for generating SWR hooks", - "version": "2.10.2", + "version": "2.11.0", "description": "ZenStack plugin for generating SWR hooks", "main": "index.js", "repository": { diff --git a/packages/plugins/tanstack-query/package.json b/packages/plugins/tanstack-query/package.json index 713457bf6..70d538457 100644 --- a/packages/plugins/tanstack-query/package.json +++ b/packages/plugins/tanstack-query/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/tanstack-query", "displayName": "ZenStack plugin for generating tanstack-query hooks", - "version": "2.10.2", + "version": "2.11.0", "description": "ZenStack plugin for generating tanstack-query hooks", "main": "index.js", "exports": { diff --git a/packages/plugins/trpc/package.json b/packages/plugins/trpc/package.json index 2aa5ad5d5..c900331e5 100644 --- a/packages/plugins/trpc/package.json +++ b/packages/plugins/trpc/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/trpc", "displayName": "ZenStack plugin for tRPC", - "version": "2.10.2", + "version": "2.11.0", "description": "ZenStack plugin for tRPC", "main": "index.js", "repository": { diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 51d026db7..1f6f106aa 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/runtime", "displayName": "ZenStack Runtime Library", - "version": "2.10.2", + "version": "2.11.0", "description": "Runtime of ZenStack for both client-side and server-side environments.", "repository": { "type": "git", @@ -80,6 +80,10 @@ "types": "./zod-utils.d.ts", "default": "./zod-utils.js" }, + "./encryption": { + "types": "./encryption/index.d.ts", + "default": "./encryption/index.js" + }, "./package.json": { "default": "./package.json" } diff --git a/packages/runtime/src/constants.ts b/packages/runtime/src/constants.ts index 36acf8c83..495e1853d 100644 --- a/packages/runtime/src/constants.ts +++ b/packages/runtime/src/constants.ts @@ -67,3 +67,15 @@ export const PRISMA_MINIMUM_VERSION = '5.0.0'; * Prefix for auxiliary relation field generated for delegated models */ export const DELEGATE_AUX_RELATION_PREFIX = 'delegate_aux'; + +/** + * Prisma actions that can have a write payload + */ +export const ACTIONS_WITH_WRITE_PAYLOAD = [ + 'create', + 'createMany', + 'createManyAndReturn', + 'update', + 'updateMany', + 'upsert', +]; diff --git a/packages/runtime/src/cross/nested-write-visitor.ts b/packages/runtime/src/cross/nested-write-visitor.ts index c69f9d203..ba4b232a6 100644 --- a/packages/runtime/src/cross/nested-write-visitor.ts +++ b/packages/runtime/src/cross/nested-write-visitor.ts @@ -4,7 +4,7 @@ import type { FieldInfo, ModelMeta } from './model-meta'; import { resolveField } from './model-meta'; import { MaybePromise, PrismaWriteActionType, PrismaWriteActions } from './types'; -import { getModelFields } from './utils'; +import { enumerate, getModelFields } from './utils'; type NestingPathItem = { field?: FieldInfo; model: string; where: any; unique: boolean }; @@ -310,31 +310,33 @@ export class NestedWriteVisitor { payload: any, nestingPath: NestingPathItem[] ) { - for (const field of getModelFields(payload)) { - const fieldInfo = resolveField(this.modelMeta, model, field); - if (!fieldInfo) { - continue; - } + for (const item of enumerate(payload)) { + for (const field of getModelFields(item)) { + const fieldInfo = resolveField(this.modelMeta, model, field); + if (!fieldInfo) { + continue; + } - if (fieldInfo.isDataModel) { - if (payload[field]) { - // recurse into nested payloads - for (const [subAction, subData] of Object.entries(payload[field])) { - if (this.isPrismaWriteAction(subAction) && subData) { - await this.doVisit(fieldInfo.type, subAction, subData, payload[field], fieldInfo, [ - ...nestingPath, - ]); + if (fieldInfo.isDataModel) { + if (item[field]) { + // recurse into nested payloads + for (const [subAction, subData] of Object.entries(item[field])) { + if (this.isPrismaWriteAction(subAction) && subData) { + await this.doVisit(fieldInfo.type, subAction, subData, item[field], fieldInfo, [ + ...nestingPath, + ]); + } } } - } - } else { - // visit plain field - if (this.callback.field) { - await this.callback.field(fieldInfo, action, payload[field], { - parent: payload, - nestingPath, - field: fieldInfo, - }); + } else { + // visit plain field + if (this.callback.field) { + await this.callback.field(fieldInfo, action, item[field], { + parent: item, + nestingPath, + field: fieldInfo, + }); + } } } } diff --git a/packages/runtime/src/encryption/index.ts b/packages/runtime/src/encryption/index.ts new file mode 100644 index 000000000..d4cb31db6 --- /dev/null +++ b/packages/runtime/src/encryption/index.ts @@ -0,0 +1,67 @@ +import { _decrypt, _encrypt, ENCRYPTION_KEY_BYTES, getKeyDigest, loadKey } from './utils'; + +/** + * Default encrypter + */ +export class Encrypter { + private key: CryptoKey | undefined; + private keyDigest: string | undefined; + + constructor(private readonly encryptionKey: Uint8Array) { + if (encryptionKey.length !== ENCRYPTION_KEY_BYTES) { + throw new Error(`Encryption key must be ${ENCRYPTION_KEY_BYTES} bytes`); + } + } + + /** + * Encrypts the given data + */ + async encrypt(data: string): Promise { + if (!this.key) { + this.key = await loadKey(this.encryptionKey, ['encrypt']); + } + + if (!this.keyDigest) { + this.keyDigest = await getKeyDigest(this.encryptionKey); + } + + return _encrypt(data, this.key, this.keyDigest); + } +} + +/** + * Default decrypter + */ +export class Decrypter { + private keys: Array<{ key: CryptoKey; digest: string }> = []; + + constructor(private readonly decryptionKeys: Uint8Array[]) { + if (decryptionKeys.length === 0) { + throw new Error('At least one decryption key must be provided'); + } + + for (const key of decryptionKeys) { + if (key.length !== ENCRYPTION_KEY_BYTES) { + throw new Error(`Decryption key must be ${ENCRYPTION_KEY_BYTES} bytes`); + } + } + } + + /** + * Decrypts the given data + */ + async decrypt(data: string): Promise { + if (this.keys.length === 0) { + this.keys = await Promise.all( + this.decryptionKeys.map(async (key) => ({ + key: await loadKey(key, ['decrypt']), + digest: await getKeyDigest(key), + })) + ); + } + + return _decrypt(data, async (digest) => + this.keys.filter((entry) => entry.digest === digest).map((entry) => entry.key) + ); + } +} diff --git a/packages/runtime/src/encryption/utils.ts b/packages/runtime/src/encryption/utils.ts new file mode 100644 index 000000000..51ab41dc7 --- /dev/null +++ b/packages/runtime/src/encryption/utils.ts @@ -0,0 +1,96 @@ +import { z } from 'zod'; + +export const ENCRYPTER_VERSION = 1; +export const ENCRYPTION_KEY_BYTES = 32; +export const IV_BYTES = 12; +export const ALGORITHM = 'AES-GCM'; +export const KEY_DIGEST_BYTES = 8; + +const encoder = new TextEncoder(); +const decoder = new TextDecoder(); + +const encryptionMetaSchema = z.object({ + // version + v: z.number(), + // algorithm + a: z.string(), + // key digest + k: z.string(), +}); + +export async function loadKey(key: Uint8Array, keyUsages: KeyUsage[]): Promise { + return crypto.subtle.importKey('raw', key, ALGORITHM, false, keyUsages); +} + +export async function getKeyDigest(key: Uint8Array) { + const rawDigest = await crypto.subtle.digest('SHA-256', key); + return new Uint8Array(rawDigest.slice(0, KEY_DIGEST_BYTES)).reduce( + (acc, byte) => acc + byte.toString(16).padStart(2, '0'), + '' + ); +} + +export async function _encrypt(data: string, key: CryptoKey, keyDigest: string): Promise { + const iv = crypto.getRandomValues(new Uint8Array(IV_BYTES)); + const encrypted = await crypto.subtle.encrypt( + { + name: ALGORITHM, + iv, + }, + key, + encoder.encode(data) + ); + + // combine IV and encrypted data into a single array of bytes + const cipherBytes = [...iv, ...new Uint8Array(encrypted)]; + + // encryption metadata + const meta = { v: ENCRYPTER_VERSION, a: ALGORITHM, k: keyDigest }; + + // convert concatenated result to base64 string + return `${btoa(JSON.stringify(meta))}.${btoa(String.fromCharCode(...cipherBytes))}`; +} + +export async function _decrypt(data: string, findKey: (digest: string) => Promise): Promise { + const [metaText, cipherText] = data.split('.'); + if (!metaText || !cipherText) { + throw new Error('Malformed encrypted data'); + } + + let metaObj: unknown; + try { + metaObj = JSON.parse(atob(metaText)); + } catch (error) { + throw new Error('Malformed metadata'); + } + + // parse meta + const { a: algorithm, k: keyDigest } = encryptionMetaSchema.parse(metaObj); + + // find a matching decryption key + const keys = await findKey(keyDigest); + if (keys.length === 0) { + throw new Error('No matching decryption key found'); + } + + // convert base64 back to bytes + const bytes = Uint8Array.from(atob(cipherText), (c) => c.charCodeAt(0)); + + // extract IV from the head + const iv = bytes.slice(0, IV_BYTES); + const cipher = bytes.slice(IV_BYTES); + let lastError: unknown; + + for (const key of keys) { + let decrypted: ArrayBuffer; + try { + decrypted = await crypto.subtle.decrypt({ name: algorithm, iv }, key, cipher); + } catch (err) { + lastError = err; + continue; + } + return decoder.decode(decrypted); + } + + throw lastError; +} diff --git a/packages/runtime/src/enhancements/edge/encryption.ts b/packages/runtime/src/enhancements/edge/encryption.ts new file mode 120000 index 000000000..9931fc8ea --- /dev/null +++ b/packages/runtime/src/enhancements/edge/encryption.ts @@ -0,0 +1 @@ +../node/encryption.ts \ No newline at end of file diff --git a/packages/runtime/src/enhancements/node/create-enhancement.ts b/packages/runtime/src/enhancements/node/create-enhancement.ts index adec1fdf2..07f905182 100644 --- a/packages/runtime/src/enhancements/node/create-enhancement.ts +++ b/packages/runtime/src/enhancements/node/create-enhancement.ts @@ -10,6 +10,7 @@ import type { } from '../../types'; import { withDefaultAuth } from './default-auth'; import { withDelegate } from './delegate'; +import { withEncrypted } from './encryption'; import { withJsonProcessor } from './json-processor'; import { Logger } from './logger'; import { withOmit } from './omit'; @@ -20,7 +21,7 @@ import type { PolicyDef } from './types'; /** * All enhancement kinds */ -const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate']; +const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate', 'encryption']; /** * Options for {@link createEnhancement} @@ -100,6 +101,7 @@ export function createEnhancement( } const hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password')); + const hasEncrypted = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@encrypted')); const hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit')); const hasDefaultAuth = allFields.some((field) => field.defaultValueProvider); const hasTypeDefField = allFields.some((field) => field.isTypeDef); @@ -120,13 +122,22 @@ export function createEnhancement( } } - // password enhancement must be applied prior to policy because it changes then length of the field + // password and encrypted enhancement must be applied prior to policy because it changes then length of the field // and can break validation rules like `@length` if (hasPassword && kinds.includes('password')) { // @password proxy result = withPassword(result, options); } + if (hasEncrypted && kinds.includes('encryption')) { + if (!options.encryption) { + throw new Error('Encryption options are required for @encrypted enhancement'); + } + + // @encrypted proxy + result = withEncrypted(result, options); + } + // 'policy' and 'validation' enhancements are both enabled by `withPolicy` if (kinds.includes('policy') || kinds.includes('validation')) { result = withPolicy(result, options, context); diff --git a/packages/runtime/src/enhancements/node/default-auth.ts b/packages/runtime/src/enhancements/node/default-auth.ts index 03ce3750c..e6162a2d2 100644 --- a/packages/runtime/src/enhancements/node/default-auth.ts +++ b/packages/runtime/src/enhancements/node/default-auth.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ /* eslint-disable @typescript-eslint/no-explicit-any */ +import { ACTIONS_WITH_WRITE_PAYLOAD } from '../../constants'; import { FieldInfo, NestedWriteVisitor, @@ -50,15 +51,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { // base override protected async preprocessArgs(action: PrismaProxyActions, args: any) { - const actionsOfInterest: PrismaProxyActions[] = [ - 'create', - 'createMany', - 'createManyAndReturn', - 'update', - 'updateMany', - 'upsert', - ]; - if (actionsOfInterest.includes(action)) { + if (args && ACTIONS_WITH_WRITE_PAYLOAD.includes(action)) { const newArgs = await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); return newArgs; } diff --git a/packages/runtime/src/enhancements/node/delegate.ts b/packages/runtime/src/enhancements/node/delegate.ts index 80fd09f17..06c1526e5 100644 --- a/packages/runtime/src/enhancements/node/delegate.ts +++ b/packages/runtime/src/enhancements/node/delegate.ts @@ -180,47 +180,102 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { return; } - for (const kind of ['select', 'include'] as const) { - if (args[kind] && typeof args[kind] === 'object') { - for (const [field, value] of Object.entries(args[kind])) { - const fieldInfo = resolveField(this.options.modelMeta, model, field); - if (!fieldInfo) { - continue; - } + // there're two cases where we need to inject polymorphic base hierarchy for fields + // defined in base models + // 1. base fields mentioned in select/include clause + // { select: { fieldFromBase: true } } => { select: { delegate_aux_[Base]: { fieldFromBase: true } } } + // 2. base fields mentioned in _count select/include clause + // { select: { _count: { select: { fieldFromBase: true } } } } => { select: { delegate_aux_[Base]: { select: { _count: { select: { fieldFromBase: true } } } } } } + // + // Note that although structurally similar, we need to correctly deal with different injection location of the "delegate_aux" hierarchy + + // selectors for the above two cases + const selectors = [ + // regular select: { select: { field: true } } + (payload: any) => ({ data: payload.select, kind: 'select' as const, isCount: false }), + // regular include: { include: { field: true } } + (payload: any) => ({ data: payload.include, kind: 'include' as const, isCount: false }), + // select _count: { select: { _count: { select: { field: true } } } } + (payload: any) => ({ + data: payload.select?._count?.select, + kind: 'select' as const, + isCount: true, + }), + // include _count: { include: { _count: { select: { field: true } } } } + (payload: any) => ({ + data: payload.include?._count?.select, + kind: 'include' as const, + isCount: true, + }), + ]; + + for (const selector of selectors) { + const { data, kind, isCount } = selector(args); + if (!data || typeof data !== 'object') { + continue; + } - if (this.isDelegateOrDescendantOfDelegate(fieldInfo?.type) && value) { - // delegate model, recursively inject hierarchy - if (args[kind][field]) { - if (args[kind][field] === true) { - // make sure the payload is an object - args[kind][field] = {}; - } - await this.injectSelectIncludeHierarchy(fieldInfo.type, args[kind][field]); + for (const [field, value] of Object.entries(data)) { + const fieldInfo = resolveField(this.options.modelMeta, model, field); + if (!fieldInfo) { + continue; + } + + if (this.isDelegateOrDescendantOfDelegate(fieldInfo?.type) && value) { + // delegate model, recursively inject hierarchy + if (data[field]) { + if (data[field] === true) { + // make sure the payload is an object + data[field] = {}; } + await this.injectSelectIncludeHierarchy(fieldInfo.type, data[field]); } + } - // refetch the field select/include value because it may have been - // updated during injection - const fieldValue = args[kind][field]; + // refetch the field select/include value because it may have been + // updated during injection + const fieldValue = data[field]; - if (fieldValue !== undefined) { - if (fieldValue.orderBy) { - // `orderBy` may contain fields from base types - enumerate(fieldValue.orderBy).forEach((item) => - this.injectWhereHierarchy(fieldInfo.type, item) - ); - } + if (fieldValue !== undefined) { + if (fieldValue.orderBy) { + // `orderBy` may contain fields from base types + enumerate(fieldValue.orderBy).forEach((item) => + this.injectWhereHierarchy(fieldInfo.type, item) + ); + } - if (this.injectBaseFieldSelect(model, field, fieldValue, args, kind)) { - delete args[kind][field]; - } else if (fieldInfo.isDataModel) { - let nextValue = fieldValue; - if (nextValue === true) { - // make sure the payload is an object - args[kind][field] = nextValue = {}; + let injected = false; + if (!isCount) { + // regular select/include injection + injected = await this.injectBaseFieldSelect(model, field, fieldValue, args, kind); + if (injected) { + // if injected, remove the field from the original payload + delete data[field]; + } + } else { + // _count select/include injection, inject into an empty payload and then merge to the proper location + const injectTarget = { [kind]: {} }; + injected = await this.injectBaseFieldSelect(model, field, fieldValue, injectTarget, kind, true); + if (injected) { + // if injected, remove the field from the original payload + delete data[field]; + if (Object.keys(data).length === 0) { + // if the original "_count" payload becomes empty, remove it + delete args[kind]['_count']; } - await this.injectSelectIncludeHierarchy(fieldInfo.type, nextValue); + // finally merge the injection into the original payload + const merged = deepmerge(args[kind], injectTarget[kind]); + args[kind] = merged; + } + } + + if (!injected && fieldInfo.isDataModel) { + let nextValue = fieldValue; + if (nextValue === true) { + // make sure the payload is an object + data[field] = nextValue = {}; } + await this.injectSelectIncludeHierarchy(fieldInfo.type, nextValue); } } } @@ -272,7 +327,8 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { field: string, value: any, selectInclude: any, - context: 'select' | 'include' + context: 'select' | 'include', + forCount = false // if the injection is for a "{ _count: { select: { field: true } } }" payload ) { const fieldInfo = resolveField(this.options.modelMeta, model, field); if (!fieldInfo?.inheritedFrom) { @@ -286,16 +342,12 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { const baseRelationName = this.makeAuxRelationName(base); // prepare base layer select/include - // let selectOrInclude = 'select'; let thisLayer: any; if (target.include) { - // selectOrInclude = 'include'; thisLayer = target.include; } else if (target.select) { - // selectOrInclude = 'select'; thisLayer = target.select; } else { - // selectInclude = 'include'; thisLayer = target.select = {}; } @@ -303,7 +355,22 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { if (!thisLayer[baseRelationName]) { thisLayer[baseRelationName] = { [context]: {} }; } - thisLayer[baseRelationName][context][field] = value; + if (forCount) { + // { _count: { select: { field: true } } } => { delegate_aux_[Base]: { select: { _count: { select: { field: true } } } } } + if ( + !thisLayer[baseRelationName][context]['_count'] || + typeof thisLayer[baseRelationName][context] !== 'object' + ) { + thisLayer[baseRelationName][context]['_count'] = {}; + } + thisLayer[baseRelationName][context]['_count'] = deepmerge( + thisLayer[baseRelationName][context]['_count'], + { select: { [field]: value } } + ); + } else { + // { select: { field: true } } => { delegate_aux_[Base]: { select: { field: true } } } + thisLayer[baseRelationName][context][field] = value; + } break; } else { if (!thisLayer[baseRelationName]) { diff --git a/packages/runtime/src/enhancements/node/encryption.ts b/packages/runtime/src/enhancements/node/encryption.ts new file mode 100644 index 000000000..4859a1225 --- /dev/null +++ b/packages/runtime/src/enhancements/node/encryption.ts @@ -0,0 +1,165 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +/* eslint-disable @typescript-eslint/no-unused-vars */ + +import { ACTIONS_WITH_WRITE_PAYLOAD } from '../../constants'; +import { + FieldInfo, + NestedWriteVisitor, + enumerate, + getModelFields, + resolveField, + type PrismaWriteActionType, +} from '../../cross'; +import { Decrypter, Encrypter } from '../../encryption'; +import { CustomEncryption, DbClientContract, SimpleEncryption } from '../../types'; +import { InternalEnhancementOptions } from './create-enhancement'; +import { Logger } from './logger'; +import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; +import { QueryUtils } from './query-utils'; + +/** + * Gets an enhanced Prisma client that supports `@encrypted` attribute. + * + * @private + */ +export function withEncrypted( + prisma: DbClient, + options: InternalEnhancementOptions +): DbClient { + return makeProxy( + prisma, + options.modelMeta, + (_prisma, model) => new EncryptedHandler(_prisma as DbClientContract, model, options), + 'encryption' + ); +} + +class EncryptedHandler extends DefaultPrismaProxyHandler { + private queryUtils: QueryUtils; + private logger: Logger; + private encryptionKey: CryptoKey | undefined; + private encryptionKeyDigest: string | undefined; + private decryptionKeys: Array<{ key: CryptoKey; digest: string }> = []; + private encrypter: Encrypter | undefined; + private decrypter: Decrypter | undefined; + + constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { + super(prisma, model, options); + + this.queryUtils = new QueryUtils(prisma, options); + this.logger = new Logger(prisma); + + if (!options.encryption) { + throw this.queryUtils.unknownError('Encryption options must be provided'); + } + + if (this.isCustomEncryption(options.encryption!)) { + if (!options.encryption.encrypt || !options.encryption.decrypt) { + throw this.queryUtils.unknownError('Custom encryption must provide encrypt and decrypt functions'); + } + } else { + if (!options.encryption.encryptionKey) { + throw this.queryUtils.unknownError('Encryption key must be provided'); + } + + this.encrypter = new Encrypter(options.encryption.encryptionKey); + this.decrypter = new Decrypter([ + options.encryption.encryptionKey, + ...(options.encryption.decryptionKeys || []), + ]); + } + } + + private isCustomEncryption(encryption: CustomEncryption | SimpleEncryption): encryption is CustomEncryption { + return 'encrypt' in encryption && 'decrypt' in encryption; + } + + private async encrypt(field: FieldInfo, data: string): Promise { + if (this.isCustomEncryption(this.options.encryption!)) { + return this.options.encryption.encrypt(this.model, field, data); + } + + return this.encrypter!.encrypt(data); + } + + private async decrypt(field: FieldInfo, data: string): Promise { + if (this.isCustomEncryption(this.options.encryption!)) { + return this.options.encryption.decrypt(this.model, field, data); + } + + return this.decrypter!.decrypt(data); + } + + // base override + protected async preprocessArgs(action: PrismaProxyActions, args: any) { + if (args && ACTIONS_WITH_WRITE_PAYLOAD.includes(action)) { + await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); + } + return args; + } + + // base override + protected async processResultEntity(method: PrismaProxyActions, data: T): Promise { + if (!data || typeof data !== 'object') { + return data; + } + + for (const value of enumerate(data)) { + await this.doPostProcess(value, this.model); + } + + return data; + } + + private async doPostProcess(entityData: any, model: string) { + const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData); + + for (const field of getModelFields(entityData)) { + // don't decrypt null, undefined or empty string values + if (!entityData[field]) continue; + + const fieldInfo = await resolveField(this.options.modelMeta, realModel, field); + if (!fieldInfo) { + continue; + } + + if (fieldInfo.isDataModel) { + const items = + fieldInfo.isArray && Array.isArray(entityData[field]) ? entityData[field] : [entityData[field]]; + for (const item of items) { + // recurse + await this.doPostProcess(item, fieldInfo.type); + } + } else { + const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted'); + if (shouldDecrypt) { + try { + entityData[field] = await this.decrypt(fieldInfo, entityData[field]); + } catch (error) { + this.logger.warn(`Decryption failed, keeping original value: ${error}`); + } + } + } + } + } + + private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { + const visitor = new NestedWriteVisitor(this.options.modelMeta, { + field: async (field, _action, data, context) => { + // don't encrypt null, undefined or empty string values + if (!data) return; + + const encAttr = field.attributes?.find((attr) => attr.name === '@encrypted'); + if (encAttr && field.type === 'String') { + try { + context.parent[field.name] = await this.encrypt(field, data); + } catch (error) { + this.queryUtils.unknownError(`Encryption failed for field ${field.name}: ${error}`); + } + } + }, + }); + + await visitor.visit(model, action, args); + } +} diff --git a/packages/runtime/src/enhancements/node/password.ts b/packages/runtime/src/enhancements/node/password.ts index 8c1aeb959..a2fdae42c 100644 --- a/packages/runtime/src/enhancements/node/password.ts +++ b/packages/runtime/src/enhancements/node/password.ts @@ -1,7 +1,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-unused-vars */ -import { DEFAULT_PASSWORD_SALT_LENGTH } from '../../constants'; +import { ACTIONS_WITH_WRITE_PAYLOAD, DEFAULT_PASSWORD_SALT_LENGTH } from '../../constants'; import { NestedWriteVisitor, type PrismaWriteActionType } from '../../cross'; import { DbClientContract } from '../../types'; import { InternalEnhancementOptions } from './create-enhancement'; @@ -39,8 +39,7 @@ class PasswordHandler extends DefaultPrismaProxyHandler { // base override protected async preprocessArgs(action: PrismaProxyActions, args: any) { - const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert']; - if (args && args.data && actionsOfInterest.includes(action)) { + if (args && ACTIONS_WITH_WRITE_PAYLOAD.includes(action)) { await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); } return args; diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index 7c4df97c1..fe31a5058 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import type { z } from 'zod'; +import { FieldInfo } from './cross'; export type PrismaPromise = Promise & Record PrismaPromise>; @@ -133,6 +134,11 @@ export type EnhancementOptions = { * The `isolationLevel` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. */ transactionIsolationLevel?: TransactionIsolationLevel; + + /** + * The encryption options for using the `encrypted` enhancement. + */ + encryption?: SimpleEncryption | CustomEncryption; }; /** @@ -145,7 +151,7 @@ export type EnhancementContext = { /** * Kinds of enhancements to `PrismaClient` */ -export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate'; +export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate' | 'encryption'; /** * Function for transforming errors. @@ -166,3 +172,39 @@ export type ZodSchemas = { */ input?: Record>; }; + +/** + * Simple encryption settings for processing fields marked with `@encrypted`. + */ +export type SimpleEncryption = { + /** + * The encryption key. + */ + encryptionKey: Uint8Array; + + /** + * Optional list of all decryption keys that were previously used to encrypt the data + * , for supporting key rotation. The `encryptionKey` field value is automatically + * included for decryption. + * + * When the encrypted data is persisted, a metadata object containing the digest of the + * encryption key is stored alongside the data. This digest is used to quickly determine + * the correct decryption key to use when reading the data. + */ + decryptionKeys?: Uint8Array[]; +}; + +/** + * Custom encryption settings for processing fields marked with `@encrypted`. + */ +export type CustomEncryption = { + /** + * Encryption function. + */ + encrypt: (model: string, field: FieldInfo, plain: string) => Promise; + + /** + * Decryption function + */ + decrypt: (model: string, field: FieldInfo, cipher: string) => Promise; +}; diff --git a/packages/schema/package.json b/packages/schema/package.json index 45740d15f..93fdfc1f0 100644 --- a/packages/schema/package.json +++ b/packages/schema/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack Language Tools", "description": "FullStack enhancement for Prisma ORM: seamless integration from database to UI", - "version": "2.10.2", + "version": "2.11.0", "author": { "name": "ZenStack Team" }, diff --git a/packages/schema/src/cli/actions/generate.ts b/packages/schema/src/cli/actions/generate.ts index d697504ee..229a9ddd8 100644 --- a/packages/schema/src/cli/actions/generate.ts +++ b/packages/schema/src/cli/actions/generate.ts @@ -37,8 +37,8 @@ export async function generate(projectPath: string, options: Options) { // check for multiple versions of Zenstack packages const packages = getZenStackPackages(projectPath); - if (packages) { - const versions = new Set(packages.map((p) => p.version)); + if (packages.length > 0) { + const versions = new Set(packages.map((p) => p.version).filter((v): v is string => !!v)); if (versions.size > 1) { console.warn( colors.yellow( diff --git a/packages/schema/src/cli/actions/info.ts b/packages/schema/src/cli/actions/info.ts index dddef9e27..c212babf4 100644 --- a/packages/schema/src/cli/actions/info.ts +++ b/packages/schema/src/cli/actions/info.ts @@ -16,7 +16,9 @@ export async function info(projectPath: string) { console.log('Installed ZenStack Packages:'); const versions = new Set(); for (const { pkg, version } of packages) { - versions.add(version); + if (version) { + versions.add(version); + } console.log(` ${colors.green(pkg.padEnd(20))}\t${version}`); } diff --git a/packages/schema/src/cli/actions/init.ts b/packages/schema/src/cli/actions/init.ts index 5790997e6..1016d61a9 100644 --- a/packages/schema/src/cli/actions/init.ts +++ b/packages/schema/src/cli/actions/init.ts @@ -63,8 +63,7 @@ export async function init(projectPath: string, options: Options) { if (sampleModelGenerated) { console.log(`Sample model generated at: ${colors.blue(zmodelFile)} -Please check the following guide on how to model your app: - https://zenstack.dev/#/modeling-your-app.`); +Learn how to use ZenStack: https://zenstack.dev/docs.`); } else if (prismaSchema) { console.log( `Your current Prisma schema "${prismaSchema}" has been copied to "${zmodelFile}". diff --git a/packages/schema/src/cli/cli-util.ts b/packages/schema/src/cli/cli-util.ts index b822f75ee..54ac123bd 100644 --- a/packages/schema/src/cli/cli-util.ts +++ b/packages/schema/src/cli/cli-util.ts @@ -227,13 +227,13 @@ export async function getPluginDocuments(services: ZModelServices, fileName: str return result; } -export function getZenStackPackages(projectPath: string) { +export function getZenStackPackages(projectPath: string): Array<{ pkg: string; version: string | undefined }> { let pkgJson: { dependencies: Record; devDependencies: Record }; const resolvedPath = path.resolve(projectPath); try { pkgJson = require(path.join(resolvedPath, 'package.json')); } catch { - return undefined; + return []; } const packages = [ @@ -245,7 +245,7 @@ export function getZenStackPackages(projectPath: string) { try { const resolved = require.resolve(`${pkg}/package.json`, { paths: [resolvedPath] }); // eslint-disable-next-line @typescript-eslint/no-var-requires - return { pkg, version: require(resolved).version }; + return { pkg, version: require(resolved).version as string }; } catch { return { pkg, version: undefined }; } @@ -286,7 +286,7 @@ export async function checkNewVersion() { return; } - if (latestVersion && semver.gt(latestVersion, currVersion)) { + if (latestVersion && currVersion && semver.gt(latestVersion, currVersion)) { console.log(`A newer version ${colors.cyan(latestVersion)} is available.`); } } diff --git a/packages/schema/src/cli/index.ts b/packages/schema/src/cli/index.ts index c58db8c43..62084ce9b 100644 --- a/packages/schema/src/cli/index.ts +++ b/packages/schema/src/cli/index.ts @@ -73,7 +73,7 @@ export const checkAction = async (options: Parameters[1]): export function createProgram() { const program = new Command('zenstack'); - program.version(getVersion(), '-v --version', 'display CLI version'); + program.version(getVersion()!, '-v --version', 'display CLI version'); const schemaExtensions = ZModelLanguageMetaData.fileExtensions.join(', '); diff --git a/packages/schema/src/language-server/validator/attribute-application-validator.ts b/packages/schema/src/language-server/validator/attribute-application-validator.ts index a7c0fef9a..0e1d8e885 100644 --- a/packages/schema/src/language-server/validator/attribute-application-validator.ts +++ b/packages/schema/src/language-server/validator/attribute-application-validator.ts @@ -25,7 +25,7 @@ import { isRelationshipField, resolved, } from '@zenstackhq/sdk'; -import { ValidationAcceptor, streamAst } from 'langium'; +import { ValidationAcceptor, streamAllContents, streamAst } from 'langium'; import pluralize from 'pluralize'; import { AstValidator } from '../types'; import { getStringLiteral, mapBuiltinTypeToExpressionType, typeAssignable } from './utils'; @@ -138,6 +138,9 @@ export default class AttributeApplicationValidator implements AstValidator { + if (isDataModelFieldReference(node) && hasAttribute(node.target.ref as DataModelField, '@encrypted')) { + accept('error', `Encrypted fields cannot be used in policy rules`, { node }); + } + }); + } + private validatePolicyKinds( kind: string, candidates: string[], diff --git a/packages/schema/src/language-server/validator/datamodel-validator.ts b/packages/schema/src/language-server/validator/datamodel-validator.ts index 9054c82c6..630bf0085 100644 --- a/packages/schema/src/language-server/validator/datamodel-validator.ts +++ b/packages/schema/src/language-server/validator/datamodel-validator.ts @@ -33,6 +33,10 @@ export default class DataModelValidator implements AstValidator { validateDuplicatedDeclarations(dm, getModelFieldsWithBases(dm), accept); this.validateAttributes(dm, accept); this.validateFields(dm, accept); + + if (dm.superTypes.length > 0) { + this.validateInheritance(dm, accept); + } } private validateFields(dm: DataModel, accept: ValidationAcceptor) { @@ -407,6 +411,26 @@ export default class DataModelValidator implements AstValidator { }); } } + + private validateInheritance(dm: DataModel, accept: ValidationAcceptor) { + const seen = [dm]; + const todo: DataModel[] = dm.superTypes.map((superType) => superType.ref!); + while (todo.length > 0) { + const current = todo.shift()!; + if (seen.includes(current)) { + accept( + 'error', + `Circular inheritance detected: ${seen.map((m) => m.name).join(' -> ')} -> ${current.name}`, + { + node: dm, + } + ); + return; + } + seen.push(current); + todo.push(...current.superTypes.map((superType) => superType.ref!)); + } + } } export interface MissingOppositeRelationData { diff --git a/packages/schema/src/language-server/validator/function-invocation-validator.ts b/packages/schema/src/language-server/validator/function-invocation-validator.ts index 8c11a2a72..343c75cad 100644 --- a/packages/schema/src/language-server/validator/function-invocation-validator.ts +++ b/packages/schema/src/language-server/validator/function-invocation-validator.ts @@ -87,7 +87,17 @@ export default class FunctionInvocationValidator implements AstValidator(expr.args[0]?.value); + if (arg && !allCasing.includes(arg)) { + accept('error', `argument must be one of: ${allCasing.map((c) => '"' + c + '"').join(', ')}`, { + node: expr.args[0], + }); + } + } else if ( funcAllowedContext.includes(ExpressionContext.AccessPolicy) || funcAllowedContext.includes(ExpressionContext.ValidationRule) ) { diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index ba8c50feb..689ddaf2c 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -24,7 +24,6 @@ import { isArrayExpr, isDataModel, isGeneratorDecl, - isReferenceExpr, isTypeDef, type Model, } from '@zenstackhq/sdk/ast'; @@ -45,6 +44,7 @@ import { } from 'ts-morph'; import { upperCaseFirst } from 'upper-case-first'; import { name } from '..'; +import { getConcreteModels, getDiscriminatorField } from '../../../utils/ast-utils'; import { execPackage } from '../../../utils/exec-utils'; import { CorePlugins, getPluginCustomOutputFolder } from '../../plugin-utils'; import { trackPrismaSchemaError } from '../../prisma'; @@ -407,9 +407,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara this.model.declarations .filter((d): d is DataModel => isDelegateModel(d)) .forEach((dm) => { - const concreteModels = this.model.declarations.filter( - (d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm) - ); + const concreteModels = getConcreteModels(dm); if (concreteModels.length > 0) { delegateInfo.push([dm, concreteModels]); } @@ -579,7 +577,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara const typeName = typeAlias.getName(); const payloadRecord = delegateInfo.find(([delegate]) => `$${delegate.name}Payload` === typeName); if (payloadRecord) { - const discriminatorDecl = this.getDiscriminatorField(payloadRecord[0]); + const discriminatorDecl = getDiscriminatorField(payloadRecord[0]); if (discriminatorDecl) { source = `${payloadRecord[1] .map( @@ -826,15 +824,6 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara .filter((n) => n.getName().startsWith(DELEGATE_AUX_RELATION_PREFIX)); } - private getDiscriminatorField(delegate: DataModel) { - const delegateAttr = getAttribute(delegate, '@@delegate'); - if (!delegateAttr) { - return undefined; - } - const arg = delegateAttr.args[0]?.value; - return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined; - } - private saveSourceFile(sf: SourceFile) { if (this.options.preserveTsFiles) { saveSourceFile(sf); diff --git a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts index 645e02cd1..0d792bdc1 100644 --- a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts +++ b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts @@ -839,16 +839,18 @@ export class ExpressionWriter { operation = this.options.operationContext; } - this.block(() => { - if (operation === 'postUpdate') { - // 'postUpdate' policies are not delegated to relations, just use constant `false` here - // e.g.: - // @@allow('all', check(author)) should not delegate "postUpdate" to author - this.writer.write(`${fieldRef.target.$refText}: ${FALSE}`); - } else { - const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation); - this.writer.write(`${fieldRef.target.$refText}: ${targetGuardFunc}(context, db)`); - } - }); + this.block(() => + this.writeFieldCondition(fieldRef, () => { + if (operation === 'postUpdate') { + // 'postUpdate' policies are not delegated to relations, just use constant `false` here + // e.g.: + // @@allow('all', check(author)) should not delegate "postUpdate" to author + this.writer.write(FALSE); + } else { + const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation); + this.writer.write(`${targetGuardFunc}(context, db)`); + } + }) + ); } } diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 8206f797b..9ffe41dcb 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -454,6 +454,8 @@ export class PolicyGenerator { writer: CodeBlockWriter, sourceFile: SourceFile ) { + // first handle several cases where a constant function can be used + if (kind === 'update' && allows.length === 0) { // no allow rule for 'update', policy is constant based on if there's // post-update counterpart diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 96a3b15f5..a0bde1769 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -57,6 +57,7 @@ import path from 'path'; import semver from 'semver'; import { name } from '.'; import { getStringLiteral } from '../../language-server/validator/utils'; +import { getConcreteModels } from '../../utils/ast-utils'; import { execPackage } from '../../utils/exec-utils'; import { isDefaultWithAuth } from '../enhancer/enhancer-utils'; import { @@ -320,9 +321,7 @@ export class PrismaSchemaGenerator { } // collect concrete models inheriting this model - const concreteModels = decl.$container.declarations.filter( - (d) => isDataModel(d) && d !== decl && d.superTypes.some((base) => base.ref === decl) - ); + const concreteModels = getConcreteModels(decl); // generate an optional relation field in delegate base model to each concrete model concreteModels.forEach((concrete) => { diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index 3316a90a9..a0a0a41f8 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -171,6 +171,29 @@ function hasSome(field: Any[], search: Any[]): Boolean { function isEmpty(field: Any[]): Boolean { } @@@expressionContext([AccessPolicy, ValidationRule]) +/** + * The name of the model for which the policy rule is defined. If the rule is + * inherited to a sub model, this function returns the name of the sub model. + * + * @param optional parameter to control the casing of the returned value. Valid + * values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults + * to "original". + */ +function currentModel(casing: String?): String { +} @@@expressionContext([AccessPolicy]) + +/** + * The operation for which the policy rule is defined for. Note that a rule with + * "all" operation is expanded to "create", "read", "update", and "delete" rules, + * and the function returns corresponding value for each expanded version. + * + * @param optional parameter to control the casing of the returned value. Valid + * values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults + * to "original". + */ +function currentOperation(casing: String?): String { +} @@@expressionContext([AccessPolicy]) + /** * Marks an attribute to be only applicable to certain field types. */ @@ -552,6 +575,14 @@ attribute @@auth() @@@supportTypeDef */ attribute @password(saltLength: Int?, salt: String?) @@@targetField([StringField]) + +/** + * Indicates that the field is encrypted when storing in the DB and should be decrypted when read + * + * ZenStack uses the Web Crypto API to encrypt and decrypt the field. + */ +attribute @encrypted() @@@targetField([StringField]) + /** * Indicates that the field should be omitted when read from the generated services. */ diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index a6fab7ea5..f59ee7faa 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -2,6 +2,7 @@ import { BinaryExpr, DataModel, DataModelAttribute, + DataModelField, Expression, InheritableNode, isBinaryExpr, @@ -9,12 +10,20 @@ import { isDataModelField, isInvocationExpr, isModel, + isReferenceExpr, isTypeDef, Model, ModelImport, TypeDef, } from '@zenstackhq/language/ast'; -import { getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk'; +import { + getAttribute, + getInheritanceChain, + getRecursiveBases, + hasAttribute, + isDelegateModel, + isFromStdlib, +} from '@zenstackhq/sdk'; import { AstNode, copyAstNode, @@ -94,6 +103,9 @@ function filterBaseAttribute(forModel: DataModel, base: DataModel, attr: DataMod // uninheritable attributes for delegate inheritance (they reference fields from the base) const uninheritableFromDelegateAttributes = ['@@unique', '@@index', '@@fulltext']; + // attributes that are inherited but can be overridden + const overrideAttributes = ['@@schema']; + if (uninheritableAttributes.includes(attr.decl.$refText)) { return false; } @@ -107,6 +119,11 @@ function filterBaseAttribute(forModel: DataModel, base: DataModel, attr: DataMod return false; } + if (hasAttribute(forModel, attr.decl.$refText) && overrideAttributes.includes(attr.decl.$refText)) { + // don't inherit an attribute if it's overridden in the sub model + return false; + } + return true; } @@ -310,3 +327,27 @@ export function findUpInheritance(start: DataModel, target: DataModel): DataMode } return undefined; } + +/** + * Gets all concrete models that inherit from the given delegate model + */ +export function getConcreteModels(dataModel: DataModel): DataModel[] { + if (!isDelegateModel(dataModel)) { + return []; + } + return dataModel.$container.declarations.filter( + (d): d is DataModel => isDataModel(d) && d !== dataModel && d.superTypes.some((base) => base.ref === dataModel) + ); +} + +/** + * Gets the discriminator field for the given delegate model + */ +export function getDiscriminatorField(dataModel: DataModel) { + const delegateAttr = getAttribute(dataModel, '@@delegate'); + if (!delegateAttr) { + return undefined; + } + const arg = delegateAttr.args[0]?.value; + return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined; +} diff --git a/packages/schema/src/utils/version-utils.ts b/packages/schema/src/utils/version-utils.ts index 0e2de705d..3a2daae57 100644 --- a/packages/schema/src/utils/version-utils.ts +++ b/packages/schema/src/utils/version-utils.ts @@ -1,5 +1,5 @@ /* eslint-disable @typescript-eslint/no-var-requires */ -export function getVersion() { +export function getVersion(): string | undefined { try { return require('../package.json').version; } catch { diff --git a/packages/schema/tests/schema/validation/cyclic-inheritance.test.ts b/packages/schema/tests/schema/validation/cyclic-inheritance.test.ts new file mode 100644 index 000000000..494dad2be --- /dev/null +++ b/packages/schema/tests/schema/validation/cyclic-inheritance.test.ts @@ -0,0 +1,39 @@ +import { loadModelWithError } from '../../utils'; + +describe('Cyclic inheritance', () => { + it('abstract inheritance', async () => { + const errors = await loadModelWithError( + ` + abstract model A extends B {} + abstract model B extends A {} + model C extends B { + id Int @id + } + ` + ); + expect(errors).toContain('Circular inheritance detected: A -> B -> A'); + expect(errors).toContain('Circular inheritance detected: B -> A -> B'); + expect(errors).toContain('Circular inheritance detected: C -> B -> A -> B'); + }); + + it('delegate inheritance', async () => { + const errors = await loadModelWithError( + ` + model A extends B { + typeA String + @@delegate(typeA) + } + model B extends A { + typeB String + @@delegate(typeB) + } + model C extends B { + id Int @id + } + ` + ); + expect(errors).toContain('Circular inheritance detected: A -> B -> A'); + expect(errors).toContain('Circular inheritance detected: B -> A -> B'); + expect(errors).toContain('Circular inheritance detected: C -> B -> A -> B'); + }); +}); diff --git a/packages/sdk/package.json b/packages/sdk/package.json index 58fd11aee..a49c0bbdd 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/sdk", - "version": "2.10.2", + "version": "2.11.0", "description": "ZenStack plugin development SDK", "main": "index.js", "scripts": { diff --git a/packages/sdk/src/code-gen.ts b/packages/sdk/src/code-gen.ts index 7b26cc0c4..67833b788 100644 --- a/packages/sdk/src/code-gen.ts +++ b/packages/sdk/src/code-gen.ts @@ -47,6 +47,11 @@ export async function saveProject(project: Project) { * Emit a TS project to JS files. */ export async function emitProject(project: Project) { + // ignore type checking for all source files + for (const sf of project.getSourceFiles()) { + sf.insertStatements(0, '// @ts-nocheck'); + } + const errors = project.getPreEmitDiagnostics().filter((d) => d.getCategory() === DiagnosticCategory.Error); if (errors.length > 0) { console.error('Error compiling generated code:'); diff --git a/packages/sdk/src/typescript-expression-transformer.ts b/packages/sdk/src/typescript-expression-transformer.ts index 9a884ebdf..801db4d4f 100644 --- a/packages/sdk/src/typescript-expression-transformer.ts +++ b/packages/sdk/src/typescript-expression-transformer.ts @@ -20,6 +20,7 @@ import { isNullExpr, isThisExpr, } from '@zenstackhq/language/ast'; +import { getContainerOfType } from 'langium'; import { P, match } from 'ts-pattern'; import { ExpressionContext } from './constants'; import { getEntityCheckerFunctionName } from './names'; @@ -40,6 +41,8 @@ type Options = { operationContext?: 'read' | 'create' | 'update' | 'postUpdate' | 'delete'; }; +type Casing = 'original' | 'upper' | 'lower' | 'capitalize' | 'uncapitalize'; + // a registry of function handlers marked with @func const functionHandlers = new Map(); @@ -150,7 +153,7 @@ export class TypeScriptExpressionTransformer { } const args = expr.args.map((arg) => arg.value); - return handler.value.call(this, args, normalizeUndefined); + return handler.value.call(this, expr, args, normalizeUndefined); } // #region function invocation handlers @@ -168,7 +171,7 @@ export class TypeScriptExpressionTransformer { } @func('length') - private _length(args: Expression[]) { + private _length(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); const min = getLiteral(args[1]); const max = getLiteral(args[2]); @@ -188,7 +191,7 @@ export class TypeScriptExpressionTransformer { } @func('contains') - private _contains(args: Expression[], normalizeUndefined: boolean) { + private _contains(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); const caseInsensitive = getLiteral(args[2]) === true; let result: string; @@ -201,34 +204,34 @@ export class TypeScriptExpressionTransformer { } @func('startsWith') - private _startsWith(args: Expression[], normalizeUndefined: boolean) { + private _startsWith(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); const result = `${field}?.startsWith(${this.transform(args[1], normalizeUndefined)})`; return this.ensureBoolean(result); } @func('endsWith') - private _endsWith(args: Expression[], normalizeUndefined: boolean) { + private _endsWith(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); const result = `${field}?.endsWith(${this.transform(args[1], normalizeUndefined)})`; return this.ensureBoolean(result); } @func('regex') - private _regex(args: Expression[]) { + private _regex(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); const pattern = getLiteral(args[1]); return this.ensureBooleanTernary(args[0], field, `new RegExp(${JSON.stringify(pattern)}).test(${field})`); } @func('email') - private _email(args: Expression[]) { + private _email(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); return this.ensureBooleanTernary(args[0], field, `z.string().email().safeParse(${field}).success`); } @func('datetime') - private _datetime(args: Expression[]) { + private _datetime(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); return this.ensureBooleanTernary( args[0], @@ -238,20 +241,20 @@ export class TypeScriptExpressionTransformer { } @func('url') - private _url(args: Expression[]) { + private _url(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); return this.ensureBooleanTernary(args[0], field, `z.string().url().safeParse(${field}).success`); } @func('has') - private _has(args: Expression[], normalizeUndefined: boolean) { + private _has(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); const result = `${field}?.includes(${this.transform(args[1], normalizeUndefined)})`; return this.ensureBoolean(result); } @func('hasEvery') - private _hasEvery(args: Expression[], normalizeUndefined: boolean) { + private _hasEvery(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); return this.ensureBooleanTernary( args[0], @@ -261,7 +264,7 @@ export class TypeScriptExpressionTransformer { } @func('hasSome') - private _hasSome(args: Expression[], normalizeUndefined: boolean) { + private _hasSome(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); return this.ensureBooleanTernary( args[0], @@ -271,13 +274,13 @@ export class TypeScriptExpressionTransformer { } @func('isEmpty') - private _isEmpty(args: Expression[]) { + private _isEmpty(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); return `(!${field} || ${field}?.length === 0)`; } @func('check') - private _check(args: Expression[]) { + private _check(_invocation: InvocationExpr, args: Expression[]) { if (!isDataModelFieldReference(args[0])) { throw new TypeScriptExpressionTransformerError(`First argument of check() must be a field`); } @@ -309,6 +312,52 @@ export class TypeScriptExpressionTransformer { return `${entityCheckerFunc}(input.${fieldRef.target.$refText}, context)`; } + private toStringWithCaseChange(value: string, casing: Casing) { + if (!value) { + return "''"; + } + return match(casing) + .with('original', () => `'${value}'`) + .with('upper', () => `'${value.toUpperCase()}'`) + .with('lower', () => `'${value.toLowerCase()}'`) + .with('capitalize', () => `'${value.charAt(0).toUpperCase() + value.slice(1)}'`) + .with('uncapitalize', () => `'${value.charAt(0).toLowerCase() + value.slice(1)}'`) + .exhaustive(); + } + + @func('currentModel') + private _currentModel(invocation: InvocationExpr, args: Expression[]) { + let casing: Casing = 'original'; + if (args[0]) { + casing = getLiteral(args[0]) as Casing; + } + + const containingModel = getContainerOfType(invocation, isDataModel); + if (!containingModel) { + throw new TypeScriptExpressionTransformerError('currentModel() must be called inside a model'); + } + return this.toStringWithCaseChange(containingModel.name, casing); + } + + @func('currentOperation') + private _currentOperation(_invocation: InvocationExpr, args: Expression[]) { + let casing: Casing = 'original'; + if (args[0]) { + casing = getLiteral(args[0]) as Casing; + } + + if (!this.options.operationContext) { + throw new TypeScriptExpressionTransformerError( + 'currentOperation() must be called inside an access policy rule' + ); + } + let contextOperation = this.options.operationContext; + if (contextOperation === 'postUpdate') { + contextOperation = 'update'; + } + return this.toStringWithCaseChange(contextOperation, casing); + } + private ensureBoolean(expr: string) { if (this.options.context === ExpressionContext.ValidationRule) { // all fields are optional in a validation context, so we treat undefined diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 46c2a82c1..ecb6895eb 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -544,8 +544,16 @@ export function getModelFieldsWithBases(model: DataModel, includeDelegate = true } } -export function getRecursiveBases(dataModel: DataModel, includeDelegate = true): DataModel[] { +export function getRecursiveBases( + dataModel: DataModel, + includeDelegate = true, + seen = new Set() +): DataModel[] { const result: DataModel[] = []; + if (seen.has(dataModel)) { + return result; + } + seen.add(dataModel); dataModel.superTypes.forEach((superType) => { const baseDecl = superType.ref; if (baseDecl) { @@ -553,7 +561,7 @@ export function getRecursiveBases(dataModel: DataModel, includeDelegate = true): return; } result.push(baseDecl); - result.push(...getRecursiveBases(baseDecl, includeDelegate)); + result.push(...getRecursiveBases(baseDecl, includeDelegate, seen)); } }); return result; diff --git a/packages/server/package.json b/packages/server/package.json index 806471ecb..85b5ec809 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/server", - "version": "2.10.2", + "version": "2.11.0", "displayName": "ZenStack Server-side Adapters", "description": "ZenStack server-side adapters", "homepage": "https://zenstack.dev", diff --git a/packages/server/src/api/rest/index.ts b/packages/server/src/api/rest/index.ts index 1107fbc64..16de93637 100644 --- a/packages/server/src/api/rest/index.ts +++ b/packages/server/src/api/rest/index.ts @@ -1103,8 +1103,8 @@ class RequestHandler extends APIHandlerBase { where: this.makePrismaIdFilter(typeInfo.idFields, resourceId), }); return { - status: 204, - body: undefined, + status: 200, + body: { meta: {} }, }; } diff --git a/packages/server/tests/adapter/express.test.ts b/packages/server/tests/adapter/express.test.ts index 0627990e7..85ccc8a21 100644 --- a/packages/server/tests/adapter/express.test.ts +++ b/packages/server/tests/adapter/express.test.ts @@ -190,7 +190,7 @@ describe('Express adapter tests - rest handler', () => { expect(r.body.data.attributes.email).toBe('user1@def.com'); r = await request(app).delete(makeUrl('/api/user/user1')); - expect(r.status).toBe(204); + expect(r.status).toBe(200); expect(await prisma.user.findMany()).toHaveLength(0); }); }); diff --git a/packages/server/tests/adapter/fastify.test.ts b/packages/server/tests/adapter/fastify.test.ts index f03066e4f..ed4da3c72 100644 --- a/packages/server/tests/adapter/fastify.test.ts +++ b/packages/server/tests/adapter/fastify.test.ts @@ -233,7 +233,7 @@ describe('Fastify adapter tests - rest handler', () => { expect(r.json().data.attributes.email).toBe('user1@def.com'); r = await app.inject({ method: 'DELETE', url: '/api/user/user1' }); - expect(r.statusCode).toBe(204); + expect(r.statusCode).toBe(200); expect(await prisma.user.findMany()).toHaveLength(0); }); }); diff --git a/packages/server/tests/adapter/hono.test.ts b/packages/server/tests/adapter/hono.test.ts index 3fc1bb9da..fc55e1647 100644 --- a/packages/server/tests/adapter/hono.test.ts +++ b/packages/server/tests/adapter/hono.test.ts @@ -167,7 +167,7 @@ describe('Hono adapter tests - rest handler', () => { expect((await unmarshal(r)).data.attributes.email).toBe('user1@def.com'); r = await handler(makeRequest('DELETE', makeUrl(makeUrl('/api/user/user1')))); - expect(r.status).toBe(204); + expect(r.status).toBe(200); expect(await prisma.user.findMany()).toHaveLength(0); }); }); diff --git a/packages/server/tests/adapter/next.test.ts b/packages/server/tests/adapter/next.test.ts index b8652de7c..733b30ade 100644 --- a/packages/server/tests/adapter/next.test.ts +++ b/packages/server/tests/adapter/next.test.ts @@ -307,7 +307,7 @@ model M { expect(resp.body.data.attributes.value).toBe(2); }); - await makeTestClient('/m/1', options).del('/').expect(204); + await makeTestClient('/m/1', options).del('/').expect(200); expect(await prisma.m.count()).toBe(0); }); }); diff --git a/packages/server/tests/adapter/sveltekit.test.ts b/packages/server/tests/adapter/sveltekit.test.ts index d9663a2b6..650f89f85 100644 --- a/packages/server/tests/adapter/sveltekit.test.ts +++ b/packages/server/tests/adapter/sveltekit.test.ts @@ -164,7 +164,7 @@ describe('SvelteKit adapter tests - rest handler', () => { expect((await unmarshal(r)).data.attributes.email).toBe('user1@def.com'); r = await handler(makeRequest('DELETE', makeUrl(makeUrl('/api/user/user1')))); - expect(r.status).toBe(204); + expect(r.status).toBe(200); expect(await prisma.user.findMany()).toHaveLength(0); }); }); diff --git a/packages/server/tests/api/rest.test.ts b/packages/server/tests/api/rest.test.ts index 2a59e6067..b36755055 100644 --- a/packages/server/tests/api/rest.test.ts +++ b/packages/server/tests/api/rest.test.ts @@ -2340,8 +2340,8 @@ describe('REST server tests', () => { prisma, }); - expect(r.status).toBe(204); - expect(r.body).toBeUndefined(); + expect(r.status).toBe(200); + expect(r.body).toMatchObject({ meta: {} }); }); it('deletes an item with compound id', async () => { @@ -2355,8 +2355,8 @@ describe('REST server tests', () => { path: `/postLike/1${idDivider}user1`, prisma, }); - expect(r.status).toBe(204); - expect(r.body).toBeUndefined(); + expect(r.status).toBe(200); + expect(r.body).toMatchObject({ meta: {} }); }); it('returns 404 if the user does not exist', async () => { diff --git a/packages/testtools/package.json b/packages/testtools/package.json index c47c29c65..74db32925 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "2.10.2", + "version": "2.11.0", "description": "ZenStack Test Tools", "main": "index.js", "private": true, diff --git a/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts b/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts index 91a385db0..7a555e0cd 100644 --- a/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts @@ -378,6 +378,66 @@ describe('Polymorphism Test', () => { ).resolves.toHaveLength(1); }); + it('read with counting relation defined in base', async () => { + const { enhance } = await loadSchema( + ` + + model A { + id Int @id @default(autoincrement()) + type String + bs B[] + cs C[] + @@delegate(type) + } + + model A1 extends A { + a1 Int + type1 String + @@delegate(type1) + } + + model A2 extends A1 { + a2 Int + } + + model B { + id Int @id @default(autoincrement()) + a A @relation(fields: [aId], references: [id]) + aId Int + b Int + } + + model C { + id Int @id @default(autoincrement()) + a A @relation(fields: [aId], references: [id]) + aId Int + c Int + } + `, + { enhancements: ['delegate'] } + ); + const db = enhance(); + + const a2 = await db.a2.create({ + data: { a1: 1, a2: 2, bs: { create: [{ b: 1 }, { b: 2 }] }, cs: { create: [{ c: 1 }] } }, + include: { _count: { select: { bs: true } } }, + }); + expect(a2).toMatchObject({ a1: 1, a2: 2, _count: { bs: 2 } }); + + await expect( + db.a2.findFirst({ select: { a1: true, _count: { select: { bs: true } } } }) + ).resolves.toStrictEqual({ + a1: 1, + _count: { bs: 2 }, + }); + + await expect(db.a.findFirst({ select: { _count: { select: { bs: true, cs: true } } } })).resolves.toMatchObject( + { + _count: { bs: 2, cs: 1 }, + } + ); + }); + it('order by base fields', async () => { const { db, user } = await setup(); diff --git a/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts b/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts index d149a6392..67fc456af 100644 --- a/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts @@ -571,4 +571,84 @@ describe('Polymorphic Policy Test', () => { expect(foundPost2.foo).toBeUndefined(); expect(foundPost2.bar).toBeUndefined(); }); + + it('respects concrete policies when read as base optional relation', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + asset Asset? + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + type String + + @@delegate(type) + @@allow('all', true) + } + + model Post extends Asset { + title String + private Boolean + @@allow('create', true) + @@deny('read', private) + } + ` + ); + + const fullDb = enhance(undefined, { kinds: ['delegate'] }); + await fullDb.user.create({ data: { id: 1 } }); + await fullDb.post.create({ data: { title: 'Post1', private: true, user: { connect: { id: 1 } } } }); + await expect(fullDb.user.findUnique({ where: { id: 1 }, include: { asset: true } })).resolves.toMatchObject({ + asset: expect.objectContaining({ type: 'Post' }), + }); + + const db = enhance(); + const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } }); + expect(read.asset).toBeTruthy(); + expect(read.asset.title).toBeUndefined(); + }); + + it('respects concrete policies when read as base required relation', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + asset Asset @relation(fields: [assetId], references: [id]) + assetId Int @unique + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + user User? + type String + + @@delegate(type) + @@allow('all', true) + } + + model Post extends Asset { + title String + private Boolean + @@deny('read', private) + } + ` + ); + + const fullDb = enhance(undefined, { kinds: ['delegate'] }); + await fullDb.post.create({ data: { id: 1, title: 'Post1', private: true, user: { create: { id: 1 } } } }); + await expect(fullDb.user.findUnique({ where: { id: 1 }, include: { asset: true } })).resolves.toMatchObject({ + asset: expect.objectContaining({ type: 'Post' }), + }); + + const db = enhance(); + const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } }); + expect(read).toBeTruthy(); + expect(read.asset.title).toBeUndefined(); + }); }); diff --git a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts new file mode 100644 index 000000000..71d32769f --- /dev/null +++ b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts @@ -0,0 +1,478 @@ +import { FieldInfo } from '@zenstackhq/runtime'; +import { loadSchema, loadModelWithError } from '@zenstackhq/testtools'; +import path from 'path'; + +describe('Encrypted test', () => { + let origDir: string; + const encryptionKey = new Uint8Array(Buffer.from('AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=', 'base64')); + + beforeAll(async () => { + origDir = path.resolve('.'); + }); + + afterEach(async () => { + process.chdir(origDir); + }); + + it('Simple encryption test', async () => { + const { enhance, prisma } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + }`, + { + enhancements: ['encryption'], + enhanceOptions: { + encryption: { encryptionKey }, + }, + } + ); + + const sudoDb = enhance(undefined, { kinds: [] }); + + const db = enhance(); + + const create = await db.user.create({ + data: { + id: '1', + encrypted_value: 'abc123', + }, + }); + + const read = await db.user.findUnique({ + where: { + id: '1', + }, + }); + + const sudoRead = await sudoDb.user.findUnique({ + where: { + id: '1', + }, + }); + + const rawRead = await prisma.user.findUnique({ where: { id: '1' } }); + + expect(create.encrypted_value).toBe('abc123'); + expect(read.encrypted_value).toBe('abc123'); + expect(sudoRead.encrypted_value).not.toBe('abc123'); + expect(rawRead.encrypted_value).not.toBe('abc123'); + + // update + const updated = await db.user.update({ + where: { id: '1' }, + data: { encrypted_value: 'abc234' }, + }); + expect(updated.encrypted_value).toBe('abc234'); + await expect(db.user.findUnique({ where: { id: '1' } })).resolves.toMatchObject({ + encrypted_value: 'abc234', + }); + await expect(prisma.user.findUnique({ where: { id: '1' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc234', + }); + + // upsert with create + const upsertCreate = await db.user.upsert({ + where: { id: '2' }, + create: { + id: '2', + encrypted_value: 'abc345', + }, + update: { + encrypted_value: 'abc456', + }, + }); + expect(upsertCreate.encrypted_value).toBe('abc345'); + await expect(db.user.findUnique({ where: { id: '2' } })).resolves.toMatchObject({ + encrypted_value: 'abc345', + }); + await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc345', + }); + + // upsert with update + const upsertUpdate = await db.user.upsert({ + where: { id: '2' }, + create: { + id: '2', + encrypted_value: 'abc345', + }, + update: { + encrypted_value: 'abc456', + }, + }); + expect(upsertUpdate.encrypted_value).toBe('abc456'); + await expect(db.user.findUnique({ where: { id: '2' } })).resolves.toMatchObject({ + encrypted_value: 'abc456', + }); + await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc456', + }); + + // createMany + await db.user.createMany({ + data: [ + { id: '3', encrypted_value: 'abc567' }, + { id: '4', encrypted_value: 'abc678' }, + ], + }); + await expect(db.user.findUnique({ where: { id: '3' } })).resolves.toMatchObject({ + encrypted_value: 'abc567', + }); + await expect(prisma.user.findUnique({ where: { id: '3' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc567', + }); + + // createManyAndReturn + await expect( + db.user.createManyAndReturn({ + data: [ + { id: '5', encrypted_value: 'abc789' }, + { id: '6', encrypted_value: 'abc890' }, + ], + }) + ).resolves.toEqual( + expect.arrayContaining([ + { id: '5', encrypted_value: 'abc789' }, + { id: '6', encrypted_value: 'abc890' }, + ]) + ); + await expect(db.user.findUnique({ where: { id: '5' } })).resolves.toMatchObject({ + encrypted_value: 'abc789', + }); + await expect(prisma.user.findUnique({ where: { id: '5' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc789', + }); + }); + + it('Works with nullish values', async () => { + const { enhance, prisma } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + encrypted_value String? @encrypted() + }`, + { + enhancements: ['encryption'], + enhanceOptions: { + encryption: { encryptionKey }, + }, + } + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: '1', encrypted_value: '' } })).resolves.toMatchObject({ + encrypted_value: '', + }); + await expect(prisma.user.findUnique({ where: { id: '1' } })).resolves.toMatchObject({ encrypted_value: '' }); + + await expect(db.user.create({ data: { id: '2' } })).resolves.toMatchObject({ + encrypted_value: null, + }); + await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.toMatchObject({ encrypted_value: null }); + + await expect(db.user.create({ data: { id: '3', encrypted_value: null } })).resolves.toMatchObject({ + encrypted_value: null, + }); + await expect(prisma.user.findUnique({ where: { id: '3' } })).resolves.toMatchObject({ encrypted_value: null }); + }); + + it('Decrypts nested fields', async () => { + const { enhance, prisma } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + posts Post[] + } + + model Post { + id String @id @default(cuid()) + title String @encrypted() + author User @relation(fields: [authorId], references: [id]) + authorId String + } + `, + { + enhancements: ['encryption'], + enhanceOptions: { + encryption: { encryptionKey }, + }, + } + ); + + const db = enhance(); + + const create = await db.user.create({ + data: { + id: '1', + posts: { create: { title: 'Post1' } }, + }, + include: { posts: true }, + }); + expect(create.posts[0].title).toBe('Post1'); + + const read = await db.user.findUnique({ + where: { + id: '1', + }, + include: { posts: true }, + }); + expect(read.posts[0].title).toBe('Post1'); + + const rawRead = await prisma.user.findUnique({ where: { id: '1' }, include: { posts: true } }); + expect(rawRead.posts[0].title).not.toBe('Post1'); + }); + + it('Multi-field encryption test', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + x1 String @encrypted() + x2 String @encrypted() + }`, + { + enhancements: ['encryption'], + enhanceOptions: { + encryption: { encryptionKey }, + }, + } + ); + + const db = enhance(); + + const create = await db.user.create({ + data: { + id: '1', + x1: 'abc123', + x2: '123abc', + }, + }); + + const read = await db.user.findUnique({ + where: { + id: '1', + }, + }); + + expect(create).toMatchObject({ x1: 'abc123', x2: '123abc' }); + expect(read).toMatchObject({ x1: 'abc123', x2: '123abc' }); + }); + + it('Custom encryption test', async () => { + const { enhance, prisma } = await loadSchema(` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + }`); + + const db = enhance(undefined, { + kinds: ['encryption'], + encryption: { + encrypt: async (model: string, field: FieldInfo, data: string) => { + // Add _enc to the end of the input + return data + '_enc'; + }, + decrypt: async (model: string, field: FieldInfo, cipher: string) => { + // Remove _enc from the end of the input explicitly + if (cipher.endsWith('_enc')) { + return cipher.slice(0, -4); // Remove last 4 characters (_enc) + } + + return cipher; + }, + }, + }); + + const create = await db.user.create({ + data: { + id: '1', + encrypted_value: 'abc123', + }, + }); + + const read = await db.user.findUnique({ + where: { + id: '1', + }, + }); + + const rawRead = await prisma.user.findUnique({ + where: { + id: '1', + }, + }); + + expect(create.encrypted_value).toBe('abc123'); + expect(read.encrypted_value).toBe('abc123'); + expect(rawRead.encrypted_value).toBe('abc123_enc'); + }); + + it('Works with multiple decryption keys', async () => { + const { enhanceRaw: enhance, prisma } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + secret String @encrypted() + }` + ); + + const key1 = crypto.getRandomValues(new Uint8Array(32)); + const key2 = crypto.getRandomValues(new Uint8Array(32)); + + const db1 = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: key1 }, + }); + const user1 = await db1.user.create({ data: { secret: 'user1' } }); + + const db2 = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: key2 }, + }); + const user2 = await db2.user.create({ data: { secret: 'user2' } }); + + const dbAll = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: crypto.getRandomValues(new Uint8Array(32)), decryptionKeys: [key1, key2] }, + }); + const allUsers = await dbAll.user.findMany(); + expect(allUsers).toEqual(expect.arrayContaining([user1, user2])); + + const dbWithEncryptionKeyExplicitlyProvided = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: key1, decryptionKeys: [key1, key2] }, + }); + await expect(dbWithEncryptionKeyExplicitlyProvided.user.findMany()).resolves.toEqual( + expect.arrayContaining([user1, user2]) + ); + + const dbWithDuplicatedKeys = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: key1, decryptionKeys: [key1, key1, key2, key2] }, + }); + await expect(dbWithDuplicatedKeys.user.findMany()).resolves.toEqual(expect.arrayContaining([user1, user2])); + + const dbWithInvalidKeys = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: key1, decryptionKeys: [key2, crypto.getRandomValues(new Uint8Array(32))] }, + }); + await expect(dbWithInvalidKeys.user.findMany()).resolves.toEqual(expect.arrayContaining([user1, user2])); + + const dbWithMissingKeys = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: key2 }, + }); + const found = await dbWithMissingKeys.user.findMany(); + expect(found).not.toContainEqual(user1); + expect(found).toContainEqual(user2); + + const dbWithAllWrongKeys = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: crypto.getRandomValues(new Uint8Array(32)) }, + }); + const found1 = await dbWithAllWrongKeys.user.findMany(); + expect(found1).not.toContainEqual(user1); + expect(found1).not.toContainEqual(user2); + }); + + it('Only supports string fields', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @id @default(cuid()) + encrypted_value Bytes @encrypted() + }` + ) + ).resolves.toContain(`attribute \"@encrypted\" cannot be used on this type of field`); + }); + + it('Returns cipher text when decryption fails', async () => { + const { enhance, enhanceRaw, prisma } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + + @@allow('all', true) + }`, + { enhancements: ['encryption'] } + ); + + const db = enhance(undefined, { + kinds: ['encryption'], + encryption: { encryptionKey }, + }); + + const create = await db.user.create({ + data: { + id: '1', + encrypted_value: 'abc123', + }, + }); + expect(create.encrypted_value).toBe('abc123'); + + const db1 = enhanceRaw(prisma, undefined, { + encryption: { encryptionKey: crypto.getRandomValues(new Uint8Array(32)) }, + }); + const read = await db1.user.findUnique({ where: { id: '1' } }); + expect(read.encrypted_value).toBeTruthy(); + expect(read.encrypted_value).not.toBe('abc123'); + }); + + it('Works with length validation', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() @length(0, 6) + @@allow('all', true) + }`, + { + enhanceOptions: { encryption: { encryptionKey } }, + } + ); + + const db = enhance(); + + const create = await db.user.create({ + data: { + id: '1', + encrypted_value: 'abc123', + }, + }); + expect(create.encrypted_value).toBe('abc123'); + + await expect( + db.user.create({ + data: { id: '2', encrypted_value: 'abc1234' }, + }) + ).toBeRejectedByPolicy(); + }); + + it('Complains when encrypted fields are used in model-level policy rules', async () => { + await expect( + loadModelWithError(` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + @@allow('all', encrypted_value != 'abc123') + } + `) + ).resolves.toContain(`Encrypted fields cannot be used in policy rules`); + }); + + it('Complains when encrypted fields are used in field-level policy rules', async () => { + await expect( + loadModelWithError(` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + value Int @allow('all', encrypted_value != 'abc123') + } + `) + ).resolves.toContain(`Encrypted fields cannot be used in policy rules`); + }); +}); diff --git a/tests/integration/tests/enhancements/with-password/with-password.test.ts b/tests/integration/tests/enhancements/with-password/with-password.test.ts index b2fd89a65..a54d0c42d 100644 --- a/tests/integration/tests/enhancements/with-password/with-password.test.ts +++ b/tests/integration/tests/enhancements/with-password/with-password.test.ts @@ -14,7 +14,7 @@ describe('Password test', () => { }); it('password tests', async () => { - const { enhance } = await loadSchema(` + const { enhance, prisma } = await loadSchema(` model User { id String @id @default(cuid()) password String @password(saltLength: 16) @@ -38,6 +38,27 @@ describe('Password test', () => { }, }); expect(compareSync('abc456', r1.password)).toBeTruthy(); + + await db.user.createMany({ + data: [ + { id: '2', password: 'user2' }, + { id: '3', password: 'user3' }, + ], + }); + await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.not.toMatchObject({ password: 'user2' }); + const r2 = await db.user.findUnique({ where: { id: '2' } }); + expect(compareSync('user2', r2.password)).toBeTruthy(); + + const [u4] = await db.user.createManyAndReturn({ + data: [ + { id: '4', password: 'user4' }, + { id: '5', password: 'user5' }, + ], + }); + expect(compareSync('user4', u4.password)).toBeTruthy(); + await expect(prisma.user.findUnique({ where: { id: '4' } })).resolves.not.toMatchObject({ password: 'user4' }); + const r4 = await db.user.findUnique({ where: { id: '4' } }); + expect(compareSync('user4', r4.password)).toBeTruthy(); }); it('length tests', async () => { diff --git a/tests/integration/tests/enhancements/with-policy/currentModel.test.ts b/tests/integration/tests/enhancements/with-policy/currentModel.test.ts new file mode 100644 index 000000000..0b98314a4 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/currentModel.test.ts @@ -0,0 +1,185 @@ +import { loadModelWithError, loadSchema } from '@zenstackhq/testtools'; + +describe('currentModel tests', () => { + it('works in models', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentModel() == 'User') + } + + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentModel() == 'User') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with upper case', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('upper') == 'USER') + } + + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('upper') == 'Post') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with lower case', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('lower') == 'user') + } + + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('lower') == 'Post') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with capitalization', async () => { + const { enhance } = await loadSchema( + ` + model user { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('capitalize') == 'User') + } + + model post { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('capitalize') == 'post') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with uncapitalization', async () => { + const { enhance } = await loadSchema( + ` + model USER { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('uncapitalize') == 'uSER') + } + + model POST { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('uncapitalize') == 'POST') + } + ` + ); + + const db = enhance(); + await expect(db.USER.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.POST.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works when inherited from abstract base', async () => { + const { enhance } = await loadSchema( + ` + abstract model Base { + id Int @id + @@allow('read', true) + @@allow('create', currentModel() == 'User') + } + + model User extends Base { + } + + model Post extends Base { + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works when inherited from delegate base', async () => { + const { enhance } = await loadSchema( + ` + model Base { + id Int @id + type String + @@delegate(type) + + @@allow('read', true) + @@allow('create', currentModel() == 'User') + } + + model User extends Base { + } + + model Post extends Base { + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('complains when used outside policies', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @default(currentModel()) + } + ` + ) + ).resolves.toContain('function "currentModel" is not allowed in the current context: DefaultValue'); + }); + + it('complains when casing argument is invalid', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @id + @@allow('create', currentModel('foo') == 'User') + } + ` + ) + ).resolves.toContain('argument must be one of: "original", "upper", "lower", "capitalize", "uncapitalize"'); + }); +}); diff --git a/tests/integration/tests/enhancements/with-policy/currentOperation.test.ts b/tests/integration/tests/enhancements/with-policy/currentOperation.test.ts new file mode 100644 index 000000000..c56713316 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/currentOperation.test.ts @@ -0,0 +1,154 @@ +import { loadModelWithError, loadSchema } from '@zenstackhq/testtools'; + +describe('currentOperation tests', () => { + it('works with specific rules', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation() == 'create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation() == 'read') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with all rule', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('all', currentOperation() == 'create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation() == 'read') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with upper case', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('upper') == 'CREATE') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('upper') == 'READ') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with lower case', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('lower') == 'create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('lower') == 'read') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with capitalization', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('capitalize') == 'Create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('capitalize') == 'create') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with uncapitalization', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('uncapitalize') == 'create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('uncapitalize') == 'read') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('complains when used outside policies', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @default(currentOperation()) + } + ` + ) + ).resolves.toContain('function "currentOperation" is not allowed in the current context: DefaultValue'); + }); + + it('complains when casing argument is invalid', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @id + @@allow('create', currentOperation('foo') == 'User') + } + ` + ) + ).resolves.toContain('argument must be one of: "original", "upper", "lower", "capitalize", "uncapitalize"'); + }); +}); diff --git a/tests/regression/tests/issue-1467.test.ts b/tests/regression/tests/issue-1467.test.ts new file mode 100644 index 000000000..374313e45 --- /dev/null +++ b/tests/regression/tests/issue-1467.test.ts @@ -0,0 +1,51 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1467', () => { + it('regression', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + type String + @@allow('all', true) + } + + model Container { + id Int @id @default(autoincrement()) + drink Drink @relation(fields: [drinkId], references: [id]) + drinkId Int + @@allow('all', true) + } + + model Drink { + id Int @id @default(autoincrement()) + name String @unique + containers Container[] + type String + + @@delegate(type) + @@allow('all', true) + } + + model Beer extends Drink { + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + await db.beer.create({ + data: { id: 1, name: 'Beer1' }, + }); + + await db.container.create({ data: { drink: { connect: { id: 1 } } } }); + await db.container.create({ data: { drink: { connect: { id: 1 } } } }); + + const beers = await db.beer.findFirst({ + select: { id: true, name: true, _count: { select: { containers: true } } }, + orderBy: { name: 'asc' }, + }); + expect(beers).toMatchObject({ _count: { containers: 2 } }); + }); +}); diff --git a/tests/regression/tests/issue-1647.test.ts b/tests/regression/tests/issue-1647.test.ts new file mode 100644 index 000000000..e93f63cfb --- /dev/null +++ b/tests/regression/tests/issue-1647.test.ts @@ -0,0 +1,69 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import fs from 'fs'; + +describe('issue 1647', () => { + it('inherits @@schema by default', async () => { + const { projectDir } = await loadSchema( + ` + datasource db { + provider = 'postgresql' + url = env('DATABASE_URL') + schemas = ['public', 'post'] + } + + generator client { + provider = 'prisma-client-js' + previewFeatures = ['multiSchema'] + } + + model Asset { + id Int @id + type String + @@delegate(type) + @@schema('public') + } + + model Post extends Asset { + title String + } + `, + { addPrelude: false, pushDb: false, getPrismaOnly: true } + ); + + const prismaSchema = fs.readFileSync(`${projectDir}/prisma/schema.prisma`, 'utf-8'); + expect(prismaSchema.split('\n').filter((l) => l.includes('@@schema("public")'))).toHaveLength(2); + }); + it('respects sub model @@schema overrides', async () => { + const { projectDir } = await loadSchema( + ` + datasource db { + provider = 'postgresql' + url = env('DATABASE_URL') + schemas = ['public', 'post'] + } + + generator client { + provider = 'prisma-client-js' + previewFeatures = ['multiSchema'] + } + + model Asset { + id Int @id + type String + @@delegate(type) + @@schema('public') + } + + model Post extends Asset { + title String + @@schema('post') + } + `, + { addPrelude: false, pushDb: false, getPrismaOnly: true } + ); + + const prismaSchema = fs.readFileSync(`${projectDir}/prisma/schema.prisma`, 'utf-8'); + expect(prismaSchema.split('\n').filter((l) => l.includes('@@schema("public")'))).toHaveLength(1); + expect(prismaSchema.split('\n').filter((l) => l.includes('@@schema("post")'))).toHaveLength(1); + }); +}); diff --git a/tests/regression/tests/issue-1930.test.ts b/tests/regression/tests/issue-1930.test.ts new file mode 100644 index 000000000..762369321 --- /dev/null +++ b/tests/regression/tests/issue-1930.test.ts @@ -0,0 +1,80 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1930', () => { + it('regression', async () => { + const { enhance } = await loadSchema( + ` +model Organization { + id String @id @default(cuid()) + entities Entity[] + + @@allow('all', true) +} + +model Entity { + id String @id @default(cuid()) + org Organization? @relation(fields: [orgId], references: [id]) + orgId String? + contents EntityContent[] + entityType String + isDeleted Boolean @default(false) + + @@delegate(entityType) + + @@allow('all', !isDeleted) +} + +model EntityContent { + id String @id @default(cuid()) + entity Entity @relation(fields: [entityId], references: [id]) + entityId String + + entityContentType String + + @@delegate(entityContentType) + + @@allow('create', true) + @@allow('read', check(entity)) +} + +model Article extends Entity { +} + +model ArticleContent extends EntityContent { + body String? +} + +model OtherContent extends EntityContent { + data Int +} + ` + ); + + const fullDb = enhance(undefined, { kinds: ['delegate'] }); + const org = await fullDb.organization.create({ data: {} }); + const article = await fullDb.article.create({ + data: { org: { connect: { id: org.id } } }, + }); + + const db = enhance(); + + // normal create/read + await expect( + db.articleContent.create({ + data: { body: 'abc', entity: { connect: { id: article.id } } }, + }) + ).toResolveTruthy(); + await expect(db.article.findFirst({ include: { contents: true } })).resolves.toMatchObject({ + contents: expect.arrayContaining([expect.objectContaining({ body: 'abc' })]), + }); + + // deleted article's contents are not readable + const deletedArticle = await fullDb.article.create({ + data: { org: { connect: { id: org.id } }, isDeleted: true }, + }); + const content1 = await fullDb.articleContent.create({ + data: { body: 'bcd', entity: { connect: { id: deletedArticle.id } } }, + }); + await expect(db.articleContent.findUnique({ where: { id: content1.id } })).toResolveNull(); + }); +});