feat(client): add streaming support (#56)
* feat(client): add streaming support * fix: variable type definition * chore: add docs * chore: bump to client 0.9.0
This commit is contained in:
parent
6d112c89c5
commit
335b817e9c
87
apps/demo-nextjs-app-router/app/streaming/page.tsx
Normal file
87
apps/demo-nextjs-app-router/app/streaming/page.tsx
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
'use client';
|
||||||
|
|
||||||
|
import * as fal from '@fal-ai/serverless-client';
|
||||||
|
import { useState } from 'react';
|
||||||
|
|
||||||
|
fal.config({
|
||||||
|
proxyUrl: '/api/fal/proxy',
|
||||||
|
});
|
||||||
|
|
||||||
|
type LlavaInput = {
|
||||||
|
prompt: string;
|
||||||
|
image_url: string;
|
||||||
|
max_new_tokens?: number;
|
||||||
|
temperature?: number;
|
||||||
|
top_p?: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
type LlavaOutput = {
|
||||||
|
output: string;
|
||||||
|
partial: boolean;
|
||||||
|
stats: {
|
||||||
|
num_input_tokens: number;
|
||||||
|
num_output_tokens: number;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export default function StreamingDemo() {
|
||||||
|
const [answer, setAnswer] = useState<string>('');
|
||||||
|
const [streamStatus, setStreamStatus] = useState<string>('idle');
|
||||||
|
|
||||||
|
const runInference = async () => {
|
||||||
|
const stream = await fal.stream<LlavaInput, LlavaOutput>(
|
||||||
|
'fal-ai/llavav15-13b',
|
||||||
|
{
|
||||||
|
input: {
|
||||||
|
prompt:
|
||||||
|
'Do you know who drew this picture and what is the name of it?',
|
||||||
|
image_url: 'https://llava-vl.github.io/static/images/monalisa.jpg',
|
||||||
|
max_new_tokens: 100,
|
||||||
|
temperature: 0.2,
|
||||||
|
top_p: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
setStreamStatus('running');
|
||||||
|
|
||||||
|
for await (const partial of stream) {
|
||||||
|
setAnswer(partial.output);
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = await stream.done();
|
||||||
|
setStreamStatus('done');
|
||||||
|
setAnswer(result.output);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="min-h-screen dark:bg-gray-900 bg-gray-100">
|
||||||
|
<main className="container dark:text-gray-50 text-gray-900 flex flex-col items-center justify-center w-full flex-1 py-10 space-y-8">
|
||||||
|
<h1 className="text-4xl font-bold mb-8">
|
||||||
|
Hello <code className="text-pink-600">fal</code> +{' '}
|
||||||
|
<code className="text-indigo-500">streaming</code>
|
||||||
|
</h1>
|
||||||
|
|
||||||
|
<div className="flex flex-row space-x-2">
|
||||||
|
<button
|
||||||
|
onClick={runInference}
|
||||||
|
className="bg-indigo-600 hover:bg-indigo-700 text-white font-bold text-lg py-3 px-6 mx-auto rounded focus:outline-none focus:shadow-outline disabled:opacity-70"
|
||||||
|
>
|
||||||
|
Run inference
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="w-full flex flex-col space-y-4">
|
||||||
|
<div className="flex flex-row items-center justify-between">
|
||||||
|
<h2 className="text-2xl font-bold">Answer</h2>
|
||||||
|
<span>
|
||||||
|
streaming: <code className="font-semibold">{streamStatus}</code>
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<p className="text-lg p-4 border min-h-[12rem] border-gray-300 bg-gray-200 dark:bg-gray-800 dark:border-gray-700 rounded">
|
||||||
|
{answer}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</main>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"name": "@fal-ai/serverless-client",
|
"name": "@fal-ai/serverless-client",
|
||||||
"description": "The fal serverless JS/TS client",
|
"description": "The fal serverless JS/TS client",
|
||||||
"version": "0.8.6",
|
"version": "0.9.0",
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"repository": {
|
"repository": {
|
||||||
"type": "git",
|
"type": "git",
|
||||||
@ -17,6 +17,7 @@
|
|||||||
],
|
],
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@msgpack/msgpack": "^3.0.0-beta2",
|
"@msgpack/msgpack": "^3.0.0-beta2",
|
||||||
|
"eventsource-parser": "^1.1.2",
|
||||||
"robot3": "^0.4.1",
|
"robot3": "^0.4.1",
|
||||||
"uuid-random": "^1.3.2"
|
"uuid-random": "^1.3.2"
|
||||||
},
|
},
|
||||||
|
|||||||
26
libs/client/src/auth.ts
Normal file
26
libs/client/src/auth.ts
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import { getRestApiUrl } from './config';
|
||||||
|
import { dispatchRequest } from './request';
|
||||||
|
import { ensureAppIdFormat } from './utils';
|
||||||
|
|
||||||
|
export const TOKEN_EXPIRATION_SECONDS = 120;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a token to connect to the realtime endpoint.
|
||||||
|
*/
|
||||||
|
export async function getTemporaryAuthToken(app: string): Promise<string> {
|
||||||
|
const [, appAlias] = ensureAppIdFormat(app).split('/');
|
||||||
|
const token: string | object = await dispatchRequest<any, string>(
|
||||||
|
'POST',
|
||||||
|
`${getRestApiUrl()}/tokens/`,
|
||||||
|
{
|
||||||
|
allowed_apps: [appAlias],
|
||||||
|
token_expiration: TOKEN_EXPIRATION_SECONDS,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
// keep this in case the response was wrapped (old versions of the proxy do that)
|
||||||
|
// should be safe to remove in the future
|
||||||
|
if (typeof token !== 'string' && token['detail']) {
|
||||||
|
return token['detail'];
|
||||||
|
}
|
||||||
|
return token;
|
||||||
|
}
|
||||||
@ -6,6 +6,7 @@ export { realtimeImpl as realtime } from './realtime';
|
|||||||
export { ApiError, ValidationError } from './response';
|
export { ApiError, ValidationError } from './response';
|
||||||
export type { ResponseHandler } from './response';
|
export type { ResponseHandler } from './response';
|
||||||
export { storageImpl as storage } from './storage';
|
export { storageImpl as storage } from './storage';
|
||||||
|
export { stream } from './streaming';
|
||||||
export type {
|
export type {
|
||||||
QueueStatus,
|
QueueStatus,
|
||||||
ValidationErrorInfo,
|
ValidationErrorInfo,
|
||||||
|
|||||||
@ -13,8 +13,7 @@ import {
|
|||||||
transition,
|
transition,
|
||||||
} from 'robot3';
|
} from 'robot3';
|
||||||
import uuid from 'uuid-random';
|
import uuid from 'uuid-random';
|
||||||
import { getRestApiUrl } from './config';
|
import { TOKEN_EXPIRATION_SECONDS, getTemporaryAuthToken } from './auth';
|
||||||
import { dispatchRequest } from './request';
|
|
||||||
import { ApiError } from './response';
|
import { ApiError } from './response';
|
||||||
import { isBrowser } from './runtime';
|
import { isBrowser } from './runtime';
|
||||||
import { ensureAppIdFormat, isReact, throttle } from './utils';
|
import { ensureAppIdFormat, isReact, throttle } from './utils';
|
||||||
@ -280,7 +279,6 @@ function buildRealtimeUrl(
|
|||||||
return `wss://fal.run/${appId}/${suffix}?${queryParams.toString()}`;
|
return `wss://fal.run/${appId}/${suffix}?${queryParams.toString()}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
const TOKEN_EXPIRATION_SECONDS = 120;
|
|
||||||
const DEFAULT_THROTTLE_INTERVAL = 128;
|
const DEFAULT_THROTTLE_INTERVAL = 128;
|
||||||
|
|
||||||
function shouldSendBinary(message: any): boolean {
|
function shouldSendBinary(message: any): boolean {
|
||||||
@ -292,27 +290,6 @@ function shouldSendBinary(message: any): boolean {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Get a token to connect to the realtime endpoint.
|
|
||||||
*/
|
|
||||||
async function getToken(app: string): Promise<string> {
|
|
||||||
const [, appAlias] = ensureAppIdFormat(app).split('/');
|
|
||||||
const token: string | object = await dispatchRequest<any, string>(
|
|
||||||
'POST',
|
|
||||||
`${getRestApiUrl()}/tokens/`,
|
|
||||||
{
|
|
||||||
allowed_apps: [appAlias],
|
|
||||||
token_expiration: TOKEN_EXPIRATION_SECONDS,
|
|
||||||
}
|
|
||||||
);
|
|
||||||
// keep this in case the response was wrapped (old versions of the proxy do that)
|
|
||||||
// should be safe to remove in the future
|
|
||||||
if (typeof token !== 'string' && token['detail']) {
|
|
||||||
return token['detail'];
|
|
||||||
}
|
|
||||||
return token;
|
|
||||||
}
|
|
||||||
|
|
||||||
function isUnauthorizedError(message: any): boolean {
|
function isUnauthorizedError(message: any): boolean {
|
||||||
// TODO we need better protocol definition with error codes
|
// TODO we need better protocol definition with error codes
|
||||||
return message['status'] === 'error' && message['error'] === 'Unauthorized';
|
return message['status'] === 'error' && message['error'] === 'Unauthorized';
|
||||||
@ -441,7 +418,7 @@ export const realtimeImpl: RealtimeClient = {
|
|||||||
previousState !== machine.current
|
previousState !== machine.current
|
||||||
) {
|
) {
|
||||||
send({ type: 'initiateAuth' });
|
send({ type: 'initiateAuth' });
|
||||||
getToken(app)
|
getTemporaryAuthToken(app)
|
||||||
.then((token) => {
|
.then((token) => {
|
||||||
send({ type: 'authenticated', token });
|
send({ type: 'authenticated', token });
|
||||||
const tokenExpirationTimeout = Math.round(
|
const tokenExpirationTimeout = Math.round(
|
||||||
|
|||||||
251
libs/client/src/streaming.ts
Normal file
251
libs/client/src/streaming.ts
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
import { createParser } from 'eventsource-parser';
|
||||||
|
import { getTemporaryAuthToken } from './auth';
|
||||||
|
import { buildUrl } from './function';
|
||||||
|
import { ApiError, defaultResponseHandler } from './response';
|
||||||
|
import { storageImpl } from './storage';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The stream API options. It requires the API input and also
|
||||||
|
* offers configuration options.
|
||||||
|
*/
|
||||||
|
type StreamOptions<Input> = {
|
||||||
|
/**
|
||||||
|
* The API input payload.
|
||||||
|
*/
|
||||||
|
input: Input;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The maximum time interval in milliseconds between stream chunks. Defaults to 15s.
|
||||||
|
*/
|
||||||
|
timeout?: number;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether it should auto-upload File-like types to fal's storage
|
||||||
|
* or not.
|
||||||
|
*/
|
||||||
|
autoUpload?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
const EVENT_STREAM_TIMEOUT = 15 * 1000;
|
||||||
|
|
||||||
|
type FalStreamEventType = 'message' | 'error' | 'done';
|
||||||
|
|
||||||
|
type EventHandler = (event: any) => void;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The class representing a streaming response. With t
|
||||||
|
*/
|
||||||
|
class FalStream<Input, Output> {
|
||||||
|
// properties
|
||||||
|
url: string;
|
||||||
|
options: StreamOptions<Input>;
|
||||||
|
|
||||||
|
// support for event listeners
|
||||||
|
private listeners: Map<FalStreamEventType, EventHandler[]> = new Map();
|
||||||
|
private buffer: Output[] = [];
|
||||||
|
|
||||||
|
// local state
|
||||||
|
private currentData: Output | undefined = undefined;
|
||||||
|
private lastEventTimestamp = 0;
|
||||||
|
private streamClosed = false;
|
||||||
|
private donePromise: Promise<Output>;
|
||||||
|
|
||||||
|
constructor(url: string, options: StreamOptions<Input>) {
|
||||||
|
this.url = url;
|
||||||
|
this.options = options;
|
||||||
|
this.donePromise = new Promise<Output>((resolve, reject) => {
|
||||||
|
if (this.streamClosed) {
|
||||||
|
reject(
|
||||||
|
new ApiError({
|
||||||
|
message: 'Streaming connection is already closed.',
|
||||||
|
status: 400,
|
||||||
|
body: undefined,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
this.on('done', (data) => {
|
||||||
|
this.streamClosed = true;
|
||||||
|
resolve(data);
|
||||||
|
});
|
||||||
|
this.on('error', (error) => {
|
||||||
|
this.streamClosed = true;
|
||||||
|
reject(error);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
this.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
private start = async () => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(this.url, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
accept: 'text/event-stream',
|
||||||
|
'content-type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify(this.options.input),
|
||||||
|
});
|
||||||
|
this.handleResponse(response);
|
||||||
|
} catch (error) {
|
||||||
|
this.handleError(error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
private handleResponse = async (response: Response) => {
|
||||||
|
if (!response.ok) {
|
||||||
|
try {
|
||||||
|
// we know the response failed, call the response handler
|
||||||
|
// so the exception gets converted to ApiError correctly
|
||||||
|
await defaultResponseHandler(response);
|
||||||
|
} catch (error) {
|
||||||
|
this.emit('error', error);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const body = response.body;
|
||||||
|
if (!body) {
|
||||||
|
this.emit(
|
||||||
|
'error',
|
||||||
|
new ApiError({
|
||||||
|
message: 'Response body is empty.',
|
||||||
|
status: 400,
|
||||||
|
body: undefined,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const decoder = new TextDecoder('utf-8');
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
|
||||||
|
const parser = createParser((event) => {
|
||||||
|
if (event.type === 'event') {
|
||||||
|
const data = event.data;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const parsedData = JSON.parse(data);
|
||||||
|
this.buffer.push(parsedData);
|
||||||
|
this.currentData = parsedData;
|
||||||
|
this.emit('message', parsedData);
|
||||||
|
} catch (e) {
|
||||||
|
this.emit('error', e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const timeout = this.options.timeout ?? EVENT_STREAM_TIMEOUT;
|
||||||
|
|
||||||
|
const readPartialResponse = async () => {
|
||||||
|
const { value, done } = await reader.read();
|
||||||
|
this.lastEventTimestamp = Date.now();
|
||||||
|
|
||||||
|
parser.feed(decoder.decode(value));
|
||||||
|
|
||||||
|
if (Date.now() - this.lastEventTimestamp > timeout) {
|
||||||
|
this.emit(
|
||||||
|
'error',
|
||||||
|
new ApiError({
|
||||||
|
message:
|
||||||
|
'Event stream timed out after 15 seconds with no messages.',
|
||||||
|
status: 408,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!done) {
|
||||||
|
readPartialResponse().catch(this.handleError);
|
||||||
|
} else {
|
||||||
|
this.emit('done', this.currentData);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
readPartialResponse().catch(this.handleError);
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
private handleError = (error: any) => {
|
||||||
|
const apiError =
|
||||||
|
error instanceof ApiError
|
||||||
|
? error
|
||||||
|
: new ApiError({
|
||||||
|
message: error.message ?? 'An unknown error occurred',
|
||||||
|
status: 500,
|
||||||
|
});
|
||||||
|
this.emit('error', apiError);
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
public on = (type: FalStreamEventType, listener: EventHandler) => {
|
||||||
|
if (!this.listeners.has(type)) {
|
||||||
|
this.listeners.set(type, []);
|
||||||
|
}
|
||||||
|
this.listeners.get(type)?.push(listener);
|
||||||
|
};
|
||||||
|
|
||||||
|
private emit = (type: FalStreamEventType, event: any) => {
|
||||||
|
const listeners = this.listeners.get(type) || [];
|
||||||
|
for (const listener of listeners) {
|
||||||
|
listener(event);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
let running = true;
|
||||||
|
const stopAsyncIterator = () => (running = false);
|
||||||
|
this.on('error', stopAsyncIterator);
|
||||||
|
this.on('done', stopAsyncIterator);
|
||||||
|
while (running) {
|
||||||
|
const data = this.buffer.shift();
|
||||||
|
if (data) {
|
||||||
|
yield data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// the short timeout ensures the while loop doesn't block other
|
||||||
|
// frames getting executed concurrently
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 16));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets a reference to the `Promise` that indicates whether the streaming
|
||||||
|
* is done or not. Developers should always call this in their apps to ensure
|
||||||
|
* the request is over.
|
||||||
|
*
|
||||||
|
* An alternative to this, is to use `on('done')` in case your application
|
||||||
|
* architecture works best with event listeners.
|
||||||
|
*
|
||||||
|
* @returns the promise that resolves when the request is done.
|
||||||
|
*/
|
||||||
|
public done = async () => this.donePromise;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls a fal app that supports streaming and provides a streaming-capable
|
||||||
|
* object as a result, that can be used to get partial results through either
|
||||||
|
* `AsyncIterator` or through an event listener.
|
||||||
|
*
|
||||||
|
* @param appId the app id, e.g. `fal-ai/llavav15-13b`.
|
||||||
|
* @param options the request options, including the input payload.
|
||||||
|
* @returns the `FalStream` instance.
|
||||||
|
*/
|
||||||
|
export async function stream<Input = Record<string, any>, Output = any>(
|
||||||
|
appId: string,
|
||||||
|
options: StreamOptions<Input>
|
||||||
|
): Promise<FalStream<Input, Output>> {
|
||||||
|
const token = await getTemporaryAuthToken(appId);
|
||||||
|
const url = buildUrl(appId, { path: '/stream' });
|
||||||
|
|
||||||
|
const input =
|
||||||
|
options.input && options.autoUpload !== false
|
||||||
|
? await storageImpl.transformInput(options.input)
|
||||||
|
: options.input;
|
||||||
|
|
||||||
|
const queryParams = new URLSearchParams({
|
||||||
|
fal_jwt_token: token,
|
||||||
|
});
|
||||||
|
|
||||||
|
return new FalStream<Input, Output>(`${url}?${queryParams}`, {
|
||||||
|
...options,
|
||||||
|
input: input as Input,
|
||||||
|
});
|
||||||
|
}
|
||||||
9
package-lock.json
generated
9
package-lock.json
generated
@ -24,6 +24,7 @@
|
|||||||
"cross-fetch": "^3.1.5",
|
"cross-fetch": "^3.1.5",
|
||||||
"dotenv": "^16.3.1",
|
"dotenv": "^16.3.1",
|
||||||
"encoding": "^0.1.13",
|
"encoding": "^0.1.13",
|
||||||
|
"eventsource-parser": "^1.1.2",
|
||||||
"execa": "^8.0.1",
|
"execa": "^8.0.1",
|
||||||
"express": "^4.18.2",
|
"express": "^4.18.2",
|
||||||
"fast-glob": "^3.2.12",
|
"fast-glob": "^3.2.12",
|
||||||
@ -14490,6 +14491,14 @@
|
|||||||
"node": ">=0.8.x"
|
"node": ">=0.8.x"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/eventsource-parser": {
|
||||||
|
"version": "1.1.2",
|
||||||
|
"resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-1.1.2.tgz",
|
||||||
|
"integrity": "sha512-v0eOBUbiaFojBu2s2NPBfYUoRR9GjcDNvCXVaqEf5vVfpIAh9f8RCo4vXTP8c63QRKCFwoLpMpTdPwwhEKVgzA==",
|
||||||
|
"engines": {
|
||||||
|
"node": ">=14.18"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/execa": {
|
"node_modules/execa": {
|
||||||
"version": "8.0.1",
|
"version": "8.0.1",
|
||||||
"resolved": "https://registry.npmjs.org/execa/-/execa-8.0.1.tgz",
|
"resolved": "https://registry.npmjs.org/execa/-/execa-8.0.1.tgz",
|
||||||
|
|||||||
@ -40,6 +40,7 @@
|
|||||||
"cross-fetch": "^3.1.5",
|
"cross-fetch": "^3.1.5",
|
||||||
"dotenv": "^16.3.1",
|
"dotenv": "^16.3.1",
|
||||||
"encoding": "^0.1.13",
|
"encoding": "^0.1.13",
|
||||||
|
"eventsource-parser": "^1.1.2",
|
||||||
"execa": "^8.0.1",
|
"execa": "^8.0.1",
|
||||||
"express": "^4.18.2",
|
"express": "^4.18.2",
|
||||||
"fast-glob": "^3.2.12",
|
"fast-glob": "^3.2.12",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user