Skip to content

Commit 0860162

Browse files
committed
feat: Allow custom client location
This breaks the type & runtime dependency on a hardcoded `@prisma/client` client output location, allowing custom location (eg when having multiple clients). This adds a layer of runtime safety to validate the DMMF, which can now be passed in the configuration, allowing all sorts of runtime hacks. See #18 & #19.
1 parent 27da08f commit 0860162

15 files changed

+150
-84
lines changed

README.md

+15
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,21 @@ _Tip: the current encryption key is already part of the decryption keys, no need
247247
Key rotation on existing fields (decrypt with old key and re-encrypt with the
248248
new one) is done by [data migrations](#migrations).
249249

250+
## Custom Prisma Client Location
251+
252+
If you are generating your Prisma client to a custom location, you'll need to
253+
tell the middleware where to look for the DMMF _(the internal AST generated by Prisma that we use to read those triple-slash comments)_:
254+
255+
```ts
256+
import { Prisma } from '../my/prisma/client'
257+
258+
prismaClient.$use(
259+
fieldEncryptionMiddleware({
260+
dmmf: Prisma.dmmf
261+
})
262+
)
263+
```
264+
250265
**Roadmap:**
251266

252267
- [x] Provide multiple decryption keys

package.json

+4-4
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,14 @@
3333
"postbuild": "chmod +x ./dist/generator/main.js && cd node_modules/.bin && ln -sf ../../dist/generator/main.js ./prisma-field-encryption",
3434
"generate": "run-s generate:*",
3535
"generate:prisma": "prisma generate",
36-
"generate:dmmf": "ts-node ./src/scripts/generateDMMF.ts",
3736
"test": "run-s test:**",
3837
"test:types": "tsc --noEmit",
3938
"test:unit": "jest --config jest.config.unit.json",
4039
"pretest:integration": "cp -f ./prisma/db.test.sqlite ./prisma/db.integration.sqlite",
4140
"test:integration": "jest --config jest.config.integration.json --runInBand",
4241
"test:coverage:merge": "nyc merge ./coverage ./coverage/coverage-final.json",
4342
"test:coverage:report": "nyc report -t ./coverage --r html -r lcov -r clover",
44-
"ci": "run-s generate build test",
43+
"ci": "run-s build test",
4544
"prepare": "husky install",
4645
"premigrate": "run-s build generate",
4746
"migrate": "ts-node ./src/tests/migrate.ts"
@@ -50,14 +49,15 @@
5049
"@47ng/cloak": "^1.1.0-beta.2",
5150
"@prisma/generator-helper": "^3.13.0",
5251
"immer": "^9.0.12",
53-
"object-path": "^0.11.8"
52+
"object-path": "^0.11.8",
53+
"zod": "^3.15.1"
5454
},
5555
"peerDependencies": {
5656
"@prisma/client": "^3.8.0"
5757
},
5858
"devDependencies": {
5959
"@commitlint/config-conventional": "^16.2.4",
60-
"@prisma/client": "^3.13.0",
60+
"@prisma/client": "3.13.0",
6161
"@prisma/sdk": "^3.13.0",
6262
"@types/jest": "^27.4.1",
6363
"@types/node": "^17.0.29",

prisma/schema.prisma

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ datasource db {
66
generator client {
77
provider = "prisma-client-js"
88
previewFeatures = ["interactiveTransactions"]
9+
output = "../src/tests/.generated/client"
910
}
1011

1112
// generator fieldEncryptionMigrations {

src/dmmf.ts

+3-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
import { Prisma } from '@prisma/client'
21
import { errors, warnings } from './errors'
3-
import type { DMMF, FieldConfiguration } from './types'
2+
import { DMMFDocument, dmmfDocumentParser, FieldConfiguration } from './types'
43

54
export interface ConnectionDescriptor {
65
modelName: string
@@ -23,13 +22,8 @@ export type DMMFModels = Record<string, DMMFModelDescriptor> // key: model name
2322

2423
const supportedCursorTypes = ['Int', 'String']
2524

26-
export function analyseDMMF(dmmf: DMMF = Prisma.dmmf): DMMFModels {
27-
// todo: Make it robust against changes in the DMMF structure
28-
// (can happen as it's an undocumented API)
29-
// - Prisma.dmmf does not exist
30-
// - Models are not located there, or empty -> warning
31-
// - Model objects don't conform to what we need (parse with zod)
32-
25+
export function analyseDMMF(input: DMMFDocument): DMMFModels {
26+
const dmmf = dmmfDocumentParser.parse(input)
3327
const allModels = dmmf.datamodel.models
3428

3529
return allModels.reduce<DMMFModels>((output, model) => {

src/encryption.ts

+34-31
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ const writeOperations = [
5959

6060
const whereClauseRegExp = /\.where\./
6161

62-
export function encryptOnWrite(
63-
params: MiddlewareParams,
62+
export function encryptOnWrite<Models extends string, Actions extends string>(
63+
params: MiddlewareParams<Models, Actions>,
6464
keys: KeysConfiguration,
6565
models: DMMFModels,
6666
operation: string
@@ -71,42 +71,45 @@ export function encryptOnWrite(
7171

7272
const encryptionErrors: string[] = []
7373

74-
const mutatedParams = produce(params, (draft: Draft<MiddlewareParams>) => {
75-
visitInputTargetFields(
76-
draft,
77-
models,
78-
function encryptFieldValue({
79-
fieldConfig,
80-
value: clearText,
81-
path,
82-
model,
83-
field
84-
}) {
85-
if (!fieldConfig.encrypt) {
86-
return
87-
}
88-
if (whereClauseRegExp.test(path)) {
89-
console.warn(warnings.whereClause(operation, path))
90-
}
91-
try {
92-
const cipherText = encryptStringSync(clearText, keys.encryptionKey)
93-
objectPath.set(draft.args, path, cipherText)
94-
} catch (error) {
95-
encryptionErrors.push(
96-
errors.fieldEncryptionError(model, field, path, error)
97-
)
74+
const mutatedParams = produce(
75+
params,
76+
(draft: Draft<MiddlewareParams<Models, Actions>>) => {
77+
visitInputTargetFields(
78+
draft,
79+
models,
80+
function encryptFieldValue({
81+
fieldConfig,
82+
value: clearText,
83+
path,
84+
model,
85+
field
86+
}) {
87+
if (!fieldConfig.encrypt) {
88+
return
89+
}
90+
if (whereClauseRegExp.test(path)) {
91+
console.warn(warnings.whereClause(operation, path))
92+
}
93+
try {
94+
const cipherText = encryptStringSync(clearText, keys.encryptionKey)
95+
objectPath.set(draft.args, path, cipherText)
96+
} catch (error) {
97+
encryptionErrors.push(
98+
errors.fieldEncryptionError(model, field, path, error)
99+
)
100+
}
98101
}
99-
}
100-
)
101-
})
102+
)
103+
}
104+
)
102105
if (encryptionErrors.length > 0) {
103106
throw new Error(errors.encryptionErrorReport(operation, encryptionErrors))
104107
}
105108
return mutatedParams
106109
}
107110

108-
export function decryptOnRead(
109-
params: MiddlewareParams,
111+
export function decryptOnRead<Models extends string, Actions extends string>(
112+
params: MiddlewareParams<Models, Actions>,
110113
result: any,
111114
keys: KeysConfiguration,
112115
models: DMMFModels,

src/errors.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { Prisma } from '@prisma/client'
1+
import type { DMMFField, DMMFModel } from './types'
22

33
const header = '[prisma-field-encryption]'
44

@@ -8,7 +8,7 @@ const prefixWarning = (input: string) => `${header} Warning: ${input}`
88
export const errors = {
99
// Setup errors
1010
noEncryptionKey: prefixError('no encryption key provided.'),
11-
unsupportedFieldType: (model: Prisma.DMMF.Model, field: Prisma.DMMF.Field) =>
11+
unsupportedFieldType: (model: DMMFModel, field: DMMFField) =>
1212
prefixError(
1313
`encryption enabled for field ${model.name}.${field.name} of unsupported type ${field.type}: only String fields can be encrypted.`
1414
),

src/generator/runtime/visitRecords.ts

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
1-
import type { PrismaClient } from '@prisma/client'
21
import { defaultProgressReport, ProgressReportCallback } from './progressReport'
32

4-
export type RecordVisitor<Cursor> = (
3+
export type RecordVisitor<PrismaClient, Cursor> = (
54
client: PrismaClient,
65
cursor: Cursor | undefined
76
) => Promise<Cursor | undefined>
87

9-
export interface VisitRecordsArgs<Cursor> {
8+
export interface VisitRecordsArgs<PrismaClient, Cursor> {
109
modelName: string
1110
client: PrismaClient
1211
getTotalCount: () => Promise<number>
13-
migrateRecord: RecordVisitor<Cursor>
12+
migrateRecord: RecordVisitor<PrismaClient, Cursor>
1413
reportProgress?: ProgressReportCallback
1514
}
1615

17-
export async function visitRecords<Cursor>({
16+
export async function visitRecords<PrismaClient, Cursor>({
1817
modelName,
1918
client,
2019
getTotalCount,
2120
migrateRecord,
2221
reportProgress = defaultProgressReport
23-
}: VisitRecordsArgs<Cursor>) {
22+
}: VisitRecordsArgs<PrismaClient, Cursor>) {
2423
const totalCount = await getTotalCount()
2524
if (totalCount === 0) {
2625
return 0

src/index.ts

+9-6
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@ import { analyseDMMF } from './dmmf'
22
import { configureKeys, decryptOnRead, encryptOnWrite } from './encryption'
33
import type { Configuration, Middleware, MiddlewareParams } from './types'
44

5-
export function fieldEncryptionMiddleware(
6-
config: Configuration = {}
7-
): Middleware {
5+
export function fieldEncryptionMiddleware<
6+
Models extends string = any,
7+
Actions extends string = any
8+
>(config: Configuration = {}): Middleware<Models, Actions> {
89
// This will throw if the encryption key is missing
910
// or if anything is invalid.
1011
const keys = configureKeys(config)
11-
const models = analyseDMMF()
12+
const models = analyseDMMF(
13+
config.dmmf ?? require('@prisma/client').Prisma.dmmf
14+
)
1215

1316
return async function fieldEncryptionMiddleware(
14-
params: MiddlewareParams,
15-
next: (params: MiddlewareParams) => Promise<any>
17+
params: MiddlewareParams<Models, Actions>,
18+
next: (params: MiddlewareParams<Models, Actions>) => Promise<any>
1619
) {
1720
if (!params.model) {
1821
// Unsupported operation

src/scripts/generateDMMF.ts

-12
This file was deleted.

src/tests/.generated/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
client/

src/tests/prismaClient.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import { PrismaClient } from '@prisma/client'
21
import { fieldEncryptionMiddleware } from '../index'
2+
import { Prisma, PrismaClient } from './.generated/client'
33

44
export const TEST_ENCRYPTION_KEY =
55
'k1.aesgcm256.OsqVmAOZBB_WW3073q1wU4ag0ap0ETYAYMh041RuxuI='
@@ -33,7 +33,8 @@ client.$use(async (params, next) => {
3333

3434
client.$use(
3535
fieldEncryptionMiddleware({
36-
encryptionKey: TEST_ENCRYPTION_KEY
36+
encryptionKey: TEST_ENCRYPTION_KEY,
37+
dmmf: Prisma.dmmf
3738
})
3839
)
3940

src/types.ts

+55-5
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,66 @@
1-
import { Prisma } from '@prisma/client'
1+
/**
2+
* Prisma types --
3+
*
4+
* We're copying just what we need for local type safety
5+
* without importing Prisma-generated types, as the location
6+
* of the generated client can be unknown (when using custom
7+
* or multiple client locations).
8+
*/
29

3-
// Prisma types --
10+
import { z } from 'zod'
411

5-
export type MiddlewareParams = Prisma.MiddlewareParams
6-
export type Middleware = Prisma.Middleware
7-
export type DMMF = typeof Prisma.dmmf
12+
/**
13+
* Not ideal to use `any` on model & action, but Prisma's
14+
* strong typing there actually prevents using the correct
15+
* type without excessive generics wizardry.
16+
*/
17+
export type MiddlewareParams<Models extends string, Actions extends string> = {
18+
model?: Models
19+
action: Actions
20+
args: any
21+
dataPath: string[]
22+
runInTransaction: boolean
23+
}
24+
25+
export type Middleware<
26+
Models extends string,
27+
Actions extends string,
28+
Result = any
29+
> = (
30+
params: MiddlewareParams<Models, Actions>,
31+
next: (params: MiddlewareParams<Models, Actions>) => Promise<Result>
32+
) => Promise<Result>
33+
34+
const dmmfFieldParser = z.object({
35+
name: z.string(),
36+
isList: z.boolean(),
37+
isUnique: z.boolean(),
38+
isId: z.boolean(),
39+
type: z.string(),
40+
documentation: z.string().optional()
41+
})
42+
43+
const dmmfModelParser = z.object({
44+
name: z.string(),
45+
fields: z.array(dmmfFieldParser)
46+
})
47+
48+
export const dmmfDocumentParser = z.object({
49+
datamodel: z.object({
50+
models: z.array(dmmfModelParser)
51+
})
52+
})
53+
54+
export type DMMFModel = z.TypeOf<typeof dmmfModelParser>
55+
export type DMMFField = z.TypeOf<typeof dmmfFieldParser>
56+
export type DMMFDocument = z.TypeOf<typeof dmmfDocumentParser>
857

958
// Internal types --
1059

1160
export interface Configuration {
1261
encryptionKey?: string
1362
decryptionKeys?: string[]
63+
dmmf?: DMMFDocument
1464
}
1565

1666
export interface FieldConfiguration {

src/visitor.test.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ describe('visitor', () => {
3636

3737
test('visitInputTargetFields - simple example', async () => {
3838
const models = analyseDMMF(await dmmf)
39-
const params: MiddlewareParams = {
39+
const params: MiddlewareParams<any, any> = {
4040
action: 'create',
4141
model: 'User',
4242
args: {
@@ -63,7 +63,7 @@ describe('visitor', () => {
6363

6464
test('visitInputTargetFields - nested create', async () => {
6565
const models = analyseDMMF(await dmmf)
66-
const params: MiddlewareParams = {
66+
const params: MiddlewareParams<any, any> = {
6767
action: 'create',
6868
model: 'User',
6969
args: {

src/visitor.ts

+10-4
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,11 @@ const makeVisitor = (models: DMMFModels, visitor: TargetFieldVisitorFn) =>
6060
return state
6161
}
6262

63-
export function visitInputTargetFields(
64-
params: MiddlewareParams,
63+
export function visitInputTargetFields<
64+
Models extends string,
65+
Actions extends string
66+
>(
67+
params: MiddlewareParams<Models, Actions>,
6568
models: DMMFModels,
6669
visitor: TargetFieldVisitorFn
6770
) {
@@ -70,8 +73,11 @@ export function visitInputTargetFields(
7073
})
7174
}
7275

73-
export function visitOutputTargetFields(
74-
params: MiddlewareParams,
76+
export function visitOutputTargetFields<
77+
Models extends string,
78+
Actions extends string
79+
>(
80+
params: MiddlewareParams<Models, Actions>,
7581
result: any,
7682
models: DMMFModels,
7783
visitor: TargetFieldVisitorFn

0 commit comments

Comments
 (0)