Skip to content

Add token streaming for text generation #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/inference/.eslintignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
dist
tapes.json
tapes.json
src/vendor
3 changes: 2 additions & 1 deletion packages/inference/.prettierignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pnpm-lock.yaml
# In order to avoid code samples to have tabs, they don't display well on npm
README.md
dist
tapes.json
tapes.json
src/vendor
7 changes: 7 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ await hf.textGeneration({
inputs: 'The answer to the universe is'
})

for await const (output of hf.textGenerationStream({
model: "google/flan-t5-xxl",
inputs: 'repeat "one two three four"'
})) {
console.log(output.token.text, output.generated_text);
}

await hf.tokenClassification({
model: 'dbmdz/bert-large-cased-finetuned-conll03-english',
inputs: 'My name is Sarah Jessica Parker but you can call me Jessica'
Expand Down
187 changes: 182 additions & 5 deletions packages/inference/src/HfInference.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import { toArray } from "./utils/to-array";
import type { EventSourceMessage } from "./vendor/fetch-event-source/parse";
import { getLines, getMessages } from "./vendor/fetch-event-source/parse";

const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co/models/";

export interface Options {
/**
Expand Down Expand Up @@ -223,6 +227,86 @@ export interface TextGenerationReturn {
generated_text: string;
}

export interface TextGenerationStreamToken {
/** Token ID from the model tokenizer */
id: number;
/** Token text */
text: string;
/** Logprob */
logprob: number;
/**
* Is the token a special token
* Can be used to ignore tokens when concatenating
*/
special: boolean;
}

export interface TextGenerationStreamPrefillToken {
/** Token ID from the model tokenizer */
id: number;
/** Token text */
text: string;
/**
* Logprob
* Optional since the logprob of the first token cannot be computed
*/
logprob?: number;
}

export interface TextGenerationStreamBestOfSequence {
/** Generated text */
generated_text: string;
/** Generation finish reason */
finish_reason: TextGenerationStreamFinishReason;
/** Number of generated tokens */
generated_tokens: number;
/** Sampling seed if sampling was activated */
seed?: number;
/** Prompt tokens */
prefill: TextGenerationStreamPrefillToken[];
/** Generated tokens */
tokens: TextGenerationStreamToken[];
}

export enum TextGenerationStreamFinishReason {
/** number of generated tokens == `max_new_tokens` */
Length = "length",
/** the model generated its end of sequence token */
EndOfSequenceToken = "eos_token",
/** the model generated a text included in `stop_sequences` */
StopSequence = "stop_sequence",
}

export interface TextGenerationStreamDetails {
/** Generation finish reason */
finish_reason: TextGenerationStreamFinishReason;
/** Number of generated tokens */
generated_tokens: number;
/** Sampling seed if sampling was activated */
seed?: number;
/** Prompt tokens */
prefill: TextGenerationStreamPrefillToken[];
/** */
tokens: TextGenerationStreamToken[];
/** Additional sequences when using the `best_of` parameter */
best_of_sequences?: TextGenerationStreamBestOfSequence[];
}

export interface TextGenerationStreamReturn {
/** Generated token, one at a time */
token: TextGenerationStreamToken;
/**
* Complete generated text
* Only available when the generation is finished
*/
generated_text?: string;
/**
* Generation details
* Only available when the generation is finished
*/
details?: TextGenerationStreamDetails;
}

export type TokenClassificationArgs = Args & {
/**
* A string to be classified
Expand Down Expand Up @@ -615,6 +699,16 @@ export class HfInference {
return res?.[0];
}

/**
* Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
*/
public async *textGenerationStream(
args: TextGenerationArgs,
options?: Options
): AsyncGenerator<TextGenerationStreamReturn> {
yield* this.streamingRequest<TextGenerationStreamReturn>(args, options);
}

/**
* Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. Recommended model: dbmdz/bert-large-cased-finetuned-conll03-english
*/
Expand Down Expand Up @@ -834,15 +928,21 @@ export class HfInference {
return res;
}

public async request<T>(
args: Args & { data?: Blob | ArrayBuffer },
/**
* Helper that prepares request arguments
*/
private makeRequestOptions(
args: Args & {
data?: Blob | ArrayBuffer;
stream?: boolean;
},
options?: Options & {
binary?: boolean;
blob?: boolean;
/** For internal HF use, which is why it's not exposed in {@link Options} */
includeCredentials?: boolean;
}
): Promise<T> {
) {
const mergedOptions = { ...this.defaultOptions, ...options };
const { model, ...otherArgs } = args;

Expand All @@ -867,7 +967,8 @@ export class HfInference {
}
}

const response = await fetch(`https://api-inference.huggingface.co/models/${model}`, {
const url = `${HF_INFERENCE_API_BASE_URL}${model}`;
const info: RequestInit = {
headers,
method: "POST",
body: options?.binary
Expand All @@ -877,7 +978,22 @@ export class HfInference {
options: mergedOptions,
}),
credentials: options?.includeCredentials ? "include" : "same-origin",
});
};

return { url, info, mergedOptions };
}

public async request<T>(
args: Args & { data?: Blob | ArrayBuffer },
options?: Options & {
binary?: boolean;
blob?: boolean;
/** For internal HF use, which is why it's not exposed in {@link Options} */
includeCredentials?: boolean;
}
): Promise<T> {
const { url, info, mergedOptions } = this.makeRequestOptions(args, options);
const response = await fetch(url, info);

if (mergedOptions.retry_on_error !== false && response.status === 503 && !mergedOptions.wait_for_model) {
return this.request(args, {
Expand All @@ -899,4 +1015,65 @@ export class HfInference {
}
return output;
}

/**
* Make request that uses server-sent events and returns response as a generator
*/
public async *streamingRequest<T>(
args: Args & { data?: Blob | ArrayBuffer },
options?: Options & {
binary?: boolean;
blob?: boolean;
/** For internal HF use, which is why it's not exposed in {@link Options} */
includeCredentials?: boolean;
}
): AsyncGenerator<T> {
const { url, info, mergedOptions } = this.makeRequestOptions({ ...args, stream: true }, options);
const response = await fetch(url, info);

if (mergedOptions.retry_on_error !== false && response.status === 503 && !mergedOptions.wait_for_model) {
return this.streamingRequest(args, {
...mergedOptions,
wait_for_model: true,
});
}
if (!response.ok) {
throw new Error(`Server response contains error: ${response.status}`);
}
if (response.headers.get("content-type") !== "text/event-stream") {
throw new Error(`Server does not support event stream content type`);
}

const reader = response.body.getReader();
const events: EventSourceMessage[] = [];

const onEvent = (event: EventSourceMessage) => {
// accumulate events in array
events.push(event);
};

const onChunk = getLines(
getMessages(
() => {},
() => {},
onEvent
)
);

try {
while (true) {
const { done, value } = await reader.read();
if (done) return;
onChunk(value);
while (events.length > 0) {
const event = events.shift();
if (event.data.length > 0) {
yield JSON.parse(event.data) as T;
}
}
}
} finally {
reader.releaseLock();
}
}
}
Loading