mirror of
https://github.com/openclaw/openclaw.git
synced 2026-06-30 19:59:35 +00:00
fix(tts): bound generated speech downloads
This commit is contained in:
@@ -98,6 +98,29 @@ describe("gradium speech provider", () => {
|
||||
expect(result.audioBuffer).toEqual(audioData);
|
||||
});
|
||||
|
||||
it("applies the configured media byte cap to synthesized audio", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue(new Response(new Uint8Array(2048), { status: 200 })),
|
||||
);
|
||||
|
||||
await expect(
|
||||
provider.synthesize({
|
||||
text: "OpenClaw test",
|
||||
cfg: {
|
||||
agents: {
|
||||
defaults: {
|
||||
mediaMaxMb: 0.001,
|
||||
},
|
||||
},
|
||||
} as never,
|
||||
providerConfig: { apiKey: "gsk_test123" },
|
||||
target: "audio-file",
|
||||
timeoutMs: 30_000,
|
||||
}),
|
||||
).rejects.toThrow("Gradium TTS audio response exceeds");
|
||||
});
|
||||
|
||||
it("uses ulaw_8000 for telephony synthesis", async () => {
|
||||
const audioData = Buffer.from("ulaw-audio-data");
|
||||
const fetchMock = vi.fn().mockResolvedValue(new Response(audioData, { status: 200 }));
|
||||
|
||||
@@ -8,6 +8,8 @@ import { asObject, trimToUndefined } from "openclaw/plugin-sdk/speech";
|
||||
import { DEFAULT_GRADIUM_VOICE_ID, GRADIUM_VOICES, normalizeGradiumBaseUrl } from "./shared.js";
|
||||
import { gradiumTTS } from "./tts.js";
|
||||
|
||||
const DEFAULT_GENERATED_AUDIO_MAX_BYTES = 16 * 1024 * 1024;
|
||||
|
||||
type GradiumProviderConfig = {
|
||||
apiKey?: string;
|
||||
baseUrl: string;
|
||||
@@ -36,6 +38,16 @@ function readGradiumProviderConfig(config: SpeechProviderConfig): GradiumProvide
|
||||
};
|
||||
}
|
||||
|
||||
function resolveGeneratedAudioMaxBytes(req: {
|
||||
cfg: { agents?: { defaults?: { mediaMaxMb?: number } } };
|
||||
}): number {
|
||||
const configured = req.cfg.agents?.defaults?.mediaMaxMb;
|
||||
if (typeof configured === "number" && Number.isFinite(configured) && configured > 0) {
|
||||
return Math.floor(configured * 1024 * 1024);
|
||||
}
|
||||
return DEFAULT_GENERATED_AUDIO_MAX_BYTES;
|
||||
}
|
||||
|
||||
function parseDirectiveToken(ctx: SpeechDirectiveTokenParseContext): {
|
||||
handled: boolean;
|
||||
overrides?: Record<string, unknown>;
|
||||
@@ -86,6 +98,7 @@ export function buildGradiumSpeechProvider(): SpeechProviderPlugin {
|
||||
voiceId: trimToUndefined(overrides.voiceId) ?? config.voiceId,
|
||||
outputFormat,
|
||||
timeoutMs: req.timeoutMs,
|
||||
maxBytes: resolveGeneratedAudioMaxBytes(req),
|
||||
});
|
||||
return {
|
||||
audioBuffer,
|
||||
@@ -110,6 +123,7 @@ export function buildGradiumSpeechProvider(): SpeechProviderPlugin {
|
||||
voiceId: trimToUndefined(overrides.voiceId) ?? config.voiceId,
|
||||
outputFormat,
|
||||
timeoutMs: req.timeoutMs,
|
||||
maxBytes: resolveGeneratedAudioMaxBytes(req),
|
||||
});
|
||||
return { audioBuffer, outputFormat, sampleRate };
|
||||
},
|
||||
|
||||
@@ -28,6 +28,14 @@ describe("gradium tts diagnostics", () => {
|
||||
};
|
||||
}
|
||||
|
||||
function createStreamingAudioResponse(params: {
|
||||
chunkCount: number;
|
||||
chunkSize: number;
|
||||
byte: number;
|
||||
}): { response: Response; getReadCount: () => number } {
|
||||
return createStreamingErrorResponse({ ...params, status: 200 });
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllGlobals();
|
||||
vi.restoreAllMocks();
|
||||
@@ -134,4 +142,27 @@ describe("gradium tts diagnostics", () => {
|
||||
});
|
||||
expect(result).toEqual(audioData);
|
||||
});
|
||||
|
||||
it("caps streamed audio responses instead of buffering oversized TTS output", async () => {
|
||||
const streamed = createStreamingAudioResponse({
|
||||
chunkCount: 20,
|
||||
chunkSize: 1024,
|
||||
byte: 121,
|
||||
});
|
||||
vi.stubGlobal("fetch", vi.fn().mockResolvedValue(streamed.response));
|
||||
|
||||
await expect(
|
||||
gradiumTTS({
|
||||
text: "hello",
|
||||
apiKey: "test-key",
|
||||
baseUrl: "https://api.gradium.ai",
|
||||
voiceId: "YTpq7expH9539ERJ",
|
||||
outputFormat: "wav",
|
||||
timeoutMs: 5_000,
|
||||
maxBytes: 2048,
|
||||
}),
|
||||
).rejects.toThrow("Gradium TTS audio response exceeds 2048 bytes");
|
||||
|
||||
expect(streamed.getReadCount()).toBeLessThan(20);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import { assertOkOrThrowProviderError } from "openclaw/plugin-sdk/provider-http";
|
||||
import { readResponseWithLimit } from "openclaw/plugin-sdk/response-limit-runtime";
|
||||
import { fetchWithSsrFGuard } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
import { normalizeGradiumBaseUrl } from "./shared.js";
|
||||
|
||||
const DEFAULT_TTS_MAX_BYTES = 16 * 1024 * 1024;
|
||||
|
||||
export async function gradiumTTS(params: {
|
||||
text: string;
|
||||
apiKey: string;
|
||||
@@ -9,8 +12,17 @@ export async function gradiumTTS(params: {
|
||||
voiceId: string;
|
||||
outputFormat: "wav" | "opus" | "ulaw_8000" | "pcm" | "pcm_24000" | "alaw_8000";
|
||||
timeoutMs: number;
|
||||
maxBytes?: number;
|
||||
}): Promise<Buffer> {
|
||||
const { text, apiKey, baseUrl, voiceId, outputFormat, timeoutMs } = params;
|
||||
const {
|
||||
text,
|
||||
apiKey,
|
||||
baseUrl,
|
||||
voiceId,
|
||||
outputFormat,
|
||||
timeoutMs,
|
||||
maxBytes = DEFAULT_TTS_MAX_BYTES,
|
||||
} = params;
|
||||
const normalizedBaseUrl = normalizeGradiumBaseUrl(baseUrl);
|
||||
const url = `${normalizedBaseUrl}/api/post/speech/tts`;
|
||||
const hostname = new URL(normalizedBaseUrl).hostname;
|
||||
@@ -39,7 +51,10 @@ export async function gradiumTTS(params: {
|
||||
try {
|
||||
await assertOkOrThrowProviderError(response, "Gradium API error");
|
||||
|
||||
return Buffer.from(await response.arrayBuffer());
|
||||
return await readResponseWithLimit(response, maxBytes, {
|
||||
onOverflow: ({ maxBytes }) =>
|
||||
new Error(`Gradium TTS audio response exceeds ${maxBytes} bytes`),
|
||||
});
|
||||
} finally {
|
||||
await release();
|
||||
}
|
||||
|
||||
@@ -291,6 +291,33 @@ describe("buildOpenAISpeechProvider", () => {
|
||||
expect(result.voiceCompatible).toBe(false);
|
||||
});
|
||||
|
||||
it("applies the configured media byte cap to synthesized audio", async () => {
|
||||
const provider = buildOpenAISpeechProvider();
|
||||
globalThis.fetch = vi.fn(
|
||||
async () => new Response(new Uint8Array(2048), { status: 200 }),
|
||||
) as unknown as typeof fetch;
|
||||
|
||||
await expect(
|
||||
provider.synthesize({
|
||||
text: "hello",
|
||||
cfg: {
|
||||
agents: {
|
||||
defaults: {
|
||||
mediaMaxMb: 0.001,
|
||||
},
|
||||
},
|
||||
} as never,
|
||||
providerConfig: {
|
||||
apiKey: "sk-test",
|
||||
model: "gpt-4o-mini-tts",
|
||||
voice: "alloy",
|
||||
},
|
||||
target: "audio-file",
|
||||
timeoutMs: 1_000,
|
||||
}),
|
||||
).rejects.toThrow("OpenAI TTS audio response exceeds");
|
||||
});
|
||||
|
||||
it("applies provider overrides to telephony synthesis", async () => {
|
||||
const provider = buildOpenAISpeechProvider();
|
||||
const fetchMock = vi.fn(async (_url: string, init?: RequestInit) => {
|
||||
|
||||
@@ -26,6 +26,7 @@ import {
|
||||
} from "./tts.js";
|
||||
|
||||
const OPENAI_SPEECH_RESPONSE_FORMATS = ["mp3", "opus", "wav"] as const;
|
||||
const DEFAULT_GENERATED_AUDIO_MAX_BYTES = 16 * 1024 * 1024;
|
||||
|
||||
type OpenAiSpeechResponseFormat = (typeof OPENAI_SPEECH_RESPONSE_FORMATS)[number];
|
||||
|
||||
@@ -174,6 +175,16 @@ function readOpenAIOverrides(
|
||||
};
|
||||
}
|
||||
|
||||
function resolveGeneratedAudioMaxBytes(req: {
|
||||
cfg: { agents?: { defaults?: { mediaMaxMb?: number } } };
|
||||
}): number {
|
||||
const configured = req.cfg.agents?.defaults?.mediaMaxMb;
|
||||
if (typeof configured === "number" && Number.isFinite(configured) && configured > 0) {
|
||||
return Math.floor(configured * 1024 * 1024);
|
||||
}
|
||||
return DEFAULT_GENERATED_AUDIO_MAX_BYTES;
|
||||
}
|
||||
|
||||
function renderOpenAITtsPersonaInstructions(req: {
|
||||
label?: string;
|
||||
prompt?: {
|
||||
@@ -328,6 +339,7 @@ export function buildOpenAISpeechProvider(): SpeechProviderPlugin {
|
||||
responseFormat,
|
||||
extraBody: config.extraBody,
|
||||
timeoutMs: req.timeoutMs,
|
||||
maxBytes: resolveGeneratedAudioMaxBytes(req),
|
||||
});
|
||||
return {
|
||||
audioBuffer,
|
||||
@@ -356,6 +368,7 @@ export function buildOpenAISpeechProvider(): SpeechProviderPlugin {
|
||||
responseFormat: outputFormat,
|
||||
extraBody: config.extraBody,
|
||||
timeoutMs: req.timeoutMs,
|
||||
maxBytes: resolveGeneratedAudioMaxBytes(req),
|
||||
});
|
||||
return { audioBuffer, outputFormat, sampleRate };
|
||||
},
|
||||
|
||||
@@ -322,6 +322,32 @@ describe("openai tts", () => {
|
||||
).rejects.toThrow("OpenAI TTS API error (503): temporary upstream outage");
|
||||
});
|
||||
|
||||
it("caps streamed audio responses instead of buffering oversized TTS output", async () => {
|
||||
const streamed = createStreamingErrorResponse({
|
||||
status: 200,
|
||||
chunkCount: 20,
|
||||
chunkSize: 1024,
|
||||
byte: 121,
|
||||
});
|
||||
const fetchMock = vi.fn(async () => streamed.response);
|
||||
globalThis.fetch = fetchMock as unknown as typeof fetch;
|
||||
|
||||
await expect(
|
||||
openaiTTS({
|
||||
text: "hello",
|
||||
apiKey: "test-key",
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
model: "gpt-4o-mini-tts",
|
||||
voice: "alloy",
|
||||
responseFormat: "mp3",
|
||||
timeoutMs: 5_000,
|
||||
maxBytes: 2048,
|
||||
}),
|
||||
).rejects.toThrow("OpenAI TTS audio response exceeds 2048 bytes");
|
||||
|
||||
expect(streamed.getReadCount()).toBeLessThan(20);
|
||||
});
|
||||
|
||||
it("caps streamed non-JSON error reads instead of consuming full response bodies", async () => {
|
||||
const streamed = createStreamingErrorResponse({
|
||||
status: 503,
|
||||
|
||||
@@ -6,12 +6,14 @@ import {
|
||||
captureHttpExchange,
|
||||
isDebugProxyGlobalFetchPatchInstalled,
|
||||
} from "openclaw/plugin-sdk/proxy-capture";
|
||||
import { readResponseWithLimit } from "openclaw/plugin-sdk/response-limit-runtime";
|
||||
import {
|
||||
fetchWithSsrFGuard,
|
||||
ssrfPolicyFromHttpBaseUrlAllowedHostname,
|
||||
} from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
|
||||
export const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1";
|
||||
const DEFAULT_TTS_MAX_BYTES = 16 * 1024 * 1024;
|
||||
|
||||
export const OPENAI_TTS_MODELS = ["gpt-4o-mini-tts", "tts-1", "tts-1-hd"] as const;
|
||||
|
||||
@@ -100,6 +102,7 @@ export async function openaiTTS(params: {
|
||||
responseFormat: "mp3" | "opus" | "pcm" | "wav";
|
||||
extraBody?: Record<string, unknown>;
|
||||
timeoutMs: number;
|
||||
maxBytes?: number;
|
||||
}): Promise<Buffer> {
|
||||
const {
|
||||
text,
|
||||
@@ -112,6 +115,7 @@ export async function openaiTTS(params: {
|
||||
responseFormat,
|
||||
extraBody,
|
||||
timeoutMs,
|
||||
maxBytes = DEFAULT_TTS_MAX_BYTES,
|
||||
} = params;
|
||||
const effectiveInstructions = resolveOpenAITtsInstructions(model, instructions, baseUrl);
|
||||
|
||||
@@ -177,7 +181,10 @@ export async function openaiTTS(params: {
|
||||
|
||||
await assertOkOrThrowProviderError(response, "OpenAI TTS API error");
|
||||
|
||||
return Buffer.from(await response.arrayBuffer());
|
||||
return await readResponseWithLimit(response, maxBytes, {
|
||||
onOverflow: ({ maxBytes }) =>
|
||||
new Error(`OpenAI TTS audio response exceeds ${maxBytes} bytes`),
|
||||
});
|
||||
} finally {
|
||||
await release();
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ function requireLastTtsCall(): {
|
||||
language?: string;
|
||||
speed?: number;
|
||||
responseFormat?: string;
|
||||
maxBytes?: number;
|
||||
} {
|
||||
const params = (xaiTTSMock.mock.calls as unknown as Array<[unknown]>).at(-1)?.[0] as
|
||||
| {
|
||||
@@ -47,6 +48,7 @@ function requireLastTtsCall(): {
|
||||
language?: string;
|
||||
speed?: number;
|
||||
responseFormat?: string;
|
||||
maxBytes?: number;
|
||||
}
|
||||
| undefined;
|
||||
if (!params) {
|
||||
@@ -68,7 +70,13 @@ describe("xai speech provider", () => {
|
||||
const provider = buildXaiSpeechProvider();
|
||||
const result = await provider.synthesize({
|
||||
text: "hello",
|
||||
cfg: {},
|
||||
cfg: {
|
||||
agents: {
|
||||
defaults: {
|
||||
mediaMaxMb: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
providerConfig: {
|
||||
apiKey: "xai-key",
|
||||
voiceId: "eve",
|
||||
@@ -87,6 +95,7 @@ describe("xai speech provider", () => {
|
||||
expect(tts.baseUrl).toBe("https://api.x.ai/v1");
|
||||
expect(tts.voiceId).toBe("eve");
|
||||
expect(tts.responseFormat).toBe("mp3");
|
||||
expect(tts.maxBytes).toBe(2 * 1024 * 1024);
|
||||
});
|
||||
|
||||
it("honors configured response formats", async () => {
|
||||
|
||||
@@ -26,6 +26,7 @@ import {
|
||||
} from "./tts.js";
|
||||
|
||||
const XAI_SPEECH_RESPONSE_FORMATS = ["mp3", "wav", "pcm", "mulaw", "alaw"] as const;
|
||||
const DEFAULT_GENERATED_AUDIO_MAX_BYTES = 16 * 1024 * 1024;
|
||||
|
||||
type XaiSpeechResponseFormat = (typeof XAI_SPEECH_RESPONSE_FORMATS)[number];
|
||||
|
||||
@@ -130,6 +131,16 @@ function readXaiOverrides(overrides: SpeechProviderOverrides | undefined): XaiTt
|
||||
};
|
||||
}
|
||||
|
||||
function resolveGeneratedAudioMaxBytes(req: {
|
||||
cfg: { agents?: { defaults?: { mediaMaxMb?: number } } };
|
||||
}): number {
|
||||
const configured = req.cfg.agents?.defaults?.mediaMaxMb;
|
||||
if (typeof configured === "number" && Number.isFinite(configured) && configured > 0) {
|
||||
return Math.floor(configured * 1024 * 1024);
|
||||
}
|
||||
return DEFAULT_GENERATED_AUDIO_MAX_BYTES;
|
||||
}
|
||||
|
||||
function parseDirectiveToken(ctx: SpeechDirectiveTokenParseContext): {
|
||||
handled: boolean;
|
||||
overrides?: SpeechProviderOverrides;
|
||||
@@ -231,6 +242,7 @@ export function buildXaiSpeechProvider(): SpeechProviderPlugin {
|
||||
speed: overrides.speed ?? config.speed,
|
||||
responseFormat,
|
||||
timeoutMs: req.timeoutMs,
|
||||
maxBytes: resolveGeneratedAudioMaxBytes(req),
|
||||
});
|
||||
return {
|
||||
audioBuffer,
|
||||
@@ -254,6 +266,7 @@ export function buildXaiSpeechProvider(): SpeechProviderPlugin {
|
||||
speed: overrides.speed ?? config.speed,
|
||||
responseFormat: outputFormat,
|
||||
timeoutMs: req.timeoutMs,
|
||||
maxBytes: resolveGeneratedAudioMaxBytes(req),
|
||||
});
|
||||
return { audioBuffer, outputFormat, sampleRate };
|
||||
},
|
||||
|
||||
@@ -2,6 +2,31 @@ import { mockPinnedHostnameResolution } from "openclaw/plugin-sdk/test-env";
|
||||
import { beforeEach, afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { isValidXaiTtsVoice, XAI_BASE_URL, XAI_TTS_VOICES, xaiTTS } from "./tts.js";
|
||||
|
||||
function createStreamingAudioResponse(params: {
|
||||
chunkCount: number;
|
||||
chunkSize: number;
|
||||
byte: number;
|
||||
}): { response: Response; getReadCount: () => number } {
|
||||
let reads = 0;
|
||||
const stream = new ReadableStream<Uint8Array>({
|
||||
pull(controller) {
|
||||
if (reads >= params.chunkCount) {
|
||||
controller.close();
|
||||
return;
|
||||
}
|
||||
reads += 1;
|
||||
controller.enqueue(new Uint8Array(params.chunkSize).fill(params.byte));
|
||||
},
|
||||
});
|
||||
return {
|
||||
response: new Response(stream, {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "audio/mpeg" },
|
||||
}),
|
||||
getReadCount: () => reads,
|
||||
};
|
||||
}
|
||||
|
||||
describe("xai tts", () => {
|
||||
const originalFetch = globalThis.fetch;
|
||||
let ssrfMock: { mockRestore: () => void } | undefined;
|
||||
@@ -103,6 +128,31 @@ describe("xai tts", () => {
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
it("caps streamed audio responses instead of buffering oversized TTS output", async () => {
|
||||
const streamed = createStreamingAudioResponse({
|
||||
chunkCount: 20,
|
||||
chunkSize: 1024,
|
||||
byte: 121,
|
||||
});
|
||||
const fetchMock = vi.fn(async () => streamed.response);
|
||||
globalThis.fetch = fetchMock as unknown as typeof fetch;
|
||||
|
||||
await expect(
|
||||
xaiTTS({
|
||||
text: "hello",
|
||||
apiKey: "ok-key",
|
||||
baseUrl: XAI_BASE_URL,
|
||||
voiceId: "eve",
|
||||
language: "en",
|
||||
responseFormat: "mp3",
|
||||
timeoutMs: 5_000,
|
||||
maxBytes: 2048,
|
||||
}),
|
||||
).rejects.toThrow("xAI TTS audio response exceeds 2048 bytes");
|
||||
|
||||
expect(streamed.getReadCount()).toBeLessThan(20);
|
||||
});
|
||||
|
||||
it("falls back to raw body text when the error body is non-JSON", async () => {
|
||||
const fetchMock = vi.fn(
|
||||
async () => new Response("temporary upstream outage", { status: 503 }),
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import { assertOkOrThrowProviderError, postJsonRequest } from "openclaw/plugin-sdk/provider-http";
|
||||
import { readResponseWithLimit } from "openclaw/plugin-sdk/response-limit-runtime";
|
||||
import { trimToUndefined } from "openclaw/plugin-sdk/speech";
|
||||
import { XAI_BASE_URL } from "./api.js";
|
||||
import { xaiUserAgentHeaderFor } from "./src/xai-user-agent.js";
|
||||
export { XAI_BASE_URL };
|
||||
|
||||
const DEFAULT_TTS_MAX_BYTES = 16 * 1024 * 1024;
|
||||
export const XAI_TTS_VOICES = ["eve", "ara", "rex", "sal", "leo", "una"] as const;
|
||||
|
||||
type XaiTtsVoice = (typeof XAI_TTS_VOICES)[number];
|
||||
@@ -49,6 +51,7 @@ export async function xaiTTS(params: {
|
||||
speed?: number;
|
||||
responseFormat?: "mp3" | "wav" | "pcm" | "mulaw" | "alaw";
|
||||
timeoutMs: number;
|
||||
maxBytes?: number;
|
||||
}): Promise<Buffer> {
|
||||
const {
|
||||
text,
|
||||
@@ -59,6 +62,7 @@ export async function xaiTTS(params: {
|
||||
speed,
|
||||
responseFormat = "mp3",
|
||||
timeoutMs,
|
||||
maxBytes = DEFAULT_TTS_MAX_BYTES,
|
||||
} = params;
|
||||
const language = normalizeXaiLanguageCode(rawLanguage) ?? "en";
|
||||
|
||||
@@ -90,7 +94,9 @@ export async function xaiTTS(params: {
|
||||
try {
|
||||
await assertOkOrThrowProviderError(response, "xAI TTS API error");
|
||||
|
||||
return Buffer.from(await response.arrayBuffer());
|
||||
return await readResponseWithLimit(response, maxBytes, {
|
||||
onOverflow: ({ maxBytes }) => new Error(`xAI TTS audio response exceeds ${maxBytes} bytes`),
|
||||
});
|
||||
} finally {
|
||||
await release();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user