more work on TTS

This commit is contained in:
Jack Merrill 2025-05-05 11:07:55 -04:00
parent 16b552262e
commit 08f172544d
No known key found for this signature in database
GPG Key ID: FD574AFF96E99636
6 changed files with 130 additions and 99 deletions

View File

@ -54,7 +54,7 @@ export default async function DocumentPage(props: { params: { id: string } }) {
} }
const { data: documents, error: documentsError } = await supabase const { data: documents, error: documentsError } = await supabase
.from("documents") .from("documents")
.select("id, file_name, created_at, owner") .select("*")
.eq("owner", user.id) .eq("owner", user.id)
.order("created_at", { ascending: false }); .order("created_at", { ascending: false });
@ -76,6 +76,8 @@ export default async function DocumentPage(props: { params: { id: string } }) {
<AppSidebar <AppSidebar
documents={documents.map((d) => { documents={documents.map((d) => {
return { return {
id: d.id,
disabled: d.is_processing,
name: d.file_name, name: d.file_name,
url: `/dashboard/documents/${d.id}`, url: `/dashboard/documents/${d.id}`,
emoji: "📄", emoji: "📄",

View File

@ -47,6 +47,8 @@ export default async function Page() {
<AppSidebar <AppSidebar
documents={documents.map((d) => { documents={documents.map((d) => {
return { return {
id: d.id,
disabled: d.is_processing,
name: d.file_name, name: d.file_name,
url: `/dashboard/documents/${d.id}`, url: `/dashboard/documents/${d.id}`,
emoji: "📄", emoji: "📄",

View File

@ -20,7 +20,10 @@ export type OCRData = {
index: number; index: number;
images: string[]; images: string[];
markdown: string; markdown: string;
citations: Record<string, string>; citations: {
text: string;
number: string;
}[];
dimensions: { dimensions: {
dpi: number; dpi: number;
width: number; width: number;
@ -64,16 +67,15 @@ export default function MarkdownRenderer({
let totalCitations = 0; let totalCitations = 0;
ocr.forEach((page) => { ocr.forEach((page) => {
Object.entries(page.citations).forEach(([key, value]) => { // each page has its own citations (1-N), so we need to map them correctly
if (value) { page.citations.forEach((citation, index) => {
totalCitations++; totalCitations += 1;
citations.push({ citations.push({
text: value, text: citation.text,
page: page.index, page: page.index,
index: key, index: (totalCitations + index).toString(), // unique index across all pages
number: Number(totalCitations), number: totalCitations + index + 1, // 1-based numbering
}); });
}
}); });
}); });
@ -128,7 +130,8 @@ export default function MarkdownRenderer({
} }
const citation = citations.find( const citation = citations.find(
(c) => c.index === referenceNumber && c.page === page.index (c) =>
c.index === referenceNumber || c.number.toString() === referenceNumber
); );
if (!citation) { if (!citation) {
@ -146,7 +149,6 @@ export default function MarkdownRenderer({
</PopoverTrigger> </PopoverTrigger>
<PopoverContent className="w-56 overflow-hidden rounded-lg p-0"> <PopoverContent className="w-56 overflow-hidden rounded-lg p-0">
<div className="p-4"> <div className="p-4">
{/* Replace with actual reference content */}
<p>{citation.text}</p> <p>{citation.text}</p>
</div> </div>
</PopoverContent> </PopoverContent>

View File

@ -83,33 +83,35 @@ export const TTSProvider = ({
if (cached) { if (cached) {
return cached; return cached;
} }
worker.current!.postMessage({
type: "generate",
text: sentence,
voice: selectedSpeaker,
});
setStatus("running");
setLoadingMessage("Generating audio...");
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
worker.current!.addEventListener( const handleMessage = (e: MessageEvent) => {
"message", if (e.data.index !== index) return; // Ignore messages for other indices
(e: any) => {
if (e.data.status === "complete") { if (e.data.status === "complete") {
localStorage.setItem(key, e.data.audio); localStorage.setItem(key, e.data.audio);
resolve(e.data.audio); worker.current!.removeEventListener("message", handleMessage); // Clean up listener
} else if (e.data.status === "error") { resolve(e.data.audio);
toast.error(`Error generating audio: ${e.data.error}`); } else if (e.data.status === "error") {
reject(e.data.error); worker.current!.removeEventListener("message", handleMessage); // Clean up listener
} toast.error(`Error generating audio: ${e.data.error}`);
}, reject(e.data.error);
{ once: true } }
); };
worker.current!.addEventListener("message", handleMessage);
worker.current!.postMessage({
type: "generate",
index,
text: sentence,
voice: selectedSpeaker,
});
}); });
} }
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted. // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
useEffect(() => { useEffect(() => {
// Create the worker if it does not yet exist.
console.log("Initializing worker..."); console.log("Initializing worker...");
worker.current ??= new Worker("/workers/kokoro-worker.js", { worker.current ??= new Worker("/workers/kokoro-worker.js", {
type: "module", type: "module",
@ -117,7 +119,6 @@ export const TTSProvider = ({
console.log("Worker initialized"); console.log("Worker initialized");
// Create a callback function for messages from the worker thread.
const onMessageReceived = (e: any) => { const onMessageReceived = (e: any) => {
switch (e.data.status) { switch (e.data.status) {
case "device": case "device":
@ -132,56 +133,71 @@ export const TTSProvider = ({
break; break;
case "complete": case "complete":
const { audio, text } = e.data; const { audio, text } = e.data;
// Generation complete: re-enable the "Generate" button
setResults((prev) => [{ text, src: audio }, ...prev]); setResults((prev) => [{ text, src: audio }, ...prev]);
setStatus("ready"); setStatus("ready");
break; break;
} }
}; };
console.log("onmessagereceived");
const onErrorReceived = (e: any) => { const onErrorReceived = (e: any) => {
console.error("Worker error:", e); console.error("Worker error:", e);
setError(e.message); setError(e.message);
}; };
console.log("Attaching event listeners to worker");
// Attach the callback function as an event listener.
worker.current.addEventListener("message", onMessageReceived); worker.current.addEventListener("message", onMessageReceived);
worker.current.addEventListener("error", onErrorReceived); worker.current.addEventListener("error", onErrorReceived);
console.log(worker.current);
// Define a cleanup function for when the component is unmounted.
return () => { return () => {
worker.current!.removeEventListener("message", onMessageReceived); worker.current!.removeEventListener("message", onMessageReceived);
worker.current!.removeEventListener("error", onErrorReceived); worker.current!.removeEventListener("error", onErrorReceived);
}; };
}, []); }, []);
// Pre-buffer current and next 2 sentences. // Pre-buffer current and next 5 sentences.
useEffect(() => { useEffect(() => {
let isCancelled = false;
async function preloadBuffer() { async function preloadBuffer() {
const newBuffer = [...ttsBuffer]; const newBuffer = [...ttsBuffer];
const end = Math.min(sentences.length, currentSentence + 3); const end = Math.min(sentences.length, currentSentence + 5); // Preload 5 sentences ahead
for (let i = currentSentence; i < end; i++) { for (let i = currentSentence; i < end; i++) {
if (isCancelled) break;
if (!newBuffer[i]) { if (!newBuffer[i]) {
console.log("Preloading TTS for sentence:", i, sentences[i]); console.log("Preloading TTS for sentence:", i, sentences[i]);
newBuffer[i] = await generateTTSForIndex( try {
removeMarkdown(sentences[i]), newBuffer[i] = await generateTTSForIndex(
i removeMarkdown(sentences[i]),
); i
);
} catch (error) {
console.error("Error preloading TTS:", error);
}
} }
} }
setTtsBuffer(newBuffer);
if (!isCancelled) {
setTtsBuffer((prev) => {
// Only update state if the buffer has changed
if (JSON.stringify(prev) !== JSON.stringify(newBuffer)) {
return newBuffer;
}
return prev;
});
}
} }
preloadBuffer(); preloadBuffer();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentSentence, sentences.join(" ")]); return () => {
isCancelled = true; // Cancel preloading if the component unmounts or dependencies change
};
}, [currentSentence, sentences]);
const playSentence = async (index: number) => { const playSentence = async (index: number) => {
if (index === currentSentence) return; // Prevent redundant updates
setCurrentSentence(index); setCurrentSentence(index);
let audioUrl = ttsBuffer[index]; let audioUrl = ttsBuffer[index];
if (!audioUrl) { if (!audioUrl) {
audioUrl = await generateTTSForIndex( audioUrl = await generateTTSForIndex(
@ -194,6 +210,7 @@ export const TTSProvider = ({
return updated; return updated;
}); });
} }
if (audioRef.current) { if (audioRef.current) {
audioRef.current.src = audioUrl; audioRef.current.src = audioUrl;
await new Promise((res) => { await new Promise((res) => {
@ -211,16 +228,21 @@ export const TTSProvider = ({
const playInOrder = async (index: number) => { const playInOrder = async (index: number) => {
if (index < 0 || index >= sentences.length) return; if (index < 0 || index >= sentences.length) return;
console.log("Playing in order from index:", index); if (index === currentSentence && playing) return; // Prevent redundant playback
setCurrentSentence(index); setCurrentSentence(index);
setPlaying(true);
for (let i = index; i < sentences.length; i++) { for (let i = index; i < sentences.length; i++) {
console.log("Playing sentence:", i, sentences[i]); console.log("Playing sentence:", i, sentences[i]);
await playSentence(i); try {
if (i < sentences.length - 1) { await playSentence(i);
console.log("Waiting for next sentence..."); } catch (error) {
await new Promise((resolve) => setTimeout(resolve, 1000)); console.error("Error playing sentence:", error);
break; // Stop playback on error
} }
} }
setPlaying(false);
}; };
const pause = () => { const pause = () => {

View File

@ -1,6 +1,9 @@
console.log("Initializing Kokoro TTS Worker"); console.log("Initializing Kokoro TTS Worker");
import { KokoroTTS } from "https://cdn.jsdelivr.net/npm/kokoro-js@1.2.0/+esm"; import {
KokoroTTS,
TextSplitterStream,
} from "https://cdn.jsdelivr.net/npm/kokoro-js@1.2.0/+esm";
async function detectWebGPU() { async function detectWebGPU() {
try { try {
const adapter = await navigator.gpu.requestAdapter(); const adapter = await navigator.gpu.requestAdapter();
@ -35,29 +38,43 @@ const tts = await KokoroTTS.from_pretrained(model_id, {
}, },
}); });
const splitter = new TextSplitterStream();
const stream = tts.stream(splitter);
let index = 0;
// Listen for messages from the main thread
self.addEventListener("message", async (e) => {
const { text, voice, index } = e.data;
console.log(
`Generating speech for text: "${text}" with voice: ${voice}, index: ${index}`
);
// Push the text to the splitter
splitter.push(text);
splitter.push(""); // Signal the end of the text
// Process the stream and include the correct index
for await (const { text: processedText, phonemes, audio } of stream) {
console.log({ processedText, phonemes });
const blob = audio.toBlob();
const base64Audio = await blobToBase64(blob);
self.postMessage({
status: "complete",
audio: base64Audio,
text: processedText,
phonemes,
index, // Include the index from the original message
});
break; // Stop processing after the first chunk for this message
}
});
console.log("Kokoro TTS model loaded successfully"); console.log("Kokoro TTS model loaded successfully");
self.postMessage({ status: "ready", voices: tts.voices, device }); self.postMessage({ status: "ready", voices: tts.voices, device });
console.log("Available voices:", tts.voices); console.log("Available voices:", tts.voices);
// Listen for messages from the main thread
self.addEventListener("message", async (e) => {
const { text, voice } = e.data;
try {
// Generate speech
console.log(`Generating speech for text: "${text}" with voice: ${voice}`);
const audio = await tts.generate(text, { voice });
// Send the audio file back to the main thread
const blob = audio.toBlob();
self.postMessage({
status: "complete",
audio: await blobToBase64(blob),
text,
});
} catch (error) {
self.postMessage({ status: "error", error: error.message });
}
});

View File

@ -14,7 +14,7 @@ const client = new Mistral({
const PROCESSING_PROMPT = ` const PROCESSING_PROMPT = `
You are a document processing AI. Your task is to process the Markdown text scanned from a document page and return it in a clean and structured format. You are a document processing AI. Your task is to process the Markdown text scanned from a document page and return it in a clean and structured format.
The textual page data should only be returned in valid Markdown format. Use proper headings and subheadings to structure the content. The textual page data should only be returned in valid Markdown format. Use proper headings and subheadings to structure the content. **Do not add headings if they do not exist in the original text.**
Any images should be included. Any images should be included.
Do not return the Markdown as a code block, only as a raw string, without any new lines. Do not return the Markdown as a code block, only as a raw string, without any new lines.
@ -35,7 +35,7 @@ Return the final result as a text object with the following structure (without c
"citations": [ "citations": [
{ {
"number": 1, // The number as it appears in the text "number": 1, // The number as it appears in the text
"text": "Citation text 1" "text": "Citation text 1" // Ensure any JSON-breaking characters are properly escaped
}, },
{ {
"number": 2, "number": 2,
@ -138,7 +138,7 @@ Deno.serve(async (req) => {
message: "File ID found in form data.", message: "File ID found in form data.",
}); });
const docId = formData.get("id"); const docId = formData.get("id");
console.log("Document ID:", docId, formData);
const { data: documentData, error: documentError } = await supabase const { data: documentData, error: documentError } = await supabase
.from("documents") .from("documents")
.select("*") .select("*")
@ -170,24 +170,9 @@ Deno.serve(async (req) => {
throw new Error("Document record not found"); throw new Error("Document record not found");
} }
const { data: storageData, error: storageError } = await supabaseServer
.from("storage.objects")
.select("name")
.eq("id", documentData.raw_file)
.single();
if (storageError) {
console.error("Error fetching file name:", storageError);
sendEvent("error", {
message: "Error fetching file name",
error: storageError,
});
throw new Error("Storage data fetch failed");
}
const { data: fileData, error: fileError } = await supabase.storage const { data: fileData, error: fileError } = await supabase.storage
.from("documents") .from("documents")
.download(storageData.name); .download(`${user.id}/${uuid}.pdf`);
if (fileError) { if (fileError) {
console.error("Error downloading file from storage:", fileError); console.error("Error downloading file from storage:", fileError);
@ -425,6 +410,7 @@ Deno.serve(async (req) => {
const content = split[0].trim(); const content = split[0].trim();
const citationsStr = split[1]?.trim() || "{}"; const citationsStr = split[1]?.trim() || "{}";
console.log(`[${page.index}] Citations: ${citationsStr}`);
const citations = JSON.parse(citationsStr).citations || {}; const citations = JSON.parse(citationsStr).citations || {};
console.log("Generating Markdown for page:", page.index); console.log("Generating Markdown for page:", page.index);