Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸ› fix: fix Azure OpenAI O1 models and refactor the Azure OpenAI implement #6079

Merged
merged 5 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@
"@aws-sdk/client-bedrock-runtime": "^3.723.0",
"@aws-sdk/client-s3": "^3.723.0",
"@aws-sdk/s3-request-presigner": "^3.723.0",
"@azure/core-rest-pipeline": "1.16.0",
"@azure/openai": "1.0.0-beta.12",
"@cfworker/json-schema": "^4.1.0",
"@clerk/localizations": "^3.9.6",
"@clerk/nextjs": "^6.10.6",
Expand Down
1 change: 1 addition & 0 deletions src/libs/agent-runtime/AgentRuntime.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ describe('AgentRuntime', () => {
const jwtPayload = {
apiKey: 'user-azure-key',
baseURL: 'user-azure-endpoint',
apiVersion: '2024-06-01',
};
const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.Azure, {
azure: jwtPayload,
Expand Down
56 changes: 47 additions & 9 deletions src/libs/agent-runtime/azureOpenai/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// @vitest-environment node
import { AzureKeyCredential, OpenAIClient } from '@azure/openai';
import OpenAI from 'openai';
import { AzureOpenAI } from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import * as debugStreamModule from '../utils/debugStream';
import * as openaiCompatibleFactoryModule from '../utils/openaiCompatibleFactory';
import { LobeAzureOpenAI } from './index';

const bizErrorType = 'ProviderBizError';
Expand All @@ -23,7 +23,7 @@ describe('LobeAzureOpenAI', () => {
);

// 使用 vi.spyOn ζ₯ζ¨‘ζ‹Ÿ streamChatCompletions 方法
vi.spyOn(instance['client'], 'streamChatCompletions').mockResolvedValue(
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);
});
Expand All @@ -48,7 +48,7 @@ describe('LobeAzureOpenAI', () => {

const instance = new LobeAzureOpenAI(endpoint, apikey, apiVersion);

expect(instance.client).toBeInstanceOf(OpenAIClient);
expect(instance.client).toBeInstanceOf(AzureOpenAI);
expect(instance.baseURL).toBe(endpoint);
});
});
Expand All @@ -59,7 +59,7 @@ describe('LobeAzureOpenAI', () => {
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);

(instance['client'].streamChatCompletions as Mock).mockResolvedValue(mockResponse);
(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);

// Act
const result = await instance.chat({
Expand Down Expand Up @@ -164,7 +164,9 @@ describe('LobeAzureOpenAI', () => {
controller.close();
},
});
vi.spyOn(instance['client'], 'streamChatCompletions').mockResolvedValue(mockStream as any);
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
mockStream as any,
);

const result = await instance.chat({
stream: true,
Expand Down Expand Up @@ -204,6 +206,42 @@ describe('LobeAzureOpenAI', () => {
].map((item) => `${item}\n`),
);
});

it('should handle non-streaming response', async () => {
vi.spyOn(openaiCompatibleFactoryModule, 'transformResponseToStream').mockImplementation(
() => {
return new ReadableStream();
},
);
// Act
await instance.chat({
stream: false,
temperature: 0.6,
model: 'gpt-35-turbo-16k',
messages: [{ role: 'user', content: 'δ½ ε₯½' }],
});

// Assert
expect(openaiCompatibleFactoryModule.transformResponseToStream).toHaveBeenCalled();
});
});

it('should handle o1 series models without streaming', async () => {
vi.spyOn(openaiCompatibleFactoryModule, 'transformResponseToStream').mockImplementation(
() => {
return new ReadableStream();
},
);

// Act
await instance.chat({
temperature: 0.6,
model: 'o1-preview',
messages: [{ role: 'user', content: 'δ½ ε₯½' }],
});

// Assert
expect(openaiCompatibleFactoryModule.transformResponseToStream).toHaveBeenCalled();
});

describe('Error', () => {
Expand All @@ -214,7 +252,7 @@ describe('LobeAzureOpenAI', () => {
message: 'Deployment not found',
};

(instance['client'].streamChatCompletions as Mock).mockRejectedValue(error);
(instance['client'].chat.completions.create as Mock).mockRejectedValue(error);

// Act
try {
Expand Down Expand Up @@ -242,7 +280,7 @@ describe('LobeAzureOpenAI', () => {
// Arrange
const genericError = new Error('Generic Error');

(instance['client'].streamChatCompletions as Mock).mockRejectedValue(genericError);
(instance['client'].chat.completions.create as Mock).mockRejectedValue(genericError);

// Act
try {
Expand Down Expand Up @@ -279,7 +317,7 @@ describe('LobeAzureOpenAI', () => {
}) as any;
mockDebugStream.toReadableStream = () => mockDebugStream;

(instance['client'].streamChatCompletions as Mock).mockResolvedValue({
(instance['client'].chat.completions.create as Mock).mockResolvedValue({
tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }],
});

Expand Down
63 changes: 35 additions & 28 deletions src/libs/agent-runtime/azureOpenai/index.ts
Original file line number Diff line number Diff line change
@@ -1,55 +1,62 @@
import {
AzureKeyCredential,
ChatRequestMessage,
GetChatCompletionsOptions,
OpenAIClient,
} from '@azure/openai';
import OpenAI, { AzureOpenAI } from 'openai';
import type { Stream } from 'openai/streaming';

import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { transformResponseToStream } from '../utils/openaiCompatibleFactory';
import { StreamingResponse } from '../utils/response';
import { AzureOpenAIStream } from '../utils/streams';
import { OpenAIStream } from '../utils/streams';

export class LobeAzureOpenAI implements LobeRuntimeAI {
client: OpenAIClient;
client: AzureOpenAI;

constructor(endpoint?: string, apikey?: string, apiVersion?: string) {
if (!apikey || !endpoint)
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);

this.client = new OpenAIClient(endpoint, new AzureKeyCredential(apikey), { apiVersion });
this.client = new AzureOpenAI({
apiKey: apikey,
apiVersion,
dangerouslyAllowBrowser: true,
endpoint,
});

this.baseURL = endpoint;
}

baseURL: string;

async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) {
// ============ 1. preprocess messages ============ //
const camelCasePayload = this.camelCaseKeys(payload);
const { messages, model, maxTokens = 2048, ...params } = camelCasePayload;

// ============ 2. send api ============ //

const { messages, model, ...params } = payload;
// o1 series models on Azure OpenAI does not support streaming currently
const enableStreaming = model.startsWith('o1') ? false : (params.stream ?? true);
try {
const response = await this.client.streamChatCompletions(
const response = await this.client.chat.completions.create({
messages: messages as OpenAI.ChatCompletionMessageParam[],
model,
messages as ChatRequestMessage[],
{ ...params, abortSignal: options?.signal, maxTokens } as GetChatCompletionsOptions,
);

const [debug, prod] = response.tee();

if (process.env.DEBUG_AZURE_CHAT_COMPLETION === '1') {
debugStream(debug).catch(console.error);
}

return StreamingResponse(AzureOpenAIStream(prod, options?.callback), {
headers: options?.headers,
...params,
max_completion_tokens: 2048,
stream: enableStreaming,
tool_choice: params.tools ? 'auto' : undefined,
});
if (enableStreaming) {
const stream = response as Stream<OpenAI.ChatCompletionChunk>;
const [prod, debug] = stream.tee();
if (process.env.DEBUG_AZURE_CHAT_COMPLETION === '1') {
debugStream(debug.toReadableStream()).catch(console.error);
}
return StreamingResponse(OpenAIStream(prod, { callbacks: options?.callback }), {
headers: options?.headers,
});
} else {
const stream = transformResponseToStream(response as OpenAI.ChatCompletion);
return StreamingResponse(OpenAIStream(stream, { callbacks: options?.callback }), {
headers: options?.headers,
});
}
} catch (e) {
let error = e as { [key: string]: any; code: string; message: string };

Expand Down
Loading