"""Embedding generation and Oracle vector store insertion.""" import array import os from contextlib import contextmanager from typing import Generator import oci import oracledb from oci.generative_ai_inference import GenerativeAiInferenceClient from oci.generative_ai_inference.models import ( EmbedTextDetails, OnDemandServingMode, ) # Reuse same pool as queue_db but connect to same ADB instance _pool: oracledb.ConnectionPool | None = None def _get_pool() -> oracledb.ConnectionPool: """Return (or lazily create) the module-level connection pool.""" global _pool if _pool is None: kwargs: dict = dict( user=os.environ["ORACLE_USER"], password=os.environ["ORACLE_PASSWORD"], dsn=os.environ["ORACLE_DSN"], min=1, max=5, increment=1, ) wallet = os.environ.get("ORACLE_WALLET") if wallet: kwargs["config_dir"] = wallet _pool = oracledb.create_pool(**kwargs) return _pool @contextmanager def _conn() -> Generator[oracledb.Connection, None, None]: """Context manager that acquires and releases a pooled connection.""" pool = _get_pool() conn = pool.acquire() try: yield conn conn.commit() except Exception: conn.rollback() raise finally: pool.release(conn) def _to_vector_param(embedding: list[float]) -> array.array: return array.array("f", embedding) def _embed_texts(texts: list[str]) -> list[list[float]]: """Generate embeddings using Cohere Embed v4 via OCI GenAI.""" config = oci.config.from_file() client = GenerativeAiInferenceClient( config, service_endpoint=os.environ["OCI_GENAI_ENDPOINT"], ) model_id = os.environ.get("OCI_EMBED_MODEL_ID", "cohere.embed-v4.0") compartment_id = os.environ["OCI_COMPARTMENT_ID"] details = EmbedTextDetails( inputs=texts, serving_mode=OnDemandServingMode(model_id=model_id), compartment_id=compartment_id, input_type="SEARCH_DOCUMENT", ) response = client.embed_text(details) return response.data.embeddings def save_to_vector(doc_id: str, chunks: list[str]) -> list[str]: """Embed chunks and insert them into the Oracle vector store. Args: doc_id: Document identifier (e.g. 'youtube:abc12345'). chunks: List of text chunks to embed and store. Returns: List of inserted row UUIDs. """ if not chunks: return [] embeddings = _embed_texts(chunks) inserted_ids: list[str] = [] sql = """ INSERT INTO vector_store (doc_id, chunk_text, embedding) VALUES (:doc_id, :chunk_text, :embedding) RETURNING id INTO :out_id """ with _conn() as conn: cursor = conn.cursor() for chunk, embedding in zip(chunks, embeddings): out_id_var = cursor.var(oracledb.STRING) cursor.execute( sql, { "doc_id": doc_id, "chunk_text": chunk, "embedding": _to_vector_param(embedding), "out_id": out_id_var, }, ) inserted_ids.append(out_id_var.getvalue()[0]) return inserted_ids