From 07f1b125f7da72b4a86a41b3cd95a5d041379df3 Mon Sep 17 00:00:00 2001 From: Daniel Rochetti Date: Mon, 15 Jan 2024 15:14:21 -0800 Subject: [PATCH] feat: fal.run url support (#40) * feat: fal.run url support * feat: new queue url format * fix: build issue * chore: change demo app ids --- .../app/camera-turbo/page.tsx | 2 +- apps/demo-nextjs-app-router/app/page.tsx | 2 +- .../app/realtime/page.tsx | 2 +- .../app/whisper/page.tsx | 2 +- apps/demo-nextjs-page-router/pages/index.tsx | 4 +- libs/client/package.json | 2 +- libs/client/src/config.spec.ts | 2 - libs/client/src/config.ts | 19 +----- libs/client/src/function.spec.ts | 12 ++-- libs/client/src/function.ts | 65 +++++++++++++------ libs/client/src/realtime.ts | 14 ++-- libs/client/src/storage.ts | 2 +- libs/client/src/utils.spec.ts | 17 ++++- libs/client/src/utils.ts | 18 ++++- libs/proxy/package.json | 2 +- libs/proxy/src/index.ts | 6 +- 16 files changed, 106 insertions(+), 65 deletions(-) diff --git a/apps/demo-nextjs-app-router/app/camera-turbo/page.tsx b/apps/demo-nextjs-app-router/app/camera-turbo/page.tsx index b70261e..716b400 100644 --- a/apps/demo-nextjs-app-router/app/camera-turbo/page.tsx +++ b/apps/demo-nextjs-app-router/app/camera-turbo/page.tsx @@ -137,7 +137,7 @@ export default function WebcamPage() { const previewRef = useRef(null); const { send } = fal.realtime.connect( - '110602490-sd-turbo-real-time-high-fps-msgpack', + 'fal-ai/sd-turbo-real-time-high-fps-msgpack', { connectionKey: 'camera-turbo-demo', // not throttling the client, handling throttling of the camera itself diff --git a/apps/demo-nextjs-app-router/app/page.tsx b/apps/demo-nextjs-app-router/app/page.tsx index e3a6062..a24c7c5 100644 --- a/apps/demo-nextjs-app-router/app/page.tsx +++ b/apps/demo-nextjs-app-router/app/page.tsx @@ -82,7 +82,7 @@ export default function Home() { const start = Date.now(); try { const result: Result = await fal.subscribe( - '54285744-illusion-diffusion', + '54285744/illusion-diffusion', { input: { prompt, diff --git a/apps/demo-nextjs-app-router/app/realtime/page.tsx b/apps/demo-nextjs-app-router/app/realtime/page.tsx index a9d9229..f39bd4c 100644 --- a/apps/demo-nextjs-app-router/app/realtime/page.tsx +++ b/apps/demo-nextjs-app-router/app/realtime/page.tsx @@ -14,7 +14,7 @@ const PROMPT = 'a moon in a starry night sky'; export default function RealtimePage() { const [image, setImage] = useState(null); - const { send } = fal.realtime.connect('110602490-lcm-sd15-i2i', { + const { send } = fal.realtime.connect('fal-ai/lcm-sd15-i2i', { connectionKey: 'realtime-demo', throttleInterval: 128, onResult(result) { diff --git a/apps/demo-nextjs-app-router/app/whisper/page.tsx b/apps/demo-nextjs-app-router/app/whisper/page.tsx index b79c091..daf02f8 100644 --- a/apps/demo-nextjs-app-router/app/whisper/page.tsx +++ b/apps/demo-nextjs-app-router/app/whisper/page.tsx @@ -110,7 +110,7 @@ export default function WhisperDemo() { setLoading(true); const start = Date.now(); try { - const result = await fal.subscribe('110602490-whisper', { + const result = await fal.subscribe('fal-ai/whisper', { input: { file_name: 'recording.wav', audio_url: audioFile, diff --git a/apps/demo-nextjs-page-router/pages/index.tsx b/apps/demo-nextjs-page-router/pages/index.tsx index 0f121ab..832aaaa 100644 --- a/apps/demo-nextjs-page-router/pages/index.tsx +++ b/apps/demo-nextjs-page-router/pages/index.tsx @@ -74,13 +74,13 @@ export function Index() { setLoading(true); const start = Date.now(); try { - const result: Result = await fal.subscribe('110602490-lora', { + const result: Result = await fal.subscribe('fal-ai/lora', { input: { prompt, model_name: 'stabilityai/stable-diffusion-xl-base-1.0', image_size: 'square_hd', }, - pollInterval: 5000, // Default is 1000 (every 1s) + pollInterval: 3000, // Default is 1000 (every 1s) logs: true, onQueueUpdate(update) { setElapsedTime(Date.now() - start); diff --git a/libs/client/package.json b/libs/client/package.json index 3892690..1b5027b 100644 --- a/libs/client/package.json +++ b/libs/client/package.json @@ -1,7 +1,7 @@ { "name": "@fal-ai/serverless-client", "description": "The fal serverless JS/TS client", - "version": "0.7.4", + "version": "0.8.0", "license": "MIT", "repository": { "type": "git", diff --git a/libs/client/src/config.spec.ts b/libs/client/src/config.spec.ts index 4f98ec0..1ef8e09 100644 --- a/libs/client/src/config.spec.ts +++ b/libs/client/src/config.spec.ts @@ -3,12 +3,10 @@ import { config, getConfig } from './config'; describe('The config test suite', () => { it('should set the config variables accordingly', () => { const newConfig = { - host: 'some-other-host', credentials: 'key-id:key-secret', }; config(newConfig); const currentConfig = getConfig(); - expect(currentConfig.host).toBe(newConfig.host); expect(currentConfig.credentials).toEqual(newConfig.credentials); }); }); diff --git a/libs/client/src/config.ts b/libs/client/src/config.ts index 9bfb6e8..5395082 100644 --- a/libs/client/src/config.ts +++ b/libs/client/src/config.ts @@ -10,7 +10,6 @@ export type CredentialsResolver = () => string | undefined; export type Config = { credentials?: undefined | string | CredentialsResolver; - host?: string; proxyUrl?: string; requestMiddleware?: RequestMiddleware; responseHandler?: ResponseHandler; @@ -46,22 +45,7 @@ export const credentialsFromEnv: CredentialsResolver = () => { return `${process.env.FAL_KEY_ID}:${process.env.FAL_KEY_SECRET}`; }; -/** - * Get the default host for the fal-serverless gateway endpoint. - * @private - * @returns the default host. Depending on the platform it can default to - * the environment variable `FAL_HOST`. - */ -function getDefaultHost(): string { - const host = 'gateway.alpha.fal.ai'; - if (typeof process !== 'undefined' && process.env) { - return process.env.FAL_HOST || host; - } - return host; -} - const DEFAULT_CONFIG: Partial = { - host: getDefaultHost(), credentials: credentialsFromEnv, requestMiddleware: (request) => Promise.resolve(request), responseHandler: defaultResponseHandler, @@ -104,6 +88,5 @@ export function getConfig(): RequiredConfig { * @returns the URL of the fal serverless rest api endpoint. */ export function getRestApiUrl(): string { - const { host } = getConfig(); - return host.replace('gateway', 'rest'); + return 'https://rest.alpha.fal.ai'; } diff --git a/libs/client/src/function.spec.ts b/libs/client/src/function.spec.ts index 6c0e5e5..6c47667 100644 --- a/libs/client/src/function.spec.ts +++ b/libs/client/src/function.spec.ts @@ -1,5 +1,4 @@ import uuid from 'uuid-random'; -import { getConfig } from './config'; import { buildUrl } from './function'; describe('The function test suite', () => { @@ -9,10 +8,15 @@ describe('The function test suite', () => { expect(url).toMatch(`trigger/12345/${id}`); }); - it('should build the URL with a function alias', () => { - const { host } = getConfig(); + it('should build the URL with a function user-id-app-alias', () => { const alias = '12345-some-alias'; const url = buildUrl(alias); - expect(url).toMatch(`${alias}.${host}`); + expect(url).toMatch(`fal.run/12345/some-alias`); + }); + + it('should build the URL with a function username/app-alias', () => { + const alias = 'fal-ai/text-to-image'; + const url = buildUrl(alias); + expect(url).toMatch(`fal.run/${alias}`); }); }); diff --git a/libs/client/src/function.ts b/libs/client/src/function.ts index 9c42f6e..f354ba6 100644 --- a/libs/client/src/function.ts +++ b/libs/client/src/function.ts @@ -1,8 +1,7 @@ -import { getConfig } from './config'; import { dispatchRequest } from './request'; import { storageImpl } from './storage'; import { EnqueueResult, QueueStatus } from './types'; -import { isUUIDv4, isValidUrl } from './utils'; +import { ensureAppIdFormat, isUUIDv4, isValidUrl } from './utils'; /** * The function input and other configuration when running @@ -36,6 +35,15 @@ type RunOptions = { readonly autoUpload?: boolean; }; +type ExtraOptions = { + /** + * If `true`, the function will use the queue to run the function + * asynchronously and return the result in a separate call. This + * influences how the URL is built. + */ + readonly subdomain?: string; +}; + /** * Builds the final url to run the function based on its `id` or alias and * a the options from `RunOptions`. @@ -47,9 +55,8 @@ type RunOptions = { */ export function buildUrl( id: string, - options: RunOptions = {} + options: RunOptions & ExtraOptions = {} ): string { - const { host } = getConfig(); const method = (options.method ?? 'post').toLowerCase(); const path = (options.path ?? '').replace(/^\//, '').replace(/\/{2,}/, '/'); const input = options.input; @@ -59,16 +66,37 @@ export function buildUrl( const queryParams = params ? `?${params.toString()}` : ''; const parts = id.split('/'); - // if a fal.ai url is passed, just use it + // if a fal url is passed, just use it if (isValidUrl(id)) { const url = id.endsWith('/') ? id : `${id}/`; return `${url}${path}${queryParams}`; } + // TODO remove this after some time, fal.run should be preferred if (parts.length === 2 && isUUIDv4(parts[1])) { + const host = 'gateway.shark.fal.ai'; return `https://${host}/trigger/${id}/${path}${queryParams}`; } - return `https://${id}.${host}/${path}${queryParams}`; + + const appId = ensureAppIdFormat(id); + const subdomain = options.subdomain ? `${options.subdomain}.` : ''; + const url = `https://${subdomain}fal.run/${appId}/${path}`; + return `${url.replace(/\/$/, '')}${queryParams}`; +} + +export async function send( + id: string, + options: RunOptions & ExtraOptions = {} +): Promise { + const input = + options.input && options.autoUpload !== false + ? await storageImpl.transformInput(options.input) + : options.input; + return dispatchRequest( + options.method ?? 'post', + buildUrl(id, options), + input as Input + ); } /** @@ -81,15 +109,7 @@ export async function run( id: string, options: RunOptions = {} ): Promise { - const input = - options.input && options.autoUpload !== false - ? await storageImpl.transformInput(options.input) - : options.input; - return dispatchRequest( - options.method ?? 'post', - buildUrl(id, options), - input as Input - ); + return send(id, options); } /** @@ -252,19 +272,21 @@ export const queue: Queue = { const query = webhookUrl ? '?' + new URLSearchParams({ fal_webhook: webhookUrl }).toString() : ''; - return run(id, { + return send(id, { ...runOptions, + subdomain: 'queue', method: 'post', - path: '/fal/queue/submit' + path + query, + path: path + query, }); }, async status( id: string, { requestId, logs = false }: QueueStatusOptions ): Promise { - return run(id, { + return send(id, { + subdomain: 'queue', method: 'get', - path: `/fal/queue/requests/${requestId}/status`, + path: `/requests/${requestId}/status`, input: { logs: logs ? '1' : '0', }, @@ -274,9 +296,10 @@ export const queue: Queue = { id: string, { requestId }: BaseQueueOptions ): Promise { - return run(id, { + return send(id, { + subdomain: 'queue', method: 'get', - path: `/fal/queue/requests/${requestId}/response`, + path: `/requests/${requestId}`, }); }, subscribe, diff --git a/libs/client/src/realtime.ts b/libs/client/src/realtime.ts index 403b23c..72afcc4 100644 --- a/libs/client/src/realtime.ts +++ b/libs/client/src/realtime.ts @@ -13,11 +13,11 @@ import { transition, } from 'robot3'; import uuid from 'uuid-random'; -import { getConfig, getRestApiUrl } from './config'; +import { getRestApiUrl } from './config'; import { dispatchRequest } from './request'; import { ApiError } from './response'; import { isBrowser } from './runtime'; -import { isReact, throttle } from './utils'; +import { ensureAppIdFormat, isReact, throttle } from './utils'; // Define the context interface Context { @@ -253,7 +253,6 @@ function buildRealtimeUrl( app: string, { token, maxBuffering }: RealtimeUrlParams ): string { - const { host } = getConfig(); if (maxBuffering !== undefined && (maxBuffering < 1 || maxBuffering > 60)) { throw new Error('The `maxBuffering` must be between 1 and 60 (inclusive)'); } @@ -263,7 +262,8 @@ function buildRealtimeUrl( if (maxBuffering !== undefined) { queryParams.set('max_buffering', maxBuffering.toFixed(0)); } - return `wss://${app}.${host}/ws?${queryParams.toString()}`; + const appId = ensureAppIdFormat(app); + return `wss://fal.run/${appId}/ws?${queryParams.toString()}`; } const TOKEN_EXPIRATION_SECONDS = 120; @@ -282,12 +282,12 @@ function shouldSendBinary(message: any): boolean { * Get a token to connect to the realtime endpoint. */ async function getToken(app: string): Promise { - const [_, ...appAlias] = app.split('-'); + const [, appAlias] = ensureAppIdFormat(app).split('/'); const token: string | object = await dispatchRequest( 'POST', - `https://${getRestApiUrl()}/tokens/`, + `${getRestApiUrl()}/tokens/`, { - allowed_apps: [appAlias.join('-')], + allowed_apps: [appAlias], token_expiration: TOKEN_EXPIRATION_SECONDS, } ); diff --git a/libs/client/src/storage.ts b/libs/client/src/storage.ts index d150114..203bbbd 100644 --- a/libs/client/src/storage.ts +++ b/libs/client/src/storage.ts @@ -75,7 +75,7 @@ async function initiateUpload(file: Blob): Promise { file.name || `${Date.now()}.${getExtensionFromContentType(contentType)}`; return await dispatchRequest( 'POST', - `https://${getRestApiUrl()}/storage/upload/initiate`, + `${getRestApiUrl()}/storage/upload/initiate`, { content_type: contentType, file_name: filename, diff --git a/libs/client/src/utils.spec.ts b/libs/client/src/utils.spec.ts index 0140d96..e995936 100644 --- a/libs/client/src/utils.spec.ts +++ b/libs/client/src/utils.spec.ts @@ -1,5 +1,5 @@ import uuid from 'uuid-random'; -import { isUUIDv4 } from './utils'; +import { ensureAppIdFormat, isUUIDv4 } from './utils'; describe('The utils test suite', () => { it('should match a valid v4 uuid', () => { @@ -11,4 +11,19 @@ describe('The utils test suite', () => { const id = 'e726b886-e2c2-11ed-b5ea-0242ac120002'; expect(isUUIDv4(id)).toBe(false); }); + + it('shoud match match a legacy appOwner-appId format', () => { + const id = '12345-abcde-fgh'; + expect(ensureAppIdFormat(id)).toBe('12345/abcde-fgh'); + }); + + it('shoud match a current appOwner/appId format', () => { + const id = 'fal-ai/fast-sdxl'; + expect(ensureAppIdFormat(id)).toBe(id); + }); + + it('should throw on an invalid app id format', () => { + const id = 'just-an-id'; + expect(() => ensureAppIdFormat(id)).toThrowError(); + }); }); diff --git a/libs/client/src/utils.ts b/libs/client/src/utils.ts index cb73b89..0d10fec 100644 --- a/libs/client/src/utils.ts +++ b/libs/client/src/utils.ts @@ -7,10 +7,24 @@ export function isUUIDv4(id: string): boolean { ); } +export function ensureAppIdFormat(id: string): string { + const parts = id.split('/'); + if (parts.length === 2) { + return id; + } + const [, appOwner, appId] = /^([0-9]+)-([a-zA-Z0-9-]+)$/.exec(id) || []; + if (appOwner && appId) { + return `${appOwner}/${appId}`; + } + throw new Error( + `Invalid app id: ${id}. Must be in the format /` + ); +} + export function isValidUrl(url: string) { try { - const parsedUrl = new URL(url); - return parsedUrl.hostname.endsWith('fal.ai'); + const { host } = new URL(url); + return /(fal\.(ai|run))$/.test(host); } catch (_) { return false; } diff --git a/libs/proxy/package.json b/libs/proxy/package.json index a03afdc..eaf6636 100644 --- a/libs/proxy/package.json +++ b/libs/proxy/package.json @@ -1,6 +1,6 @@ { "name": "@fal-ai/serverless-proxy", - "version": "0.6.0", + "version": "0.7.0", "license": "MIT", "repository": { "type": "git", diff --git a/libs/proxy/src/index.ts b/libs/proxy/src/index.ts index 2902876..351d8c3 100644 --- a/libs/proxy/src/index.ts +++ b/libs/proxy/src/index.ts @@ -9,6 +9,8 @@ const FAL_KEY_SECRET = export type HeaderValue = string | string[] | undefined | null; +const FAL_URL_REG_EXP = /(fal\.(run|ai))$/; + /** * The proxy behavior that is passed to the proxy handler. This is a subset of * request objects that are used by different frameworks, like Express and NextJS. @@ -69,7 +71,9 @@ export async function handleRequest( if (!targetUrl) { return behavior.respondWith(400, `Missing the ${TARGET_URL_HEADER} header`); } - if (targetUrl.indexOf('fal.ai') === -1) { + + const urlHost = new URL(targetUrl).host; + if (!FAL_URL_REG_EXP.test(urlHost)) { return behavior.respondWith(412, `Invalid ${TARGET_URL_HEADER} header`); }