Handles both thinking and non thinking models now
Browse files- agent/agent_runner.py +21 -4
- agent/llm.py +11 -1
- agent/state.py +32 -1
agent/agent_runner.py
CHANGED
|
@@ -16,7 +16,7 @@ from __future__ import annotations
|
|
| 16 |
|
| 17 |
import json
|
| 18 |
import logging
|
| 19 |
-
from typing import Any, Dict
|
| 20 |
|
| 21 |
from agent.state import AgentState
|
| 22 |
from agent.memory import get_memory, SQLBufferMemory
|
|
@@ -37,6 +37,23 @@ logger.setLevel(logging.DEBUG)
|
|
| 37 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
# Helper β normalise LLM / chain return types
|
| 39 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def _extract_text(res: Any) -> str:
|
| 41 |
if isinstance(res, str):
|
| 42 |
return res
|
|
@@ -97,7 +114,7 @@ def run_agent_step(state: AgentState, user_input: str, conversation_id: str):
|
|
| 97 |
field_fallback_raw = parsed.get("update_needed_values", {}) or {}
|
| 98 |
if field_fallback_raw:
|
| 99 |
if isinstance(field_fallback_raw, dict):
|
| 100 |
-
state.needed_fields.update(field_fallback_raw)
|
| 101 |
else:
|
| 102 |
# Handle cases where the LLM returns a list/tuple of keyβvalue pairs
|
| 103 |
try:
|
|
@@ -131,7 +148,7 @@ def run_agent_step(state: AgentState, user_input: str, conversation_id: str):
|
|
| 131 |
new_vals_raw = parsed.get("update_needed_values", {})
|
| 132 |
if new_vals_raw:
|
| 133 |
if isinstance(new_vals_raw, dict):
|
| 134 |
-
state.needed_fields.update(new_vals_raw)
|
| 135 |
else:
|
| 136 |
try:
|
| 137 |
state.needed_fields.update(dict(new_vals_raw))
|
|
@@ -223,7 +240,7 @@ def run_agent_step(state: AgentState, user_input: str, conversation_id: str):
|
|
| 223 |
|
| 224 |
if follow_parsed:
|
| 225 |
new_vals = follow_parsed.get("update_needed_values", {})
|
| 226 |
-
state.needed_fields.update(new_vals)
|
| 227 |
|
| 228 |
# We echo the missing description back to the user
|
| 229 |
reply_to_user = check.missing_desc
|
|
|
|
| 16 |
|
| 17 |
import json
|
| 18 |
import logging
|
| 19 |
+
from typing import Any, Dict, Union
|
| 20 |
|
| 21 |
from agent.state import AgentState
|
| 22 |
from agent.memory import get_memory, SQLBufferMemory
|
|
|
|
| 37 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
# Helper β normalise LLM / chain return types
|
| 39 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 40 |
+
def _stringify_values(d: Dict[str, Any]) -> Dict[str, str]:
|
| 41 |
+
"""Ensure all dict values are plain strings (JSON-encoded if needed)."""
|
| 42 |
+
out: Dict[str, str] = {}
|
| 43 |
+
for k, v in d.items():
|
| 44 |
+
if isinstance(v, str):
|
| 45 |
+
out[k] = v
|
| 46 |
+
else:
|
| 47 |
+
# Fallback β JSON-encode non-string values (e.g. nested dicts)
|
| 48 |
+
try:
|
| 49 |
+
out[k] = json.dumps(v, ensure_ascii=False)
|
| 50 |
+
except TypeError:
|
| 51 |
+
out[k] = str(v)
|
| 52 |
+
return out
|
| 53 |
+
|
| 54 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
+
# Helper β normalise LLM / chain return types
|
| 56 |
+
|
| 57 |
def _extract_text(res: Any) -> str:
|
| 58 |
if isinstance(res, str):
|
| 59 |
return res
|
|
|
|
| 114 |
field_fallback_raw = parsed.get("update_needed_values", {}) or {}
|
| 115 |
if field_fallback_raw:
|
| 116 |
if isinstance(field_fallback_raw, dict):
|
| 117 |
+
state.needed_fields.update(_stringify_values(field_fallback_raw))
|
| 118 |
else:
|
| 119 |
# Handle cases where the LLM returns a list/tuple of keyβvalue pairs
|
| 120 |
try:
|
|
|
|
| 148 |
new_vals_raw = parsed.get("update_needed_values", {})
|
| 149 |
if new_vals_raw:
|
| 150 |
if isinstance(new_vals_raw, dict):
|
| 151 |
+
state.needed_fields.update(_stringify_values(new_vals_raw))
|
| 152 |
else:
|
| 153 |
try:
|
| 154 |
state.needed_fields.update(dict(new_vals_raw))
|
|
|
|
| 240 |
|
| 241 |
if follow_parsed:
|
| 242 |
new_vals = follow_parsed.get("update_needed_values", {})
|
| 243 |
+
state.needed_fields.update(_stringify_values(new_vals))
|
| 244 |
|
| 245 |
# We echo the missing description back to the user
|
| 246 |
reply_to_user = check.missing_desc
|
agent/llm.py
CHANGED
|
@@ -19,6 +19,7 @@ logger.setLevel(logging.INFO)
|
|
| 19 |
# ββββββββββββββββββββββββββββββββ
|
| 20 |
load_dotenv()
|
| 21 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
|
| 22 |
if not HF_TOKEN:
|
| 23 |
raise RuntimeError("Missing HF_TOKEN environment variable")
|
| 24 |
|
|
@@ -35,8 +36,17 @@ llm_chain: Runnable = ChatOpenAI(
|
|
| 35 |
streaming=True,
|
| 36 |
)
|
| 37 |
|
| 38 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
|
|
|
| 40 |
# llm_chain: Runnable = ChatOpenAI(
|
| 41 |
# model="deepseek-chat", # Optional for HF OpenAI-compatible endpoints, but kept for clarity
|
| 42 |
# openai_api_base="https://qwryad273mlndckn.us-east-1.aws.endpoints.huggingface.cloud/v1/",
|
|
|
|
| 19 |
# ββββββββββββββββββββββββββββββββ
|
| 20 |
load_dotenv()
|
| 21 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 22 |
+
print("HF_TOKEN", HF_TOKEN)
|
| 23 |
if not HF_TOKEN:
|
| 24 |
raise RuntimeError("Missing HF_TOKEN environment variable")
|
| 25 |
|
|
|
|
| 36 |
streaming=True,
|
| 37 |
)
|
| 38 |
|
| 39 |
+
#DeepSeekOpenRouter
|
| 40 |
+
# llm_chain: Runnable = ChatOpenAI(
|
| 41 |
+
# model="deepseek/deepseek-chat-v3-0324:free", # Optional for HF OpenAI-compatible endpoints, but kept for clarity
|
| 42 |
+
# openai_api_base="https://openrouter.ai/api/v1",
|
| 43 |
+
# openai_api_key=os.getenv("OPENROUTER_API_KEY"),
|
| 44 |
+
# temperature=0.3,
|
| 45 |
+
# max_tokens=8192,
|
| 46 |
+
# streaming=True,
|
| 47 |
+
# )
|
| 48 |
|
| 49 |
+
#DeepSeekR1-1q1
|
| 50 |
# llm_chain: Runnable = ChatOpenAI(
|
| 51 |
# model="deepseek-chat", # Optional for HF OpenAI-compatible endpoints, but kept for clarity
|
| 52 |
# openai_api_base="https://qwryad273mlndckn.us-east-1.aws.endpoints.huggingface.cloud/v1/",
|
agent/state.py
CHANGED
|
@@ -4,7 +4,9 @@ agent/state.py
|
|
| 4 |
Canonical state shared by all chains & the runner.
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
| 8 |
from pydantic import BaseModel
|
| 9 |
|
| 10 |
|
|
@@ -18,6 +20,35 @@ class AgentState(BaseModel):
|
|
| 18 |
is_drafted: bool = False # True β’ draft has been generated
|
| 19 |
missing_prompt_count: int = 0 # Times user has been asked for missing details
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
# ββββββββββββββββββββββββββββββ
|
| 22 |
# Internal helpers (not exposed)
|
| 23 |
# ββββββββββββββββββββββββββββββ
|
|
|
|
| 4 |
Canonical state shared by all chains & the runner.
|
| 5 |
"""
|
| 6 |
|
| 7 |
+
import json
|
| 8 |
+
from typing import Dict, Any
|
| 9 |
+
from pydantic import BaseModel, field_validator
|
| 10 |
from pydantic import BaseModel
|
| 11 |
|
| 12 |
|
|
|
|
| 20 |
is_drafted: bool = False # True β’ draft has been generated
|
| 21 |
missing_prompt_count: int = 0 # Times user has been asked for missing details
|
| 22 |
|
| 23 |
+
# ββββββββββββββββββββββββββββββ
|
| 24 |
+
# Validation β ensure string values
|
| 25 |
+
# ββββββββββββββββββββββββββββββ
|
| 26 |
+
@field_validator('needed_fields', mode='before')
|
| 27 |
+
@classmethod
|
| 28 |
+
def _coerce_needed_fields(cls, v: Any):
|
| 29 |
+
"""Ensure needed_fields is Dict[str, str] regardless of input shapes."""
|
| 30 |
+
if v is None:
|
| 31 |
+
return {}
|
| 32 |
+
if isinstance(v, dict):
|
| 33 |
+
str_dict: Dict[str, str] = {}
|
| 34 |
+
for k, val in v.items():
|
| 35 |
+
if isinstance(val, str):
|
| 36 |
+
str_dict[str(k)] = val
|
| 37 |
+
else:
|
| 38 |
+
try:
|
| 39 |
+
str_dict[str(k)] = json.dumps(val, ensure_ascii=False)
|
| 40 |
+
except TypeError:
|
| 41 |
+
str_dict[str(k)] = str(val)
|
| 42 |
+
return str_dict
|
| 43 |
+
# Accept list of pairs etc.
|
| 44 |
+
try:
|
| 45 |
+
as_dict = dict(v)
|
| 46 |
+
return cls._coerce_needed_fields(as_dict)
|
| 47 |
+
except Exception:
|
| 48 |
+
# Fallback: ignore invalid format
|
| 49 |
+
return {}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
# ββββββββββββββββββββββββββββββ
|
| 53 |
# Internal helpers (not exposed)
|
| 54 |
# ββββββββββββββββββββββββββββββ
|