/* eslint-disable @typescript-eslint/no-explicit-any */ import { decode, encode } from "@msgpack/msgpack"; import { ContextFunction, InterpretOnChangeFunction, Service, createMachine, guard, immediate, interpret, reduce, state, transition, } from "robot3"; import { TOKEN_EXPIRATION_SECONDS, getTemporaryAuthToken } from "./auth"; import { RequiredConfig } from "./config"; import { ApiError } from "./response"; import { isBrowser } from "./runtime"; import { ensureEndpointIdFormat, isReact, throttle } from "./utils"; // Define the context interface Context { token?: string; enqueuedMessage?: any; websocket?: WebSocket; error?: Error; } const initialState: ContextFunction = () => ({ enqueuedMessage: undefined, }); type SendEvent = { type: "send"; message: any }; type AuthenticatedEvent = { type: "authenticated"; token: string }; type InitiateAuthEvent = { type: "initiateAuth" }; type UnauthorizedEvent = { type: "unauthorized"; error: Error }; type ConnectedEvent = { type: "connected"; websocket: WebSocket }; type ConnectionClosedEvent = { type: "connectionClosed"; code: number; reason: string; }; type Event = | SendEvent | AuthenticatedEvent | InitiateAuthEvent | UnauthorizedEvent | ConnectedEvent | ConnectionClosedEvent; function hasToken(context: Context): boolean { return context.token !== undefined; } function noToken(context: Context): boolean { return !hasToken(context); } function enqueueMessage(context: Context, event: SendEvent): Context { return { ...context, enqueuedMessage: event.message, }; } function closeConnection(context: Context): Context { if (context.websocket && context.websocket.readyState === WebSocket.OPEN) { context.websocket.close(); } return { ...context, websocket: undefined, }; } function sendMessage(context: Context, event: SendEvent): Context { if (context.websocket && context.websocket.readyState === WebSocket.OPEN) { if (event.message instanceof Uint8Array) { context.websocket.send(event.message); } else { context.websocket.send(encode(event.message)); } return { ...context, enqueuedMessage: undefined, }; } return { ...context, enqueuedMessage: event.message, }; } function expireToken(context: Context): Context { return { ...context, token: undefined, }; } function setToken(context: Context, event: AuthenticatedEvent): Context { return { ...context, token: event.token, }; } function connectionEstablished( context: Context, event: ConnectedEvent, ): Context { return { ...context, websocket: event.websocket, }; } // State machine const connectionStateMachine = createMachine( "idle", { idle: state( transition("send", "connecting", reduce(enqueueMessage)), transition("expireToken", "idle", reduce(expireToken)), transition("close", "idle", reduce(closeConnection)), ), connecting: state( transition("connecting", "connecting"), transition("connected", "active", reduce(connectionEstablished)), transition("connectionClosed", "idle", reduce(closeConnection)), transition("send", "connecting", reduce(enqueueMessage)), transition("close", "idle", reduce(closeConnection)), immediate("authRequired", guard(noToken)), ), authRequired: state( transition("initiateAuth", "authInProgress"), transition("send", "authRequired", reduce(enqueueMessage)), transition("close", "idle", reduce(closeConnection)), ), authInProgress: state( transition("authenticated", "connecting", reduce(setToken)), transition( "unauthorized", "idle", reduce(expireToken), reduce(closeConnection), ), transition("send", "authInProgress", reduce(enqueueMessage)), transition("close", "idle", reduce(closeConnection)), ), active: state( transition("send", "active", reduce(sendMessage)), transition("unauthorized", "idle", reduce(expireToken)), transition("connectionClosed", "idle", reduce(closeConnection)), transition("close", "idle", reduce(closeConnection)), ), failed: state( transition("send", "failed"), transition("close", "idle", reduce(closeConnection)), ), }, initialState, ); type WithRequestId = { request_id: string; }; /** * A connection object that allows you to `send` request payloads to a * realtime endpoint. */ export interface RealtimeConnection { send(input: Input & Partial): void; close(): void; } /** * Options for connecting to the realtime endpoint. */ export interface RealtimeConnectionHandler { /** * The connection key. This is used to reuse the same connection * across multiple calls to `connect`. This is particularly useful in * contexts where the connection is established as part of a component * lifecycle (e.g. React) and the component is re-rendered multiple times. */ connectionKey?: string; /** * If `true`, the connection will only be established on the client side. * This is useful for frameworks that reuse code for both server-side * rendering and client-side rendering (e.g. Next.js). * * This is set to `true` by default when running on React in the server. * Otherwise, it is set to `false`. * * Note that more SSR frameworks might be automatically detected * in the future. In the meantime, you can set this to `true` when needed. */ clientOnly?: boolean; /** * The throtle duration in milliseconds. This is used to throtle the * calls to the `send` function. Realtime apps usually react to user * input, which can be very frequent (e.g. fast typing or mouse/drag movements). * * The default value is `128` milliseconds. */ throttleInterval?: number; /** * Configures the maximum amount of frames to store in memory before starting to drop * old ones for in favor of the newer ones. It must be between `1` and `60`. * * The recommended is `2`. The default is `undefined` so it can be determined * by the app (normally is set to the recommended setting). */ maxBuffering?: number; /** * Callback function that is called when a result is received. * @param result - The result of the request. */ onResult(result: Output & WithRequestId): void; /** * Callback function that is called when an error occurs. * @param error - The error that occurred. */ onError?(error: ApiError): void; } export interface RealtimeClient { /** * Connect to the realtime endpoint. The default implementation uses * WebSockets to connect to fal function endpoints that support WSS. * * @param app the app alias or identifier. * @param handler the connection handler. */ connect( app: string, handler: RealtimeConnectionHandler, ): RealtimeConnection; } type RealtimeUrlParams = { token: string; maxBuffering?: number; }; function buildRealtimeUrl( app: string, { token, maxBuffering }: RealtimeUrlParams, ): string { if (maxBuffering !== undefined && (maxBuffering < 1 || maxBuffering > 60)) { throw new Error("The `maxBuffering` must be between 1 and 60 (inclusive)"); } const queryParams = new URLSearchParams({ fal_jwt_token: token, }); if (maxBuffering !== undefined) { queryParams.set("max_buffering", maxBuffering.toFixed(0)); } const appId = ensureEndpointIdFormat(app); return `wss://fal.run/${appId}/realtime?${queryParams.toString()}`; } const DEFAULT_THROTTLE_INTERVAL = 128; function isUnauthorizedError(message: any): boolean { // TODO we need better protocol definition with error codes return message["status"] === "error" && message["error"] === "Unauthorized"; } /** * See https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1 */ const WebSocketErrorCodes = { NORMAL_CLOSURE: 1000, GOING_AWAY: 1001, }; type ConnectionStateMachine = Service & { throttledSend: ( event: Event, payload?: any, ) => void | Promise | undefined; }; type ConnectionOnChange = InterpretOnChangeFunction< typeof connectionStateMachine >; type RealtimeConnectionCallback = Pick< RealtimeConnectionHandler, "onResult" | "onError" >; const connectionCache = new Map(); const connectionCallbacks = new Map(); function reuseInterpreter( key: string, throttleInterval: number, onChange: ConnectionOnChange, ) { if (!connectionCache.has(key)) { const machine = interpret(connectionStateMachine, onChange); connectionCache.set(key, { ...machine, throttledSend: throttleInterval > 0 ? throttle(machine.send, throttleInterval, true) : machine.send, }); } return connectionCache.get(key) as ConnectionStateMachine; } const noop = () => { /* No-op */ }; /** * A no-op connection that does not send any message. * Useful on the frameworks that reuse code for both ssr and csr (e.g. Next) * so the call when doing ssr has no side-effects. */ // eslint-disable-next-line @typescript-eslint/no-explicit-any const NoOpConnection: RealtimeConnection = { send: noop, close: noop, }; function isSuccessfulResult(data: any): boolean { return ( data.status !== "error" && data.type !== "x-fal-message" && !isFalErrorResult(data) ); } type FalErrorResult = { type: "x-fal-error"; error: string; reason: string; }; function isFalErrorResult(data: any): data is FalErrorResult { return data.type === "x-fal-error"; } type RealtimeClientDependencies = { config: RequiredConfig; }; export function createRealtimeClient({ config, }: RealtimeClientDependencies): RealtimeClient { return { connect( app: string, handler: RealtimeConnectionHandler, ): RealtimeConnection { const { // if running on React in the server, set clientOnly to true by default clientOnly = isReact() && !isBrowser(), connectionKey = crypto.randomUUID(), maxBuffering, throttleInterval = DEFAULT_THROTTLE_INTERVAL, } = handler; if (clientOnly && !isBrowser()) { return NoOpConnection; } let previousState: string | undefined; // Although the state machine is cached so we don't open multiple connections, // we still need to update the callbacks so we can call the correct references // when the state machine is reused. This is needed because the callbacks // are passed as part of the handler object, which can be different across // different calls to `connect`. connectionCallbacks.set(connectionKey, { onError: handler.onError, onResult: handler.onResult, }); const getCallbacks = () => connectionCallbacks.get(connectionKey) as RealtimeConnectionCallback; const stateMachine = reuseInterpreter( connectionKey, throttleInterval, ({ context, machine, send }) => { const { enqueuedMessage, token } = context; if (machine.current === "active" && enqueuedMessage) { send({ type: "send", message: enqueuedMessage }); } if ( machine.current === "authRequired" && token === undefined && previousState !== machine.current ) { send({ type: "initiateAuth" }); getTemporaryAuthToken(app, config) .then((token) => { send({ type: "authenticated", token }); const tokenExpirationTimeout = Math.round( TOKEN_EXPIRATION_SECONDS * 0.9 * 1000, ); setTimeout(() => { send({ type: "expireToken" }); }, tokenExpirationTimeout); }) .catch((error) => { send({ type: "unauthorized", error }); }); } if ( machine.current === "connecting" && previousState !== machine.current && token !== undefined ) { const ws = new WebSocket( buildRealtimeUrl(app, { token, maxBuffering }), ); ws.onopen = () => { send({ type: "connected", websocket: ws }); }; ws.onclose = (event) => { if (event.code !== WebSocketErrorCodes.NORMAL_CLOSURE) { const { onError = noop } = getCallbacks(); onError( new ApiError({ message: `Error closing the connection: ${event.reason}`, status: event.code, }), ); } send({ type: "connectionClosed", code: event.code }); }; ws.onerror = (event) => { // TODO specify error protocol for identified errors const { onError = noop } = getCallbacks(); onError(new ApiError({ message: "Unknown error", status: 500 })); }; ws.onmessage = (event) => { const { onResult } = getCallbacks(); // Handle binary messages as msgpack messages if (event.data instanceof ArrayBuffer) { const result = decode(new Uint8Array(event.data)); onResult(result); return; } if (event.data instanceof Uint8Array) { const result = decode(event.data); onResult(result); return; } if (event.data instanceof Blob) { event.data.arrayBuffer().then((buffer) => { const result = decode(new Uint8Array(buffer)); onResult(result); }); return; } // Otherwise handle strings as plain JSON messages const data = JSON.parse(event.data); // Drop messages that are not related to the actual result. // In the future, we might want to handle other types of messages. // TODO: specify the fal ws protocol format if (isUnauthorizedError(data)) { send({ type: "unauthorized", error: new Error("Unauthorized"), }); return; } if (isSuccessfulResult(data)) { onResult(data); return; } if (isFalErrorResult(data)) { if (data.error === "TIMEOUT") { // Timeout error messages just indicate that the connection hasn't // received an incoming message for a while. We don't need to // handle them as errors. return; } const { onError = noop } = getCallbacks(); onError( new ApiError({ message: `${data.error}: ${data.reason}`, // TODO better error status code status: 400, body: data, }), ); return; } }; } previousState = machine.current; }, ); const send = (input: Input & Partial) => { // Use throttled send to avoid sending too many messages stateMachine.throttledSend({ type: "send", message: input, }); }; const close = () => { stateMachine.send({ type: "close" }); }; return { send, close, }; }, }; }