diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 8fc0495b..944d4a59 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -68,7 +68,9 @@ jobs: run: pnpm install --frozen-lockfile - name: Build - run: pnpm run build + run: | + pnpm run build + pnpm tsx packages/cli/scripts/post-build.ts - name: Lint run: pnpm run lint diff --git a/.github/workflows/publish-release.yml b/.github/workflows/publish-release.yml index 19457988..471c2b52 100644 --- a/.github/workflows/publish-release.yml +++ b/.github/workflows/publish-release.yml @@ -36,7 +36,9 @@ jobs: run: pnpm install --frozen-lockfile - name: Build - run: pnpm run build + run: | + pnpm run build + pnpm tsx packages/cli/scripts/post-build.ts - name: Get version from package.json id: version diff --git a/CLAUDE.md b/CLAUDE.md index 432af85c..4c48ffd2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -82,3 +82,4 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co - Database migrations still use Prisma CLI under the hood - Plugin system allows interception at ORM, Kysely, and entity mutation levels - Computed fields are evaluated at database level for performance +- The "ide/vscode" package by-design has a different version from the rest of the packages as VSCode doesn't allow pre-release versions in its marketplace. diff --git a/README.md b/README.md index f134a133..7f6dfbe0 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ -> V3 is currently in alpha phase and not ready for production use. Feedback and bug reports are greatly appreciated. Please visit this dedicated [discord channel](https://discord.com/channels/1035538056146595961/1352359627525718056) for chat and support. +> V3 is currently in beta phase and not ready for production use. Feedback and bug reports are greatly appreciated. Please visit this dedicated [discord channel](https://discord.com/channels/1035538056146595961/1352359627525718056) for chat and support. # What's ZenStack diff --git a/TODO.md b/TODO.md index 8ccc6729..cd66cb46 100644 --- a/TODO.md +++ b/TODO.md @@ -83,8 +83,10 @@ - [x] Error system - [x] Custom table name - [x] Custom field name + - [ ] Global omit - [ ] DbNull vs JsonNull - [ ] Migrate to tsdown + - [ ] @default validation - [ ] Benchmark - [x] Plugin - [x] Post-mutation hooks should be called after transaction is committed diff --git a/package.json b/package.json index 090976bb..039fb22b 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "zenstack-v3", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "description": "ZenStack", "packageManager": "pnpm@10.12.1", "scripts": { diff --git a/packages/cli/package.json b/packages/cli/package.json index 586b70b5..63c4249f 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack CLI", "description": "FullStack database toolkit with built-in access control and automatic API generation.", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "type": "module", "author": { "name": "ZenStack Team" diff --git a/packages/cli/scripts/post-build.ts b/packages/cli/scripts/post-build.ts new file mode 100644 index 00000000..99d4e3fb --- /dev/null +++ b/packages/cli/scripts/post-build.ts @@ -0,0 +1,20 @@ +import fs from 'node:fs'; +import path from 'node:path'; +import { fileURLToPath } from 'node:url'; + +const token = process.env.TELEMETRY_TRACKING_TOKEN ?? ''; + +if (!token) { + console.warn('TELEMETRY_TRACKING_TOKEN is not set.'); +} + +const filesToProcess = ['dist/index.js', 'dist/index.cjs']; +const _dirname = path.dirname(fileURLToPath(import.meta.url)); + +for (const file of filesToProcess) { + console.log(`Processing ${file} for telemetry token...`); + const filePath = path.join(_dirname, '..', file); + const content = fs.readFileSync(filePath, 'utf-8'); + const updatedContent = content.replace('', token); + fs.writeFileSync(filePath, updatedContent, 'utf-8'); +} diff --git a/packages/cli/tsup.config.ts b/packages/cli/tsup.config.ts index c1881d32..2496f3ea 100644 --- a/packages/cli/tsup.config.ts +++ b/packages/cli/tsup.config.ts @@ -1,5 +1,3 @@ -import fs from 'node:fs'; -import path from 'node:path'; import { defineConfig } from 'tsup'; export default defineConfig({ @@ -12,19 +10,4 @@ export default defineConfig({ clean: true, dts: true, format: ['esm', 'cjs'], - onSuccess: async () => { - if (!process.env['TELEMETRY_TRACKING_TOKEN']) { - return; - } - const filesToProcess = ['dist/index.js', 'dist/index.cjs']; - for (const file of filesToProcess) { - console.log(`Processing ${file} for telemetry token...`); - const content = fs.readFileSync(path.join(__dirname, file), 'utf-8'); - const updatedContent = content.replace( - '', - process.env['TELEMETRY_TRACKING_TOKEN'], - ); - fs.writeFileSync(file, updatedContent, 'utf-8'); - } - }, }); diff --git a/packages/common-helpers/package.json b/packages/common-helpers/package.json index 11b3cbc4..965c3dd9 100644 --- a/packages/common-helpers/package.json +++ b/packages/common-helpers/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/common-helpers", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "description": "ZenStack Common Helpers", "type": "module", "scripts": { diff --git a/packages/create-zenstack/package.json b/packages/create-zenstack/package.json index 8fa79f85..edc194bc 100644 --- a/packages/create-zenstack/package.json +++ b/packages/create-zenstack/package.json @@ -1,6 +1,6 @@ { "name": "create-zenstack", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "description": "Create a new ZenStack project", "type": "module", "scripts": { diff --git a/packages/dialects/sql.js/package.json b/packages/dialects/sql.js/package.json index 4130d721..be06b085 100644 --- a/packages/dialects/sql.js/package.json +++ b/packages/dialects/sql.js/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/kysely-sql-js", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "description": "Kysely dialect for sql.js", "type": "module", "scripts": { @@ -25,6 +25,10 @@ "types": "./dist/index.d.cts", "default": "./dist/index.cjs" } + }, + "./package.json": { + "import": "./package.json", + "require": "./package.json" } }, "devDependencies": { diff --git a/packages/eslint-config/package.json b/packages/eslint-config/package.json index 17f184bc..690197dc 100644 --- a/packages/eslint-config/package.json +++ b/packages/eslint-config/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/eslint-config", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "type": "module", "private": true, "license": "MIT" diff --git a/packages/language/package.json b/packages/language/package.json index 155a12ef..2e130d7f 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/language", "description": "ZenStack ZModel language specification", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "license": "MIT", "author": "ZenStack Team", "files": [ @@ -59,13 +59,14 @@ }, "devDependencies": { "@types/pluralize": "^0.0.33", + "@types/tmp": "catalog:", + "@zenstackhq/common-helpers": "workspace:*", "@zenstackhq/eslint-config": "workspace:*", "@zenstackhq/typescript-config": "workspace:*", - "@zenstackhq/common-helpers": "workspace:*", "@zenstackhq/vitest-config": "workspace:*", + "glob": "^11.0.2", "langium-cli": "catalog:", - "tmp": "catalog:", - "@types/tmp": "catalog:" + "tmp": "catalog:" }, "volta": { "node": "18.19.1", diff --git a/packages/language/src/validators/attribute-application-validator.ts b/packages/language/src/validators/attribute-application-validator.ts index c04666df..dc376036 100644 --- a/packages/language/src/validators/attribute-application-validator.ts +++ b/packages/language/src/validators/attribute-application-validator.ts @@ -1,19 +1,20 @@ import { AstUtils, type ValidationAcceptor } from 'langium'; import pluralize from 'pluralize'; +import type { BinaryExpr, DataModel, Expression } from '../ast'; import { ArrayExpr, Attribute, AttributeArg, AttributeParam, - DataModelAttribute, DataField, DataFieldAttribute, + DataModelAttribute, InternalAttribute, ReferenceExpr, isArrayExpr, isAttribute, - isDataModel, isDataField, + isDataModel, isEnum, isReferenceExpr, isTypeDef, @@ -21,7 +22,8 @@ import { import { getAllAttributes, getStringLiteral, - hasAttribute, + isAuthOrAuthMemberAccess, + isCollectionPredicate, isDataFieldReference, isDelegateModel, isFutureExpr, @@ -31,7 +33,6 @@ import { typeAssignable, } from '../utils'; import type { AstValidator } from './common'; -import type { DataModel } from '../ast'; // a registry of function handlers marked with @check const attributeCheckers = new Map(); @@ -153,6 +154,7 @@ export default class AttributeApplicationValidator implements AstValidator { + if (!isDataFieldReference(node)) { + // not a field reference, skip + return false; + } + + // referenced field is not a member of the context model, skip + if (node.target.ref?.$container !== contextModel) { + return false; + } + + const field = node.target.ref as DataField; + if (!isRelationshipField(field)) { + // not a relation, skip + return false; + } + + if (isAuthOrAuthMemberAccess(node)) { + // field reference is from auth() or access from auth(), not a relation query + return false; + } + + // check if the the node is a reference inside a collection predicate scope by auth access, + // e.g., `auth().foo?[x > 0]` + + // make sure to skip the current level if the node is already an LHS of a collection predicate, + // otherwise we're just circling back to itself when visiting the parent + const startNode = + isCollectionPredicate(node.$container) && (node.$container as BinaryExpr).left === node + ? node.$container + : node; + const collectionPredicate = AstUtils.getContainerOfType(startNode.$container, isCollectionPredicate); + if (collectionPredicate && isAuthOrAuthMemberAccess(collectionPredicate.left)) { + return false; + } + + const relationAttr = field.attributes.find((attr) => attr.decl.ref?.name === '@relation'); + if (!relationAttr) { + // no "@relation", not owner side of the relation, match + return true; + } + + if (!relationAttr.args.some((arg) => arg.name === 'fields')) { + // no "fields" argument, can't be owner side of the relation, match + return true; + } + + return false; + }) + ) { + accept('error', `non-owned relation fields are not allowed in "create" rules`, { node: expr }); + } + } + + // TODO: design a way to let plugin register validation @check('@allow') @check('@deny') // @ts-expect-error @@ -199,9 +266,6 @@ export default class AttributeApplicationValidator implements AstValidator { - if (isDataFieldReference(node) && hasAttribute(node.target.ref as DataField, '@encrypted')) { - accept('error', `Encrypted fields cannot be used in policy rules`, { node }); - } - }); - } - private validatePolicyKinds( kind: string, candidates: string[], diff --git a/packages/language/src/validators/expression-validator.ts b/packages/language/src/validators/expression-validator.ts index 28c15fc6..cf74db06 100644 --- a/packages/language/src/validators/expression-validator.ts +++ b/packages/language/src/validators/expression-validator.ts @@ -207,12 +207,12 @@ export default class ExpressionValidator implements AstValidator { isDataFieldReference(expr.left) && (isThisExpr(expr.right) || isDataFieldReference(expr.right)) ) { - accept('error', 'comparison between model-typed fields are not supported', { node: expr }); + accept('error', 'comparison between models is not supported', { node: expr }); } else if ( isDataFieldReference(expr.right) && (isThisExpr(expr.left) || isDataFieldReference(expr.left)) ) { - accept('error', 'comparison between model-typed fields are not supported', { node: expr }); + accept('error', 'comparison between models is not supported', { node: expr }); } } else if ( (isDataModel(leftType) && !isNullExpr(expr.right)) || diff --git a/packages/language/test/expression-validation.test.ts b/packages/language/test/expression-validation.test.ts new file mode 100644 index 00000000..5a3179f3 --- /dev/null +++ b/packages/language/test/expression-validation.test.ts @@ -0,0 +1,100 @@ +import { describe, it } from 'vitest'; +import { loadSchema, loadSchemaWithError } from './utils'; + +describe('Expression Validation Tests', () => { + it('should reject model comparison', async () => { + await loadSchemaWithError( + ` + model User { + id Int @id + name String + posts Post[] + } + + model Post { + id Int @id + title String + author User @relation(fields: [authorId], references: [id]) + @@allow('all', author == this) + } + `, + 'comparison between models is not supported', + ); + }); + + it('should reject model comparison', async () => { + await loadSchemaWithError( + ` + model User { + id Int @id + name String + profile Profile? + address Address? + @@allow('read', profile == this) + } + + model Profile { + id Int @id + bio String + user User @relation(fields: [userId], references: [id]) + userId Int @unique + } + + model Address { + id Int @id + street String + user User @relation(fields: [userId], references: [id]) + userId Int @unique + } + `, + 'comparison between models is not supported', + ); + }); + + it('should allow auth comparison with auth type', async () => { + await loadSchema( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id Int @id + name String + profile Profile? + @@allow('read', auth() == this) + } + + model Profile { + id Int @id + bio String + user User @relation(fields: [userId], references: [id]) + userId Int @unique + @@allow('read', auth() == user) + } + `, + ); + }); + + it('should reject auth comparison with non-auth type', async () => { + await loadSchemaWithError( + ` + model User { + id Int @id + name String + profile Profile? + } + + model Profile { + id Int @id + bio String + user User @relation(fields: [userId], references: [id]) + userId Int @unique + @@allow('read', auth() == this) + } + `, + 'incompatible operand types', + ); + }); +}); diff --git a/packages/language/test/utils.ts b/packages/language/test/utils.ts index fe558f41..b14bdabb 100644 --- a/packages/language/test/utils.ts +++ b/packages/language/test/utils.ts @@ -1,16 +1,20 @@ +import { invariant } from '@zenstackhq/common-helpers'; +import { glob } from 'glob'; +import fs from 'node:fs'; import os from 'node:os'; import path from 'node:path'; -import fs from 'node:fs'; -import { loadDocument } from '../src'; import { expect } from 'vitest'; -import { invariant } from '@zenstackhq/common-helpers'; +import { loadDocument } from '../src'; export async function loadSchema(schema: string) { // create a temp file const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`); fs.writeFileSync(tempFile, schema); - const r = await loadDocument(tempFile); - expect(r.success).toBe(true); + const r = await loadDocument(tempFile, getPluginModels()); + expect(r).toSatisfy( + (r) => r.success, + `Failed to load schema: ${(r as any).errors?.map((e) => e.toString()).join(', ')}`, + ); invariant(r.success); return r.model; } @@ -19,12 +23,21 @@ export async function loadSchemaWithError(schema: string, error: string | RegExp // create a temp file const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`); fs.writeFileSync(tempFile, schema); - const r = await loadDocument(tempFile); + const r = await loadDocument(tempFile, getPluginModels()); expect(r.success).toBe(false); invariant(!r.success); if (typeof error === 'string') { - expect(r.errors.some((e) => e.toString().toLowerCase().includes(error.toLowerCase()))).toBe(true); + expect(r).toSatisfy( + (r) => r.errors.some((e) => e.toString().toLowerCase().includes(error.toLowerCase())), + `Expected error message to include "${error}" but got: ${r.errors.map((e) => e.toString()).join(', ')}`, + ); } else { - expect(r.errors.some((e) => error.test(e))).toBe(true); + expect(r).toSatisfy( + (r) => r.errors.some((e) => error.test(e)), + `Expected error message to match "${error}" but got: ${r.errors.map((e) => e.toString()).join(', ')}`, + ); } } +function getPluginModels() { + return glob.sync(path.resolve(__dirname, '../../runtime/src/plugins/**/plugin.zmodel')); +} diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 2b06fb44..da696705 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/runtime", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "description": "ZenStack Runtime", "type": "module", "scripts": { diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base-dialect.ts similarity index 99% rename from packages/runtime/src/client/crud/dialects/base.ts rename to packages/runtime/src/client/crud/dialects/base-dialect.ts index 8afec156..9f314bf9 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base-dialect.ts @@ -1104,7 +1104,7 @@ export abstract class BaseCrudDialect { return (node as ValueNode).value === false || (node as ValueNode).value === 0; } - protected and(eb: ExpressionBuilder, ...args: Expression[]) { + and(eb: ExpressionBuilder, ...args: Expression[]) { const nonTrueArgs = args.filter((arg) => !this.isTrue(arg)); if (nonTrueArgs.length === 0) { return this.true(eb); @@ -1115,7 +1115,7 @@ export abstract class BaseCrudDialect { } } - protected or(eb: ExpressionBuilder, ...args: Expression[]) { + or(eb: ExpressionBuilder, ...args: Expression[]) { const nonFalseArgs = args.filter((arg) => !this.isFalse(arg)); if (nonFalseArgs.length === 0) { return this.false(eb); @@ -1126,7 +1126,7 @@ export abstract class BaseCrudDialect { } } - protected not(eb: ExpressionBuilder, ...args: Expression[]) { + not(eb: ExpressionBuilder, ...args: Expression[]) { return eb.not(this.and(eb, ...args)); } diff --git a/packages/runtime/src/client/crud/dialects/index.ts b/packages/runtime/src/client/crud/dialects/index.ts index 9d67009e..ede19cdd 100644 --- a/packages/runtime/src/client/crud/dialects/index.ts +++ b/packages/runtime/src/client/crud/dialects/index.ts @@ -1,7 +1,7 @@ import { match } from 'ts-pattern'; import type { SchemaDef } from '../../../schema'; import type { ClientOptions } from '../../options'; -import type { BaseCrudDialect } from './base'; +import type { BaseCrudDialect } from './base-dialect'; import { PostgresCrudDialect } from './postgresql'; import { SqliteCrudDialect } from './sqlite'; diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index 93722037..a71e987d 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -20,7 +20,7 @@ import { requireField, requireModel, } from '../../query-utils'; -import { BaseCrudDialect } from './base'; +import { BaseCrudDialect } from './base-dialect'; export class PostgresCrudDialect extends BaseCrudDialect { override get provider() { diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 34ece56e..69de608d 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -20,7 +20,7 @@ import { requireField, requireModel, } from '../../query-utils'; -import { BaseCrudDialect } from './base'; +import { BaseCrudDialect } from './base-dialect'; export class SqliteCrudDialect extends BaseCrudDialect { override get provider() { diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index c3bab79d..65d0d32b 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -7,7 +7,6 @@ import { UpdateResult, type Compilable, type IsolationLevel, - type QueryResult, type SelectQueryBuilder, } from 'kysely'; import { nanoid } from 'nanoid'; @@ -44,7 +43,7 @@ import { requireModel, } from '../../query-utils'; import { getCrudDialect } from '../dialects'; -import type { BaseCrudDialect } from '../dialects/base'; +import type { BaseCrudDialect } from '../dialects/base-dialect'; import { InputValidator } from '../validator'; export type CoreCrudOperation = @@ -66,10 +65,16 @@ export type CoreCrudOperation = export type AllCrudOperation = CoreCrudOperation | 'findUniqueOrThrow' | 'findFirstOrThrow'; +// context for nested relation operations export type FromRelationContext = { + // the model where the relation field is defined model: GetModels; + // the relation field name field: string; + // the parent entity's id fields and values ids: any; + // for relations owned by model, record the parent updates needed after the relation is processed + parentUpdates: Record; }; export abstract class BaseOperationHandler { @@ -258,7 +263,7 @@ export abstract class BaseOperationHandler { } let createFields: any = {}; - let parentUpdateTask: ((entity: any) => Promise) | undefined = undefined; + let updateParent: ((entity: any) => void) | undefined = undefined; let m2m: ReturnType = undefined; @@ -281,26 +286,11 @@ export abstract class BaseOperationHandler { ); Object.assign(createFields, parentFkFields); } else { - parentUpdateTask = (entity) => { - const query = kysely - .updateTable(fromRelation.model) - .set( - keyPairs.reduce( - (acc, { fk, pk }) => ({ - ...acc, - [fk]: entity[pk], - }), - {} as any, - ), - ) - .where((eb) => eb.and(fromRelation.ids)) - .modifyEnd( - this.makeContextComment({ - model: fromRelation.model, - operation: 'update', - }), - ); - return this.executeQuery(kysely, query, 'update'); + // record parent fk update after entity is created + updateParent = (entity) => { + for (const { fk, pk } of keyPairs) { + fromRelation.parentUpdates[fk] = entity[pk]; + } }; } } @@ -353,7 +343,7 @@ export abstract class BaseOperationHandler { createFields = baseCreateResult.remainingFields; } - const updatedData = this.fillGeneratedValues(modelDef, createFields); + const updatedData = this.fillGeneratedAndDefaultValues(modelDef, createFields); const idFields = getIdFields(this.schema, model); const query = kysely .insertInto(model) @@ -403,8 +393,8 @@ export abstract class BaseOperationHandler { } // finally update parent if needed - if (parentUpdateTask) { - await parentUpdateTask(createdEntity); + if (updateParent) { + updateParent(createdEntity); } return createdEntity; @@ -567,7 +557,10 @@ export abstract class BaseOperationHandler { select: fieldsToSelectObject(referencedPkFields) as any, }); if (!relationEntity) { - throw new NotFoundError(`Could not find the entity for connect action`); + throw new NotFoundError( + relationModel, + `Could not find the entity to connect for the relation "${relationField.name}"`, + ); } result = relationEntity; } @@ -605,10 +598,11 @@ export abstract class BaseOperationHandler { const relationFieldDef = this.requireField(contextModel, relationFieldName); const relationModel = relationFieldDef.type as GetModels; const tasks: Promise[] = []; - const fromRelationContext = { + const fromRelationContext: FromRelationContext = { model: contextModel, field: relationFieldName, ids: parentEntity, + parentUpdates: {}, }; for (const [action, subPayload] of Object.entries(payload)) { @@ -641,13 +635,7 @@ export abstract class BaseOperationHandler { } case 'connect': { - tasks.push( - this.connectRelation(kysely, relationModel, subPayload, { - model: contextModel, - field: relationFieldName, - ids: parentEntity, - }), - ); + tasks.push(this.connectRelation(kysely, relationModel, subPayload, fromRelationContext)); break; } @@ -656,16 +644,8 @@ export abstract class BaseOperationHandler { ...enumerate(subPayload).map((item) => this.exists(kysely, relationModel, item.where).then((found) => !found - ? this.create(kysely, relationModel, item.create, { - model: contextModel, - field: relationFieldName, - ids: parentEntity, - }) - : this.connectRelation(kysely, relationModel, found, { - model: contextModel, - field: relationFieldName, - ids: parentEntity, - }), + ? this.create(kysely, relationModel, item.create, fromRelationContext) + : this.connectRelation(kysely, relationModel, found, fromRelationContext), ), ), ); @@ -722,7 +702,7 @@ export abstract class BaseOperationHandler { newItem[fk] = fromRelation.ids[pk]; } } - return this.fillGeneratedValues(modelDef, newItem); + return this.fillGeneratedAndDefaultValues(modelDef, newItem); }); if (!this.dialect.supportInsertWithDefault) { @@ -841,7 +821,7 @@ export abstract class BaseOperationHandler { return { baseEntities, remainingFieldRows }; } - private fillGeneratedValues(modelDef: ModelDef, data: object) { + private fillGeneratedAndDefaultValues(modelDef: ModelDef, data: object) { const fields = modelDef.fields; const values: any = clone(data); for (const [field, fieldDef] of Object.entries(fields)) { @@ -858,6 +838,21 @@ export abstract class BaseOperationHandler { } else if (fields[field]?.updatedAt) { // TODO: should this work at kysely level instead? values[field] = this.dialect.transformPrimitive(new Date(), 'DateTime', false); + } else if (fields[field]?.default !== undefined) { + let value = fields[field].default; + if (fieldDef.type === 'Json') { + // Schema uses JSON string for default value of Json fields + if (fieldDef.array && Array.isArray(value)) { + value = value.map((v) => (typeof v === 'string' ? JSON.parse(v) : v)); + } else if (typeof value === 'string') { + value = JSON.parse(value); + } + } + values[field] = this.dialect.transformPrimitive( + value, + fields[field].type as BuiltinType, + !!fields[field].array, + ); } } } @@ -1026,7 +1021,7 @@ export abstract class BaseOperationHandler { } } } - await this.processRelationUpdates( + const parentUpdates = await this.processRelationUpdates( kysely, model, field, @@ -1035,6 +1030,11 @@ export abstract class BaseOperationHandler { finalData[field], throwIfNotFound, ); + + if (Object.keys(parentUpdates).length > 0) { + // merge field updates propagated from nested relation processing + Object.assign(updateFields, parentUpdates); + } } } @@ -1354,10 +1354,11 @@ export abstract class BaseOperationHandler { ) { const tasks: Promise[] = []; const fieldModel = fieldDef.type as GetModels; - const fromRelationContext = { + const fromRelationContext: FromRelationContext = { model, field, ids: parentIds, + parentUpdates: {}, }; for (const [key, value] of Object.entries(args)) { @@ -1488,6 +1489,8 @@ export abstract class BaseOperationHandler { } await Promise.all(tasks); + + return fromRelationContext.parentUpdates; } // #region relation manipulation @@ -1532,10 +1535,9 @@ export abstract class BaseOperationHandler { fromRelation.model, fromRelation.field, ); - let updateResult: QueryResult; if (ownedByModel) { - // set parent fk directly + // record parent fk update invariant(_data.length === 1, 'only one entity can be connected'); const target = await this.readUnique(kysely, model, { where: _data[0], @@ -1543,25 +1545,10 @@ export abstract class BaseOperationHandler { if (!target) { throw new NotFoundError(model); } - const query = kysely - .updateTable(fromRelation.model) - .where((eb) => eb.and(fromRelation.ids)) - .set( - keyPairs.reduce( - (acc, { fk, pk }) => ({ - ...acc, - [fk]: target[pk], - }), - {} as any, - ), - ) - .modifyEnd( - this.makeContextComment({ - model: fromRelation.model, - operation: 'update', - }), - ); - updateResult = await this.executeQuery(kysely, query, 'connect'); + + for (const { fk, pk } of keyPairs) { + fromRelation.parentUpdates[fk] = target[pk]; + } } else { // disconnect current if it's a one-one relation const relationFieldDef = this.requireField(fromRelation.model, fromRelation.field); @@ -1599,13 +1586,13 @@ export abstract class BaseOperationHandler { operation: 'update', }), ); - updateResult = await this.executeQuery(kysely, query, 'connect'); - } + const updateResult = await this.executeQuery(kysely, query, 'connect'); - // validate connect result - if (_data.length > updateResult.numAffectedRows!) { - // some entities were not connected - throw new NotFoundError(model); + // validate connect result + if (!updateResult.numAffectedRows || _data.length > updateResult.numAffectedRows) { + // some entities were not connected + throw new NotFoundError(model); + } } } } @@ -1689,35 +1676,44 @@ export abstract class BaseOperationHandler { const eb = expressionBuilder(); if (ownedByModel) { - // set parent fk directly + // record parent fk update invariant(disconnectConditions.length === 1, 'only one entity can be disconnected'); const condition = disconnectConditions[0]; - const query = kysely - .updateTable(fromRelation.model) - // id filter - .where(eb.and(fromRelation.ids)) - // merge extra disconnect conditions - .$if(condition !== true, (qb) => - qb.where( - eb( - // @ts-ignore - eb.refTuple(...keyPairs.map(({ fk }) => fk)), - 'in', - eb - .selectFrom(model) - .select(keyPairs.map(({ pk }) => pk)) - .where(this.dialect.buildFilter(eb, model, model, condition)), - ), - ), - ) - .set(keyPairs.reduce((acc, { fk }) => ({ ...acc, [fk]: null }), {} as any)) - .modifyEnd( - this.makeContextComment({ - model: fromRelation.model, - operation: 'update', - }), - ); - await this.executeQuery(kysely, query, 'disconnect'); + + if (condition === true) { + // just disconnect, record parent fk update + for (const { fk } of keyPairs) { + fromRelation.parentUpdates[fk] = null; + } + } else { + // disconnect with a filter + + // read parent's fk + const fromEntity = await this.readUnique(kysely, fromRelation.model, { + where: fromRelation.ids, + select: fieldsToSelectObject(keyPairs.map(({ fk }) => fk)), + }); + if (!fromEntity || keyPairs.some(({ fk }) => fromEntity[fk] == null)) { + return; + } + + // check if the disconnect target exists under parent fk and the filter condition + const relationFilter = { + AND: [condition, Object.fromEntries(keyPairs.map(({ fk, pk }) => [pk, fromEntity[fk]]))], + }; + + // if the target exists, record parent fk update, otherwise do nothing + const targetExists = await this.read(kysely, model, { + where: relationFilter, + take: 1, + select: this.makeIdSelect(model), + } as any); + if (targetExists.length > 0) { + for (const { fk } of keyPairs) { + fromRelation.parentUpdates[fk] = null; + } + } + } } else { // disconnect const query = kysely @@ -1841,7 +1837,7 @@ export abstract class BaseOperationHandler { const r = await this.executeQuery(kysely, query, 'connect'); // validate result - if (_data.length > r.numAffectedRows!) { + if (!r.numAffectedRows || _data.length > r.numAffectedRows) { // some entities were not connected throw new NotFoundError(model); } @@ -1874,9 +1870,12 @@ export abstract class BaseOperationHandler { } let deleteResult: { count: number }; + let deleteFromModel: GetModels; const m2m = getManyToManyRelation(this.schema, fromRelation.model, fromRelation.field); if (m2m) { + deleteFromModel = model; + // handle many-to-many relation const fieldDef = this.requireField(fromRelation.model, fromRelation.field); invariant(fieldDef.relation?.opposite); @@ -1901,11 +1900,13 @@ export abstract class BaseOperationHandler { ); if (ownedByModel) { + deleteFromModel = fromRelation.model; + const fromEntity = await this.readUnique(kysely, fromRelation.model as GetModels, { where: fromRelation.ids, }); if (!fromEntity) { - throw new NotFoundError(model); + throw new NotFoundError(fromRelation.model); } const fieldDef = this.requireField(fromRelation.model, fromRelation.field); @@ -1920,6 +1921,7 @@ export abstract class BaseOperationHandler { ], }); } else { + deleteFromModel = model; deleteResult = await this.delete(kysely, model, { AND: [ Object.fromEntries(keyPairs.map(({ fk, pk }) => [fk, fromRelation.ids[pk]])), @@ -1934,7 +1936,7 @@ export abstract class BaseOperationHandler { // validate result if (throwForNotFound && expectedDeleteCount > deleteResult.count) { // some entities were not deleted - throw new NotFoundError(model); + throw new NotFoundError(deleteFromModel); } } diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index b4097dea..372129ff 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -22,12 +22,11 @@ import { type UpdateManyArgs, type UpsertArgs, } from '../crud-types'; -import { InputValidationError, InternalError, QueryError } from '../errors'; +import { InputValidationError, InternalError } from '../errors'; import { fieldHasDefaultValue, getDiscriminatorField, getEnum, - getModel, getUniqueFields, requireField, requireModel, @@ -279,10 +278,7 @@ export class InputValidator { withoutRelationFields = false, withAggregations = false, ): ZodType { - const modelDef = getModel(this.schema, model); - if (!modelDef) { - throw new QueryError(`Model "${model}" not found in schema`); - } + const modelDef = requireModel(this.schema, model); const fields: Record = {}; for (const field of Object.keys(modelDef.fields)) { diff --git a/packages/runtime/src/client/errors.ts b/packages/runtime/src/client/errors.ts index 38c5077b..1d6134e9 100644 --- a/packages/runtime/src/client/errors.ts +++ b/packages/runtime/src/client/errors.ts @@ -25,7 +25,7 @@ export class InternalError extends Error {} * Error thrown when an entity is not found. */ export class NotFoundError extends Error { - constructor(model: string) { - super(`Entity not found for model "${model}"`); + constructor(model: string, details?: string) { + super(`Entity not found for model "${model}"${details ? `: ${details}` : ''}`); } } diff --git a/packages/runtime/src/client/executor/kysely-utils.ts b/packages/runtime/src/client/executor/kysely-utils.ts index 5ae92d39..fb9ec845 100644 --- a/packages/runtime/src/client/executor/kysely-utils.ts +++ b/packages/runtime/src/client/executor/kysely-utils.ts @@ -1,13 +1,11 @@ -import { invariant } from '@zenstackhq/common-helpers'; -import { type OperationNode, AliasNode, IdentifierNode } from 'kysely'; +import { type OperationNode, AliasNode } from 'kysely'; /** * Strips alias from the node if it exists. */ export function stripAlias(node: OperationNode) { if (AliasNode.is(node)) { - invariant(IdentifierNode.is(node.alias), 'Expected identifier as alias'); - return { alias: node.alias.name, node: node.node }; + return { alias: node.alias, node: node.node }; } else { return { alias: undefined, node }; } diff --git a/packages/runtime/src/client/executor/name-mapper.ts b/packages/runtime/src/client/executor/name-mapper.ts index cc8163c1..c839bc75 100644 --- a/packages/runtime/src/client/executor/name-mapper.ts +++ b/packages/runtime/src/client/executor/name-mapper.ts @@ -22,7 +22,7 @@ import { stripAlias } from './kysely-utils'; type Scope = { model?: string; - alias?: string; + alias?: OperationNode; namesMapped?: boolean; // true means fields referring to this scope have their names already mapped }; @@ -120,7 +120,7 @@ export class QueryNameMapper extends OperationNodeTransformer { // map table name depending on how it is resolved let mappedTableName = node.table?.table.identifier.name; if (mappedTableName) { - if (scope.alias === mappedTableName) { + if (scope.alias && IdentifierNode.is(scope.alias) && scope.alias.name === mappedTableName) { // table name is resolved to an alias, no mapping needed } else if (scope.model === mappedTableName) { // table name is resolved to a model, map the name as needed @@ -222,7 +222,14 @@ export class QueryNameMapper extends OperationNodeTransformer { const origFieldName = this.extractFieldName(selection.selection); const fieldName = this.extractFieldName(transformed); if (fieldName !== origFieldName) { - selections.push(SelectionNode.create(this.wrapAlias(transformed, origFieldName))); + selections.push( + SelectionNode.create( + this.wrapAlias( + transformed, + origFieldName ? IdentifierNode.create(origFieldName) : undefined, + ), + ), + ); } else { selections.push(SelectionNode.create(transformed)); } @@ -241,7 +248,7 @@ export class QueryNameMapper extends OperationNodeTransformer { // if the field as a qualifier, the qualifier must match the scope's // alias if any, or model if no alias if (scope.alias) { - if (scope.alias === qualifier) { + if (scope.alias && IdentifierNode.is(scope.alias) && scope.alias.name === qualifier) { // scope has an alias that matches the qualifier return scope; } else { @@ -295,8 +302,8 @@ export class QueryNameMapper extends OperationNodeTransformer { } } - private wrapAlias(node: T, alias: string | undefined) { - return alias ? AliasNode.create(node, IdentifierNode.create(alias)) : node; + private wrapAlias(node: T, alias: OperationNode | undefined) { + return alias ? AliasNode.create(node, alias) : node; } private processTableRef(node: TableNode) { @@ -351,11 +358,11 @@ export class QueryNameMapper extends OperationNodeTransformer { // inner transformations will map column names const modelName = innerNode.table.identifier.name; const mappedName = this.mapTableName(modelName); - const finalAlias = alias ?? (mappedName !== modelName ? modelName : undefined); + const finalAlias = alias ?? (mappedName !== modelName ? IdentifierNode.create(modelName) : undefined); return { node: this.wrapAlias(TableNode.create(mappedName), finalAlias), scope: { - alias: alias ?? modelName, + alias: alias ?? IdentifierNode.create(modelName), model: modelName, namesMapped: !this.hasMappedColumns(modelName), }, @@ -374,13 +381,13 @@ export class QueryNameMapper extends OperationNodeTransformer { } } - private createSelectAllFields(model: string, alias: string | undefined) { + private createSelectAllFields(model: string, alias: OperationNode | undefined) { const modelDef = requireModel(this.schema, model); return this.getModelFields(modelDef).map((fieldDef) => { const columnName = this.mapFieldName(model, fieldDef.name); const columnRef = ReferenceNode.create( ColumnNode.create(columnName), - alias ? TableNode.create(alias) : undefined, + alias && IdentifierNode.is(alias) ? TableNode.create(alias.name) : undefined, ); if (columnName !== fieldDef.name) { const aliased = AliasNode.create(columnRef, IdentifierNode.create(fieldDef.name)); @@ -421,7 +428,7 @@ export class QueryNameMapper extends OperationNodeTransformer { alias = this.extractFieldName(node); } const result = super.transformNode(node); - return this.wrapAlias(result, alias); + return this.wrapAlias(result, alias ? IdentifierNode.create(alias) : undefined); } private processSelectAll(node: SelectAllNode) { @@ -438,7 +445,9 @@ export class QueryNameMapper extends OperationNodeTransformer { return this.getModelFields(modelDef).map((fieldDef) => { const columnName = this.mapFieldName(modelDef.name, fieldDef.name); const columnRef = ReferenceNode.create(ColumnNode.create(columnName)); - return columnName !== fieldDef.name ? this.wrapAlias(columnRef, fieldDef.name) : columnRef; + return columnName !== fieldDef.name + ? this.wrapAlias(columnRef, IdentifierNode.create(fieldDef.name)) + : columnRef; }); } diff --git a/packages/runtime/src/client/executor/zenstack-query-executor.ts b/packages/runtime/src/client/executor/zenstack-query-executor.ts index 768f65ae..be317924 100644 --- a/packages/runtime/src/client/executor/zenstack-query-executor.ts +++ b/packages/runtime/src/client/executor/zenstack-query-executor.ts @@ -100,7 +100,6 @@ export class ZenStackQueryExecutor extends DefaultQuer const hookResult = await hook!({ client: this.client as ClientContract, schema: this.client.$schema, - kysely: this.kysely, query, proceed: _p, }); diff --git a/packages/runtime/src/client/options.ts b/packages/runtime/src/client/options.ts index 3146a402..7c90e330 100644 --- a/packages/runtime/src/client/options.ts +++ b/packages/runtime/src/client/options.ts @@ -2,7 +2,7 @@ import type { Dialect, Expression, ExpressionBuilder, KyselyConfig } from 'kysel import type { GetModel, GetModels, ProcedureDef, SchemaDef } from '../schema'; import type { PrependParameter } from '../utils/type-utils'; import type { ClientContract, CRUD, ProcedureFunc } from './contract'; -import type { BaseCrudDialect } from './crud/dialects/base'; +import type { BaseCrudDialect } from './crud/dialects/base-dialect'; import type { RuntimePlugin } from './plugin'; import type { ToKyselySchema } from './query-builder'; diff --git a/packages/runtime/src/client/plugin.ts b/packages/runtime/src/client/plugin.ts index 0a4c4a7f..62216a3d 100644 --- a/packages/runtime/src/client/plugin.ts +++ b/packages/runtime/src/client/plugin.ts @@ -1,5 +1,5 @@ import type { OperationNode, QueryResult, RootOperationNode, UnknownRow } from 'kysely'; -import type { ClientContract, ToKysely } from '.'; +import type { ClientContract } from '.'; import type { GetModels, SchemaDef } from '../schema'; import type { MaybePromise } from '../utils/type-utils'; import type { AllCrudOperation } from './crud/operations/base'; @@ -180,7 +180,6 @@ export type PluginAfterEntityMutationArgs = MutationHo // #region OnKyselyQuery hooks export type OnKyselyQueryArgs = { - kysely: ToKysely; schema: SchemaDef; client: ClientContract; query: RootOperationNode; diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index 6f961029..fdce2aaf 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -14,15 +14,19 @@ export function hasModel(schema: SchemaDef, model: string) { } export function getModel(schema: SchemaDef, model: string) { - return schema.models[model]; + return Object.values(schema.models).find((m) => m.name.toLowerCase() === model.toLowerCase()); +} + +export function getTypeDef(schema: SchemaDef, type: string) { + return schema.typeDefs?.[type]; } export function requireModel(schema: SchemaDef, model: string) { - const matchedName = Object.keys(schema.models).find((k) => k.toLowerCase() === model.toLowerCase()); - if (!matchedName) { + const modelDef = getModel(schema, model); + if (!modelDef) { throw new QueryError(`Model "${model}" not found in schema`); } - return schema.models[matchedName]!; + return modelDef; } export function getField(schema: SchemaDef, model: string, field: string) { @@ -30,12 +34,24 @@ export function getField(schema: SchemaDef, model: string, field: string) { return modelDef?.fields[field]; } -export function requireField(schema: SchemaDef, model: string, field: string) { - const modelDef = requireModel(schema, model); - if (!modelDef.fields[field]) { - throw new QueryError(`Field "${field}" not found in model "${model}"`); +export function requireField(schema: SchemaDef, modelOrType: string, field: string) { + const modelDef = getModel(schema, modelOrType); + if (modelDef) { + if (!modelDef.fields[field]) { + throw new QueryError(`Field "${field}" not found in model "${modelOrType}"`); + } else { + return modelDef.fields[field]; + } + } + const typeDef = getTypeDef(schema, modelOrType); + if (typeDef) { + if (!typeDef.fields[field]) { + throw new QueryError(`Field "${field}" not found in type "${modelOrType}"`); + } else { + return typeDef.fields[field]; + } } - return modelDef.fields[field]; + throw new QueryError(`Model or type "${modelOrType}" not found in schema`); } export function getIdFields(schema: SchemaDef, model: GetModels) { diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index bbc98881..9cf81ccc 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -22,10 +22,10 @@ import { import { match } from 'ts-pattern'; import type { CRUD } from '../../client/contract'; import { getCrudDialect } from '../../client/crud/dialects'; -import type { BaseCrudDialect } from '../../client/crud/dialects/base'; +import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import { InternalError, QueryError } from '../../client/errors'; import type { ClientOptions } from '../../client/options'; -import { getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils'; +import { getModel, getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils'; import type { BinaryExpression, BinaryOperator, @@ -51,8 +51,6 @@ export type ExpressionTransformerContext = { model: GetModels; alias?: string; operation: CRUD; - thisEntity?: Record; - thisEntityRaw?: Record; auth?: any; memberFilter?: OperationNode; memberSelect?: SelectionNode; @@ -86,7 +84,7 @@ export class ExpressionTransformer { if (!this.schema.authType) { throw new InternalError('Schema does not have an "authType" specified'); } - return this.schema.authType; + return this.schema.authType!; } transform(expression: Expression, context: ExpressionTransformerContext): OperationNode { @@ -117,11 +115,7 @@ export class ExpressionTransformer { private _field(expr: FieldExpression, context: ExpressionTransformerContext) { const fieldDef = requireField(this.schema, context.model, expr.field); if (!fieldDef.relation) { - if (context.thisEntity) { - return context.thisEntity[expr.field]; - } else { - return this.createColumnRef(expr.field, context); - } + return this.createColumnRef(expr.field, context); } else { const { memberFilter, memberSelect, ...restContext } = context; const relation = this.transformRelationAccess(expr.field, fieldDef.type, restContext); @@ -159,7 +153,7 @@ export class ExpressionTransformer { } if (this.isAuthCall(expr.left) || this.isAuthCall(expr.right)) { - return this.transformAuthBinary(expr); + return this.transformAuthBinary(expr, context); } const op = expr.op; @@ -234,11 +228,10 @@ export class ExpressionTransformer { ...context, model: newContextModel as GetModels, alias: undefined, - thisEntity: undefined, }); if (expr.op === '!') { - predicateFilter = logicalNot(predicateFilter); + predicateFilter = logicalNot(this.dialect, predicateFilter); } const count = FunctionNode.create('count', [ValueNode.createImmediate(1)]); @@ -256,21 +249,50 @@ export class ExpressionTransformer { }); } - private transformAuthBinary(expr: BinaryExpression) { + private transformAuthBinary(expr: BinaryExpression, context: ExpressionTransformerContext) { if (expr.op !== '==' && expr.op !== '!=') { - throw new Error(`Unsupported operator for auth call: ${expr.op}`); + throw new QueryError( + `Unsupported operator for \`auth()\` in policy of model "${context.model}": ${expr.op}`, + ); } + + let authExpr: Expression; let other: Expression; if (this.isAuthCall(expr.left)) { + authExpr = expr.left; other = expr.right; } else { + authExpr = expr.right; other = expr.left; } if (ExpressionUtils.isNull(other)) { return this.transformValue(expr.op === '==' ? !this.auth : !!this.auth, 'Boolean'); } else { - throw new Error('Unsupported binary expression with `auth()`'); + const authModel = getModel(this.schema, this.authType); + if (!authModel) { + throw new QueryError( + `Unsupported use of \`auth()\` in policy of model "${context.model}", comparing with \`auth()\` is only possible when auth type is a model`, + ); + } + + const idFields = Object.values(authModel.fields) + .filter((f) => f.id) + .map((f) => f.name); + invariant(idFields.length > 0, 'auth type model must have at least one id field'); + + const conditions = idFields.map((fieldName) => + ExpressionUtils.binary( + ExpressionUtils.member(authExpr, [fieldName]), + '==', + ExpressionUtils.member(other, [fieldName]), + ), + ); + let result = this.buildAnd(conditions); + if (expr.op === '!=') { + result = this.buildLogicalNot(result); + } + return this.transform(result, context); } } @@ -283,11 +305,7 @@ export class ExpressionTransformer { private _unary(expr: UnaryExpression, context: ExpressionTransformerContext) { // only '!' operator for now invariant(expr.op === '!', 'only "!" operator is supported'); - return BinaryOperationNode.create( - this.transform(expr.operand, context), - this.transformOperator('!='), - trueNode(this.dialect), - ); + return logicalNot(this.dialect, this.transform(expr.operand, context)); } private transformOperator(op: Exclude) { @@ -331,7 +349,7 @@ export class ExpressionTransformer { } if (ExpressionUtils.isField(arg)) { - return context.thisEntityRaw ? eb.val(context.thisEntityRaw[arg.field]) : eb.ref(arg.field); + return eb.ref(arg.field); } if (ExpressionUtils.isCall(arg)) { @@ -358,20 +376,46 @@ export class ExpressionTransformer { return this.valueMemberAccess(this.auth, expr, this.authType); } - invariant(ExpressionUtils.isField(expr.receiver), 'expect receiver to be field expression'); + invariant( + ExpressionUtils.isField(expr.receiver) || ExpressionUtils.isThis(expr.receiver), + 'expect receiver to be field expression or "this"', + ); + let members = expr.members; + let receiver: OperationNode; const { memberFilter, memberSelect, ...restContext } = context; - const receiver = this.transform(expr.receiver, restContext); + if (ExpressionUtils.isThis(expr.receiver)) { + if (expr.members.length === 1) { + // optimize for the simple this.scalar case + const fieldDef = requireField(this.schema, context.model, expr.members[0]!); + invariant(!fieldDef.relation, 'this.relation access should have been transformed into relation access'); + return this.createColumnRef(expr.members[0]!, restContext); + } + + // transform the first segment into a relation access, then continue with the rest of the members + const firstMemberFieldDef = requireField(this.schema, context.model, expr.members[0]!); + receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext); + members = expr.members.slice(1); + } else { + receiver = this.transform(expr.receiver, restContext); + } + invariant(SelectQueryNode.is(receiver), 'expected receiver to be select query'); - // relation member access - const receiverField = requireField(this.schema, context.model, expr.receiver.field); + let startType: string; + if (ExpressionUtils.isField(expr.receiver)) { + const receiverField = requireField(this.schema, context.model, expr.receiver.field); + startType = receiverField.type; + } else { + // "this." case, start type is the model of the context + startType = context.model; + } // traverse forward to collect member types const memberFields: { fromModel: string; fieldDef: FieldDef }[] = []; - let currType = receiverField.type; - for (const member of expr.members) { + let currType = startType; + for (const member of members) { const fieldDef = requireField(this.schema, currType, member); memberFields.push({ fieldDef, fromModel: currType }); currType = fieldDef.type; @@ -379,8 +423,8 @@ export class ExpressionTransformer { let currNode: SelectQueryNode | ColumnNode | ReferenceNode | undefined = undefined; - for (let i = expr.members.length - 1; i >= 0; i--) { - const member = expr.members[i]!; + for (let i = members.length - 1; i >= 0; i--) { + const member = members[i]!; const { fieldDef, fromModel } = memberFields[i]!; if (fieldDef.relation) { @@ -388,7 +432,6 @@ export class ExpressionTransformer { ...restContext, model: fromModel as GetModels, alias: undefined, - thisEntity: undefined, }); if (currNode) { @@ -396,9 +439,7 @@ export class ExpressionTransformer { currNode = { ...relation, selections: [ - SelectionNode.create( - AliasNode.create(currNode, IdentifierNode.create(expr.members[i + 1]!)), - ), + SelectionNode.create(AliasNode.create(currNode, IdentifierNode.create(members[i + 1]!))), ], }; } else { @@ -410,7 +451,7 @@ export class ExpressionTransformer { }; } } else { - invariant(i === expr.members.length - 1, 'plain field access must be the last segment'); + invariant(i === members.length - 1, 'plain field access must be the last segment'); invariant(!currNode, 'plain field access must be the last segment'); currNode = ColumnNode.create(member); @@ -446,71 +487,38 @@ export class ExpressionTransformer { const fromModel = context.model; const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs(this.schema, fromModel, field); - if (context.thisEntity) { - let condition: OperationNode; - if (ownedByModel) { - condition = conjunction( - this.dialect, - keyPairs.map(({ fk, pk }) => - BinaryOperationNode.create( - ReferenceNode.create(ColumnNode.create(pk), TableNode.create(relationModel)), - OperatorNode.create('='), - context.thisEntity![fk]!, - ), - ), - ); - } else { - condition = conjunction( - this.dialect, - keyPairs.map(({ fk, pk }) => - BinaryOperationNode.create( - ReferenceNode.create(ColumnNode.create(fk), TableNode.create(relationModel)), - OperatorNode.create('='), - context.thisEntity![pk]!, - ), + let condition: OperationNode; + if (ownedByModel) { + // `fromModel` owns the fk + condition = conjunction( + this.dialect, + keyPairs.map(({ fk, pk }) => + BinaryOperationNode.create( + ReferenceNode.create(ColumnNode.create(fk), TableNode.create(context.alias ?? fromModel)), + OperatorNode.create('='), + ReferenceNode.create(ColumnNode.create(pk), TableNode.create(relationModel)), ), - ); - } - - return { - kind: 'SelectQueryNode', - from: FromNode.create([TableNode.create(relationModel)]), - where: WhereNode.create(condition), - }; + ), + ); } else { - let condition: OperationNode; - if (ownedByModel) { - // `fromModel` owns the fk - condition = conjunction( - this.dialect, - keyPairs.map(({ fk, pk }) => - BinaryOperationNode.create( - ReferenceNode.create(ColumnNode.create(fk), TableNode.create(context.alias ?? fromModel)), - OperatorNode.create('='), - ReferenceNode.create(ColumnNode.create(pk), TableNode.create(relationModel)), - ), - ), - ); - } else { - // `relationModel` owns the fk - condition = conjunction( - this.dialect, - keyPairs.map(({ fk, pk }) => - BinaryOperationNode.create( - ReferenceNode.create(ColumnNode.create(pk), TableNode.create(context.alias ?? fromModel)), - OperatorNode.create('='), - ReferenceNode.create(ColumnNode.create(fk), TableNode.create(relationModel)), - ), + // `relationModel` owns the fk + condition = conjunction( + this.dialect, + keyPairs.map(({ fk, pk }) => + BinaryOperationNode.create( + ReferenceNode.create(ColumnNode.create(pk), TableNode.create(context.alias ?? fromModel)), + OperatorNode.create('='), + ReferenceNode.create(ColumnNode.create(fk), TableNode.create(relationModel)), ), - ); - } - - return { - kind: 'SelectQueryNode', - from: FromNode.create([TableNode.create(relationModel)]), - where: WhereNode.create(condition), - }; + ), + ); } + + return { + kind: 'SelectQueryNode', + from: FromNode.create([TableNode.create(relationModel)]), + where: WhereNode.create(condition), + }; } private createColumnRef(column: string, context: ExpressionTransformerContext): ReferenceNode { @@ -528,4 +536,18 @@ export class ExpressionTransformer { private isNullNode(node: OperationNode) { return ValueNode.is(node) && node.value === null; } + + private buildLogicalNot(result: Expression): Expression { + return ExpressionUtils.unary('!', result); + } + + private buildAnd(conditions: BinaryExpression[]): Expression { + if (conditions.length === 0) { + return ExpressionUtils.literal(true); + } else if (conditions.length === 1) { + return conditions[0]!; + } else { + return conditions.reduce((acc, condition) => ExpressionUtils.binary(acc, '&&', condition)); + } + } } diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index 7cb672c2..f26c2038 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -5,10 +5,12 @@ import { ColumnNode, DeleteQueryNode, FromNode, + FunctionNode, IdentifierNode, InsertQueryNode, OperationNodeTransformer, OperatorNode, + ParensNode, PrimitiveValueListNode, RawNode, ReturningNode, @@ -16,6 +18,7 @@ import { SelectQueryNode, TableNode, UpdateQueryNode, + ValueListNode, ValueNode, ValuesNode, WhereNode, @@ -27,7 +30,7 @@ import { match } from 'ts-pattern'; import type { ClientContract } from '../../client'; import type { CRUD } from '../../client/contract'; import { getCrudDialect } from '../../client/crud/dialects'; -import type { BaseCrudDialect } from '../../client/crud/dialects/base'; +import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import { InternalError } from '../../client/errors'; import type { ProceedKyselyQueryFunction } from '../../client/plugin'; import { getIdFields, requireField, requireModel } from '../../client/query-utils'; @@ -103,7 +106,7 @@ export class PolicyHandler extends OperationNodeTransf } // TODO: run in transaction - //let readBackError = false; + // let readBackError = false; // transform and post-process in a transaction // const result = await transaction(async (txProceed) => { @@ -142,19 +145,14 @@ export class PolicyHandler extends OperationNodeTransf } private async enforcePreCreatePolicy(node: InsertQueryNode, proceed: ProceedKyselyQueryFunction) { - if (!node.columns || !node.values) { - return; - } - const model = this.getMutationModel(node); - const fields = node.columns.map((c) => c.column.name); - const valueRows = this.unwrapCreateValueRows(node.values, model, fields); + const fields = node.columns?.map((c) => c.column.name) ?? []; + const valueRows = node.values ? this.unwrapCreateValueRows(node.values, model, fields) : [[]]; for (const values of valueRows) { await this.enforcePreCreatePolicyForOne( model, fields, values.map((v) => v.node), - values.map((v) => v.raw), proceed, ); } @@ -164,23 +162,59 @@ export class PolicyHandler extends OperationNodeTransf model: GetModels, fields: string[], values: OperationNode[], - valuesRaw: unknown[], proceed: ProceedKyselyQueryFunction, ) { - const thisEntity: Record = {}; - const thisEntityRaw: Record = {}; - for (let i = 0; i < fields.length; i++) { - thisEntity[fields[i]!] = values[i]!; - thisEntityRaw[fields[i]!] = valuesRaw[i]!; + const allFields = Object.keys(requireModel(this.client.$schema, model).fields); + const allValues: OperationNode[] = []; + + for (const fieldName of allFields) { + const index = fields.indexOf(fieldName); + if (index >= 0) { + allValues.push(values[index]!); + } else { + // set non-provided fields to null + allValues.push(ValueNode.createImmediate(null)); + } } - const filter = this.buildPolicyFilter(model, undefined, 'create', thisEntity, thisEntityRaw); + // create a `SELECT column1 as field1, column2 as field2, ... FROM (VALUES (...))` table for policy evaluation + const constTable: SelectQueryNode = { + kind: 'SelectQueryNode', + from: FromNode.create([ + AliasNode.create( + ParensNode.create(ValuesNode.create([ValueListNode.create(allValues)])), + IdentifierNode.create('$t'), + ), + ]), + selections: allFields.map((field, index) => + SelectionNode.create( + AliasNode.create(ColumnNode.create(`column${index + 1}`), IdentifierNode.create(field)), + ), + ), + }; + + const filter = this.buildPolicyFilter(model, undefined, 'create'); + const preCreateCheck: SelectQueryNode = { kind: 'SelectQueryNode', - selections: [SelectionNode.create(AliasNode.create(filter, IdentifierNode.create('$condition')))], + from: FromNode.create([AliasNode.create(constTable, IdentifierNode.create(model))]), + selections: [ + SelectionNode.create( + AliasNode.create( + BinaryOperationNode.create( + FunctionNode.create('COUNT', [ValueNode.createImmediate(1)]), + OperatorNode.create('>'), + ValueNode.createImmediate(0), + ), + IdentifierNode.create('$condition'), + ), + ), + ], + where: WhereNode.create(filter), }; + const result = await proceed(preCreateCheck); - if (!(result.rows[0] as any)?.$condition) { + if (!result.rows[0]?.$condition) { throw new RejectedByPolicyError(model); } } @@ -327,13 +361,7 @@ export class PolicyHandler extends OperationNodeTransf return InsertQueryNode.is(node) || UpdateQueryNode.is(node) || DeleteQueryNode.is(node); } - private buildPolicyFilter( - model: GetModels, - alias: string | undefined, - operation: CRUD, - thisEntity?: Record, - thisEntityRaw?: Record, - ) { + private buildPolicyFilter(model: GetModels, alias: string | undefined, operation: CRUD) { const policies = this.getModelPolicies(model, operation); if (policies.length === 0) { return falseNode(this.dialect); @@ -341,11 +369,11 @@ export class PolicyHandler extends OperationNodeTransf const allows = policies .filter((policy) => policy.kind === 'allow') - .map((policy) => this.transformPolicyCondition(model, alias, operation, policy, thisEntity, thisEntityRaw)); + .map((policy) => this.transformPolicyCondition(model, alias, operation, policy)); const denies = policies .filter((policy) => policy.kind === 'deny') - .map((policy) => this.transformPolicyCondition(model, alias, operation, policy, thisEntity, thisEntityRaw)); + .map((policy) => this.transformPolicyCondition(model, alias, operation, policy)); let combinedPolicy: OperationNode; @@ -458,8 +486,6 @@ export class PolicyHandler extends OperationNodeTransf alias: string | undefined, operation: CRUD, policy: Policy, - thisEntity?: Record, - thisEntityRaw?: Record, ) { return new ExpressionTransformer(this.client.$schema, this.client.$options, this.client.$auth).transform( policy.condition, @@ -467,8 +493,6 @@ export class PolicyHandler extends OperationNodeTransf model, alias, operation, - thisEntity, - thisEntityRaw, auth: this.client.$auth, }, ); diff --git a/packages/runtime/src/plugins/policy/utils.ts b/packages/runtime/src/plugins/policy/utils.ts index 3c2e641d..1113cb7e 100644 --- a/packages/runtime/src/plugins/policy/utils.ts +++ b/packages/runtime/src/plugins/policy/utils.ts @@ -12,7 +12,7 @@ import { UnaryOperationNode, ValueNode, } from 'kysely'; -import type { BaseCrudDialect } from '../../client/crud/dialects/base'; +import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import type { SchemaDef } from '../../schema'; /** @@ -50,6 +50,12 @@ export function conjunction( dialect: BaseCrudDialect, nodes: OperationNode[], ): OperationNode { + if (nodes.length === 0) { + return trueNode(dialect); + } + if (nodes.length === 1) { + return nodes[0]!; + } if (nodes.some(isFalseNode)) { return falseNode(dialect); } @@ -57,17 +63,19 @@ export function conjunction( if (items.length === 0) { return trueNode(dialect); } - return items.reduce((acc, node) => - OrNode.is(node) - ? AndNode.create(acc, ParensNode.create(node)) // wraps parentheses - : AndNode.create(acc, node), - ); + return items.reduce((acc, node) => AndNode.create(wrapParensIf(acc, OrNode.is), wrapParensIf(node, OrNode.is))); } export function disjunction( dialect: BaseCrudDialect, nodes: OperationNode[], ): OperationNode { + if (nodes.length === 0) { + return falseNode(dialect); + } + if (nodes.length === 1) { + return nodes[0]!; + } if (nodes.some(isTrueNode)) { return trueNode(dialect); } @@ -75,25 +83,32 @@ export function disjunction( if (items.length === 0) { return falseNode(dialect); } - return items.reduce((acc, node) => - AndNode.is(node) - ? OrNode.create(acc, ParensNode.create(node)) // wraps parentheses - : OrNode.create(acc, node), - ); + return items.reduce((acc, node) => OrNode.create(wrapParensIf(acc, AndNode.is), wrapParensIf(node, AndNode.is))); } /** * Negates a logical expression. */ -export function logicalNot(node: OperationNode): OperationNode { +export function logicalNot( + dialect: BaseCrudDialect, + node: OperationNode, +): OperationNode { + if (isTrueNode(node)) { + return falseNode(dialect); + } + if (isFalseNode(node)) { + return trueNode(dialect); + } return UnaryOperationNode.create( OperatorNode.create('not'), - AndNode.is(node) || OrNode.is(node) - ? ParensNode.create(node) // wraps parentheses - : node, + wrapParensIf(node, (n) => AndNode.is(n) || OrNode.is(n)), ); } +function wrapParensIf(node: OperationNode, predicate: (node: OperationNode) => boolean): OperationNode { + return predicate(node) ? ParensNode.create(node) : node; +} + /** * Builds an expression node that checks if a node is true. */ diff --git a/packages/runtime/src/schema/expression.ts b/packages/runtime/src/schema/expression.ts index a650391a..2e2337fa 100644 --- a/packages/runtime/src/schema/expression.ts +++ b/packages/runtime/src/schema/expression.ts @@ -88,6 +88,10 @@ export const ExpressionUtils = { return expressions.reduce((acc, exp) => ExpressionUtils.binary(acc, '||', exp), expr); }, + not: (expr: Expression) => { + return ExpressionUtils.unary('!', expr); + }, + is: (value: unknown, kind: Expression['kind']): value is Expression => { return !!value && typeof value === 'object' && 'kind' in value && value.kind === kind; }, diff --git a/packages/runtime/src/utils/type-utils.ts b/packages/runtime/src/utils/type-utils.ts index f6d784f0..e5cd1f33 100644 --- a/packages/runtime/src/utils/type-utils.ts +++ b/packages/runtime/src/utils/type-utils.ts @@ -16,10 +16,12 @@ export type Simplify = D extends 0 : { [K in keyof T]: Simplify } & {} : T; -export type WrapType = Optional extends true - ? T | null - : Array extends true - ? T[] +export type WrapType = Array extends true + ? Optional extends true + ? T[] | null + : T[] + : Optional extends true + ? T | null : T; type TypeMap = { diff --git a/packages/runtime/test/client-api/update.test.ts b/packages/runtime/test/client-api/update.test.ts index 2fc75fb8..a82a87bc 100644 --- a/packages/runtime/test/client-api/update.test.ts +++ b/packages/runtime/test/client-api/update.test.ts @@ -1815,6 +1815,21 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client update tests', ({ createCli user: { connect: { id: '1' } }, }, }); + // not matching filter, no-op + await expect( + client.profile.update({ + where: { id: profile.id }, + data: { + user: { + disconnect: { id: '2' }, + }, + }, + include: { user: true }, + }), + ).resolves.toMatchObject({ + user: { id: '1' }, + }); + // connected, disconnect await expect( client.profile.update({ where: { id: profile.id }, @@ -1828,8 +1843,7 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client update tests', ({ createCli ).resolves.toMatchObject({ user: null, }); - - // non-existing + // not connected, no-op await expect( client.profile.update({ where: { id: profile.id }, diff --git a/packages/runtime/test/plugin/on-kysely-query.test.ts b/packages/runtime/test/plugin/on-kysely-query.test.ts index 7e0ac024..75105927 100644 --- a/packages/runtime/test/plugin/on-kysely-query.test.ts +++ b/packages/runtime/test/plugin/on-kysely-query.test.ts @@ -84,7 +84,7 @@ describe('On kysely query tests', () => { it('supports spawning multiple queries', async () => { const client = _client.$use({ id: 'test-plugin', - async onKyselyQuery({ kysely, proceed, query }) { + async onKyselyQuery({ client, proceed, query }) { if (query.kind !== 'InsertQueryNode') { return proceed(query); } @@ -92,7 +92,7 @@ describe('On kysely query tests', () => { const result = await proceed(query); // create a post for the user - await proceed(createPost(kysely, result)); + await proceed(createPost(client.$qb, result)); return result; }, diff --git a/packages/runtime/test/policy/auth-equality.test.ts b/packages/runtime/test/policy/auth-equality.test.ts new file mode 100644 index 00000000..2baf46ed --- /dev/null +++ b/packages/runtime/test/policy/auth-equality.test.ts @@ -0,0 +1,109 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from './utils'; + +describe('Reference Equality Tests', () => { + it('works with create and auth equality', async () => { + const db = await createPolicyTestClient( + ` +model User { + id1 Int + id2 Int + posts Post[] + @@id([id1, id2]) + @@allow('all', auth() == this) + @@allow('read', true) +} + +model Post { + id Int @id @default(autoincrement()) + title String + authorId1 Int + authorId2 Int + author User @relation(fields: [authorId1, authorId2], references: [id1, id2]) + @@allow('all', auth() == author) +} + `, + ); + + await expect( + db.user.create({ + data: { id1: 1, id2: 2 }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.$setAuth({ id1: 1, id2: 2 }).user.create({ + data: { id1: 1, id2: 2 }, + }), + ).resolves.toMatchObject({ id1: 1, id2: 2 }); + + await expect( + db.post.create({ + data: { authorId1: 1, authorId2: 2, title: 'Post 1' }, + }), + ).toBeRejectedByPolicy(); + await expect( + db.post.create({ + data: { author: { connect: { id1_id2: { id1: 1, id2: 2 } } }, title: 'Post 1' }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.$setAuth({ id1: 1, id2: 2 }).post.create({ + data: { authorId1: 1, authorId2: 2, title: 'Post 1' }, + }), + ).resolves.toMatchObject({ title: 'Post 1' }); + await expect( + db.$setAuth({ id1: 1, id2: 2 }).post.create({ + data: { author: { connect: { id1_id2: { id1: 1, id2: 2 } } }, title: 'Post 2' }, + }), + ).resolves.toMatchObject({ title: 'Post 2' }); + }); + + it('works with create and auth inequality', async () => { + const db = await createPolicyTestClient( + ` +model User { + id1 Int + id2 Int + posts Post[] + @@id([id1, id2]) + @@allow('all', auth() != this) + @@allow('read', true) +} + +model Post { + id Int @id @default(autoincrement()) + title String + authorId1 Int + authorId2 Int + author User @relation(fields: [authorId1, authorId2], references: [id1, id2]) + @@allow('all', auth() != author) + @@allow('read', true) +} + `, + ); + + await expect( + db.$setAuth({ id1: 1, id2: 2 }).user.create({ + data: { id1: 1, id2: 2 }, + }), + ).toBeRejectedByPolicy(); + await expect( + db.$setAuth({ id1: 2, id2: 2 }).user.create({ + data: { id1: 1, id2: 2 }, + }), + ).toResolveTruthy(); + + await expect( + db.$setAuth({ id1: 1, id2: 2 }).post.create({ + data: { authorId1: 1, authorId2: 2, title: 'Post 1' }, + }), + ).toBeRejectedByPolicy(); + await expect( + db.$setAuth({ id1: 2, id2: 2 }).post.create({ + data: { authorId1: 1, authorId2: 2, title: 'Post 1' }, + }), + ).resolves.toMatchObject({ title: 'Post 1' }); + }); +}); diff --git a/packages/runtime/test/policy/crud/create.test.ts b/packages/runtime/test/policy/crud/create.test.ts new file mode 100644 index 00000000..dbd7a414 --- /dev/null +++ b/packages/runtime/test/policy/crud/create.test.ts @@ -0,0 +1,276 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy create tests', () => { + it('works with scalar field check', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id @default(autoincrement()) + x Int + @@allow('create', x > 0) + @@allow('read', true) +} +`, + ); + await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { x: 1 } })).resolves.toMatchObject({ x: 1 }); + + await expect( + db.$qb.insertInto('Foo').values({ x: 0 }).returningAll().executeTakeFirst(), + ).toBeRejectedByPolicy(); + await expect( + db.$qb.insertInto('Foo').values({ x: 1 }).returningAll().executeTakeFirst(), + ).resolves.toMatchObject({ x: 1 }); + }); + + it('works with this scalar member check', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id @default(autoincrement()) + x Int + @@allow('create', this.x > 0) + @@allow('read', true) +} +`, + ); + await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { x: 1 } })).resolves.toMatchObject({ x: 1 }); + }); + + it('denies by default', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id @default(autoincrement()) + x Int +} +`, + ); + await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + }); + + it('works with deny rule', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id @default(autoincrement()) + x Int + @@deny('create', x <= 0) + @@allow('create,read', true) +} +`, + ); + await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { x: 1 } })).resolves.toMatchObject({ x: 1 }); + }); + + it('works with mixed allow and deny rules', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id @default(autoincrement()) + x Int + @@deny('create', x <= 0) + @@allow('create', x <= 0 || x > 1) + @@allow('read', true) +} +`, + ); + await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { x: 1 } })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); + }); + + it('works with non-provided fields', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id @default(autoincrement()) + x Int @default(0) + @@allow('create', x > 0) + @@allow('read', true) +} +`, + ); + await expect(db.foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { x: 1 } })).toResolveTruthy(); + }); + + it('works with db-generated fields', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id @default(autoincrement()) + @@allow('create', id > 0) + @@allow('read', true) +} +`, + ); + await expect(db.foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { id: 1 } })).toResolveTruthy(); + }); + + it('rejects non-owned relation reference', async () => { + await expect( + createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? + @@allow('create', profile == null) + @@allow('read', true) +} + +model Profile { + id Int @id + name String + user User @relation(fields: [userId], references: [id]) + userId Int @unique +} + `, + ), + ).rejects.toThrow('non-owned relation fields are not allowed in "create" rules'); + }); + + it('works with auth check', async () => { + const db = await createPolicyTestClient( + ` +type Auth { + x Int + @@auth +} + +model Foo { + id Int @id @default(autoincrement()) + x Int + @@allow('create', x == auth().x) + @@allow('read', true) +} +`, + ); + await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + await expect(db.$setAuth({ x: 0 }).foo.create({ data: { x: 1 } })).toBeRejectedByPolicy(); + await expect(db.$setAuth({ x: 1 }).foo.create({ data: { x: 1 } })).resolves.toMatchObject({ x: 1 }); + }); + + it('works with owned to-one relation reference', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? + @@allow('all', true) +} + +model Profile { + id Int @id + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + + @@deny('all', auth() == null) + @@allow('create', user.id == auth().id) + @@allow('read', true) +} + `, + ); + + await db.user.create({ data: { id: 1 } }); + await expect(db.profile.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + await expect(db.$setAuth({ id: 0 }).profile.create({ data: { id: 1, userId: 1 } })).toBeRejectedByPolicy(); + await expect(db.$setAuth({ id: 1 }).profile.create({ data: { id: 1, userId: 1 } })).resolves.toMatchObject({ + id: 1, + }); + + await expect(db.profile.create({ data: { id: 2, user: { create: { id: 2 } } } })).toBeRejectedByPolicy(); + await expect(db.user.findUnique({ where: { id: 2 } })).toResolveNull(); + await expect( + db + .$setAuth({ id: 2 }) + .profile.create({ data: { id: 2, user: { create: { id: 2 } } }, include: { user: true } }), + ).resolves.toMatchObject({ + id: 2, + user: { + id: 2, + }, + }); + + await db.user.create({ data: { id: 3 } }); + await expect( + db.$setAuth({ id: 2 }).profile.create({ data: { id: 3, user: { connect: { id: 3 } } } }), + ).toBeRejectedByPolicy(); + await expect( + db.$setAuth({ id: 3 }).profile.create({ data: { id: 3, user: { connect: { id: 3 } } } }), + ).toResolveTruthy(); + + await expect(db.$setAuth({ id: 4 }).profile.create({ data: { id: 2, userId: 4 } })).toBeRejectedByPolicy(); + }); + + it('works with nested create owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? + @@allow('all', true) +} + +model Profile { + id Int @id + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + + @@deny('all', auth() == null) + @@allow('create', user.id == auth().id) + @@allow('read', true) +} + `, + ); + + await expect(db.user.create({ data: { id: 1, profile: { create: { id: 1 } } } })).toBeRejectedByPolicy(); + await expect( + db + .$setAuth({ id: 1 }) + .user.create({ data: { id: 1, profile: { create: { id: 1 } } }, include: { profile: true } }), + ).resolves.toMatchObject({ + id: 1, + profile: { + id: 1, + }, + }); + }); + + it('works with nested create non-owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? + @@deny('all', auth() == null) + @@allow('create', this.id == auth().id) + @@allow('read', true) +} + +model Profile { + id Int @id + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + @@allow('all', true) +} + `, + ); + + await expect(db.profile.create({ data: { id: 1, user: { create: { id: 1 } } } })).toBeRejectedByPolicy(); + await expect( + db + .$setAuth({ id: 1 }) + .profile.create({ data: { id: 1, user: { create: { id: 1 } } }, include: { user: true } }), + ).resolves.toMatchObject({ + id: 1, + user: { + id: 1, + }, + }); + }); +}); diff --git a/packages/runtime/test/policy/crud/dumb-rules.test.ts b/packages/runtime/test/policy/crud/dumb-rules.test.ts new file mode 100644 index 00000000..b169e3a0 --- /dev/null +++ b/packages/runtime/test/policy/crud/dumb-rules.test.ts @@ -0,0 +1,42 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy dumb rules tests', () => { + it('works with create dumb rules', async () => { + const db = await createPolicyTestClient( + ` +model A { + id Int @id @default(autoincrement()) + x Int + @@allow('create', 1 > 0) + @@allow('read', true) +} + +model B { + id Int @id @default(autoincrement()) + x Int + @@allow('create', 0 > 1) + @@allow('read', true) +} + +model C { + id Int @id @default(autoincrement()) + x Int + @@allow('create', true) + @@allow('read', true) +} + +model D { + id Int @id @default(autoincrement()) + x Int + @@allow('create', false) + @@allow('read', true) +} +`, + ); + await expect(db.a.create({ data: { x: 0 } })).resolves.toMatchObject({ x: 0 }); + await expect(db.b.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + await expect(db.c.create({ data: { x: 0 } })).resolves.toMatchObject({ x: 0 }); + await expect(db.d.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + }); +}); diff --git a/packages/runtime/test/policy/crud/update.test.ts b/packages/runtime/test/policy/crud/update.test.ts new file mode 100644 index 00000000..e0082a49 --- /dev/null +++ b/packages/runtime/test/policy/crud/update.test.ts @@ -0,0 +1,584 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Update policy tests', () => { + describe('Scalar condition tests', () => { + it('works with scalar field check', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('update', x > 0) + @@allow('create,read', true) +} +`, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 2, x: 1 } }); + await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); + + await expect( + db.$qb.updateTable('Foo').set({ x: 1 }).where('id', '=', 1).executeTakeFirst(), + ).resolves.toMatchObject({ numUpdatedRows: 0n }); + await expect( + db.$qb.updateTable('Foo').set({ x: 3 }).where('id', '=', 2).returningAll().execute(), + ).resolves.toMatchObject([{ id: 2, x: 3 }]); + }); + + it('works with this scalar member check', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('update', this.x > 0) + @@allow('create,read', true) +} +`, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 2, x: 1 } }); + await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); + }); + + it('denies by default', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create,read', true) +} +`, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + }); + + it('works with deny rule', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@deny('update', x <= 0) + @@allow('create,read,update', true) +} +`, + ); + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 2, x: 1 } }); + await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); + }); + + it('works with mixed allow and deny rules', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@deny('update', x <= 0) + @@allow('update', x <= 0 || x > 1) + @@allow('create,read', true) +} +`, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 2, x: 1 } }); + await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 3, x: 2 } }); + await expect(db.foo.update({ where: { id: 3 }, data: { x: 3 } })).resolves.toMatchObject({ x: 3 }); + }); + + it('works with auth check', async () => { + const db = await createPolicyTestClient( + ` +type Auth { + x Int + @@auth +} + +model Foo { + id Int @id + x Int + @@allow('update', x == auth().x) + @@allow('create,read', true) +} +`, + ); + await db.foo.create({ data: { id: 1, x: 1 } }); + await expect(db.$setAuth({ x: 0 }).foo.update({ where: { id: 1 }, data: { x: 2 } })).toBeRejectedNotFound(); + await expect(db.$setAuth({ x: 1 }).foo.update({ where: { id: 1 }, data: { x: 2 } })).resolves.toMatchObject( + { + x: 2, + }, + ); + }); + }); + + describe('Relation condition tests', () => { + it('works with to-one relation check owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + name String + profile Profile? + @@allow('all', true) +} + +model Profile { + id Int @id + bio String + user User @relation(fields: [userId], references: [id]) + userId Int @unique + @@allow('create,read', true) + @@allow('update', user.name == 'User2') +} +`, + ); + + await db.user.create({ data: { id: 1, name: 'User1', profile: { create: { id: 1, bio: 'Bio1' } } } }); + await expect(db.profile.update({ where: { id: 1 }, data: { bio: 'UpdatedBio1' } })).toBeRejectedNotFound(); + + await db.user.create({ data: { id: 2, name: 'User2', profile: { create: { id: 2, bio: 'Bio2' } } } }); + await expect(db.profile.update({ where: { id: 2 }, data: { bio: 'UpdatedBio2' } })).resolves.toMatchObject({ + bio: 'UpdatedBio2', + }); + }); + + it('works with to-one relation check owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + name String + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + @@allow('all', true) +} + +model Profile { + id Int @id + bio String + user User? + @@allow('create,read', true) + @@allow('update', user.name == 'User2') +} +`, + ); + + await db.user.create({ data: { id: 1, name: 'User1', profile: { create: { id: 1, bio: 'Bio1' } } } }); + await expect(db.profile.update({ where: { id: 1 }, data: { bio: 'UpdatedBio1' } })).toBeRejectedNotFound(); + + await db.user.create({ data: { id: 2, name: 'User2', profile: { create: { id: 2, bio: 'Bio2' } } } }); + await expect(db.profile.update({ where: { id: 2 }, data: { bio: 'UpdatedBio2' } })).resolves.toMatchObject({ + bio: 'UpdatedBio2', + }); + }); + + it('works with to-many relation check some', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + name String + posts Post[] + @@allow('create,read', true) + @@allow('update', posts?[published]) +} + +model Post { + id Int @id + title String + published Boolean + author User @relation(fields: [authorId], references: [id]) + authorId Int + @@allow('all', true) +} +`, + ); + + await db.user.create({ data: { id: 1, name: 'User1' } }); + await expect(db.user.update({ where: { id: 1 }, data: { name: 'UpdatedUser1' } })).toBeRejectedNotFound(); + + await db.user.create({ + data: { id: 2, name: 'User2', posts: { create: { id: 1, title: 'Post1', published: false } } }, + }); + await expect(db.user.update({ where: { id: 2 }, data: { name: 'UpdatedUser2' } })).toBeRejectedNotFound(); + + await db.user.create({ + data: { + id: 3, + name: 'User3', + posts: { + create: [ + { id: 2, title: 'Post2', published: false }, + { id: 3, title: 'Post3', published: true }, + ], + }, + }, + }); + await expect(db.user.update({ where: { id: 3 }, data: { name: 'UpdatedUser3' } })).toResolveTruthy(); + }); + + it('works with to-many relation check all', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + name String + posts Post[] + @@allow('create,read', true) + @@allow('update', posts![published]) +} + +model Post { + id Int @id + title String + published Boolean + author User @relation(fields: [authorId], references: [id]) + authorId Int + @@allow('all', true) +} +`, + ); + + await db.user.create({ data: { id: 1, name: 'User1' } }); + await expect(db.user.update({ where: { id: 1 }, data: { name: 'UpdatedUser1' } })).toResolveTruthy(); + + await db.user.create({ + data: { + id: 2, + name: 'User2', + posts: { + create: [ + { id: 1, title: 'Post1', published: false }, + { id: 2, title: 'Post2', published: true }, + ], + }, + }, + }); + await expect(db.user.update({ where: { id: 2 }, data: { name: 'UpdatedUser2' } })).toBeRejectedNotFound(); + + await db.user.create({ + data: { + id: 3, + name: 'User3', + posts: { + create: [ + { id: 3, title: 'Post3', published: true }, + { id: 4, title: 'Post4', published: true }, + ], + }, + }, + }); + await expect(db.user.update({ where: { id: 3 }, data: { name: 'UpdatedUser3' } })).toResolveTruthy(); + }); + + it('works with to-many relation check none', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + name String + posts Post[] + @@allow('create,read', true) + @@allow('update', posts^[published]) +} + +model Post { + id Int @id + title String + published Boolean + author User @relation(fields: [authorId], references: [id]) + authorId Int + @@allow('all', true) +} +`, + ); + + await db.user.create({ data: { id: 1, name: 'User1' } }); + await expect(db.user.update({ where: { id: 1 }, data: { name: 'UpdatedUser1' } })).toResolveTruthy(); + + await db.user.create({ + data: { + id: 2, + name: 'User2', + posts: { + create: [ + { id: 1, title: 'Post1', published: false }, + { id: 2, title: 'Post2', published: true }, + ], + }, + }, + }); + await expect(db.user.update({ where: { id: 2 }, data: { name: 'UpdatedUser2' } })).toBeRejectedNotFound(); + + await db.user.create({ + data: { + id: 3, + name: 'User3', + posts: { + create: [ + { id: 3, title: 'Post3', published: false }, + { id: 4, title: 'Post4', published: false }, + ], + }, + }, + }); + await expect(db.user.update({ where: { id: 3 }, data: { name: 'UpdatedUser3' } })).toResolveTruthy(); + }); + }); + + describe('Nested update tests', () => { + it('works with nested update owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? + @@allow('all', true) +} + +model Profile { + id Int @id + bio String + private Boolean + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + @@allow('create,read', true) + @@allow('update', !private) +} +`, + ); + + await db.user.create({ data: { id: 1, profile: { create: { id: 1, bio: 'Bio1', private: true } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { update: { bio: 'UpdatedBio1' } } }, + }), + ).toBeRejectedNotFound(); + + await db.user.create({ data: { id: 2, profile: { create: { id: 2, bio: 'Bio2', private: false } } } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { profile: { update: { bio: 'UpdatedBio2' } } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ + profile: { + bio: 'UpdatedBio2', + }, + }); + }); + + it('works with nested update non-owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + @@allow('all', true) +} + +model Profile { + id Int @id + bio String + private Boolean + user User? + @@allow('create,read', true) + @@allow('update', !private) +} +`, + ); + + await db.user.create({ data: { id: 1, profile: { create: { id: 1, bio: 'Bio1', private: true } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { update: { bio: 'UpdatedBio1' } } }, + }), + ).toBeRejectedNotFound(); + + await db.user.create({ data: { id: 2, profile: { create: { id: 2, bio: 'Bio2', private: false } } } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { profile: { update: { bio: 'UpdatedBio2' } } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ + profile: { + bio: 'UpdatedBio2', + }, + }); + }); + }); + + describe('Relation manipulation tests', () => { + it('works with connect/disconnect/create owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? + @@allow('all', true) +} + +model Profile { + id Int @id + private Boolean + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + @@allow('create,read', true) + @@allow('update', !private) +} +`, + ); + + await db.user.create({ data: { id: 1 } }); + + await db.profile.create({ data: { id: 1, private: true } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { connect: { id: 1 } } }, + include: { profile: true }, + }), + ).toBeRejectedNotFound(); + + await db.profile.create({ data: { id: 2, private: false } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { connect: { id: 2 } } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ + profile: { + id: 2, + }, + }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { disconnect: true } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ + profile: null, + }); + // reconnect + await db.user.update({ where: { id: 1 }, data: { profile: { connect: { id: 2 } } } }); + // set private + await db.profile.update({ where: { id: 2 }, data: { private: true } }); + // disconnect should have no effect since update is not allowed + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { disconnect: true } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ profile: { id: 2 } }); + + await db.profile.create({ data: { id: 3, private: true } }); + await expect( + db.profile.update({ + where: { id: 3 }, + data: { user: { create: { id: 2 } } }, + }), + ).toBeRejectedNotFound(); + }); + + it('works with connect/disconnect/create non-owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? @relation(fields: [profileId], references: [id]) + profileId Int? @unique + private Boolean + @@allow('create,read', true) + @@allow('update', !private) +} + +model Profile { + id Int @id + user User? + @@allow('all', true) +} +`, + ); + + await db.user.create({ data: { id: 1, private: true } }); + await db.profile.create({ data: { id: 1 } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { connect: { id: 1 } } }, + include: { profile: true }, + }), + ).toBeRejectedNotFound(); + + await db.user.create({ data: { id: 2, private: false } }); + await db.profile.create({ data: { id: 2 } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { profile: { connect: { id: 2 } } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ + profile: { + id: 2, + }, + }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { profile: { disconnect: true } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ + profile: null, + }); + // reconnect + await db.user.update({ where: { id: 2 }, data: { profile: { connect: { id: 2 } } } }); + // set private + await db.user.update({ where: { id: 2 }, data: { private: true } }); + // disconnect should be rejected since update is not allowed + await expect( + db.user.update({ + where: { id: 2 }, + data: { profile: { disconnect: true } }, + include: { profile: true }, + }), + ).toBeRejectedNotFound(); + + await db.profile.create({ data: { id: 3 } }); + await expect( + db.profile.update({ + where: { id: 3 }, + data: { user: { create: { id: 3, private: true } } }, + }), + ).toResolveTruthy(); + }); + }); + + // describe('Upsert tests', () => {}); + + // describe('Update many tests', () => {}); +}); diff --git a/packages/runtime/test/policy/deep-nested.test.ts b/packages/runtime/test/policy/deep-nested.test.ts index 0be59e24..a35e34b8 100644 --- a/packages/runtime/test/policy/deep-nested.test.ts +++ b/packages/runtime/test/policy/deep-nested.test.ts @@ -7,7 +7,8 @@ describe('deep nested operations tests', () => { // -* M4 model M1 { myId String @id @default(cuid()) - m2 M2? + m2 M2? @relation(fields: [m2Id], references: [id], onDelete: Cascade) + m2Id Int? @unique value Int @default(0) @@allow('all', true) @@ -19,8 +20,7 @@ describe('deep nested operations tests', () => { model M2 { id Int @id @default(autoincrement()) value Int - m1 M1 @relation(fields: [m1Id], references: [myId], onDelete: Cascade) - m1Id String @unique + m1 M1? m3 M3? m4 M4[] @@ -616,7 +616,8 @@ describe('deep nested operations tests', () => { myId: '1', m2: { create: { - value: 1, + id: 1, + value: 3, m4: { create: [{ value: 200 }, { value: 22 }], }, @@ -628,10 +629,14 @@ describe('deep nested operations tests', () => { // delete read-back filtered: M4 @@deny('read', value == 200) const r = await db.m1.delete({ where: { myId: '1' }, - include: { m2: { select: { m4: true } } }, + include: { m2: { select: { id: true, m4: true } } }, }); expect(r.m2.m4).toHaveLength(1); + await expect(db.m2.findMany()).resolves.toHaveLength(1); + await expect(db.m4.findMany()).resolves.toHaveLength(1); + + await db.m2.delete({ where: { id: 1 } }); await expect(db.m4.findMany()).resolves.toHaveLength(0); await db.m1.create({ diff --git a/packages/runtime/test/scripts/generate.ts b/packages/runtime/test/scripts/generate.ts index a1393e30..7e5f1293 100644 --- a/packages/runtime/test/scripts/generate.ts +++ b/packages/runtime/test/scripts/generate.ts @@ -8,7 +8,6 @@ import { fileURLToPath } from 'node:url'; const dir = path.dirname(fileURLToPath(import.meta.url)); async function main() { - // glob all zmodel files in "e2e" directory const zmodelFiles = glob.sync(path.resolve(dir, '../schemas/**/*.zmodel')); for (const file of zmodelFiles) { console.log(`Generating TS schema for: ${file}`); diff --git a/packages/runtime/test/utils.ts b/packages/runtime/test/utils.ts index 4654fccc..64484593 100644 --- a/packages/runtime/test/utils.ts +++ b/packages/runtime/test/utils.ts @@ -3,7 +3,7 @@ import { loadDocument } from '@zenstackhq/language'; import { PrismaSchemaGenerator } from '@zenstackhq/sdk'; import { createTestProject, generateTsSchema } from '@zenstackhq/testtools'; import SQLite from 'better-sqlite3'; -import { PostgresDialect, SqliteDialect } from 'kysely'; +import { PostgresDialect, SqliteDialect, type LogEvent } from 'kysely'; import { execSync } from 'node:child_process'; import fs from 'node:fs'; import path from 'node:path'; @@ -192,3 +192,7 @@ export async function createTestClient( return client; } + +export function testLogger(e: LogEvent) { + console.log(e.query.sql, e.query.parameters); +} diff --git a/packages/sdk/package.json b/packages/sdk/package.json index c5aec65b..01e8af3b 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/sdk", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "description": "ZenStack SDK", "type": "module", "scripts": { diff --git a/packages/sdk/src/model-utils.ts b/packages/sdk/src/model-utils.ts index 3ab4a01e..7b54aa96 100644 --- a/packages/sdk/src/model-utils.ts +++ b/packages/sdk/src/model-utils.ts @@ -2,6 +2,7 @@ import { isDataModel, isLiteralExpr, isModel, + isTypeDef, Model, type AstNode, type Attribute, @@ -102,7 +103,7 @@ export function resolved(ref: Reference): T { export function getAuthDecl(model: Model) { let found = model.declarations.find( - (d) => isDataModel(d) && d.attributes.some((attr) => attr.decl.$refText === '@@auth'), + (d) => (isDataModel(d) || isTypeDef(d)) && d.attributes.some((attr) => attr.decl.$refText === '@@auth'), ); if (!found) { found = model.declarations.find((d) => isDataModel(d) && d.name === 'User'); diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index d2e8ba64..1d558300 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -52,9 +52,14 @@ import { } from './model-utils'; export class TsSchemaGenerator { + private usedExpressionUtils = false; + async generate(model: Model, outputDir: string) { fs.mkdirSync(outputDir, { recursive: true }); + // Reset the flag for each generation + this.usedExpressionUtils = false; + // the schema itself this.generateSchema(model, outputDir); @@ -82,6 +87,10 @@ export class TsSchemaGenerator { (d) => isDataModel(d) && d.fields.some((f) => hasAttribute(f, '@computed')), ); + // Generate schema content first to determine if ExpressionUtils is needed + const schemaObject = this.createSchemaObject(model); + + // Now generate the import declaration with the correct imports const runtimeImportDecl = ts.factory.createImportDeclaration( undefined, ts.factory.createImportClause( @@ -98,7 +107,15 @@ export class TsSchemaGenerator { ), ] : []), - ts.factory.createImportSpecifier(false, undefined, ts.factory.createIdentifier('ExpressionUtils')), + ...(this.usedExpressionUtils + ? [ + ts.factory.createImportSpecifier( + false, + undefined, + ts.factory.createIdentifier('ExpressionUtils'), + ), + ] + : []), ]), ), ts.factory.createStringLiteral('@zenstackhq/runtime/schema'), @@ -114,10 +131,7 @@ export class TsSchemaGenerator { undefined, undefined, ts.factory.createSatisfiesExpression( - ts.factory.createAsExpression( - this.createSchemaObject(model), - ts.factory.createTypeReferenceNode('const'), - ), + ts.factory.createAsExpression(schemaObject, ts.factory.createTypeReferenceNode('const')), ts.factory.createTypeReferenceNode('SchemaDef'), ), ), @@ -137,6 +151,15 @@ export class TsSchemaGenerator { statements.push(typeDeclaration); } + private createExpressionUtilsCall(method: string, args?: ts.Expression[]): ts.CallExpression { + this.usedExpressionUtils = true; + return ts.factory.createCallExpression( + ts.factory.createPropertyAccessExpression(ts.factory.createIdentifier('ExpressionUtils'), method), + undefined, + args || [], + ); + } + private createSchemaObject(model: Model) { const properties: ts.PropertyAssignment[] = [ // provider @@ -477,40 +500,28 @@ export class TsSchemaGenerator { ts.factory.createPropertyAssignment( 'default', - ts.factory.createCallExpression( - ts.factory.createIdentifier('ExpressionUtils.call'), - undefined, - [ - ts.factory.createStringLiteral(defaultValue.call), - ...(defaultValue.args.length > 0 - ? [ - ts.factory.createArrayLiteralExpression( - defaultValue.args.map((arg) => this.createLiteralNode(arg)), - ), - ] - : []), - ], - ), + this.createExpressionUtilsCall('call', [ + ts.factory.createStringLiteral(defaultValue.call), + ...(defaultValue.args.length > 0 + ? [ + ts.factory.createArrayLiteralExpression( + defaultValue.args.map((arg) => this.createLiteralNode(arg)), + ), + ] + : []), + ]), ), ); } else if ('authMember' in defaultValue) { objectFields.push( ts.factory.createPropertyAssignment( 'default', - ts.factory.createCallExpression( - ts.factory.createIdentifier('ExpressionUtils.member'), - undefined, - [ - ts.factory.createCallExpression( - ts.factory.createIdentifier('ExpressionUtils.call'), - undefined, - [ts.factory.createStringLiteral('auth')], - ), - ts.factory.createArrayLiteralExpression( - defaultValue.authMember.map((m) => ts.factory.createStringLiteral(m)), - ), - ], - ), + this.createExpressionUtilsCall('member', [ + this.createExpressionUtilsCall('call', [ts.factory.createStringLiteral('auth')]), + ts.factory.createArrayLiteralExpression( + defaultValue.authMember.map((m) => ts.factory.createStringLiteral(m)), + ), + ]), ), ); } else { @@ -1015,7 +1026,7 @@ export class TsSchemaGenerator { } private createThisExpression() { - return ts.factory.createCallExpression(ts.factory.createIdentifier('ExpressionUtils._this'), undefined, []); + return this.createExpressionUtilsCall('_this'); } private createMemberExpression(expr: MemberAccessExpr) { @@ -1034,15 +1045,15 @@ export class TsSchemaGenerator { ts.factory.createArrayLiteralExpression(members.map((m) => ts.factory.createStringLiteral(m))), ]; - return ts.factory.createCallExpression(ts.factory.createIdentifier('ExpressionUtils.member'), undefined, args); + return this.createExpressionUtilsCall('member', args); } private createNullExpression() { - return ts.factory.createCallExpression(ts.factory.createIdentifier('ExpressionUtils._null'), undefined, []); + return this.createExpressionUtilsCall('_null'); } private createBinaryExpression(expr: BinaryExpr) { - return ts.factory.createCallExpression(ts.factory.createIdentifier('ExpressionUtils.binary'), undefined, [ + return this.createExpressionUtilsCall('binary', [ this.createExpression(expr.left), this.createLiteralNode(expr.operator), this.createExpression(expr.right), @@ -1050,23 +1061,21 @@ export class TsSchemaGenerator { } private createUnaryExpression(expr: UnaryExpr) { - return ts.factory.createCallExpression(ts.factory.createIdentifier('ExpressionUtils.unary'), undefined, [ + return this.createExpressionUtilsCall('unary', [ this.createLiteralNode(expr.operator), this.createExpression(expr.operand), ]); } private createArrayExpression(expr: ArrayExpr): any { - return ts.factory.createCallExpression(ts.factory.createIdentifier('ExpressionUtils.array'), undefined, [ + return this.createExpressionUtilsCall('array', [ ts.factory.createArrayLiteralExpression(expr.items.map((item) => this.createExpression(item))), ]); } private createRefExpression(expr: ReferenceExpr): any { if (isDataField(expr.target.ref)) { - return ts.factory.createCallExpression(ts.factory.createIdentifier('ExpressionUtils.field'), undefined, [ - this.createLiteralNode(expr.target.$refText), - ]); + return this.createExpressionUtilsCall('field', [this.createLiteralNode(expr.target.$refText)]); } else if (isEnumField(expr.target.ref)) { return this.createLiteralExpression('StringLiteral', expr.target.$refText); } else { @@ -1075,7 +1084,7 @@ export class TsSchemaGenerator { } private createCallExpression(expr: InvocationExpr) { - return ts.factory.createCallExpression(ts.factory.createIdentifier('ExpressionUtils.call'), undefined, [ + return this.createExpressionUtilsCall('call', [ ts.factory.createStringLiteral(expr.function.$refText), ...(expr.args.length > 0 ? [ts.factory.createArrayLiteralExpression(expr.args.map((arg) => this.createExpression(arg.value)))] @@ -1085,21 +1094,11 @@ export class TsSchemaGenerator { private createLiteralExpression(type: string, value: string | boolean) { return match(type) - .with('BooleanLiteral', () => - ts.factory.createCallExpression(ts.factory.createIdentifier('ExpressionUtils.literal'), undefined, [ - this.createLiteralNode(value), - ]), - ) + .with('BooleanLiteral', () => this.createExpressionUtilsCall('literal', [this.createLiteralNode(value)])) .with('NumberLiteral', () => - ts.factory.createCallExpression(ts.factory.createIdentifier('ExpressionUtils.literal'), undefined, [ - ts.factory.createIdentifier(value as string), - ]), - ) - .with('StringLiteral', () => - ts.factory.createCallExpression(ts.factory.createIdentifier('ExpressionUtils.literal'), undefined, [ - this.createLiteralNode(value), - ]), + this.createExpressionUtilsCall('literal', [ts.factory.createIdentifier(value as string)]), ) + .with('StringLiteral', () => this.createExpressionUtilsCall('literal', [this.createLiteralNode(value)])) .otherwise(() => { throw new Error(`Unsupported literal type: ${type}`); }); diff --git a/packages/tanstack-query/package.json b/packages/tanstack-query/package.json index 82385793..86a9f6ae 100644 --- a/packages/tanstack-query/package.json +++ b/packages/tanstack-query/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/tanstack-query", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "description": "", "main": "index.js", "type": "module", diff --git a/packages/testtools/package.json b/packages/testtools/package.json index b951b9c2..81e91b7c 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "description": "ZenStack Test Tools", "type": "module", "scripts": { diff --git a/packages/typescript-config/package.json b/packages/typescript-config/package.json index 3437f29a..75109c2e 100644 --- a/packages/typescript-config/package.json +++ b/packages/typescript-config/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/typescript-config", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "private": true, "license": "MIT" } diff --git a/packages/vitest-config/package.json b/packages/vitest-config/package.json index 61c8e722..878e8fbd 100644 --- a/packages/vitest-config/package.json +++ b/packages/vitest-config/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/vitest-config", "type": "module", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "private": true, "license": "MIT", "exports": { diff --git a/packages/zod/package.json b/packages/zod/package.json index 7b8434fb..7bc82864 100644 --- a/packages/zod/package.json +++ b/packages/zod/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/zod", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "description": "", "type": "module", "main": "index.js", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 82877877..c339a56b 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -251,6 +251,9 @@ importers: '@zenstackhq/vitest-config': specifier: workspace:* version: link:../vitest-config + glob: + specifier: ^11.0.2 + version: 11.0.2 langium-cli: specifier: 'catalog:' version: 3.5.0 @@ -482,6 +485,31 @@ importers: specifier: workspace:* version: link:../../packages/vitest-config + tests/regression: + dependencies: + '@zenstackhq/testtools': + specifier: workspace:* + version: link:../../packages/testtools + devDependencies: + '@zenstackhq/cli': + specifier: workspace:* + version: link:../../packages/cli + '@zenstackhq/language': + specifier: workspace:* + version: link:../../packages/language + '@zenstackhq/runtime': + specifier: workspace:* + version: link:../../packages/runtime + '@zenstackhq/sdk': + specifier: workspace:* + version: link:../../packages/sdk + '@zenstackhq/typescript-config': + specifier: workspace:* + version: link:../../packages/typescript-config + '@zenstackhq/vitest-config': + specifier: workspace:* + version: link:../../packages/vitest-config + packages: '@chevrotain/cst-dts-gen@11.0.3': diff --git a/samples/blog/package.json b/samples/blog/package.json index 1168c32b..a30dd08e 100644 --- a/samples/blog/package.json +++ b/samples/blog/package.json @@ -1,6 +1,6 @@ { "name": "sample-blog", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "description": "", "main": "index.js", "scripts": { diff --git a/tests/e2e/package.json b/tests/e2e/package.json index d4e7e77c..e3129a69 100644 --- a/tests/e2e/package.json +++ b/tests/e2e/package.json @@ -1,6 +1,6 @@ { "name": "e2e", - "version": "3.0.0-beta.2", + "version": "3.0.0-beta.4", "private": true, "type": "module", "scripts": { diff --git a/tests/regression/generate.ts b/tests/regression/generate.ts new file mode 100644 index 00000000..86993b96 --- /dev/null +++ b/tests/regression/generate.ts @@ -0,0 +1,30 @@ +import { loadDocument } from '@zenstackhq/language'; +import { TsSchemaGenerator } from '@zenstackhq/sdk'; +import { glob } from 'glob'; +import fs from 'node:fs'; +import path from 'node:path'; +import { fileURLToPath } from 'node:url'; + +const dir = path.dirname(fileURLToPath(import.meta.url)); + +async function main() { + const zmodelFiles = glob.sync(path.resolve(dir, './test/**/*.zmodel')); + for (const file of zmodelFiles) { + console.log(`Generating TS schema for: ${file}`); + await generate(file); + } +} + +async function generate(schemaPath: string) { + const generator = new TsSchemaGenerator(); + const outputDir = path.dirname(schemaPath); + const tsPath = path.join(outputDir, 'schema.ts'); + const pluginModelFiles = glob.sync(path.resolve(dir, '../../packages/runtime/dist/**/plugin.zmodel')); + const result = await loadDocument(schemaPath, pluginModelFiles); + if (!result.success) { + throw new Error(`Failed to load schema from ${schemaPath}: ${result.errors}`); + } + await generator.generate(result.model, outputDir); +} + +main(); diff --git a/tests/regression/package.json b/tests/regression/package.json new file mode 100644 index 00000000..1d54ca4f --- /dev/null +++ b/tests/regression/package.json @@ -0,0 +1,21 @@ +{ + "name": "regression", + "version": "3.0.0-beta.3", + "private": true, + "type": "module", + "scripts": { + "generate": "tsx generate.ts", + "test": "pnpm generate && tsc && vitest run" + }, + "dependencies": { + "@zenstackhq/testtools": "workspace:*" + }, + "devDependencies": { + "@zenstackhq/cli": "workspace:*", + "@zenstackhq/sdk": "workspace:*", + "@zenstackhq/language": "workspace:*", + "@zenstackhq/runtime": "workspace:*", + "@zenstackhq/typescript-config": "workspace:*", + "@zenstackhq/vitest-config": "workspace:*" + } +} diff --git a/tests/regression/test/issue-204/input.ts b/tests/regression/test/issue-204/input.ts new file mode 100644 index 00000000..3916c070 --- /dev/null +++ b/tests/regression/test/issue-204/input.ts @@ -0,0 +1,30 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// DO NOT MODIFY THIS FILE // +// This file is automatically generated by ZenStack CLI and should not be manually updated. // +////////////////////////////////////////////////////////////////////////////////////////////// + +/* eslint-disable */ + +import { type SchemaType as $Schema } from "./schema"; +import type { FindManyArgs as $FindManyArgs, FindUniqueArgs as $FindUniqueArgs, FindFirstArgs as $FindFirstArgs, CreateArgs as $CreateArgs, CreateManyArgs as $CreateManyArgs, CreateManyAndReturnArgs as $CreateManyAndReturnArgs, UpdateArgs as $UpdateArgs, UpdateManyArgs as $UpdateManyArgs, UpdateManyAndReturnArgs as $UpdateManyAndReturnArgs, UpsertArgs as $UpsertArgs, DeleteArgs as $DeleteArgs, DeleteManyArgs as $DeleteManyArgs, CountArgs as $CountArgs, AggregateArgs as $AggregateArgs, GroupByArgs as $GroupByArgs, WhereInput as $WhereInput, SelectInput as $SelectInput, IncludeInput as $IncludeInput, OmitInput as $OmitInput } from "@zenstackhq/runtime"; +import type { SimplifiedModelResult as $SimplifiedModelResult, SelectIncludeOmit as $SelectIncludeOmit } from "@zenstackhq/runtime"; +export type FooFindManyArgs = $FindManyArgs<$Schema, "Foo">; +export type FooFindUniqueArgs = $FindUniqueArgs<$Schema, "Foo">; +export type FooFindFirstArgs = $FindFirstArgs<$Schema, "Foo">; +export type FooCreateArgs = $CreateArgs<$Schema, "Foo">; +export type FooCreateManyArgs = $CreateManyArgs<$Schema, "Foo">; +export type FooCreateManyAndReturnArgs = $CreateManyAndReturnArgs<$Schema, "Foo">; +export type FooUpdateArgs = $UpdateArgs<$Schema, "Foo">; +export type FooUpdateManyArgs = $UpdateManyArgs<$Schema, "Foo">; +export type FooUpdateManyAndReturnArgs = $UpdateManyAndReturnArgs<$Schema, "Foo">; +export type FooUpsertArgs = $UpsertArgs<$Schema, "Foo">; +export type FooDeleteArgs = $DeleteArgs<$Schema, "Foo">; +export type FooDeleteManyArgs = $DeleteManyArgs<$Schema, "Foo">; +export type FooCountArgs = $CountArgs<$Schema, "Foo">; +export type FooAggregateArgs = $AggregateArgs<$Schema, "Foo">; +export type FooGroupByArgs = $GroupByArgs<$Schema, "Foo">; +export type FooWhereInput = $WhereInput<$Schema, "Foo">; +export type FooSelect = $SelectInput<$Schema, "Foo">; +export type FooInclude = $IncludeInput<$Schema, "Foo">; +export type FooOmit = $OmitInput<$Schema, "Foo">; +export type FooGetPayload> = $SimplifiedModelResult<$Schema, "Foo", Args>; diff --git a/tests/regression/test/issue-204/models.ts b/tests/regression/test/issue-204/models.ts new file mode 100644 index 00000000..c03d254e --- /dev/null +++ b/tests/regression/test/issue-204/models.ts @@ -0,0 +1,13 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// DO NOT MODIFY THIS FILE // +// This file is automatically generated by ZenStack CLI and should not be manually updated. // +////////////////////////////////////////////////////////////////////////////////////////////// + +/* eslint-disable */ + +import { schema as $schema, type SchemaType as $Schema } from "./schema"; +import { type ModelResult as $ModelResult, type TypeDefResult as $TypeDefResult } from "@zenstackhq/runtime"; +export type Foo = $ModelResult<$Schema, "Foo">; +export type Configuration = $TypeDefResult<$Schema, "Configuration">; +export const ShirtColor = $schema.enums.ShirtColor; +export type ShirtColor = (typeof ShirtColor)[keyof typeof ShirtColor]; diff --git a/tests/regression/test/issue-204/regression.test.ts b/tests/regression/test/issue-204/regression.test.ts new file mode 100644 index 00000000..24a43e3b --- /dev/null +++ b/tests/regression/test/issue-204/regression.test.ts @@ -0,0 +1,11 @@ +import { describe, it } from 'vitest'; +import { type Configuration, ShirtColor } from './models'; + +describe('Issue 204 regression tests', () => { + it('tests issue 204', () => { + const config: Configuration = { teamColors: [ShirtColor.Black, ShirtColor.Blue] }; + console.log(config.teamColors?.[0]); + const config1: Configuration = {}; + console.log(config1); + }); +}); diff --git a/tests/regression/test/issue-204/regression.zmodel b/tests/regression/test/issue-204/regression.zmodel new file mode 100644 index 00000000..95309329 --- /dev/null +++ b/tests/regression/test/issue-204/regression.zmodel @@ -0,0 +1,21 @@ +datasource db { + provider = "sqlite" + url = "file:./dev.db" +} + +enum ShirtColor { + Black + White + Red + Green + Blue +} + +type Configuration { + teamColors ShirtColor[]? // This should be an optional array +} + +model Foo { + id Int @id + config Configuration @json +} diff --git a/tests/regression/test/issue-204/schema.ts b/tests/regression/test/issue-204/schema.ts new file mode 100644 index 00000000..b214a272 --- /dev/null +++ b/tests/regression/test/issue-204/schema.ts @@ -0,0 +1,59 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// DO NOT MODIFY THIS FILE // +// This file is automatically generated by ZenStack CLI and should not be manually updated. // +////////////////////////////////////////////////////////////////////////////////////////////// + +/* eslint-disable */ + +import { type SchemaDef } from "@zenstackhq/runtime/schema"; +export const schema = { + provider: { + type: "sqlite" + }, + models: { + Foo: { + name: "Foo", + fields: { + id: { + name: "id", + type: "Int", + id: true, + attributes: [{ name: "@id" }] + }, + config: { + name: "config", + type: "Configuration", + attributes: [{ name: "@json" }] + } + }, + idFields: ["id"], + uniqueFields: { + id: { type: "Int" } + } + } + }, + typeDefs: { + Configuration: { + name: "Configuration", + fields: { + teamColors: { + name: "teamColors", + type: "ShirtColor", + optional: true, + array: true + } + } + } + }, + enums: { + ShirtColor: { + Black: "Black", + White: "White", + Red: "Red", + Green: "Green", + Blue: "Blue" + } + }, + plugins: {} +} as const satisfies SchemaDef; +export type SchemaType = typeof schema; diff --git a/tests/regression/tsconfig.json b/tests/regression/tsconfig.json new file mode 100644 index 00000000..f3a2dbcb --- /dev/null +++ b/tests/regression/tsconfig.json @@ -0,0 +1,7 @@ +{ + "extends": "@zenstackhq/typescript-config/base.json", + "compilerOptions": { + "noEmit": true + }, + "include": ["src/**/*.ts", "test/**/*.ts"] +} diff --git a/tests/regression/vitest.config.ts b/tests/regression/vitest.config.ts new file mode 100644 index 00000000..75a9f709 --- /dev/null +++ b/tests/regression/vitest.config.ts @@ -0,0 +1,4 @@ +import base from '@zenstackhq/vitest-config/base'; +import { defineConfig, mergeConfig } from 'vitest/config'; + +export default mergeConfig(base, defineConfig({}));