Skip to content

Commit

Permalink
Adjust error handling in download to not return json
Browse files Browse the repository at this point in the history
  • Loading branch information
JustMaier committed Jun 7, 2023
1 parent aba1084 commit d446aa6
Showing 1 changed file with 18 additions and 23 deletions.
41 changes: 18 additions & 23 deletions src/pages/api/download/models/[modelVersionId].ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,30 @@ const schema = z.object({
fp: z.enum(constants.modelFileFp).optional(),
});

const forbidden = (req: NextApiRequest, res: NextApiResponse) => {
res.status(403);
if (req.headers['content-type'] === 'application/json') return res.json({ error: 'Forbidden' });
else return res.send('Forbidden');
};

const notFound = (req: NextApiRequest, res: NextApiResponse, message = 'Not Found') => {
res.status(404);
if (req.headers['content-type'] === 'application/json') return res.json({ error: message });
else return res.send(message);
};

export default RateLimitedEndpoint(
async function downloadModel(req: NextApiRequest, res: NextApiResponse) {
function errorResponse(status: number, message: string) {
res.status(status);
if (req.headers['content-type'] === 'application/json') return res.json({ error: message });
else return res.send(message);
}

// Get ip so that we can block exploits we catch
const ip = requestIp.getClientIp(req);
const ipBlacklist = (
((await dbRead.keyValue.findUnique({ where: { key: 'ip-blacklist' } }))?.value as string) ??
''
).split(',');
if (ip && ipBlacklist.includes(ip)) return forbidden(req, res);
if (ip && ipBlacklist.includes(ip)) return errorResponse(403, 'Forbidden');

const session = await getServerAuthSession({ req, res });
if (!!session?.user) {
const userBlacklist = (
((await dbRead.keyValue.findUnique({ where: { key: 'user-blacklist' } }))
?.value as string) ?? ''
).split(',');
if (userBlacklist.includes(session.user.id.toString())) return forbidden(req, res);
if (userBlacklist.includes(session.user.id.toString()))
return errorResponse(403, 'Forbidden');
}

const queryResults = schema.safeParse(req.query);
Expand All @@ -64,7 +59,7 @@ export default RateLimitedEndpoint(
.json({ error: `Invalid id: ${queryResults.error.flatten().fieldErrors.modelVersionId}` });

const { type, modelVersionId, format, size, fp } = queryResults.data;
if (!modelVersionId) return res.status(400).json({ error: 'Missing modelVersionId' });
if (!modelVersionId) return errorResponse(400, 'Missing modelVersionId');

const fileWhere: Prisma.ModelFileWhereInput = {};
if (type) fileWhere.type = type;
Expand Down Expand Up @@ -102,7 +97,7 @@ export default RateLimitedEndpoint(
},
},
});
if (!modelVersion) return notFound(req, res, 'Model not found');
if (!modelVersion) return errorResponse(404, 'Model not found');

const { files } = modelVersion;
const metadata: FileMetadata = {
Expand All @@ -113,25 +108,24 @@ export default RateLimitedEndpoint(
Omit<(typeof files)[number], 'metadata'> & { metadata: FileMetadata }
>;
const file = getPrimaryFile(castedFiles, { metadata });
if (!file) return notFound(req, res, 'Model file not found');
if (!file) return errorResponse(404, 'Model file not found');

// Handle non-published models
const isMod = session?.user?.isModerator;
const userId = session?.user?.id;
const archived = modelVersion.model.mode === ModelModifier.Archived;
if (archived)
return res.status(410).json({ error: 'Model archived, not available for download' });
if (archived) return errorResponse(410, 'Model archived, not available for download');

const canDownload =
isMod ||
modelVersion?.model?.status === 'Published' ||
(userId && modelVersion?.model?.userId === userId);
if (!canDownload) return notFound(req, res, 'Model not found');
if (!canDownload) return errorResponse(404, 'Model not found');

// Handle unauthenticated downloads
if (!env.UNAUTHENTICATED_DOWNLOAD && !userId) {
if (req.headers['content-type'] === 'application/json')
return res.status(401).json({ error: 'Unauthorized' });
return errorResponse(401, 'Unauthorized');
else
return res.redirect(
getLoginLink({ reason: 'download-auth', returnUrl: `/models/${modelVersion.model.id}` })
Expand Down Expand Up @@ -195,7 +189,8 @@ export default RateLimitedEndpoint(
modelVersionId: modelVersion.id,
});
} catch (error) {
return res.status(500).json({ error: 'Invalid database operation', cause: error });
console.error(error);
return errorResponse(500, 'Invalid database operation');
}

const fileName = getDownloadFilename({ model: modelVersion.model, modelVersion, file });
Expand All @@ -205,7 +200,7 @@ export default RateLimitedEndpoint(
} catch (err: unknown) {
const error = err as Error;
console.error(`Error downloading file: ${file.url} - ${error.message}`);
return res.status(500).json({ error: 'Error downloading file' });
return errorResponse(500, 'Error downloading file');
}
},
['GET'],
Expand Down

0 comments on commit d446aa6

Please sign in to comment.