feat(client): stream queue status (#71)

* feat(client): stream queue status

* chore: remove console log

* fix: accumulative logs when streaming

* fix(client): stream logs on queue update

* chore(apps): remove pollInterval from sample apps
This commit is contained in:
Daniel Rochetti 2024-06-25 15:00:01 -07:00 committed by GitHub
parent c7910163a7
commit 4ea43b4cea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 231 additions and 57 deletions

View File

@ -90,7 +90,6 @@ export default function ComfyImageToImagePage() {
prompt: prompt, prompt: prompt,
loadimage_1: imageFile, loadimage_1: imageFile,
}, },
pollInterval: 3000, // Default is 1000 (every 1s)
logs: true, logs: true,
onQueueUpdate(update) { onQueueUpdate(update) {
setElapsedTime(Date.now() - start); setElapsedTime(Date.now() - start);

View File

@ -85,7 +85,6 @@ export default function ComfyImageToVideoPage() {
input: { input: {
loadimage_1: imageFile, loadimage_1: imageFile,
}, },
pollInterval: 3000, // Default is 1000 (every 1s)
logs: true, logs: true,
onQueueUpdate(update) { onQueueUpdate(update) {
setElapsedTime(Date.now() - start); setElapsedTime(Date.now() - start);

View File

@ -86,7 +86,6 @@ export default function ComfyTextToImagePage() {
input: { input: {
prompt: prompt, prompt: prompt,
}, },
pollInterval: 3000, // Default is 1000 (every 1s)
logs: true, logs: true,
onQueueUpdate(update) { onQueueUpdate(update) {
setElapsedTime(Date.now() - start); setElapsedTime(Date.now() - start);

View File

@ -85,7 +85,6 @@ export default function Home() {
image_url: imageFile, image_url: imageFile,
image_size: 'square_hd', image_size: 'square_hd',
}, },
pollInterval: 3000, // Default is 1000 (every 1s)
logs: true, logs: true,
onQueueUpdate(update) { onQueueUpdate(update) {
setElapsedTime(Date.now() - start); setElapsedTime(Date.now() - start);

View File

@ -0,0 +1,152 @@
'use client';
import * as fal from '@fal-ai/serverless-client';
import { useState } from 'react';
fal.config({
proxyUrl: '/api/fal/proxy',
});
type ErrorProps = {
error: any;
};
function Error(props: ErrorProps) {
if (!props.error) {
return null;
}
return (
<div
className="p-4 mb-4 text-sm text-red-800 rounded bg-red-50 dark:bg-gray-800 dark:text-red-400"
role="alert"
>
<span className="font-medium">Error</span> {props.error.message}
</div>
);
}
export default function Home() {
// Input state
const [endpointId, setEndpointId] = useState<string>('');
const [input, setInput] = useState<string>('{}');
// Result state
const [loading, setLoading] = useState(false);
const [error, setError] = useState<Error | null>(null);
const [result, setResult] = useState<any | null>(null);
const [logs, setLogs] = useState<string[]>([]);
const [elapsedTime, setElapsedTime] = useState<number>(0);
const reset = () => {
setLoading(false);
setError(null);
setResult(null);
setLogs([]);
setElapsedTime(0);
};
const run = async () => {
reset();
setLoading(true);
const start = Date.now();
try {
const result: any = await fal.subscribe(endpointId, {
input: JSON.parse(input),
logs: true,
onQueueUpdate(update) {
console.log('queue update');
console.log(update);
setElapsedTime(Date.now() - start);
if (
update.status === 'IN_PROGRESS' ||
update.status === 'COMPLETED'
) {
if (update.logs && update.logs.length > logs.length) {
setLogs((update.logs || []).map((log) => log.message));
}
}
},
});
setResult(result);
} catch (error: any) {
setError(error);
} finally {
setLoading(false);
setElapsedTime(Date.now() - start);
}
};
return (
<div className="min-h-screen dark:bg-gray-900 bg-gray-100">
<main className="container dark:text-gray-50 text-gray-900 flex flex-col items-center justify-center w-full flex-1 py-10 space-y-8">
<h1 className="text-4xl font-bold mb-8">
<code className="font-light text-pink-600">fal</code>
<code>queue</code>
</h1>
<div className="text-lg w-full">
<label htmlFor="prompt" className="block mb-2 text-current">
Endpoint ID
</label>
<input
className="w-full text-base p-2 rounded bg-black/10 dark:bg-white/5 border border-black/20 dark:border-white/10"
id="endpointId"
name="endpointId"
autoComplete="off"
placeholder="Endpoint ID"
value={endpointId}
spellCheck={false}
onChange={(e) => setEndpointId(e.target.value)}
/>
</div>
<div className="text-lg w-full">
<label htmlFor="prompt" className="block mb-2 text-current">
JSON Input
</label>
<textarea
className="w-full text-sm p-2 rounded bg-black/10 dark:bg-white/5 border border-black/20 dark:border-white/10 font-mono"
id="input"
name="Input"
placeholder="JSON"
value={input}
autoComplete="off"
spellCheck={false}
onChange={(e) => setInput(e.target.value)}
rows={6}
></textarea>
</div>
<button
onClick={(e) => {
e.preventDefault();
run();
}}
className="bg-indigo-600 hover:bg-indigo-700 text-white font-bold text-lg py-3 px-6 mx-auto rounded focus:outline-none focus:shadow-outline"
disabled={loading}
>
{loading ? 'Running...' : 'Run'}
</button>
<Error error={error} />
<div className="w-full flex flex-col space-y-4">
<div className="space-y-2">
<h3 className="text-xl font-light">JSON Result</h3>
<p className="text-sm text-current/80">
{`Elapsed Time (seconds): ${(elapsedTime / 1000).toFixed(2)}`}
</p>
<pre className="text-sm bg-black/70 text-white/80 font-mono h-60 rounded whitespace-pre overflow-auto w-full">
{result
? JSON.stringify(result, null, 2)
: '// result pending...'}
</pre>
</div>
<div className="space-y-2">
<h3 className="text-xl font-light">Logs</h3>
<pre className="text-sm bg-black/70 text-white/80 font-mono h-60 rounded whitespace-pre overflow-auto w-full">
{logs.join('\n')}
</pre>
</div>
</div>
</main>
</div>
);
}

View File

@ -113,7 +113,6 @@ export default function WhisperDemo() {
file_name: 'recording.wav', file_name: 'recording.wav',
audio_url: audioFile, audio_url: audioFile,
}, },
pollInterval: 1000,
logs: true, logs: true,
onQueueUpdate(update) { onQueueUpdate(update) {
setElapsedTime(Date.now() - start); setElapsedTime(Date.now() - start);

View File

@ -78,7 +78,6 @@ export function Index() {
model_name: 'stabilityai/stable-diffusion-xl-base-1.0', model_name: 'stabilityai/stable-diffusion-xl-base-1.0',
image_size: 'square_hd', image_size: 'square_hd',
}, },
pollInterval: 3000, // Default is 1000 (every 1s)
logs: true, logs: true,
onQueueUpdate(update) { onQueueUpdate(update) {
setElapsedTime(Date.now() - start); setElapsedTime(Date.now() - start);

View File

@ -1,7 +1,7 @@
{ {
"name": "@fal-ai/serverless-client", "name": "@fal-ai/serverless-client",
"description": "The fal serverless JS/TS client", "description": "The fal serverless JS/TS client",
"version": "0.11.0", "version": "0.12.0",
"license": "MIT", "license": "MIT",
"repository": { "repository": {
"type": "git", "type": "git",

View File

@ -1,6 +1,8 @@
import { getTemporaryAuthToken } from './auth';
import { dispatchRequest } from './request'; import { dispatchRequest } from './request';
import { storageImpl } from './storage'; import { storageImpl } from './storage';
import { EnqueueResult, QueueStatus } from './types'; import { FalStream } from './streaming';
import { EnqueueResult, QueueStatus, RequestLog } from './types';
import { ensureAppIdFormat, isUUIDv4, isValidUrl, parseAppId } from './utils'; import { ensureAppIdFormat, isUUIDv4, isValidUrl, parseAppId } from './utils';
/** /**
@ -138,36 +140,22 @@ export async function subscribe<Input, Output>(
if (options.onEnqueue) { if (options.onEnqueue) {
options.onEnqueue(requestId); options.onEnqueue(requestId);
} }
return new Promise<Output>((resolve, reject) => { const status = await queue.streamStatus(id, {
let timeoutId: ReturnType<typeof setTimeout>;
const pollInterval = options.pollInterval ?? 1000;
const poll = async () => {
try {
const requestStatus = await queue.status(id, {
requestId, requestId,
logs: options.logs ?? false, logs: options.logs,
}); });
const logs: RequestLog[] = [];
status.on('message', (data: QueueStatus) => {
if (options.onQueueUpdate) { if (options.onQueueUpdate) {
options.onQueueUpdate(requestStatus); // accumulate logs to match previous polling behavior
if ('logs' in data && Array.isArray(data.logs) && data.logs.length > 0) {
logs.push(...data.logs);
} }
if (requestStatus.status === 'COMPLETED') { options.onQueueUpdate('logs' in data ? { ...data, logs } : data);
clearTimeout(timeoutId);
try {
const result = await queue.result<Output>(id, { requestId });
resolve(result);
} catch (error) {
reject(error);
} }
return;
}
timeoutId = setTimeout(poll, pollInterval);
} catch (error) {
clearTimeout(timeoutId);
reject(error);
}
};
poll().catch(reject);
}); });
await status.done();
return queue.result<Output>(id, { requestId });
} }
/** /**
@ -177,6 +165,9 @@ type QueueSubscribeOptions = {
/** /**
* The interval (in milliseconds) at which to poll for updates. * The interval (in milliseconds) at which to poll for updates.
* If not provided, a default value of `1000` will be used. * If not provided, a default value of `1000` will be used.
*
* @deprecated starting from v0.12.0 the queue status is streamed
* using the `queue.subscribeToStatus` method.
*/ */
pollInterval?: number; pollInterval?: number;
@ -239,40 +230,48 @@ interface Queue {
/** /**
* Submits a request to the queue. * Submits a request to the queue.
* *
* @param id - The ID or URL of the function web endpoint. * @param endpointId - The ID or URL of the function web endpoint.
* @param options - Options to configure how the request is run. * @param options - Options to configure how the request is run.
* @returns A promise that resolves to the result of enqueuing the request. * @returns A promise that resolves to the result of enqueuing the request.
*/ */
submit<Input>( submit<Input>(
id: string, endpointId: string,
options: SubmitOptions<Input> options: SubmitOptions<Input>
): Promise<EnqueueResult>; ): Promise<EnqueueResult>;
/** /**
* Retrieves the status of a specific request in the queue. * Retrieves the status of a specific request in the queue.
* *
* @param id - The ID or URL of the function web endpoint. * @param endpointId - The ID or URL of the function web endpoint.
* @param options - Options to configure how the request is run. * @param options - Options to configure how the request is run.
* @returns A promise that resolves to the status of the request. * @returns A promise that resolves to the status of the request.
*/ */
status(id: string, options: QueueStatusOptions): Promise<QueueStatus>; status(endpointId: string, options: QueueStatusOptions): Promise<QueueStatus>;
/** /**
* Retrieves the result of a specific request from the queue. * Retrieves the result of a specific request from the queue.
* *
* @param id - The ID or URL of the function web endpoint. * @param endpointId - The ID or URL of the function web endpoint.
* @param options - Options to configure how the request is run. * @param options - Options to configure how the request is run.
* @returns A promise that resolves to the result of the request. * @returns A promise that resolves to the result of the request.
*/ */
result<Output>(id: string, options: BaseQueueOptions): Promise<Output>; result<Output>(
endpointId: string,
options: BaseQueueOptions
): Promise<Output>;
/** /**
* @deprecated Use `fal.subscribe` instead. * @deprecated Use `fal.subscribe` instead.
*/ */
subscribe<Input, Output>( subscribe<Input, Output>(
id: string, endpointId: string,
options: RunOptions<Input> & QueueSubscribeOptions options: RunOptions<Input> & QueueSubscribeOptions
): Promise<Output>; ): Promise<Output>;
streamStatus(
endpointId: string,
options: QueueStatusOptions
): Promise<FalStream<unknown, QueueStatus>>;
} }
/** /**
@ -282,11 +281,11 @@ interface Queue {
*/ */
export const queue: Queue = { export const queue: Queue = {
async submit<Input>( async submit<Input>(
id: string, endpointId: string,
options: SubmitOptions<Input> options: SubmitOptions<Input>
): Promise<EnqueueResult> { ): Promise<EnqueueResult> {
const { webhookUrl, path = '', ...runOptions } = options; const { webhookUrl, path = '', ...runOptions } = options;
return send(id, { return send(endpointId, {
...runOptions, ...runOptions,
subdomain: 'queue', subdomain: 'queue',
method: 'post', method: 'post',
@ -295,10 +294,10 @@ export const queue: Queue = {
}); });
}, },
async status( async status(
id: string, endpointId: string,
{ requestId, logs = false }: QueueStatusOptions { requestId, logs = false }: QueueStatusOptions
): Promise<QueueStatus> { ): Promise<QueueStatus> {
const appId = parseAppId(id); const appId = parseAppId(endpointId);
const prefix = appId.namespace ? `${appId.namespace}/` : ''; const prefix = appId.namespace ? `${appId.namespace}/` : '';
return send(`${prefix}${appId.owner}/${appId.alias}`, { return send(`${prefix}${appId.owner}/${appId.alias}`, {
subdomain: 'queue', subdomain: 'queue',
@ -309,11 +308,33 @@ export const queue: Queue = {
}, },
}); });
}, },
async streamStatus(
endpointId: string,
{ requestId, logs = false }: QueueStatusOptions
): Promise<FalStream<unknown, QueueStatus>> {
const appId = parseAppId(endpointId);
const prefix = appId.namespace ? `${appId.namespace}/` : '';
const token = await getTemporaryAuthToken(endpointId);
const url = buildUrl(`${prefix}${appId.owner}/${appId.alias}`, {
subdomain: 'queue',
path: `/requests/${requestId}/status/stream`,
});
const queryParams = new URLSearchParams({
fal_jwt_token: token,
logs: logs ? '1' : '0',
});
return new FalStream<unknown, QueueStatus>(`${url}?${queryParams}`, {
input: {},
method: 'get',
});
},
async result<Output>( async result<Output>(
id: string, endpointId: string,
{ requestId }: BaseQueueOptions { requestId }: BaseQueueOptions
): Promise<Output> { ): Promise<Output> {
const appId = parseAppId(id); const appId = parseAppId(endpointId);
const prefix = appId.namespace ? `${appId.namespace}/` : ''; const prefix = appId.namespace ? `${appId.namespace}/` : '';
return send(`${prefix}${appId.owner}/${appId.alias}`, { return send(`${prefix}${appId.owner}/${appId.alias}`, {
subdomain: 'queue', subdomain: 'queue',

View File

@ -12,18 +12,23 @@ type StreamOptions<Input> = {
/** /**
* The API input payload. * The API input payload.
*/ */
input: Input; readonly input?: Input;
/** /**
* The maximum time interval in milliseconds between stream chunks. Defaults to 15s. * The maximum time interval in milliseconds between stream chunks. Defaults to 15s.
*/ */
timeout?: number; readonly timeout?: number;
/** /**
* Whether it should auto-upload File-like types to fal's storage * Whether it should auto-upload File-like types to fal's storage
* or not. * or not.
*/ */
autoUpload?: boolean; readonly autoUpload?: boolean;
/**
* The HTTP method, defaults to `post`;
*/
readonly method?: 'get' | 'post' | 'put' | 'delete' | string;
}; };
const EVENT_STREAM_TIMEOUT = 15 * 1000; const EVENT_STREAM_TIMEOUT = 15 * 1000;
@ -35,7 +40,7 @@ type EventHandler = (event: any) => void;
/** /**
* The class representing a streaming response. With t * The class representing a streaming response. With t
*/ */
class FalStream<Input, Output> { export class FalStream<Input, Output> {
// properties // properties
url: string; url: string;
options: StreamOptions<Input>; options: StreamOptions<Input>;
@ -76,14 +81,16 @@ class FalStream<Input, Output> {
} }
private start = async () => { private start = async () => {
const { url, options } = this;
const { input, method = 'post' } = options;
try { try {
const response = await fetch(this.url, { const response = await fetch(url, {
method: 'POST', method: method.toUpperCase(),
headers: { headers: {
accept: 'text/event-stream', accept: 'text/event-stream',
'content-type': 'application/json', 'content-type': 'application/json',
}, },
body: JSON.stringify(this.options.input), body: input && method !== 'get' ? JSON.stringify(input) : undefined,
}); });
this.handleResponse(response); this.handleResponse(response);
} catch (error) { } catch (error) {