Fix TTS: switch to 1.7B with ref_audio, speakable text on all lines
- Use 1.7B model (0.6B had tensor mismatch with cached prompts) - Speak endpoint uses ref_audio directly (not cached pkl) as fallback - Cache voice clone prompts in memory on startup - Add SpeakableText component: 🔊 icon on each p and li element - Remove old TTSReader sequential approach - Add global exception handler to TTS server - Fix profile localStorage caching - inference_mode + bf16 optimization Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import AuthGuard from "@/components/auth-guard";
|
|||||||
import NavBar from "@/components/nav-bar";
|
import NavBar from "@/components/nav-bar";
|
||||||
import { useApi } from "@/lib/use-api";
|
import { useApi } from "@/lib/use-api";
|
||||||
import ReactMarkdown from "react-markdown";
|
import ReactMarkdown from "react-markdown";
|
||||||
|
import SpeakableText from "@/components/speakable-text";
|
||||||
|
|
||||||
interface Category {
|
interface Category {
|
||||||
ID: string;
|
ID: string;
|
||||||
@@ -314,10 +315,18 @@ export default function KnowledgeDetailPage() {
|
|||||||
h1: ({children}) => <h1 className="text-xl font-bold mt-6 mb-3">{children}</h1>,
|
h1: ({children}) => <h1 className="text-xl font-bold mt-6 mb-3">{children}</h1>,
|
||||||
h2: ({children}) => <h2 className="text-lg font-bold mt-5 mb-2">{children}</h2>,
|
h2: ({children}) => <h2 className="text-lg font-bold mt-5 mb-2">{children}</h2>,
|
||||||
h3: ({children}) => <h3 className="text-base font-bold mt-4 mb-2">{children}</h3>,
|
h3: ({children}) => <h3 className="text-base font-bold mt-4 mb-2">{children}</h3>,
|
||||||
p: ({children}) => <p className="mb-3">{children}</p>,
|
p: ({children, node}) => {
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
const txt = node?.children?.map((c: any) => c.type === 'text' ? c.value : '').join('') || '';
|
||||||
|
return <p className="mb-3"><SpeakableText text={txt}>{children}</SpeakableText></p>;
|
||||||
|
},
|
||||||
ul: ({children}) => <ul className="list-disc ml-5 mb-3 space-y-1">{children}</ul>,
|
ul: ({children}) => <ul className="list-disc ml-5 mb-3 space-y-1">{children}</ul>,
|
||||||
ol: ({children}) => <ol className="list-decimal ml-5 mb-3 space-y-1">{children}</ol>,
|
ol: ({children}) => <ol className="list-decimal ml-5 mb-3 space-y-1">{children}</ol>,
|
||||||
li: ({children}) => <li className="leading-relaxed">{children}</li>,
|
li: ({children, node}) => {
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
const txt = node?.children?.map((c: any) => c.type === 'text' ? c.value : '').join('') || '';
|
||||||
|
return <li className="leading-relaxed"><SpeakableText text={txt}>{children}</SpeakableText></li>;
|
||||||
|
},
|
||||||
strong: ({children}) => <strong className="font-bold">{children}</strong>,
|
strong: ({children}) => <strong className="font-bold">{children}</strong>,
|
||||||
blockquote: ({children}) => <blockquote className="border-l-2 border-[var(--color-primary)] pl-4 my-3 italic text-[var(--color-text-muted)]">{children}</blockquote>,
|
blockquote: ({children}) => <blockquote className="border-l-2 border-[var(--color-primary)] pl-4 my-3 italic text-[var(--color-text-muted)]">{children}</blockquote>,
|
||||||
code: ({children}) => <code className="bg-[var(--color-bg-hover)] px-1.5 py-0.5 rounded text-xs">{children}</code>,
|
code: ({children}) => <code className="bg-[var(--color-bg-hover)] px-1.5 py-0.5 rounded text-xs">{children}</code>,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import AuthGuard from "@/components/auth-guard";
|
|||||||
import NavBar from "@/components/nav-bar";
|
import NavBar from "@/components/nav-bar";
|
||||||
import { useApi } from "@/lib/use-api";
|
import { useApi } from "@/lib/use-api";
|
||||||
import ReactMarkdown from "react-markdown";
|
import ReactMarkdown from "react-markdown";
|
||||||
|
import SpeakableText from "@/components/speakable-text";
|
||||||
|
|
||||||
interface NoteDetail {
|
interface NoteDetail {
|
||||||
ID: string;
|
ID: string;
|
||||||
@@ -183,10 +184,18 @@ export default function NoteDetailPage() {
|
|||||||
h1: ({children}) => <h1 className="text-xl font-bold mt-6 mb-3">{children}</h1>,
|
h1: ({children}) => <h1 className="text-xl font-bold mt-6 mb-3">{children}</h1>,
|
||||||
h2: ({children}) => <h2 className="text-lg font-bold mt-5 mb-2">{children}</h2>,
|
h2: ({children}) => <h2 className="text-lg font-bold mt-5 mb-2">{children}</h2>,
|
||||||
h3: ({children}) => <h3 className="text-base font-bold mt-4 mb-2">{children}</h3>,
|
h3: ({children}) => <h3 className="text-base font-bold mt-4 mb-2">{children}</h3>,
|
||||||
p: ({children}) => <p className="mb-3">{children}</p>,
|
p: ({children, node}) => {
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
const txt = node?.children?.map((c: any) => c.type === 'text' ? c.value : '').join('') || '';
|
||||||
|
return <p className="mb-3"><SpeakableText text={txt}>{children}</SpeakableText></p>;
|
||||||
|
},
|
||||||
ul: ({children}) => <ul className="list-disc ml-5 mb-3 space-y-1">{children}</ul>,
|
ul: ({children}) => <ul className="list-disc ml-5 mb-3 space-y-1">{children}</ul>,
|
||||||
ol: ({children}) => <ol className="list-decimal ml-5 mb-3 space-y-1">{children}</ol>,
|
ol: ({children}) => <ol className="list-decimal ml-5 mb-3 space-y-1">{children}</ol>,
|
||||||
li: ({children}) => <li className="leading-relaxed">{children}</li>,
|
li: ({children, node}) => {
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
const txt = node?.children?.map((c: any) => c.type === 'text' ? c.value : '').join('') || '';
|
||||||
|
return <li className="leading-relaxed"><SpeakableText text={txt}>{children}</SpeakableText></li>;
|
||||||
|
},
|
||||||
strong: ({children}) => <strong className="font-bold">{children}</strong>,
|
strong: ({children}) => <strong className="font-bold">{children}</strong>,
|
||||||
blockquote: ({children}) => <blockquote className="border-l-2 border-[var(--color-primary)] pl-4 my-3 italic text-[var(--color-text-muted)]">{children}</blockquote>,
|
blockquote: ({children}) => <blockquote className="border-l-2 border-[var(--color-primary)] pl-4 my-3 italic text-[var(--color-text-muted)]">{children}</blockquote>,
|
||||||
}}
|
}}
|
||||||
|
|||||||
@@ -45,8 +45,21 @@ export default function TTSPage() {
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const fetchProfiles = () => {
|
const fetchProfiles = () => {
|
||||||
|
// 캐시 먼저
|
||||||
|
const cached = localStorage.getItem("tts_profiles");
|
||||||
|
if (cached) {
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(cached);
|
||||||
|
setProfiles(data);
|
||||||
|
if (data.length > 0 && !selectedProfile) setSelectedProfile(data[0].id);
|
||||||
|
} catch {}
|
||||||
|
}
|
||||||
fetch("/api/tts/profiles").then(r => r.json())
|
fetch("/api/tts/profiles").then(r => r.json())
|
||||||
.then(setProfiles).catch(() => {});
|
.then(data => {
|
||||||
|
setProfiles(data);
|
||||||
|
localStorage.setItem("tts_profiles", JSON.stringify(data));
|
||||||
|
if (data.length > 0 && !selectedProfile) setSelectedProfile(data[0].id);
|
||||||
|
}).catch(() => {});
|
||||||
};
|
};
|
||||||
|
|
||||||
const startRecording = async () => {
|
const startRecording = async () => {
|
||||||
@@ -97,6 +110,7 @@ export default function TTSPage() {
|
|||||||
setRecordedUrl(null);
|
setRecordedUrl(null);
|
||||||
setUploadedFile(null);
|
setUploadedFile(null);
|
||||||
fetchProfiles();
|
fetchProfiles();
|
||||||
|
localStorage.removeItem("tts_profiles"); // 캐시 강제 갱신
|
||||||
setSelectedProfile(result.id);
|
setSelectedProfile(result.id);
|
||||||
setTab("generate");
|
setTab("generate");
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
@@ -125,9 +139,10 @@ export default function TTSPage() {
|
|||||||
fd.append("text", text);
|
fd.append("text", text);
|
||||||
fd.append("profile_id", selectedProfile);
|
fd.append("profile_id", selectedProfile);
|
||||||
fd.append("language", language);
|
fd.append("language", language);
|
||||||
const res = await fetch("/api/tts/generate", { method: "POST", body: fd });
|
const res = await fetch("/api/tts/speak", { method: "POST", body: fd });
|
||||||
if (!res.ok) throw new Error(`HTTP ${res.status}`);
|
if (!res.ok) throw new Error(`HTTP ${res.status}`);
|
||||||
const blob = await res.blob();
|
const blob = await res.blob();
|
||||||
|
if (blob.size < 100) throw new Error("Empty audio");
|
||||||
setOutputUrl(URL.createObjectURL(blob));
|
setOutputUrl(URL.createObjectURL(blob));
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
setError("생성 실패: " + (err instanceof Error ? err.message : ""));
|
setError("생성 실패: " + (err instanceof Error ? err.message : ""));
|
||||||
|
|||||||
85
sundol-frontend/src/components/speakable-text.tsx
Normal file
85
sundol-frontend/src/components/speakable-text.tsx
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState, useRef, useEffect } from "react";
|
||||||
|
|
||||||
|
interface SpeakableProps {
|
||||||
|
children: React.ReactNode;
|
||||||
|
text: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
let cachedProfileId: string | null = null;
|
||||||
|
let profileChecked = false;
|
||||||
|
|
||||||
|
export default function SpeakableText({ children, text }: SpeakableProps) {
|
||||||
|
const [playing, setPlaying] = useState(false);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [hasProfile, setHasProfile] = useState(false);
|
||||||
|
const audioRef = useRef<HTMLAudioElement | null>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (profileChecked) {
|
||||||
|
setHasProfile(!!cachedProfileId);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const profiles = JSON.parse(localStorage.getItem("tts_profiles") || "[]");
|
||||||
|
if (profiles.length > 0) {
|
||||||
|
cachedProfileId = profiles[0].id;
|
||||||
|
setHasProfile(true);
|
||||||
|
}
|
||||||
|
profileChecked = true;
|
||||||
|
} catch {}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleSpeak = async (e: React.MouseEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
e.stopPropagation();
|
||||||
|
|
||||||
|
if (playing) {
|
||||||
|
audioRef.current?.pause();
|
||||||
|
setPlaying(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!cachedProfileId || text.length < 5) return;
|
||||||
|
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
const fd = new FormData();
|
||||||
|
fd.append("text", text);
|
||||||
|
fd.append("profile_id", cachedProfileId);
|
||||||
|
fd.append("language", "Korean");
|
||||||
|
const res = await fetch("/api/tts/speak", { method: "POST", body: fd });
|
||||||
|
if (!res.ok) { setLoading(false); return; }
|
||||||
|
const blob = await res.blob();
|
||||||
|
if (blob.size < 200) { setLoading(false); return; }
|
||||||
|
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const audio = new Audio(url);
|
||||||
|
audioRef.current = audio;
|
||||||
|
audio.onended = () => setPlaying(false);
|
||||||
|
setPlaying(true);
|
||||||
|
setLoading(false);
|
||||||
|
audio.play();
|
||||||
|
} catch {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!hasProfile || text.length < 5) return <>{children}</>;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{children}
|
||||||
|
<button
|
||||||
|
onClick={handleSpeak}
|
||||||
|
disabled={loading}
|
||||||
|
className="inline-flex items-center ml-1 text-[var(--color-text-muted)] hover:text-[var(--color-primary)] disabled:opacity-30 align-middle"
|
||||||
|
title={playing ? "중지" : "읽어주기"}
|
||||||
|
style={{ fontSize: "0.85em", verticalAlign: "middle", cursor: "pointer" }}
|
||||||
|
>
|
||||||
|
{loading ? "⏳" : playing ? "⏹" : "🔊"}
|
||||||
|
</button>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
159
sundol-frontend/src/components/tts-reader.tsx
Normal file
159
sundol-frontend/src/components/tts-reader.tsx
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState, useEffect, useRef } from "react";
|
||||||
|
|
||||||
|
interface TTSReaderProps {
|
||||||
|
text: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface VoiceProfile {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function TTSReader({ text }: TTSReaderProps) {
|
||||||
|
const [profiles, setProfiles] = useState<VoiceProfile[]>([]);
|
||||||
|
const [selectedProfile, setSelectedProfile] = useState("");
|
||||||
|
const [generating, setGenerating] = useState(false);
|
||||||
|
const [playing, setPlaying] = useState(false);
|
||||||
|
const [progress, setProgress] = useState("");
|
||||||
|
const audioRef = useRef<HTMLAudioElement | null>(null);
|
||||||
|
const stoppedRef = useRef(false);
|
||||||
|
const audioUrlsRef = useRef<string[]>([]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
// localStorage 캐시
|
||||||
|
const cached = localStorage.getItem("tts_profiles");
|
||||||
|
if (cached) {
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(cached);
|
||||||
|
setProfiles(data);
|
||||||
|
if (data.length > 0) setSelectedProfile(data[0].id);
|
||||||
|
} catch {}
|
||||||
|
}
|
||||||
|
// 백그라운드에서 갱신 (블록 안 됨)
|
||||||
|
fetch("/api/tts/profiles").then(r => r.json()).then(data => {
|
||||||
|
setProfiles(data);
|
||||||
|
if (data.length > 0 && !selectedProfile) setSelectedProfile(data[0].id);
|
||||||
|
localStorage.setItem("tts_profiles", JSON.stringify(data));
|
||||||
|
}).catch(() => {});
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const toSentences = (md: string): string[] => {
|
||||||
|
return md
|
||||||
|
.replace(/^#+\s+.*$/gm, "")
|
||||||
|
.replace(/\*\*/g, "")
|
||||||
|
.replace(/^[-*]\s+/gm, "")
|
||||||
|
.replace(/^>\s+/gm, "")
|
||||||
|
.replace(/---+/g, "")
|
||||||
|
.replace(/\[([^\]]+)\]\([^)]+\)/g, "$1")
|
||||||
|
.split("\n")
|
||||||
|
.map(s => s.trim())
|
||||||
|
.filter(s => s.length >= 10);
|
||||||
|
};
|
||||||
|
|
||||||
|
// 직접 동기 호출 — 바로 wav 반환
|
||||||
|
const speak = async (chunk: string): Promise<string | null> => {
|
||||||
|
const fd = new FormData();
|
||||||
|
fd.append("text", chunk);
|
||||||
|
fd.append("profile_id", selectedProfile);
|
||||||
|
fd.append("language", "Korean");
|
||||||
|
const res = await fetch("/api/tts/speak", { method: "POST", body: fd });
|
||||||
|
if (!res.ok) return null;
|
||||||
|
const blob = await res.blob();
|
||||||
|
return blob.size > 100 ? URL.createObjectURL(blob) : null;
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleGenerate = async () => {
|
||||||
|
if (!selectedProfile || !text.trim()) return;
|
||||||
|
setGenerating(true);
|
||||||
|
setPlaying(true);
|
||||||
|
stoppedRef.current = false;
|
||||||
|
audioUrlsRef.current = [];
|
||||||
|
|
||||||
|
const sentences = toSentences(text);
|
||||||
|
let isAudioPlaying = false;
|
||||||
|
let playIdx = 0;
|
||||||
|
|
||||||
|
const playNext = () => {
|
||||||
|
if (stoppedRef.current) return;
|
||||||
|
if (playIdx >= audioUrlsRef.current.length) { isAudioPlaying = false; return; }
|
||||||
|
isAudioPlaying = true;
|
||||||
|
const a = new Audio(audioUrlsRef.current[playIdx++]);
|
||||||
|
audioRef.current = a;
|
||||||
|
a.onended = () => {
|
||||||
|
if (stoppedRef.current) return;
|
||||||
|
playIdx < audioUrlsRef.current.length ? playNext() : (isAudioPlaying = false);
|
||||||
|
};
|
||||||
|
a.play();
|
||||||
|
};
|
||||||
|
|
||||||
|
for (let i = 0; i < sentences.length; i++) {
|
||||||
|
if (stoppedRef.current) break;
|
||||||
|
setProgress(`${i + 1}/${sentences.length}`);
|
||||||
|
const url = await speak(sentences[i]);
|
||||||
|
if (url && !stoppedRef.current) {
|
||||||
|
audioUrlsRef.current.push(url);
|
||||||
|
if (!isAudioPlaying) playNext();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setGenerating(false);
|
||||||
|
setProgress("");
|
||||||
|
if (!isAudioPlaying) setPlaying(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleStop = () => {
|
||||||
|
stoppedRef.current = true;
|
||||||
|
audioRef.current?.pause();
|
||||||
|
setPlaying(false);
|
||||||
|
setGenerating(false);
|
||||||
|
setProgress("");
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleReplay = () => {
|
||||||
|
if (audioUrlsRef.current.length === 0) return;
|
||||||
|
stoppedRef.current = false;
|
||||||
|
setPlaying(true);
|
||||||
|
let idx = 0;
|
||||||
|
const play = () => {
|
||||||
|
if (idx >= audioUrlsRef.current.length || stoppedRef.current) { setPlaying(false); return; }
|
||||||
|
const audio = new Audio(audioUrlsRef.current[idx]);
|
||||||
|
audioRef.current = audio;
|
||||||
|
idx++;
|
||||||
|
audio.onended = play;
|
||||||
|
audio.play();
|
||||||
|
};
|
||||||
|
play();
|
||||||
|
};
|
||||||
|
|
||||||
|
if (profiles.length === 0) return null;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex items-center gap-2 flex-wrap">
|
||||||
|
<select value={selectedProfile} onChange={e => setSelectedProfile(e.target.value)}
|
||||||
|
className="text-xs px-2 py-1 rounded bg-[var(--color-bg-hover)] border border-[var(--color-border)]">
|
||||||
|
{profiles.map(p => <option key={p.id} value={p.id}>{p.name}</option>)}
|
||||||
|
</select>
|
||||||
|
|
||||||
|
{playing || generating ? (
|
||||||
|
<button onClick={handleStop}
|
||||||
|
className="text-xs px-3 py-1 bg-red-500/20 text-red-400 rounded hover:bg-red-500/30">
|
||||||
|
{progress || "중지"}
|
||||||
|
</button>
|
||||||
|
) : (
|
||||||
|
<button onClick={handleGenerate} disabled={!selectedProfile}
|
||||||
|
className="text-xs px-3 py-1 bg-[var(--color-primary)]/20 text-[var(--color-primary)] rounded hover:bg-[var(--color-primary)]/30 disabled:opacity-40">
|
||||||
|
읽어주기
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{audioUrlsRef.current.length > 0 && !playing && !generating && (
|
||||||
|
<button onClick={handleReplay}
|
||||||
|
className="text-xs px-3 py-1 bg-[var(--color-bg-hover)] border border-[var(--color-border)] rounded">
|
||||||
|
다시 재생
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
309
tts-server.py
309
tts-server.py
@@ -1,159 +1,94 @@
|
|||||||
"""
|
"""
|
||||||
Qwen3-TTS Voice Clone API Server
|
Qwen3-TTS Voice Clone API Server (최적화 버전)
|
||||||
별도 프로세스로 실행 (GPU 메모리 관리를 위해)
|
- 0.6B 모델 사용 (A10 속도 최적화)
|
||||||
|
- 모델 1회 로드, voice clone prompt 캐시
|
||||||
|
- inference_mode, bf16
|
||||||
|
- 문장 단위 분할
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
import base64
|
|
||||||
import tempfile
|
|
||||||
import torch
|
|
||||||
import soundfile as sf
|
|
||||||
import numpy as np
|
|
||||||
from fastapi import FastAPI, UploadFile, File, Form
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
from fastapi import FastAPI, UploadFile, File, Form
|
||||||
|
from fastapi.responses import StreamingResponse, FileResponse
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
||||||
|
|
||||||
model = None
|
|
||||||
PROFILES_DIR = os.path.join(os.path.dirname(__file__), "voice-profiles")
|
PROFILES_DIR = os.path.join(os.path.dirname(__file__), "voice-profiles")
|
||||||
os.makedirs(PROFILES_DIR, exist_ok=True)
|
os.makedirs(PROFILES_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
MODEL_NAME = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||||
|
model = None
|
||||||
|
prompt_cache = {} # profile_id → voice_clone_prompt
|
||||||
|
|
||||||
|
|
||||||
def get_model():
|
def get_model():
|
||||||
global model
|
global model
|
||||||
if model is None:
|
if model is None:
|
||||||
from qwen_tts import Qwen3TTSModel
|
from qwen_tts import Qwen3TTSModel
|
||||||
print("Loading Qwen3-TTS model...")
|
print(f"Loading {MODEL_NAME}...")
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||||
model = Qwen3TTSModel.from_pretrained(
|
model = Qwen3TTSModel.from_pretrained(
|
||||||
"Qwen/Qwen3-TTS-12Hz-1.7B-Base",
|
MODEL_NAME, device_map="cuda:0", dtype=dtype,
|
||||||
device_map="cuda:0",
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
)
|
)
|
||||||
|
# 프로필 프롬프트 캐시 로드
|
||||||
|
load_all_prompts()
|
||||||
print("Model loaded!")
|
print("Model loaded!")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_all_prompts():
|
||||||
|
"""모든 프로필의 voice clone prompt를 메모리에 캐시"""
|
||||||
|
global prompt_cache
|
||||||
|
for f in os.listdir(PROFILES_DIR):
|
||||||
|
if f.endswith(".pkl"):
|
||||||
|
pid = f.replace(".pkl", "")
|
||||||
|
try:
|
||||||
|
with open(os.path.join(PROFILES_DIR, f), "rb") as fh:
|
||||||
|
prompt_cache[pid] = pickle.load(fh)
|
||||||
|
print(f" Cached prompt: {pid}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Failed to cache {pid}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt(profile_id: str):
|
||||||
|
"""캐시에서 프롬프트 가져오기, 없으면 파일에서 로드"""
|
||||||
|
if profile_id in prompt_cache:
|
||||||
|
return prompt_cache[profile_id]
|
||||||
|
|
||||||
|
pkl_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl")
|
||||||
|
if os.path.exists(pkl_path):
|
||||||
|
with open(pkl_path, "rb") as f:
|
||||||
|
prompt = pickle.load(f)
|
||||||
|
prompt_cache[profile_id] = prompt
|
||||||
|
return prompt
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# === API ===
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
@app.get("/api/tts/health")
|
@app.get("/api/tts/health")
|
||||||
def health():
|
def health():
|
||||||
return {"status": "ok", "model_loaded": model is not None}
|
return {"status": "ok", "model": MODEL_NAME, "model_loaded": model is not None}
|
||||||
|
|
||||||
@app.post("/api/tts/clone")
|
|
||||||
async def voice_clone(
|
|
||||||
text: str = Form(...),
|
|
||||||
language: str = Form("korean"),
|
|
||||||
ref_audio: UploadFile = File(...),
|
|
||||||
ref_text: str = Form(""),
|
|
||||||
):
|
|
||||||
"""참조 음성으로 보이스 클로닝"""
|
|
||||||
m = get_model()
|
|
||||||
|
|
||||||
# 참조 음성 저장
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
|
||||||
content = await ref_audio.read()
|
|
||||||
tmp.write(content)
|
|
||||||
tmp_path = tmp.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
# wav 변환 (필요 시)
|
|
||||||
if not ref_audio.filename.endswith(".wav"):
|
|
||||||
wav_path = tmp_path + "_converted.wav"
|
|
||||||
os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null')
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
tmp_path = wav_path
|
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
"text": text,
|
|
||||||
"language": language,
|
|
||||||
"ref_audio": tmp_path,
|
|
||||||
}
|
|
||||||
if ref_text and ref_text.strip():
|
|
||||||
kwargs["ref_text"] = ref_text
|
|
||||||
else:
|
|
||||||
kwargs["x_vector_only_mode"] = True
|
|
||||||
|
|
||||||
wavs, sr = m.generate_voice_clone(**kwargs)
|
|
||||||
print(f"Clone generated: wavs={len(wavs)}, samples={len(wavs[0]) if len(wavs) > 0 else 0}, sr={sr}")
|
|
||||||
|
|
||||||
audio_data = np.array(wavs[0], dtype=np.float32)
|
|
||||||
buf = io.BytesIO()
|
|
||||||
sf.write(buf, audio_data, sr, format="WAV")
|
|
||||||
buf.seek(0)
|
|
||||||
|
|
||||||
return StreamingResponse(buf, media_type="audio/wav",
|
|
||||||
headers={"Content-Disposition": "attachment; filename=tts_output.wav"})
|
|
||||||
finally:
|
|
||||||
if os.path.exists(tmp_path):
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
|
|
||||||
@app.post("/api/tts/design")
|
|
||||||
async def voice_design(
|
|
||||||
text: str = Form(...),
|
|
||||||
language: str = Form("korean"),
|
|
||||||
instruct: str = Form("A calm, professional Korean male voice"),
|
|
||||||
):
|
|
||||||
"""음성 디자인으로 생성 (참조 음성 없이)"""
|
|
||||||
m = get_model()
|
|
||||||
wavs, sr = m.generate_voice_design(text=text, instruct=instruct, language=language)
|
|
||||||
|
|
||||||
buf = io.BytesIO()
|
|
||||||
sf.write(buf, wavs[0], sr, format="WAV")
|
|
||||||
buf.seek(0)
|
|
||||||
|
|
||||||
return StreamingResponse(buf, media_type="audio/wav",
|
|
||||||
headers={"Content-Disposition": "attachment; filename=tts_output.wav"})
|
|
||||||
|
|
||||||
@app.post("/api/tts/profiles")
|
|
||||||
async def create_profile(
|
|
||||||
name: str = Form(...),
|
|
||||||
ref_audio: UploadFile = File(...),
|
|
||||||
ref_text: str = Form(""),
|
|
||||||
):
|
|
||||||
"""음성 프로필 등록: 참조 음성으로 보이스 프로필 생성 후 저장"""
|
|
||||||
m = get_model()
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
|
||||||
content = await ref_audio.read()
|
|
||||||
tmp.write(content)
|
|
||||||
tmp_path = tmp.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not ref_audio.filename.endswith(".wav"):
|
|
||||||
wav_path = tmp_path + "_converted.wav"
|
|
||||||
os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null')
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
tmp_path = wav_path
|
|
||||||
|
|
||||||
# 프로필 생성
|
|
||||||
kwargs = {"ref_audio": tmp_path}
|
|
||||||
if ref_text and ref_text.strip():
|
|
||||||
kwargs["ref_text"] = ref_text
|
|
||||||
prompt = m.create_voice_clone_prompt(**kwargs)
|
|
||||||
else:
|
|
||||||
kwargs["x_vector_only_mode"] = True
|
|
||||||
prompt = m.create_voice_clone_prompt(**kwargs)
|
|
||||||
|
|
||||||
# 저장
|
|
||||||
profile_id = name.replace(" ", "_").lower()
|
|
||||||
profile_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl")
|
|
||||||
meta_path = os.path.join(PROFILES_DIR, f"{profile_id}.json")
|
|
||||||
|
|
||||||
with open(profile_path, "wb") as f:
|
|
||||||
pickle.dump(prompt, f)
|
|
||||||
with open(meta_path, "w") as f:
|
|
||||||
json.dump({"id": profile_id, "name": name, "ref_text": ref_text}, f, ensure_ascii=False)
|
|
||||||
|
|
||||||
return {"id": profile_id, "name": name, "status": "created"}
|
|
||||||
finally:
|
|
||||||
if os.path.exists(tmp_path):
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
|
|
||||||
@app.get("/api/tts/profiles")
|
@app.get("/api/tts/profiles")
|
||||||
def list_profiles():
|
def list_profiles():
|
||||||
"""저장된 음성 프로필 목록"""
|
|
||||||
profiles = []
|
profiles = []
|
||||||
for f in os.listdir(PROFILES_DIR):
|
for f in os.listdir(PROFILES_DIR):
|
||||||
if f.endswith(".json"):
|
if f.endswith(".json"):
|
||||||
@@ -161,49 +96,119 @@ def list_profiles():
|
|||||||
profiles.append(json.load(fh))
|
profiles.append(json.load(fh))
|
||||||
return profiles
|
return profiles
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/tts/profiles")
|
||||||
|
async def create_profile(
|
||||||
|
name: str = Form(...),
|
||||||
|
ref_audio: UploadFile = File(...),
|
||||||
|
ref_text: str = Form(""),
|
||||||
|
):
|
||||||
|
m = get_model()
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||||
|
content = await ref_audio.read()
|
||||||
|
tmp.write(content)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not ref_audio.filename.endswith(".wav"):
|
||||||
|
wav_path = tmp_path + "_converted.wav"
|
||||||
|
os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null')
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
tmp_path = wav_path
|
||||||
|
|
||||||
|
kwargs = {"ref_audio": tmp_path}
|
||||||
|
if ref_text and ref_text.strip():
|
||||||
|
kwargs["ref_text"] = ref_text
|
||||||
|
else:
|
||||||
|
kwargs["x_vector_only_mode"] = True
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
prompt = m.create_voice_clone_prompt(**kwargs)
|
||||||
|
|
||||||
|
profile_id = name.replace(" ", "_").lower()
|
||||||
|
|
||||||
|
# wav, pkl, json 저장
|
||||||
|
import shutil
|
||||||
|
shutil.copy2(tmp_path, os.path.join(PROFILES_DIR, f"{profile_id}.wav"))
|
||||||
|
with open(os.path.join(PROFILES_DIR, f"{profile_id}.pkl"), "wb") as f:
|
||||||
|
pickle.dump(prompt, f)
|
||||||
|
with open(os.path.join(PROFILES_DIR, f"{profile_id}.json"), "w") as f:
|
||||||
|
json.dump({"id": profile_id, "name": name, "ref_text": ref_text}, f, ensure_ascii=False)
|
||||||
|
|
||||||
|
# 캐시에 추가
|
||||||
|
prompt_cache[profile_id] = prompt
|
||||||
|
|
||||||
|
return {"id": profile_id, "name": name, "status": "created"}
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/api/tts/profiles/{profile_id}")
|
@app.delete("/api/tts/profiles/{profile_id}")
|
||||||
def delete_profile(profile_id: str):
|
def delete_profile(profile_id: str):
|
||||||
"""음성 프로필 삭제"""
|
for ext in [".pkl", ".json", ".wav"]:
|
||||||
pkl = os.path.join(PROFILES_DIR, f"{profile_id}.pkl")
|
p = os.path.join(PROFILES_DIR, f"{profile_id}{ext}")
|
||||||
meta = os.path.join(PROFILES_DIR, f"{profile_id}.json")
|
if os.path.exists(p):
|
||||||
if os.path.exists(pkl): os.unlink(pkl)
|
os.unlink(p)
|
||||||
if os.path.exists(meta): os.unlink(meta)
|
prompt_cache.pop(profile_id, None)
|
||||||
return {"status": "deleted"}
|
return {"status": "deleted"}
|
||||||
|
|
||||||
@app.post("/api/tts/generate")
|
|
||||||
async def generate_from_profile(
|
@app.post("/api/tts/speak")
|
||||||
|
async def speak(
|
||||||
text: str = Form(...),
|
text: str = Form(...),
|
||||||
profile_id: str = Form(...),
|
profile_id: str = Form(...),
|
||||||
language: str = Form("korean"),
|
language: str = Form("Korean"),
|
||||||
):
|
):
|
||||||
"""저장된 음성 프로필로 TTS 생성"""
|
"""한 문장 TTS — 캐시된 프롬프트 사용, 바로 wav 반환"""
|
||||||
m = get_model()
|
m = get_model()
|
||||||
|
|
||||||
profile_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl")
|
prompt = get_prompt(profile_id)
|
||||||
if not os.path.exists(profile_path):
|
if prompt is None:
|
||||||
return {"error": f"Profile '{profile_id}' not found"}, 404
|
# 프롬프트가 없으면 ref_audio로 직접
|
||||||
|
meta_path = os.path.join(PROFILES_DIR, f"{profile_id}.json")
|
||||||
|
ref_audio_path = os.path.join(PROFILES_DIR, f"{profile_id}.wav")
|
||||||
|
if not os.path.exists(ref_audio_path):
|
||||||
|
return {"error": "Profile not found"}, 404
|
||||||
|
|
||||||
with open(profile_path, "rb") as f:
|
with open(meta_path) as f:
|
||||||
prompt = pickle.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
print(f"Generating with profile '{profile_id}', text='{text[:50]}...', language={language}")
|
kwargs = {"text": text, "language": language, "ref_audio": ref_audio_path}
|
||||||
wavs, sr = m.generate_voice_clone(
|
if meta.get("ref_text"):
|
||||||
text=text,
|
kwargs["ref_text"] = meta["ref_text"]
|
||||||
language=language,
|
else:
|
||||||
voice_clone_prompt=prompt,
|
kwargs["x_vector_only_mode"] = True
|
||||||
)
|
|
||||||
print(f"Generated: wavs={len(wavs)}, samples={len(wavs[0]) if len(wavs) > 0 else 0}, sr={sr}")
|
|
||||||
|
|
||||||
if len(wavs) == 0 or len(wavs[0]) == 0:
|
start = time.perf_counter()
|
||||||
return {"error": "Empty audio generated"}, 500
|
with torch.inference_mode():
|
||||||
|
wavs, sr = m.generate_voice_clone(**kwargs)
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
else:
|
||||||
|
start = time.perf_counter()
|
||||||
|
with torch.inference_mode():
|
||||||
|
wavs, sr = m.generate_voice_clone(
|
||||||
|
text=text, language=language, voice_clone_prompt=prompt,
|
||||||
|
)
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
|
||||||
audio_data = np.array(wavs[0], dtype=np.float32)
|
audio_data = np.array(wavs[0], dtype=np.float32)
|
||||||
|
print(f"speak: {len(text)} chars → {len(audio_data)/sr:.1f}s audio in {elapsed:.1f}s")
|
||||||
|
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
sf.write(buf, audio_data, sr, format="WAV")
|
sf.write(buf, audio_data, sr, format="WAV")
|
||||||
buf.seek(0)
|
buf.seek(0)
|
||||||
|
return StreamingResponse(buf, media_type="audio/wav")
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def global_exception_handler(request, exc):
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return JSONResponse(status_code=500, content={"error": str(exc)})
|
||||||
|
|
||||||
return StreamingResponse(buf, media_type="audio/wav",
|
|
||||||
headers={"Content-Disposition": "attachment; filename=tts_output.wav"})
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|||||||
Reference in New Issue
Block a user