Files
sundol/tts-server.py
joungmin 20210830cf Fix TTS: switch to 1.7B with ref_audio, speakable text on all lines
- Use 1.7B model (0.6B had tensor mismatch with cached prompts)
- Speak endpoint uses ref_audio directly (not cached pkl) as fallback
- Cache voice clone prompts in memory on startup
- Add SpeakableText component: 🔊 icon on each p and li element
- Remove old TTSReader sequential approach
- Add global exception handler to TTS server
- Fix profile localStorage caching
- inference_mode + bf16 optimization

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-13 12:14:06 +00:00

217 lines
6.6 KiB
Python

"""
Qwen3-TTS Voice Clone API Server (최적화 버전)
- 0.6B 모델 사용 (A10 속도 최적화)
- 모델 1회 로드, voice clone prompt 캐시
- inference_mode, bf16
- 문장 단위 분할
"""
import os
import io
import json
import pickle
import re
import tempfile
import time
import uuid
import threading
import numpy as np
import soundfile as sf
import torch
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
PROFILES_DIR = os.path.join(os.path.dirname(__file__), "voice-profiles")
os.makedirs(PROFILES_DIR, exist_ok=True)
MODEL_NAME = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
model = None
prompt_cache = {} # profile_id → voice_clone_prompt
def get_model():
global model
if model is None:
from qwen_tts import Qwen3TTSModel
print(f"Loading {MODEL_NAME}...")
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
model = Qwen3TTSModel.from_pretrained(
MODEL_NAME, device_map="cuda:0", dtype=dtype,
)
# 프로필 프롬프트 캐시 로드
load_all_prompts()
print("Model loaded!")
return model
def load_all_prompts():
"""모든 프로필의 voice clone prompt를 메모리에 캐시"""
global prompt_cache
for f in os.listdir(PROFILES_DIR):
if f.endswith(".pkl"):
pid = f.replace(".pkl", "")
try:
with open(os.path.join(PROFILES_DIR, f), "rb") as fh:
prompt_cache[pid] = pickle.load(fh)
print(f" Cached prompt: {pid}")
except Exception as e:
print(f" Failed to cache {pid}: {e}")
def get_prompt(profile_id: str):
"""캐시에서 프롬프트 가져오기, 없으면 파일에서 로드"""
if profile_id in prompt_cache:
return prompt_cache[profile_id]
pkl_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl")
if os.path.exists(pkl_path):
with open(pkl_path, "rb") as f:
prompt = pickle.load(f)
prompt_cache[profile_id] = prompt
return prompt
return None
# === API ===
@app.get("/health")
@app.get("/api/tts/health")
def health():
return {"status": "ok", "model": MODEL_NAME, "model_loaded": model is not None}
@app.get("/api/tts/profiles")
def list_profiles():
profiles = []
for f in os.listdir(PROFILES_DIR):
if f.endswith(".json"):
with open(os.path.join(PROFILES_DIR, f)) as fh:
profiles.append(json.load(fh))
return profiles
@app.post("/api/tts/profiles")
async def create_profile(
name: str = Form(...),
ref_audio: UploadFile = File(...),
ref_text: str = Form(""),
):
m = get_model()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
content = await ref_audio.read()
tmp.write(content)
tmp_path = tmp.name
try:
if not ref_audio.filename.endswith(".wav"):
wav_path = tmp_path + "_converted.wav"
os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null')
os.unlink(tmp_path)
tmp_path = wav_path
kwargs = {"ref_audio": tmp_path}
if ref_text and ref_text.strip():
kwargs["ref_text"] = ref_text
else:
kwargs["x_vector_only_mode"] = True
with torch.inference_mode():
prompt = m.create_voice_clone_prompt(**kwargs)
profile_id = name.replace(" ", "_").lower()
# wav, pkl, json 저장
import shutil
shutil.copy2(tmp_path, os.path.join(PROFILES_DIR, f"{profile_id}.wav"))
with open(os.path.join(PROFILES_DIR, f"{profile_id}.pkl"), "wb") as f:
pickle.dump(prompt, f)
with open(os.path.join(PROFILES_DIR, f"{profile_id}.json"), "w") as f:
json.dump({"id": profile_id, "name": name, "ref_text": ref_text}, f, ensure_ascii=False)
# 캐시에 추가
prompt_cache[profile_id] = prompt
return {"id": profile_id, "name": name, "status": "created"}
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
@app.delete("/api/tts/profiles/{profile_id}")
def delete_profile(profile_id: str):
for ext in [".pkl", ".json", ".wav"]:
p = os.path.join(PROFILES_DIR, f"{profile_id}{ext}")
if os.path.exists(p):
os.unlink(p)
prompt_cache.pop(profile_id, None)
return {"status": "deleted"}
@app.post("/api/tts/speak")
async def speak(
text: str = Form(...),
profile_id: str = Form(...),
language: str = Form("Korean"),
):
"""한 문장 TTS — 캐시된 프롬프트 사용, 바로 wav 반환"""
m = get_model()
prompt = get_prompt(profile_id)
if prompt is None:
# 프롬프트가 없으면 ref_audio로 직접
meta_path = os.path.join(PROFILES_DIR, f"{profile_id}.json")
ref_audio_path = os.path.join(PROFILES_DIR, f"{profile_id}.wav")
if not os.path.exists(ref_audio_path):
return {"error": "Profile not found"}, 404
with open(meta_path) as f:
meta = json.load(f)
kwargs = {"text": text, "language": language, "ref_audio": ref_audio_path}
if meta.get("ref_text"):
kwargs["ref_text"] = meta["ref_text"]
else:
kwargs["x_vector_only_mode"] = True
start = time.perf_counter()
with torch.inference_mode():
wavs, sr = m.generate_voice_clone(**kwargs)
elapsed = time.perf_counter() - start
else:
start = time.perf_counter()
with torch.inference_mode():
wavs, sr = m.generate_voice_clone(
text=text, language=language, voice_clone_prompt=prompt,
)
elapsed = time.perf_counter() - start
audio_data = np.array(wavs[0], dtype=np.float32)
print(f"speak: {len(text)} chars → {len(audio_data)/sr:.1f}s audio in {elapsed:.1f}s")
buf = io.BytesIO()
sf.write(buf, audio_data, sr, format="WAV")
buf.seek(0)
return StreamingResponse(buf, media_type="audio/wav")
from fastapi.responses import JSONResponse
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
import traceback
traceback.print_exc()
return JSONResponse(status_code=500, content={"error": str(exc)})
if __name__ == "__main__":
import uvicorn
get_model()
uvicorn.run(app, host="0.0.0.0", port=8090)