diff --git a/cli/utils/env_loader.py b/cli/utils/env_loader.py
index 768fdce..6d03c3f 100644
--- a/cli/utils/env_loader.py
+++ b/cli/utils/env_loader.py
@@ -71,14 +71,14 @@ def set_vectordb(
"""VectorDB 타입과 위치를 설정합니다.
Args:
- vectordb_type (str): VectorDB 타입 ("faiss" 또는 "pgvector").
+ vectordb_type (str): VectorDB 타입 ("faiss" 또는 "pgvector" 또는 "qdrant").
vectordb_location (Optional[str]): 경로 또는 연결 URL.
Raises:
ValueError: 잘못된 타입이나 경로/URL일 경우.
"""
- if vectordb_type not in ("faiss", "pgvector"):
+ if vectordb_type not in ("faiss", "pgvector", "qdrant"):
raise ValueError(f"지원하지 않는 VectorDB 타입: {vectordb_type}")
os.environ["VECTORDB_TYPE"] = vectordb_type
diff --git a/docker/docker-compose-pgvector.yml b/docker/docker-compose-pgvector.yml
index 8ad5e16..443baf9 100644
--- a/docker/docker-compose-pgvector.yml
+++ b/docker/docker-compose-pgvector.yml
@@ -1,7 +1,13 @@
-# docker compose -f docker-compose-pgvector.yml up
-# docker compose -f docker-compose-pgvector.yml down
+# docker compose -f docker/docker-compose.yml -f docker/docker-compose-pgvector.yml up
+# docker compose -f docker/docker-compose.yml -f docker/docker-compose-pgvector.yml down
services:
+ streamlit:
+ environment:
+ - DATABASE_URL=postgresql://pgvector:pgvector@pgvector:5432/streamlit
+ depends_on:
+ - pgvector
+
pgvector:
image: pgvector/pgvector:pg17
hostname: pgvector
@@ -12,7 +18,7 @@ services:
environment:
POSTGRES_USER: pgvector
POSTGRES_PASSWORD: pgvector
- POSTGRES_DB: pgvector
+ POSTGRES_DB: streamlit
TZ: Asia/Seoul
LANG: en_US.utf8
volumes:
diff --git a/docker/docker-compose-postgres.yml b/docker/docker-compose-postgres.yml
index 696f7e1..b8b4903 100644
--- a/docker/docker-compose-postgres.yml
+++ b/docker/docker-compose-postgres.yml
@@ -1,7 +1,13 @@
-# docker compose -f docker-compose-postgres.yml up
-# docker compose -f docker-compose-postgres.yml down
+# docker compose -f docker/docker-compose.yml -f docker/docker-compose-postgres.yml up
+# docker compose -f docker/docker-compose.yml -f docker/docker-compose-postgres.yml down
services:
+ streamlit:
+ environment:
+ - DATABASE_URL=postgresql://postgres:postgres@postgres:5432/streamlit
+ depends_on:
+ - postgres
+
postgres:
image: postgres:15
hostname: postgres
@@ -12,7 +18,7 @@ services:
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
- POSTGRES_DB: postgres
+ POSTGRES_DB: streamlit
TZ: Asia/Seoul
LANG: en_US.utf8
volumes:
diff --git a/docker/docker-compose-qdrant.yml b/docker/docker-compose-qdrant.yml
new file mode 100644
index 0000000..98caca5
--- /dev/null
+++ b/docker/docker-compose-qdrant.yml
@@ -0,0 +1,23 @@
+# docker compose -f docker/docker-compose.yml -f docker/docker-compose-qdrant.yml up
+# docker compose -f docker/docker-compose.yml -f docker/docker-compose-qdrant.yml down
+
+services:
+ streamlit:
+ environment:
+ - QDRANT_HOST=qdrant
+ - QDRANT_PORT=6333
+ depends_on:
+ - qdrant
+
+ qdrant:
+ image: qdrant/qdrant:latest
+ hostname: qdrant
+ container_name: qdrant
+ restart: always
+ ports:
+ - "6333:6333"
+ volumes:
+ - qdrant_data:/qdrant/storage
+
+volumes:
+ qdrant_data:
diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml
index 115575a..66a83ce 100644
--- a/docker/docker-compose.yml
+++ b/docker/docker-compose.yml
@@ -13,22 +13,4 @@ services:
- ../.env
environment:
- STREAMLIT_SERVER_PORT=8501
- - DATABASE_URL=postgresql://pgvector:pgvector@localhost:5432/streamlit
- depends_on:
- - pgvector
- pgvector:
- image: pgvector/pgvector:pg17
- hostname: pgvector
- container_name: pgvector
- environment:
- POSTGRES_USER: pgvector
- POSTGRES_PASSWORD: pgvector
- POSTGRES_DB: streamlit
- ports:
- - "5432:5432"
- volumes:
- - pgdata:/var/lib/postgresql/data
-
-volumes:
- pgdata:
diff --git a/interface/app_pages/chatbot.py b/interface/app_pages/chatbot.py
index 9879147..17bffab 100644
--- a/interface/app_pages/chatbot.py
+++ b/interface/app_pages/chatbot.py
@@ -16,6 +16,8 @@
)
from interface.core.config import load_config
+from interface.app_pages.chatbot_components import render_context_cards
+
def initialize_session_state():
"""세션 상태 초기화 함수
@@ -123,6 +125,9 @@ def initialize_session_state():
for message in st.session_state.chatbot_messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
+ # 컨텍스트 정보가 있으면 렌더링
+ if "context" in message:
+ render_context_cards(message["context"])
# 사용자 입력 처리
if prompt := st.chat_input("메시지를 입력하세요"):
@@ -147,10 +152,28 @@ def initialize_session_state():
# 응답 표시
st.markdown(response_content)
+ # 컨텍스트 데이터 추출
+ context_data = {}
+ if response.get("table_schema_outputs"):
+ context_data["table_schema_outputs"] = response["table_schema_outputs"]
+ if response.get("glossary_outputs"):
+ context_data["glossary_outputs"] = response["glossary_outputs"]
+ if response.get("query_example_outputs"):
+ context_data["query_example_outputs"] = response[
+ "query_example_outputs"
+ ]
+
+ # 컨텍스트 카드가 있으면 렌더링
+ if context_data:
+ render_context_cards(context_data)
+
# AI 응답을 기록에 추가
- st.session_state.chatbot_messages.append(
- {"role": "assistant", "content": response_content}
- )
+ msg_data = {"role": "assistant", "content": response_content}
+ if context_data:
+ msg_data["context"] = context_data
+
+ st.session_state.chatbot_messages.append(msg_data)
+
except Exception as e:
error_msg = f"오류가 발생했습니다: {str(e)}"
st.error(error_msg)
diff --git a/interface/app_pages/chatbot_components/__init__.py b/interface/app_pages/chatbot_components/__init__.py
new file mode 100644
index 0000000..01fe149
--- /dev/null
+++ b/interface/app_pages/chatbot_components/__init__.py
@@ -0,0 +1,3 @@
+from .context_cards import render_context_cards
+
+__all__ = ["render_context_cards"]
diff --git a/interface/app_pages/chatbot_components/context_cards.py b/interface/app_pages/chatbot_components/context_cards.py
new file mode 100644
index 0000000..25f980e
--- /dev/null
+++ b/interface/app_pages/chatbot_components/context_cards.py
@@ -0,0 +1,114 @@
+import streamlit as st
+import json
+
+
+def render_context_cards(context):
+ """
+ 검색된 컨텍스트(테이블 스키마, 용어집, 쿼리 예제)를 가로 스크롤 가능한 카드 형태로 렌더링합니다.
+ """
+ if not context:
+ return
+
+ cards_html = ""
+
+ # 카드 HTML 생성 헬퍼 함수
+ def create_card(title, icon, content_html, color_border):
+ return f"""
+
+ {icon} {title}
+
+
+ {content_html}
+
+
"""
+
+ # 1. 테이블 스키마 처리
+ if "table_schema_outputs" in context:
+ for output in context["table_schema_outputs"]:
+ if isinstance(output, dict):
+ # 에러 처리
+ if output.get("error"):
+ cards_html += create_card(
+ "Error", "⚠️", f"오류: {output.get('message')}", "#F44336"
+ )
+ continue
+
+ # 테이블별로 카드 생성
+ for table_name, table_info in output.items():
+ if not isinstance(table_info, dict):
+ continue
+
+ desc = table_info.get("table_description", "설명 없음")
+ columns_html = "컬럼:
"
+ columns_html += ""
+
+ for col_name, col_desc in table_info.items():
+ if col_name == "table_description":
+ continue
+ columns_html += f"- {col_name}: {col_desc}
"
+ columns_html += "
"
+
+ content = f"{table_name}
{desc}
{columns_html}"
+ cards_html += create_card("Table Schema", "🗃️", content, "#4CAF50")
+ else:
+ # 문자열이나 기타 타입인 경우
+ cards_html += create_card("Table Schema", "🗃️", str(output), "#4CAF50")
+
+ # 2. 용어집 처리
+ if "glossary_outputs" in context:
+ for output in context["glossary_outputs"]:
+ if isinstance(output, list):
+ for term in output:
+ if isinstance(term, dict):
+ name = term.get("name", "이름 없음")
+ desc = term.get("description", "설명 없음")
+ content = f"{name}
{desc}
"
+ cards_html += create_card("Glossary", "📚", content, "#2196F3")
+ else:
+ cards_html += create_card(
+ "Glossary", "📚", str(term), "#2196F3"
+ )
+ elif isinstance(output, dict) and output.get("error"):
+ cards_html += create_card(
+ "Error", "⚠️", f"오류: {output.get('message')}", "#F44336"
+ )
+ else:
+ cards_html += create_card("Glossary", "📚", str(output), "#2196F3")
+
+ # 3. 쿼리 예제 처리
+ if "query_example_outputs" in context:
+ for output in context["query_example_outputs"]:
+ if isinstance(output, list):
+ for example in output:
+ if isinstance(example, dict):
+ name = example.get("name", "예제")
+ desc = example.get("description", "")
+ sql = example.get("statement", "")
+
+ content = f"{name}
"
+ if desc:
+ content += f"{desc}
"
+ content += f"{sql}
"
+
+ cards_html += create_card(
+ "Query Example", "💡", content, "#FF9800"
+ )
+ else:
+ cards_html += create_card(
+ "Query Example", "💡", str(example), "#FF9800"
+ )
+ elif isinstance(output, dict) and output.get("error"):
+ cards_html += create_card(
+ "Error", "⚠️", f"오류: {output.get('message')}", "#F44336"
+ )
+ else:
+ cards_html += create_card("Query Example", "💡", str(output), "#FF9800")
+
+ if not cards_html:
+ return
+
+ # 가로 스크롤 컨테이너 렌더링
+ st.markdown(
+ f"""{cards_html}
""",
+ unsafe_allow_html=True,
+ )
diff --git a/interface/app_pages/settings_sections/data_source_section.py b/interface/app_pages/settings_sections/data_source_section.py
index 2f63881..e1b3383 100644
--- a/interface/app_pages/settings_sections/data_source_section.py
+++ b/interface/app_pages/settings_sections/data_source_section.py
@@ -103,10 +103,36 @@ def render_data_source_section(config: Config | None = None) -> None:
new_url = st.text_input(
"URL", value=existing.url, key="dh_edit_url"
)
- new_faiss = st.text_input(
- "FAISS 저장 경로(선택)",
- value=existing.faiss_path or "",
- key="dh_edit_faiss",
+ new_vdb_type = st.selectbox(
+ "VectorDB 타입",
+ options=["faiss", "pgvector", "qdrant"],
+ index=(
+ 0
+ if existing.vectordb_type == "faiss"
+ else (1 if existing.vectordb_type == "pgvector" else 2)
+ ),
+ key="dh_edit_vdb_type",
+ )
+ new_vdb_loc_placeholder = (
+ "FAISS 디렉토리 경로 (예: ./dev/table_info_db)"
+ if new_vdb_type == "faiss"
+ else (
+ "pgvector 연결 문자열 (postgresql://...)"
+ if new_vdb_type == "pgvector"
+ else "Qdrant URL (예: http://localhost:6333)"
+ )
+ )
+ new_vdb_location = st.text_input(
+ "VectorDB 위치",
+ value=existing.vectordb_location or existing.faiss_path or "",
+ key="dh_edit_vdb_loc",
+ placeholder=new_vdb_loc_placeholder,
+ )
+ new_vdb_api_key = st.text_input(
+ "VectorDB API Key (선택)",
+ value=existing.vectordb_api_key or "",
+ type="password",
+ key="dh_edit_vdb_key",
)
new_note = st.text_input(
"메모", value=existing.note or "", key="dh_edit_note"
@@ -128,7 +154,14 @@ def render_data_source_section(config: Config | None = None) -> None:
update_datahub_source(
name=edit_dh,
url=new_url,
- faiss_path=(new_faiss or None),
+ faiss_path=(
+ new_vdb_location
+ if new_vdb_type == "faiss"
+ else None
+ ),
+ vectordb_type=new_vdb_type,
+ vectordb_location=(new_vdb_location or None),
+ vectordb_api_key=(new_vdb_api_key or None),
note=(new_note or None),
)
st.success("저장되었습니다.")
@@ -147,10 +180,29 @@ def render_data_source_section(config: Config | None = None) -> None:
dh_url = st.text_input(
"URL", key="dh_url", placeholder="http://localhost:8080"
)
- dh_faiss = st.text_input(
- "FAISS 저장 경로(선택)",
- key="dh_faiss",
- placeholder="예: ./dev/table_info_db",
+ dh_vdb_type = st.selectbox(
+ "VectorDB 타입",
+ options=["faiss", "pgvector", "qdrant"],
+ key="dh_new_vdb_type",
+ )
+ dh_vdb_loc_placeholder = (
+ "FAISS 디렉토리 경로 (예: ./dev/table_info_db)"
+ if dh_vdb_type == "faiss"
+ else (
+ "pgvector 연결 문자열 (postgresql://...)"
+ if dh_vdb_type == "pgvector"
+ else "Qdrant URL (예: http://localhost:6333)"
+ )
+ )
+ dh_vdb_location = st.text_input(
+ "VectorDB 위치",
+ key="dh_new_vdb_loc",
+ placeholder=dh_vdb_loc_placeholder,
+ )
+ dh_vdb_api_key = st.text_input(
+ "VectorDB API Key (선택)",
+ type="password",
+ key="dh_new_vdb_key",
)
dh_note = st.text_input("메모", key="dh_note", placeholder="선택")
@@ -174,7 +226,12 @@ def render_data_source_section(config: Config | None = None) -> None:
add_datahub_source(
name=dh_name,
url=dh_url,
- faiss_path=(dh_faiss or None),
+ faiss_path=(
+ dh_vdb_location if dh_vdb_type == "faiss" else None
+ ),
+ vectordb_type=dh_vdb_type,
+ vectordb_location=(dh_vdb_location or None),
+ vectordb_api_key=(dh_vdb_api_key or None),
note=dh_note or None,
)
st.success("추가되었습니다.")
@@ -216,14 +273,22 @@ def render_data_source_section(config: Config | None = None) -> None:
if existing:
new_type = st.selectbox(
"타입",
- options=["faiss", "pgvector"],
- index=(0 if existing.type == "faiss" else 1),
+ options=["faiss", "pgvector", "qdrant"],
+ index=(
+ 0
+ if existing.type == "faiss"
+ else (1 if existing.type == "pgvector" else 2)
+ ),
key="vdb_edit_type",
)
new_loc_placeholder = (
"FAISS 디렉토리 경로 (예: ./dev/table_info_db)"
if new_type == "faiss"
- else "pgvector 연결 문자열 (postgresql://user:pass@host:port/db)"
+ else (
+ "pgvector 연결 문자열 (postgresql://user:pass@host:port/db)"
+ if new_type == "pgvector"
+ else "Qdrant URL (예: http://localhost:6333)"
+ )
)
new_location = st.text_input(
"위치",
@@ -231,6 +296,12 @@ def render_data_source_section(config: Config | None = None) -> None:
key="vdb_edit_location",
placeholder=new_loc_placeholder,
)
+ new_api_key = st.text_input(
+ "API Key (선택)",
+ value=existing.api_key or "",
+ type="password",
+ key="vdb_edit_key",
+ )
new_prefix = st.text_input(
"컬렉션 접두사(선택)",
value=existing.collection_prefix or "",
@@ -258,6 +329,7 @@ def render_data_source_section(config: Config | None = None) -> None:
name=edit_vdb,
vtype=new_type,
location=new_location,
+ api_key=(new_api_key or None),
collection_prefix=(new_prefix or None),
note=(new_note or None),
)
@@ -275,16 +347,23 @@ def render_data_source_section(config: Config | None = None) -> None:
st.write("VectorDB 추가")
vdb_name = st.text_input("이름", key="vdb_name")
vdb_type = st.selectbox(
- "타입", options=["faiss", "pgvector"], key="vdb_type"
+ "타입", options=["faiss", "pgvector", "qdrant"], key="vdb_type"
)
vdb_loc_placeholder = (
"FAISS 디렉토리 경로 (예: ./dev/table_info_db)"
if vdb_type == "faiss"
- else "pgvector 연결 문자열 (postgresql://user:pass@host:port/db)"
+ else (
+ "pgvector 연결 문자열 (postgresql://user:pass@host:port/db)"
+ if vdb_type == "pgvector"
+ else "Qdrant URL (예: http://localhost:6333)"
+ )
)
vdb_location = st.text_input(
"위치", key="vdb_location", placeholder=vdb_loc_placeholder
)
+ vdb_api_key = st.text_input(
+ "API Key (선택)", type="password", key="vdb_new_key"
+ )
vdb_prefix = st.text_input(
"컬렉션 접두사(선택)", key="vdb_prefix", placeholder="예: app1_"
)
@@ -312,6 +391,7 @@ def render_data_source_section(config: Config | None = None) -> None:
name=vdb_name,
vtype=vdb_type,
location=vdb_location,
+ api_key=(vdb_api_key or None),
collection_prefix=(vdb_prefix or None),
note=(vdb_note or None),
)
diff --git a/interface/app_pages/settings_sections/db_section.py b/interface/app_pages/settings_sections/db_section.py
index 2ebeec5..b2236dd 100644
--- a/interface/app_pages/settings_sections/db_section.py
+++ b/interface/app_pages/settings_sections/db_section.py
@@ -168,14 +168,23 @@ def render_db_section() -> None:
default_secret = (existing.extra or {}).get(
k, ""
) or _prefill_from_env(new_type, k)
+
+ # 비밀키는 값을 미리 채우지 않음 (보안)
+ placeholder = (
+ "값 설정됨 (변경하려면 입력)" if default_secret else "값 없음"
+ )
sv = st.text_input(
label,
- value=str(default_secret or ""),
+ value="",
type="password",
key=f"db_edit_secret_{k}",
+ placeholder=placeholder,
)
- if sv != "":
+ # 입력값이 없으면 기존 값 유지, 있으면 새 값 사용
+ if sv:
secrets[k] = sv
+ elif default_secret:
+ secrets[k] = default_secret
cols = st.columns([1, 1, 2])
with cols[0]:
@@ -257,7 +266,7 @@ def render_db_section() -> None:
secrets_new: dict[str, str] = {}
for label, k in _secret_fields(db_type):
- sv = st.text_input(label, key=f"db_new_secret_{k}")
+ sv = st.text_input(label, key=f"db_new_secret_{k}", type="password")
if sv != "":
secrets_new[k] = sv
diff --git a/interface/app_pages/settings_sections/llm_section.py b/interface/app_pages/settings_sections/llm_section.py
index f5fe59f..3ee029b 100644
--- a/interface/app_pages/settings_sections/llm_section.py
+++ b/interface/app_pages/settings_sections/llm_section.py
@@ -139,20 +139,31 @@ def render_llm_section(config: Config | None = None) -> None:
values: dict[str, str | None] = {}
non_secret_values: dict[str, str | None] = {}
for label, env_key, is_secret in fields:
- prefill = st.session_state.get(env_key) or os.getenv(env_key) or ""
+ current_val = st.session_state.get(env_key) or os.getenv(env_key) or ""
if is_secret:
- values[env_key] = st.text_input(
- label, value=prefill, type="password", key=f"llm_{env_key}"
+ # 비밀키는 값을 미리 채우지 않음 (보안)
+ # 값이 설정되어 있으면 placeholder로 표시
+ placeholder = (
+ "값 설정됨 (변경하려면 입력)" if current_val else "값 없음"
)
+ new_val = st.text_input(
+ label,
+ value="",
+ type="password",
+ key=f"llm_{env_key}",
+ placeholder=placeholder,
+ )
+ # 입력값이 없으면 기존 값 유지, 있으면 새 값 사용
+ values[env_key] = new_val if new_val else current_val
else:
- v = st.text_input(label, value=prefill, key=f"llm_{env_key}")
+ v = st.text_input(label, value=current_val, key=f"llm_{env_key}")
values[env_key] = v
non_secret_values[env_key] = v
# 메시지 영역
llm_msg = st.empty()
- st.markdown("**프로파일 저장 (비밀키 제외)**")
+ st.markdown("**프로파일 저장**")
with st.form("llm_profile_save_form"):
prof_cols = st.columns([2, 2])
with prof_cols[0]:
@@ -185,10 +196,16 @@ def render_llm_section(config: Config | None = None) -> None:
with st.expander("저장된 LLM 프로파일", expanded=False):
for p in reg.profiles:
if p.fields:
- pairs = [
- f"{k}={p.fields.get(k, '')}"
- for k in sorted(p.fields.keys())
- ]
+ # 필드 정의 가져오기 (시크릿 여부 확인용)
+ fields_def = _llm_fields(p.provider)
+ secret_keys = {k for _, k, is_secret in fields_def if is_secret}
+
+ pairs = []
+ for k in sorted(p.fields.keys()):
+ val = p.fields.get(k, "")
+ if k in secret_keys:
+ val = "***"
+ pairs.append(f"{k}={val}")
fields_text = ", ".join(pairs)
else:
fields_text = "-"
@@ -217,20 +234,28 @@ def render_llm_section(config: Config | None = None) -> None:
e_fields = _embedding_fields(e_provider)
e_values: dict[str, str | None] = {}
for label, env_key, is_secret in e_fields:
- prefill = st.session_state.get(env_key) or os.getenv(env_key) or ""
+ current_val = st.session_state.get(env_key) or os.getenv(env_key) or ""
if is_secret:
- e_values[env_key] = st.text_input(
- label, value=prefill, type="password", key=f"emb_{env_key}"
+ placeholder = (
+ "값 설정됨 (변경하려면 입력)" if current_val else "값 없음"
)
+ new_val = st.text_input(
+ label,
+ value="",
+ type="password",
+ key=f"emb_{env_key}",
+ placeholder=placeholder,
+ )
+ e_values[env_key] = new_val if new_val else current_val
else:
e_values[env_key] = st.text_input(
- label, value=prefill, key=f"emb_{env_key}"
+ label, value=current_val, key=f"emb_{env_key}"
)
# 메시지 영역: 버튼 컬럼 밖(섹션 폭)
emb_msg = st.empty()
- st.markdown("**Embeddings 프로파일 저장 (시크릿 포함)**")
+ st.markdown("**Embeddings 프로파일 저장**")
with st.form("embedding_profile_save_form"):
e_prof_cols = st.columns([2, 2])
with e_prof_cols[0]:
@@ -263,10 +288,16 @@ def render_llm_section(config: Config | None = None) -> None:
with st.expander("저장된 Embeddings 프로파일", expanded=False):
for p in e_reg.profiles:
if p.fields:
- pairs = [
- f"{k}={p.fields.get(k, '')}"
- for k in sorted(p.fields.keys())
- ]
+ # 필드 정의 가져오기 (시크릿 여부 확인용)
+ fields_def = _embedding_fields(p.provider)
+ secret_keys = {k for _, k, is_secret in fields_def if is_secret}
+
+ pairs = []
+ for k in sorted(p.fields.keys()):
+ val = p.fields.get(k, "")
+ if k in secret_keys:
+ val = "***"
+ pairs.append(f"{k}={val}")
fields_text = ", ".join(pairs)
else:
fields_text = "-"
diff --git a/interface/app_pages/sidebar_components/data_source_selector.py b/interface/app_pages/sidebar_components/data_source_selector.py
index b32cf3f..fe54540 100644
--- a/interface/app_pages/sidebar_components/data_source_selector.py
+++ b/interface/app_pages/sidebar_components/data_source_selector.py
@@ -39,8 +39,18 @@ def render_sidebar_data_source_selector(config=None) -> None:
return
try:
update_datahub_server(config, selected.url)
- # DataHub 선택 시, FAISS 경로가 정의되어 있으면 기본 VectorDB 로케이션으로도 반영
- if selected.faiss_path:
+ # DataHub 선택 시, VectorDB 설정이 정의되어 있으면 기본 VectorDB 로케이션으로도 반영
+ if selected.vectordb_location:
+ try:
+ update_vectordb_settings(
+ config,
+ vectordb_type=selected.vectordb_type or "faiss",
+ vectordb_location=selected.vectordb_location,
+ )
+ except Exception as e:
+ st.sidebar.warning(f"VectorDB 설정 적용 경고: {e}")
+ elif selected.faiss_path:
+ # Backward compatibility
try:
update_vectordb_settings(
config,
diff --git a/interface/core/config/models.py b/interface/core/config/models.py
index 9ec02af..7be3d14 100644
--- a/interface/core/config/models.py
+++ b/interface/core/config/models.py
@@ -17,14 +17,18 @@ class DataHubSource:
name: str
url: str
faiss_path: Optional[str] = None
+ vectordb_type: str = "faiss"
+ vectordb_location: Optional[str] = None
+ vectordb_api_key: Optional[str] = None
note: Optional[str] = None
@dataclass
class VectorDBSource:
name: str
- type: str # 'faiss' | 'pgvector'
+ type: str # 'faiss' | 'pgvector' | 'qdrant'
location: str
+ api_key: Optional[str] = None
collection_prefix: Optional[str] = None
note: Optional[str] = None
diff --git a/interface/core/config/persist.py b/interface/core/config/persist.py
index 81ba144..ee077ba 100644
--- a/interface/core/config/persist.py
+++ b/interface/core/config/persist.py
@@ -63,11 +63,22 @@ def _parse_datahub_list(items: List[Dict[str, Any]]) -> List[DataHubSource]:
name = str(item.get("name", "")).strip()
url = str(item.get("url", "")).strip()
faiss_path = item.get("faiss_path")
+ vectordb_type = item.get("vectordb_type", "faiss")
+ vectordb_location = item.get("vectordb_location")
+ vectordb_api_key = item.get("vectordb_api_key")
note = item.get("note")
if not name or not url:
continue
parsed.append(
- DataHubSource(name=name, url=url, faiss_path=faiss_path, note=note)
+ DataHubSource(
+ name=name,
+ url=url,
+ faiss_path=faiss_path,
+ vectordb_type=vectordb_type,
+ vectordb_location=vectordb_location,
+ vectordb_api_key=vectordb_api_key,
+ note=note,
+ )
)
return parsed
@@ -81,12 +92,14 @@ def _parse_vectordb_list(items: List[Dict[str, Any]]) -> List[VectorDBSource]:
if not name or not vtype or not location:
continue
collection_prefix = item.get("collection_prefix")
+ api_key = item.get("api_key")
note = item.get("note")
parsed.append(
VectorDBSource(
name=name,
type=vtype,
location=location,
+ api_key=api_key,
collection_prefix=collection_prefix,
note=note,
)
diff --git a/interface/core/config/registry_data_sources.py b/interface/core/config/registry_data_sources.py
index 8e9b646..7f393f8 100644
--- a/interface/core/config/registry_data_sources.py
+++ b/interface/core/config/registry_data_sources.py
@@ -41,25 +41,53 @@ def _save_registry(registry: DataSourcesRegistry) -> None:
def add_datahub_source(
- *, name: str, url: str, faiss_path: Optional[str] = None, note: Optional[str] = None
+ *,
+ name: str,
+ url: str,
+ faiss_path: Optional[str] = None,
+ vectordb_type: str = "faiss",
+ vectordb_location: Optional[str] = None,
+ vectordb_api_key: Optional[str] = None,
+ note: Optional[str] = None,
) -> None:
registry = get_data_sources_registry()
if any(s.name == name for s in registry.datahub):
raise ValueError(f"이미 존재하는 DataHub 이름입니다: {name}")
registry.datahub.append(
- DataHubSource(name=name, url=url, faiss_path=faiss_path, note=note)
+ DataHubSource(
+ name=name,
+ url=url,
+ faiss_path=faiss_path,
+ vectordb_type=vectordb_type,
+ vectordb_location=vectordb_location,
+ vectordb_api_key=vectordb_api_key,
+ note=note,
+ )
)
_save_registry(registry)
def update_datahub_source(
- *, name: str, url: str, faiss_path: Optional[str], note: Optional[str]
+ *,
+ name: str,
+ url: str,
+ faiss_path: Optional[str],
+ vectordb_type: str = "faiss",
+ vectordb_location: Optional[str] = None,
+ vectordb_api_key: Optional[str] = None,
+ note: Optional[str],
) -> None:
registry = get_data_sources_registry()
for idx, s in enumerate(registry.datahub):
if s.name == name:
registry.datahub[idx] = DataHubSource(
- name=name, url=url, faiss_path=faiss_path, note=note
+ name=name,
+ url=url,
+ faiss_path=faiss_path,
+ vectordb_type=vectordb_type,
+ vectordb_location=vectordb_location,
+ vectordb_api_key=vectordb_api_key,
+ note=note,
)
_save_registry(registry)
return
@@ -77,12 +105,15 @@ def add_vectordb_source(
name: str,
vtype: str,
location: str,
+ api_key: Optional[str] = None,
collection_prefix: Optional[str] = None,
note: Optional[str] = None,
) -> None:
vtype = (vtype or "").lower()
- if vtype not in ("faiss", "pgvector"):
- raise ValueError("VectorDB 타입은 'faiss' 또는 'pgvector'여야 합니다")
+ if vtype not in ("faiss", "pgvector", "qdrant"):
+ raise ValueError(
+ "VectorDB 타입은 'faiss', 'pgvector', 'qdrant' 중 하나여야 합니다"
+ )
registry = get_data_sources_registry()
if any(s.name == name for s in registry.vectordb):
raise ValueError(f"이미 존재하는 VectorDB 이름입니다: {name}")
@@ -91,6 +122,7 @@ def add_vectordb_source(
name=name,
type=vtype,
location=location,
+ api_key=api_key,
collection_prefix=collection_prefix,
note=note,
)
@@ -103,12 +135,15 @@ def update_vectordb_source(
name: str,
vtype: str,
location: str,
+ api_key: Optional[str] = None,
collection_prefix: Optional[str],
note: Optional[str],
) -> None:
vtype = (vtype or "").lower()
- if vtype not in ("faiss", "pgvector"):
- raise ValueError("VectorDB 타입은 'faiss' 또는 'pgvector'여야 합니다")
+ if vtype not in ("faiss", "pgvector", "qdrant"):
+ raise ValueError(
+ "VectorDB 타입은 'faiss', 'pgvector', 'qdrant' 중 하나여야 합니다"
+ )
registry = get_data_sources_registry()
for idx, s in enumerate(registry.vectordb):
if s.name == name:
@@ -116,6 +151,7 @@ def update_vectordb_source(
name=name,
type=vtype,
location=location,
+ api_key=api_key,
collection_prefix=collection_prefix,
note=note,
)
diff --git a/interface/core/config/settings.py b/interface/core/config/settings.py
index 9b4eeb2..b0a02ee 100644
--- a/interface/core/config/settings.py
+++ b/interface/core/config/settings.py
@@ -153,12 +153,13 @@ def update_vectordb_settings(
"""Validate and update VectorDB settings into env and session.
Basic validation rules follow CLI's behavior:
- - vectordb_type must be 'faiss' or 'pgvector'
+ - vectordb_type must be 'faiss' or 'pgvector' or 'qdrant'
- if type == 'faiss' and location provided: must be an existing directory
- if type == 'pgvector' and location provided: must start with 'postgresql://'
+ - if type == 'qdrant' and location provided: must start with 'http://'
"""
vtype = (vectordb_type or "").lower()
- if vtype not in ("faiss", "pgvector"):
+ if vtype not in ("faiss", "pgvector", "qdrant"):
raise ValueError(f"지원하지 않는 VectorDB 타입: {vectordb_type}")
vloc = vectordb_location or ""
diff --git a/utils/llm/README.md b/utils/llm/README.md
index 993ee98..260e8f7 100644
--- a/utils/llm/README.md
+++ b/utils/llm/README.md
@@ -10,7 +10,10 @@ utils/llm/
├── chains.py # LangChain 체인 생성 모듈
├── retrieval.py # 테이블 메타 검색 및 재순위화
├── llm_response_parser.py # LLM 응답에서 SQL 블록 추출
-├── chatbot.py # LangGraph ChatBot 구현
+├── chatbot/ # LangGraph ChatBot 패키지
+│ ├── __init__.py
+│ ├── core.py # ChatBot 핵심 로직
+│ └── README.md # [상세 문서](./chatbot/README.md)
├── core/ # LLM/Embedding 팩토리 모듈
│ ├── __init__.py
│ ├── factory.py # LLM 및 Embedding 모델 생성 팩토리
@@ -143,7 +146,7 @@ utils/llm/
**목적**: DataHub 메타데이터 수집 및 LangGraph ChatBot용 Tool 함수 제공
**주요 기능:**
-- `get_info_from_db()`: DataHub에서 테이블 메타데이터를 LangChain Document로 수집
+- `get_table_schema()`: DataHub에서 테이블 메타데이터를 dictionary 형태로 반환
- `get_metadata_from_db()`: 전체 메타데이터 딕셔너리 반환
- `search_database_tables()`: 벡터 검색 기반 테이블 정보 검색 Tool
- `get_glossary_terms()`: 용어집 정보 조회 Tool
@@ -152,7 +155,7 @@ utils/llm/
**사용처:**
- `utils/llm/vectordb/faiss_db.py`: 벡터DB 초기화 시 메타데이터 수집
- `utils/llm/vectordb/pgvector_db.py`: 벡터DB 초기화 시 메타데이터 수집
-- `utils/llm/chatbot.py`: ChatBot 도구로 사용
+- `utils/llm/chatbot/`: ChatBot 도구로 사용
**상세 문서**: [tools/README.md](./tools/README.md)
@@ -298,7 +301,7 @@ engine/query_executor.py
│ └── utils/llm/retrieval.py
│ └── utils/llm/vectordb/get_vector_db()
│ ├── utils/llm/core/get_embeddings()
-│ └── utils/llm/tools/get_info_from_db()
+│ └── utils/llm/tools/get_table_schema()
└── utils/llm/llm_response_parser.py
```
@@ -316,8 +319,8 @@ engine/query_executor.py
- `retrieval.py` → `vectordb/get_vector_db()` 사용
**vectordb 모듈:**
-- `vectordb/faiss_db.py` → `core/get_embeddings()`, `tools/get_info_from_db()` 사용
-- `vectordb/pgvector_db.py` → `core/get_embeddings()`, `tools/get_info_from_db()` 사용
+- `vectordb/faiss_db.py` → `core/get_embeddings()`, `tools/get_table_schema()` 사용
+- `vectordb/pgvector_db.py` → `core/get_embeddings()`, `tools/get_table_schema()` 사용
**tools 모듈:**
- `tools/datahub.py` → DataHub 메타데이터 수집
diff --git a/utils/llm/chatbot.py b/utils/llm/chatbot.py
deleted file mode 100644
index 51bcab0..0000000
--- a/utils/llm/chatbot.py
+++ /dev/null
@@ -1,214 +0,0 @@
-"""
-LangGraph 기반 ChatBot 모델
-OpenAI의 ChatGPT 모델을 사용하여 대화 기록을 유지하는 챗봇 구현
-"""
-
-from typing import Annotated, Sequence, TypedDict
-
-from langchain_core.messages import BaseMessage, SystemMessage
-from langchain_openai import ChatOpenAI
-from langgraph.checkpoint.memory import MemorySaver
-from langgraph.graph import START, StateGraph
-from langgraph.graph.message import add_messages
-from langgraph.prebuilt import ToolNode
-
-from utils.llm.tools import (
- search_database_tables,
- get_glossary_terms,
- get_query_examples,
-)
-
-
-class ChatBotState(TypedDict):
- """
- 챗봇 상태 - 사용자 질문을 SQL로 변환 가능한 구체적인 질문으로 만들어가는 과정 추적
- """
-
- # 기본 메시지 (MessagesState와 동일)
- messages: Annotated[Sequence[BaseMessage], add_messages]
-
- # datahub 서버 정보
- gms_server: str
-
-
-class ChatBot:
- """
- LangGraph를 사용한 대화형 챗봇 클래스
- OpenAI API를 통해 다양한 GPT 모델을 사용할 수 있으며,
- MemorySaver를 통해 대화 기록을 관리합니다.
- """
-
- def __init__(
- self,
- openai_api_key: str,
- model_name: str = "gpt-4o-mini",
- gms_server: str = "http://localhost:8080",
- ):
- """
- ChatBot 인스턴스 초기화
-
- Args:
- openai_api_key: OpenAI API 키
- model_name: 사용할 모델명 (기본값: gpt-4o-mini)
- gms_server: DataHub GMS 서버 URL (기본값: http://localhost:8080)
- """
- self.openai_api_key = openai_api_key
- self.model_name = model_name
- self.gms_server = gms_server
- # SQL 생성을 위한 데이터베이스 메타데이터 조회 도구
- self.tools = [
- search_database_tables, # 데이터베이스 테이블 정보 검색
- get_glossary_terms, # 용어집 조회 도구
- get_query_examples, # 쿼리 예제 조회 도구
- ]
- self.llm = self._setup_llm() # LLM 인스턴스 설정
- self.app = self._setup_workflow() # LangGraph 워크플로우 설정
-
- def _setup_llm(self):
- """
- OpenAI ChatGPT LLM 인스턴스 생성
- Tool을 바인딩하여 LLM이 필요시 tool을 호출할 수 있도록 설정합니다.
-
- Returns:
- ChatOpenAI: Tool이 바인딩된 LLM 인스턴스
- """
- llm = ChatOpenAI(
- temperature=0.0, # SQL 생성은 정확성이 중요하므로 0으로 설정
- openai_api_key=self.openai_api_key,
- model_name=self.model_name,
- )
- # Tool을 LLM에 바인딩하여 함수 호출 기능 활성화
- return llm.bind_tools(self.tools)
-
- def _setup_workflow(self):
- """
- LangGraph 워크플로우 설정
- 대화 기록을 관리하고 LLM과 통신하는 그래프 구조를 생성합니다.
- Tool 호출 기능을 포함하여 LLM이 필요시 도구를 사용할 수 있도록 합니다.
-
- Returns:
- CompiledGraph: 컴파일된 LangGraph 워크플로우
- """
- # ChatBotState를 사용하는 StateGraph 생성
- workflow = StateGraph(state_schema=ChatBotState)
-
- def call_model(state: ChatBotState):
- """
- LLM 모델을 호출하는 노드 함수
- LLM이 응답을 생성하거나 tool 호출을 결정합니다.
-
- Args:
- state: 현재 메시지 상태
-
- Returns:
- dict: LLM 응답이 포함된 상태 업데이트
- """
- # 질문 구체화 전문 어시스턴트 시스템 메시지
- sys_msg = SystemMessage(
- content="""# 역할
-당신은 사용자의 모호한 질문을 명확하고 구체적인 질문으로 만드는 전문 AI 어시스턴트입니다.
-
-# 주요 임무
-- 사용자의 자연어 질문을 이해하고 의도를 정확히 파악합니다
-- 대화를 통해 날짜, 지표, 필터 조건 등 구체적인 정보를 수집합니다
-- 단계별로 사용자와 대화하며 명확하고 구체적인 질문으로 다듬어갑니다
-
-# 작업 프로세스
-1. 사용자의 최초 질문에서 의도 파악
-2. 질문을 명확히 하기 위해 필요한 정보 식별 (날짜, 지표, 대상, 조건 등)
-3. **도구를 적극 활용하여 데이터베이스 스키마, 테이블 정보, 용어집 등을 확인**
-4. 부족한 정보를 자연스럽게 질문하여 수집
-5. 수집된 정보를 바탕으로 질문을 점진적으로 구체화
-6. 충분히 구체화되면 최종 질문 확정
-
-# 도구 사용 가이드
-- **search_database_tables**: 사용자와의 대화를 데이터와 연관짓기 위해 관련 테이블을 적극적으로 확인할 수 있는 도구
-- **get_glossary_terms**: 사용자가 사용한 용어의 정확한 의미를 확인할 때 사용가능한 도구
-- **get_query_examples**: 조직내 저장된 쿼리 예제를 조회하여 참고할 수 있는 도구
-- 답변하기 전에 최대한 많은 도구를 적극 활용하여 정보를 수집하세요
-- 불확실한 정보가 있다면 추측하지 말고 도구를 사용하여 확인하세요
-
-# 예시
-- 모호한 질문: "KPI가 궁금해"
-- 대화 후 구체화: "2025-01-02 날짜의 신규 유저가 발생시킨 매출이 궁금해"
-
-# 주의사항
-- 항상 친절하고 명확하게 대화합니다
-- 이전 대화 맥락을 고려하여 일관성 있게 응답합니다
-- 한 번에 너무 많은 것을 물어보지 않고 단계적으로 진행합니다
-- **중요: 사용자가 말한 내용이 충분히 구체화되지 않거나 의도가 명확히 파악되지 않을 경우, 추측하지 말고 모든 도구(get_glossary_terms, get_query_examples, search_database_tables)를 적극적으로 사용하여 맥락을 파악하세요**
-- 도구를 통해 수집한 정보를 바탕으로 사용자에게 구체적인 방향성과 옵션을 제안하세요
-- 불확실한 정보가 있다면 추측하지 말고 도구를 사용하여 확인한 후 답변하세요
-
----
-다음은 사용자와의 대화입니다:"""
- )
- # 시스템 메시지를 대화의 맨 앞에 추가
- messages = [sys_msg] + state["messages"]
- response = self.llm.invoke(messages)
- return {"messages": response}
-
- def route_model_output(state: ChatBotState):
- """
- LLM 출력에 따라 다음 노드를 결정하는 라우팅 함수
- Tool 호출이 필요한 경우 'tools' 노드로, 아니면 대화를 종료합니다.
-
- Args:
- state: 현재 메시지 상태
-
- Returns:
- str: 다음에 실행할 노드 이름 ('tools' 또는 '__end__')
- """
- messages = state["messages"]
- last_message = messages[-1]
- # LLM이 tool을 호출하려고 하는 경우 (tool_calls가 있는 경우)
- if hasattr(last_message, "tool_calls") and last_message.tool_calls:
- return "tools"
- # Tool 호출이 없으면 대화 종료
- return "__end__"
-
- # 워크플로우 구조 정의
- workflow.add_edge(START, "model") # 시작 -> model 노드
- workflow.add_node("model", call_model) # LLM 호출 노드
- workflow.add_node("tools", ToolNode(self.tools)) # Tool 실행 노드
-
- # model 노드 이후 조건부 라우팅
- workflow.add_conditional_edges("model", route_model_output)
- # Tool 실행 후 다시 model로 돌아가서 최종 응답 생성
- workflow.add_edge("tools", "model")
-
- # MemorySaver를 사용하여 대화 기록 저장 기능 추가
- return workflow.compile(checkpointer=MemorySaver())
-
- def chat(self, message: str, thread_id: str):
- """
- 사용자 메시지에 대한 응답 생성
-
- Args:
- message: 사용자 입력 메시지
- thread_id: 대화 세션을 구분하는 고유 ID
-
- Returns:
- dict: LLM 응답을 포함한 결과 딕셔너리
- """
- config = {"configurable": {"thread_id": thread_id}}
-
- # 상태 준비
- input_state = {
- "messages": [{"role": "user", "content": message}],
- "gms_server": self.gms_server, # DataHub 서버 URL을 상태에 포함
- }
-
- return self.app.invoke(input_state, config)
-
- def update_model(self, model_name: str):
- """
- 사용 중인 LLM 모델 변경
- 모델 변경 시 LLM 인스턴스와 워크플로우를 재설정합니다.
-
- Args:
- model_name: 변경할 모델명
- """
- self.model_name = model_name
- self.llm = self._setup_llm() # 새 모델로 LLM 재설정
- self.app = self._setup_workflow() # 워크플로우 재생성
diff --git a/utils/llm/chatbot/README.md b/utils/llm/chatbot/README.md
new file mode 100644
index 0000000..061bb4f
--- /dev/null
+++ b/utils/llm/chatbot/README.md
@@ -0,0 +1,57 @@
+# ChatBot Module
+
+LangGraph 기반의 대화형 챗봇 모듈입니다. 사용자의 자연어 질문을 이해하고, 적절한 가이드라인과 도구를 선택하여 답변을 생성합니다.
+
+## 구조
+
+```
+utils/llm/chatbot/
+├── __init__.py # 패키지 초기화 및 ChatBot 클래스 export
+├── core.py # ChatBot 클래스 및 LangGraph 워크플로우 정의
+├── guidelines.py # 가이드라인 및 툴 래퍼 함수 정의
+├── matcher.py # LLM 기반 가이드라인 매칭 로직
+└── types.py # 데이터 타입 및 구조 정의
+```
+
+## 주요 컴포넌트
+
+### `ChatBot` (`core.py`)
+챗봇의 메인 클래스입니다. LangGraph를 사용하여 대화 흐름을 제어합니다.
+- **초기화**: OpenAI API 키, 모델명, GMS 서버 URL 등을 설정합니다.
+- **워크플로우**: `select_guidelines` -> `call_model` 순서로 실행됩니다.
+- **chat 메서드**: 사용자 메시지를 입력받아 응답을 생성합니다.
+
+### `LLMGuidelineMatcher` (`matcher.py`)
+사용자의 메시지를 분석하여 가장 적절한 가이드라인을 선택하는 클래스입니다.
+- LLM을 사용하여 사용자 의도를 파악하고, 미리 정의된 가이드라인 중 하나 이상을 매칭합니다.
+- JSON Schema를 사용하여 구조화된 출력을 보장합니다.
+
+### `Guideline` (`types.py`)
+챗봇이 따를 규칙과 도구를 정의하는 데이터 클래스입니다.
+- `id`: 가이드라인 식별자
+- `description`: 가이드라인 설명
+- `example_phrases`: 매칭에 사용될 예시 문구
+- `tools`: 해당 가이드라인에서 사용할 도구 함수 목록
+- `priority`: 매칭 우선순위
+
+### `GUIDELINES` (`guidelines.py`)
+기본적으로 제공되는 가이드라인 목록입니다.
+- `table_schema`: 데이터베이스 테이블 정보 검색
+- `glossary`: 용어집 조회
+- `query_examples`: 쿼리 예제 조회
+
+## 사용 예시
+
+```python
+from utils.llm.chatbot import ChatBot
+
+# 챗봇 인스턴스 생성
+bot = ChatBot(
+ openai_api_key="sk-...",
+ gms_server="http://localhost:8080"
+)
+
+# 대화하기
+response = bot.chat("매출 테이블 정보 알려줘", thread_id="session_1")
+print(response["messages"][-1].content)
+```
diff --git a/utils/llm/chatbot/__init__.py b/utils/llm/chatbot/__init__.py
new file mode 100644
index 0000000..d816b62
--- /dev/null
+++ b/utils/llm/chatbot/__init__.py
@@ -0,0 +1,7 @@
+"""
+ChatBot 패키지 초기화 모듈
+"""
+
+from utils.llm.chatbot.core import ChatBot
+
+__all__ = ["ChatBot"]
diff --git a/utils/llm/chatbot/core.py b/utils/llm/chatbot/core.py
new file mode 100644
index 0000000..b7e27fa
--- /dev/null
+++ b/utils/llm/chatbot/core.py
@@ -0,0 +1,311 @@
+"""
+ChatBot 핵심 로직 및 LangGraph 워크플로우 정의
+"""
+
+from typing import Any, Dict, List, Optional
+
+from langchain_core.messages import HumanMessage, SystemMessage
+from langchain_openai import ChatOpenAI
+from langgraph.checkpoint.memory import MemorySaver
+from langgraph.graph import END, START, StateGraph
+from openai import OpenAI
+
+from utils.llm.tools import filter_relevant_outputs
+from utils.llm.chatbot.guidelines import GUIDELINES
+from utils.llm.chatbot.matcher import LLMGuidelineMatcher
+from utils.llm.chatbot.types import ChatBotState, Guideline
+
+
+class ChatBot:
+ """
+ LangGraph를 사용한 대화형 챗봇 클래스 (Guideline 기반)
+ """
+
+ def __init__(
+ self,
+ openai_api_key: str,
+ model_name: str = "gpt-4o-mini",
+ gms_server: str = "http://localhost:8080",
+ guidelines: Optional[List[Guideline]] = None,
+ ):
+ """
+ ChatBot 인스턴스 초기화
+
+ Args:
+ openai_api_key: OpenAI API 키
+ model_name: 사용할 모델명 (기본값: gpt-4o-mini)
+ gms_server: DataHub GMS 서버 URL (기본값: http://localhost:8080)
+ guidelines: 사용할 가이드라인 목록 (없으면 기본값 사용)
+ """
+ self.openai_api_key = openai_api_key
+ self.model_name = model_name
+ self.gms_server = gms_server
+ self.guidelines = guidelines or GUIDELINES
+ self.guideline_map = {g.id: g for g in self.guidelines}
+
+ self._client = OpenAI(api_key=openai_api_key)
+ self.matcher = LLMGuidelineMatcher(
+ self.guidelines,
+ model=self.model_name,
+ client_obj=self._client,
+ )
+ self.llm = ChatOpenAI(
+ temperature=0.0,
+ model_name=self.model_name,
+ openai_api_key=openai_api_key,
+ )
+ self.app = self._setup_workflow()
+
+ def _setup_workflow(self):
+ """
+ LangGraph 워크플로우 설정
+ """
+ workflow = StateGraph(state_schema=ChatBotState)
+
+ def select_guidelines(state: ChatBotState):
+ user_text = ""
+ # 마지막 사용자 메시지 찾기
+ for msg in reversed(state["messages"]):
+ if isinstance(msg, HumanMessage) or (
+ hasattr(msg, "type") and msg.type == "human"
+ ):
+ user_text = msg.content
+ break
+
+ # 만약 메시지 객체 구조가 달라서 못 찾았을 경우를 대비해 마지막 메시지 내용 사용
+ if not user_text and state["messages"]:
+ user_text = state["messages"][-1].content
+
+ matched = self.matcher.match(str(user_text))
+
+ # 컨텍스트 업데이트 (현재 사용자 메시지 추가)
+ ctx = state.get("context") or {}
+ ctx["last_user_message"] = user_text
+ ctx["gms_server"] = self.gms_server
+ # search_database_tables_tool을 위해 query 키도 설정
+ ctx["query"] = user_text
+
+ # 결과 저장을 위한 임시 딕셔너리 (기존 상태 유지 + 추가)
+ updates = {
+ "table_schema_outputs": list(state.get("table_schema_outputs") or []),
+ "glossary_outputs": list(state.get("glossary_outputs") or []),
+ "query_example_outputs": list(state.get("query_example_outputs") or []),
+ }
+
+ for g in matched:
+ target_list = None
+ if g.id == "table_schema":
+ target_list = updates["table_schema_outputs"]
+ elif g.id == "glossary":
+ target_list = updates["glossary_outputs"]
+ elif g.id == "query_examples":
+ target_list = updates["query_example_outputs"]
+
+ # 매칭되는 카테고리가 없으면 스킵하거나 로깅 (현재는 스킵)
+ if target_list is None:
+ continue
+
+ for tool in g.tools or []:
+ try:
+ result = tool(ctx)
+ # 구조화된 데이터를 그대로 저장 (UI 렌더링용)
+ target_list.append(result)
+ except Exception as exc:
+ target_list.append({"error": str(exc), "tool": tool.__name__})
+
+ # 빈 리스트인 필드는 제거하여 State 업데이트 시 기존 값을 덮어쓰지 않도록 함
+ # (LangGraph State 업데이트 동작: 딕셔너리에 포함된 키만 업데이트됨)
+ final_updates = {k: v for k, v in updates.items() if v}
+
+ return {
+ "selected_ids": [g.id for g in matched],
+ "context": ctx,
+ **final_updates,
+ }
+
+ def filter_context(state: ChatBotState):
+ """
+ 수집된 컨텍스트를 LLM을 통해 필터링하는 노드
+ """
+ # HumanMessage만 필터링
+ human_messages = [
+ msg
+ for msg in state["messages"]
+ if isinstance(msg, HumanMessage)
+ or (hasattr(msg, "type") and msg.type == "human")
+ ]
+
+ table_outs = state.get("table_schema_outputs", [])
+ glossary_outs = state.get("glossary_outputs", [])
+ query_outs = state.get("query_example_outputs", [])
+
+ # 필터링 수행
+ filtered = filter_relevant_outputs(
+ messages=human_messages,
+ table_outputs=table_outs,
+ glossary_outputs=glossary_outs,
+ query_outputs=query_outs,
+ llm=self.llm,
+ )
+
+ return {
+ "table_schema_outputs": filtered.get("table_schema_outputs", []),
+ "glossary_outputs": filtered.get("glossary_outputs", []),
+ "query_example_outputs": filtered.get("query_example_outputs", []),
+ }
+
+ def generate_analysis_guide(state: ChatBotState):
+ """
+ 필터링된 컨텍스트를 바탕으로 분석 가이드를 생성하는 노드
+ """
+ user_text = ""
+ for msg in reversed(state["messages"]):
+ if isinstance(msg, HumanMessage) or (
+ hasattr(msg, "type") and msg.type == "human"
+ ):
+ user_text = msg.content
+ break
+
+ if not user_text and state["messages"]:
+ user_text = state["messages"][-1].content
+
+ table_outs = state.get("table_schema_outputs", [])
+ glossary_outs = state.get("glossary_outputs", [])
+ query_outs = state.get("query_example_outputs", [])
+
+ # 컨텍스트가 없으면 가이드 생성 생략
+ if not (table_outs or glossary_outs or query_outs):
+ return {"analysis_guide": None}
+
+ prompt = (
+ "당신은 데이터 분석 전문가입니다. 사용자의 질문과 제공된 컨텍스트(테이블 스키마, 용어집, 쿼리 예제)를 바탕으로 "
+ "데이터 분석 시나리오(Analysis Guide)를 작성해주세요.\n\n"
+ "다음 우선순위에 따라 분석 전략을 수립하세요:\n"
+ "1. Query Example 활용: 유사한 쿼리 예제가 있다면 이를 변형하여 분석하는 방법을 제안하세요.\n"
+ "2. Glossary 활용: 질문에 포함된 모호한 용어가 용어집에 있다면 그 정의를 바탕으로 분석 방법을 제안하세요.\n"
+ "3. Table Schema 활용: 위 정보가 부족하다면 테이블 스키마를 보고 어떤 컬럼을 조합하여 분석할지 제안하세요.\n\n"
+ f"# 사용자 질문: {user_text}\n\n"
+ f"# 테이블 스키마 정보: {table_outs}\n"
+ f"# 용어집 정보: {glossary_outs}\n"
+ f"# 쿼리 예제 정보: {query_outs}\n\n"
+ "분석 가이드는 명확하고 논리적인 단계로 작성해주세요."
+ )
+
+ response = self.llm.invoke([HumanMessage(content=prompt)])
+ return {"analysis_guide": response.content}
+
+ def call_model(state: ChatBotState):
+ selected_ids = state.get("selected_ids", [])
+
+ # 각 출력 필드 가져오기
+ table_outs = state.get("table_schema_outputs", [])
+ glossary_outs = state.get("glossary_outputs", [])
+ query_outs = state.get("query_example_outputs", [])
+ analysis_guide = state.get("analysis_guide")
+
+ guideline_lines = [
+ f"- {gid}: {self.guideline_map[gid].description}"
+ for gid in selected_ids
+ if gid in self.guideline_map
+ ] or ["- 적용 가능한 가이드라인 없음 (일반 대화 진행)"]
+
+ # 툴 실행 결과 통합 (LLM 프롬프트용 문자열 변환)
+ all_tool_lines = []
+ if table_outs:
+ all_tool_lines.append("## 테이블 스키마 정보")
+ for item in table_outs:
+ all_tool_lines.append(str(item))
+ if glossary_outs:
+ all_tool_lines.append("## 용어집 정보")
+ for item in glossary_outs:
+ all_tool_lines.append(str(item))
+ if query_outs:
+ all_tool_lines.append("## 쿼리 예제 정보")
+ for item in query_outs:
+ all_tool_lines.append(str(item))
+
+ if not all_tool_lines:
+ all_tool_lines = ["(툴 실행 결과 없음)"]
+
+ # 분석 가이드 추가
+ analysis_guide_text = ""
+ if analysis_guide:
+ analysis_guide_text = (
+ f"\n\n# 분석 가이드 (Analysis Guide)\n{analysis_guide}"
+ )
+
+ sys_msg = SystemMessage(
+ content=(
+ "# 역할\n"
+ "당신은 사용자의 비즈니스 질문을 구체적인 데이터 분석 시나리오로 연결해주는 '데이터 분석 컨설턴트'입니다.\n"
+ "단순히 질문을 구체화하는 것을 넘어, 제공된 데이터 자산(테이블, 용어, 쿼리)을 활용하여 '어떤 데이터를 어떻게 분석하면 답을 얻을 수 있는지'를 전문적으로 가이드해야 합니다.\n"
+ "# 적용된 가이드라인\n"
+ + "\n".join(guideline_lines)
+ + "\n\n# 툴 실행 결과 (참고 정보)\n"
+ + "\n".join(all_tool_lines)
+ + analysis_guide_text
+ + "\n\n# 지침\n"
+ "- 툴 실행 결과에 유용한 정보가 있다면 적극적으로 인용하여 답변하세요.\n"
+ "- 정보가 부족하다면 추가 질문을 통해 구체화하세요.\n"
+ "- 항상 친절하고 명확하게 대화하세요."
+ )
+ )
+
+ # 시스템 메시지를 대화의 맨 앞에 추가 (또는 매번 컨텍스트로 주입)
+ # LangGraph에서는 메시지 리스트가 계속 쌓이므로,
+ # 이번 턴의 시스템 메시지를 앞에 붙여서 invoke 하는 방식 사용
+ messages = [sys_msg] + list(state["messages"])
+ response = self.llm.invoke(messages)
+ return {"messages": response}
+
+ workflow.add_node("select", select_guidelines)
+ workflow.add_node("filter", filter_context)
+ workflow.add_node("generate_analysis_guide", generate_analysis_guide)
+ workflow.add_node("respond", call_model)
+
+ workflow.add_edge(START, "select")
+ workflow.add_edge("select", "filter")
+ workflow.add_edge("filter", "generate_analysis_guide")
+ workflow.add_edge("generate_analysis_guide", "respond")
+ workflow.add_edge("respond", END)
+
+ return workflow.compile(checkpointer=MemorySaver())
+
+ def chat(self, message: str, thread_id: str):
+ """
+ 사용자 메시지에 대한 응답 생성
+
+ Args:
+ message: 사용자 입력 메시지
+ thread_id: 대화 세션을 구분하는 고유 ID
+
+ Returns:
+ dict: LLM 응답을 포함한 결과 딕셔너리
+ """
+ config = {"configurable": {"thread_id": thread_id}}
+
+ # 초기 상태 설정
+ input_state = {
+ "messages": [HumanMessage(content=message)],
+ "context": {"gms_server": self.gms_server},
+ "selected_ids": [],
+ }
+
+ return self.app.invoke(input_state, config)
+
+ def update_model(self, model_name: str):
+ """
+ 사용 중인 LLM 모델 변경
+ """
+ self.model_name = model_name
+ self._client = OpenAI(api_key=self.openai_api_key)
+ self.matcher = LLMGuidelineMatcher(
+ self.guidelines,
+ model=self.model_name,
+ client_obj=self._client,
+ )
+ self.llm = ChatOpenAI(
+ temperature=0.0,
+ model_name=self.model_name,
+ openai_api_key=self.openai_api_key,
+ )
diff --git a/utils/llm/chatbot/guidelines.py b/utils/llm/chatbot/guidelines.py
new file mode 100644
index 0000000..6de1723
--- /dev/null
+++ b/utils/llm/chatbot/guidelines.py
@@ -0,0 +1,67 @@
+"""
+ChatBot 가이드라인 및 툴 정의
+"""
+
+from typing import Any, Dict, List
+
+from utils.llm.tools import (
+ search_database_tables,
+ get_glossary_terms,
+ get_query_examples,
+)
+from utils.llm.chatbot.types import Guideline
+
+
+def search_database_tables_tool(ctx: Dict[str, Any]) -> str:
+ query = ctx.get("query") or ctx.get("last_user_message", "")
+ return search_database_tables.invoke({"query": query})
+
+
+def get_glossary_terms_tool(ctx: Dict[str, Any]) -> str:
+ query = ctx.get("query") or ctx.get("last_user_message", "")
+ return get_glossary_terms.invoke({"query": query})
+
+
+def get_query_examples_tool(ctx: Dict[str, Any]) -> str:
+ query = ctx.get("query") or ctx.get("last_user_message", "")
+ return get_query_examples.invoke({"query": query})
+
+
+GUIDELINES: List[Guideline] = [
+ Guideline(
+ id="table_schema",
+ description="데이터베이스 테이블 정보나 스키마 확인이 필요할 때 사용",
+ example_phrases=[
+ "테이블 정보 알려줘",
+ "어떤 컬럼이 있어?",
+ "스키마 보여줘",
+ "데이터 구조가 궁금해",
+ ],
+ tools=[search_database_tables_tool],
+ priority=10,
+ ),
+ Guideline(
+ id="glossary",
+ description="용어의 정의나 비즈니스 의미 확인이 필요할 때 사용",
+ example_phrases=[
+ "용어집 보여줘",
+ "이 단어 뜻이 뭐야?",
+ "비즈니스 용어 설명해줘",
+ "KPI 정의가 뭐야?",
+ ],
+ tools=[get_glossary_terms_tool],
+ priority=8,
+ ),
+ Guideline(
+ id="query_examples",
+ description="쿼리 예제나 SQL 작성 패턴 확인이 필요할 때 사용",
+ example_phrases=[
+ "쿼리 예제 보여줘",
+ "비슷한 쿼리 있어?",
+ "SQL 어떻게 짜야해?",
+ "다른 사람들은 어떻게 쿼리했어?",
+ ],
+ tools=[get_query_examples_tool],
+ priority=9,
+ ),
+]
diff --git a/utils/llm/chatbot/matcher.py b/utils/llm/chatbot/matcher.py
new file mode 100644
index 0000000..ab1ba92
--- /dev/null
+++ b/utils/llm/chatbot/matcher.py
@@ -0,0 +1,84 @@
+"""
+LLM 기반 가이드라인 매칭 로직
+"""
+
+import json
+from typing import Any, Dict, List, Optional
+
+from openai import OpenAI
+
+from utils.llm.chatbot.types import Guideline
+
+
+class LLMGuidelineMatcher:
+ def __init__(
+ self,
+ guidelines: List[Guideline],
+ model: str,
+ client_obj: Optional[OpenAI] = None,
+ ):
+ self.guidelines = guidelines
+ self.model = model
+ self.client = client_obj or OpenAI()
+ self._id_set = {g.id for g in guidelines}
+
+ def _build_messages(self, message: str) -> List[Dict[str, str]]:
+ sys = (
+ "You are a strict GuidelineMatcher.\n"
+ "Return ONLY a JSON object that matches the provided JSON schema."
+ )
+ lines = [
+ "아래 사용자 메시지에 해당하는 모든 가이드라인 id를 선택하세요.",
+ f"[USER MESSAGE]\n{message}\n",
+ "[GUIDELINES]",
+ ]
+ for g in self.guidelines:
+ examples = ", ".join(g.example_phrases) if g.example_phrases else "-"
+ lines.append(
+ f"- id: {g.id}\n desc: {g.description}\n examples: {examples}"
+ )
+ return [
+ {"role": "system", "content": sys},
+ {"role": "user", "content": "\n".join(lines)},
+ ]
+
+ def _json_schema_spec(self) -> Dict[str, Any]:
+ return {
+ "name": "guideline_matches",
+ "schema": {
+ "type": "object",
+ "properties": {
+ "matches": {
+ "type": "array",
+ "items": {"type": "string", "enum": list(self._id_set)},
+ }
+ },
+ "required": ["matches"],
+ "additionalProperties": False,
+ },
+ "strict": True,
+ }
+
+ def match(self, message: str) -> List[Guideline]:
+ ids: List[str] = []
+ try:
+ completion = self.client.chat.completions.create(
+ model=self.model,
+ temperature=0,
+ messages=self._build_messages(message),
+ response_format={
+ "type": "json_schema",
+ "json_schema": self._json_schema_spec(),
+ },
+ )
+ raw = completion.choices[0].message.content
+ data = json.loads(raw) if isinstance(raw, str) else raw
+ ids = [i for i in (data.get("matches") or []) if i in self._id_set]
+ except Exception:
+ # LLM 호출 실패 시 빈 리스트 반환 (일반 대화로 처리)
+ ids = []
+
+ id_to_g = {g.id: g for g in self.guidelines}
+ selected = [id_to_g[i] for i in ids if i in id_to_g]
+ selected.sort(key=lambda g: g.priority, reverse=True)
+ return selected
diff --git a/utils/llm/chatbot/types.py b/utils/llm/chatbot/types.py
new file mode 100644
index 0000000..4be1c96
--- /dev/null
+++ b/utils/llm/chatbot/types.py
@@ -0,0 +1,38 @@
+"""
+ChatBot 관련 데이터 타입 및 구조 정의
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Sequence, TypedDict, Annotated
+import operator
+
+from langchain_core.messages import BaseMessage
+from langgraph.graph.message import add_messages
+
+ToolFn = Callable[[Dict[str, Any]], Any]
+
+
+@dataclass
+class Guideline:
+ id: str
+ description: str
+ example_phrases: List[str]
+ tools: Optional[List[ToolFn]] = None
+ priority: int = 0
+
+
+class ChatBotState(TypedDict):
+ """
+ 챗봇 상태
+ """
+
+ messages: Annotated[Sequence[BaseMessage], add_messages]
+ context: Dict[str, Any]
+ selected_ids: List[str]
+
+ table_schema_outputs: List[Optional[Dict[str, Any]]]
+ glossary_outputs: List[Optional[Dict[str, Any]]]
+ query_example_outputs: List[Optional[Dict[str, Any]]]
+ analysis_guide: Optional[str]
diff --git a/utils/llm/retrieval.py b/utils/llm/retrieval.py
index 0b5d916..5f85715 100644
--- a/utils/llm/retrieval.py
+++ b/utils/llm/retrieval.py
@@ -6,6 +6,8 @@
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from utils.llm.vectordb import get_vector_db
+from utils.llm.tools.datahub import get_glossary_vector_data, get_query_vector_data
+from utils.llm.core import get_embeddings
def load_reranker_model(device: str = "cpu"):
@@ -102,3 +104,86 @@ def search_tables(
}
return documents_dict
+
+
+def _prepare_vector_data(data_fetcher, text_fields):
+ """
+ 데이터를 가져와서 임베딩을 생성하는 헬퍼 함수
+ """
+ points = data_fetcher()
+ embeddings = get_embeddings()
+
+ for point in points:
+ payload = point["payload"]
+ # 텍스트 필드들을 결합하여 임베딩 생성
+ text_to_embed = " ".join([str(payload.get(field, "")) for field in text_fields])
+ vector = embeddings.embed_query(text_to_embed)
+ point["vector"] = {"dense": vector}
+
+ return points
+
+
+def search_glossary(query: str, force_update: bool = False, top_n: int = 5) -> list:
+ """
+ 용어집 검색 함수
+ """
+ collection_name = "lang2sql_glossary"
+ db = get_vector_db()
+
+ # 데이터 로더 정의 (임베딩 생성 포함)
+ def data_loader():
+ return _prepare_vector_data(get_glossary_vector_data, ["name", "description"])
+
+ # 컬렉션 초기화 (필요시)
+ db.initialize_collection_if_empty(
+ collection_name=collection_name,
+ force_update=force_update,
+ data_loader=data_loader,
+ )
+
+ # 검색 수행
+ embeddings = get_embeddings()
+ query_vector = embeddings.embed_query(query)
+
+ results = db.search(
+ collection_name=collection_name,
+ query_vector=("dense", query_vector),
+ limit=top_n,
+ )
+
+ # 결과 포맷팅
+ return [res.payload for res in results]
+
+
+def search_query_examples(
+ query: str, force_update: bool = False, top_n: int = 5
+) -> list:
+ """
+ 쿼리 예제 검색 함수
+ """
+ collection_name = "lang2sql_query_example"
+ db = get_vector_db()
+
+ # 데이터 로더 정의 (임베딩 생성 포함)
+ def data_loader():
+ return _prepare_vector_data(get_query_vector_data, ["name", "description"])
+
+ # 컬렉션 초기화 (필요시)
+ db.initialize_collection_if_empty(
+ collection_name=collection_name,
+ force_update=force_update,
+ data_loader=data_loader,
+ )
+
+ # 검색 수행
+ embeddings = get_embeddings()
+ query_vector = embeddings.embed_query(query)
+
+ results = db.search(
+ collection_name=collection_name,
+ query_vector=("dense", query_vector),
+ limit=top_n,
+ )
+
+ # 결과 포맷팅
+ return [res.payload for res in results]
diff --git a/utils/llm/tools/README.md b/utils/llm/tools/README.md
index fa7153d..9747422 100644
--- a/utils/llm/tools/README.md
+++ b/utils/llm/tools/README.md
@@ -21,7 +21,7 @@ utils/llm/tools/
**datahub 모듈에서**:
- `set_gms_server`: GMS 서버 설정
-- `get_info_from_db`: LangChain Document 리스트로 테이블/컬럼 정보 반환
+- `get_table_schema`: LangChain Document 리스트로 테이블/컬럼 정보 반환
- `get_metadata_from_db`: 전체 메타데이터 딕셔너리 리스트 반환
**chatbot_tool 모듈에서**:
@@ -39,7 +39,7 @@ utils/llm/tools/
- 환경변수 `DATAHUB_SERVER`를 설정하고 DatahubMetadataFetcher 초기화
- 유효하지 않은 서버 URL 시 ValueError 발생
-2. **`get_info_from_db(max_workers: int = 8) -> List[Document]`**
+2. **`get_table_schema(max_workers: int = 8) -> List[Document]`**
- DataHub에서 모든 테이블 메타데이터를 수집하여 LangChain Document 리스트 반환
- 각 Document에는 테이블명, 설명, 컬럼 정보가 포함
- 형식: `"{테이블명}: {설명}\nColumns:\n {컬럼명}: {컬럼설명}"`
@@ -157,10 +157,10 @@ utils/llm/tools/
#### 1. DataHub 메타데이터 수집 (vectorDB 초기화)
```python
-from utils.llm.tools import get_info_from_db
+from utils.llm.tools import get_table_schema
# 모든 테이블 메타데이터를 LangChain Document로 수집
-documents = get_info_from_db(max_workers=8)
+documents = get_table_schema(max_workers=8)
# 각 document는 다음과 같은 형식:
# "테이블명: 설명\nColumns:\n 컬럼1: 설명1\n 컬럼2: 설명2"
@@ -224,8 +224,8 @@ queries = get_query_examples(
**import하는 파일**:
- `utils/llm/chatbot.py`: `from utils.llm.tools import search_database_tables, get_glossary_terms, get_query_examples`
-- `utils/llm/vectordb/faiss_db.py`: `from utils.llm.tools import get_info_from_db`
-- `utils/llm/vectordb/pgvector_db.py`: `from utils.llm.tools import get_info_from_db`
+- `utils/llm/vectordb/faiss_db.py`: `from utils.llm.tools import get_table_schema`
+- `utils/llm/vectordb/pgvector_db.py`: `from utils.llm.tools import get_table_schema`
- `interface/core/config/settings.py`: `from utils.llm.tools import set_gms_server`
**내부 의존성**:
@@ -258,7 +258,7 @@ queries = get_query_examples(
#### 메타데이터 수집 흐름 (벡터DB 초기화 시)
-1. `get_info_from_db()` 호출
+1. `get_table_schema()` 호출
2. `_get_fetcher()`로 DatahubMetadataFetcher 인스턴스 생성
3. `parallel_process()`로 병렬 테이블 정보 수집
4. 각 테이블별로 컬럼 정보 추가 수집
diff --git a/utils/llm/tools/__init__.py b/utils/llm/tools/__init__.py
index f0dcb9d..c039860 100644
--- a/utils/llm/tools/__init__.py
+++ b/utils/llm/tools/__init__.py
@@ -1,5 +1,5 @@
from utils.llm.tools.datahub import (
- get_info_from_db,
+ get_table_schema,
get_metadata_from_db,
set_gms_server,
)
@@ -10,11 +10,14 @@
get_query_examples,
)
+from utils.llm.tools.chatbot_node import filter_relevant_outputs
+
__all__ = [
"set_gms_server",
- "get_info_from_db",
+ "get_table_schema",
"get_metadata_from_db",
"search_database_tables",
"get_glossary_terms",
"get_query_examples",
+ "filter_relevant_outputs",
]
diff --git a/utils/llm/tools/chatbot_node.py b/utils/llm/tools/chatbot_node.py
new file mode 100644
index 0000000..701ec65
--- /dev/null
+++ b/utils/llm/tools/chatbot_node.py
@@ -0,0 +1,82 @@
+from typing import Any, Dict, List
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.output_parsers import JsonOutputParser
+from langchain_core.messages import BaseMessage
+
+
+def filter_relevant_outputs(
+ messages: List[BaseMessage],
+ table_outputs: List[Dict[str, Any]],
+ glossary_outputs: List[Dict[str, Any]],
+ query_outputs: List[Dict[str, Any]],
+ llm: Any,
+) -> Dict[str, Any]:
+ """
+ LLM을 사용하여 툴 출력 결과 중 사용자 메시지 히스토리와 관련 있는 항목만 필터링합니다.
+ """
+ if not any([table_outputs, glossary_outputs, query_outputs]):
+ return {
+ "table_schema_outputs": table_outputs,
+ "glossary_outputs": glossary_outputs,
+ "query_example_outputs": query_outputs,
+ }
+
+ # 메시지 히스토리 포맷팅
+ history_text = ""
+ for i, msg in enumerate(messages):
+ # 마지막 메시지인지 확인
+ prefix = "User (Last Message)" if i == len(messages) - 1 else "User"
+ history_text += f"{prefix}: {msg.content}\n"
+
+ parser = JsonOutputParser()
+
+ prompt = ChatPromptTemplate.from_template(
+ """
+ 당신은 검색된 데이터베이스 정보 중 사용자의 질문과 **관련성이 낮은 정보**를 선별하는 전문가입니다.
+
+ # 대화 히스토리 (User 메시지)
+ {history_text}
+
+ # 검색된 정보
+ 1. 테이블 스키마: {table_outputs}
+ 2. 용어집: {glossary_outputs}
+ 3. 쿼리 예제: {query_outputs}
+
+ # 지침
+ - **목표**: 대화의 흐름을 고려하여, 답변에 도움이 될 수 있는 정보는 유지하고 **명확하게 관련 없는 정보**만 제거하세요.
+ - **기준**:
+ - 사용자의 질문과 직접적인 관련이 없더라도, 문맥상 유용한 정보라면 유지하세요.
+ - 대명사(그거, 저거 등)가 가리키는 대상이나, 질문의 의도를 파악하는 데 필요한 정보는 반드시 유지해야 합니다.
+ - 정말로 엉뚱하거나 불필요한 정보라고 확신할 때만 제거하세요.
+ - **형식**: 데이터 구조를 변경하지 말고, 리스트 내부의 불필요한 항목만 제거하세요.
+ - **결과**: 관련된 정보가 없다면 빈 리스트 `[]`를 반환하세요.
+ - 반드시 아래 JSON 형식으로만 응답하세요.
+
+ {{
+ "table_schema_outputs": [...],
+ "glossary_outputs": [...],
+ "query_example_outputs": [...]
+ }}
+ """
+ )
+
+ chain = prompt | llm | parser
+
+ try:
+ result = chain.invoke(
+ {
+ "history_text": history_text,
+ "table_outputs": str(table_outputs),
+ "glossary_outputs": str(glossary_outputs),
+ "query_outputs": str(query_outputs),
+ }
+ )
+ return result
+ except Exception as e:
+ # 에러 발생 시 원본 데이터 반환 (안전장치)
+ print(f"Filtering failed: {e}")
+ return {
+ "table_schema_outputs": table_outputs,
+ "glossary_outputs": glossary_outputs,
+ "query_example_outputs": query_outputs,
+ }
diff --git a/utils/llm/tools/chatbot_tool.py b/utils/llm/tools/chatbot_tool.py
index 9c496f0..85a125b 100644
--- a/utils/llm/tools/chatbot_tool.py
+++ b/utils/llm/tools/chatbot_tool.py
@@ -6,6 +6,7 @@
from utils.data.datahub_services.base_client import DataHubBaseClient
from utils.data.datahub_services.glossary_service import GlossaryService
from utils.data.datahub_services.query_service import QueryService
+from utils.llm.retrieval import search_glossary, search_query_examples
@tool
@@ -105,11 +106,11 @@ def _simplify_glossary_data(glossary_data):
@tool
-def get_glossary_terms(gms_server: str = "http://35.222.65.99:8080") -> list:
+def get_glossary_terms(query: str, force_update: bool = False) -> list:
"""
- DataHub에서 용어집(Glossary) 정보를 조회합니다.
+ DataHub에서 용어집(Glossary) 정보를 검색합니다.
- 이 함수는 DataHub 서버에 연결하여 전체 용어집 데이터를 가져옵니다.
+ 이 함수는 사용자의 질문과 관련된 용어 정의를 찾기 위해 Vector Search를 수행합니다.
용어집은 비즈니스 용어, 도메인 지식, 데이터 정의 등을 표준화하여 관리하는 곳입니다.
**중요**: 사용자의 질문이나 대화에서 다음과 같은 상황이 발생하면 반드시 이 도구를 사용하세요:
@@ -120,40 +121,26 @@ def get_glossary_terms(gms_server: str = "http://35.222.65.99:8080") -> list:
5. 표준 정의가 필요한 비즈니스 용어가 나왔을 때
Args:
- gms_server (str, optional): DataHub GMS 서버 URL입니다.
- 기본값은 "http://35.222.65.99:8080"
+ query (str): 검색할 용어 또는 관련 질문입니다.
+ force_update (bool, optional): True일 경우 데이터를 새로고침하여 검색 인덱스를 재생성합니다. 기본값은 False.
Returns:
- list: 간소화된 용어집 데이터 리스트입니다.
- 각 항목은 name, description, children(선택적) 필드를 포함합니다.
+ list: 검색된 용어집 데이터 리스트입니다.
+ 각 항목은 name, description 등을 포함합니다.
예시 형태:
[
{
"name": "가짜연구소",
"description": "스터디 단체 가짜연구소를 의미하며...",
- "children": [
- {
- "name": "빌더",
- "description": "가짜연구소 스터디 리더를 지칭..."
- }
- ]
+ "type": "term"
},
- {
- "name": "PII",
- "description": "개인 식별 정보...",
- "children": [
- {
- "name": "identifier",
- "description": "개인식별정보중 github 아이디..."
- }
- ]
- }
+ ...
]
Examples:
- >>> get_glossary_terms()
- [{'name': '가짜연구소', 'description': '...', 'children': [...]}]
+ >>> get_glossary_terms("PII가 뭐야?")
+ [{'name': 'PII', 'description': '개인 식별 정보...', ...}]
Note:
이 도구는 다음과 같은 경우에 **반드시** 사용하세요:
@@ -178,37 +165,22 @@ def get_glossary_terms(gms_server: str = "http://35.222.65.99:8080") -> list:
있는지 확인하고, 있다면 먼저 이 도구를 호출하여 정확한 정의를 파악하세요.
"""
try:
- # DataHub 클라이언트 초기화
- client = DataHubBaseClient(gms_server=gms_server)
-
- # GlossaryService 초기화
- glossary_service = GlossaryService(client)
-
- # 전체 용어집 데이터 가져오기
- glossary_data = glossary_service.get_glossary_data()
+ return search_glossary(query=query, force_update=force_update)
- # 간소화된 데이터 반환
- simplified_data = _simplify_glossary_data(glossary_data)
-
- return simplified_data
-
- except ValueError as e:
- return {"error": True, "message": f"DataHub 서버 연결 실패: {str(e)}"}
except Exception as e:
return {"error": True, "message": f"용어집 조회 중 오류 발생: {str(e)}"}
@tool
def get_query_examples(
- gms_server: str = "http://35.222.65.99:8080",
- start: int = 0,
- count: int = 10,
- query: str = "*",
+ query: str,
+ force_update: bool = False,
+ count: int = 5,
) -> list:
"""
- DataHub에서 저장된 쿼리 예제들을 조회합니다.
+ DataHub에서 저장된 쿼리 예제들을 검색합니다.
- 이 함수는 DataHub 서버에 연결하여 저장된 SQL 쿼리 목록을 가져옵니다.
+ 이 함수는 사용자의 질문과 관련된 SQL 쿼리 예제를 찾기 위해 Vector Search를 수행합니다.
조직에서 실제로 사용되고 검증된 쿼리 패턴을 참고하여 더 정확한 SQL을 생성할 수 있습니다.
**중요**: 사용자의 질문이나 대화에서 다음과 같은 상황이 발생하면 반드시 이 도구를 사용하세요:
@@ -220,11 +192,9 @@ def get_query_examples(
6. 조직 내에서 검증된 쿼리 작성 방식을 확인해야 할 때
Args:
- gms_server (str, optional): DataHub GMS 서버 URL입니다.
- 기본값은 "http://35.222.65.99:8080"
- start (int, optional): 조회 시작 위치입니다. 기본값은 0
- count (int, optional): 조회할 쿼리 개수입니다. 기본값은 10
- query (str, optional): 검색 쿼리입니다. 기본값은 "*" (모든 쿼리)
+ query (str): 검색할 쿼리 관련 질문이나 키워드입니다.
+ force_update (bool, optional): True일 경우 데이터를 새로고침하여 검색 인덱스를 재생성합니다. 기본값은 False.
+ count (int, optional): 반환할 검색 결과 개수입니다. 기본값은 5.
Returns:
list: 쿼리 정보 리스트입니다.
@@ -237,19 +207,12 @@ def get_query_examples(
"description": "각 고객별 주문 건수를 집계하는 쿼리",
"statement": "SELECT customer_id, COUNT(*) as order_count FROM orders GROUP BY customer_id"
},
- {
- "name": "월별 매출 현황",
- "description": "월별 총 매출을 계산하는 쿼리",
- "statement": "SELECT DATE_TRUNC('month', order_date) as month, SUM(amount) FROM orders GROUP BY month"
- }
+ ...
]
Examples:
- >>> get_query_examples()
- [{'name': '고객별 주문 수 조회', 'description': '...', 'statement': 'SELECT ...'}]
-
- >>> get_query_examples(count=5)
- # 5개의 쿼리 예제만 조회
+ >>> get_query_examples("매출 집계 쿼리 보여줘")
+ [{'name': '월별 매출 현황', 'description': '...', 'statement': 'SELECT ...'}]
Note:
이 도구는 다음과 같은 경우에 **반드시** 사용하세요:
@@ -280,33 +243,10 @@ def get_query_examples(
SQL을 생성하는 데 큰 도움이 됩니다.
"""
try:
- # DataHub 클라이언트 초기화
- client = DataHubBaseClient(gms_server=gms_server)
-
- # QueryService 초기화
- query_service = QueryService(client)
-
- # 쿼리 데이터 가져오기
- result = query_service.get_query_data(start=start, count=count, query=query)
-
- # 오류 체크
- if "error" in result and result["error"]:
- return {"error": True, "message": result.get("message")}
-
- # name, description, statement만 추출하여 리스트 생성
- simplified_queries = []
- for query_item in result.get("queries", []):
- simplified_query = {
- "name": query_item.get("name"),
- "description": query_item.get("description", ""),
- "statement": query_item.get("statement", ""),
- }
- simplified_queries.append(simplified_query)
-
- return simplified_queries
+ return search_query_examples(
+ query=query, force_update=force_update, top_n=count
+ )
- except ValueError as e:
- return {"error": True, "message": f"DataHub 서버 연결 실패: {str(e)}"}
except Exception as e:
return {
"error": True,
diff --git a/utils/llm/tools/datahub.py b/utils/llm/tools/datahub.py
index 42e564d..7fb87cb 100644
--- a/utils/llm/tools/datahub.py
+++ b/utils/llm/tools/datahub.py
@@ -1,12 +1,16 @@
import os
import re
+import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Dict, Iterable, List, Optional, TypeVar
from langchain.schema import Document
from tqdm import tqdm
+from utils.data.datahub_services.glossary_service import GlossaryService
+from utils.data.datahub_services.query_service import QueryService
from utils.data.datahub_source import DatahubMetadataFetcher
+from utils.data.datahub_services.base_client import DataHubBaseClient
T = TypeVar("T")
R = TypeVar("R")
@@ -76,7 +80,7 @@ def _get_table_info(max_workers: int = 8) -> Dict[str, str]:
def _get_column_info(
- table_name: str, urn_table_mapping: Dict[str, str], max_workers: int = 8
+ table_name: str, urn_table_mapping: Dict[str, str]
) -> List[Dict[str, str]]:
target_urn = urn_table_mapping.get(table_name)
if not target_urn:
@@ -103,7 +107,21 @@ def _extract_dataset_name_from_urn(urn: str) -> Optional[str]:
return None
-def get_info_from_db(max_workers: int = 8) -> List[Document]:
+def get_metadata_from_db() -> List[Dict]:
+ fetcher = _get_fetcher()
+ urns = list(fetcher.get_urns())
+
+ metadata = []
+ total = len(urns)
+ for idx, urn in enumerate(urns, 1):
+ print(f"[{idx}/{total}] Processing URN: {urn}")
+ table_metadata = fetcher.build_table_metadata(urn)
+ metadata.append(table_metadata)
+
+ return metadata
+
+
+def _prepare_datahub_metadata_mappings(max_workers: int = 8):
table_info = _get_table_info(max_workers=max_workers)
fetcher = _get_fetcher()
@@ -118,20 +136,31 @@ def get_info_from_db(max_workers: int = 8) -> List[Document]:
if parsed_name:
display_name_by_table[original_name] = parsed_name
- def process_table_info(item: tuple[str, str, str]) -> str:
- original_table_name, table_description, display_table_name = item
- # 컬럼 조회는 기존 테이블 이름으로 수행 (urn_table_mapping과 일치)
- column_info = _get_column_info(
- original_table_name, urn_table_mapping, max_workers=max_workers
- )
- column_info_str = "\n".join(
- [
- f"{col['column_name']}: {col['column_description']}"
- for col in column_info
- ]
- )
- used_name = display_table_name or original_table_name
- return f"{used_name}: {table_description}\nColumns:\n {column_info_str}"
+ return table_info, urn_table_mapping, display_name_by_table
+
+
+def _format_datahub_table_info(
+ item: tuple[str, str, str], urn_table_mapping: Dict[str, str]
+) -> Dict:
+ original_table_name, table_description, display_table_name = item
+ # 컬럼 조회는 기존 테이블 이름으로 수행 (urn_table_mapping과 일치)
+ column_info = _get_column_info(original_table_name, urn_table_mapping)
+
+ columns = {col["column_name"]: col["column_description"] for col in column_info}
+
+ used_name = display_table_name or original_table_name
+ return {
+ used_name: {
+ "table_description": table_description,
+ "columns": columns,
+ }
+ }
+
+
+def get_table_schema(max_workers: int = 8) -> List[Dict]:
+ table_info, urn_table_mapping, display_name_by_table = (
+ _prepare_datahub_metadata_mappings(max_workers)
+ )
# 표시용 이름을 세 번째 파라미터로 함께 전달
items_with_display = [
@@ -143,25 +172,116 @@ def process_table_info(item: tuple[str, str, str]) -> str:
for name, desc in table_info.items()
]
- table_info_str_list = parallel_process(
+ # parallel_process에 전달할 함수 래핑
+ def process_fn(item):
+ return _format_datahub_table_info(item, urn_table_mapping)
+
+ table_info_list = parallel_process(
items_with_display,
- process_table_info,
+ process_fn,
max_workers=max_workers,
desc="컬럼 정보 수집 중",
)
- return [Document(page_content=info) for info in table_info_str_list]
-
+ return table_info_list
-def get_metadata_from_db() -> List[Dict]:
- fetcher = _get_fetcher()
- urns = list(fetcher.get_urns())
-
- metadata = []
- total = len(urns)
- for idx, urn in enumerate(urns, 1):
- print(f"[{idx}/{total}] Processing URN: {urn}")
- table_metadata = fetcher.build_table_metadata(urn)
- metadata.append(table_metadata)
- return metadata
+def get_glossary_vector_data() -> List[Dict]:
+ """
+ Vector Search를 위한 용어집 데이터를 조회하고 포맷팅합니다.
+ """
+ gms_server = os.getenv("DATAHUB_SERVER", "http://35.222.65.99:8080")
+ client = DataHubBaseClient(gms_server=gms_server)
+ glossary_service = GlossaryService(client)
+
+ glossary_data = glossary_service.get_glossary_data()
+
+ points = []
+ if "error" in glossary_data:
+ print(f"Error fetching glossary data: {glossary_data.get('message')}")
+ return points
+
+ # Flatten the glossary structure
+ def process_node(node):
+ # Current node
+ name = node.get("name")
+ description = node.get("description", "")
+
+ # Create point for the node itself if it has meaningful content
+ if name:
+ # Generate deterministic UUID based on name
+ point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, name))
+ points.append(
+ {
+ "id": point_id,
+ "vector": {}, # Placeholder, will be embedded later
+ "payload": {
+ "name": name,
+ "description": description,
+ "type": "term", # or node
+ },
+ }
+ )
+
+ # Process children
+ if "details" in node and "children" in node["details"]:
+ for child in node["details"]["children"]:
+ child_name = child.get("name")
+ child_desc = child.get("description", "")
+ if child_name:
+ child_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, child_name))
+ points.append(
+ {
+ "id": child_id,
+ "vector": {},
+ "payload": {
+ "name": child_name,
+ "description": child_desc,
+ "type": "term",
+ },
+ }
+ )
+
+ for node in glossary_data.get("nodes", []):
+ process_node(node)
+
+ return points
+
+
+def get_query_vector_data() -> List[Dict]:
+ """
+ Vector Search를 위한 쿼리 예제 데이터를 조회하고 포맷팅합니다.
+ """
+ gms_server = os.getenv("DATAHUB_SERVER", "http://35.222.65.99:8080")
+ client = DataHubBaseClient(gms_server=gms_server)
+ query_service = QueryService(client)
+
+ # Fetch all queries (adjust count as needed)
+ query_data = query_service.get_query_data(count=1000)
+
+ points = []
+ if "error" in query_data:
+ print(f"Error fetching query data: {query_data.get('message')}")
+ return points
+
+ for query in query_data.get("queries", []):
+ name = query.get("name")
+ description = query.get("description", "")
+ statement = query.get("statement", "")
+
+ if name and statement:
+ # Generate deterministic UUID based on name
+ point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, name))
+ points.append(
+ {
+ "id": point_id,
+ "vector": {},
+ "payload": {
+ "name": name,
+ "description": description,
+ "statement": statement,
+ },
+ }
+ )
+
+ return points
diff --git a/utils/llm/vectordb/README.md b/utils/llm/vectordb/README.md
index 356ffcd..2a31981 100644
--- a/utils/llm/vectordb/README.md
+++ b/utils/llm/vectordb/README.md
@@ -52,13 +52,13 @@ utils/llm/vectordb/
- `vectordb_path`: 저장 경로 (기본: `dev/table_info_db`)
- 동작 방식:
- 기존 DB가 있으면 `FAISS.load_local()`로 로드
- - 없으면 `get_info_from_db()`로 문서 수집 후 `FAISS.from_documents()` 생성 및 저장
+ - 없으면 `get_table_schema()`로 문서 수집 후 `FAISS.from_documents()` 생성 및 저장
- 반환: FAISS 벡터스토어 인스턴스
**의존성**:
- `langchain_community.vectorstores.FAISS`: LangChain FAISS 래퍼
- `utils.llm.core.get_embeddings`: 임베딩 모델 로드
-- `utils.llm.tools.get_info_from_db`: DataHub에서 테이블 메타데이터 수집
+- `utils.llm.tools.get_table_schema`: DataHub에서 테이블 메타데이터 수집
**특징**:
- 로컬 디스크에 저장되어 네트워크 연결 불필요
@@ -84,7 +84,7 @@ utils/llm/vectordb/
- `PGVECTOR_COLLECTION`: "lang2sql_table_info_db"
- 동작 방식:
- 기존 컬렉션이 있고 비어있지 않으면 로드
- - 없거나 비어있으면 `get_info_from_db()`로 문서 수집 후 `PGVector.from_documents()` 생성
+ - 없거나 비어있으면 `get_table_schema()`로 문서 수집 후 `PGVector.from_documents()` 생성
- 반환: PGVector 벡터스토어 인스턴스
2. **`_check_collection_exists(connection_string, collection_name)`**
@@ -96,7 +96,7 @@ utils/llm/vectordb/
- `langchain_postgres.vectorstores.PGVector`: LangChain pgvector 래퍼
- `psycopg2`: PostgreSQL 연결
- `utils.llm.core.get_embeddings`: 임베딩 모델 로드
-- `utils.llm.tools.get_info_from_db`: DataHub에서 테이블 메타데이터 수집
+- `utils.llm.tools.get_table_schema`: DataHub에서 테이블 메타데이터 수집
**특징**:
- PostgreSQL 데이터베이스에 저장되어 다중 서버 환경에 적합
@@ -181,7 +181,7 @@ export PGVECTOR_COLLECTION=lang2sql_table_info_db
**내부 의존성**:
- `utils/llm/core/factory.py`: `get_embeddings()` - 임베딩 모델 로드
-- `utils/llm/tools/datahub.py`: `get_info_from_db()` - DataHub 메타데이터 수집
+- `utils/llm/tools/datahub.py`: `get_table_schema()` - DataHub 메타데이터 수집
**외부 의존성**:
- `langchain_community.vectorstores.FAISS`: FAISS 벡터스토어
diff --git a/utils/llm/vectordb/factory.py b/utils/llm/vectordb/factory.py
index 942a443..68a13b1 100644
--- a/utils/llm/vectordb/factory.py
+++ b/utils/llm/vectordb/factory.py
@@ -7,6 +7,20 @@
from utils.llm.vectordb.faiss_db import get_faiss_vector_db
from utils.llm.vectordb.pgvector_db import get_pgvector_db
+from utils.llm.vectordb.qdrant_db import QdrantDB
+
+
+def get_qdrant_vector_db(url: Optional[str] = None, api_key: Optional[str] = None):
+ """Qdrant VectorDB 인스턴스를 반환하고 초기화합니다."""
+ if url is None:
+ url = os.getenv("QDRANT_URL", "http://localhost:6333")
+
+ if api_key is None:
+ api_key = os.getenv("QDRANT_API_KEY")
+
+ db = QdrantDB(url=url, api_key=api_key)
+ db.initialize_collection_if_empty()
+ return db
def get_vector_db(
@@ -16,11 +30,11 @@ def get_vector_db(
VectorDB 타입과 위치에 따라 적절한 VectorDB 인스턴스를 반환합니다.
Args:
- vectordb_type: VectorDB 타입 ("faiss" 또는 "pgvector"). None인 경우 환경 변수에서 읽음.
+ vectordb_type: VectorDB 타입 ("faiss", "pgvector", "qdrant"). None인 경우 환경 변수에서 읽음.
vectordb_location: VectorDB 위치 (FAISS: 디렉토리 경로, pgvector: 연결 문자열). None인 경우 환경 변수에서 읽음.
Returns:
- VectorDB 인스턴스 (FAISS 또는 PGVector)
+ VectorDB 인스턴스 (FAISS, PGVector, 또는 Qdrant)
"""
if vectordb_type is None:
vectordb_type = os.getenv("VECTORDB_TYPE", "faiss").lower()
@@ -32,7 +46,9 @@ def get_vector_db(
return get_faiss_vector_db(vectordb_location)
elif vectordb_type == "pgvector":
return get_pgvector_db(vectordb_location)
+ elif vectordb_type == "qdrant":
+ return get_qdrant_vector_db(url=vectordb_location)
else:
raise ValueError(
- f"지원하지 않는 VectorDB 타입: {vectordb_type}. 'faiss' 또는 'pgvector'를 사용하세요."
+ f"지원하지 않는 VectorDB 타입: {vectordb_type}. 'faiss', 'pgvector', 또는 'qdrant'를 사용하세요."
)
diff --git a/utils/llm/vectordb/faiss_db.py b/utils/llm/vectordb/faiss_db.py
index d4754a5..0b48d01 100644
--- a/utils/llm/vectordb/faiss_db.py
+++ b/utils/llm/vectordb/faiss_db.py
@@ -6,9 +6,10 @@
from typing import Optional
from langchain_community.vectorstores import FAISS
+from langchain.schema import Document
from utils.llm.core import get_embeddings
-from utils.llm.tools import get_info_from_db
+from utils.llm.tools import get_table_schema
def get_faiss_vector_db(vectordb_path: Optional[str] = None):
@@ -26,7 +27,15 @@ def get_faiss_vector_db(vectordb_path: Optional[str] = None):
allow_dangerous_deserialization=True,
)
except:
- documents = get_info_from_db()
+ raw_data = get_table_schema()
+ documents = []
+ for item in raw_data:
+ for table_name, table_info in item.items():
+ column_info_str = "\n".join(
+ [f"{k}: {v}" for k, v in table_info["columns"].items()]
+ )
+ page_content = f"{table_name}: {table_info['table_description']}\nColumns:\n {column_info_str}"
+ documents.append(Document(page_content=page_content))
db = FAISS.from_documents(documents, embeddings)
db.save_local(vectordb_path)
print(f"VectorDB를 새로 생성했습니다: {vectordb_path}")
diff --git a/utils/llm/vectordb/pgvector_db.py b/utils/llm/vectordb/pgvector_db.py
index d03f034..edba041 100644
--- a/utils/llm/vectordb/pgvector_db.py
+++ b/utils/llm/vectordb/pgvector_db.py
@@ -7,9 +7,10 @@
import psycopg2
from langchain_postgres.vectorstores import PGVector
+from langchain.schema import Document
from utils.llm.core import get_embeddings
-from utils.llm.tools import get_info_from_db
+from utils.llm.tools import get_table_schema
def _check_collection_exists(connection_string: str, collection_name: str) -> bool:
@@ -71,7 +72,15 @@ def get_pgvector_db(
except Exception as e:
print(f"exception: {e}")
# 컬렉션이 없거나 불러오기에 실패한 경우, 문서를 다시 인덱싱
- documents = get_info_from_db()
+ raw_data = get_table_schema()
+ documents = []
+ for item in raw_data:
+ for table_name, table_info in item.items():
+ column_info_str = "\n".join(
+ [f"{k}: {v}" for k, v in table_info["columns"].items()]
+ )
+ page_content = f"{table_name}: {table_info['table_description']}\nColumns:\n {column_info_str}"
+ documents.append(Document(page_content=page_content))
vector_store = PGVector.from_documents(
documents=documents,
embedding=embeddings,
diff --git a/utils/llm/vectordb/qdrant_db.py b/utils/llm/vectordb/qdrant_db.py
new file mode 100644
index 0000000..6369ef8
--- /dev/null
+++ b/utils/llm/vectordb/qdrant_db.py
@@ -0,0 +1,262 @@
+from qdrant_client import QdrantClient, models
+from typing import List, Dict, Any, Optional, Union, Callable
+import os
+import uuid
+from dotenv import load_dotenv
+
+load_dotenv()
+
+
+class QdrantDB:
+ def __init__(
+ self, url: str = "http://localhost:6333", api_key: Optional[str] = None
+ ):
+ """
+ Qdrant 클라이언트를 초기화합니다.
+
+ Args:
+ url: Qdrant 서버 URL.
+ api_key: Qdrant 클라우드 또는 인증된 인스턴스를 위한 API 키.
+ """
+ self.client = QdrantClient(url=url, api_key=api_key)
+
+ def create_collection(
+ self, collection_name: str, dense_dim: int = 1536, colbert_dim: int = 128
+ ):
+ """
+ Dense, ColBERT, Sparse 벡터 구성을 포함한 컬렉션을 생성합니다.
+
+ Args:
+ collection_name: 생성할 컬렉션의 이름.
+ dense_dim: Dense 벡터의 차원 (기본값: OpenAI small 모델 기준 1536).
+ colbert_dim: ColBERT 벡터의 차원 (기본값: 128).
+ """
+ if not self.client.collection_exists(collection_name):
+ self.client.create_collection(
+ collection_name=collection_name,
+ vectors_config={
+ "dense": models.VectorParams(
+ size=dense_dim, distance=models.Distance.COSINE
+ ),
+ "colbert": models.VectorParams(
+ size=colbert_dim,
+ distance=models.Distance.COSINE,
+ multivector_config=models.MultiVectorConfig(
+ comparator=models.MultiVectorComparator.MAX_SIM
+ ),
+ hnsw_config=models.HnswConfigDiff(m=0),
+ ),
+ },
+ sparse_vectors_config={"sparse": models.SparseVectorParams()},
+ )
+ print(f"Collection '{collection_name}' created.")
+ else:
+ print(f"Collection '{collection_name}' already exists.")
+
+ def upsert(self, collection_name: str, points: List[Dict[str, Any]]):
+ """
+ 컬렉션에 포인트들을 업서트(Upsert)합니다.
+
+ Args:
+ collection_name: 컬렉션 이름.
+ points: 다음 항목들을 포함하는 딕셔너리 리스트:
+ - id: 고유 식별자 (int 또는 str)
+ - vector: 'dense', 'colbert', 'sparse' 키와 해당 벡터 값을 포함하는 딕셔너리.
+ - payload: 메타데이터를 포함하는 딕셔너리.
+ """
+ point_structs = []
+ for point in points:
+ if "id" not in point or "vector" not in point:
+ raise ValueError("Each point must contain 'id' and 'vector' keys.")
+
+ point_structs.append(
+ models.PointStruct(
+ id=point["id"],
+ vector=point["vector"],
+ payload=point.get("payload", {}),
+ )
+ )
+
+ self.client.upload_points(collection_name=collection_name, points=point_structs)
+ print(
+ f"Successfully upserted {len(point_structs)} points to '{collection_name}'."
+ )
+
+ def search(
+ self,
+ collection_name: str,
+ query_vector: Union[List[float], tuple],
+ query_filter: Optional[models.Filter] = None,
+ limit: int = 10,
+ with_payload: bool = True,
+ ) -> List[models.ScoredPoint]:
+ """
+ 특정 컬렉션에서 벡터 검색을 수행합니다.
+
+ Args:
+ collection_name: 검색할 컬렉션의 이름.
+ query_vector: 검색에 사용할 쿼리 벡터. 명명된 벡터를 사용하는 경우 ('vector_name', vector) 튜플로 전달해야 합니다.
+ query_filter: 검색 시 적용할 필터 (선택 사항).
+ limit: 반환할 결과의 최대 개수 (기본값: 10).
+ with_payload: 결과에 페이로드를 포함할지 여부 (기본값: True).
+
+ Returns:
+ 검색 결과 리스트 (ScoredPoint 객체들의 리스트).
+ """
+ return self.client.search(
+ collection_name=collection_name,
+ query_vector=query_vector,
+ query_filter=query_filter,
+ limit=limit,
+ with_payload=with_payload,
+ )
+
+ def similarity_search(
+ self, query: str, k: int = 5, collection_name: str = "lang2sql_table_schema"
+ ) -> List[Any]:
+ """
+ LangChain 호환성을 위한 유사도 검색 메서드.
+
+ Args:
+ query: 검색 쿼리 문자열.
+ k: 반환할 결과 개수.
+ collection_name: 검색할 컬렉션 이름.
+
+ Returns:
+ LangChain Document 객체 리스트.
+ """
+ from langchain.schema import Document
+ from utils.llm.core import get_embeddings
+
+ embeddings = get_embeddings()
+ query_vector = embeddings.embed_query(query)
+
+ results = self.search(
+ collection_name=collection_name,
+ query_vector=("dense", query_vector),
+ limit=k,
+ )
+
+ documents = []
+ for res in results:
+ payload = res.payload
+ # payload를 page_content와 metadata로 변환
+ # 여기서는 payload의 모든 내용을 metadata로 넣고,
+ # 특정 필드를 page_content로 구성하거나 payload 전체를 문자열로 변환
+
+ # 기존 faiss_db.py의 로직을 참고하여 page_content 구성
+ # table_name: table_description
+ # Columns:
+ # col1: desc1
+
+ table_name = payload.get("table_name", "Unknown Table")
+ table_description = payload.get("table_description", "")
+ columns = payload.get("columns", {})
+
+ column_info_str = "\n".join(
+ [f"{key}: {val}" for key, val in columns.items()]
+ )
+ page_content = (
+ f"{table_name}: {table_description}\nColumns:\n {column_info_str}"
+ )
+
+ documents.append(Document(page_content=page_content, metadata=payload))
+
+ return documents
+
+ def as_retriever(self, search_kwargs: Optional[Dict] = None):
+ """
+ LangChain Retriever 인터페이스 호환 메서드.
+ """
+ return self
+
+ def invoke(self, query: str):
+ """
+ Retriever 인터페이스의 invoke 메서드 구현.
+ """
+ # search_kwargs에서 k 값 가져오기 (기본값 5)
+ # as_retriever 호출 시 저장된 설정이 있다면 그것을 사용해야 하지만,
+ # 여기서는 간단하게 구현
+ return self.similarity_search(query)
+
+ def _get_table_schema_points(self) -> List[Dict[str, Any]]:
+ """
+ 기본 테이블 스키마 정보를 가져와서 포인트 리스트로 변환합니다.
+ """
+ from utils.llm.tools.datahub import get_table_schema
+ from utils.llm.core import get_embeddings
+
+ raw_data = get_table_schema()
+ embeddings = get_embeddings()
+
+ points = []
+ for idx, item in enumerate(raw_data):
+ for table_name, table_info in item.items():
+ # 벡터 생성을 위한 텍스트 구성
+ column_info_str = "\n".join(
+ [f"{k}: {v}" for k, v in table_info["columns"].items()]
+ )
+ text_to_embed = f"{table_name}: {table_info['table_description']}"
+
+ vector = embeddings.embed_query(text_to_embed)
+
+ # payload 구성
+ payload = {
+ "table_name": table_name,
+ "table_description": table_info["table_description"],
+ "columns": table_info["columns"],
+ }
+
+ # Generate deterministic UUID based on table_name
+ point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name))
+
+ points.append(
+ {
+ "id": point_id,
+ "vector": {"dense": vector}, # dense vector only for now
+ "payload": payload,
+ }
+ )
+ return points
+
+ def initialize_collection_if_empty(
+ self,
+ collection_name: str = "lang2sql_table_schema",
+ force_update: bool = False,
+ data_loader: Optional[Callable[[], List[Dict[str, Any]]]] = None,
+ ):
+ """
+ 컬렉션이 비어있거나 없으면 데이터를 채웁니다.
+
+ Args:
+ collection_name: 초기화할 컬렉션 이름.
+ force_update: 데이터가 있어도 강제로 업데이트할지 여부.
+ data_loader: 데이터를 가져오는 함수. 포인트 리스트(id, vector, payload)를 반환해야 합니다.
+ None인 경우 기본 테이블 스키마 로더를 사용합니다.
+ """
+ # 컬렉션 존재 여부 확인 및 생성
+ if not self.client.collection_exists(collection_name):
+ self.create_collection(collection_name)
+
+ # 데이터 존재 여부 확인
+ if not force_update:
+ count_result = self.client.count(collection_name=collection_name)
+ if count_result.count > 0:
+ print(
+ f"Collection '{collection_name}' is not empty. Skipping initialization."
+ )
+ return
+
+ print(f"Initializing collection '{collection_name}'...")
+
+ # 데이터 로드
+ if data_loader is None:
+ # 기본 동작: 테이블 스키마 정보 사용
+ points = self._get_table_schema_points()
+ else:
+ points = data_loader()
+
+ if points:
+ self.upsert(collection_name, points)
+ else:
+ print("No data found to initialize.")