Skip to content

Commit 2d30689

Browse files
committed
Improves the relevances of codestral completion
1 parent 8043bf9 commit 2d30689

File tree

7 files changed

+91
-17
lines changed

7 files changed

+91
-17
lines changed

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
"@jupyterlab/settingregistry": "^4.2.0",
6363
"@langchain/core": "^0.3.13",
6464
"@langchain/mistralai": "^0.1.1",
65+
"@lumino/commands": "^2.1.2",
6566
"@lumino/coreutils": "^2.1.2",
6667
"@lumino/polling": "^2.1.2",
6768
"@lumino/signaling": "^2.1.2"

src/completion-provider.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ export class CompletionProvider implements IInlineCompletionProvider {
1616

1717
constructor(options: CompletionProvider.IOptions) {
1818
const { name, settings } = options;
19+
this._requestCompletion = options.requestCompletion;
1920
this.setCompleter(name, settings);
2021
}
2122

@@ -28,6 +29,9 @@ export class CompletionProvider implements IInlineCompletionProvider {
2829
setCompleter(name: string, settings: ReadonlyPartialJSONObject) {
2930
try {
3031
this._completer = getCompleter(name, settings);
32+
if (this._completer) {
33+
this._completer.requestCompletion = this._requestCompletion;
34+
}
3135
this._name = this._completer === null ? 'None' : name;
3236
} catch (e: any) {
3337
this._completer = null;
@@ -65,11 +69,13 @@ export class CompletionProvider implements IInlineCompletionProvider {
6569
}
6670

6771
private _name: string = 'None';
72+
private _requestCompletion: () => void;
6873
private _completer: IBaseCompleter | null = null;
6974
}
7075

7176
export namespace CompletionProvider {
7277
export interface IOptions extends BaseCompleter.IOptions {
7378
name: string;
79+
requestCompletion: () => void;
7480
}
7581
}

src/index.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ const aiProviderPlugin: JupyterFrontEndPlugin<IAIProvider> = {
100100
manager: ICompletionProviderManager,
101101
settingRegistry: ISettingRegistry
102102
): IAIProvider => {
103-
const aiProvider = new AIProvider({ completionProviderManager: manager });
103+
const aiProvider = new AIProvider({
104+
completionProviderManager: manager,
105+
requestCompletion: () => app.commands.execute('inline-completer:invoke')
106+
});
104107

105108
settingRegistry
106109
.load(aiProviderPlugin.id)

src/llm-models/base-completer.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ export interface IBaseCompleter {
1111
*/
1212
provider: LLM;
1313

14+
/**
15+
* The function to fetch a new completion.
16+
*/
17+
requestCompletion?: () => void;
18+
1419
/**
1520
* The fetch request for the LLM completer.
1621
*/

src/llm-models/codestral-completer.ts

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,70 @@ const INTERVAL = 1000;
1616

1717
export class CodestralCompleter implements IBaseCompleter {
1818
constructor(options: BaseCompleter.IOptions) {
19+
// this._requestCompletion = options.requestCompletion;
1920
this._mistralProvider = new MistralAI({ ...options.settings });
20-
this._throttler = new Throttler(async (data: CompletionRequest) => {
21-
const response = await this._mistralProvider.completionWithRetry(
22-
data,
23-
{},
24-
false
25-
);
26-
const items = response.choices.map((choice: any) => {
27-
return { insertText: choice.message.content as string };
28-
});
21+
this._throttler = new Throttler(
22+
async (data: CompletionRequest) => {
23+
this._invokedData = data;
24+
let fetchAgain = false;
2925

30-
return {
31-
items
32-
};
33-
}, INTERVAL);
26+
// Request completion.
27+
const response = await this._mistralProvider.completionWithRetry(
28+
data,
29+
{},
30+
false
31+
);
32+
33+
// Extract results of completion request.
34+
let items = response.choices.map((choice: any) => {
35+
return { insertText: choice.message.content as string };
36+
});
37+
38+
// Check if the prompt has changed during the request.
39+
if (this._invokedData.prompt !== this._currentData?.prompt) {
40+
// The current prompt does not include the invoked one, the result is
41+
// cancelled and a new completion will be requested.
42+
if (!this._currentData?.prompt.startsWith(this._invokedData.prompt)) {
43+
fetchAgain = true;
44+
items = [];
45+
} else {
46+
// Check if some results contain the current prompt, and return them if so,
47+
// otherwise request completion again.
48+
const newItems: { insertText: string }[] = [];
49+
items.forEach(item => {
50+
const result = this._invokedData!.prompt + item.insertText;
51+
if (result.startsWith(this._currentData!.prompt)) {
52+
const insertText = result.slice(
53+
this._currentData!.prompt.length
54+
);
55+
newItems.push({ insertText });
56+
}
57+
});
58+
if (newItems.length) {
59+
items = newItems;
60+
} else {
61+
fetchAgain = true;
62+
items = [];
63+
}
64+
}
65+
}
66+
return {
67+
items,
68+
fetchAgain
69+
};
70+
},
71+
{ limit: INTERVAL }
72+
);
3473
}
3574

3675
get provider(): LLM {
3776
return this._mistralProvider;
3877
}
3978

79+
set requestCompletion(value: () => void) {
80+
this._requestCompletion = value;
81+
}
82+
4083
async fetch(
4184
request: CompletionHandler.IRequest,
4285
context: IInlineCompletionContext
@@ -59,13 +102,23 @@ export class CodestralCompleter implements IBaseCompleter {
59102
};
60103

61104
try {
62-
return this._throttler.invoke(data);
105+
this._currentData = data;
106+
const completionResult = await this._throttler.invoke(data);
107+
if (completionResult.fetchAgain) {
108+
if (this._requestCompletion) {
109+
this._requestCompletion();
110+
}
111+
}
112+
return { items: completionResult.items };
63113
} catch (error) {
64114
console.error('Error fetching completions', error);
65115
return { items: [] };
66116
}
67117
}
68118

119+
private _requestCompletion?: () => void;
69120
private _throttler: Throttler;
70121
private _mistralProvider: MistralAI;
122+
private _invokedData: CompletionRequest | null = null;
123+
private _currentData: CompletionRequest | null = null;
71124
}

src/provider.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ export class AIProvider implements IAIProvider {
1212
constructor(options: AIProvider.IOptions) {
1313
this._completionProvider = new CompletionProvider({
1414
name: 'None',
15-
settings: {}
15+
settings: {},
16+
requestCompletion: options.requestCompletion
1617
});
1718
options.completionProviderManager.registerInlineProvider(
1819
this._completionProvider
@@ -103,6 +104,10 @@ export namespace AIProvider {
103104
* The completion provider manager in which register the LLM completer.
104105
*/
105106
completionProviderManager: ICompletionProviderManager;
107+
/**
108+
* The application commands registry.
109+
*/
110+
requestCompletion: () => void;
106111
}
107112

108113
/**

yarn.lock

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1740,7 +1740,7 @@ __metadata:
17401740
languageName: node
17411741
linkType: hard
17421742

1743-
"@lumino/commands@npm:^2.3.0":
1743+
"@lumino/commands@npm:^2.1.2, @lumino/commands@npm:^2.3.0":
17441744
version: 2.3.0
17451745
resolution: "@lumino/commands@npm:2.3.0"
17461746
dependencies:
@@ -4885,6 +4885,7 @@ __metadata:
48854885
"@jupyterlab/settingregistry": ^4.2.0
48864886
"@langchain/core": ^0.3.13
48874887
"@langchain/mistralai": ^0.1.1
4888+
"@lumino/commands": ^2.1.2
48884889
"@lumino/coreutils": ^2.1.2
48894890
"@lumino/polling": ^2.1.2
48904891
"@lumino/signaling": ^2.1.2

0 commit comments

Comments
 (0)