feat(client): add streaming support (#56)

* feat(client): add streaming support

* fix: variable type definition

* chore: add docs

* chore: bump to client 0.9.0
This commit is contained in:
Daniel Rochetti 2024-03-06 08:44:04 -08:00 committed by GitHub
parent 6d112c89c5
commit 335b817e9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 379 additions and 26 deletions

View File

@ -0,0 +1,87 @@
'use client';
import * as fal from '@fal-ai/serverless-client';
import { useState } from 'react';
fal.config({
proxyUrl: '/api/fal/proxy',
});
type LlavaInput = {
prompt: string;
image_url: string;
max_new_tokens?: number;
temperature?: number;
top_p?: number;
};
type LlavaOutput = {
output: string;
partial: boolean;
stats: {
num_input_tokens: number;
num_output_tokens: number;
};
};
export default function StreamingDemo() {
const [answer, setAnswer] = useState<string>('');
const [streamStatus, setStreamStatus] = useState<string>('idle');
const runInference = async () => {
const stream = await fal.stream<LlavaInput, LlavaOutput>(
'fal-ai/llavav15-13b',
{
input: {
prompt:
'Do you know who drew this picture and what is the name of it?',
image_url: 'https://llava-vl.github.io/static/images/monalisa.jpg',
max_new_tokens: 100,
temperature: 0.2,
top_p: 1,
},
}
);
setStreamStatus('running');
for await (const partial of stream) {
setAnswer(partial.output);
}
const result = await stream.done();
setStreamStatus('done');
setAnswer(result.output);
};
return (
<div className="min-h-screen dark:bg-gray-900 bg-gray-100">
<main className="container dark:text-gray-50 text-gray-900 flex flex-col items-center justify-center w-full flex-1 py-10 space-y-8">
<h1 className="text-4xl font-bold mb-8">
Hello <code className="text-pink-600">fal</code> +{' '}
<code className="text-indigo-500">streaming</code>
</h1>
<div className="flex flex-row space-x-2">
<button
onClick={runInference}
className="bg-indigo-600 hover:bg-indigo-700 text-white font-bold text-lg py-3 px-6 mx-auto rounded focus:outline-none focus:shadow-outline disabled:opacity-70"
>
Run inference
</button>
</div>
<div className="w-full flex flex-col space-y-4">
<div className="flex flex-row items-center justify-between">
<h2 className="text-2xl font-bold">Answer</h2>
<span>
streaming: <code className="font-semibold">{streamStatus}</code>
</span>
</div>
<p className="text-lg p-4 border min-h-[12rem] border-gray-300 bg-gray-200 dark:bg-gray-800 dark:border-gray-700 rounded">
{answer}
</p>
</div>
</main>
</div>
);
}

View File

@ -1,7 +1,7 @@
{ {
"name": "@fal-ai/serverless-client", "name": "@fal-ai/serverless-client",
"description": "The fal serverless JS/TS client", "description": "The fal serverless JS/TS client",
"version": "0.8.6", "version": "0.9.0",
"license": "MIT", "license": "MIT",
"repository": { "repository": {
"type": "git", "type": "git",
@ -17,6 +17,7 @@
], ],
"dependencies": { "dependencies": {
"@msgpack/msgpack": "^3.0.0-beta2", "@msgpack/msgpack": "^3.0.0-beta2",
"eventsource-parser": "^1.1.2",
"robot3": "^0.4.1", "robot3": "^0.4.1",
"uuid-random": "^1.3.2" "uuid-random": "^1.3.2"
}, },

26
libs/client/src/auth.ts Normal file
View File

@ -0,0 +1,26 @@
import { getRestApiUrl } from './config';
import { dispatchRequest } from './request';
import { ensureAppIdFormat } from './utils';
export const TOKEN_EXPIRATION_SECONDS = 120;
/**
* Get a token to connect to the realtime endpoint.
*/
export async function getTemporaryAuthToken(app: string): Promise<string> {
const [, appAlias] = ensureAppIdFormat(app).split('/');
const token: string | object = await dispatchRequest<any, string>(
'POST',
`${getRestApiUrl()}/tokens/`,
{
allowed_apps: [appAlias],
token_expiration: TOKEN_EXPIRATION_SECONDS,
}
);
// keep this in case the response was wrapped (old versions of the proxy do that)
// should be safe to remove in the future
if (typeof token !== 'string' && token['detail']) {
return token['detail'];
}
return token;
}

View File

@ -6,6 +6,7 @@ export { realtimeImpl as realtime } from './realtime';
export { ApiError, ValidationError } from './response'; export { ApiError, ValidationError } from './response';
export type { ResponseHandler } from './response'; export type { ResponseHandler } from './response';
export { storageImpl as storage } from './storage'; export { storageImpl as storage } from './storage';
export { stream } from './streaming';
export type { export type {
QueueStatus, QueueStatus,
ValidationErrorInfo, ValidationErrorInfo,

View File

@ -13,8 +13,7 @@ import {
transition, transition,
} from 'robot3'; } from 'robot3';
import uuid from 'uuid-random'; import uuid from 'uuid-random';
import { getRestApiUrl } from './config'; import { TOKEN_EXPIRATION_SECONDS, getTemporaryAuthToken } from './auth';
import { dispatchRequest } from './request';
import { ApiError } from './response'; import { ApiError } from './response';
import { isBrowser } from './runtime'; import { isBrowser } from './runtime';
import { ensureAppIdFormat, isReact, throttle } from './utils'; import { ensureAppIdFormat, isReact, throttle } from './utils';
@ -280,7 +279,6 @@ function buildRealtimeUrl(
return `wss://fal.run/${appId}/${suffix}?${queryParams.toString()}`; return `wss://fal.run/${appId}/${suffix}?${queryParams.toString()}`;
} }
const TOKEN_EXPIRATION_SECONDS = 120;
const DEFAULT_THROTTLE_INTERVAL = 128; const DEFAULT_THROTTLE_INTERVAL = 128;
function shouldSendBinary(message: any): boolean { function shouldSendBinary(message: any): boolean {
@ -292,27 +290,6 @@ function shouldSendBinary(message: any): boolean {
); );
} }
/**
* Get a token to connect to the realtime endpoint.
*/
async function getToken(app: string): Promise<string> {
const [, appAlias] = ensureAppIdFormat(app).split('/');
const token: string | object = await dispatchRequest<any, string>(
'POST',
`${getRestApiUrl()}/tokens/`,
{
allowed_apps: [appAlias],
token_expiration: TOKEN_EXPIRATION_SECONDS,
}
);
// keep this in case the response was wrapped (old versions of the proxy do that)
// should be safe to remove in the future
if (typeof token !== 'string' && token['detail']) {
return token['detail'];
}
return token;
}
function isUnauthorizedError(message: any): boolean { function isUnauthorizedError(message: any): boolean {
// TODO we need better protocol definition with error codes // TODO we need better protocol definition with error codes
return message['status'] === 'error' && message['error'] === 'Unauthorized'; return message['status'] === 'error' && message['error'] === 'Unauthorized';
@ -441,7 +418,7 @@ export const realtimeImpl: RealtimeClient = {
previousState !== machine.current previousState !== machine.current
) { ) {
send({ type: 'initiateAuth' }); send({ type: 'initiateAuth' });
getToken(app) getTemporaryAuthToken(app)
.then((token) => { .then((token) => {
send({ type: 'authenticated', token }); send({ type: 'authenticated', token });
const tokenExpirationTimeout = Math.round( const tokenExpirationTimeout = Math.round(

View File

@ -0,0 +1,251 @@
import { createParser } from 'eventsource-parser';
import { getTemporaryAuthToken } from './auth';
import { buildUrl } from './function';
import { ApiError, defaultResponseHandler } from './response';
import { storageImpl } from './storage';
/**
* The stream API options. It requires the API input and also
* offers configuration options.
*/
type StreamOptions<Input> = {
/**
* The API input payload.
*/
input: Input;
/**
* The maximum time interval in milliseconds between stream chunks. Defaults to 15s.
*/
timeout?: number;
/**
* Whether it should auto-upload File-like types to fal's storage
* or not.
*/
autoUpload?: boolean;
};
const EVENT_STREAM_TIMEOUT = 15 * 1000;
type FalStreamEventType = 'message' | 'error' | 'done';
type EventHandler = (event: any) => void;
/**
* The class representing a streaming response. With t
*/
class FalStream<Input, Output> {
// properties
url: string;
options: StreamOptions<Input>;
// support for event listeners
private listeners: Map<FalStreamEventType, EventHandler[]> = new Map();
private buffer: Output[] = [];
// local state
private currentData: Output | undefined = undefined;
private lastEventTimestamp = 0;
private streamClosed = false;
private donePromise: Promise<Output>;
constructor(url: string, options: StreamOptions<Input>) {
this.url = url;
this.options = options;
this.donePromise = new Promise<Output>((resolve, reject) => {
if (this.streamClosed) {
reject(
new ApiError({
message: 'Streaming connection is already closed.',
status: 400,
body: undefined,
})
);
}
this.on('done', (data) => {
this.streamClosed = true;
resolve(data);
});
this.on('error', (error) => {
this.streamClosed = true;
reject(error);
});
});
this.start();
}
private start = async () => {
try {
const response = await fetch(this.url, {
method: 'POST',
headers: {
accept: 'text/event-stream',
'content-type': 'application/json',
},
body: JSON.stringify(this.options.input),
});
this.handleResponse(response);
} catch (error) {
this.handleError(error);
}
};
private handleResponse = async (response: Response) => {
if (!response.ok) {
try {
// we know the response failed, call the response handler
// so the exception gets converted to ApiError correctly
await defaultResponseHandler(response);
} catch (error) {
this.emit('error', error);
}
return;
}
const body = response.body;
if (!body) {
this.emit(
'error',
new ApiError({
message: 'Response body is empty.',
status: 400,
body: undefined,
})
);
return;
}
const decoder = new TextDecoder('utf-8');
const reader = response.body.getReader();
const parser = createParser((event) => {
if (event.type === 'event') {
const data = event.data;
try {
const parsedData = JSON.parse(data);
this.buffer.push(parsedData);
this.currentData = parsedData;
this.emit('message', parsedData);
} catch (e) {
this.emit('error', e);
}
}
});
const timeout = this.options.timeout ?? EVENT_STREAM_TIMEOUT;
const readPartialResponse = async () => {
const { value, done } = await reader.read();
this.lastEventTimestamp = Date.now();
parser.feed(decoder.decode(value));
if (Date.now() - this.lastEventTimestamp > timeout) {
this.emit(
'error',
new ApiError({
message:
'Event stream timed out after 15 seconds with no messages.',
status: 408,
})
);
}
if (!done) {
readPartialResponse().catch(this.handleError);
} else {
this.emit('done', this.currentData);
}
};
readPartialResponse().catch(this.handleError);
return;
};
private handleError = (error: any) => {
const apiError =
error instanceof ApiError
? error
: new ApiError({
message: error.message ?? 'An unknown error occurred',
status: 500,
});
this.emit('error', apiError);
return;
};
public on = (type: FalStreamEventType, listener: EventHandler) => {
if (!this.listeners.has(type)) {
this.listeners.set(type, []);
}
this.listeners.get(type)?.push(listener);
};
private emit = (type: FalStreamEventType, event: any) => {
const listeners = this.listeners.get(type) || [];
for (const listener of listeners) {
listener(event);
}
};
async *[Symbol.asyncIterator]() {
let running = true;
const stopAsyncIterator = () => (running = false);
this.on('error', stopAsyncIterator);
this.on('done', stopAsyncIterator);
while (running) {
const data = this.buffer.shift();
if (data) {
yield data;
}
// the short timeout ensures the while loop doesn't block other
// frames getting executed concurrently
await new Promise((resolve) => setTimeout(resolve, 16));
}
}
/**
* Gets a reference to the `Promise` that indicates whether the streaming
* is done or not. Developers should always call this in their apps to ensure
* the request is over.
*
* An alternative to this, is to use `on('done')` in case your application
* architecture works best with event listeners.
*
* @returns the promise that resolves when the request is done.
*/
public done = async () => this.donePromise;
}
/**
* Calls a fal app that supports streaming and provides a streaming-capable
* object as a result, that can be used to get partial results through either
* `AsyncIterator` or through an event listener.
*
* @param appId the app id, e.g. `fal-ai/llavav15-13b`.
* @param options the request options, including the input payload.
* @returns the `FalStream` instance.
*/
export async function stream<Input = Record<string, any>, Output = any>(
appId: string,
options: StreamOptions<Input>
): Promise<FalStream<Input, Output>> {
const token = await getTemporaryAuthToken(appId);
const url = buildUrl(appId, { path: '/stream' });
const input =
options.input && options.autoUpload !== false
? await storageImpl.transformInput(options.input)
: options.input;
const queryParams = new URLSearchParams({
fal_jwt_token: token,
});
return new FalStream<Input, Output>(`${url}?${queryParams}`, {
...options,
input: input as Input,
});
}

9
package-lock.json generated
View File

@ -24,6 +24,7 @@
"cross-fetch": "^3.1.5", "cross-fetch": "^3.1.5",
"dotenv": "^16.3.1", "dotenv": "^16.3.1",
"encoding": "^0.1.13", "encoding": "^0.1.13",
"eventsource-parser": "^1.1.2",
"execa": "^8.0.1", "execa": "^8.0.1",
"express": "^4.18.2", "express": "^4.18.2",
"fast-glob": "^3.2.12", "fast-glob": "^3.2.12",
@ -14490,6 +14491,14 @@
"node": ">=0.8.x" "node": ">=0.8.x"
} }
}, },
"node_modules/eventsource-parser": {
"version": "1.1.2",
"resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-1.1.2.tgz",
"integrity": "sha512-v0eOBUbiaFojBu2s2NPBfYUoRR9GjcDNvCXVaqEf5vVfpIAh9f8RCo4vXTP8c63QRKCFwoLpMpTdPwwhEKVgzA==",
"engines": {
"node": ">=14.18"
}
},
"node_modules/execa": { "node_modules/execa": {
"version": "8.0.1", "version": "8.0.1",
"resolved": "https://registry.npmjs.org/execa/-/execa-8.0.1.tgz", "resolved": "https://registry.npmjs.org/execa/-/execa-8.0.1.tgz",

View File

@ -40,6 +40,7 @@
"cross-fetch": "^3.1.5", "cross-fetch": "^3.1.5",
"dotenv": "^16.3.1", "dotenv": "^16.3.1",
"encoding": "^0.1.13", "encoding": "^0.1.13",
"eventsource-parser": "^1.1.2",
"execa": "^8.0.1", "execa": "^8.0.1",
"express": "^4.18.2", "express": "^4.18.2",
"fast-glob": "^3.2.12", "fast-glob": "^3.2.12",