Fix TTS: switch to 1.7B with ref_audio, speakable text on all lines
- Use 1.7B model (0.6B had tensor mismatch with cached prompts) - Speak endpoint uses ref_audio directly (not cached pkl) as fallback - Cache voice clone prompts in memory on startup - Add SpeakableText component: 🔊 icon on each p and li element - Remove old TTSReader sequential approach - Add global exception handler to TTS server - Fix profile localStorage caching - inference_mode + bf16 optimization Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
309
tts-server.py
309
tts-server.py
@@ -1,159 +1,94 @@
|
||||
"""
|
||||
Qwen3-TTS Voice Clone API Server
|
||||
별도 프로세스로 실행 (GPU 메모리 관리를 위해)
|
||||
Qwen3-TTS Voice Clone API Server (최적화 버전)
|
||||
- 0.6B 모델 사용 (A10 속도 최적화)
|
||||
- 모델 1회 로드, voice clone prompt 캐시
|
||||
- inference_mode, bf16
|
||||
- 문장 단위 분할
|
||||
"""
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import tempfile
|
||||
import torch
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, UploadFile, File, Form
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
import json
|
||||
import pickle
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from fastapi import FastAPI, UploadFile, File, Form
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
||||
|
||||
model = None
|
||||
PROFILES_DIR = os.path.join(os.path.dirname(__file__), "voice-profiles")
|
||||
os.makedirs(PROFILES_DIR, exist_ok=True)
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
model = None
|
||||
prompt_cache = {} # profile_id → voice_clone_prompt
|
||||
|
||||
|
||||
def get_model():
|
||||
global model
|
||||
if model is None:
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
print("Loading Qwen3-TTS model...")
|
||||
print(f"Loading {MODEL_NAME}...")
|
||||
torch.set_grad_enabled(False)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
model = Qwen3TTSModel.from_pretrained(
|
||||
"Qwen/Qwen3-TTS-12Hz-1.7B-Base",
|
||||
device_map="cuda:0",
|
||||
dtype=torch.bfloat16,
|
||||
MODEL_NAME, device_map="cuda:0", dtype=dtype,
|
||||
)
|
||||
# 프로필 프롬프트 캐시 로드
|
||||
load_all_prompts()
|
||||
print("Model loaded!")
|
||||
return model
|
||||
|
||||
|
||||
def load_all_prompts():
|
||||
"""모든 프로필의 voice clone prompt를 메모리에 캐시"""
|
||||
global prompt_cache
|
||||
for f in os.listdir(PROFILES_DIR):
|
||||
if f.endswith(".pkl"):
|
||||
pid = f.replace(".pkl", "")
|
||||
try:
|
||||
with open(os.path.join(PROFILES_DIR, f), "rb") as fh:
|
||||
prompt_cache[pid] = pickle.load(fh)
|
||||
print(f" Cached prompt: {pid}")
|
||||
except Exception as e:
|
||||
print(f" Failed to cache {pid}: {e}")
|
||||
|
||||
|
||||
def get_prompt(profile_id: str):
|
||||
"""캐시에서 프롬프트 가져오기, 없으면 파일에서 로드"""
|
||||
if profile_id in prompt_cache:
|
||||
return prompt_cache[profile_id]
|
||||
|
||||
pkl_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl")
|
||||
if os.path.exists(pkl_path):
|
||||
with open(pkl_path, "rb") as f:
|
||||
prompt = pickle.load(f)
|
||||
prompt_cache[profile_id] = prompt
|
||||
return prompt
|
||||
return None
|
||||
|
||||
|
||||
# === API ===
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/api/tts/health")
|
||||
def health():
|
||||
return {"status": "ok", "model_loaded": model is not None}
|
||||
return {"status": "ok", "model": MODEL_NAME, "model_loaded": model is not None}
|
||||
|
||||
@app.post("/api/tts/clone")
|
||||
async def voice_clone(
|
||||
text: str = Form(...),
|
||||
language: str = Form("korean"),
|
||||
ref_audio: UploadFile = File(...),
|
||||
ref_text: str = Form(""),
|
||||
):
|
||||
"""참조 음성으로 보이스 클로닝"""
|
||||
m = get_model()
|
||||
|
||||
# 참조 음성 저장
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
content = await ref_audio.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
# wav 변환 (필요 시)
|
||||
if not ref_audio.filename.endswith(".wav"):
|
||||
wav_path = tmp_path + "_converted.wav"
|
||||
os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null')
|
||||
os.unlink(tmp_path)
|
||||
tmp_path = wav_path
|
||||
|
||||
kwargs = {
|
||||
"text": text,
|
||||
"language": language,
|
||||
"ref_audio": tmp_path,
|
||||
}
|
||||
if ref_text and ref_text.strip():
|
||||
kwargs["ref_text"] = ref_text
|
||||
else:
|
||||
kwargs["x_vector_only_mode"] = True
|
||||
|
||||
wavs, sr = m.generate_voice_clone(**kwargs)
|
||||
print(f"Clone generated: wavs={len(wavs)}, samples={len(wavs[0]) if len(wavs) > 0 else 0}, sr={sr}")
|
||||
|
||||
audio_data = np.array(wavs[0], dtype=np.float32)
|
||||
buf = io.BytesIO()
|
||||
sf.write(buf, audio_data, sr, format="WAV")
|
||||
buf.seek(0)
|
||||
|
||||
return StreamingResponse(buf, media_type="audio/wav",
|
||||
headers={"Content-Disposition": "attachment; filename=tts_output.wav"})
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
@app.post("/api/tts/design")
|
||||
async def voice_design(
|
||||
text: str = Form(...),
|
||||
language: str = Form("korean"),
|
||||
instruct: str = Form("A calm, professional Korean male voice"),
|
||||
):
|
||||
"""음성 디자인으로 생성 (참조 음성 없이)"""
|
||||
m = get_model()
|
||||
wavs, sr = m.generate_voice_design(text=text, instruct=instruct, language=language)
|
||||
|
||||
buf = io.BytesIO()
|
||||
sf.write(buf, wavs[0], sr, format="WAV")
|
||||
buf.seek(0)
|
||||
|
||||
return StreamingResponse(buf, media_type="audio/wav",
|
||||
headers={"Content-Disposition": "attachment; filename=tts_output.wav"})
|
||||
|
||||
@app.post("/api/tts/profiles")
|
||||
async def create_profile(
|
||||
name: str = Form(...),
|
||||
ref_audio: UploadFile = File(...),
|
||||
ref_text: str = Form(""),
|
||||
):
|
||||
"""음성 프로필 등록: 참조 음성으로 보이스 프로필 생성 후 저장"""
|
||||
m = get_model()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
content = await ref_audio.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
if not ref_audio.filename.endswith(".wav"):
|
||||
wav_path = tmp_path + "_converted.wav"
|
||||
os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null')
|
||||
os.unlink(tmp_path)
|
||||
tmp_path = wav_path
|
||||
|
||||
# 프로필 생성
|
||||
kwargs = {"ref_audio": tmp_path}
|
||||
if ref_text and ref_text.strip():
|
||||
kwargs["ref_text"] = ref_text
|
||||
prompt = m.create_voice_clone_prompt(**kwargs)
|
||||
else:
|
||||
kwargs["x_vector_only_mode"] = True
|
||||
prompt = m.create_voice_clone_prompt(**kwargs)
|
||||
|
||||
# 저장
|
||||
profile_id = name.replace(" ", "_").lower()
|
||||
profile_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl")
|
||||
meta_path = os.path.join(PROFILES_DIR, f"{profile_id}.json")
|
||||
|
||||
with open(profile_path, "wb") as f:
|
||||
pickle.dump(prompt, f)
|
||||
with open(meta_path, "w") as f:
|
||||
json.dump({"id": profile_id, "name": name, "ref_text": ref_text}, f, ensure_ascii=False)
|
||||
|
||||
return {"id": profile_id, "name": name, "status": "created"}
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
@app.get("/api/tts/profiles")
|
||||
def list_profiles():
|
||||
"""저장된 음성 프로필 목록"""
|
||||
profiles = []
|
||||
for f in os.listdir(PROFILES_DIR):
|
||||
if f.endswith(".json"):
|
||||
@@ -161,49 +96,119 @@ def list_profiles():
|
||||
profiles.append(json.load(fh))
|
||||
return profiles
|
||||
|
||||
|
||||
@app.post("/api/tts/profiles")
|
||||
async def create_profile(
|
||||
name: str = Form(...),
|
||||
ref_audio: UploadFile = File(...),
|
||||
ref_text: str = Form(""),
|
||||
):
|
||||
m = get_model()
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
content = await ref_audio.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
if not ref_audio.filename.endswith(".wav"):
|
||||
wav_path = tmp_path + "_converted.wav"
|
||||
os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null')
|
||||
os.unlink(tmp_path)
|
||||
tmp_path = wav_path
|
||||
|
||||
kwargs = {"ref_audio": tmp_path}
|
||||
if ref_text and ref_text.strip():
|
||||
kwargs["ref_text"] = ref_text
|
||||
else:
|
||||
kwargs["x_vector_only_mode"] = True
|
||||
|
||||
with torch.inference_mode():
|
||||
prompt = m.create_voice_clone_prompt(**kwargs)
|
||||
|
||||
profile_id = name.replace(" ", "_").lower()
|
||||
|
||||
# wav, pkl, json 저장
|
||||
import shutil
|
||||
shutil.copy2(tmp_path, os.path.join(PROFILES_DIR, f"{profile_id}.wav"))
|
||||
with open(os.path.join(PROFILES_DIR, f"{profile_id}.pkl"), "wb") as f:
|
||||
pickle.dump(prompt, f)
|
||||
with open(os.path.join(PROFILES_DIR, f"{profile_id}.json"), "w") as f:
|
||||
json.dump({"id": profile_id, "name": name, "ref_text": ref_text}, f, ensure_ascii=False)
|
||||
|
||||
# 캐시에 추가
|
||||
prompt_cache[profile_id] = prompt
|
||||
|
||||
return {"id": profile_id, "name": name, "status": "created"}
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
@app.delete("/api/tts/profiles/{profile_id}")
|
||||
def delete_profile(profile_id: str):
|
||||
"""음성 프로필 삭제"""
|
||||
pkl = os.path.join(PROFILES_DIR, f"{profile_id}.pkl")
|
||||
meta = os.path.join(PROFILES_DIR, f"{profile_id}.json")
|
||||
if os.path.exists(pkl): os.unlink(pkl)
|
||||
if os.path.exists(meta): os.unlink(meta)
|
||||
for ext in [".pkl", ".json", ".wav"]:
|
||||
p = os.path.join(PROFILES_DIR, f"{profile_id}{ext}")
|
||||
if os.path.exists(p):
|
||||
os.unlink(p)
|
||||
prompt_cache.pop(profile_id, None)
|
||||
return {"status": "deleted"}
|
||||
|
||||
@app.post("/api/tts/generate")
|
||||
async def generate_from_profile(
|
||||
|
||||
@app.post("/api/tts/speak")
|
||||
async def speak(
|
||||
text: str = Form(...),
|
||||
profile_id: str = Form(...),
|
||||
language: str = Form("korean"),
|
||||
language: str = Form("Korean"),
|
||||
):
|
||||
"""저장된 음성 프로필로 TTS 생성"""
|
||||
"""한 문장 TTS — 캐시된 프롬프트 사용, 바로 wav 반환"""
|
||||
m = get_model()
|
||||
|
||||
profile_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl")
|
||||
if not os.path.exists(profile_path):
|
||||
return {"error": f"Profile '{profile_id}' not found"}, 404
|
||||
prompt = get_prompt(profile_id)
|
||||
if prompt is None:
|
||||
# 프롬프트가 없으면 ref_audio로 직접
|
||||
meta_path = os.path.join(PROFILES_DIR, f"{profile_id}.json")
|
||||
ref_audio_path = os.path.join(PROFILES_DIR, f"{profile_id}.wav")
|
||||
if not os.path.exists(ref_audio_path):
|
||||
return {"error": "Profile not found"}, 404
|
||||
|
||||
with open(profile_path, "rb") as f:
|
||||
prompt = pickle.load(f)
|
||||
with open(meta_path) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
print(f"Generating with profile '{profile_id}', text='{text[:50]}...', language={language}")
|
||||
wavs, sr = m.generate_voice_clone(
|
||||
text=text,
|
||||
language=language,
|
||||
voice_clone_prompt=prompt,
|
||||
)
|
||||
print(f"Generated: wavs={len(wavs)}, samples={len(wavs[0]) if len(wavs) > 0 else 0}, sr={sr}")
|
||||
kwargs = {"text": text, "language": language, "ref_audio": ref_audio_path}
|
||||
if meta.get("ref_text"):
|
||||
kwargs["ref_text"] = meta["ref_text"]
|
||||
else:
|
||||
kwargs["x_vector_only_mode"] = True
|
||||
|
||||
if len(wavs) == 0 or len(wavs[0]) == 0:
|
||||
return {"error": "Empty audio generated"}, 500
|
||||
start = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
wavs, sr = m.generate_voice_clone(**kwargs)
|
||||
elapsed = time.perf_counter() - start
|
||||
else:
|
||||
start = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
wavs, sr = m.generate_voice_clone(
|
||||
text=text, language=language, voice_clone_prompt=prompt,
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
audio_data = np.array(wavs[0], dtype=np.float32)
|
||||
print(f"speak: {len(text)} chars → {len(audio_data)/sr:.1f}s audio in {elapsed:.1f}s")
|
||||
|
||||
buf = io.BytesIO()
|
||||
sf.write(buf, audio_data, sr, format="WAV")
|
||||
buf.seek(0)
|
||||
return StreamingResponse(buf, media_type="audio/wav")
|
||||
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request, exc):
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return JSONResponse(status_code=500, content={"error": str(exc)})
|
||||
|
||||
return StreamingResponse(buf, media_type="audio/wav",
|
||||
headers={"Content-Disposition": "attachment; filename=tts_output.wav"})
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
Reference in New Issue
Block a user