Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 51 additions & 27 deletions graphql/server/src/middleware/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { getPgPool } from 'pg-cache';

import errorPage50x from '../errors/50x';
import errorPage404Message from '../errors/404-message';
import { ApiConfigResult, ApiError, ApiOptions, ApiStructure } from '../types';
import { ApiConfigResult, ApiError, ApiOptions, ApiStructure, RlsModule } from '../types';
import './types';

const log = new Logger('api');
Expand All @@ -20,6 +20,7 @@ const isDev = () => getNodeEnv() === 'development';

const DOMAIN_LOOKUP_SQL = `
SELECT
a.id as api_id,
a.database_id,
a.dbname,
a.role_name,
Expand All @@ -39,6 +40,7 @@ const DOMAIN_LOOKUP_SQL = `

const API_NAME_LOOKUP_SQL = `
SELECT
a.id as api_id,
a.database_id,
a.dbname,
a.role_name,
Expand Down Expand Up @@ -77,11 +79,23 @@ const API_LIST_SQL = `
LIMIT 100
`;

const RLS_MODULE_SQL = `
SELECT
rm.authenticate,
rm.authenticate_strict,
ps.schema_name as private_schema_name
FROM metaschema_modules_public.rls_module rm
LEFT JOIN metaschema_public.schema ps ON rm.private_schema_id = ps.id
WHERE rm.api_id = $1
LIMIT 1
`;

// =============================================================================
// Types
// =============================================================================

interface ApiRow {
api_id: string;
database_id: string;
dbname: string;
role_name: string;
Expand All @@ -90,6 +104,12 @@ interface ApiRow {
schemas: string[];
}

interface RlsModuleRow {
authenticate: string | null;
authenticate_strict: string | null;
private_schema_name: string | null;
}

interface ApiListRow {
id: string;
database_id: string;
Expand Down Expand Up @@ -164,12 +184,24 @@ export const getSvcKey = (opts: ApiOptions, req: Request): string => {
return baseKey;
};

const toApiStructure = (row: ApiRow, opts: ApiOptions): ApiStructure => ({
const toRlsModule = (row: RlsModuleRow | null): RlsModule | undefined => {
if (!row || !row.private_schema_name) return undefined;
return {
authenticate: row.authenticate ?? undefined,
authenticateStrict: row.authenticate_strict ?? undefined,
privateSchema: {
schemaName: row.private_schema_name,
},
};
};

const toApiStructure = (row: ApiRow, opts: ApiOptions, rlsModuleRow?: RlsModuleRow | null): ApiStructure => ({
dbname: row.dbname || opts.pg?.database || '',
anonRole: row.anon_role || 'anon',
roleName: row.role_name || 'authenticated',
schema: row.schemas || [],
apiModules: [],
rlsModule: toRlsModule(rlsModuleRow ?? null),
domains: [],
databaseId: row.database_id,
isPublic: row.is_public,
Expand Down Expand Up @@ -208,13 +240,8 @@ const queryByDomain = async (
subdomain: string | null,
isPublic: boolean
): Promise<ApiRow | null> => {
try {
const result = await pool.query<ApiRow>(DOMAIN_LOOKUP_SQL, [domain, subdomain, isPublic]);
return result.rows[0] ?? null;
} catch (err: unknown) {
if ((err as Error).message?.includes('does not exist')) return null;
throw err;
}
const result = await pool.query<ApiRow>(DOMAIN_LOOKUP_SQL, [domain, subdomain, isPublic]);
return result.rows[0] ?? null;
};

const queryByApiName = async (
Expand All @@ -223,23 +250,18 @@ const queryByApiName = async (
name: string,
isPublic: boolean
): Promise<ApiRow | null> => {
try {
const result = await pool.query<ApiRow>(API_NAME_LOOKUP_SQL, [databaseId, name, isPublic]);
return result.rows[0] ?? null;
} catch (err: unknown) {
if ((err as Error).message?.includes('does not exist')) return null;
throw err;
}
const result = await pool.query<ApiRow>(API_NAME_LOOKUP_SQL, [databaseId, name, isPublic]);
return result.rows[0] ?? null;
};

const queryApiList = async (pool: Pool, isPublic: boolean): Promise<ApiListRow[]> => {
try {
const result = await pool.query<ApiListRow>(API_LIST_SQL, [isPublic]);
return result.rows;
} catch (err: unknown) {
if ((err as Error).message?.includes('does not exist')) return [];
throw err;
}
const result = await pool.query<ApiListRow>(API_LIST_SQL, [isPublic]);
return result.rows;
};

const queryRlsModule = async (pool: Pool, apiId: string): Promise<RlsModuleRow | null> => {
const result = await pool.query<RlsModuleRow>(RLS_MODULE_SQL, [apiId]);
return result.rows[0] ?? null;
};

// =============================================================================
Expand Down Expand Up @@ -300,8 +322,9 @@ const resolveApiNameHeader = async (ctx: ResolveContext): Promise<ApiStructure |
return null;
}

log.debug(`[api-name-lookup] resolved schemas: [${row.schemas?.join(', ')}]`);
return toApiStructure(row, opts);
const rlsModule = await queryRlsModule(pool, row.api_id);
log.debug(`[api-name-lookup] resolved schemas: [${row.schemas?.join(', ')}], rlsModule: ${rlsModule ? 'found' : 'none'}`);
return toApiStructure(row, opts, rlsModule);
};

const resolveMetaSchemaHeader = (
Expand All @@ -324,8 +347,9 @@ const resolveDomainLookup = async (ctx: ResolveContext): Promise<ApiStructure |
return null;
}

log.debug(`[domain-lookup] resolved schemas: [${row.schemas?.join(', ')}]`);
return toApiStructure(row, opts);
const rlsModule = await queryRlsModule(pool, row.api_id);
log.debug(`[domain-lookup] resolved schemas: [${row.schemas?.join(', ')}], rlsModule: ${rlsModule ? 'found' : 'none'}`);
return toApiStructure(row, opts, rlsModule);
};

const buildDevFallbackError = async (
Expand Down
41 changes: 36 additions & 5 deletions graphql/server/src/middleware/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export const createAuthenticateMiddleware = (
next: NextFunction
): Promise<void> => {
const api = req.api;
log.info(`[auth] middleware called, api=${api ? 'present' : 'missing'}`);
if (!api) {
res.status(500).send('Missing API info');
return;
Expand All @@ -29,21 +30,38 @@ export const createAuthenticateMiddleware = (
});
const rlsModule = api.rlsModule;

log.info(
`[auth] rlsModule=${rlsModule ? 'present' : 'missing'}, ` +
`authenticate=${rlsModule?.authenticate ?? 'none'}, ` +
`authenticateStrict=${rlsModule?.authenticateStrict ?? 'none'}, ` +
`privateSchema=${rlsModule?.privateSchema?.schemaName ?? 'none'}`
);

if (!rlsModule) {
if (isDev()) log.debug('No RLS module configured, skipping auth');
log.info('[auth] No RLS module configured, skipping auth');
return next();
}

const authFn = opts.server.strictAuth
const authFn = opts.server?.strictAuth
? rlsModule.authenticateStrict
: rlsModule.authenticate;

log.info(
`[auth] strictAuth=${opts.server?.strictAuth ?? false}, authFn=${authFn ?? 'none'}`
);

if (authFn && rlsModule.privateSchema.schemaName) {
const { authorization = '' } = req.headers;
const [authType, authToken] = authorization.split(' ');
let token: any = {};

log.info(
`[auth] authorization header present=${!!authorization}, ` +
`authType=${authType ?? 'none'}, hasToken=${!!authToken}`
);

if (authType?.toLowerCase() === 'bearer' && authToken) {
log.info('[auth] Processing bearer token authentication');
const context: Record<string, any> = {
'jwt.claims.ip_address': req.clientIp,
};
Expand All @@ -55,25 +73,31 @@ export const createAuthenticateMiddleware = (
context['jwt.claims.user_agent'] = req.get('User-Agent');
}

const authQuery = `SELECT * FROM "${rlsModule.privateSchema.schemaName}"."${authFn}"($1)`;
log.info(`[auth] Executing auth query: ${authQuery}`);

try {
const result = await pgQueryContext({
client: pool,
context,
query: `SELECT * FROM "${rlsModule.privateSchema.schemaName}"."${authFn}"($1)`,
query: authQuery,
variables: [authToken],
});

log.info(`[auth] Query result: rowCount=${result?.rowCount}`);

if (result?.rowCount === 0) {
log.info('[auth] No rows returned, returning UNAUTHENTICATED');
res.status(200).json({
errors: [{ extensions: { code: 'UNAUTHENTICATED' } }],
});
return;
}

token = result.rows[0];
if (isDev()) log.debug(`Auth success: role=${token.role}`);
log.info(`[auth] Auth success: role=${token.role}, user_id=${token.user_id}`);
} catch (e: any) {
log.error('Auth error:', e.message);
log.error('[auth] Auth error:', e.message);
res.status(200).json({
errors: [
{
Expand All @@ -86,9 +110,16 @@ export const createAuthenticateMiddleware = (
});
return;
}
} else {
log.info('[auth] No bearer token provided, using anonymous auth');
}

req.token = token;
} else {
log.info(
`[auth] Skipping auth: authFn=${authFn ?? 'none'}, ` +
`privateSchema=${rlsModule.privateSchema?.schemaName ?? 'none'}`
);
}

next();
Expand Down
5 changes: 3 additions & 2 deletions graphql/server/src/middleware/graphile.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ const createGraphileInstance = async (
},
grafast: {
explain: process.env.NODE_ENV === 'development',
context: (ctx: unknown) => {
const req = (ctx as { node?: { req?: Request } } | undefined)?.node?.req;
context: (requestContext: Partial<Grafast.RequestContext>) => {
// In grafserv/express/v4, the request is available at requestContext.expressv4.req
const req = (requestContext as { expressv4?: { req?: Request } })?.expressv4?.req;
const context: Record<string, string> = {};

if (req) {
Expand Down