Skip to content

Commit

Permalink
Merge pull request #3852 from L-Mario564/zod-coerce
Browse files Browse the repository at this point in the history
Add type coercion support to `drizzle-zod`
  • Loading branch information
AndriiSherman authored Jan 23, 2025
2 parents 79df8c1 + 6201629 commit 3dff816
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 18 deletions.
43 changes: 30 additions & 13 deletions drizzle-zod/src/column.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ import type {
} from 'drizzle-orm/singlestore-core';
import type { SQLiteInteger, SQLiteReal, SQLiteText } from 'drizzle-orm/sqlite-core';
import { z } from 'zod';
import type { z as zod } from 'zod';
import { z as zod } from 'zod';
import { CONSTANTS } from './constants.ts';
import type { CreateSchemaFactoryOptions } from './schema.types.ts';
import { isColumnType, isWithEnum } from './utils.ts';
import type { Json } from './utils.ts';

Expand All @@ -65,7 +66,9 @@ export const jsonSchema: z.ZodType<Json> = z.lazy(() =>
);
export const bufferSchema: z.ZodType<Buffer> = z.custom<Buffer>((v) => v instanceof Buffer); // eslint-disable-line no-instanceof/no-instanceof

export function columnToSchema(column: Column, z: typeof zod): z.ZodTypeAny {
export function columnToSchema(column: Column, factory: CreateSchemaFactoryOptions | undefined): z.ZodTypeAny {
const z = factory?.zodInstance ?? zod;
const coerce = factory?.coerce ?? {};
let schema!: z.ZodTypeAny;

if (isWithEnum(column)) {
Expand Down Expand Up @@ -98,15 +101,15 @@ export function columnToSchema(column: Column, z: typeof zod): z.ZodTypeAny {
} else if (column.dataType === 'array') {
schema = z.array(z.any());
} else if (column.dataType === 'number') {
schema = numberColumnToSchema(column, z);
schema = numberColumnToSchema(column, z, coerce);
} else if (column.dataType === 'bigint') {
schema = bigintColumnToSchema(column, z);
schema = bigintColumnToSchema(column, z, coerce);
} else if (column.dataType === 'boolean') {
schema = z.boolean();
schema = coerce === true || coerce.boolean ? z.coerce.boolean() : z.boolean();
} else if (column.dataType === 'date') {
schema = z.date();
schema = coerce === true || coerce.date ? z.coerce.date() : z.date();
} else if (column.dataType === 'string') {
schema = stringColumnToSchema(column, z);
schema = stringColumnToSchema(column, z, coerce);
} else if (column.dataType === 'json') {
schema = jsonSchema;
} else if (column.dataType === 'custom') {
Expand All @@ -123,7 +126,11 @@ export function columnToSchema(column: Column, z: typeof zod): z.ZodTypeAny {
return schema;
}

function numberColumnToSchema(column: Column, z: typeof zod): z.ZodTypeAny {
function numberColumnToSchema(
column: Column,
z: typeof zod,
coerce: CreateSchemaFactoryOptions['coerce'],
): z.ZodTypeAny {
let unsigned = column.getSQLType().includes('unsigned');
let min!: number;
let max!: number;
Expand Down Expand Up @@ -223,19 +230,29 @@ function numberColumnToSchema(column: Column, z: typeof zod): z.ZodTypeAny {
max = Number.MAX_SAFE_INTEGER;
}

const schema = z.number().min(min).max(max);
let schema = coerce === true || coerce?.number ? z.coerce.number() : z.number();
schema = schema.min(min).max(max);
return integer ? schema.int() : schema;
}

function bigintColumnToSchema(column: Column, z: typeof zod): z.ZodTypeAny {
function bigintColumnToSchema(
column: Column,
z: typeof zod,
coerce: CreateSchemaFactoryOptions['coerce'],
): z.ZodTypeAny {
const unsigned = column.getSQLType().includes('unsigned');
const min = unsigned ? 0n : CONSTANTS.INT64_MIN;
const max = unsigned ? CONSTANTS.INT64_UNSIGNED_MAX : CONSTANTS.INT64_MAX;

return z.bigint().min(min).max(max);
const schema = coerce === true || coerce?.bigint ? z.coerce.bigint() : z.bigint();
return schema.min(min).max(max);
}

function stringColumnToSchema(column: Column, z: typeof zod): z.ZodTypeAny {
function stringColumnToSchema(
column: Column,
z: typeof zod,
coerce: CreateSchemaFactoryOptions['coerce'],
): z.ZodTypeAny {
if (isColumnType<PgUUID<ColumnBaseConfig<'string', 'PgUUID'>>>(column, ['PgUUID'])) {
return z.string().uuid();
}
Expand Down Expand Up @@ -278,7 +295,7 @@ function stringColumnToSchema(column: Column, z: typeof zod): z.ZodTypeAny {
max = column.dimensions;
}

let schema = z.string();
let schema = coerce === true || coerce?.string ? z.coerce.string() : z.string();
schema = regex ? schema.regex(regex) : schema;
return max && fixed ? schema.length(max) : max ? schema.max(max) : schema;
}
2 changes: 1 addition & 1 deletion drizzle-zod/src/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function handleColumns(
}

const column = is(selected, Column) ? selected : undefined;
const schema = column ? columnToSchema(column, factory?.zodInstance ?? z) : z.any();
const schema = column ? columnToSchema(column, factory) : z.any();
const refined = typeof refinement === 'function' ? refinement(schema) : schema;

if (conditions.never(column)) {
Expand Down
1 change: 1 addition & 0 deletions drizzle-zod/src/schema.types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,5 @@ export interface CreateUpdateSchema {

export interface CreateSchemaFactoryOptions {
zodInstance?: any;
coerce?: Partial<Record<'bigint' | 'boolean' | 'date' | 'number' | 'string', true>> | true;
}
55 changes: 54 additions & 1 deletion drizzle-zod/tests/mysql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { test } from 'vitest';
import { z } from 'zod';
import { jsonSchema } from '~/column.ts';
import { CONSTANTS } from '~/constants.ts';
import { createInsertSchema, createSelectSchema, createUpdateSchema } from '../src';
import { createInsertSchema, createSchemaFactory, createSelectSchema, createUpdateSchema } from '../src';
import { Expect, expectSchemaShape } from './utils.ts';

const intSchema = z.number().min(CONSTANTS.INT32_MIN).max(CONSTANTS.INT32_MAX).int();
Expand Down Expand Up @@ -454,6 +454,59 @@ test('all data types', (t) => {
Expect<Equal<typeof result, typeof expected>>();
});

test('type coercion - all', (t) => {
const table = mysqlTable('test', ({
bigint,
boolean,
timestamp,
int,
text,
}) => ({
bigint: bigint({ mode: 'bigint' }).notNull(),
boolean: boolean().notNull(),
timestamp: timestamp().notNull(),
int: int().notNull(),
text: text().notNull(),
}));

const { createSelectSchema } = createSchemaFactory({
coerce: true,
});
const result = createSelectSchema(table);
const expected = z.object({
bigint: z.coerce.bigint().min(CONSTANTS.INT64_MIN).max(CONSTANTS.INT64_MAX),
boolean: z.coerce.boolean(),
timestamp: z.coerce.date(),
int: z.coerce.number().min(CONSTANTS.INT32_MIN).max(CONSTANTS.INT32_MAX).int(),
text: z.coerce.string().max(CONSTANTS.INT16_UNSIGNED_MAX),
});
expectSchemaShape(t, expected).from(result);
Expect<Equal<typeof result, typeof expected>>();
});

test('type coercion - mixed', (t) => {
const table = mysqlTable('test', ({
timestamp,
int,
}) => ({
timestamp: timestamp().notNull(),
int: int().notNull(),
}));

const { createSelectSchema } = createSchemaFactory({
coerce: {
date: true,
},
});
const result = createSelectSchema(table);
const expected = z.object({
timestamp: z.coerce.date(),
int: z.number().min(CONSTANTS.INT32_MIN).max(CONSTANTS.INT32_MAX).int(),
});
expectSchemaShape(t, expected).from(result);
Expect<Equal<typeof result, typeof expected>>();
});

/* Disallow unknown keys in table refinement - select */ {
const table = mysqlTable('test', { id: int() });
// @ts-expect-error
Expand Down
55 changes: 54 additions & 1 deletion drizzle-zod/tests/pg.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import { test } from 'vitest';
import { z } from 'zod';
import { jsonSchema } from '~/column.ts';
import { CONSTANTS } from '~/constants.ts';
import { createInsertSchema, createSelectSchema, createUpdateSchema } from '../src';
import { createInsertSchema, createSchemaFactory, createSelectSchema, createUpdateSchema } from '../src';
import { Expect, expectEnumValues, expectSchemaShape } from './utils.ts';

const integerSchema = z.number().min(CONSTANTS.INT32_MIN).max(CONSTANTS.INT32_MAX).int();
Expand Down Expand Up @@ -500,6 +500,59 @@ test('all data types', (t) => {
Expect<Equal<typeof result, typeof expected>>();
});

test('type coercion - all', (t) => {
const table = pgTable('test', ({
bigint,
boolean,
timestamp,
integer,
text,
}) => ({
bigint: bigint({ mode: 'bigint' }).notNull(),
boolean: boolean().notNull(),
timestamp: timestamp().notNull(),
integer: integer().notNull(),
text: text().notNull(),
}));

const { createSelectSchema } = createSchemaFactory({
coerce: true,
});
const result = createSelectSchema(table);
const expected = z.object({
bigint: z.coerce.bigint().min(CONSTANTS.INT64_MIN).max(CONSTANTS.INT64_MAX),
boolean: z.coerce.boolean(),
timestamp: z.coerce.date(),
integer: z.coerce.number().min(CONSTANTS.INT32_MIN).max(CONSTANTS.INT32_MAX).int(),
text: z.coerce.string(),
});
expectSchemaShape(t, expected).from(result);
Expect<Equal<typeof result, typeof expected>>();
});

test('type coercion - mixed', (t) => {
const table = pgTable('test', ({
timestamp,
integer,
}) => ({
timestamp: timestamp().notNull(),
integer: integer().notNull(),
}));

const { createSelectSchema } = createSchemaFactory({
coerce: {
date: true,
},
});
const result = createSelectSchema(table);
const expected = z.object({
timestamp: z.coerce.date(),
integer: z.number().min(CONSTANTS.INT32_MIN).max(CONSTANTS.INT32_MAX).int(),
});
expectSchemaShape(t, expected).from(result);
Expect<Equal<typeof result, typeof expected>>();
});

/* Disallow unknown keys in table refinement - select */ {
const table = pgTable('test', { id: integer() });
// @ts-expect-error
Expand Down
55 changes: 54 additions & 1 deletion drizzle-zod/tests/singlestore.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { test } from 'vitest';
import { z } from 'zod';
import { jsonSchema } from '~/column.ts';
import { CONSTANTS } from '~/constants.ts';
import { createInsertSchema, createSelectSchema, createUpdateSchema } from '../src';
import { createInsertSchema, createSchemaFactory, createSelectSchema, createUpdateSchema } from '../src';
import { Expect, expectSchemaShape } from './utils.ts';

const intSchema = z.number().min(CONSTANTS.INT32_MIN).max(CONSTANTS.INT32_MAX).int();
Expand Down Expand Up @@ -456,6 +456,59 @@ test('all data types', (t) => {
Expect<Equal<typeof result, typeof expected>>();
});

test('type coercion - all', (t) => {
const table = singlestoreTable('test', ({
bigint,
boolean,
timestamp,
int,
text,
}) => ({
bigint: bigint({ mode: 'bigint' }).notNull(),
boolean: boolean().notNull(),
timestamp: timestamp().notNull(),
int: int().notNull(),
text: text().notNull(),
}));

const { createSelectSchema } = createSchemaFactory({
coerce: true,
});
const result = createSelectSchema(table);
const expected = z.object({
bigint: z.coerce.bigint().min(CONSTANTS.INT64_MIN).max(CONSTANTS.INT64_MAX),
boolean: z.coerce.boolean(),
timestamp: z.coerce.date(),
int: z.coerce.number().min(CONSTANTS.INT32_MIN).max(CONSTANTS.INT32_MAX).int(),
text: z.coerce.string().max(CONSTANTS.INT16_UNSIGNED_MAX),
});
expectSchemaShape(t, expected).from(result);
Expect<Equal<typeof result, typeof expected>>();
});

test('type coercion - mixed', (t) => {
const table = singlestoreTable('test', ({
timestamp,
int,
}) => ({
timestamp: timestamp().notNull(),
int: int().notNull(),
}));

const { createSelectSchema } = createSchemaFactory({
coerce: {
date: true,
},
});
const result = createSelectSchema(table);
const expected = z.object({
timestamp: z.coerce.date(),
int: z.number().min(CONSTANTS.INT32_MIN).max(CONSTANTS.INT32_MAX).int(),
});
expectSchemaShape(t, expected).from(result);
Expect<Equal<typeof result, typeof expected>>();
});

/* Disallow unknown keys in table refinement - select */ {
const table = singlestoreTable('test', { id: int() });
// @ts-expect-error
Expand Down
52 changes: 51 additions & 1 deletion drizzle-zod/tests/sqlite.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { test } from 'vitest';
import { z } from 'zod';
import { bufferSchema, jsonSchema } from '~/column.ts';
import { CONSTANTS } from '~/constants.ts';
import { createInsertSchema, createSelectSchema, createUpdateSchema } from '../src';
import { createInsertSchema, createSchemaFactory, createSelectSchema, createUpdateSchema } from '../src';
import { Expect, expectSchemaShape } from './utils.ts';

const intSchema = z.number().min(Number.MIN_SAFE_INTEGER).max(Number.MAX_SAFE_INTEGER).int();
Expand Down Expand Up @@ -350,6 +350,56 @@ test('all data types', (t) => {
Expect<Equal<typeof result, typeof expected>>();
});

test('type coercion - all', (t) => {
const table = sqliteTable('test', ({
blob,
integer,
text,
}) => ({
blob: blob({ mode: 'bigint' }).notNull(),
integer1: integer({ mode: 'boolean' }).notNull(),
integer2: integer({ mode: 'timestamp' }).notNull(),
integer3: integer().notNull(),
text: text().notNull(),
}));

const { createSelectSchema } = createSchemaFactory({
coerce: true,
});
const result = createSelectSchema(table);
const expected = z.object({
blob: z.coerce.bigint().min(CONSTANTS.INT64_MIN).max(CONSTANTS.INT64_MAX),
integer1: z.coerce.boolean(),
integer2: z.coerce.date(),
integer3: z.coerce.number().min(Number.MIN_SAFE_INTEGER).max(Number.MAX_SAFE_INTEGER).int(),
text: z.coerce.string(),
});
expectSchemaShape(t, expected).from(result);
Expect<Equal<typeof result, typeof expected>>();
});

test('type coercion - mixed', (t) => {
const table = sqliteTable('test', ({
integer,
}) => ({
integer1: integer({ mode: 'timestamp' }).notNull(),
integer2: integer().notNull(),
}));

const { createSelectSchema } = createSchemaFactory({
coerce: {
date: true,
},
});
const result = createSelectSchema(table);
const expected = z.object({
integer1: z.coerce.date(),
integer2: z.number().min(Number.MIN_SAFE_INTEGER).max(Number.MAX_SAFE_INTEGER).int(),
});
expectSchemaShape(t, expected).from(result);
Expect<Equal<typeof result, typeof expected>>();
});

/* Disallow unknown keys in table refinement - select */ {
const table = sqliteTable('test', { id: int() });
// @ts-expect-error
Expand Down
1 change: 1 addition & 0 deletions drizzle-zod/tests/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export function expectSchemaShape<T extends z.ZodObject<z.ZodRawShape>>(t: TaskC
for (const key of Object.keys(actual.shape)) {
expect(actual.shape[key]!._def.typeName).toStrictEqual(expected.shape[key]?._def.typeName);
expect(actual.shape[key]!._def?.checks).toEqual(expected.shape[key]?._def?.checks);
expect(actual.shape[key]!._def?.coerce).toEqual(expected.shape[key]?._def?.coerce);
if (actual.shape[key]?._def.typeName === 'ZodOptional') {
expect(actual.shape[key]!._def.innerType._def.typeName).toStrictEqual(
actual.shape[key]!._def.innerType._def.typeName,
Expand Down

0 comments on commit 3dff816

Please sign in to comment.