"""백테스트 결과 Oracle DB 저장 모듈. 테이블: backtest_runs - 실행 단위 (실행시각, 설명, 파라미터) backtest_results - 조건별 집계 (run_id + label) backtest_trade_log - 개별 거래 (run_id + label + 종목 + pnl + fng + ...) """ from __future__ import annotations import json import os from contextlib import contextmanager from datetime import datetime from pathlib import Path from typing import Generator import oracledb from dotenv import load_dotenv load_dotenv(dotenv_path=Path(__file__).parent / ".env") _pool: oracledb.ConnectionPool | None = None def _get_pool() -> oracledb.ConnectionPool: global _pool if _pool is None: kwargs: dict = dict( user=os.environ["ORACLE_USER"], password=os.environ["ORACLE_PASSWORD"], dsn=os.environ["ORACLE_DSN"], min=1, max=3, increment=1, ) wallet = os.environ.get("ORACLE_WALLET") if wallet: kwargs["config_dir"] = wallet _pool = oracledb.create_pool(**kwargs) return _pool @contextmanager def _conn() -> Generator[oracledb.Connection, None, None]: pool = _get_pool() conn = pool.acquire() try: yield conn conn.commit() except Exception: conn.rollback() raise finally: pool.release(conn) # ── DDL ──────────────────────────────────────────────────────── _DDL_RUNS = """ CREATE TABLE backtest_runs ( run_id VARCHAR2(36) DEFAULT SYS_GUID() PRIMARY KEY, run_name VARCHAR2(200) NOT NULL, description VARCHAR2(1000), params_json CLOB, created_at TIMESTAMP DEFAULT SYSTIMESTAMP NOT NULL ) """ _DDL_RESULTS = """ CREATE TABLE backtest_results ( id NUMBER GENERATED ALWAYS AS IDENTITY PRIMARY KEY, run_id VARCHAR2(36) NOT NULL, label VARCHAR2(100) NOT NULL, n_trades NUMBER, win_rate NUMBER(6,3), avg_pnl NUMBER(10,4), total_pnl NUMBER(12,4), rr NUMBER(8,4), avg_win NUMBER(10,4), avg_loss NUMBER(10,4), max_dd NUMBER(10,4), fng_lo NUMBER, fng_hi NUMBER, created_at TIMESTAMP DEFAULT SYSTIMESTAMP NOT NULL, CONSTRAINT fk_br_run FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ) """ _DDL_TRADES = """ CREATE TABLE backtest_trade_log ( id NUMBER GENERATED ALWAYS AS IDENTITY PRIMARY KEY, run_id VARCHAR2(36) NOT NULL, label VARCHAR2(100), ticker VARCHAR2(20), pnl NUMBER(10,4), hold_h NUMBER, fng_val NUMBER, exit_type VARCHAR2(10), created_at TIMESTAMP DEFAULT SYSTIMESTAMP NOT NULL, CONSTRAINT fk_bt_run FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ) """ def ensure_tables() -> None: """백테스트 테이블이 없으면 생성.""" with _conn() as conn: cur = conn.cursor() for tbl_name, ddl in [ ("BACKTEST_RUNS", _DDL_RUNS), ("BACKTEST_RESULTS", _DDL_RESULTS), ("BACKTEST_TRADE_LOG", _DDL_TRADES), ]: cur.execute( "SELECT COUNT(*) FROM user_tables WHERE table_name=:1", [tbl_name] ) if cur.fetchone()[0] == 0: cur.execute(ddl) print(f" {tbl_name} 테이블 생성 완료") # ── 삽입 헬퍼 ────────────────────────────────────────────────── def insert_run(run_name: str, description: str = "", params: dict | None = None) -> str: """새 백테스트 실행 레코드 삽입. run_id 반환.""" sql = """ INSERT INTO backtest_runs (run_name, description, params_json) VALUES (:rname, :rdesc, :rparams) RETURNING run_id INTO :out_id """ with _conn() as conn: cur = conn.cursor() out = cur.var(oracledb.STRING) cur.execute(sql, { "rname": run_name, "rdesc": description, "rparams": json.dumps(params or {}, ensure_ascii=False), "out_id": out, }) return out.getvalue()[0] def insert_result( run_id: str, label: str, stats: dict, fng_lo: int | None = None, fng_hi: int | None = None, ) -> None: """조건별 집계 결과 삽입.""" sql = """ INSERT INTO backtest_results (run_id, label, n_trades, win_rate, avg_pnl, total_pnl, rr, avg_win, avg_loss, max_dd, fng_lo, fng_hi) VALUES (:run_id, :label, :n, :wr, :avg_pnl, :total_pnl, :rr, :avg_win, :avg_loss, :max_dd, :fng_lo, :fng_hi) """ with _conn() as conn: conn.cursor().execute(sql, { "run_id": run_id, "label": label, "n": stats.get("n", 0), "wr": round(stats.get("wr", 0), 3), "avg_pnl": round(stats.get("avg_pnl", 0), 4), "total_pnl": round(stats.get("total_pnl", 0), 4), "rr": round(stats.get("rr", 0), 4), "avg_win": round(stats.get("avg_win", 0), 4), "avg_loss": round(stats.get("avg_loss", 0), 4), "max_dd": round(stats.get("max_dd", 0), 4), "fng_lo": fng_lo, "fng_hi": fng_hi, }) def insert_trades_bulk( run_id: str, label: str, ticker: str, trades: list, ) -> None: """개별 거래 목록 일괄 삽입.""" if not trades: return sql = """ INSERT INTO backtest_trade_log (run_id, label, ticker, pnl, hold_h, fng_val, exit_type) VALUES (:run_id, :label, :ticker, :pnl, :hold_h, :fng_val, :exit_type) """ rows = [] for t in trades: rows.append({ "run_id": run_id, "label": label, "ticker": ticker, "pnl": round(float(getattr(t, "pnl", 0)), 4), "hold_h": int(getattr(t, "h", 0)), "fng_val": int(getattr(t, "fng", 0)), "exit_type": str(getattr(t, "exit", "")), }) with _conn() as conn: conn.cursor().executemany(sql, rows) # ── 조회 ─────────────────────────────────────────────────────── def list_runs(limit: int = 20) -> list[dict]: """최근 백테스트 실행 목록 반환.""" sql = """ SELECT run_id, run_name, description, created_at FROM backtest_runs ORDER BY created_at DESC FETCH FIRST :n ROWS ONLY """ with _conn() as conn: cur = conn.cursor() cur.execute(sql, {"n": limit}) rows = cur.fetchall() return [ {"run_id": r[0], "run_name": r[1], "description": r[2], "created_at": r[3].strftime("%Y-%m-%d %H:%M")} for r in rows ] def get_results(run_id: str) -> list[dict]: """특정 run_id의 조건별 결과 반환.""" sql = """ SELECT label, n_trades, win_rate, avg_pnl, total_pnl, rr, avg_win, avg_loss, max_dd, fng_lo, fng_hi FROM backtest_results WHERE run_id = :run_id ORDER BY avg_pnl DESC """ with _conn() as conn: cur = conn.cursor() cur.execute(sql, {"run_id": run_id}) cols = ["label", "n_trades", "win_rate", "avg_pnl", "total_pnl", "rr", "avg_win", "avg_loss", "max_dd", "fng_lo", "fng_hi"] return [dict(zip(cols, r)) for r in cur.fetchall()] if __name__ == "__main__": print("백테스트 DB 테이블 확인/생성...") ensure_tables() print("완료. 최근 실행 목록:") for r in list_runs(5): print(f" {r['created_at']} {r['run_name']}")