feat: fal.run url support (#40)

* feat: fal.run url support

* feat: new queue url format

* fix: build issue

* chore: change demo app ids
This commit is contained in:
Daniel Rochetti 2024-01-15 15:14:21 -08:00 committed by GitHub
parent 208073ce17
commit 07f1b125f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 106 additions and 65 deletions

View File

@ -137,7 +137,7 @@ export default function WebcamPage() {
const previewRef = useRef<HTMLCanvasElement | null>(null);
const { send } = fal.realtime.connect<LCMInput, LCMOutput>(
'110602490-sd-turbo-real-time-high-fps-msgpack',
'fal-ai/sd-turbo-real-time-high-fps-msgpack',
{
connectionKey: 'camera-turbo-demo',
// not throttling the client, handling throttling of the camera itself

View File

@ -82,7 +82,7 @@ export default function Home() {
const start = Date.now();
try {
const result: Result = await fal.subscribe(
'54285744-illusion-diffusion',
'54285744/illusion-diffusion',
{
input: {
prompt,

View File

@ -14,7 +14,7 @@ const PROMPT = 'a moon in a starry night sky';
export default function RealtimePage() {
const [image, setImage] = useState<string | null>(null);
const { send } = fal.realtime.connect('110602490-lcm-sd15-i2i', {
const { send } = fal.realtime.connect('fal-ai/lcm-sd15-i2i', {
connectionKey: 'realtime-demo',
throttleInterval: 128,
onResult(result) {

View File

@ -110,7 +110,7 @@ export default function WhisperDemo() {
setLoading(true);
const start = Date.now();
try {
const result = await fal.subscribe('110602490-whisper', {
const result = await fal.subscribe('fal-ai/whisper', {
input: {
file_name: 'recording.wav',
audio_url: audioFile,

View File

@ -74,13 +74,13 @@ export function Index() {
setLoading(true);
const start = Date.now();
try {
const result: Result = await fal.subscribe('110602490-lora', {
const result: Result = await fal.subscribe('fal-ai/lora', {
input: {
prompt,
model_name: 'stabilityai/stable-diffusion-xl-base-1.0',
image_size: 'square_hd',
},
pollInterval: 5000, // Default is 1000 (every 1s)
pollInterval: 3000, // Default is 1000 (every 1s)
logs: true,
onQueueUpdate(update) {
setElapsedTime(Date.now() - start);

View File

@ -1,7 +1,7 @@
{
"name": "@fal-ai/serverless-client",
"description": "The fal serverless JS/TS client",
"version": "0.7.4",
"version": "0.8.0",
"license": "MIT",
"repository": {
"type": "git",

View File

@ -3,12 +3,10 @@ import { config, getConfig } from './config';
describe('The config test suite', () => {
it('should set the config variables accordingly', () => {
const newConfig = {
host: 'some-other-host',
credentials: 'key-id:key-secret',
};
config(newConfig);
const currentConfig = getConfig();
expect(currentConfig.host).toBe(newConfig.host);
expect(currentConfig.credentials).toEqual(newConfig.credentials);
});
});

View File

@ -10,7 +10,6 @@ export type CredentialsResolver = () => string | undefined;
export type Config = {
credentials?: undefined | string | CredentialsResolver;
host?: string;
proxyUrl?: string;
requestMiddleware?: RequestMiddleware;
responseHandler?: ResponseHandler<any>;
@ -46,22 +45,7 @@ export const credentialsFromEnv: CredentialsResolver = () => {
return `${process.env.FAL_KEY_ID}:${process.env.FAL_KEY_SECRET}`;
};
/**
* Get the default host for the fal-serverless gateway endpoint.
* @private
* @returns the default host. Depending on the platform it can default to
* the environment variable `FAL_HOST`.
*/
function getDefaultHost(): string {
const host = 'gateway.alpha.fal.ai';
if (typeof process !== 'undefined' && process.env) {
return process.env.FAL_HOST || host;
}
return host;
}
const DEFAULT_CONFIG: Partial<Config> = {
host: getDefaultHost(),
credentials: credentialsFromEnv,
requestMiddleware: (request) => Promise.resolve(request),
responseHandler: defaultResponseHandler,
@ -104,6 +88,5 @@ export function getConfig(): RequiredConfig {
* @returns the URL of the fal serverless rest api endpoint.
*/
export function getRestApiUrl(): string {
const { host } = getConfig();
return host.replace('gateway', 'rest');
return 'https://rest.alpha.fal.ai';
}

View File

@ -1,5 +1,4 @@
import uuid from 'uuid-random';
import { getConfig } from './config';
import { buildUrl } from './function';
describe('The function test suite', () => {
@ -9,10 +8,15 @@ describe('The function test suite', () => {
expect(url).toMatch(`trigger/12345/${id}`);
});
it('should build the URL with a function alias', () => {
const { host } = getConfig();
it('should build the URL with a function user-id-app-alias', () => {
const alias = '12345-some-alias';
const url = buildUrl(alias);
expect(url).toMatch(`${alias}.${host}`);
expect(url).toMatch(`fal.run/12345/some-alias`);
});
it('should build the URL with a function username/app-alias', () => {
const alias = 'fal-ai/text-to-image';
const url = buildUrl(alias);
expect(url).toMatch(`fal.run/${alias}`);
});
});

View File

@ -1,8 +1,7 @@
import { getConfig } from './config';
import { dispatchRequest } from './request';
import { storageImpl } from './storage';
import { EnqueueResult, QueueStatus } from './types';
import { isUUIDv4, isValidUrl } from './utils';
import { ensureAppIdFormat, isUUIDv4, isValidUrl } from './utils';
/**
* The function input and other configuration when running
@ -36,6 +35,15 @@ type RunOptions<Input> = {
readonly autoUpload?: boolean;
};
type ExtraOptions = {
/**
* If `true`, the function will use the queue to run the function
* asynchronously and return the result in a separate call. This
* influences how the URL is built.
*/
readonly subdomain?: string;
};
/**
* Builds the final url to run the function based on its `id` or alias and
* a the options from `RunOptions<Input>`.
@ -47,9 +55,8 @@ type RunOptions<Input> = {
*/
export function buildUrl<Input>(
id: string,
options: RunOptions<Input> = {}
options: RunOptions<Input> & ExtraOptions = {}
): string {
const { host } = getConfig();
const method = (options.method ?? 'post').toLowerCase();
const path = (options.path ?? '').replace(/^\//, '').replace(/\/{2,}/, '/');
const input = options.input;
@ -59,16 +66,37 @@ export function buildUrl<Input>(
const queryParams = params ? `?${params.toString()}` : '';
const parts = id.split('/');
// if a fal.ai url is passed, just use it
// if a fal url is passed, just use it
if (isValidUrl(id)) {
const url = id.endsWith('/') ? id : `${id}/`;
return `${url}${path}${queryParams}`;
}
// TODO remove this after some time, fal.run should be preferred
if (parts.length === 2 && isUUIDv4(parts[1])) {
const host = 'gateway.shark.fal.ai';
return `https://${host}/trigger/${id}/${path}${queryParams}`;
}
return `https://${id}.${host}/${path}${queryParams}`;
const appId = ensureAppIdFormat(id);
const subdomain = options.subdomain ? `${options.subdomain}.` : '';
const url = `https://${subdomain}fal.run/${appId}/${path}`;
return `${url.replace(/\/$/, '')}${queryParams}`;
}
export async function send<Input, Output>(
id: string,
options: RunOptions<Input> & ExtraOptions = {}
): Promise<Output> {
const input =
options.input && options.autoUpload !== false
? await storageImpl.transformInput(options.input)
: options.input;
return dispatchRequest<Input, Output>(
options.method ?? 'post',
buildUrl(id, options),
input as Input
);
}
/**
@ -81,15 +109,7 @@ export async function run<Input, Output>(
id: string,
options: RunOptions<Input> = {}
): Promise<Output> {
const input =
options.input && options.autoUpload !== false
? await storageImpl.transformInput(options.input)
: options.input;
return dispatchRequest<Input, Output>(
options.method ?? 'post',
buildUrl(id, options),
input as Input
);
return send(id, options);
}
/**
@ -252,19 +272,21 @@ export const queue: Queue = {
const query = webhookUrl
? '?' + new URLSearchParams({ fal_webhook: webhookUrl }).toString()
: '';
return run(id, {
return send(id, {
...runOptions,
subdomain: 'queue',
method: 'post',
path: '/fal/queue/submit' + path + query,
path: path + query,
});
},
async status(
id: string,
{ requestId, logs = false }: QueueStatusOptions
): Promise<QueueStatus> {
return run(id, {
return send(id, {
subdomain: 'queue',
method: 'get',
path: `/fal/queue/requests/${requestId}/status`,
path: `/requests/${requestId}/status`,
input: {
logs: logs ? '1' : '0',
},
@ -274,9 +296,10 @@ export const queue: Queue = {
id: string,
{ requestId }: BaseQueueOptions
): Promise<Output> {
return run(id, {
return send(id, {
subdomain: 'queue',
method: 'get',
path: `/fal/queue/requests/${requestId}/response`,
path: `/requests/${requestId}`,
});
},
subscribe,

View File

@ -13,11 +13,11 @@ import {
transition,
} from 'robot3';
import uuid from 'uuid-random';
import { getConfig, getRestApiUrl } from './config';
import { getRestApiUrl } from './config';
import { dispatchRequest } from './request';
import { ApiError } from './response';
import { isBrowser } from './runtime';
import { isReact, throttle } from './utils';
import { ensureAppIdFormat, isReact, throttle } from './utils';
// Define the context
interface Context {
@ -253,7 +253,6 @@ function buildRealtimeUrl(
app: string,
{ token, maxBuffering }: RealtimeUrlParams
): string {
const { host } = getConfig();
if (maxBuffering !== undefined && (maxBuffering < 1 || maxBuffering > 60)) {
throw new Error('The `maxBuffering` must be between 1 and 60 (inclusive)');
}
@ -263,7 +262,8 @@ function buildRealtimeUrl(
if (maxBuffering !== undefined) {
queryParams.set('max_buffering', maxBuffering.toFixed(0));
}
return `wss://${app}.${host}/ws?${queryParams.toString()}`;
const appId = ensureAppIdFormat(app);
return `wss://fal.run/${appId}/ws?${queryParams.toString()}`;
}
const TOKEN_EXPIRATION_SECONDS = 120;
@ -282,12 +282,12 @@ function shouldSendBinary(message: any): boolean {
* Get a token to connect to the realtime endpoint.
*/
async function getToken(app: string): Promise<string> {
const [_, ...appAlias] = app.split('-');
const [, appAlias] = ensureAppIdFormat(app).split('/');
const token: string | object = await dispatchRequest<any, string>(
'POST',
`https://${getRestApiUrl()}/tokens/`,
`${getRestApiUrl()}/tokens/`,
{
allowed_apps: [appAlias.join('-')],
allowed_apps: [appAlias],
token_expiration: TOKEN_EXPIRATION_SECONDS,
}
);

View File

@ -75,7 +75,7 @@ async function initiateUpload(file: Blob): Promise<InitiateUploadResult> {
file.name || `${Date.now()}.${getExtensionFromContentType(contentType)}`;
return await dispatchRequest<InitiateUploadData, InitiateUploadResult>(
'POST',
`https://${getRestApiUrl()}/storage/upload/initiate`,
`${getRestApiUrl()}/storage/upload/initiate`,
{
content_type: contentType,
file_name: filename,

View File

@ -1,5 +1,5 @@
import uuid from 'uuid-random';
import { isUUIDv4 } from './utils';
import { ensureAppIdFormat, isUUIDv4 } from './utils';
describe('The utils test suite', () => {
it('should match a valid v4 uuid', () => {
@ -11,4 +11,19 @@ describe('The utils test suite', () => {
const id = 'e726b886-e2c2-11ed-b5ea-0242ac120002';
expect(isUUIDv4(id)).toBe(false);
});
it('shoud match match a legacy appOwner-appId format', () => {
const id = '12345-abcde-fgh';
expect(ensureAppIdFormat(id)).toBe('12345/abcde-fgh');
});
it('shoud match a current appOwner/appId format', () => {
const id = 'fal-ai/fast-sdxl';
expect(ensureAppIdFormat(id)).toBe(id);
});
it('should throw on an invalid app id format', () => {
const id = 'just-an-id';
expect(() => ensureAppIdFormat(id)).toThrowError();
});
});

View File

@ -7,10 +7,24 @@ export function isUUIDv4(id: string): boolean {
);
}
export function ensureAppIdFormat(id: string): string {
const parts = id.split('/');
if (parts.length === 2) {
return id;
}
const [, appOwner, appId] = /^([0-9]+)-([a-zA-Z0-9-]+)$/.exec(id) || [];
if (appOwner && appId) {
return `${appOwner}/${appId}`;
}
throw new Error(
`Invalid app id: ${id}. Must be in the format <appOwner>/<appId>`
);
}
export function isValidUrl(url: string) {
try {
const parsedUrl = new URL(url);
return parsedUrl.hostname.endsWith('fal.ai');
const { host } = new URL(url);
return /(fal\.(ai|run))$/.test(host);
} catch (_) {
return false;
}

View File

@ -1,6 +1,6 @@
{
"name": "@fal-ai/serverless-proxy",
"version": "0.6.0",
"version": "0.7.0",
"license": "MIT",
"repository": {
"type": "git",

View File

@ -9,6 +9,8 @@ const FAL_KEY_SECRET =
export type HeaderValue = string | string[] | undefined | null;
const FAL_URL_REG_EXP = /(fal\.(run|ai))$/;
/**
* The proxy behavior that is passed to the proxy handler. This is a subset of
* request objects that are used by different frameworks, like Express and NextJS.
@ -69,7 +71,9 @@ export async function handleRequest<ResponseType>(
if (!targetUrl) {
return behavior.respondWith(400, `Missing the ${TARGET_URL_HEADER} header`);
}
if (targetUrl.indexOf('fal.ai') === -1) {
const urlHost = new URL(targetUrl).host;
if (!FAL_URL_REG_EXP.test(urlHost)) {
return behavior.respondWith(412, `Invalid ${TARGET_URL_HEADER} header`);
}