Skip to content

Commit

Permalink
🐛 fix: fix Azure OpenAI O1 models and refactor the Azure OpenAI imple…
Browse files Browse the repository at this point in the history
…ment (#6079)

* 🐛 fix: fix Azure OpenAI O1 models

* fix: remove @azure/core-rest-pipeline

* fix: fix import & type assertion

* fix: fix import
  • Loading branch information
faceCutWall authored Feb 15, 2025
1 parent 40df6c2 commit 6a89a8c
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 660 deletions.
2 changes: 0 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@
"@aws-sdk/client-bedrock-runtime": "^3.723.0",
"@aws-sdk/client-s3": "^3.723.0",
"@aws-sdk/s3-request-presigner": "^3.723.0",
"@azure/core-rest-pipeline": "1.16.0",
"@azure/openai": "1.0.0-beta.12",
"@cfworker/json-schema": "^4.1.0",
"@clerk/localizations": "^3.9.6",
"@clerk/nextjs": "^6.10.6",
Expand Down
1 change: 1 addition & 0 deletions src/libs/agent-runtime/AgentRuntime.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ describe('AgentRuntime', () => {
const jwtPayload = {
apiKey: 'user-azure-key',
baseURL: 'user-azure-endpoint',
apiVersion: '2024-06-01',
};
const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.Azure, {
azure: jwtPayload,
Expand Down
56 changes: 47 additions & 9 deletions src/libs/agent-runtime/azureOpenai/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// @vitest-environment node
import { AzureKeyCredential, OpenAIClient } from '@azure/openai';
import OpenAI from 'openai';
import { AzureOpenAI } from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import * as debugStreamModule from '../utils/debugStream';
import * as openaiCompatibleFactoryModule from '../utils/openaiCompatibleFactory';
import { LobeAzureOpenAI } from './index';

const bizErrorType = 'ProviderBizError';
Expand All @@ -23,7 +23,7 @@ describe('LobeAzureOpenAI', () => {
);

// 使用 vi.spyOn 来模拟 streamChatCompletions 方法
vi.spyOn(instance['client'], 'streamChatCompletions').mockResolvedValue(
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);
});
Expand All @@ -48,7 +48,7 @@ describe('LobeAzureOpenAI', () => {

const instance = new LobeAzureOpenAI(endpoint, apikey, apiVersion);

expect(instance.client).toBeInstanceOf(OpenAIClient);
expect(instance.client).toBeInstanceOf(AzureOpenAI);
expect(instance.baseURL).toBe(endpoint);
});
});
Expand All @@ -59,7 +59,7 @@ describe('LobeAzureOpenAI', () => {
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);

(instance['client'].streamChatCompletions as Mock).mockResolvedValue(mockResponse);
(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);

// Act
const result = await instance.chat({
Expand Down Expand Up @@ -164,7 +164,9 @@ describe('LobeAzureOpenAI', () => {
controller.close();
},
});
vi.spyOn(instance['client'], 'streamChatCompletions').mockResolvedValue(mockStream as any);
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
mockStream as any,
);

const result = await instance.chat({
stream: true,
Expand Down Expand Up @@ -204,6 +206,42 @@ describe('LobeAzureOpenAI', () => {
].map((item) => `${item}\n`),
);
});

it('should handle non-streaming response', async () => {
vi.spyOn(openaiCompatibleFactoryModule, 'transformResponseToStream').mockImplementation(
() => {
return new ReadableStream();
},
);
// Act
await instance.chat({
stream: false,
temperature: 0.6,
model: 'gpt-35-turbo-16k',
messages: [{ role: 'user', content: '你好' }],
});

// Assert
expect(openaiCompatibleFactoryModule.transformResponseToStream).toHaveBeenCalled();
});
});

it('should handle o1 series models without streaming', async () => {
vi.spyOn(openaiCompatibleFactoryModule, 'transformResponseToStream').mockImplementation(
() => {
return new ReadableStream();
},
);

// Act
await instance.chat({
temperature: 0.6,
model: 'o1-preview',
messages: [{ role: 'user', content: '你好' }],
});

// Assert
expect(openaiCompatibleFactoryModule.transformResponseToStream).toHaveBeenCalled();
});

describe('Error', () => {
Expand All @@ -214,7 +252,7 @@ describe('LobeAzureOpenAI', () => {
message: 'Deployment not found',
};

(instance['client'].streamChatCompletions as Mock).mockRejectedValue(error);
(instance['client'].chat.completions.create as Mock).mockRejectedValue(error);

// Act
try {
Expand Down Expand Up @@ -242,7 +280,7 @@ describe('LobeAzureOpenAI', () => {
// Arrange
const genericError = new Error('Generic Error');

(instance['client'].streamChatCompletions as Mock).mockRejectedValue(genericError);
(instance['client'].chat.completions.create as Mock).mockRejectedValue(genericError);

// Act
try {
Expand Down Expand Up @@ -279,7 +317,7 @@ describe('LobeAzureOpenAI', () => {
}) as any;
mockDebugStream.toReadableStream = () => mockDebugStream;

(instance['client'].streamChatCompletions as Mock).mockResolvedValue({
(instance['client'].chat.completions.create as Mock).mockResolvedValue({
tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }],
});

Expand Down
63 changes: 35 additions & 28 deletions src/libs/agent-runtime/azureOpenai/index.ts
Original file line number Diff line number Diff line change
@@ -1,55 +1,62 @@
import {
AzureKeyCredential,
ChatRequestMessage,
GetChatCompletionsOptions,
OpenAIClient,
} from '@azure/openai';
import OpenAI, { AzureOpenAI } from 'openai';
import type { Stream } from 'openai/streaming';

import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { transformResponseToStream } from '../utils/openaiCompatibleFactory';
import { StreamingResponse } from '../utils/response';
import { AzureOpenAIStream } from '../utils/streams';
import { OpenAIStream } from '../utils/streams';

export class LobeAzureOpenAI implements LobeRuntimeAI {
client: OpenAIClient;
client: AzureOpenAI;

constructor(endpoint?: string, apikey?: string, apiVersion?: string) {
if (!apikey || !endpoint)
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);

this.client = new OpenAIClient(endpoint, new AzureKeyCredential(apikey), { apiVersion });
this.client = new AzureOpenAI({
apiKey: apikey,
apiVersion,
dangerouslyAllowBrowser: true,
endpoint,
});

this.baseURL = endpoint;
}

baseURL: string;

async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) {
// ============ 1. preprocess messages ============ //
const camelCasePayload = this.camelCaseKeys(payload);
const { messages, model, maxTokens = 2048, ...params } = camelCasePayload;

// ============ 2. send api ============ //

const { messages, model, ...params } = payload;
// o1 series models on Azure OpenAI does not support streaming currently
const enableStreaming = model.startsWith('o1') ? false : (params.stream ?? true);
try {
const response = await this.client.streamChatCompletions(
const response = await this.client.chat.completions.create({
messages: messages as OpenAI.ChatCompletionMessageParam[],
model,
messages as ChatRequestMessage[],
{ ...params, abortSignal: options?.signal, maxTokens } as GetChatCompletionsOptions,
);

const [debug, prod] = response.tee();

if (process.env.DEBUG_AZURE_CHAT_COMPLETION === '1') {
debugStream(debug).catch(console.error);
}

return StreamingResponse(AzureOpenAIStream(prod, options?.callback), {
headers: options?.headers,
...params,
max_completion_tokens: 2048,
stream: enableStreaming,
tool_choice: params.tools ? 'auto' : undefined,
});
if (enableStreaming) {
const stream = response as Stream<OpenAI.ChatCompletionChunk>;
const [prod, debug] = stream.tee();
if (process.env.DEBUG_AZURE_CHAT_COMPLETION === '1') {
debugStream(debug.toReadableStream()).catch(console.error);
}
return StreamingResponse(OpenAIStream(prod, { callbacks: options?.callback }), {
headers: options?.headers,
});
} else {
const stream = transformResponseToStream(response as OpenAI.ChatCompletion);
return StreamingResponse(OpenAIStream(stream, { callbacks: options?.callback }), {
headers: options?.headers,
});
}
} catch (e) {
let error = e as { [key: string]: any; code: string; message: string };

Expand Down
Loading

0 comments on commit 6a89a8c

Please sign in to comment.