Fix TTS: switch to 1.7B with ref_audio, speakable text on all lines

- Use 1.7B model (0.6B had tensor mismatch with cached prompts)
- Speak endpoint uses ref_audio directly (not cached pkl) as fallback
- Cache voice clone prompts in memory on startup
- Add SpeakableText component: 🔊 icon on each p and li element
- Remove old TTSReader sequential approach
- Add global exception handler to TTS server
- Fix profile localStorage caching
- inference_mode + bf16 optimization

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-13 12:14:06 +00:00
parent 1088b23790
commit 20210830cf
6 changed files with 440 additions and 158 deletions

View File

@@ -6,6 +6,7 @@ import AuthGuard from "@/components/auth-guard";
import NavBar from "@/components/nav-bar"; import NavBar from "@/components/nav-bar";
import { useApi } from "@/lib/use-api"; import { useApi } from "@/lib/use-api";
import ReactMarkdown from "react-markdown"; import ReactMarkdown from "react-markdown";
import SpeakableText from "@/components/speakable-text";
interface Category { interface Category {
ID: string; ID: string;
@@ -314,10 +315,18 @@ export default function KnowledgeDetailPage() {
h1: ({children}) => <h1 className="text-xl font-bold mt-6 mb-3">{children}</h1>, h1: ({children}) => <h1 className="text-xl font-bold mt-6 mb-3">{children}</h1>,
h2: ({children}) => <h2 className="text-lg font-bold mt-5 mb-2">{children}</h2>, h2: ({children}) => <h2 className="text-lg font-bold mt-5 mb-2">{children}</h2>,
h3: ({children}) => <h3 className="text-base font-bold mt-4 mb-2">{children}</h3>, h3: ({children}) => <h3 className="text-base font-bold mt-4 mb-2">{children}</h3>,
p: ({children}) => <p className="mb-3">{children}</p>, p: ({children, node}) => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const txt = node?.children?.map((c: any) => c.type === 'text' ? c.value : '').join('') || '';
return <p className="mb-3"><SpeakableText text={txt}>{children}</SpeakableText></p>;
},
ul: ({children}) => <ul className="list-disc ml-5 mb-3 space-y-1">{children}</ul>, ul: ({children}) => <ul className="list-disc ml-5 mb-3 space-y-1">{children}</ul>,
ol: ({children}) => <ol className="list-decimal ml-5 mb-3 space-y-1">{children}</ol>, ol: ({children}) => <ol className="list-decimal ml-5 mb-3 space-y-1">{children}</ol>,
li: ({children}) => <li className="leading-relaxed">{children}</li>, li: ({children, node}) => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const txt = node?.children?.map((c: any) => c.type === 'text' ? c.value : '').join('') || '';
return <li className="leading-relaxed"><SpeakableText text={txt}>{children}</SpeakableText></li>;
},
strong: ({children}) => <strong className="font-bold">{children}</strong>, strong: ({children}) => <strong className="font-bold">{children}</strong>,
blockquote: ({children}) => <blockquote className="border-l-2 border-[var(--color-primary)] pl-4 my-3 italic text-[var(--color-text-muted)]">{children}</blockquote>, blockquote: ({children}) => <blockquote className="border-l-2 border-[var(--color-primary)] pl-4 my-3 italic text-[var(--color-text-muted)]">{children}</blockquote>,
code: ({children}) => <code className="bg-[var(--color-bg-hover)] px-1.5 py-0.5 rounded text-xs">{children}</code>, code: ({children}) => <code className="bg-[var(--color-bg-hover)] px-1.5 py-0.5 rounded text-xs">{children}</code>,

View File

@@ -6,6 +6,7 @@ import AuthGuard from "@/components/auth-guard";
import NavBar from "@/components/nav-bar"; import NavBar from "@/components/nav-bar";
import { useApi } from "@/lib/use-api"; import { useApi } from "@/lib/use-api";
import ReactMarkdown from "react-markdown"; import ReactMarkdown from "react-markdown";
import SpeakableText from "@/components/speakable-text";
interface NoteDetail { interface NoteDetail {
ID: string; ID: string;
@@ -183,10 +184,18 @@ export default function NoteDetailPage() {
h1: ({children}) => <h1 className="text-xl font-bold mt-6 mb-3">{children}</h1>, h1: ({children}) => <h1 className="text-xl font-bold mt-6 mb-3">{children}</h1>,
h2: ({children}) => <h2 className="text-lg font-bold mt-5 mb-2">{children}</h2>, h2: ({children}) => <h2 className="text-lg font-bold mt-5 mb-2">{children}</h2>,
h3: ({children}) => <h3 className="text-base font-bold mt-4 mb-2">{children}</h3>, h3: ({children}) => <h3 className="text-base font-bold mt-4 mb-2">{children}</h3>,
p: ({children}) => <p className="mb-3">{children}</p>, p: ({children, node}) => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const txt = node?.children?.map((c: any) => c.type === 'text' ? c.value : '').join('') || '';
return <p className="mb-3"><SpeakableText text={txt}>{children}</SpeakableText></p>;
},
ul: ({children}) => <ul className="list-disc ml-5 mb-3 space-y-1">{children}</ul>, ul: ({children}) => <ul className="list-disc ml-5 mb-3 space-y-1">{children}</ul>,
ol: ({children}) => <ol className="list-decimal ml-5 mb-3 space-y-1">{children}</ol>, ol: ({children}) => <ol className="list-decimal ml-5 mb-3 space-y-1">{children}</ol>,
li: ({children}) => <li className="leading-relaxed">{children}</li>, li: ({children, node}) => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const txt = node?.children?.map((c: any) => c.type === 'text' ? c.value : '').join('') || '';
return <li className="leading-relaxed"><SpeakableText text={txt}>{children}</SpeakableText></li>;
},
strong: ({children}) => <strong className="font-bold">{children}</strong>, strong: ({children}) => <strong className="font-bold">{children}</strong>,
blockquote: ({children}) => <blockquote className="border-l-2 border-[var(--color-primary)] pl-4 my-3 italic text-[var(--color-text-muted)]">{children}</blockquote>, blockquote: ({children}) => <blockquote className="border-l-2 border-[var(--color-primary)] pl-4 my-3 italic text-[var(--color-text-muted)]">{children}</blockquote>,
}} }}

View File

@@ -45,8 +45,21 @@ export default function TTSPage() {
}, []); }, []);
const fetchProfiles = () => { const fetchProfiles = () => {
// 캐시 먼저
const cached = localStorage.getItem("tts_profiles");
if (cached) {
try {
const data = JSON.parse(cached);
setProfiles(data);
if (data.length > 0 && !selectedProfile) setSelectedProfile(data[0].id);
} catch {}
}
fetch("/api/tts/profiles").then(r => r.json()) fetch("/api/tts/profiles").then(r => r.json())
.then(setProfiles).catch(() => {}); .then(data => {
setProfiles(data);
localStorage.setItem("tts_profiles", JSON.stringify(data));
if (data.length > 0 && !selectedProfile) setSelectedProfile(data[0].id);
}).catch(() => {});
}; };
const startRecording = async () => { const startRecording = async () => {
@@ -97,6 +110,7 @@ export default function TTSPage() {
setRecordedUrl(null); setRecordedUrl(null);
setUploadedFile(null); setUploadedFile(null);
fetchProfiles(); fetchProfiles();
localStorage.removeItem("tts_profiles"); // 캐시 강제 갱신
setSelectedProfile(result.id); setSelectedProfile(result.id);
setTab("generate"); setTab("generate");
} catch (err) { } catch (err) {
@@ -125,9 +139,10 @@ export default function TTSPage() {
fd.append("text", text); fd.append("text", text);
fd.append("profile_id", selectedProfile); fd.append("profile_id", selectedProfile);
fd.append("language", language); fd.append("language", language);
const res = await fetch("/api/tts/generate", { method: "POST", body: fd }); const res = await fetch("/api/tts/speak", { method: "POST", body: fd });
if (!res.ok) throw new Error(`HTTP ${res.status}`); if (!res.ok) throw new Error(`HTTP ${res.status}`);
const blob = await res.blob(); const blob = await res.blob();
if (blob.size < 100) throw new Error("Empty audio");
setOutputUrl(URL.createObjectURL(blob)); setOutputUrl(URL.createObjectURL(blob));
} catch (err) { } catch (err) {
setError("생성 실패: " + (err instanceof Error ? err.message : "")); setError("생성 실패: " + (err instanceof Error ? err.message : ""));

View File

@@ -0,0 +1,85 @@
"use client";
import { useState, useRef, useEffect } from "react";
interface SpeakableProps {
children: React.ReactNode;
text: string;
}
let cachedProfileId: string | null = null;
let profileChecked = false;
export default function SpeakableText({ children, text }: SpeakableProps) {
const [playing, setPlaying] = useState(false);
const [loading, setLoading] = useState(false);
const [hasProfile, setHasProfile] = useState(false);
const audioRef = useRef<HTMLAudioElement | null>(null);
useEffect(() => {
if (profileChecked) {
setHasProfile(!!cachedProfileId);
return;
}
try {
const profiles = JSON.parse(localStorage.getItem("tts_profiles") || "[]");
if (profiles.length > 0) {
cachedProfileId = profiles[0].id;
setHasProfile(true);
}
profileChecked = true;
} catch {}
}, []);
const handleSpeak = async (e: React.MouseEvent) => {
e.preventDefault();
e.stopPropagation();
if (playing) {
audioRef.current?.pause();
setPlaying(false);
return;
}
if (!cachedProfileId || text.length < 5) return;
setLoading(true);
try {
const fd = new FormData();
fd.append("text", text);
fd.append("profile_id", cachedProfileId);
fd.append("language", "Korean");
const res = await fetch("/api/tts/speak", { method: "POST", body: fd });
if (!res.ok) { setLoading(false); return; }
const blob = await res.blob();
if (blob.size < 200) { setLoading(false); return; }
const url = URL.createObjectURL(blob);
const audio = new Audio(url);
audioRef.current = audio;
audio.onended = () => setPlaying(false);
setPlaying(true);
setLoading(false);
audio.play();
} catch {
setLoading(false);
}
};
if (!hasProfile || text.length < 5) return <>{children}</>;
return (
<>
{children}
<button
onClick={handleSpeak}
disabled={loading}
className="inline-flex items-center ml-1 text-[var(--color-text-muted)] hover:text-[var(--color-primary)] disabled:opacity-30 align-middle"
title={playing ? "중지" : "읽어주기"}
style={{ fontSize: "0.85em", verticalAlign: "middle", cursor: "pointer" }}
>
{loading ? "⏳" : playing ? "⏹" : "🔊"}
</button>
</>
);
}

View File

@@ -0,0 +1,159 @@
"use client";
import { useState, useEffect, useRef } from "react";
interface TTSReaderProps {
text: string;
}
interface VoiceProfile {
id: string;
name: string;
}
export default function TTSReader({ text }: TTSReaderProps) {
const [profiles, setProfiles] = useState<VoiceProfile[]>([]);
const [selectedProfile, setSelectedProfile] = useState("");
const [generating, setGenerating] = useState(false);
const [playing, setPlaying] = useState(false);
const [progress, setProgress] = useState("");
const audioRef = useRef<HTMLAudioElement | null>(null);
const stoppedRef = useRef(false);
const audioUrlsRef = useRef<string[]>([]);
useEffect(() => {
// localStorage 캐시
const cached = localStorage.getItem("tts_profiles");
if (cached) {
try {
const data = JSON.parse(cached);
setProfiles(data);
if (data.length > 0) setSelectedProfile(data[0].id);
} catch {}
}
// 백그라운드에서 갱신 (블록 안 됨)
fetch("/api/tts/profiles").then(r => r.json()).then(data => {
setProfiles(data);
if (data.length > 0 && !selectedProfile) setSelectedProfile(data[0].id);
localStorage.setItem("tts_profiles", JSON.stringify(data));
}).catch(() => {});
}, []);
const toSentences = (md: string): string[] => {
return md
.replace(/^#+\s+.*$/gm, "")
.replace(/\*\*/g, "")
.replace(/^[-*]\s+/gm, "")
.replace(/^>\s+/gm, "")
.replace(/---+/g, "")
.replace(/\[([^\]]+)\]\([^)]+\)/g, "$1")
.split("\n")
.map(s => s.trim())
.filter(s => s.length >= 10);
};
// 직접 동기 호출 — 바로 wav 반환
const speak = async (chunk: string): Promise<string | null> => {
const fd = new FormData();
fd.append("text", chunk);
fd.append("profile_id", selectedProfile);
fd.append("language", "Korean");
const res = await fetch("/api/tts/speak", { method: "POST", body: fd });
if (!res.ok) return null;
const blob = await res.blob();
return blob.size > 100 ? URL.createObjectURL(blob) : null;
};
const handleGenerate = async () => {
if (!selectedProfile || !text.trim()) return;
setGenerating(true);
setPlaying(true);
stoppedRef.current = false;
audioUrlsRef.current = [];
const sentences = toSentences(text);
let isAudioPlaying = false;
let playIdx = 0;
const playNext = () => {
if (stoppedRef.current) return;
if (playIdx >= audioUrlsRef.current.length) { isAudioPlaying = false; return; }
isAudioPlaying = true;
const a = new Audio(audioUrlsRef.current[playIdx++]);
audioRef.current = a;
a.onended = () => {
if (stoppedRef.current) return;
playIdx < audioUrlsRef.current.length ? playNext() : (isAudioPlaying = false);
};
a.play();
};
for (let i = 0; i < sentences.length; i++) {
if (stoppedRef.current) break;
setProgress(`${i + 1}/${sentences.length}`);
const url = await speak(sentences[i]);
if (url && !stoppedRef.current) {
audioUrlsRef.current.push(url);
if (!isAudioPlaying) playNext();
}
}
setGenerating(false);
setProgress("");
if (!isAudioPlaying) setPlaying(false);
};
const handleStop = () => {
stoppedRef.current = true;
audioRef.current?.pause();
setPlaying(false);
setGenerating(false);
setProgress("");
};
const handleReplay = () => {
if (audioUrlsRef.current.length === 0) return;
stoppedRef.current = false;
setPlaying(true);
let idx = 0;
const play = () => {
if (idx >= audioUrlsRef.current.length || stoppedRef.current) { setPlaying(false); return; }
const audio = new Audio(audioUrlsRef.current[idx]);
audioRef.current = audio;
idx++;
audio.onended = play;
audio.play();
};
play();
};
if (profiles.length === 0) return null;
return (
<div className="flex items-center gap-2 flex-wrap">
<select value={selectedProfile} onChange={e => setSelectedProfile(e.target.value)}
className="text-xs px-2 py-1 rounded bg-[var(--color-bg-hover)] border border-[var(--color-border)]">
{profiles.map(p => <option key={p.id} value={p.id}>{p.name}</option>)}
</select>
{playing || generating ? (
<button onClick={handleStop}
className="text-xs px-3 py-1 bg-red-500/20 text-red-400 rounded hover:bg-red-500/30">
{progress || "중지"}
</button>
) : (
<button onClick={handleGenerate} disabled={!selectedProfile}
className="text-xs px-3 py-1 bg-[var(--color-primary)]/20 text-[var(--color-primary)] rounded hover:bg-[var(--color-primary)]/30 disabled:opacity-40">
</button>
)}
{audioUrlsRef.current.length > 0 && !playing && !generating && (
<button onClick={handleReplay}
className="text-xs px-3 py-1 bg-[var(--color-bg-hover)] border border-[var(--color-border)] rounded">
</button>
)}
</div>
);
}

View File

@@ -1,159 +1,94 @@
""" """
Qwen3-TTS Voice Clone API Server Qwen3-TTS Voice Clone API Server (최적화 버전)
별도 프로세스로 실행 (GPU 메모리 관리를 위해) - 0.6B 모델 사용 (A10 속도 최적화)
- 모델 1회 로드, voice clone prompt 캐시
- inference_mode, bf16
- 문장 단위 분할
""" """
import os import os
import io import io
import base64
import tempfile
import torch
import soundfile as sf
import numpy as np
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import json import json
import pickle import pickle
import re
import tempfile
import time
import uuid
import threading
import numpy as np
import soundfile as sf
import torch
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI() app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
model = None
PROFILES_DIR = os.path.join(os.path.dirname(__file__), "voice-profiles") PROFILES_DIR = os.path.join(os.path.dirname(__file__), "voice-profiles")
os.makedirs(PROFILES_DIR, exist_ok=True) os.makedirs(PROFILES_DIR, exist_ok=True)
MODEL_NAME = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
model = None
prompt_cache = {} # profile_id → voice_clone_prompt
def get_model(): def get_model():
global model global model
if model is None: if model is None:
from qwen_tts import Qwen3TTSModel from qwen_tts import Qwen3TTSModel
print("Loading Qwen3-TTS model...") print(f"Loading {MODEL_NAME}...")
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
model = Qwen3TTSModel.from_pretrained( model = Qwen3TTSModel.from_pretrained(
"Qwen/Qwen3-TTS-12Hz-1.7B-Base", MODEL_NAME, device_map="cuda:0", dtype=dtype,
device_map="cuda:0",
dtype=torch.bfloat16,
) )
# 프로필 프롬프트 캐시 로드
load_all_prompts()
print("Model loaded!") print("Model loaded!")
return model return model
def load_all_prompts():
"""모든 프로필의 voice clone prompt를 메모리에 캐시"""
global prompt_cache
for f in os.listdir(PROFILES_DIR):
if f.endswith(".pkl"):
pid = f.replace(".pkl", "")
try:
with open(os.path.join(PROFILES_DIR, f), "rb") as fh:
prompt_cache[pid] = pickle.load(fh)
print(f" Cached prompt: {pid}")
except Exception as e:
print(f" Failed to cache {pid}: {e}")
def get_prompt(profile_id: str):
"""캐시에서 프롬프트 가져오기, 없으면 파일에서 로드"""
if profile_id in prompt_cache:
return prompt_cache[profile_id]
pkl_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl")
if os.path.exists(pkl_path):
with open(pkl_path, "rb") as f:
prompt = pickle.load(f)
prompt_cache[profile_id] = prompt
return prompt
return None
# === API ===
@app.get("/health") @app.get("/health")
@app.get("/api/tts/health") @app.get("/api/tts/health")
def health(): def health():
return {"status": "ok", "model_loaded": model is not None} return {"status": "ok", "model": MODEL_NAME, "model_loaded": model is not None}
@app.post("/api/tts/clone")
async def voice_clone(
text: str = Form(...),
language: str = Form("korean"),
ref_audio: UploadFile = File(...),
ref_text: str = Form(""),
):
"""참조 음성으로 보이스 클로닝"""
m = get_model()
# 참조 음성 저장
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
content = await ref_audio.read()
tmp.write(content)
tmp_path = tmp.name
try:
# wav 변환 (필요 시)
if not ref_audio.filename.endswith(".wav"):
wav_path = tmp_path + "_converted.wav"
os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null')
os.unlink(tmp_path)
tmp_path = wav_path
kwargs = {
"text": text,
"language": language,
"ref_audio": tmp_path,
}
if ref_text and ref_text.strip():
kwargs["ref_text"] = ref_text
else:
kwargs["x_vector_only_mode"] = True
wavs, sr = m.generate_voice_clone(**kwargs)
print(f"Clone generated: wavs={len(wavs)}, samples={len(wavs[0]) if len(wavs) > 0 else 0}, sr={sr}")
audio_data = np.array(wavs[0], dtype=np.float32)
buf = io.BytesIO()
sf.write(buf, audio_data, sr, format="WAV")
buf.seek(0)
return StreamingResponse(buf, media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=tts_output.wav"})
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
@app.post("/api/tts/design")
async def voice_design(
text: str = Form(...),
language: str = Form("korean"),
instruct: str = Form("A calm, professional Korean male voice"),
):
"""음성 디자인으로 생성 (참조 음성 없이)"""
m = get_model()
wavs, sr = m.generate_voice_design(text=text, instruct=instruct, language=language)
buf = io.BytesIO()
sf.write(buf, wavs[0], sr, format="WAV")
buf.seek(0)
return StreamingResponse(buf, media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=tts_output.wav"})
@app.post("/api/tts/profiles")
async def create_profile(
name: str = Form(...),
ref_audio: UploadFile = File(...),
ref_text: str = Form(""),
):
"""음성 프로필 등록: 참조 음성으로 보이스 프로필 생성 후 저장"""
m = get_model()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
content = await ref_audio.read()
tmp.write(content)
tmp_path = tmp.name
try:
if not ref_audio.filename.endswith(".wav"):
wav_path = tmp_path + "_converted.wav"
os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null')
os.unlink(tmp_path)
tmp_path = wav_path
# 프로필 생성
kwargs = {"ref_audio": tmp_path}
if ref_text and ref_text.strip():
kwargs["ref_text"] = ref_text
prompt = m.create_voice_clone_prompt(**kwargs)
else:
kwargs["x_vector_only_mode"] = True
prompt = m.create_voice_clone_prompt(**kwargs)
# 저장
profile_id = name.replace(" ", "_").lower()
profile_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl")
meta_path = os.path.join(PROFILES_DIR, f"{profile_id}.json")
with open(profile_path, "wb") as f:
pickle.dump(prompt, f)
with open(meta_path, "w") as f:
json.dump({"id": profile_id, "name": name, "ref_text": ref_text}, f, ensure_ascii=False)
return {"id": profile_id, "name": name, "status": "created"}
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
@app.get("/api/tts/profiles") @app.get("/api/tts/profiles")
def list_profiles(): def list_profiles():
"""저장된 음성 프로필 목록"""
profiles = [] profiles = []
for f in os.listdir(PROFILES_DIR): for f in os.listdir(PROFILES_DIR):
if f.endswith(".json"): if f.endswith(".json"):
@@ -161,49 +96,119 @@ def list_profiles():
profiles.append(json.load(fh)) profiles.append(json.load(fh))
return profiles return profiles
@app.post("/api/tts/profiles")
async def create_profile(
name: str = Form(...),
ref_audio: UploadFile = File(...),
ref_text: str = Form(""),
):
m = get_model()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
content = await ref_audio.read()
tmp.write(content)
tmp_path = tmp.name
try:
if not ref_audio.filename.endswith(".wav"):
wav_path = tmp_path + "_converted.wav"
os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null')
os.unlink(tmp_path)
tmp_path = wav_path
kwargs = {"ref_audio": tmp_path}
if ref_text and ref_text.strip():
kwargs["ref_text"] = ref_text
else:
kwargs["x_vector_only_mode"] = True
with torch.inference_mode():
prompt = m.create_voice_clone_prompt(**kwargs)
profile_id = name.replace(" ", "_").lower()
# wav, pkl, json 저장
import shutil
shutil.copy2(tmp_path, os.path.join(PROFILES_DIR, f"{profile_id}.wav"))
with open(os.path.join(PROFILES_DIR, f"{profile_id}.pkl"), "wb") as f:
pickle.dump(prompt, f)
with open(os.path.join(PROFILES_DIR, f"{profile_id}.json"), "w") as f:
json.dump({"id": profile_id, "name": name, "ref_text": ref_text}, f, ensure_ascii=False)
# 캐시에 추가
prompt_cache[profile_id] = prompt
return {"id": profile_id, "name": name, "status": "created"}
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
@app.delete("/api/tts/profiles/{profile_id}") @app.delete("/api/tts/profiles/{profile_id}")
def delete_profile(profile_id: str): def delete_profile(profile_id: str):
"""음성 프로필 삭제""" for ext in [".pkl", ".json", ".wav"]:
pkl = os.path.join(PROFILES_DIR, f"{profile_id}.pkl") p = os.path.join(PROFILES_DIR, f"{profile_id}{ext}")
meta = os.path.join(PROFILES_DIR, f"{profile_id}.json") if os.path.exists(p):
if os.path.exists(pkl): os.unlink(pkl) os.unlink(p)
if os.path.exists(meta): os.unlink(meta) prompt_cache.pop(profile_id, None)
return {"status": "deleted"} return {"status": "deleted"}
@app.post("/api/tts/generate")
async def generate_from_profile( @app.post("/api/tts/speak")
async def speak(
text: str = Form(...), text: str = Form(...),
profile_id: str = Form(...), profile_id: str = Form(...),
language: str = Form("korean"), language: str = Form("Korean"),
): ):
"""저장된 음성 프로필로 TTS 생성""" """한 문장 TTS — 캐시된 프롬프트 사용, 바로 wav 반환"""
m = get_model() m = get_model()
profile_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl") prompt = get_prompt(profile_id)
if not os.path.exists(profile_path): if prompt is None:
return {"error": f"Profile '{profile_id}' not found"}, 404 # 프롬프트가 없으면 ref_audio로 직접
meta_path = os.path.join(PROFILES_DIR, f"{profile_id}.json")
ref_audio_path = os.path.join(PROFILES_DIR, f"{profile_id}.wav")
if not os.path.exists(ref_audio_path):
return {"error": "Profile not found"}, 404
with open(profile_path, "rb") as f: with open(meta_path) as f:
prompt = pickle.load(f) meta = json.load(f)
print(f"Generating with profile '{profile_id}', text='{text[:50]}...', language={language}") kwargs = {"text": text, "language": language, "ref_audio": ref_audio_path}
wavs, sr = m.generate_voice_clone( if meta.get("ref_text"):
text=text, kwargs["ref_text"] = meta["ref_text"]
language=language, else:
voice_clone_prompt=prompt, kwargs["x_vector_only_mode"] = True
)
print(f"Generated: wavs={len(wavs)}, samples={len(wavs[0]) if len(wavs) > 0 else 0}, sr={sr}")
if len(wavs) == 0 or len(wavs[0]) == 0: start = time.perf_counter()
return {"error": "Empty audio generated"}, 500 with torch.inference_mode():
wavs, sr = m.generate_voice_clone(**kwargs)
elapsed = time.perf_counter() - start
else:
start = time.perf_counter()
with torch.inference_mode():
wavs, sr = m.generate_voice_clone(
text=text, language=language, voice_clone_prompt=prompt,
)
elapsed = time.perf_counter() - start
audio_data = np.array(wavs[0], dtype=np.float32) audio_data = np.array(wavs[0], dtype=np.float32)
print(f"speak: {len(text)} chars → {len(audio_data)/sr:.1f}s audio in {elapsed:.1f}s")
buf = io.BytesIO() buf = io.BytesIO()
sf.write(buf, audio_data, sr, format="WAV") sf.write(buf, audio_data, sr, format="WAV")
buf.seek(0) buf.seek(0)
return StreamingResponse(buf, media_type="audio/wav")
from fastapi.responses import JSONResponse
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
import traceback
traceback.print_exc()
return JSONResponse(status_code=500, content={"error": str(exc)})
return StreamingResponse(buf, media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=tts_output.wav"})
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn