"""Vector embedding generation and storage for restaurant semantic search.""" from __future__ import annotations import array import json import logging import os import oci import oracledb from oci.generative_ai_inference import GenerativeAiInferenceClient from oci.generative_ai_inference.models import ( EmbedTextDetails, OnDemandServingMode, ) from core.db import conn logger = logging.getLogger(__name__) _EMBED_BATCH_SIZE = 96 # Cohere embed v4 max batch size def _embed_texts(texts: list[str]) -> list[list[float]]: config = oci.config.from_file() client = GenerativeAiInferenceClient( config, service_endpoint=os.environ["OCI_GENAI_ENDPOINT"], ) model_id = os.environ.get("OCI_EMBED_MODEL_ID", "cohere.embed-v4.0") compartment_id = os.environ["OCI_COMPARTMENT_ID"] details = EmbedTextDetails( inputs=texts, serving_mode=OnDemandServingMode(model_id=model_id), compartment_id=compartment_id, input_type="SEARCH_DOCUMENT", ) response = client.embed_text(details) return response.data.embeddings def _embed_texts_batched(texts: list[str]) -> list[list[float]]: """Embed texts in batches to respect API limits.""" all_embeddings: list[list[float]] = [] for i in range(0, len(texts), _EMBED_BATCH_SIZE): batch = texts[i : i + _EMBED_BATCH_SIZE] all_embeddings.extend(_embed_texts(batch)) return all_embeddings def _to_vec(embedding: list[float]) -> array.array: return array.array("f", embedding) def _parse_json_field(val, default): if val is None: return default if isinstance(val, (list, dict)): return val if hasattr(val, "read"): val = val.read() if isinstance(val, str): try: return json.loads(val) except (json.JSONDecodeError, ValueError): return default return default def _build_rich_chunk(rest: dict, video_links: list[dict]) -> str: """Build a single JSON chunk per restaurant with all relevant info.""" # Collect all foods, evaluations, video titles from linked videos all_foods: list[str] = [] all_evaluations: list[str] = [] video_titles: list[str] = [] channel_names: set[str] = set() for vl in video_links: if vl.get("title"): video_titles.append(vl["title"]) if vl.get("channel_name"): channel_names.add(vl["channel_name"]) foods = _parse_json_field(vl.get("foods_mentioned"), []) if foods: all_foods.extend(foods) ev = _parse_json_field(vl.get("evaluation"), {}) if isinstance(ev, dict) and ev.get("text"): all_evaluations.append(ev["text"]) elif isinstance(ev, str) and ev: all_evaluations.append(ev) doc = { "name": rest.get("name"), "cuisine_type": rest.get("cuisine_type"), "region": rest.get("region"), "address": rest.get("address"), "price_range": rest.get("price_range"), "menu": list(dict.fromkeys(all_foods)), # deduplicate, preserve order "summary": all_evaluations, "video_titles": video_titles, "channels": sorted(channel_names), } # Remove None/empty values doc = {k: v for k, v in doc.items() if v} return json.dumps(doc, ensure_ascii=False) def rebuild_all_vectors(): """Rebuild vector embeddings for ALL restaurants. Yields progress dicts: {"status": "progress", "current": N, "total": M, "name": "..."} Final yield: {"status": "done", "total": N} """ # 1. Get all restaurants with video links sql_restaurants = """ SELECT DISTINCT r.id, r.name, r.address, r.region, r.cuisine_type, r.price_range FROM restaurants r JOIN video_restaurants vr ON vr.restaurant_id = r.id WHERE r.latitude IS NOT NULL ORDER BY r.name """ sql_video_links = """ SELECT v.title, vr.foods_mentioned, vr.evaluation, c.channel_name FROM video_restaurants vr JOIN videos v ON v.id = vr.video_id JOIN channels c ON c.id = v.channel_id WHERE vr.restaurant_id = :rid """ # Load all restaurant data restaurants_data: list[tuple[dict, str]] = [] # (rest_dict, chunk_text) with conn() as c: cur = c.cursor() cur.execute(sql_restaurants) cols = [d[0].lower() for d in cur.description] all_rests = [dict(zip(cols, row)) for row in cur.fetchall()] total = len(all_rests) logger.info("Rebuilding vectors for %d restaurants", total) for i, rest in enumerate(all_rests): with conn() as c: cur = c.cursor() cur.execute(sql_video_links, {"rid": rest["id"]}) vl_cols = [d[0].lower() for d in cur.description] video_links = [dict(zip(vl_cols, row)) for row in cur.fetchall()] chunk = _build_rich_chunk(rest, video_links) restaurants_data.append((rest, chunk)) yield {"status": "progress", "current": i + 1, "total": total, "phase": "prepare", "name": rest["name"]} # 2. Delete all existing vectors with conn() as c: c.cursor().execute("DELETE FROM restaurant_vectors") logger.info("Cleared existing vectors") yield {"status": "progress", "current": 0, "total": total, "phase": "embed"} # 3. Embed in batches and insert chunks = [chunk for _, chunk in restaurants_data] rest_ids = [rest["id"] for rest, _ in restaurants_data] embeddings = _embed_texts_batched(chunks) logger.info("Generated %d embeddings", len(embeddings)) insert_sql = """ INSERT INTO restaurant_vectors (restaurant_id, chunk_text, embedding) VALUES (:rid, :chunk, :emb) """ with conn() as c: cur = c.cursor() for i, (rid, chunk, emb) in enumerate(zip(rest_ids, chunks, embeddings)): cur.execute(insert_sql, { "rid": rid, "chunk": chunk, "emb": _to_vec(emb), }) if (i + 1) % 50 == 0 or i + 1 == total: yield {"status": "progress", "current": i + 1, "total": total, "phase": "insert"} logger.info("Rebuilt vectors for %d restaurants", total) yield {"status": "done", "total": total} def save_restaurant_vectors(restaurant_id: str, chunks: list[str]) -> list[str]: """Embed and store text chunks for a restaurant. Returns list of inserted row IDs. """ if not chunks: return [] embeddings = _embed_texts(chunks) inserted: list[str] = [] sql = """ INSERT INTO restaurant_vectors (restaurant_id, chunk_text, embedding) VALUES (:rid, :chunk, :emb) RETURNING id INTO :out_id """ with conn() as c: cur = c.cursor() for chunk, emb in zip(chunks, embeddings): out_id = cur.var(oracledb.STRING) cur.execute(sql, { "rid": restaurant_id, "chunk": chunk, "emb": _to_vec(emb), "out_id": out_id, }) inserted.append(out_id.getvalue()[0]) return inserted def search_similar(query: str, top_k: int = 10, max_distance: float = 0.57) -> list[dict]: """Semantic search: find restaurants similar to query text. Returns list of dicts: restaurant_id, chunk_text, distance. Only results with cosine distance <= max_distance are returned. """ embeddings = _embed_texts([query]) query_vec = _to_vec(embeddings[0]) sql = """ SELECT rv.restaurant_id, rv.chunk_text, VECTOR_DISTANCE(rv.embedding, :qvec, COSINE) AS dist FROM restaurant_vectors rv WHERE VECTOR_DISTANCE(rv.embedding, :qvec2, COSINE) <= :max_dist ORDER BY dist FETCH FIRST :k ROWS ONLY """ with conn() as c: cur = c.cursor() cur.execute(sql, {"qvec": query_vec, "qvec2": query_vec, "k": top_k, "max_dist": max_distance}) return [ { "restaurant_id": r[0], "chunk_text": r[1].read() if hasattr(r[1], "read") else r[1], "distance": r[2], } for r in cur.fetchall() ]