diff --git a/src/channel/ChannelReceiver.ts b/src/channel/ChannelReceiver.ts index b3c67ce..2183fb9 100644 --- a/src/channel/ChannelReceiver.ts +++ b/src/channel/ChannelReceiver.ts @@ -22,12 +22,14 @@ import { InternalEmitterRequestType, InternalReceiverRequestType } from "./types export type ChannelReceiverOptions = { readyTimeout: number + allowedOrigin: string | null } export const channelReceiverDefaultOptions: ChannelReceiverOptions & Partial = { readyTimeout: 20000, requestIDPrefix: "receiver-", + allowedOrigin: null, } export type AllChannelReceiverOptions = ChannelReceiverOptions & ChannelNetworkOptions @@ -65,7 +67,10 @@ export abstract class ChannelReceiver< >( request, (request) => { - window.parent.postMessage(request, "*") + window.parent.postMessage( + request, + this.options.allowedOrigin ? this.options.allowedOrigin : "*", + ) }, { timeout: this.options.readyTimeout, @@ -79,6 +84,14 @@ export abstract class ChannelReceiver< /** Handles public messages */ private _onPublicMessage(event: MessageEvent): void { + // Validate origin if allowedOrigin is configured + if ( + this.options.allowedOrigin && + this.options.allowedOrigin !== event.origin + ) { + return + } + try { const message = validateMessage(event.data) @@ -102,6 +115,7 @@ export abstract class ChannelReceiver< debug: this.options.debug, requestIDPrefix: this.options.requestIDPrefix, readyTimeout: this.options.readyTimeout, + allowedOrigin: this.options.allowedOrigin, } const response = createSuccessResponseMessage(message.requestID, undefined) diff --git a/src/channel/types.ts b/src/channel/types.ts index e8a5a9c..1b7415a 100644 --- a/src/channel/types.ts +++ b/src/channel/types.ts @@ -105,7 +105,13 @@ export type InternalEmitterTransactions< [InternalEmitterRequestType.Connect]: Transaction< RequestMessage< InternalEmitterRequestType.Connect, - Partial> | undefined + | Partial< + Omit< + TReceiverOptions, + "debug" | "requestIDPrefix" | "readyTimeout" | "allowedOrigin" + > + > + | undefined > > } diff --git a/src/kit/SimulatorManager.ts b/src/kit/SimulatorManager.ts index a2cdd4b..4fe0007 100644 --- a/src/kit/SimulatorManager.ts +++ b/src/kit/SimulatorManager.ts @@ -11,17 +11,20 @@ import { StateEventType } from "./types" type ManagerConstructorArgs = { slices?: SliceZone + allowedOrigin?: string } export class SimulatorManager { public state: State private _api: SimulatorAPI | null private _initialized: boolean + private _allowedOrigin: string | null constructor(args?: ManagerConstructorArgs) { this.state = new State(args) this._api = null this._initialized = false + this._allowedOrigin = args?.allowedOrigin ?? null } async init(): Promise { @@ -57,54 +60,59 @@ export class SimulatorManager { private async _initAPI(): Promise { // Register SimulatorAPI request handlers - this._api = new SimulatorAPI({ - [ClientRequestType.SetSliceZone]: (req, res) => { - this.state.setSliceZone(req.data) - - return res.success() - }, - [ClientRequestType.ScrollToSlice]: (req, res) => { - // Error if `sliceIndex` is invalid - if (req.data.sliceIndex < 0) { - return res.error("`sliceIndex` must be > 0", 400) - } else if (req.data.sliceIndex >= this.state.slices.length) { - return res.error( - `\`sliceIndex\` must be < ${this.state.slices.length} (\`\` current length)`, - 400, - ) - } + this._api = new SimulatorAPI( + { + [ClientRequestType.SetSliceZone]: (req, res) => { + this.state.setSliceZone(req.data) + + return res.success() + }, + [ClientRequestType.ScrollToSlice]: (req, res) => { + // Error if `sliceIndex` is invalid + if (req.data.sliceIndex < 0) { + return res.error("`sliceIndex` must be > 0", 400) + } else if (req.data.sliceIndex >= this.state.slices.length) { + return res.error( + `\`sliceIndex\` must be < ${this.state.slices.length} (\`\` current length)`, + 400, + ) + } - const $sliceZone = getSliceZoneDOM(this.state.slices.length) - if (!$sliceZone) { - return res.error("Failed to find ``", 500) - } + const $sliceZone = getSliceZoneDOM(this.state.slices.length) + if (!$sliceZone) { + return res.error("Failed to find ``", 500) + } - // Destroy existing active slice as we're about to scroll - this.state.activeSlice = null + // Destroy existing active slice as we're about to scroll + this.state.activeSlice = null - const $slice = $sliceZone.children[req.data.sliceIndex] - if (!$slice) { - return res.error( - `Failed fo find slice at index $\`{req.data.sliceIndex}\` in \`\``, - 500, - ) - } + const $slice = $sliceZone.children[req.data.sliceIndex] + if (!$slice) { + return res.error( + `Failed fo find slice at index $\`{req.data.sliceIndex}\` in \`\``, + 500, + ) + } - // Scroll to Slice - $slice.scrollIntoView({ - behavior: req.data.behavior, - block: req.data.block, - inline: req.data.inline, - }) + // Scroll to Slice + $slice.scrollIntoView({ + behavior: req.data.behavior, + block: req.data.block, + inline: req.data.inline, + }) - // Update active slice after scrolling - if (this._api?.options.activeSliceAPI) { - setTimeout(this.state.setActiveSlice, 750) - } + // Update active slice after scrolling + if (this._api?.options.activeSliceAPI) { + setTimeout(this.state.setActiveSlice, 750) + } - return res.success() + return res.success() + }, + }, + { + allowedOrigin: this._allowedOrigin, }, - }) + ) // Mark API as ready await this._api.ready() diff --git a/test/SimulatorAPI.test.ts b/test/SimulatorAPI.test.ts index 75d7ebe..6d07853 100644 --- a/test/SimulatorAPI.test.ts +++ b/test/SimulatorAPI.test.ts @@ -96,6 +96,35 @@ const callsPostFormattedRequestCorrectly = < }, ] +it("passes allowedOrigin through to receiver options", () => { + const simulatorAPI = new SimulatorAPI( + { + [ClientRequestType.SetSliceZone]: (_req, res) => { + return res.success() + }, + [ClientRequestType.ScrollToSlice]: (_req, res) => { + return res.success() + }, + }, + { allowedOrigin: "https://example.com" }, + ) + + expect(simulatorAPI.options.allowedOrigin).toBe("https://example.com") +}) + +it("defaults allowedOrigin to null when not provided", () => { + const simulatorAPI = new SimulatorAPI({ + [ClientRequestType.SetSliceZone]: (_req, res) => { + return res.success() + }, + [ClientRequestType.ScrollToSlice]: (_req, res) => { + return res.success() + }, + }) + + expect(simulatorAPI.options.allowedOrigin).toBeNull() +}) + it(...callsPostFormattedRequestCorrectly(APIRequestType.SetActiveSlice, null)) it( ...callsPostFormattedRequestCorrectly(APIRequestType.SetSliceZoneSize, { diff --git a/test/channel-ChannelReceiver-allowedOrigin.test.ts b/test/channel-ChannelReceiver-allowedOrigin.test.ts new file mode 100644 index 0000000..8fc1d23 --- /dev/null +++ b/test/channel-ChannelReceiver-allowedOrigin.test.ts @@ -0,0 +1,199 @@ +import { expect, it, vi } from "vitest" + +import type { UnknownRequestMessage } from "../src/channel" +import { + ChannelReceiver, + InternalEmitterRequestType, + createRequestMessage, + createSuccessResponseMessage, +} from "../src/channel" + +class StandaloneChannelReceiver extends ChannelReceiver {} + +// --- Inbound origin validation --- + +it("silently drops messages from non-matching origins when allowedOrigin is set", () => { + const channelReceiver = new StandaloneChannelReceiver( + {}, + { allowedOrigin: "https://example.com" }, + ) + // @ts-expect-error - taking a shortcut by accessing protected property + const postResponseStub = vi.spyOn(channelReceiver, "postResponse") + + const channel = new MessageChannel() + const request = createRequestMessage( + InternalEmitterRequestType.Connect, + undefined, + ) + + // @ts-expect-error - taking a shortcut by accessing private property + channelReceiver._onPublicMessage({ + data: request, + origin: "https://evil.com", + ports: [channel.port1], + }) + + expect(postResponseStub).not.toHaveBeenCalled() +}) + +it("accepts messages from matching origins when allowedOrigin is set", () => { + const channelReceiver = new StandaloneChannelReceiver( + {}, + { allowedOrigin: "https://example.com" }, + ) + // @ts-expect-error - taking a shortcut by accessing protected property + const postResponseStub = vi.spyOn(channelReceiver, "postResponse") + + const channel = new MessageChannel() + const request = createRequestMessage( + InternalEmitterRequestType.Connect, + undefined, + ) + const response = createSuccessResponseMessage(request.requestID, undefined) + + // @ts-expect-error - taking a shortcut by accessing private property + channelReceiver._onPublicMessage({ + data: request, + origin: "https://example.com", + ports: [channel.port1], + }) + + expect(postResponseStub).toHaveBeenCalledOnce() + expect(postResponseStub).toHaveBeenCalledWith(response) +}) + +it("accepts messages from any origin when allowedOrigin is null (default)", () => { + const channelReceiver = new StandaloneChannelReceiver({}, {}) + // @ts-expect-error - taking a shortcut by accessing protected property + const postResponseStub = vi.spyOn(channelReceiver, "postResponse") + + const channel = new MessageChannel() + const request = createRequestMessage( + InternalEmitterRequestType.Connect, + undefined, + ) + const response = createSuccessResponseMessage(request.requestID, undefined) + + // @ts-expect-error - taking a shortcut by accessing private property + channelReceiver._onPublicMessage({ + data: request, + origin: "https://any-origin.com", + ports: [channel.port1], + }) + + expect(postResponseStub).toHaveBeenCalledOnce() + expect(postResponseStub).toHaveBeenCalledWith(response) +}) + +// --- Outbound targetOrigin in ready() --- + +it("uses allowedOrigin as targetOrigin in ready() postMessage", async () => { + const channelReceiver = new StandaloneChannelReceiver( + {}, + { allowedOrigin: "https://example.com" }, + ) + + // Mock `window.parent.postMessage` + const windowParentBck = window.parent + // @ts-expect-error - deleting for test purpose + delete window.parent + const postMessageMock = vi.fn( + (request: UnknownRequestMessage, targetOrigin: string) => { + const response = createSuccessResponseMessage( + request.requestID, + undefined, + ) + // @ts-expect-error - taking a shortcut by accessing private property + channelReceiver._onPublicMessage({ + data: response, + origin: "https://example.com", + }) + }, + ) + window.parent = { + postMessage: postMessageMock as Window["postMessage"], + } as Window["parent"] + + await channelReceiver.ready() + + expect(postMessageMock).toHaveBeenCalledOnce() + expect(postMessageMock.mock.calls[0][1]).toBe("https://example.com") + + window.parent = windowParentBck +}) + +it("uses '*' as targetOrigin in ready() when allowedOrigin is null", async () => { + const channelReceiver = new StandaloneChannelReceiver({}, {}) + + // Mock `window.parent.postMessage` + const windowParentBck = window.parent + // @ts-expect-error - deleting for test purpose + delete window.parent + const postMessageMock = vi.fn( + (request: UnknownRequestMessage, targetOrigin: string) => { + const response = createSuccessResponseMessage( + request.requestID, + undefined, + ) + // @ts-expect-error - taking a shortcut by accessing private property + channelReceiver._onPublicMessage({ data: response }) + }, + ) + window.parent = { + postMessage: postMessageMock as Window["postMessage"], + } as Window["parent"] + + await channelReceiver.ready() + + expect(postMessageMock).toHaveBeenCalledOnce() + expect(postMessageMock.mock.calls[0][1]).toBe("*") + + window.parent = windowParentBck +}) + +// --- allowedOrigin preserved during Connect options merge --- + +it("preserves allowedOrigin when Connect request sends conflicting options", () => { + const channelReceiver = new StandaloneChannelReceiver( + {}, + { allowedOrigin: "https://example.com" }, + ) + + const channel = new MessageChannel() + + // Connect request data tries to overwrite allowedOrigin + const request = createRequestMessage(InternalEmitterRequestType.Connect, { + allowedOrigin: "https://evil.com", + someOtherOption: true, + }) + + // @ts-expect-error - taking a shortcut by accessing private property + channelReceiver._onPublicMessage({ + data: request, + origin: "https://example.com", + ports: [channel.port1], + }) + + expect(channelReceiver.options.allowedOrigin).toBe("https://example.com") + // Verify that non-protected options ARE merged + expect(channelReceiver.options.someOtherOption).toBe(true) +}) + +it("preserves null allowedOrigin during Connect options merge", () => { + const channelReceiver = new StandaloneChannelReceiver({}, {}) + + const channel = new MessageChannel() + + // Connect request data tries to set allowedOrigin + const request = createRequestMessage(InternalEmitterRequestType.Connect, { + allowedOrigin: "https://evil.com", + }) + + // @ts-expect-error - taking a shortcut by accessing private property + channelReceiver._onPublicMessage({ + data: request, + ports: [channel.port1], + }) + + expect(channelReceiver.options.allowedOrigin).toBeNull() +})