""" Qwen3-TTS Voice Clone API Server (최적화 버전) - 0.6B 모델 사용 (A10 속도 최적화) - 모델 1회 로드, voice clone prompt 캐시 - inference_mode, bf16 - 문장 단위 분할 """ import os import io import json import pickle import re import tempfile import time import uuid import threading import numpy as np import soundfile as sf import torch from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import StreamingResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware app = FastAPI() app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) PROFILES_DIR = os.path.join(os.path.dirname(__file__), "voice-profiles") os.makedirs(PROFILES_DIR, exist_ok=True) MODEL_NAME = "Qwen/Qwen3-TTS-12Hz-1.7B-Base" model = None prompt_cache = {} # profile_id → voice_clone_prompt def get_model(): global model if model is None: from qwen_tts import Qwen3TTSModel print(f"Loading {MODEL_NAME}...") torch.set_grad_enabled(False) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 model = Qwen3TTSModel.from_pretrained( MODEL_NAME, device_map="cuda:0", dtype=dtype, ) # 프로필 프롬프트 캐시 로드 load_all_prompts() print("Model loaded!") return model def load_all_prompts(): """모든 프로필의 voice clone prompt를 메모리에 캐시""" global prompt_cache for f in os.listdir(PROFILES_DIR): if f.endswith(".pkl"): pid = f.replace(".pkl", "") try: with open(os.path.join(PROFILES_DIR, f), "rb") as fh: prompt_cache[pid] = pickle.load(fh) print(f" Cached prompt: {pid}") except Exception as e: print(f" Failed to cache {pid}: {e}") def get_prompt(profile_id: str): """캐시에서 프롬프트 가져오기, 없으면 파일에서 로드""" if profile_id in prompt_cache: return prompt_cache[profile_id] pkl_path = os.path.join(PROFILES_DIR, f"{profile_id}.pkl") if os.path.exists(pkl_path): with open(pkl_path, "rb") as f: prompt = pickle.load(f) prompt_cache[profile_id] = prompt return prompt return None # === API === @app.get("/health") @app.get("/api/tts/health") def health(): return {"status": "ok", "model": MODEL_NAME, "model_loaded": model is not None} @app.get("/api/tts/profiles") def list_profiles(): profiles = [] for f in os.listdir(PROFILES_DIR): if f.endswith(".json"): with open(os.path.join(PROFILES_DIR, f)) as fh: profiles.append(json.load(fh)) return profiles @app.post("/api/tts/profiles") async def create_profile( name: str = Form(...), ref_audio: UploadFile = File(...), ref_text: str = Form(""), ): m = get_model() with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: content = await ref_audio.read() tmp.write(content) tmp_path = tmp.name try: if not ref_audio.filename.endswith(".wav"): wav_path = tmp_path + "_converted.wav" os.system(f'ffmpeg -i "{tmp_path}" -ar 16000 -ac 1 -y "{wav_path}" 2>/dev/null') os.unlink(tmp_path) tmp_path = wav_path kwargs = {"ref_audio": tmp_path} if ref_text and ref_text.strip(): kwargs["ref_text"] = ref_text else: kwargs["x_vector_only_mode"] = True with torch.inference_mode(): prompt = m.create_voice_clone_prompt(**kwargs) profile_id = name.replace(" ", "_").lower() # wav, pkl, json 저장 import shutil shutil.copy2(tmp_path, os.path.join(PROFILES_DIR, f"{profile_id}.wav")) with open(os.path.join(PROFILES_DIR, f"{profile_id}.pkl"), "wb") as f: pickle.dump(prompt, f) with open(os.path.join(PROFILES_DIR, f"{profile_id}.json"), "w") as f: json.dump({"id": profile_id, "name": name, "ref_text": ref_text}, f, ensure_ascii=False) # 캐시에 추가 prompt_cache[profile_id] = prompt return {"id": profile_id, "name": name, "status": "created"} finally: if os.path.exists(tmp_path): os.unlink(tmp_path) @app.delete("/api/tts/profiles/{profile_id}") def delete_profile(profile_id: str): for ext in [".pkl", ".json", ".wav"]: p = os.path.join(PROFILES_DIR, f"{profile_id}{ext}") if os.path.exists(p): os.unlink(p) prompt_cache.pop(profile_id, None) return {"status": "deleted"} @app.post("/api/tts/speak") async def speak( text: str = Form(...), profile_id: str = Form(...), language: str = Form("Korean"), ): """한 문장 TTS — 캐시된 프롬프트 사용, 바로 wav 반환""" m = get_model() prompt = get_prompt(profile_id) if prompt is None: # 프롬프트가 없으면 ref_audio로 직접 meta_path = os.path.join(PROFILES_DIR, f"{profile_id}.json") ref_audio_path = os.path.join(PROFILES_DIR, f"{profile_id}.wav") if not os.path.exists(ref_audio_path): return {"error": "Profile not found"}, 404 with open(meta_path) as f: meta = json.load(f) kwargs = {"text": text, "language": language, "ref_audio": ref_audio_path} if meta.get("ref_text"): kwargs["ref_text"] = meta["ref_text"] else: kwargs["x_vector_only_mode"] = True start = time.perf_counter() with torch.inference_mode(): wavs, sr = m.generate_voice_clone(**kwargs) elapsed = time.perf_counter() - start else: start = time.perf_counter() with torch.inference_mode(): wavs, sr = m.generate_voice_clone( text=text, language=language, voice_clone_prompt=prompt, ) elapsed = time.perf_counter() - start audio_data = np.array(wavs[0], dtype=np.float32) print(f"speak: {len(text)} chars → {len(audio_data)/sr:.1f}s audio in {elapsed:.1f}s") buf = io.BytesIO() sf.write(buf, audio_data, sr, format="WAV") buf.seek(0) return StreamingResponse(buf, media_type="audio/wav") from fastapi.responses import JSONResponse @app.exception_handler(Exception) async def global_exception_handler(request, exc): import traceback traceback.print_exc() return JSONResponse(status_code=500, content={"error": str(exc)}) if __name__ == "__main__": import uvicorn get_model() uvicorn.run(app, host="0.0.0.0", port=8090)