feat(client): realtime msgpack payload (#65)

* feat(client): realtime msgpack payload

* chore: update realtime samples
This commit is contained in:
Daniel Rochetti 2024-05-08 15:16:01 -07:00 committed by GitHub
parent 9f4f70517f
commit 5f15da9d83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 96 additions and 68 deletions

View File

@ -31,8 +31,7 @@ This client library is crafted as a lightweight layer atop platform standards li
fal.config({ fal.config({
// Can also be auto-configured using environment variables: // Can also be auto-configured using environment variables:
// Either a single FAL_KEY or a combination of FAL_KEY_ID and FAL_KEY_SECRET credentials: 'FAL_KEY',
credentials: 'FAL_KEY_ID:FAL_KEY_SECRET',
}); });
``` ```

View File

@ -109,7 +109,7 @@ const useWebcam = ({
type LCMInput = { type LCMInput = {
prompt: string; prompt: string;
image: Uint8Array; image_bytes: Uint8Array;
strength?: number; strength?: number;
negative_prompt?: string; negative_prompt?: string;
seed?: number | null; seed?: number | null;
@ -121,8 +121,14 @@ type LCMInput = {
width?: number; width?: number;
}; };
type ImageOutput = {
content: Uint8Array;
width: number;
height: number;
};
type LCMOutput = { type LCMOutput = {
image: Uint8Array; images: ImageOutput[];
timings: Record<string, number>; timings: Record<string, number>;
seed: number; seed: number;
num_inference_steps: number; num_inference_steps: number;
@ -137,15 +143,17 @@ export default function WebcamPage() {
const previewRef = useRef<HTMLCanvasElement | null>(null); const previewRef = useRef<HTMLCanvasElement | null>(null);
const { send } = fal.realtime.connect<LCMInput, LCMOutput>( const { send } = fal.realtime.connect<LCMInput, LCMOutput>(
'fal-ai/sd-turbo-real-time-high-fps-msgpack-a10g', 'fal-ai/fast-turbo-diffusion/image-to-image',
{ {
connectionKey: 'camera-turbo-demo', connectionKey: 'camera-turbo-demo',
// not throttling the client, handling throttling of the camera itself // not throttling the client, handling throttling of the camera itself
// and letting all requests through in real-time // and letting all requests through in real-time
throttleInterval: 0, throttleInterval: 0,
onResult(result) { onResult(result) {
if (processedImageRef.current && result.image) { if (processedImageRef.current && result.images && result.images[0]) {
const blob = new Blob([result.image], { type: 'image/jpeg' }); const blob = new Blob([result.images[0].content], {
type: 'image/jpeg',
});
const url = URL.createObjectURL(blob); const url = URL.createObjectURL(blob);
processedImageRef.current.src = url; processedImageRef.current.src = url;
} }
@ -158,10 +166,10 @@ export default function WebcamPage() {
return; return;
} }
send({ send({
prompt: 'a picture of leonardo di caprio, elegant, in a suit, 8k, uhd', prompt: 'a picture of george clooney, elegant, in a suit, 8k, uhd',
image: data, image_bytes: data,
num_inference_steps: 3, num_inference_steps: 3,
strength: 0.44, strength: 0.6,
guidance_scale: 1, guidance_scale: 1,
seed: 6252023, seed: 6252023,
}); });

View File

@ -2,27 +2,64 @@
/* eslint-disable @next/next/no-img-element */ /* eslint-disable @next/next/no-img-element */
import * as fal from '@fal-ai/serverless-client'; import * as fal from '@fal-ai/serverless-client';
import { useState } from 'react'; import { ChangeEvent, useRef, useState } from 'react';
import { DrawingCanvas } from '../../components/drawing'; import { DrawingCanvas } from '../../components/drawing';
fal.config({ fal.config({
proxyUrl: '/api/fal/proxy', proxyUrl: '/api/fal/proxy',
}); });
const PROMPT = 'a moon in a starry night sky'; const PROMPT_EXPANDED =
', beautiful, colorful, highly detailed, best quality, uhd';
const PROMPT = 'a moon in the night sky';
const defaults = {
model_name: 'runwayml/stable-diffusion-v1-5',
image_size: 'square',
num_inference_steps: 4,
seed: 6252023,
};
export default function RealtimePage() { export default function RealtimePage() {
const [image, setImage] = useState<string | null>(null); const [prompt, setPrompt] = useState(PROMPT);
const { send } = fal.realtime.connect('fal-ai/lcm-sd15-i2i', { const currentDrawing = useRef<Uint8Array | null>(null);
connectionKey: 'realtime-demo', const outputCanvasRef = useRef<HTMLCanvasElement | null>(null);
throttleInterval: 128,
onResult(result) { const { send } = fal.realtime.connect(
if (result.images && result.images[0]) { 'fal-ai/fast-lcm-diffusion/image-to-image',
setImage(result.images[0].url); {
} connectionKey: 'realtime-demo',
}, throttleInterval: 128,
}); onResult(result) {
if (result.images && result.images[0] && result.images[0].content) {
const canvas = outputCanvasRef.current;
const context = canvas?.getContext('2d');
if (canvas && context) {
const imageBytes: Uint8Array = result.images[0].content;
const blob = new Blob([imageBytes], { type: 'image/png' });
createImageBitmap(blob)
.then((bitmap) => {
context.drawImage(bitmap, 0, 0);
})
.catch(console.error);
}
}
},
}
);
const handlePromptChange = (e: ChangeEvent<HTMLInputElement>) => {
setPrompt(e.target.value);
if (currentDrawing.current) {
send({
prompt: e.target.value.trim() + PROMPT_EXPANDED,
image_bytes: currentDrawing.current,
...defaults,
});
}
};
return ( return (
<div className="min-h-screen bg-neutral-900 text-neutral-50"> <div className="min-h-screen bg-neutral-900 text-neutral-50">
@ -30,31 +67,34 @@ export default function RealtimePage() {
<h1 className="text-4xl font-mono mb-8 text-neutral-50"> <h1 className="text-4xl font-mono mb-8 text-neutral-50">
fal<code className="font-light text-pink-600">realtime</code> fal<code className="font-light text-pink-600">realtime</code>
</h1> </h1>
<div className="prose text-neutral-400"> <div className="w-full max-w-full text-neutral-400">
<blockquote className="italic text-xl">{PROMPT}</blockquote> <input
className="italic text-xl px-3 py-2 border border-white/10 rounded-md bg-white/5 w-full"
value={prompt}
onChange={handlePromptChange}
/>
</div> </div>
<div className="flex flex-col md:flex-row space-x-4"> <div className="flex flex-col md:flex-row space-x-4">
<div className="flex-1"> <div className="flex-1">
<DrawingCanvas <DrawingCanvas
onCanvasChange={({ imageData }) => { onCanvasChange={({ imageData }) => {
currentDrawing.current = imageData;
send({ send({
prompt: PROMPT, prompt: prompt + PROMPT_EXPANDED,
image_url: imageData, image_bytes: imageData,
sync_mode: true, ...defaults,
seed: 6252023,
}); });
}} }}
/> />
</div> </div>
<div className="flex-1"> <div className="flex-1">
<div className="w-[512px] h-[512px]"> <div>
{image && ( <canvas
<img className="w-[512px] h-[512px]"
src={image} width="512"
alt={`${PROMPT} generated by fal.ai`} height="512"
className="object-contain w-full h-full" ref={outputCanvasRef}
/> />
)}
</div> </div>
</div> </div>
</div> </div>

View File

@ -10,14 +10,14 @@ import initialDrawing from './drawingState.json';
export type CanvasChangeEvent = { export type CanvasChangeEvent = {
elements: readonly ExcalidrawElement[]; elements: readonly ExcalidrawElement[];
appState: AppState; appState: AppState;
imageData: string; imageData: Uint8Array;
}; };
export type DrawingCanvasProps = { export type DrawingCanvasProps = {
onCanvasChange: (event: CanvasChangeEvent) => void; onCanvasChange: (event: CanvasChangeEvent) => void;
}; };
async function blobToBase64(blob: Blob): Promise<string> { export async function blobToBase64(blob: Blob): Promise<string> {
const reader = new FileReader(); const reader = new FileReader();
reader.readAsDataURL(blob); reader.readAsDataURL(blob);
return new Promise<string>((resolve) => { return new Promise<string>((resolve) => {
@ -27,6 +27,11 @@ async function blobToBase64(blob: Blob): Promise<string> {
}); });
} }
export async function blobToUint8Array(blob: Blob): Promise<Uint8Array> {
const buffer = await blob.arrayBuffer();
return new Uint8Array(buffer);
}
export function DrawingCanvas({ onCanvasChange }: DrawingCanvasProps) { export function DrawingCanvas({ onCanvasChange }: DrawingCanvasProps) {
const [ExcalidrawComponent, setExcalidrawComponent] = useState< const [ExcalidrawComponent, setExcalidrawComponent] = useState<
typeof Excalidraw | null typeof Excalidraw | null
@ -95,7 +100,7 @@ export function DrawingCanvas({ onCanvasChange }: DrawingCanvasProps) {
return { width: 512, height: 512 }; return { width: 512, height: 512 };
}, },
}); });
const imageData = await blobToBase64(blob); const imageData = await blobToUint8Array(blob);
onCanvasChange({ elements, appState, imageData }); onCanvasChange({ elements, appState, imageData });
} }
}, },

View File

@ -1,7 +1,7 @@
{ {
"name": "@fal-ai/serverless-client", "name": "@fal-ai/serverless-client",
"description": "The fal serverless JS/TS client", "description": "The fal serverless JS/TS client",
"version": "0.9.3", "version": "0.10.0",
"license": "MIT", "license": "MIT",
"repository": { "repository": {
"type": "git", "type": "git",

View File

@ -12,3 +12,4 @@ export type {
ValidationErrorInfo, ValidationErrorInfo,
WebHookResponse, WebHookResponse,
} from './types'; } from './types';
export { parseAppId } from './utils';

View File

@ -16,7 +16,7 @@ import uuid from 'uuid-random';
import { TOKEN_EXPIRATION_SECONDS, getTemporaryAuthToken } from './auth'; import { TOKEN_EXPIRATION_SECONDS, getTemporaryAuthToken } from './auth';
import { ApiError } from './response'; import { ApiError } from './response';
import { isBrowser } from './runtime'; import { isBrowser } from './runtime';
import { ensureAppIdFormat, isReact, parseAppId, throttle } from './utils'; import { ensureAppIdFormat, isReact, throttle } from './utils';
// Define the context // Define the context
interface Context { interface Context {
@ -78,10 +78,8 @@ function sendMessage(context: Context, event: SendEvent): Context {
if (context.websocket && context.websocket.readyState === WebSocket.OPEN) { if (context.websocket && context.websocket.readyState === WebSocket.OPEN) {
if (event.message instanceof Uint8Array) { if (event.message instanceof Uint8Array) {
context.websocket.send(event.message); context.websocket.send(event.message);
} else if (shouldSendBinary(event.message)) {
context.websocket.send(encode(event.message));
} else { } else {
context.websocket.send(JSON.stringify(event.message)); context.websocket.send(encode(event.message));
} }
return { return {
@ -248,17 +246,6 @@ type RealtimeUrlParams = {
maxBuffering?: number; maxBuffering?: number;
}; };
// This is a list of apps deployed before formal realtime support. Their URLs follow
// a different pattern and will be kept here until we fully sunset them.
const LEGACY_APPS = [
'lcm-sd15-i2i',
'lcm',
'sdxl-turbo-realtime',
'sd-turbo-real-time-high-fps-msgpack-a10g',
'lcm-plexed-sd15-i2i',
'sd-turbo-real-time-high-fps-msgpack',
];
function buildRealtimeUrl( function buildRealtimeUrl(
app: string, app: string,
{ token, maxBuffering }: RealtimeUrlParams { token, maxBuffering }: RealtimeUrlParams
@ -273,23 +260,11 @@ function buildRealtimeUrl(
queryParams.set('max_buffering', maxBuffering.toFixed(0)); queryParams.set('max_buffering', maxBuffering.toFixed(0));
} }
const appId = ensureAppIdFormat(app); const appId = ensureAppIdFormat(app);
const { alias } = parseAppId(appId); return `wss://fal.run/${appId}/realtime?${queryParams.toString()}`;
const suffix =
LEGACY_APPS.includes(alias) || !app.includes('/') ? 'ws' : 'realtime';
return `wss://fal.run/${appId}/${suffix}?${queryParams.toString()}`;
} }
const DEFAULT_THROTTLE_INTERVAL = 128; const DEFAULT_THROTTLE_INTERVAL = 128;
function shouldSendBinary(message: any): boolean {
return Object.values(message).some(
(value) =>
value instanceof Blob ||
value instanceof ArrayBuffer ||
value instanceof Uint8Array
);
}
function isUnauthorizedError(message: any): boolean { function isUnauthorizedError(message: any): boolean {
// TODO we need better protocol definition with error codes // TODO we need better protocol definition with error codes
return message['status'] === 'error' && message['error'] === 'Unauthorized'; return message['status'] === 'error' && message['error'] === 'Unauthorized';