feat(client): realtime state machine impl (#32)

* fix: connection state handling

* chore: reset token expiration

* feat: state machine experiment

* feat: new realtime state machine impl

* chore: update client to 0.7.0 before release

* fix: error handling x-fal-error

* chore(client): release v0.7.0

* fix(client): strict type check error
This commit is contained in:
Daniel Rochetti 2023-12-04 14:44:56 -08:00 committed by GitHub
parent c020d97acd
commit 6ad41e1bfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 366 additions and 148 deletions

View File

@ -3,6 +3,7 @@
* This is only a minimal backend to get started.
*/
import * as fal from '@fal-ai/serverless-client';
import * as falProxy from '@fal-ai/serverless-proxy/express';
import cors from 'cors';
import { configDotenv } from 'dotenv';
@ -25,6 +26,16 @@ app.get('/api', (req, res) => {
res.send({ message: 'Welcome to demo-express-app!' });
});
app.get('/fal-on-server', async (req, res) => {
const result = await fal.run('110602490-lcm', {
input: {
prompt:
'a black cat with glowing eyes, cute, adorable, disney, pixar, highly detailed, 8k',
},
});
res.send(result);
});
const port = process.env.PORT || 3333;
const server = app.listen(port, () => {
console.log(`Listening at http://localhost:${port}/api`);

View File

@ -2,8 +2,8 @@
/* eslint-disable @next/next/no-img-element */
import * as fal from '@fal-ai/serverless-client';
import { DrawingCanvas } from '../../components/drawing';
import { useState } from 'react';
import { DrawingCanvas } from '../../components/drawing';
fal.config({
proxyUrl: '/api/fal/proxy',
@ -14,8 +14,9 @@ const PROMPT = 'a moon in a starry night sky';
export default function RealtimePage() {
const [image, setImage] = useState<string | null>(null);
const { send } = fal.realtime.connect('110602490-shared-lcm-test', {
const { send } = fal.realtime.connect('110602490-lcm-sd15-i2i', {
connectionKey: 'realtime-demo',
throttleInterval: 128,
onResult(result) {
if (result.images && result.images[0]) {
setImage(result.images[0].url);

View File

@ -1,7 +1,7 @@
{
"name": "@fal-ai/serverless-client",
"description": "The fal serverless JS/TS client",
"version": "0.6.1",
"version": "0.7.0",
"license": "MIT",
"repository": {
"type": "git",
@ -14,5 +14,11 @@
"client",
"ai",
"ml"
]
],
"dependencies": {
"robot3": "^0.4.1"
},
"engines": {
"node": ">=18.0.0"
}
}

View File

@ -1,7 +1,7 @@
import {
withMiddleware,
withProxy,
type RequestMiddleware,
withMiddleware,
} from './middleware';
import type { ResponseHandler } from './response';
import { defaultResponseHandler } from './response';

View File

@ -1,6 +1,6 @@
import { getConfig } from './config';
import { storageImpl } from './storage';
import { dispatchRequest } from './request';
import { storageImpl } from './storage';
import { EnqueueResult, QueueStatus } from './types';
import { isUUIDv4, isValidUrl } from './utils';

View File

@ -1,23 +1,170 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import {
ContextFunction,
createMachine,
guard,
immediate,
interpret,
InterpretOnChangeFunction,
reduce,
Service,
state,
transition,
} from 'robot3';
import { getConfig, getRestApiUrl } from './config';
import { dispatchRequest } from './request';
import { ApiError } from './response';
import { isBrowser } from './runtime';
import { isReact, throttle } from './utils';
// Define the context
interface Context {
token?: string;
enqueuedMessage?: any;
websocket?: WebSocket;
error?: Error;
}
const initialState: ContextFunction<Context> = () => ({
enqueuedMessage: undefined,
});
type SendEvent = { type: 'send'; message: any };
type AuthenticatedEvent = { type: 'authenticated'; token: string };
type InitiateAuthEvent = { type: 'initiateAuth' };
type UnauthorizedEvent = { type: 'unauthorized'; error: Error };
type ConnectedEvent = { type: 'connected'; websocket: WebSocket };
type ConnectionClosedEvent = {
type: 'connectionClosed';
code: number;
reason: string;
};
type Event =
| SendEvent
| AuthenticatedEvent
| InitiateAuthEvent
| UnauthorizedEvent
| ConnectedEvent
| ConnectionClosedEvent;
function hasToken(context: Context): boolean {
return context.token !== undefined;
}
function noToken(context: Context): boolean {
return !hasToken(context);
}
function enqueueMessage(context: Context, event: SendEvent): Context {
return {
...context,
enqueuedMessage: event.message,
};
}
function closeConnection(context: Context): Context {
if (context.websocket && context.websocket.readyState === WebSocket.OPEN) {
context.websocket.close();
}
return {
...context,
websocket: undefined,
};
}
function sendMessage(context: Context, event: SendEvent): Context {
if (context.websocket && context.websocket.readyState === WebSocket.OPEN) {
context.websocket.send(JSON.stringify(event.message));
return {
...context,
enqueuedMessage: undefined,
};
}
return {
...context,
enqueuedMessage: event.message,
};
}
function expireToken(context: Context): Context {
return {
...context,
token: undefined,
};
}
function setToken(context: Context, event: AuthenticatedEvent): Context {
return {
...context,
token: event.token,
};
}
function connectionEstablished(
context: Context,
event: ConnectedEvent
): Context {
return {
...context,
websocket: event.websocket,
};
}
// State machine
const connectionStateMachine = createMachine(
'idle',
{
idle: state(
transition('send', 'connecting', reduce(enqueueMessage)),
transition('expireToken', 'idle', reduce(expireToken))
),
connecting: state(
transition('connecting', 'connecting'),
transition('connected', 'active', reduce(connectionEstablished)),
transition('connectionClosed', 'idle', reduce(closeConnection)),
transition('send', 'connecting', reduce(enqueueMessage)),
immediate('authRequired', guard(noToken))
),
authRequired: state(
transition('initiateAuth', 'authInProgress'),
transition('send', 'authRequired', reduce(enqueueMessage))
),
authInProgress: state(
transition('authenticated', 'connecting', reduce(setToken)),
transition(
'unauthorized',
'idle',
reduce(expireToken),
reduce(closeConnection)
),
transition('send', 'authInProgress', reduce(enqueueMessage))
),
active: state(
transition('send', 'active', reduce(sendMessage)),
transition('unauthorized', 'idle', reduce(expireToken)),
transition('connectionClosed', 'idle', reduce(closeConnection))
),
failed: state(transition('send', 'failed')),
},
initialState
);
type WithRequestId = {
request_id: string;
};
/**
* A connection object that allows you to `send` request payloads to a
* realtime endpoint.
*/
export interface RealtimeConnection<Input> {
send(input: Input): void;
send(input: Input & Partial<WithRequestId>): void;
close(): void;
}
type ResultWithRequestId = {
request_id: string;
};
/**
* Options for connecting to the realtime endpoint.
*/
@ -46,23 +193,31 @@ export interface RealtimeConnectionHandler<Output> {
/**
* The throtle duration in milliseconds. This is used to throtle the
* calls to the `send` function. Realtime apps usually react to user
* input, which can be very frequesnt (e.g. fast typing or mouse/drag movements).
* input, which can be very frequent (e.g. fast typing or mouse/drag movements).
*
* The default value is `64` milliseconds.
* The default value is `128` milliseconds.
*/
throttleInterval?: number;
/**
* Configures the maximum amount of frames to store in memory before starting to drop
* old ones for in favor of the newer ones. It must be between `1` and `60`.
*
* The recommended is `2`. The default is `undefined` so it can be determined
* by the app (normally is set to the recommended setting).
*/
maxBuffering?: number;
/**
* Callback function that is called when a result is received.
* @param result - The result of the request.
*/
onResult(result: Output & ResultWithRequestId): void;
onResult(result: Output & WithRequestId): void;
/**
* Callback function that is called when an error occurs.
* @param error - The error that occurred.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
onError?(error: ApiError<any>): void;
}
@ -74,19 +229,36 @@ export interface RealtimeClient {
* @param app the app alias or identifier.
* @param handler the connection handler.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
connect<Input = any, Output = any>(
app: string,
handler: RealtimeConnectionHandler<Output>
): RealtimeConnection<Input>;
}
function buildRealtimeUrl(app: string): string {
type RealtimeUrlParams = {
token: string;
maxBuffering?: number;
};
function buildRealtimeUrl(
app: string,
{ token, maxBuffering }: RealtimeUrlParams
): string {
const { host } = getConfig();
return `wss://${app}.${host}/ws`;
if (maxBuffering !== undefined && (maxBuffering < 1 || maxBuffering > 60)) {
throw new Error('The `maxBuffering` must be between 1 and 60 (inclusive)');
}
const queryParams = new URLSearchParams({
fal_jwt_token: token,
});
if (maxBuffering !== undefined) {
queryParams.set('max_buffering', maxBuffering.toFixed(0));
}
return `wss://${app}.${host}/ws?${queryParams.toString()}`;
}
const TOKEN_EXPIRATION_SECONDS = 120;
const DEFAULT_THROTTLE_INTERVAL = 128;
/**
* Get a token to connect to the realtime endpoint.
@ -98,7 +270,7 @@ async function getToken(app: string): Promise<string> {
`https://${getRestApiUrl()}/tokens/`,
{
allowed_apps: [appAlias.join('-')],
token_expiration: 120,
token_expiration: TOKEN_EXPIRATION_SECONDS,
}
);
// keep this in case the response was wrapped (old versions of the proxy do that)
@ -109,6 +281,11 @@ async function getToken(app: string): Promise<string> {
return token;
}
function isUnauthorizedError(message: any): boolean {
// TODO we need better protocol definition with error codes
return message['status'] === 'error' && message['error'] === 'Unauthorized';
}
/**
* See https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
*/
@ -117,71 +294,40 @@ const WebSocketErrorCodes = {
GOING_AWAY: 1001,
};
const connectionManager = (() => {
const connections = new Map<string, WebSocket>();
const tokens = new Map<string, string>();
const isAuthInProgress = new Map<string, true>();
type ConnectionStateMachine = Service<typeof connectionStateMachine> & {
throttledSend: (
event: Event,
payload?: any
) => void | Promise<void> | undefined;
};
return {
token(app: string) {
return tokens.get(app);
},
expireToken(app: string) {
tokens.delete(app);
},
async refreshToken(app: string) {
const token = await getToken(app);
tokens.set(app, token);
// Very simple token expiration mechanism.
// We should make it more robust in the future.
setTimeout(() => {
tokens.delete(app);
}, TOKEN_EXPIRATION_SECONDS * 0.9 * 1000);
return token;
},
has(connectionKey: string): boolean {
return connections.has(connectionKey);
},
get(connectionKey: string): WebSocket | undefined {
return connections.get(connectionKey);
},
set(connectionKey: string, ws: WebSocket) {
connections.set(connectionKey, ws);
},
remove(connectionKey: string) {
connections.delete(connectionKey);
},
isAuthInProgress(app: string) {
return isAuthInProgress.has(app);
},
setAuthInProgress(app: string, inProgress: boolean) {
if (inProgress) {
isAuthInProgress.set(app, true);
} else {
isAuthInProgress.delete(app);
}
},
};
})();
type ConnectionOnChange = InterpretOnChangeFunction<
typeof connectionStateMachine
>;
async function getConnection(app: string, key: string): Promise<WebSocket> {
if (connectionManager.isAuthInProgress(app)) {
throw new Error('Authentication in progress');
}
const url = buildRealtimeUrl(app);
type RealtimeConnectionCallback = Pick<
RealtimeConnectionHandler<any>,
'onResult' | 'onError'
>;
if (connectionManager.has(key)) {
return connectionManager.get(key) as WebSocket;
const connectionCache = new Map<string, ConnectionStateMachine>();
const connectionCallbacks = new Map<string, RealtimeConnectionCallback>();
function reuseInterpreter(
key: string,
throttleInterval: number,
onChange: ConnectionOnChange
) {
if (!connectionCache.has(key)) {
const machine = interpret(connectionStateMachine, onChange);
connectionCache.set(key, {
...machine,
throttledSend:
throttleInterval > 0
? throttle(machine.send, throttleInterval, true)
: machine.send,
});
}
let token = connectionManager.token(app);
if (!token) {
connectionManager.setAuthInProgress(app, true);
token = await connectionManager.refreshToken(app);
connectionManager.setAuthInProgress(app, false);
}
const ws = new WebSocket(`${url}?fal_jwt_token=${token}`);
connectionManager.set(key, ws);
return ws;
return connectionCache.get(key) as ConnectionStateMachine;
}
const noop = () => {
@ -199,6 +345,24 @@ const NoOpConnection: RealtimeConnection<any> = {
close: noop,
};
function isSuccessfulResult(data: any): boolean {
return (
data.status !== 'error' &&
data.type !== 'x-fal-message' &&
!isFalErrorResult(data)
);
}
type FalErrorResult = {
type: 'x-fal-error';
error: string;
reason: string;
};
function isFalErrorResult(data: any): data is FalErrorResult {
return data.type === 'x-fal-error';
}
/**
* The default implementation of the realtime client.
*/
@ -211,58 +375,68 @@ export const realtimeImpl: RealtimeClient = {
// if running on React in the server, set clientOnly to true by default
clientOnly = isReact() && !isBrowser(),
connectionKey = crypto.randomUUID(),
throttleInterval = 64,
onError = noop,
onResult,
maxBuffering,
throttleInterval = DEFAULT_THROTTLE_INTERVAL,
} = handler;
if (clientOnly && typeof window === 'undefined') {
if (clientOnly && !isBrowser()) {
return NoOpConnection;
}
let pendingMessage: Input | undefined = undefined;
let previousState: string | undefined;
let reconnecting = false;
let ws: WebSocket | null = null;
const _send = (input: Input) => {
const requestId = crypto.randomUUID();
if (ws && ws.readyState === WebSocket.OPEN) {
ws.send(
JSON.stringify({
request_id: requestId,
...input,
})
);
} else {
pendingMessage = input;
if (!reconnecting) {
reconnecting = true;
reconnect();
// Although the state machine is cached so we don't open multiple connections,
// we still need to update the callbacks so we can call the correct references
// when the state machine is reused. This is needed because the callbacks
// are passed as part of the handler object, which can be different across
// different calls to `connect`.
connectionCallbacks.set(connectionKey, {
onError: handler.onError,
onResult: handler.onResult,
});
const getCallbacks = () =>
connectionCallbacks.get(connectionKey) as RealtimeConnectionCallback;
const stateMachine = reuseInterpreter(
connectionKey,
throttleInterval,
({ context, machine, send }) => {
const { enqueuedMessage, token } = context;
if (machine.current === 'active' && enqueuedMessage) {
send({ type: 'send', message: enqueuedMessage });
}
}
};
const send =
throttleInterval > 0 ? throttle(_send, throttleInterval) : _send;
const reconnect = () => {
if (ws && ws.readyState === WebSocket.OPEN) {
return;
}
if (connectionManager.isAuthInProgress(app)) {
return;
}
getConnection(app, connectionKey)
.then((connection) => {
ws = connection;
if (
machine.current === 'authRequired' &&
token === undefined &&
previousState !== machine.current
) {
send({ type: 'initiateAuth' });
getToken(app)
.then((token) => {
send({ type: 'authenticated', token });
const tokenExpirationTimeout = Math.round(
TOKEN_EXPIRATION_SECONDS * 0.9 * 1000
);
setTimeout(() => {
send({ type: 'expireToken' });
}, tokenExpirationTimeout);
})
.catch((error) => {
send({ type: 'unauthorized', error });
});
}
if (
machine.current === 'connecting' &&
previousState !== machine.current &&
token !== undefined
) {
const ws = new WebSocket(
buildRealtimeUrl(app, { token, maxBuffering })
);
ws.onopen = () => {
reconnecting = false;
if (pendingMessage) {
send(pendingMessage);
pendingMessage = undefined;
}
send({ type: 'connected', websocket: ws });
};
ws.onclose = (event) => {
connectionManager.remove(connectionKey);
if (event.code !== WebSocketErrorCodes.NORMAL_CLOSURE) {
const { onError = noop } = getCallbacks();
onError(
new ApiError({
message: `Error closing the connection: ${event.reason}`,
@ -270,16 +444,11 @@ export const realtimeImpl: RealtimeClient = {
})
);
}
ws = null;
send({ type: 'connectionClosed', code: event.code });
};
ws.onerror = (event) => {
// TODO handle errors once server specify them
// if error 401, refresh token and retry
// if error 403, refresh token and retry
connectionManager.expireToken(app);
connectionManager.remove(connectionKey);
ws = null;
// if any of those are failed again, call onError
// TODO specify error protocol for identified errors
const { onError = noop } = getCallbacks();
onError(new ApiError({ message: 'Unknown error', status: 500 }));
};
ws.onmessage = (event) => {
@ -287,28 +456,51 @@ export const realtimeImpl: RealtimeClient = {
// Drop messages that are not related to the actual result.
// In the future, we might want to handle other types of messages.
// TODO: specify the fal ws protocol format
if (data.status !== 'error' && data.type !== 'x-fal-message') {
if (isUnauthorizedError(data)) {
send({ type: 'unauthorized', error: new Error('Unauthorized') });
return;
}
if (isSuccessfulResult(data)) {
const { onResult } = getCallbacks();
onResult(data);
return;
}
if (isFalErrorResult(data)) {
const { onError = noop } = getCallbacks();
onError(
new ApiError({
message: `${data.error}: ${data.reason}`,
// TODO better error status code
status: 400,
body: data,
})
);
return;
}
};
})
.catch((error) => {
onError(
new ApiError({ message: 'Error opening connection', status: 500 })
);
});
}
previousState = machine.current;
}
);
const send = (input: Input & Partial<WithRequestId>) => {
// Use throttled send to avoid sending too many messages
stateMachine.throttledSend({
type: 'send',
message: {
...input,
request_id: input['request_id'] ?? crypto.randomUUID(),
},
});
};
const close = () => {
stateMachine.send({ type: 'close' });
};
return {
send,
close() {
if (ws && ws.readyState === WebSocket.CLOSED) {
ws.close(
WebSocketErrorCodes.GOING_AWAY,
'Client manually closed the connection.'
);
}
},
close,
};
},
};

View File

@ -19,13 +19,14 @@ export function isValidUrl(url: string) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function throttle<T extends (...args: any[]) => any>(
func: T,
limit: number
limit: number,
leading = false
): (...funcArgs: Parameters<T>) => ReturnType<T> | void {
let lastFunc: NodeJS.Timeout | null;
let lastRan: number;
return (...args: Parameters<T>): ReturnType<T> | void => {
if (!lastRan) {
if (!lastRan && leading) {
func(...args);
lastRan = Date.now();
} else {

View File

@ -1,5 +1,5 @@
import { NextResponse, type NextRequest } from 'next/server';
import type { NextApiHandler } from 'next/types';
import { type NextRequest, NextResponse } from 'next/server';
import { DEFAULT_PROXY_ROUTE, handleRequest } from './index';
/**

6
package-lock.json generated
View File

@ -28,6 +28,7 @@
"react": "^18.2.0",
"react-dom": "^18.2.0",
"regenerator-runtime": "0.13.7",
"robot3": "^0.4.1",
"ts-morph": "^17.0.1",
"tslib": "^2.3.0"
},
@ -24588,6 +24589,11 @@
"url": "https://github.com/sponsors/isaacs"
}
},
"node_modules/robot3": {
"version": "0.4.1",
"resolved": "https://registry.npmjs.org/robot3/-/robot3-0.4.1.tgz",
"integrity": "sha512-hzjy826lrxzx8eRgv80idkf8ua1JAepRc9Efdtj03N3KNJuznQCPlyCJ7gnUmDFwZCLQjxy567mQVKmdv2BsXQ=="
},
"node_modules/run-parallel": {
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz",

View File

@ -44,6 +44,7 @@
"react": "^18.2.0",
"react-dom": "^18.2.0",
"regenerator-runtime": "0.13.7",
"robot3": "^0.4.1",
"ts-morph": "^17.0.1",
"tslib": "^2.3.0"
},