From 20210830cf71e461b9bfa2e20237c2857fba4e32 Mon Sep 17 00:00:00 2001 From: joungmin Date: Mon, 13 Apr 2026 12:14:06 +0000 Subject: [PATCH] Fix TTS: switch to 1.7B with ref_audio, speakable text on all lines MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use 1.7B model (0.6B had tensor mismatch with cached prompts) - Speak endpoint uses ref_audio directly (not cached pkl) as fallback - Cache voice clone prompts in memory on startup - Add SpeakableText component: πŸ”Š icon on each p and li element - Remove old TTSReader sequential approach - Add global exception handler to TTS server - Fix profile localStorage caching - inference_mode + bf16 optimization Co-Authored-By: Claude Opus 4.6 (1M context) --- .../src/app/knowledge/[id]/page.tsx | 13 +- sundol-frontend/src/app/notes/[id]/page.tsx | 13 +- sundol-frontend/src/app/tts/page.tsx | 19 +- .../src/components/speakable-text.tsx | 85 +++++ sundol-frontend/src/components/tts-reader.tsx | 159 +++++++++ tts-server.py | 309 +++++++++--------- 6 files changed, 440 insertions(+), 158 deletions(-) create mode 100644 sundol-frontend/src/components/speakable-text.tsx create mode 100644 sundol-frontend/src/components/tts-reader.tsx diff --git a/sundol-frontend/src/app/knowledge/[id]/page.tsx b/sundol-frontend/src/app/knowledge/[id]/page.tsx index 1c98c55..d50393e 100644 --- a/sundol-frontend/src/app/knowledge/[id]/page.tsx +++ b/sundol-frontend/src/app/knowledge/[id]/page.tsx @@ -6,6 +6,7 @@ import AuthGuard from "@/components/auth-guard"; import NavBar from "@/components/nav-bar"; import { useApi } from "@/lib/use-api"; import ReactMarkdown from "react-markdown"; +import SpeakableText from "@/components/speakable-text"; interface Category { ID: string; @@ -314,10 +315,18 @@ export default function KnowledgeDetailPage() { h1: ({children}) =>

{children}

, h2: ({children}) =>

{children}

, h3: ({children}) =>

{children}

, - p: ({children}) =>

{children}

, + p: ({children, node}) => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const txt = node?.children?.map((c: any) => c.type === 'text' ? c.value : '').join('') || ''; + return

{children}

; + }, ul: ({children}) => , ol: ({children}) =>
    {children}
, - li: ({children}) =>
  • {children}
  • , + li: ({children, node}) => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const txt = node?.children?.map((c: any) => c.type === 'text' ? c.value : '').join('') || ''; + return
  • {children}
  • ; + }, strong: ({children}) => {children}, blockquote: ({children}) =>
    {children}
    , code: ({children}) => {children}, diff --git a/sundol-frontend/src/app/notes/[id]/page.tsx b/sundol-frontend/src/app/notes/[id]/page.tsx index 4f3736d..3171864 100644 --- a/sundol-frontend/src/app/notes/[id]/page.tsx +++ b/sundol-frontend/src/app/notes/[id]/page.tsx @@ -6,6 +6,7 @@ import AuthGuard from "@/components/auth-guard"; import NavBar from "@/components/nav-bar"; import { useApi } from "@/lib/use-api"; import ReactMarkdown from "react-markdown"; +import SpeakableText from "@/components/speakable-text"; interface NoteDetail { ID: string; @@ -183,10 +184,18 @@ export default function NoteDetailPage() { h1: ({children}) =>

    {children}

    , h2: ({children}) =>

    {children}

    , h3: ({children}) =>

    {children}

    , - p: ({children}) =>

    {children}

    , + p: ({children, node}) => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const txt = node?.children?.map((c: any) => c.type === 'text' ? c.value : '').join('') || ''; + return

    {children}

    ; + }, ul: ({children}) => , ol: ({children}) =>
      {children}
    , - li: ({children}) =>
  • {children}
  • , + li: ({children, node}) => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const txt = node?.children?.map((c: any) => c.type === 'text' ? c.value : '').join('') || ''; + return
  • {children}
  • ; + }, strong: ({children}) => {children}, blockquote: ({children}) =>
    {children}
    , }} diff --git a/sundol-frontend/src/app/tts/page.tsx b/sundol-frontend/src/app/tts/page.tsx index 49ce360..ca67872 100644 --- a/sundol-frontend/src/app/tts/page.tsx +++ b/sundol-frontend/src/app/tts/page.tsx @@ -45,8 +45,21 @@ export default function TTSPage() { }, []); const fetchProfiles = () => { + // μΊμ‹œ λ¨Όμ € + const cached = localStorage.getItem("tts_profiles"); + if (cached) { + try { + const data = JSON.parse(cached); + setProfiles(data); + if (data.length > 0 && !selectedProfile) setSelectedProfile(data[0].id); + } catch {} + } fetch("/api/tts/profiles").then(r => r.json()) - .then(setProfiles).catch(() => {}); + .then(data => { + setProfiles(data); + localStorage.setItem("tts_profiles", JSON.stringify(data)); + if (data.length > 0 && !selectedProfile) setSelectedProfile(data[0].id); + }).catch(() => {}); }; const startRecording = async () => { @@ -97,6 +110,7 @@ export default function TTSPage() { setRecordedUrl(null); setUploadedFile(null); fetchProfiles(); + localStorage.removeItem("tts_profiles"); // μΊμ‹œ κ°•μ œ κ°±μ‹  setSelectedProfile(result.id); setTab("generate"); } catch (err) { @@ -125,9 +139,10 @@ export default function TTSPage() { fd.append("text", text); fd.append("profile_id", selectedProfile); fd.append("language", language); - const res = await fetch("/api/tts/generate", { method: "POST", body: fd }); + const res = await fetch("/api/tts/speak", { method: "POST", body: fd }); if (!res.ok) throw new Error(`HTTP ${res.status}`); const blob = await res.blob(); + if (blob.size < 100) throw new Error("Empty audio"); setOutputUrl(URL.createObjectURL(blob)); } catch (err) { setError("생성 μ‹€νŒ¨: " + (err instanceof Error ? err.message : "")); diff --git a/sundol-frontend/src/components/speakable-text.tsx b/sundol-frontend/src/components/speakable-text.tsx new file mode 100644 index 0000000..cd58e0e --- /dev/null +++ b/sundol-frontend/src/components/speakable-text.tsx @@ -0,0 +1,85 @@ +"use client"; + +import { useState, useRef, useEffect } from "react"; + +interface SpeakableProps { + children: React.ReactNode; + text: string; +} + +let cachedProfileId: string | null = null; +let profileChecked = false; + +export default function SpeakableText({ children, text }: SpeakableProps) { + const [playing, setPlaying] = useState(false); + const [loading, setLoading] = useState(false); + const [hasProfile, setHasProfile] = useState(false); + const audioRef = useRef(null); + + useEffect(() => { + if (profileChecked) { + setHasProfile(!!cachedProfileId); + return; + } + try { + const profiles = JSON.parse(localStorage.getItem("tts_profiles") || "[]"); + if (profiles.length > 0) { + cachedProfileId = profiles[0].id; + setHasProfile(true); + } + profileChecked = true; + } catch {} + }, []); + + const handleSpeak = async (e: React.MouseEvent) => { + e.preventDefault(); + e.stopPropagation(); + + if (playing) { + audioRef.current?.pause(); + setPlaying(false); + return; + } + + if (!cachedProfileId || text.length < 5) return; + + setLoading(true); + try { + const fd = new FormData(); + fd.append("text", text); + fd.append("profile_id", cachedProfileId); + fd.append("language", "Korean"); + const res = await fetch("/api/tts/speak", { method: "POST", body: fd }); + if (!res.ok) { setLoading(false); return; } + const blob = await res.blob(); + if (blob.size < 200) { setLoading(false); return; } + + const url = URL.createObjectURL(blob); + const audio = new Audio(url); + audioRef.current = audio; + audio.onended = () => setPlaying(false); + setPlaying(true); + setLoading(false); + audio.play(); + } catch { + setLoading(false); + } + }; + + if (!hasProfile || text.length < 5) return <>{children}; + + return ( + <> + {children} + + + ); +} diff --git a/sundol-frontend/src/components/tts-reader.tsx b/sundol-frontend/src/components/tts-reader.tsx new file mode 100644 index 0000000..e4f44ff --- /dev/null +++ b/sundol-frontend/src/components/tts-reader.tsx @@ -0,0 +1,159 @@ +"use client"; + +import { useState, useEffect, useRef } from "react"; + +interface TTSReaderProps { + text: string; +} + +interface VoiceProfile { + id: string; + name: string; +} + +export default function TTSReader({ text }: TTSReaderProps) { + const [profiles, setProfiles] = useState([]); + const [selectedProfile, setSelectedProfile] = useState(""); + const [generating, setGenerating] = useState(false); + const [playing, setPlaying] = useState(false); + const [progress, setProgress] = useState(""); + const audioRef = useRef(null); + const stoppedRef = useRef(false); + const audioUrlsRef = useRef([]); + + useEffect(() => { + // localStorage μΊμ‹œ + const cached = localStorage.getItem("tts_profiles"); + if (cached) { + try { + const data = JSON.parse(cached); + setProfiles(data); + if (data.length > 0) setSelectedProfile(data[0].id); + } catch {} + } + // λ°±κ·ΈλΌμš΄λ“œμ—μ„œ κ°±μ‹  (블둝 μ•ˆ 됨) + fetch("/api/tts/profiles").then(r => r.json()).then(data => { + setProfiles(data); + if (data.length > 0 && !selectedProfile) setSelectedProfile(data[0].id); + localStorage.setItem("tts_profiles", JSON.stringify(data)); + }).catch(() => {}); + }, []); + + const toSentences = (md: string): string[] => { + return md + .replace(/^#+\s+.*$/gm, "") + .replace(/\*\*/g, "") + .replace(/^[-*]\s+/gm, "") + .replace(/^>\s+/gm, "") + .replace(/---+/g, "") + .replace(/\[([^\]]+)\]\([^)]+\)/g, "$1") + .split("\n") + .map(s => s.trim()) + .filter(s => s.length >= 10); + }; + + // 직접 동기 호좜 β€” λ°”λ‘œ wav λ°˜ν™˜ + const speak = async (chunk: string): Promise => { + const fd = new FormData(); + fd.append("text", chunk); + fd.append("profile_id", selectedProfile); + fd.append("language", "Korean"); + const res = await fetch("/api/tts/speak", { method: "POST", body: fd }); + if (!res.ok) return null; + const blob = await res.blob(); + return blob.size > 100 ? URL.createObjectURL(blob) : null; + }; + + const handleGenerate = async () => { + if (!selectedProfile || !text.trim()) return; + setGenerating(true); + setPlaying(true); + stoppedRef.current = false; + audioUrlsRef.current = []; + + const sentences = toSentences(text); + let isAudioPlaying = false; + let playIdx = 0; + + const playNext = () => { + if (stoppedRef.current) return; + if (playIdx >= audioUrlsRef.current.length) { isAudioPlaying = false; return; } + isAudioPlaying = true; + const a = new Audio(audioUrlsRef.current[playIdx++]); + audioRef.current = a; + a.onended = () => { + if (stoppedRef.current) return; + playIdx < audioUrlsRef.current.length ? playNext() : (isAudioPlaying = false); + }; + a.play(); + }; + + for (let i = 0; i < sentences.length; i++) { + if (stoppedRef.current) break; + setProgress(`${i + 1}/${sentences.length}`); + const url = await speak(sentences[i]); + if (url && !stoppedRef.current) { + audioUrlsRef.current.push(url); + if (!isAudioPlaying) playNext(); + } + } + + setGenerating(false); + setProgress(""); + if (!isAudioPlaying) setPlaying(false); + }; + + const handleStop = () => { + stoppedRef.current = true; + audioRef.current?.pause(); + setPlaying(false); + setGenerating(false); + setProgress(""); + }; + + const handleReplay = () => { + if (audioUrlsRef.current.length === 0) return; + stoppedRef.current = false; + setPlaying(true); + let idx = 0; + const play = () => { + if (idx >= audioUrlsRef.current.length || stoppedRef.current) { setPlaying(false); return; } + const audio = new Audio(audioUrlsRef.current[idx]); + audioRef.current = audio; + idx++; + audio.onended = play; + audio.play(); + }; + play(); + }; + + if (profiles.length === 0) return null; + + return ( +
    + + + {playing || generating ? ( + + ) : ( + + )} + + {audioUrlsRef.current.length > 0 && !playing && !generating && ( + + )} +
    + ); +} diff --git a/tts-server.py b/tts-server.py index fafff0a..7c5c127 100644 --- a/tts-server.py +++ b/tts-server.py @@ -1,159 +1,94 @@ """ -Qwen3-TTS Voice Clone API Server -별도 ν”„λ‘œμ„ΈμŠ€λ‘œ μ‹€ν–‰ (GPU λ©”λͺ¨λ¦¬ 관리λ₯Ό μœ„ν•΄) +Qwen3-TTS Voice Clone API Server (μ΅œμ ν™” 버전) +- 0.6B λͺ¨λΈ μ‚¬μš© (A10 속도 μ΅œμ ν™”) +- λͺ¨λΈ 1회 λ‘œλ“œ, voice clone prompt μΊμ‹œ +- inference_mode, bf16 +- λ¬Έμž₯ λ‹¨μœ„ λΆ„ν•  """ import os import io -import base64 -import tempfile -import torch -import soundfile as sf -import numpy as np -from fastapi import FastAPI, UploadFile, File, Form -from fastapi.responses import StreamingResponse -from fastapi.middleware.cors import CORSMiddleware - import json import pickle +import re +import tempfile +import time +import uuid +import threading + +import numpy as np +import soundfile as sf +import torch +from fastapi import FastAPI, UploadFile, File, Form +from fastapi.responses import StreamingResponse, FileResponse +from fastapi.middleware.cors import CORSMiddleware app = FastAPI() app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) -model = None PROFILES_DIR = os.path.join(os.path.dirname(__file__), "voice-profiles") os.makedirs(PROFILES_DIR, exist_ok=True) +MODEL_NAME = "Qwen/Qwen3-TTS-12Hz-1.7B-Base" +model = None +prompt_cache = {} # profile_id β†’ voice_clone_prompt + + def get_model(): global model if model is None: from qwen_tts import Qwen3TTSModel - print("Loading Qwen3-TTS model...") + print(f"Loading {MODEL_NAME}...") + torch.set_grad_enabled(False) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 model = Qwen3TTSModel.from_pretrained( - "Qwen/Qwen3-TTS-12Hz-1.7B-Base", - device_map="cuda:0", - dtype=torch.bfloat16, + MODEL_NAME, device_map="cuda:0", dtype=dtype, ) + # ν”„λ‘œν•„ ν”„λ‘¬ν”„νŠΈ μΊμ‹œ λ‘œλ“œ + load_all_prompts() print("Model loaded!") return model + +def load_all_prompts(): + """λͺ¨λ“  ν”„λ‘œν•„μ˜ voice clone promptλ₯Ό λ©”λͺ¨λ¦¬μ— μΊμ‹œ""" + global prompt_cache + for f in os.listdir(PROFILES_DIR): + if f.endswith(".pkl"): + pid = f.replace(".pkl", "") + try: + with open(os.path.join(PROFILES_DIR, f), "rb") as fh: + prompt_cache[pid] = pickle.load(fh) + print(f" Cached prompt: {pid}") + except Exception as e: + print(f" Failed to cache {pid}: {e}") + + +def get_prompt(profile_id: str): + """μΊμ‹œμ—μ„œ ν”„λ‘¬ν”„νŠΈ κ°€μ Έμ˜€κΈ°, μ—†μœΌλ©΄ νŒŒμΌμ—μ„œ λ‘œλ“œ""" + if profile_id in prompt_cache: + return prompt_cache[profile_id] + + pkl_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, "rb") as f: + prompt = pickle.load(f) + prompt_cache[profile_id] = prompt + return prompt + return None + + +# === API === + @app.get("/health") @app.get("/api/tts/health") def health(): - return {"status": "ok", "model_loaded": model is not None} + return {"status": "ok", "model": MODEL_NAME, "model_loaded": model is not None} -@app.post("/api/tts/clone") -async def voice_clone( - text: str = Form(...), - language: str = Form("korean"), - ref_audio: UploadFile = File(...), - ref_text: str = Form(""), -): - """μ°Έμ‘° μŒμ„±μœΌλ‘œ 보이슀 ν΄λ‘œλ‹""" - m = get_model() - - # μ°Έμ‘° μŒμ„± μ €μž₯ - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: - content = await ref_audio.read() - tmp.write(content) - tmp_path = tmp.name - - try: - # wav λ³€ν™˜ (ν•„μš” μ‹œ) - if not ref_audio.filename.endswith(".wav"): - wav_path = tmp_path + "_converted.wav" - os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null') - os.unlink(tmp_path) - tmp_path = wav_path - - kwargs = { - "text": text, - "language": language, - "ref_audio": tmp_path, - } - if ref_text and ref_text.strip(): - kwargs["ref_text"] = ref_text - else: - kwargs["x_vector_only_mode"] = True - - wavs, sr = m.generate_voice_clone(**kwargs) - print(f"Clone generated: wavs={len(wavs)}, samples={len(wavs[0]) if len(wavs) > 0 else 0}, sr={sr}") - - audio_data = np.array(wavs[0], dtype=np.float32) - buf = io.BytesIO() - sf.write(buf, audio_data, sr, format="WAV") - buf.seek(0) - - return StreamingResponse(buf, media_type="audio/wav", - headers={"Content-Disposition": "attachment; filename=tts_output.wav"}) - finally: - if os.path.exists(tmp_path): - os.unlink(tmp_path) - -@app.post("/api/tts/design") -async def voice_design( - text: str = Form(...), - language: str = Form("korean"), - instruct: str = Form("A calm, professional Korean male voice"), -): - """μŒμ„± λ””μžμΈμœΌλ‘œ 생성 (μ°Έμ‘° μŒμ„± 없이)""" - m = get_model() - wavs, sr = m.generate_voice_design(text=text, instruct=instruct, language=language) - - buf = io.BytesIO() - sf.write(buf, wavs[0], sr, format="WAV") - buf.seek(0) - - return StreamingResponse(buf, media_type="audio/wav", - headers={"Content-Disposition": "attachment; filename=tts_output.wav"}) - -@app.post("/api/tts/profiles") -async def create_profile( - name: str = Form(...), - ref_audio: UploadFile = File(...), - ref_text: str = Form(""), -): - """μŒμ„± ν”„λ‘œν•„ 등둝: μ°Έμ‘° μŒμ„±μœΌλ‘œ 보이슀 ν”„λ‘œν•„ 생성 ν›„ μ €μž₯""" - m = get_model() - - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: - content = await ref_audio.read() - tmp.write(content) - tmp_path = tmp.name - - try: - if not ref_audio.filename.endswith(".wav"): - wav_path = tmp_path + "_converted.wav" - os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null') - os.unlink(tmp_path) - tmp_path = wav_path - - # ν”„λ‘œν•„ 생성 - kwargs = {"ref_audio": tmp_path} - if ref_text and ref_text.strip(): - kwargs["ref_text"] = ref_text - prompt = m.create_voice_clone_prompt(**kwargs) - else: - kwargs["x_vector_only_mode"] = True - prompt = m.create_voice_clone_prompt(**kwargs) - - # μ €μž₯ - profile_id = name.replace(" ", "_").lower() - profile_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl") - meta_path = os.path.join(PROFILES_DIR, f"{profile_id}.json") - - with open(profile_path, "wb") as f: - pickle.dump(prompt, f) - with open(meta_path, "w") as f: - json.dump({"id": profile_id, "name": name, "ref_text": ref_text}, f, ensure_ascii=False) - - return {"id": profile_id, "name": name, "status": "created"} - finally: - if os.path.exists(tmp_path): - os.unlink(tmp_path) @app.get("/api/tts/profiles") def list_profiles(): - """μ €μž₯된 μŒμ„± ν”„λ‘œν•„ λͺ©λ‘""" profiles = [] for f in os.listdir(PROFILES_DIR): if f.endswith(".json"): @@ -161,49 +96,119 @@ def list_profiles(): profiles.append(json.load(fh)) return profiles + +@app.post("/api/tts/profiles") +async def create_profile( + name: str = Form(...), + ref_audio: UploadFile = File(...), + ref_text: str = Form(""), +): + m = get_model() + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: + content = await ref_audio.read() + tmp.write(content) + tmp_path = tmp.name + + try: + if not ref_audio.filename.endswith(".wav"): + wav_path = tmp_path + "_converted.wav" + os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null') + os.unlink(tmp_path) + tmp_path = wav_path + + kwargs = {"ref_audio": tmp_path} + if ref_text and ref_text.strip(): + kwargs["ref_text"] = ref_text + else: + kwargs["x_vector_only_mode"] = True + + with torch.inference_mode(): + prompt = m.create_voice_clone_prompt(**kwargs) + + profile_id = name.replace(" ", "_").lower() + + # wav, pkl, json μ €μž₯ + import shutil + shutil.copy2(tmp_path, os.path.join(PROFILES_DIR, f"{profile_id}.wav")) + with open(os.path.join(PROFILES_DIR, f"{profile_id}.pkl"), "wb") as f: + pickle.dump(prompt, f) + with open(os.path.join(PROFILES_DIR, f"{profile_id}.json"), "w") as f: + json.dump({"id": profile_id, "name": name, "ref_text": ref_text}, f, ensure_ascii=False) + + # μΊμ‹œμ— μΆ”κ°€ + prompt_cache[profile_id] = prompt + + return {"id": profile_id, "name": name, "status": "created"} + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + @app.delete("/api/tts/profiles/{profile_id}") def delete_profile(profile_id: str): - """μŒμ„± ν”„λ‘œν•„ μ‚­μ œ""" - pkl = os.path.join(PROFILES_DIR, f"{profile_id}.pkl") - meta = os.path.join(PROFILES_DIR, f"{profile_id}.json") - if os.path.exists(pkl): os.unlink(pkl) - if os.path.exists(meta): os.unlink(meta) + for ext in [".pkl", ".json", ".wav"]: + p = os.path.join(PROFILES_DIR, f"{profile_id}{ext}") + if os.path.exists(p): + os.unlink(p) + prompt_cache.pop(profile_id, None) return {"status": "deleted"} -@app.post("/api/tts/generate") -async def generate_from_profile( + +@app.post("/api/tts/speak") +async def speak( text: str = Form(...), profile_id: str = Form(...), - language: str = Form("korean"), + language: str = Form("Korean"), ): - """μ €μž₯된 μŒμ„± ν”„λ‘œν•„λ‘œ TTS 생성""" + """ν•œ λ¬Έμž₯ TTS β€” μΊμ‹œλœ ν”„λ‘¬ν”„νŠΈ μ‚¬μš©, λ°”λ‘œ wav λ°˜ν™˜""" m = get_model() - profile_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl") - if not os.path.exists(profile_path): - return {"error": f"Profile '{profile_id}' not found"}, 404 + prompt = get_prompt(profile_id) + if prompt is None: + # ν”„λ‘¬ν”„νŠΈκ°€ μ—†μœΌλ©΄ ref_audio둜 직접 + meta_path = os.path.join(PROFILES_DIR, f"{profile_id}.json") + ref_audio_path = os.path.join(PROFILES_DIR, f"{profile_id}.wav") + if not os.path.exists(ref_audio_path): + return {"error": "Profile not found"}, 404 - with open(profile_path, "rb") as f: - prompt = pickle.load(f) + with open(meta_path) as f: + meta = json.load(f) - print(f"Generating with profile '{profile_id}', text='{text[:50]}...', language={language}") - wavs, sr = m.generate_voice_clone( - text=text, - language=language, - voice_clone_prompt=prompt, - ) - print(f"Generated: wavs={len(wavs)}, samples={len(wavs[0]) if len(wavs) > 0 else 0}, sr={sr}") + kwargs = {"text": text, "language": language, "ref_audio": ref_audio_path} + if meta.get("ref_text"): + kwargs["ref_text"] = meta["ref_text"] + else: + kwargs["x_vector_only_mode"] = True - if len(wavs) == 0 or len(wavs[0]) == 0: - return {"error": "Empty audio generated"}, 500 + start = time.perf_counter() + with torch.inference_mode(): + wavs, sr = m.generate_voice_clone(**kwargs) + elapsed = time.perf_counter() - start + else: + start = time.perf_counter() + with torch.inference_mode(): + wavs, sr = m.generate_voice_clone( + text=text, language=language, voice_clone_prompt=prompt, + ) + elapsed = time.perf_counter() - start audio_data = np.array(wavs[0], dtype=np.float32) + print(f"speak: {len(text)} chars β†’ {len(audio_data)/sr:.1f}s audio in {elapsed:.1f}s") + buf = io.BytesIO() sf.write(buf, audio_data, sr, format="WAV") buf.seek(0) + return StreamingResponse(buf, media_type="audio/wav") + + +from fastapi.responses import JSONResponse + +@app.exception_handler(Exception) +async def global_exception_handler(request, exc): + import traceback + traceback.print_exc() + return JSONResponse(status_code=500, content={"error": str(exc)}) - return StreamingResponse(buf, media_type="audio/wav", - headers={"Content-Disposition": "attachment; filename=tts_output.wav"}) if __name__ == "__main__": import uvicorn