Skip to content

Commit 1369f9e

Browse files
authoredMar 21, 2025··
fix: allow using currentModel() and currentOperation calls in policy rules (#2050)
1 parent 57c6120 commit 1369f9e

File tree

3 files changed

+91
-11
lines changed

3 files changed

+91
-11
lines changed
 

Diff for: ‎packages/schema/src/language-server/validator/function-invocation-validator.ts

+33-10
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {
1111
isDataModel,
1212
isDataModelAttribute,
1313
isDataModelFieldAttribute,
14+
isInvocationExpr,
1415
isLiteralExpr,
1516
} from '@zenstackhq/language/ast';
1617
import {
@@ -21,6 +22,7 @@ import {
2122
isDataModelFieldReference,
2223
isEnumFieldReference,
2324
isFromStdlib,
25+
isValidationAttribute,
2426
} from '@zenstackhq/sdk';
2527
import { AstNode, streamAst, ValidationAcceptor } from 'langium';
2628
import { match, P } from 'ts-pattern';
@@ -70,20 +72,21 @@ export default class FunctionInvocationValidator implements AstValidator<Express
7072
}
7173

7274
// validate the context allowed for the function
73-
const exprContext = match(containerAttribute?.decl.$refText)
74-
.with('@default', () => ExpressionContext.DefaultValue)
75-
.with(P.union('@@allow', '@@deny', '@allow', '@deny'), () => ExpressionContext.AccessPolicy)
76-
.with('@@validate', () => ExpressionContext.ValidationRule)
77-
.with('@@index', () => ExpressionContext.Index)
78-
.otherwise(() => undefined);
75+
const exprContext = this.getExpressionContext(containerAttribute);
7976

8077
// get the context allowed for the function
8178
const funcAllowedContext = getFunctionExpressionContext(funcDecl);
8279

83-
if (exprContext && !funcAllowedContext.includes(exprContext)) {
84-
accept('error', `function "${funcDecl.name}" is not allowed in the current context: ${exprContext}`, {
85-
node: expr,
86-
});
80+
if (funcAllowedContext.length > 0 && (!exprContext || !funcAllowedContext.includes(exprContext))) {
81+
accept(
82+
'error',
83+
`function "${funcDecl.name}" is not allowed in the current context${
84+
exprContext ? ': ' + exprContext : ''
85+
}`,
86+
{
87+
node: expr,
88+
}
89+
);
8790
return;
8891
}
8992

@@ -121,6 +124,8 @@ export default class FunctionInvocationValidator implements AstValidator<Express
121124
!isEnumFieldReference(secondArg) &&
122125
// `auth()...` expression
123126
!isAuthOrAuthMemberAccess(secondArg) &&
127+
// static function calls that are runtime constants: `currentModel`, `currentOperation`
128+
!this.isStaticFunctionCall(secondArg) &&
124129
// array of literal/enum
125130
!(
126131
isArrayExpr(secondArg) &&
@@ -148,6 +153,24 @@ export default class FunctionInvocationValidator implements AstValidator<Express
148153
}
149154
}
150155

156+
private getExpressionContext(containerAttribute: DataModelAttribute | DataModelFieldAttribute | undefined) {
157+
if (!containerAttribute) {
158+
return undefined;
159+
}
160+
if (isValidationAttribute(containerAttribute)) {
161+
return ExpressionContext.ValidationRule;
162+
}
163+
return match(containerAttribute?.decl.$refText)
164+
.with('@default', () => ExpressionContext.DefaultValue)
165+
.with(P.union('@@allow', '@@deny', '@allow', '@deny'), () => ExpressionContext.AccessPolicy)
166+
.with('@@index', () => ExpressionContext.Index)
167+
.otherwise(() => undefined);
168+
}
169+
170+
private isStaticFunctionCall(expr: Expression) {
171+
return isInvocationExpr(expr) && ['currentModel', 'currentOperation'].includes(expr.function.$refText);
172+
}
173+
151174
private validateArgs(funcDecl: FunctionDecl, args: Argument[], accept: ValidationAcceptor) {
152175
let success = true;
153176
for (let i = 0; i < funcDecl.params.length; i++) {

Diff for: ‎packages/sdk/src/validation.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import {
77
type TypeDef,
88
} from './ast';
99

10-
function isValidationAttribute(attr: DataModelAttribute | DataModelFieldAttribute) {
10+
export function isValidationAttribute(attr: DataModelAttribute | DataModelFieldAttribute) {
1111
return attr.decl.ref?.attributes.some((attr) => attr.decl.$refText === '@@@validation');
1212
}
1313

Diff for: ‎tests/regression/tests/issue-1984.test.ts

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import { loadModel, loadModelWithError, loadSchema } from '@zenstackhq/testtools';
2+
3+
describe('issue 1984', () => {
4+
it('regression1', async () => {
5+
const { enhance } = await loadSchema(
6+
`
7+
model User {
8+
id Int @id @default(autoincrement())
9+
access String
10+
11+
@@allow('all',
12+
contains(auth().access, currentModel()) ||
13+
contains(auth().access, currentOperation()))
14+
}
15+
`
16+
);
17+
18+
const db1 = enhance();
19+
await expect(db1.user.create({ data: { access: 'foo' } })).toBeRejectedByPolicy();
20+
21+
const db2 = enhance({ id: 1, access: 'aUser' });
22+
await expect(db2.user.create({ data: { access: 'aUser' } })).toResolveTruthy();
23+
24+
const db3 = enhance({ id: 1, access: 'do-create-read' });
25+
await expect(db3.user.create({ data: { access: 'do-create-read' } })).toResolveTruthy();
26+
27+
const db4 = enhance({ id: 1, access: 'do-read' });
28+
await expect(db4.user.create({ data: { access: 'do-read' } })).toBeRejectedByPolicy();
29+
});
30+
31+
it('regression2', async () => {
32+
await expect(
33+
loadModelWithError(
34+
`
35+
model User {
36+
id Int @id @default(autoincrement())
37+
modelName String
38+
@@validate(contains(modelName, currentModel()))
39+
}
40+
`
41+
)
42+
).resolves.toContain('function "currentModel" is not allowed in the current context: ValidationRule');
43+
});
44+
45+
it('regression3', async () => {
46+
await expect(
47+
loadModelWithError(
48+
`
49+
model User {
50+
id Int @id @default(autoincrement())
51+
modelName String @contains(currentModel())
52+
}
53+
`
54+
)
55+
).resolves.toContain('function "currentModel" is not allowed in the current context: ValidationRule');
56+
});
57+
});

0 commit comments

Comments
 (0)