Use Optional[T] + from __future__ import annotations instead of T | None syntax which requires Python 3.10+. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
117 lines
3.3 KiB
Python
117 lines
3.3 KiB
Python
"""Embedding generation and Oracle vector store insertion."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import array
|
|
import os
|
|
from contextlib import contextmanager
|
|
from typing import Generator, Optional
|
|
|
|
import oci
|
|
import oracledb
|
|
from oci.generative_ai_inference import GenerativeAiInferenceClient
|
|
from oci.generative_ai_inference.models import (
|
|
EmbedTextDetails,
|
|
OnDemandServingMode,
|
|
)
|
|
|
|
# Reuse same pool as queue_db but connect to same ADB instance
|
|
_pool: Optional[oracledb.ConnectionPool] = None
|
|
|
|
|
|
def _get_pool() -> oracledb.ConnectionPool:
|
|
"""Return (or lazily create) the module-level connection pool."""
|
|
global _pool
|
|
if _pool is None:
|
|
kwargs: dict = dict(
|
|
user=os.environ["ORACLE_USER"],
|
|
password=os.environ["ORACLE_PASSWORD"],
|
|
dsn=os.environ["ORACLE_DSN"],
|
|
min=1,
|
|
max=5,
|
|
increment=1,
|
|
)
|
|
wallet = os.environ.get("ORACLE_WALLET")
|
|
if wallet:
|
|
kwargs["config_dir"] = wallet
|
|
_pool = oracledb.create_pool(**kwargs)
|
|
return _pool
|
|
|
|
|
|
@contextmanager
|
|
def _conn() -> Generator[oracledb.Connection, None, None]:
|
|
"""Context manager that acquires and releases a pooled connection."""
|
|
pool = _get_pool()
|
|
conn = pool.acquire()
|
|
try:
|
|
yield conn
|
|
conn.commit()
|
|
except Exception:
|
|
conn.rollback()
|
|
raise
|
|
finally:
|
|
pool.release(conn)
|
|
|
|
|
|
def _to_vector_param(embedding: list[float]) -> array.array:
|
|
return array.array("f", embedding)
|
|
|
|
|
|
def _embed_texts(texts: list[str]) -> list[list[float]]:
|
|
"""Generate embeddings using Cohere Embed v4 via OCI GenAI."""
|
|
config = oci.config.from_file()
|
|
client = GenerativeAiInferenceClient(
|
|
config,
|
|
service_endpoint=os.environ["OCI_GENAI_ENDPOINT"],
|
|
)
|
|
model_id = os.environ.get("OCI_EMBED_MODEL_ID", "cohere.embed-v4.0")
|
|
compartment_id = os.environ["OCI_COMPARTMENT_ID"]
|
|
|
|
details = EmbedTextDetails(
|
|
inputs=texts,
|
|
serving_mode=OnDemandServingMode(model_id=model_id),
|
|
compartment_id=compartment_id,
|
|
input_type="SEARCH_DOCUMENT",
|
|
)
|
|
response = client.embed_text(details)
|
|
return response.data.embeddings
|
|
|
|
|
|
def save_to_vector(doc_id: str, chunks: list[str]) -> list[str]:
|
|
"""Embed chunks and insert them into the Oracle vector store.
|
|
|
|
Args:
|
|
doc_id: Document identifier (e.g. 'youtube:abc12345').
|
|
chunks: List of text chunks to embed and store.
|
|
|
|
Returns:
|
|
List of inserted row UUIDs.
|
|
"""
|
|
if not chunks:
|
|
return []
|
|
|
|
embeddings = _embed_texts(chunks)
|
|
inserted_ids: list[str] = []
|
|
|
|
sql = """
|
|
INSERT INTO vector_store (doc_id, chunk_text, embedding)
|
|
VALUES (:doc_id, :chunk_text, :embedding)
|
|
RETURNING id INTO :out_id
|
|
"""
|
|
with _conn() as conn:
|
|
cursor = conn.cursor()
|
|
for chunk, embedding in zip(chunks, embeddings):
|
|
out_id_var = cursor.var(oracledb.STRING)
|
|
cursor.execute(
|
|
sql,
|
|
{
|
|
"doc_id": doc_id,
|
|
"chunk_text": chunk,
|
|
"embedding": _to_vector_param(embedding),
|
|
"out_id": out_id_var,
|
|
},
|
|
)
|
|
inserted_ids.append(out_id_var.getvalue()[0])
|
|
|
|
return inserted_ids
|