feat: fal.run url support (#40)

* feat: fal.run url support

* feat: new queue url format

* fix: build issue

* chore: change demo app ids
This commit is contained in:
Daniel Rochetti 2024-01-15 15:14:21 -08:00 committed by GitHub
parent 208073ce17
commit 07f1b125f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 106 additions and 65 deletions

View File

@ -137,7 +137,7 @@ export default function WebcamPage() {
const previewRef = useRef<HTMLCanvasElement | null>(null); const previewRef = useRef<HTMLCanvasElement | null>(null);
const { send } = fal.realtime.connect<LCMInput, LCMOutput>( const { send } = fal.realtime.connect<LCMInput, LCMOutput>(
'110602490-sd-turbo-real-time-high-fps-msgpack', 'fal-ai/sd-turbo-real-time-high-fps-msgpack',
{ {
connectionKey: 'camera-turbo-demo', connectionKey: 'camera-turbo-demo',
// not throttling the client, handling throttling of the camera itself // not throttling the client, handling throttling of the camera itself

View File

@ -82,7 +82,7 @@ export default function Home() {
const start = Date.now(); const start = Date.now();
try { try {
const result: Result = await fal.subscribe( const result: Result = await fal.subscribe(
'54285744-illusion-diffusion', '54285744/illusion-diffusion',
{ {
input: { input: {
prompt, prompt,

View File

@ -14,7 +14,7 @@ const PROMPT = 'a moon in a starry night sky';
export default function RealtimePage() { export default function RealtimePage() {
const [image, setImage] = useState<string | null>(null); const [image, setImage] = useState<string | null>(null);
const { send } = fal.realtime.connect('110602490-lcm-sd15-i2i', { const { send } = fal.realtime.connect('fal-ai/lcm-sd15-i2i', {
connectionKey: 'realtime-demo', connectionKey: 'realtime-demo',
throttleInterval: 128, throttleInterval: 128,
onResult(result) { onResult(result) {

View File

@ -110,7 +110,7 @@ export default function WhisperDemo() {
setLoading(true); setLoading(true);
const start = Date.now(); const start = Date.now();
try { try {
const result = await fal.subscribe('110602490-whisper', { const result = await fal.subscribe('fal-ai/whisper', {
input: { input: {
file_name: 'recording.wav', file_name: 'recording.wav',
audio_url: audioFile, audio_url: audioFile,

View File

@ -74,13 +74,13 @@ export function Index() {
setLoading(true); setLoading(true);
const start = Date.now(); const start = Date.now();
try { try {
const result: Result = await fal.subscribe('110602490-lora', { const result: Result = await fal.subscribe('fal-ai/lora', {
input: { input: {
prompt, prompt,
model_name: 'stabilityai/stable-diffusion-xl-base-1.0', model_name: 'stabilityai/stable-diffusion-xl-base-1.0',
image_size: 'square_hd', image_size: 'square_hd',
}, },
pollInterval: 5000, // Default is 1000 (every 1s) pollInterval: 3000, // Default is 1000 (every 1s)
logs: true, logs: true,
onQueueUpdate(update) { onQueueUpdate(update) {
setElapsedTime(Date.now() - start); setElapsedTime(Date.now() - start);

View File

@ -1,7 +1,7 @@
{ {
"name": "@fal-ai/serverless-client", "name": "@fal-ai/serverless-client",
"description": "The fal serverless JS/TS client", "description": "The fal serverless JS/TS client",
"version": "0.7.4", "version": "0.8.0",
"license": "MIT", "license": "MIT",
"repository": { "repository": {
"type": "git", "type": "git",

View File

@ -3,12 +3,10 @@ import { config, getConfig } from './config';
describe('The config test suite', () => { describe('The config test suite', () => {
it('should set the config variables accordingly', () => { it('should set the config variables accordingly', () => {
const newConfig = { const newConfig = {
host: 'some-other-host',
credentials: 'key-id:key-secret', credentials: 'key-id:key-secret',
}; };
config(newConfig); config(newConfig);
const currentConfig = getConfig(); const currentConfig = getConfig();
expect(currentConfig.host).toBe(newConfig.host);
expect(currentConfig.credentials).toEqual(newConfig.credentials); expect(currentConfig.credentials).toEqual(newConfig.credentials);
}); });
}); });

View File

@ -10,7 +10,6 @@ export type CredentialsResolver = () => string | undefined;
export type Config = { export type Config = {
credentials?: undefined | string | CredentialsResolver; credentials?: undefined | string | CredentialsResolver;
host?: string;
proxyUrl?: string; proxyUrl?: string;
requestMiddleware?: RequestMiddleware; requestMiddleware?: RequestMiddleware;
responseHandler?: ResponseHandler<any>; responseHandler?: ResponseHandler<any>;
@ -46,22 +45,7 @@ export const credentialsFromEnv: CredentialsResolver = () => {
return `${process.env.FAL_KEY_ID}:${process.env.FAL_KEY_SECRET}`; return `${process.env.FAL_KEY_ID}:${process.env.FAL_KEY_SECRET}`;
}; };
/**
* Get the default host for the fal-serverless gateway endpoint.
* @private
* @returns the default host. Depending on the platform it can default to
* the environment variable `FAL_HOST`.
*/
function getDefaultHost(): string {
const host = 'gateway.alpha.fal.ai';
if (typeof process !== 'undefined' && process.env) {
return process.env.FAL_HOST || host;
}
return host;
}
const DEFAULT_CONFIG: Partial<Config> = { const DEFAULT_CONFIG: Partial<Config> = {
host: getDefaultHost(),
credentials: credentialsFromEnv, credentials: credentialsFromEnv,
requestMiddleware: (request) => Promise.resolve(request), requestMiddleware: (request) => Promise.resolve(request),
responseHandler: defaultResponseHandler, responseHandler: defaultResponseHandler,
@ -104,6 +88,5 @@ export function getConfig(): RequiredConfig {
* @returns the URL of the fal serverless rest api endpoint. * @returns the URL of the fal serverless rest api endpoint.
*/ */
export function getRestApiUrl(): string { export function getRestApiUrl(): string {
const { host } = getConfig(); return 'https://rest.alpha.fal.ai';
return host.replace('gateway', 'rest');
} }

View File

@ -1,5 +1,4 @@
import uuid from 'uuid-random'; import uuid from 'uuid-random';
import { getConfig } from './config';
import { buildUrl } from './function'; import { buildUrl } from './function';
describe('The function test suite', () => { describe('The function test suite', () => {
@ -9,10 +8,15 @@ describe('The function test suite', () => {
expect(url).toMatch(`trigger/12345/${id}`); expect(url).toMatch(`trigger/12345/${id}`);
}); });
it('should build the URL with a function alias', () => { it('should build the URL with a function user-id-app-alias', () => {
const { host } = getConfig();
const alias = '12345-some-alias'; const alias = '12345-some-alias';
const url = buildUrl(alias); const url = buildUrl(alias);
expect(url).toMatch(`${alias}.${host}`); expect(url).toMatch(`fal.run/12345/some-alias`);
});
it('should build the URL with a function username/app-alias', () => {
const alias = 'fal-ai/text-to-image';
const url = buildUrl(alias);
expect(url).toMatch(`fal.run/${alias}`);
}); });
}); });

View File

@ -1,8 +1,7 @@
import { getConfig } from './config';
import { dispatchRequest } from './request'; import { dispatchRequest } from './request';
import { storageImpl } from './storage'; import { storageImpl } from './storage';
import { EnqueueResult, QueueStatus } from './types'; import { EnqueueResult, QueueStatus } from './types';
import { isUUIDv4, isValidUrl } from './utils'; import { ensureAppIdFormat, isUUIDv4, isValidUrl } from './utils';
/** /**
* The function input and other configuration when running * The function input and other configuration when running
@ -36,6 +35,15 @@ type RunOptions<Input> = {
readonly autoUpload?: boolean; readonly autoUpload?: boolean;
}; };
type ExtraOptions = {
/**
* If `true`, the function will use the queue to run the function
* asynchronously and return the result in a separate call. This
* influences how the URL is built.
*/
readonly subdomain?: string;
};
/** /**
* Builds the final url to run the function based on its `id` or alias and * Builds the final url to run the function based on its `id` or alias and
* a the options from `RunOptions<Input>`. * a the options from `RunOptions<Input>`.
@ -47,9 +55,8 @@ type RunOptions<Input> = {
*/ */
export function buildUrl<Input>( export function buildUrl<Input>(
id: string, id: string,
options: RunOptions<Input> = {} options: RunOptions<Input> & ExtraOptions = {}
): string { ): string {
const { host } = getConfig();
const method = (options.method ?? 'post').toLowerCase(); const method = (options.method ?? 'post').toLowerCase();
const path = (options.path ?? '').replace(/^\//, '').replace(/\/{2,}/, '/'); const path = (options.path ?? '').replace(/^\//, '').replace(/\/{2,}/, '/');
const input = options.input; const input = options.input;
@ -59,16 +66,37 @@ export function buildUrl<Input>(
const queryParams = params ? `?${params.toString()}` : ''; const queryParams = params ? `?${params.toString()}` : '';
const parts = id.split('/'); const parts = id.split('/');
// if a fal.ai url is passed, just use it // if a fal url is passed, just use it
if (isValidUrl(id)) { if (isValidUrl(id)) {
const url = id.endsWith('/') ? id : `${id}/`; const url = id.endsWith('/') ? id : `${id}/`;
return `${url}${path}${queryParams}`; return `${url}${path}${queryParams}`;
} }
// TODO remove this after some time, fal.run should be preferred
if (parts.length === 2 && isUUIDv4(parts[1])) { if (parts.length === 2 && isUUIDv4(parts[1])) {
const host = 'gateway.shark.fal.ai';
return `https://${host}/trigger/${id}/${path}${queryParams}`; return `https://${host}/trigger/${id}/${path}${queryParams}`;
} }
return `https://${id}.${host}/${path}${queryParams}`;
const appId = ensureAppIdFormat(id);
const subdomain = options.subdomain ? `${options.subdomain}.` : '';
const url = `https://${subdomain}fal.run/${appId}/${path}`;
return `${url.replace(/\/$/, '')}${queryParams}`;
}
export async function send<Input, Output>(
id: string,
options: RunOptions<Input> & ExtraOptions = {}
): Promise<Output> {
const input =
options.input && options.autoUpload !== false
? await storageImpl.transformInput(options.input)
: options.input;
return dispatchRequest<Input, Output>(
options.method ?? 'post',
buildUrl(id, options),
input as Input
);
} }
/** /**
@ -81,15 +109,7 @@ export async function run<Input, Output>(
id: string, id: string,
options: RunOptions<Input> = {} options: RunOptions<Input> = {}
): Promise<Output> { ): Promise<Output> {
const input = return send(id, options);
options.input && options.autoUpload !== false
? await storageImpl.transformInput(options.input)
: options.input;
return dispatchRequest<Input, Output>(
options.method ?? 'post',
buildUrl(id, options),
input as Input
);
} }
/** /**
@ -252,19 +272,21 @@ export const queue: Queue = {
const query = webhookUrl const query = webhookUrl
? '?' + new URLSearchParams({ fal_webhook: webhookUrl }).toString() ? '?' + new URLSearchParams({ fal_webhook: webhookUrl }).toString()
: ''; : '';
return run(id, { return send(id, {
...runOptions, ...runOptions,
subdomain: 'queue',
method: 'post', method: 'post',
path: '/fal/queue/submit' + path + query, path: path + query,
}); });
}, },
async status( async status(
id: string, id: string,
{ requestId, logs = false }: QueueStatusOptions { requestId, logs = false }: QueueStatusOptions
): Promise<QueueStatus> { ): Promise<QueueStatus> {
return run(id, { return send(id, {
subdomain: 'queue',
method: 'get', method: 'get',
path: `/fal/queue/requests/${requestId}/status`, path: `/requests/${requestId}/status`,
input: { input: {
logs: logs ? '1' : '0', logs: logs ? '1' : '0',
}, },
@ -274,9 +296,10 @@ export const queue: Queue = {
id: string, id: string,
{ requestId }: BaseQueueOptions { requestId }: BaseQueueOptions
): Promise<Output> { ): Promise<Output> {
return run(id, { return send(id, {
subdomain: 'queue',
method: 'get', method: 'get',
path: `/fal/queue/requests/${requestId}/response`, path: `/requests/${requestId}`,
}); });
}, },
subscribe, subscribe,

View File

@ -13,11 +13,11 @@ import {
transition, transition,
} from 'robot3'; } from 'robot3';
import uuid from 'uuid-random'; import uuid from 'uuid-random';
import { getConfig, getRestApiUrl } from './config'; import { getRestApiUrl } from './config';
import { dispatchRequest } from './request'; import { dispatchRequest } from './request';
import { ApiError } from './response'; import { ApiError } from './response';
import { isBrowser } from './runtime'; import { isBrowser } from './runtime';
import { isReact, throttle } from './utils'; import { ensureAppIdFormat, isReact, throttle } from './utils';
// Define the context // Define the context
interface Context { interface Context {
@ -253,7 +253,6 @@ function buildRealtimeUrl(
app: string, app: string,
{ token, maxBuffering }: RealtimeUrlParams { token, maxBuffering }: RealtimeUrlParams
): string { ): string {
const { host } = getConfig();
if (maxBuffering !== undefined && (maxBuffering < 1 || maxBuffering > 60)) { if (maxBuffering !== undefined && (maxBuffering < 1 || maxBuffering > 60)) {
throw new Error('The `maxBuffering` must be between 1 and 60 (inclusive)'); throw new Error('The `maxBuffering` must be between 1 and 60 (inclusive)');
} }
@ -263,7 +262,8 @@ function buildRealtimeUrl(
if (maxBuffering !== undefined) { if (maxBuffering !== undefined) {
queryParams.set('max_buffering', maxBuffering.toFixed(0)); queryParams.set('max_buffering', maxBuffering.toFixed(0));
} }
return `wss://${app}.${host}/ws?${queryParams.toString()}`; const appId = ensureAppIdFormat(app);
return `wss://fal.run/${appId}/ws?${queryParams.toString()}`;
} }
const TOKEN_EXPIRATION_SECONDS = 120; const TOKEN_EXPIRATION_SECONDS = 120;
@ -282,12 +282,12 @@ function shouldSendBinary(message: any): boolean {
* Get a token to connect to the realtime endpoint. * Get a token to connect to the realtime endpoint.
*/ */
async function getToken(app: string): Promise<string> { async function getToken(app: string): Promise<string> {
const [_, ...appAlias] = app.split('-'); const [, appAlias] = ensureAppIdFormat(app).split('/');
const token: string | object = await dispatchRequest<any, string>( const token: string | object = await dispatchRequest<any, string>(
'POST', 'POST',
`https://${getRestApiUrl()}/tokens/`, `${getRestApiUrl()}/tokens/`,
{ {
allowed_apps: [appAlias.join('-')], allowed_apps: [appAlias],
token_expiration: TOKEN_EXPIRATION_SECONDS, token_expiration: TOKEN_EXPIRATION_SECONDS,
} }
); );

View File

@ -75,7 +75,7 @@ async function initiateUpload(file: Blob): Promise<InitiateUploadResult> {
file.name || `${Date.now()}.${getExtensionFromContentType(contentType)}`; file.name || `${Date.now()}.${getExtensionFromContentType(contentType)}`;
return await dispatchRequest<InitiateUploadData, InitiateUploadResult>( return await dispatchRequest<InitiateUploadData, InitiateUploadResult>(
'POST', 'POST',
`https://${getRestApiUrl()}/storage/upload/initiate`, `${getRestApiUrl()}/storage/upload/initiate`,
{ {
content_type: contentType, content_type: contentType,
file_name: filename, file_name: filename,

View File

@ -1,5 +1,5 @@
import uuid from 'uuid-random'; import uuid from 'uuid-random';
import { isUUIDv4 } from './utils'; import { ensureAppIdFormat, isUUIDv4 } from './utils';
describe('The utils test suite', () => { describe('The utils test suite', () => {
it('should match a valid v4 uuid', () => { it('should match a valid v4 uuid', () => {
@ -11,4 +11,19 @@ describe('The utils test suite', () => {
const id = 'e726b886-e2c2-11ed-b5ea-0242ac120002'; const id = 'e726b886-e2c2-11ed-b5ea-0242ac120002';
expect(isUUIDv4(id)).toBe(false); expect(isUUIDv4(id)).toBe(false);
}); });
it('shoud match match a legacy appOwner-appId format', () => {
const id = '12345-abcde-fgh';
expect(ensureAppIdFormat(id)).toBe('12345/abcde-fgh');
});
it('shoud match a current appOwner/appId format', () => {
const id = 'fal-ai/fast-sdxl';
expect(ensureAppIdFormat(id)).toBe(id);
});
it('should throw on an invalid app id format', () => {
const id = 'just-an-id';
expect(() => ensureAppIdFormat(id)).toThrowError();
});
}); });

View File

@ -7,10 +7,24 @@ export function isUUIDv4(id: string): boolean {
); );
} }
export function ensureAppIdFormat(id: string): string {
const parts = id.split('/');
if (parts.length === 2) {
return id;
}
const [, appOwner, appId] = /^([0-9]+)-([a-zA-Z0-9-]+)$/.exec(id) || [];
if (appOwner && appId) {
return `${appOwner}/${appId}`;
}
throw new Error(
`Invalid app id: ${id}. Must be in the format <appOwner>/<appId>`
);
}
export function isValidUrl(url: string) { export function isValidUrl(url: string) {
try { try {
const parsedUrl = new URL(url); const { host } = new URL(url);
return parsedUrl.hostname.endsWith('fal.ai'); return /(fal\.(ai|run))$/.test(host);
} catch (_) { } catch (_) {
return false; return false;
} }

View File

@ -1,6 +1,6 @@
{ {
"name": "@fal-ai/serverless-proxy", "name": "@fal-ai/serverless-proxy",
"version": "0.6.0", "version": "0.7.0",
"license": "MIT", "license": "MIT",
"repository": { "repository": {
"type": "git", "type": "git",

View File

@ -9,6 +9,8 @@ const FAL_KEY_SECRET =
export type HeaderValue = string | string[] | undefined | null; export type HeaderValue = string | string[] | undefined | null;
const FAL_URL_REG_EXP = /(fal\.(run|ai))$/;
/** /**
* The proxy behavior that is passed to the proxy handler. This is a subset of * The proxy behavior that is passed to the proxy handler. This is a subset of
* request objects that are used by different frameworks, like Express and NextJS. * request objects that are used by different frameworks, like Express and NextJS.
@ -69,7 +71,9 @@ export async function handleRequest<ResponseType>(
if (!targetUrl) { if (!targetUrl) {
return behavior.respondWith(400, `Missing the ${TARGET_URL_HEADER} header`); return behavior.respondWith(400, `Missing the ${TARGET_URL_HEADER} header`);
} }
if (targetUrl.indexOf('fal.ai') === -1) {
const urlHost = new URL(targetUrl).host;
if (!FAL_URL_REG_EXP.test(urlHost)) {
return behavior.respondWith(412, `Invalid ${TARGET_URL_HEADER} header`); return behavior.respondWith(412, `Invalid ${TARGET_URL_HEADER} header`);
} }