MrTechie commited on
Commit
2932504
Β·
1 Parent(s): 8208a63

Handles both thinking and non thinking models now

Browse files
Files changed (3) hide show
  1. agent/agent_runner.py +21 -4
  2. agent/llm.py +11 -1
  3. agent/state.py +32 -1
agent/agent_runner.py CHANGED
@@ -16,7 +16,7 @@ from __future__ import annotations
16
 
17
  import json
18
  import logging
19
- from typing import Any, Dict
20
 
21
  from agent.state import AgentState
22
  from agent.memory import get_memory, SQLBufferMemory
@@ -37,6 +37,23 @@ logger.setLevel(logging.DEBUG)
37
  # ─────────────────────────────────────────────────────────────
38
  # Helper – normalise LLM / chain return types
39
  # ─────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def _extract_text(res: Any) -> str:
41
  if isinstance(res, str):
42
  return res
@@ -97,7 +114,7 @@ def run_agent_step(state: AgentState, user_input: str, conversation_id: str):
97
  field_fallback_raw = parsed.get("update_needed_values", {}) or {}
98
  if field_fallback_raw:
99
  if isinstance(field_fallback_raw, dict):
100
- state.needed_fields.update(field_fallback_raw)
101
  else:
102
  # Handle cases where the LLM returns a list/tuple of key–value pairs
103
  try:
@@ -131,7 +148,7 @@ def run_agent_step(state: AgentState, user_input: str, conversation_id: str):
131
  new_vals_raw = parsed.get("update_needed_values", {})
132
  if new_vals_raw:
133
  if isinstance(new_vals_raw, dict):
134
- state.needed_fields.update(new_vals_raw)
135
  else:
136
  try:
137
  state.needed_fields.update(dict(new_vals_raw))
@@ -223,7 +240,7 @@ def run_agent_step(state: AgentState, user_input: str, conversation_id: str):
223
 
224
  if follow_parsed:
225
  new_vals = follow_parsed.get("update_needed_values", {})
226
- state.needed_fields.update(new_vals)
227
 
228
  # We echo the missing description back to the user
229
  reply_to_user = check.missing_desc
 
16
 
17
  import json
18
  import logging
19
+ from typing import Any, Dict, Union
20
 
21
  from agent.state import AgentState
22
  from agent.memory import get_memory, SQLBufferMemory
 
37
  # ─────────────────────────────────────────────────────────────
38
  # Helper – normalise LLM / chain return types
39
  # ─────────────────────────────────────────────────────────────
40
+ def _stringify_values(d: Dict[str, Any]) -> Dict[str, str]:
41
+ """Ensure all dict values are plain strings (JSON-encoded if needed)."""
42
+ out: Dict[str, str] = {}
43
+ for k, v in d.items():
44
+ if isinstance(v, str):
45
+ out[k] = v
46
+ else:
47
+ # Fallback – JSON-encode non-string values (e.g. nested dicts)
48
+ try:
49
+ out[k] = json.dumps(v, ensure_ascii=False)
50
+ except TypeError:
51
+ out[k] = str(v)
52
+ return out
53
+
54
+ # ─────────────────────────────────────────────────────────────
55
+ # Helper – normalise LLM / chain return types
56
+
57
  def _extract_text(res: Any) -> str:
58
  if isinstance(res, str):
59
  return res
 
114
  field_fallback_raw = parsed.get("update_needed_values", {}) or {}
115
  if field_fallback_raw:
116
  if isinstance(field_fallback_raw, dict):
117
+ state.needed_fields.update(_stringify_values(field_fallback_raw))
118
  else:
119
  # Handle cases where the LLM returns a list/tuple of key–value pairs
120
  try:
 
148
  new_vals_raw = parsed.get("update_needed_values", {})
149
  if new_vals_raw:
150
  if isinstance(new_vals_raw, dict):
151
+ state.needed_fields.update(_stringify_values(new_vals_raw))
152
  else:
153
  try:
154
  state.needed_fields.update(dict(new_vals_raw))
 
240
 
241
  if follow_parsed:
242
  new_vals = follow_parsed.get("update_needed_values", {})
243
+ state.needed_fields.update(_stringify_values(new_vals))
244
 
245
  # We echo the missing description back to the user
246
  reply_to_user = check.missing_desc
agent/llm.py CHANGED
@@ -19,6 +19,7 @@ logger.setLevel(logging.INFO)
19
  # ────────────────────────────────
20
  load_dotenv()
21
  HF_TOKEN = os.getenv("HF_TOKEN")
 
22
  if not HF_TOKEN:
23
  raise RuntimeError("Missing HF_TOKEN environment variable")
24
 
@@ -35,8 +36,17 @@ llm_chain: Runnable = ChatOpenAI(
35
  streaming=True,
36
  )
37
 
38
- #DeepSeekR1-1q1
 
 
 
 
 
 
 
 
39
 
 
40
  # llm_chain: Runnable = ChatOpenAI(
41
  # model="deepseek-chat", # Optional for HF OpenAI-compatible endpoints, but kept for clarity
42
  # openai_api_base="https://qwryad273mlndckn.us-east-1.aws.endpoints.huggingface.cloud/v1/",
 
19
  # ────────────────────────────────
20
  load_dotenv()
21
  HF_TOKEN = os.getenv("HF_TOKEN")
22
+ print("HF_TOKEN", HF_TOKEN)
23
  if not HF_TOKEN:
24
  raise RuntimeError("Missing HF_TOKEN environment variable")
25
 
 
36
  streaming=True,
37
  )
38
 
39
+ #DeepSeekOpenRouter
40
+ # llm_chain: Runnable = ChatOpenAI(
41
+ # model="deepseek/deepseek-chat-v3-0324:free", # Optional for HF OpenAI-compatible endpoints, but kept for clarity
42
+ # openai_api_base="https://openrouter.ai/api/v1",
43
+ # openai_api_key=os.getenv("OPENROUTER_API_KEY"),
44
+ # temperature=0.3,
45
+ # max_tokens=8192,
46
+ # streaming=True,
47
+ # )
48
 
49
+ #DeepSeekR1-1q1
50
  # llm_chain: Runnable = ChatOpenAI(
51
  # model="deepseek-chat", # Optional for HF OpenAI-compatible endpoints, but kept for clarity
52
  # openai_api_base="https://qwryad273mlndckn.us-east-1.aws.endpoints.huggingface.cloud/v1/",
agent/state.py CHANGED
@@ -4,7 +4,9 @@ agent/state.py
4
  Canonical state shared by all chains & the runner.
5
  """
6
 
7
- from typing import Dict
 
 
8
  from pydantic import BaseModel
9
 
10
 
@@ -18,6 +20,35 @@ class AgentState(BaseModel):
18
  is_drafted: bool = False # True β‡’ draft has been generated
19
  missing_prompt_count: int = 0 # Times user has been asked for missing details
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # ──────────────────────────────
22
  # Internal helpers (not exposed)
23
  # ──────────────────────────────
 
4
  Canonical state shared by all chains & the runner.
5
  """
6
 
7
+ import json
8
+ from typing import Dict, Any
9
+ from pydantic import BaseModel, field_validator
10
  from pydantic import BaseModel
11
 
12
 
 
20
  is_drafted: bool = False # True β‡’ draft has been generated
21
  missing_prompt_count: int = 0 # Times user has been asked for missing details
22
 
23
+ # ──────────────────────────────
24
+ # Validation – ensure string values
25
+ # ──────────────────────────────
26
+ @field_validator('needed_fields', mode='before')
27
+ @classmethod
28
+ def _coerce_needed_fields(cls, v: Any):
29
+ """Ensure needed_fields is Dict[str, str] regardless of input shapes."""
30
+ if v is None:
31
+ return {}
32
+ if isinstance(v, dict):
33
+ str_dict: Dict[str, str] = {}
34
+ for k, val in v.items():
35
+ if isinstance(val, str):
36
+ str_dict[str(k)] = val
37
+ else:
38
+ try:
39
+ str_dict[str(k)] = json.dumps(val, ensure_ascii=False)
40
+ except TypeError:
41
+ str_dict[str(k)] = str(val)
42
+ return str_dict
43
+ # Accept list of pairs etc.
44
+ try:
45
+ as_dict = dict(v)
46
+ return cls._coerce_needed_fields(as_dict)
47
+ except Exception:
48
+ # Fallback: ignore invalid format
49
+ return {}
50
+
51
+
52
  # ──────────────────────────────
53
  # Internal helpers (not exposed)
54
  # ──────────────────────────────