Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve session handling #682

Merged
merged 9 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"bcryptjs": "^2.4.3",
"better-sqlite3": "^11.2.1",
"body-parser": "^1.19.0",
"connect-redis": "^8.0.1",
"cookie": "^1.0.1",
"cookie-parser": "^1.4.4",
"cors": "^2.8.5",
Expand Down
20 changes: 16 additions & 4 deletions src/AuthRegistry.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ class AuthRegistry {

/**
* @param {import('./schema.js').User} user
* @param {string} sessionID
*/
async createAuthToken(user) {
async createAuthToken(user, sessionID) {
const token = (await randomBytes(64)).toString('hex');
await this.#redis.set(`http-api:socketAuth:${token}`, user.id, 'EX', 60);
await this.#redis.set(`http-api:socketAuth:${token}`, `${user.id}/${sessionID}`, 'EX', 60);
return token;
}

Expand All @@ -37,12 +38,23 @@ class AuthRegistry {
.exec();
assert(result);

const [err, userID] = result[0];
const [err, authParts] = result[0];
if (err) {
throw err;
}
if (typeof authParts !== 'string') {
throw new Error('Invalid auth parts');
}

const index = authParts.indexOf('/');
if (index === -1) {
throw new Error('Invalid auth parts');
}

const userID = /** @type {import('./schema.js').UserID} */ (authParts.slice(0, index));
const sessionID = authParts.slice(index + 1);

return /** @type {import('./schema.js').UserID} */ (userID);
return { userID, sessionID };
}
}

Expand Down
8 changes: 7 additions & 1 deletion src/HttpApi.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import cookieParser from 'cookie-parser';
import cors from 'cors';
import helmet from 'helmet';
import session from 'express-session';
import { RedisStore } from 'connect-redis';
import qs from 'qs';
import { pinoHttp } from 'pino-http';

Expand Down Expand Up @@ -86,6 +87,7 @@ async function httpApi(uw, options) {

const logger = uw.logger.child({
ns: 'uwave:http-api',
level: 'warn',
});

uw.config.register(optionsSchema['uw:key'], optionsSchema);
Expand Down Expand Up @@ -121,11 +123,15 @@ async function httpApi(uw, options) {
secure: uw.express.get('env') === 'production',
httpOnly: true,
},
store: new RedisStore({
client: uw.redis,
}),
}))
.use(uw.passport.initialize())
.use(addFullUrl())
.use(attachUwaveMeta(uw.httpApi, uw))
.use(uw.passport.authenticate('jwt'))
.use(uw.passport.authenticate('jwt', { session: false }))
.use(uw.passport.session())
.use(rateLimit('api-http', { max: 500, duration: 60 * 1000 }));

uw.httpApi
Expand Down
42 changes: 25 additions & 17 deletions src/SocketServer.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import { serializeUser } from './utils/serialize.js';

const { debounce, isEmpty } = lodash;

export const REDIS_ACTIVE_SESSIONS = 'users';

/**
* @typedef {import('./schema.js').User} User
*/
Expand Down Expand Up @@ -76,7 +78,7 @@ class SocketServer {
// We do need to clear the `users` list because the lost connection handlers
// will not do so.
uw.socketServer.#logger.warn({ err }, 'could not initialise lost connections');
await uw.redis.del('users');
await uw.redis.del(REDIS_ACTIVE_SESSIONS);
}
});

Expand Down Expand Up @@ -394,7 +396,7 @@ class SocketServer {
const user = await users.getUser(userID);
if (user) {
// TODO this should not be the socket server code's responsibility
await redis.rpush('users', user.id);
await redis.rpush(REDIS_ACTIVE_SESSIONS, user.id);
this.broadcast('join', serializeUser(user));
}
},
Expand Down Expand Up @@ -453,7 +455,9 @@ class SocketServer {
*/
async initLostConnections() {
const { db, redis } = this.#uw;
const userIDs = /** @type {import('./schema').UserID[]} */ (await redis.lrange('users', 0, -1));
const userIDs = /** @type {import('./schema').UserID[]} */ (
await redis.lrange(REDIS_ACTIVE_SESSIONS, 0, -1)
);
const disconnectedIDs = userIDs.filter((userID) => !this.connection(userID));

if (disconnectedIDs.length === 0) {
Expand All @@ -465,7 +469,7 @@ class SocketServer {
.selectAll()
.execute();
disconnectedUsers.forEach((user) => {
this.add(this.createLostConnection(user));
this.add(this.createLostConnection(user, 'TODO: Actual session ID!!'));
});
}

Expand Down Expand Up @@ -529,15 +533,15 @@ class SocketServer {
connection.on('close', () => {
this.remove(connection);
});
connection.on('authenticate', async (user) => {
const isReconnect = await connection.isReconnect(user);
connection.on('authenticate', async (user, sessionID) => {
const isReconnect = await connection.isReconnect(sessionID);
this.#logger.info({ userId: user.id, isReconnect }, 'authenticated socket');
if (isReconnect) {
const previousConnection = this.getLostConnection(user);
const previousConnection = this.getLostConnection(sessionID);
if (previousConnection) this.remove(previousConnection);
}

this.replace(connection, this.createAuthedConnection(socket, user));
this.replace(connection, this.createAuthedConnection(socket, user, sessionID));

if (!isReconnect) {
this.#uw.publish('user:join', { userID: user.id });
Expand All @@ -551,18 +555,19 @@ class SocketServer {
*
* @param {import('ws').WebSocket} socket
* @param {User} user
* @param {string} sessionID
* @returns {AuthedConnection}
* @private
*/
createAuthedConnection(socket, user) {
const connection = new AuthedConnection(this.#uw, socket, user);
createAuthedConnection(socket, user, sessionID) {
const connection = new AuthedConnection(this.#uw, socket, user, sessionID);
connection.on('close', ({ banned }) => {
if (banned) {
this.#logger.info({ userId: user.id }, 'removing connection after ban');
disconnectUser(this.#uw, user.id);
} else if (!this.#closing) {
this.#logger.info({ userId: user.id }, 'lost connection');
this.add(this.createLostConnection(user));
this.add(this.createLostConnection(user, sessionID));
}
this.remove(connection);
});
Expand Down Expand Up @@ -594,11 +599,12 @@ class SocketServer {
* Create a connection instance for a user who disconnected.
*
* @param {User} user
* @param {string} sessionID
* @returns {LostConnection}
* @private
*/
createLostConnection(user) {
const connection = new LostConnection(this.#uw, user, this.options.timeout);
createLostConnection(user, sessionID) {
const connection = new LostConnection(this.#uw, user, sessionID, this.options.timeout);
connection.on('close', () => {
this.#logger.info({ userId: user.id }, 'user left');
this.remove(connection);
Expand All @@ -618,8 +624,9 @@ class SocketServer {
* @private
*/
add(connection) {
const userId = 'user' in connection ? connection.user.id : null;
this.#logger.trace({ type: connection.constructor.name, userId }, 'add connection');
const userID = 'user' in connection ? connection.user.id : null;
const sessionID = 'sessionID' in connection ? connection.sessionID : null;
this.#logger.trace({ type: connection.constructor.name, userID, sessionID }, 'add connection');

this.#connections.push(connection);
this.#recountGuests();
Expand All @@ -632,8 +639,9 @@ class SocketServer {
* @private
*/
remove(connection) {
const userId = 'user' in connection ? connection.user.id : null;
this.#logger.trace({ type: connection.constructor.name, userId }, 'remove connection');
const userID = 'user' in connection ? connection.user.id : null;
const sessionID = 'sessionID' in connection ? connection.sessionID : null;
this.#logger.trace({ type: connection.constructor.name, userID, sessionID }, 'remove connection');

const i = this.#connections.indexOf(connection);
this.#connections.splice(i, 1);
Expand Down
6 changes: 5 additions & 1 deletion src/auth/JWTStrategy.js
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ class JWTStrategy extends Strategy {
throw new BannedError();
}

return this.success(user);
const info = 'sessionID' in value && typeof value.sessionID === 'string'
? { sessionID: value.sessionID }
: undefined;

return this.success(user, info);
}
}

Expand Down
77 changes: 37 additions & 40 deletions src/controllers/authenticate.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { randomUUID } from 'node:crypto';
import { promisify } from 'node:util';
import cookie from 'cookie';
import jwt from 'jsonwebtoken';
import randomString from 'random-string';
import nodeFetch from 'node-fetch';
import ms from 'ms';
import htmlescape from 'htmlescape';
import httpErrors from 'http-errors';
import nodemailer from 'nodemailer';
Expand Down Expand Up @@ -37,13 +37,6 @@ const { BadRequest } = httpErrors;
* @prop {AuthenticateOptions} authOptions
*/

/**
* @param {string} str
*/
function seconds(str) {
return Math.floor(ms(str) / 1000);
}

/**
* @type {import('../types.js').Controller}
*/
Expand All @@ -68,28 +61,23 @@ async function getAuthStrategies(req) {
}

/**
* @param {import('express').Response} res
* @param {import('../HttpApi.js').HttpApi} api
* @param {import('../schema.js').User} user
* @param {import('../types.js').Request} req
* @param {import('../schema').User} user
* @param {AuthenticateOptions & { session: 'cookie' | 'token' }} options
*/
async function refreshSession(res, api, user, options) {
async function refreshSession(req, user, options) {
const { authRegistry } = req.uwaveHttp;
const sessionID = req.authInfo?.sessionID ?? req.sessionID;

const token = jwt.sign(
{ id: user.id },
{ id: user.id, sessionID: randomUUID() },
options.secret,
{ expiresIn: '31d' },
);

const socketToken = await api.authRegistry.createAuthToken(user);
const socketToken = await authRegistry.createAuthToken(user, sessionID);

if (options.session === 'cookie') {
const serialized = cookie.serialize('uwsession', token, {
httpOnly: true,
secure: !!options.cookieSecure,
path: options.cookiePath ?? '/',
maxAge: seconds('31 days'),
});
res.setHeader('Set-Cookie', serialized);
return { token: 'cookie', socketToken };
}

Expand All @@ -103,9 +91,8 @@ async function refreshSession(res, api, user, options) {
* @typedef {object} LoginQuery
* @prop {'cookie'|'token'} [session]
* @param {import('../types.js').AuthenticatedRequest<{}, LoginQuery, {}> & WithAuthOptions} req
* @param {import('express').Response} res
*/
async function login(req, res) {
async function login(req) {
const options = req.authOptions;
const { user } = req;
const { session } = req.query;
Expand All @@ -117,10 +104,11 @@ async function login(req, res) {
throw new BannedError();
}

const { token, socketToken } = await refreshSession(res, req.uwaveHttp, user, {
...options,
session: sessionType,
});
const { token, socketToken } = await refreshSession(
req,
user,
{ ...options, session: sessionType },
);

return toItemResponse(serializeCurrentUser(user), {
meta: {
Expand Down Expand Up @@ -189,10 +177,7 @@ async function socialLoginCallback(service, req, res) {
window.close();
`;

await refreshSession(res, req.uwaveHttp, user, {
...req.authOptions,
session: 'cookie',
});
await refreshSession(req, user, { ...req.authOptions, session: 'cookie' });

res.end(`
<!DOCTYPE html>
Expand Down Expand Up @@ -221,9 +206,8 @@ async function socialLoginCallback(service, req, res) {
* @param {string} service
* @param {import('../types.js').Request<{}, SocialLoginFinishQuery, SocialLoginFinishBody> &
* WithAuthOptions} req
* @param {import('express').Response} res
*/
async function socialLoginFinish(service, req, res) {
async function socialLoginFinish(service, req) {
const options = req.authOptions;
const { pendingUser: user } = req;
const sessionType = req.query.session === 'cookie' ? 'cookie' : 'token';
Expand Down Expand Up @@ -262,10 +246,23 @@ async function socialLoginFinish(service, req, res) {

Object.assign(user, updates);

const { token, socketToken } = await refreshSession(res, req.uwaveHttp, user, {
...options,
session: sessionType,
});
const passportLogin = promisify(
/**
* @type {(
* user: Express.User,
* options: import('passport').LogInOptions,
* callback: (err: any) => void,
* ) => void}
*/
(req.login),
);
await passportLogin(user, { session: sessionType === 'cookie' });

const { token, socketToken } = await refreshSession(
req,
user,
{ ...options, session: sessionType },
);

return toItemResponse(user, {
meta: {
Expand All @@ -279,10 +276,10 @@ async function socialLoginFinish(service, req, res) {
* @type {import('../types.js').AuthenticatedController}
*/
async function getSocketToken(req) {
const { user } = req;
const { user, sessionID } = req;
const { authRegistry } = req.uwaveHttp;

const socketToken = await authRegistry.createAuthToken(user);
const socketToken = await authRegistry.createAuthToken(user, sessionID);

return toItemResponse({ socketToken }, {
url: req.fullUrl,
Expand Down Expand Up @@ -449,6 +446,7 @@ async function logout(req, res) {
userID: user.id,
});

// Clear the legacy `uwsession` cookie.
if (cookies && cookies.uwsession) {
const serialized = cookie.serialize('uwsession', '', {
httpOnly: true,
Expand Down Expand Up @@ -479,7 +477,6 @@ export {
getSocketToken,
login,
logout,
refreshSession,
register,
removeSession,
reset,
Expand Down
Loading
Loading