diff --git a/src/http-proxy-middleware.ts b/src/http-proxy-middleware.ts index 54f054fb..597c3cb3 100644 --- a/src/http-proxy-middleware.ts +++ b/src/http-proxy-middleware.ts @@ -98,10 +98,16 @@ export class HttpProxyMiddleware { }; private handleUpgrade = async (req: http.IncomingMessage, socket: net.Socket, head: Buffer) => { - if (this.shouldProxy(this.proxyOptions.pathFilter, req)) { - const activeProxyOptions = await this.prepareProxyRequest(req); - this.proxy.ws(req, socket, head, activeProxyOptions); - debug('server upgrade event received. Proxying WebSocket'); + try { + if (this.shouldProxy(this.proxyOptions.pathFilter, req)) { + const activeProxyOptions = await this.prepareProxyRequest(req); + this.proxy.ws(req, socket, head, activeProxyOptions); + debug('server upgrade event received. Proxying WebSocket'); + } + } catch (err) { + // This error does not include the URL as the fourth argument as we won't + // have the URL if `this.prepareProxyRequest` throws an error. + this.proxy.emit('error', err, req, socket); } }; diff --git a/src/plugins/default/error-response-plugin.ts b/src/plugins/default/error-response-plugin.ts index 4d9113db..f13a815a 100644 --- a/src/plugins/default/error-response-plugin.ts +++ b/src/plugins/default/error-response-plugin.ts @@ -1,6 +1,16 @@ +import type * as http from 'http'; +import type { Socket } from 'net'; import { getStatusCode } from '../../status-code'; import { Plugin } from '../../types'; +function isResponseLike(obj: any): obj is http.ServerResponse { + return obj && typeof obj.writeHead === 'function'; +} + +function isSocketLike(obj: any): obj is Socket { + return obj && typeof obj.write === 'function' && !('writeHead' in obj); +} + export const errorResponsePlugin: Plugin = (proxyServer, options) => { proxyServer.on('error', (err, req, res, target?) => { // Re-throw error. Not recoverable since req & res are empty. @@ -8,12 +18,16 @@ export const errorResponsePlugin: Plugin = (proxyServer, options) => { throw err; // "Error: Must provide a proper URL as target" } - if ('writeHead' in res && !res.headersSent) { - const statusCode = getStatusCode((err as unknown as any).code); - res.writeHead(statusCode); - } + if (isResponseLike(res)) { + if (!res.headersSent) { + const statusCode = getStatusCode((err as unknown as any).code); + res.writeHead(statusCode); + } - const host = req.headers && req.headers.host; - res.end(`Error occurred while trying to proxy: ${host}${req.url}`); + const host = req.headers && req.headers.host; + res.end(`Error occurred while trying to proxy: ${host}${req.url}`); + } else if (isSocketLike(res)) { + res.destroy(); + } }); }; diff --git a/test/e2e/websocket.spec.ts b/test/e2e/websocket.spec.ts index 9dbd9856..7dde4769 100644 --- a/test/e2e/websocket.spec.ts +++ b/test/e2e/websocket.spec.ts @@ -97,15 +97,14 @@ describe('E2E WebSocket proxy', () => { describe('with router and pathRewrite', () => { beforeEach(() => { - // override - proxyServer = createApp( + const proxyMiddleware = createProxyMiddleware({ // cSpell:ignore notworkinghost - createProxyMiddleware({ - target: 'ws://notworkinghost:6789', - router: { '/socket': `ws://localhost:${WS_SERVER_PORT}` }, - pathRewrite: { '^/socket': '' }, - }), - ).listen(SERVER_PORT); + target: 'ws://notworkinghost:6789', + router: { '/socket': `ws://localhost:${WS_SERVER_PORT}` }, + pathRewrite: { '^/socket': '' }, + }); + + proxyServer = createApp(proxyMiddleware).listen(SERVER_PORT); proxyServer.on('upgrade', proxyMiddleware.upgrade); }); @@ -124,4 +123,29 @@ describe('E2E WebSocket proxy', () => { ws.send('foobar'); }); }); + + describe('with error in router', () => { + beforeEach(() => { + const proxyMiddleware = createProxyMiddleware({ + // cSpell:ignore notworkinghost + target: `http://notworkinghost:6789`, + router: async () => { + throw new Error('error'); + }, + }); + + proxyServer = createApp(proxyMiddleware).listen(SERVER_PORT); + + proxyServer.on('upgrade', proxyMiddleware.upgrade); + }); + + it('should handle error', (done) => { + ws = new WebSocket(`ws://localhost:${SERVER_PORT}/socket`); + + ws.on('error', (err) => { + expect(err).toBeTruthy(); + done(); + }); + }); + }); });