- Use 1.7B model (0.6B had tensor mismatch with cached prompts) - Speak endpoint uses ref_audio directly (not cached pkl) as fallback - Cache voice clone prompts in memory on startup - Add SpeakableText component: 🔊 icon on each p and li element - Remove old TTSReader sequential approach - Add global exception handler to TTS server - Fix profile localStorage caching - inference_mode + bf16 optimization Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
217 lines
6.6 KiB
Python
217 lines
6.6 KiB
Python
"""
|
|
Qwen3-TTS Voice Clone API Server (최적화 버전)
|
|
- 0.6B 모델 사용 (A10 속도 최적화)
|
|
- 모델 1회 로드, voice clone prompt 캐시
|
|
- inference_mode, bf16
|
|
- 문장 단위 분할
|
|
"""
|
|
import os
|
|
import io
|
|
import json
|
|
import pickle
|
|
import re
|
|
import tempfile
|
|
import time
|
|
import uuid
|
|
import threading
|
|
|
|
import numpy as np
|
|
import soundfile as sf
|
|
import torch
|
|
from fastapi import FastAPI, UploadFile, File, Form
|
|
from fastapi.responses import StreamingResponse, FileResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
app = FastAPI()
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
|
|
|
PROFILES_DIR = os.path.join(os.path.dirname(__file__), "voice-profiles")
|
|
os.makedirs(PROFILES_DIR, exist_ok=True)
|
|
|
|
MODEL_NAME = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
|
model = None
|
|
prompt_cache = {} # profile_id → voice_clone_prompt
|
|
|
|
|
|
def get_model():
|
|
global model
|
|
if model is None:
|
|
from qwen_tts import Qwen3TTSModel
|
|
print(f"Loading {MODEL_NAME}...")
|
|
torch.set_grad_enabled(False)
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
model = Qwen3TTSModel.from_pretrained(
|
|
MODEL_NAME, device_map="cuda:0", dtype=dtype,
|
|
)
|
|
# 프로필 프롬프트 캐시 로드
|
|
load_all_prompts()
|
|
print("Model loaded!")
|
|
return model
|
|
|
|
|
|
def load_all_prompts():
|
|
"""모든 프로필의 voice clone prompt를 메모리에 캐시"""
|
|
global prompt_cache
|
|
for f in os.listdir(PROFILES_DIR):
|
|
if f.endswith(".pkl"):
|
|
pid = f.replace(".pkl", "")
|
|
try:
|
|
with open(os.path.join(PROFILES_DIR, f), "rb") as fh:
|
|
prompt_cache[pid] = pickle.load(fh)
|
|
print(f" Cached prompt: {pid}")
|
|
except Exception as e:
|
|
print(f" Failed to cache {pid}: {e}")
|
|
|
|
|
|
def get_prompt(profile_id: str):
|
|
"""캐시에서 프롬프트 가져오기, 없으면 파일에서 로드"""
|
|
if profile_id in prompt_cache:
|
|
return prompt_cache[profile_id]
|
|
|
|
pkl_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl")
|
|
if os.path.exists(pkl_path):
|
|
with open(pkl_path, "rb") as f:
|
|
prompt = pickle.load(f)
|
|
prompt_cache[profile_id] = prompt
|
|
return prompt
|
|
return None
|
|
|
|
|
|
# === API ===
|
|
|
|
@app.get("/health")
|
|
@app.get("/api/tts/health")
|
|
def health():
|
|
return {"status": "ok", "model": MODEL_NAME, "model_loaded": model is not None}
|
|
|
|
|
|
@app.get("/api/tts/profiles")
|
|
def list_profiles():
|
|
profiles = []
|
|
for f in os.listdir(PROFILES_DIR):
|
|
if f.endswith(".json"):
|
|
with open(os.path.join(PROFILES_DIR, f)) as fh:
|
|
profiles.append(json.load(fh))
|
|
return profiles
|
|
|
|
|
|
@app.post("/api/tts/profiles")
|
|
async def create_profile(
|
|
name: str = Form(...),
|
|
ref_audio: UploadFile = File(...),
|
|
ref_text: str = Form(""),
|
|
):
|
|
m = get_model()
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
|
content = await ref_audio.read()
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
if not ref_audio.filename.endswith(".wav"):
|
|
wav_path = tmp_path + "_converted.wav"
|
|
os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null')
|
|
os.unlink(tmp_path)
|
|
tmp_path = wav_path
|
|
|
|
kwargs = {"ref_audio": tmp_path}
|
|
if ref_text and ref_text.strip():
|
|
kwargs["ref_text"] = ref_text
|
|
else:
|
|
kwargs["x_vector_only_mode"] = True
|
|
|
|
with torch.inference_mode():
|
|
prompt = m.create_voice_clone_prompt(**kwargs)
|
|
|
|
profile_id = name.replace(" ", "_").lower()
|
|
|
|
# wav, pkl, json 저장
|
|
import shutil
|
|
shutil.copy2(tmp_path, os.path.join(PROFILES_DIR, f"{profile_id}.wav"))
|
|
with open(os.path.join(PROFILES_DIR, f"{profile_id}.pkl"), "wb") as f:
|
|
pickle.dump(prompt, f)
|
|
with open(os.path.join(PROFILES_DIR, f"{profile_id}.json"), "w") as f:
|
|
json.dump({"id": profile_id, "name": name, "ref_text": ref_text}, f, ensure_ascii=False)
|
|
|
|
# 캐시에 추가
|
|
prompt_cache[profile_id] = prompt
|
|
|
|
return {"id": profile_id, "name": name, "status": "created"}
|
|
finally:
|
|
if os.path.exists(tmp_path):
|
|
os.unlink(tmp_path)
|
|
|
|
|
|
@app.delete("/api/tts/profiles/{profile_id}")
|
|
def delete_profile(profile_id: str):
|
|
for ext in [".pkl", ".json", ".wav"]:
|
|
p = os.path.join(PROFILES_DIR, f"{profile_id}{ext}")
|
|
if os.path.exists(p):
|
|
os.unlink(p)
|
|
prompt_cache.pop(profile_id, None)
|
|
return {"status": "deleted"}
|
|
|
|
|
|
@app.post("/api/tts/speak")
|
|
async def speak(
|
|
text: str = Form(...),
|
|
profile_id: str = Form(...),
|
|
language: str = Form("Korean"),
|
|
):
|
|
"""한 문장 TTS — 캐시된 프롬프트 사용, 바로 wav 반환"""
|
|
m = get_model()
|
|
|
|
prompt = get_prompt(profile_id)
|
|
if prompt is None:
|
|
# 프롬프트가 없으면 ref_audio로 직접
|
|
meta_path = os.path.join(PROFILES_DIR, f"{profile_id}.json")
|
|
ref_audio_path = os.path.join(PROFILES_DIR, f"{profile_id}.wav")
|
|
if not os.path.exists(ref_audio_path):
|
|
return {"error": "Profile not found"}, 404
|
|
|
|
with open(meta_path) as f:
|
|
meta = json.load(f)
|
|
|
|
kwargs = {"text": text, "language": language, "ref_audio": ref_audio_path}
|
|
if meta.get("ref_text"):
|
|
kwargs["ref_text"] = meta["ref_text"]
|
|
else:
|
|
kwargs["x_vector_only_mode"] = True
|
|
|
|
start = time.perf_counter()
|
|
with torch.inference_mode():
|
|
wavs, sr = m.generate_voice_clone(**kwargs)
|
|
elapsed = time.perf_counter() - start
|
|
else:
|
|
start = time.perf_counter()
|
|
with torch.inference_mode():
|
|
wavs, sr = m.generate_voice_clone(
|
|
text=text, language=language, voice_clone_prompt=prompt,
|
|
)
|
|
elapsed = time.perf_counter() - start
|
|
|
|
audio_data = np.array(wavs[0], dtype=np.float32)
|
|
print(f"speak: {len(text)} chars → {len(audio_data)/sr:.1f}s audio in {elapsed:.1f}s")
|
|
|
|
buf = io.BytesIO()
|
|
sf.write(buf, audio_data, sr, format="WAV")
|
|
buf.seek(0)
|
|
return StreamingResponse(buf, media_type="audio/wav")
|
|
|
|
|
|
from fastapi.responses import JSONResponse
|
|
|
|
@app.exception_handler(Exception)
|
|
async def global_exception_handler(request, exc):
|
|
import traceback
|
|
traceback.print_exc()
|
|
return JSONResponse(status_code=500, content={"error": str(exc)})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
get_model()
|
|
uvicorn.run(app, host="0.0.0.0", port=8090)
|