414 lines
17 KiB
Python
414 lines
17 KiB
Python
import time
|
|
import requests
|
|
import json
|
|
import argparse
|
|
import io
|
|
import logging
|
|
import types
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.markdown import Markdown
|
|
from rich.json import JSON
|
|
|
|
|
|
# Local imports
|
|
from logging_config import setup_logging
|
|
import utils
|
|
import prompts as prompts
|
|
from templates import agent_template, repl_template
|
|
|
|
logger = logging.getLogger(__name__)
|
|
console = Console()
|
|
|
|
# Configuration
|
|
DEFAULT_AGENT_API = "http://localhost:8080"
|
|
DEFAULT_REPL_API = "http://localhost:8090"
|
|
|
|
DEFAULT_CONTEXT_FILE = "context.txt"
|
|
DEFAULT_TASK_FILE = "task.txt"
|
|
MAX_REPL_STEPS = 20
|
|
MAX_VIRTUAL_CONTEXT_RATIO = 0.85
|
|
|
|
|
|
class LlamaClient:
|
|
def __init__(self, base_url, name="LlamaClient"):
|
|
self.base_url = base_url
|
|
self.name = name
|
|
self.n_ctx = self._get_context_size()
|
|
self.max_input_tokens = int(self.n_ctx * MAX_VIRTUAL_CONTEXT_RATIO)
|
|
self.color = self._determine_color() # Add this line
|
|
if debug: logger.debug(f"Connected to {name} ({base_url}). Model Context: {self.n_ctx}. Max Input Safe Limit: {self.max_input_tokens}. Color: {self.color}")
|
|
|
|
def _determine_color(self):
|
|
if self.base_url == DEFAULT_AGENT_API: # Assuming args.agent_api is a string
|
|
return "dodger_blue1"
|
|
elif self.base_url == DEFAULT_REPL_API:
|
|
return "dodger_blue3"
|
|
else:
|
|
return "cyan1" # Default color if base_url is unknown
|
|
|
|
def _get_context_size(self):
|
|
try:
|
|
resp = requests.get(f"{self.base_url}/props")
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
|
|
if 'n_ctx' in data: return data['n_ctx']
|
|
if 'default_n_ctx' in data: return data['default_n_ctx']
|
|
if 'default_generation_settings' in data:
|
|
settings = data['default_generation_settings']
|
|
if 'n_ctx' in settings: return settings['n_ctx']
|
|
|
|
return 4096
|
|
except Exception as e:
|
|
logger.error(f"[{self.name}] Failed to get props: {e}. Defaulting to 4096.")
|
|
return 4096
|
|
|
|
def tokenize(self, text):
|
|
try:
|
|
resp = requests.post(f"{self.base_url}/tokenize", json={"content": text})
|
|
resp.raise_for_status()
|
|
return len(resp.json().get('tokens', []))
|
|
except Exception:
|
|
return len(text) // 4
|
|
|
|
def completion(self, prompt, schema=None, temperature=0.1):
|
|
payload = {
|
|
"prompt": prompt,
|
|
"n_predict": -1,
|
|
"temperature": temperature,
|
|
"cache_prompt": True
|
|
}
|
|
if schema:
|
|
payload["json_schema"] = schema
|
|
else:
|
|
payload["stop"] = ["<|eot_id|>", "<|im_end|>", "Observation:", "User:"]
|
|
if debug:
|
|
console.print(Panel(
|
|
prompt[500:],
|
|
title=f"Last 500 Characters of {self.name} Call",
|
|
title_align="left",
|
|
border_style=self.color
|
|
))
|
|
try:
|
|
resp = requests.post(f"{self.base_url}/completion", json=payload)
|
|
if debug:
|
|
console.print(Panel(
|
|
JSON.from_data(resp.json().get('content', '').strip()),
|
|
title=f"{self.name} Response",
|
|
title_align="left",
|
|
border_style=self.color
|
|
))
|
|
resp.raise_for_status()
|
|
return resp.json().get('content', '').strip()
|
|
except Exception as e:
|
|
logger.error(f"[{self.name}] Error calling LLM: {e}")
|
|
return f"Error: {e}"
|
|
|
|
class AgentTools:
|
|
def __init__(self, repl_client: LlamaClient, data_content: str):
|
|
self.client = repl_client
|
|
self.RAW_CORPUS = data_content
|
|
|
|
def llm_query(self, content_chunk, query):
|
|
if content_chunk == "RAW_CORPUS":
|
|
return "ERROR: You passed the string 'RAW_CORPUS' You must pass the CONTENT of the variable (e.g., `chunk = RAW_CORPUS[:1000]`, then `llm_query(chunk, ...)`)."
|
|
|
|
# --- OPTIMIZATION FIX: Heuristic check before network call ---
|
|
# Assume approx 4 chars per token. If it's wildly larger than context,
|
|
# fail fast to prevent network timeout on the /tokenize call.
|
|
estimated_tokens = len(content_chunk) // 3
|
|
if estimated_tokens > (self.client.n_ctx * 2):
|
|
return f"ERROR: Chunk is massively too large (approx {estimated_tokens} tokens). Slice strictly."
|
|
|
|
# 2. Precise Safety check
|
|
chunk_tokens = self.client.tokenize(content_chunk)
|
|
query_tokens = self.client.tokenize(query)
|
|
total = chunk_tokens + query_tokens + 150
|
|
|
|
if debug: logger.debug(f"[Sub-LLM] Processing Query with {total} tokens.")
|
|
|
|
if total > self.client.n_ctx:
|
|
msg = f"ERROR: Chunk too large ({chunk_tokens} tokens). Limit is {self.client.n_ctx}. Slice smaller."
|
|
logger.warning(msg)
|
|
return msg
|
|
|
|
# 3. Strict Grounding Prompt
|
|
sub_messages = [
|
|
{"role": repl_template.ROLE_SYSTEM, "content": (
|
|
"You are a strict reading assistant. "
|
|
"Answer the question based ONLY on the provided Context. "
|
|
"Do not use outside training data. "
|
|
f"If the answer is not in the text, say 'NULL'."
|
|
)},
|
|
{"role": repl_template.ROLE_USER, "content": f"Context:\n{content_chunk}\n\nQuestion: {query}"}
|
|
]
|
|
results = self.client.completion(utils.build_chat_prompt(sub_messages))
|
|
result_tokens = self.client.tokenize(results)
|
|
if debug: logger.debug(f"[Sub-LLM] Responded with {result_tokens} tokens.")
|
|
return results
|
|
|
|
class AgentOutputBuffer:
|
|
def __init__(self, max_total_chars=20000, max_len_per_print=1009):
|
|
self._io = io.StringIO()
|
|
self.max_total_chars = max_total_chars # Hard cap for infinite loop protection
|
|
self.max_len_per_print = max_len_per_print # Soft cap for raw data dumping protection
|
|
self.current_chars = 0
|
|
self.global_truncated = False
|
|
|
|
def custom_print(self, *args, **kwargs):
|
|
# 1. Capture the content of THIS specific print call
|
|
temp_io = io.StringIO()
|
|
print(*args, file=temp_io, **kwargs)
|
|
text = temp_io.getvalue()
|
|
|
|
# 2. Check PER-PRINT limit (The "Density" Check)
|
|
# This prevents printing raw corpus data, but allows short summaries to pass through
|
|
if len(text) > self.max_len_per_print:
|
|
# Slice the text
|
|
truncated_text = text[:self.max_len_per_print]
|
|
|
|
# Create a localized warning that doesn't stop the whole stream
|
|
text = (
|
|
f"{truncated_text}\n"
|
|
f"... [LINE TRUNCATED: Output exceeded {self.max_len_per_print-9} chars. "
|
|
f"Use slicing or llm_query() to inspect data.] ...\n"
|
|
)
|
|
|
|
# 3. Check GLOBAL limit (The "Sanity" Check)
|
|
# This prevents infinite loops (while True: print('a')) from crashing memory
|
|
if self.current_chars + len(text) > self.max_total_chars:
|
|
remaining = self.max_total_chars - self.current_chars
|
|
if remaining > 0:
|
|
self._io.write(text[:remaining])
|
|
|
|
if not self.global_truncated:
|
|
self._io.write(f"\n... [SYSTEM HALT: Total output limit ({self.max_total_chars}) reached] ...\n")
|
|
self.global_truncated = True
|
|
|
|
self.current_chars += len(text)
|
|
else:
|
|
self._io.write(text)
|
|
self.current_chars += len(text)
|
|
|
|
def read_and_clear(self):
|
|
value = self._io.getvalue()
|
|
self._io = io.StringIO()
|
|
self.current_chars = 0
|
|
self.global_truncated = False
|
|
return value
|
|
|
|
def run_agent(agent_client, repl_client, context_text, task_text):
|
|
tools = AgentTools(repl_client, context_text)
|
|
|
|
agent_schema = {
|
|
"type": "object",
|
|
"properties": {
|
|
"thought": {"type": "string", "description": "Reasoning about current state and what to do next."},
|
|
"action": {"type": "string", "enum": ["execute_python", "final_answer"]},
|
|
"content": {"type": "string", "description": "Python code or Final Answer text."}
|
|
},
|
|
"required": ["thought", "action", "content"]
|
|
}
|
|
|
|
# 1. Instantiate the buffer
|
|
out_buffer = AgentOutputBuffer()
|
|
|
|
trace_filepath = utils.init_trace_file(debug)
|
|
|
|
|
|
# 2. Add it to the environment
|
|
exec_env = {
|
|
"RAW_CORPUS": tools.RAW_CORPUS,
|
|
"llm_query": tools.llm_query,
|
|
# Standard Libs
|
|
"re": __import__("re"),
|
|
"math": __import__("math"),
|
|
"json": __import__("json"),
|
|
"collections": __import__("collections"),
|
|
"statistics": __import__("statistics"),
|
|
"random": __import__("random"),
|
|
"datetime": __import__("datetime"),
|
|
"difflib": __import__("difflib"),
|
|
"string": __import__("string"),
|
|
|
|
# Overrides
|
|
"print": out_buffer.custom_print
|
|
}
|
|
|
|
system_instruction = prompts.get_system_prompt()
|
|
|
|
messages = [
|
|
{"role": agent_template.ROLE_SYSTEM, "content": system_instruction},
|
|
{"role": agent_template.ROLE_USER, "content": f"USER TASK: {task_text}"}
|
|
]
|
|
|
|
step = 0
|
|
while step < MAX_REPL_STEPS:
|
|
step += 1
|
|
if debug: logger.debug(f"Step {step} of {MAX_REPL_STEPS}")
|
|
|
|
modules = []
|
|
functions = []
|
|
variables = []
|
|
ACTIVE_VAR_SNIPPET_LEN = 100
|
|
|
|
for name, val in exec_env.items():
|
|
if name.startswith("__"): continue
|
|
if name == "print": continue # Hide print, it's implied
|
|
|
|
if isinstance(val, types.ModuleType):
|
|
modules.append(name)
|
|
elif callable(val):
|
|
functions.append(name)
|
|
else:
|
|
# For variables, provide a type and a short preview
|
|
type_name = type(val).__name__
|
|
s_val = str(val)
|
|
# Truncate long values for display (e.g. RAW_CORPUS)
|
|
snippet = (s_val[:ACTIVE_VAR_SNIPPET_LEN] + '...') if len(s_val) > ACTIVE_VAR_SNIPPET_LEN else s_val
|
|
variables.append(f"{name} ({type_name}): {snippet}")
|
|
|
|
# 2. Create the status message
|
|
dynamic_state_msg = (
|
|
f"[SYSTEM STATE REMINDER]\n"
|
|
f"Current Step: {step}/{MAX_REPL_STEPS}\n"
|
|
f"Available Libraries: {', '.join(modules)}\n"
|
|
f"Available Tools: {', '.join(functions)}\n"
|
|
f"Active Variables:\n" + ("\n".join([f" - {v}" for v in variables]) if variables else " (None)") + "\n---"
|
|
)
|
|
|
|
# 3. Create a temporary message list for this specific inference
|
|
# We append the state to the very end so it has high 'recency' bias
|
|
inference_messages = messages.copy()
|
|
inference_messages.append({"role": agent_template.ROLE_USER, "content": dynamic_state_msg})
|
|
|
|
# 4. Build prompt using the INFERENCE messages (not the permanent history)
|
|
full_prompt = utils.build_chat_prompt(inference_messages)
|
|
|
|
usage = agent_client.tokenize(full_prompt)
|
|
if debug: logger.debug(f"Context Usage: {usage} / {agent_client.max_input_tokens}")
|
|
|
|
# Check context use and attempt compression
|
|
if usage > agent_client.max_input_tokens:
|
|
if debug: logger.warning("Context limit exceeded. Triggering History Compression.")
|
|
|
|
messages = utils.compress_history(debug, agent_client, messages, keep_last_pairs=2)
|
|
|
|
# Re-check usage after compression
|
|
full_prompt = utils.build_chat_prompt(messages)
|
|
new_usage = agent_client.tokenize(full_prompt)
|
|
if debug: logger.debug(f"Context Usage after compression: {new_usage}")
|
|
|
|
# Panic mode: If it's STILL too big (unlikely), truncate the summary
|
|
if new_usage > agent_client.max_input_tokens:
|
|
logger.error("Compression insufficient. Forcing hard truncation.")
|
|
messages.pop(2)
|
|
|
|
# Agent Completion
|
|
response_text = agent_client.completion(full_prompt, schema=agent_schema, temperature=0.5)
|
|
|
|
try:
|
|
response_json = json.loads(response_text)
|
|
except json.JSONDecodeError:
|
|
logger.error("JSON Parse Error")
|
|
messages.append({"role": agent_template.ROLE_USER, "content": "System: Invalid JSON returned. Please retry."})
|
|
continue
|
|
|
|
thought = response_json.get("thought", "")
|
|
action = response_json.get("action", "")
|
|
content = response_json.get("content", "")
|
|
|
|
if action == "execute_python" and content:
|
|
# Run the safeguard. If the code is bad, 'content' gets replaced
|
|
content = utils.safeguard_and_repair(debug, agent_client, messages, agent_schema, content)
|
|
|
|
if debug:
|
|
console.print(Panel(
|
|
f"[italic]{thought}[/italic]",
|
|
title="🧠 Agent Thought",
|
|
title_align="left",
|
|
border_style="magenta"
|
|
))
|
|
|
|
messages.append({"role": agent_template.ROLE_ASSISTANT, "content": json.dumps(response_json, indent=2, ensure_ascii=False)})
|
|
|
|
# 3. Execution
|
|
if action == "final_answer":
|
|
# 1. Capture the raw result (keep this for logs/debugging)
|
|
if debug: logger.debug(f"Raw Agent Output: {content}")
|
|
|
|
# Check if content looks like JSON/Structure, if so, summarize it.
|
|
# Even if it's already text, a quick polish pass ensures consistent tone.
|
|
final_report = utils.generate_final_report(debug, agent_client, task_text, content)
|
|
|
|
# 3. Print the pretty version
|
|
final_report_md = Markdown(final_report)
|
|
print("\n\n")
|
|
console.print(final_report_md)
|
|
print("\n")
|
|
break
|
|
|
|
elif action == "execute_python":
|
|
# Update the thought/log to reflect potential changes for the human observer
|
|
if debug and content != response_json.get("content"):
|
|
console.print(Panel(content, title="Executing Code via Safeguard", title_align="left", border_style="cyan"))
|
|
elif debug and content == response_json.get("content"):
|
|
console.print(Panel(content, title="Executing Code", title_align="left", border_style="yellow"))
|
|
|
|
observation = ""
|
|
try:
|
|
# 1. Clear any leftover junk from previous steps (safety)
|
|
out_buffer.read_and_clear()
|
|
|
|
# 2. Execute. The Agent calls 'print', which goes to out_buffer
|
|
exec(content, exec_env)
|
|
|
|
# 3. Extract the text
|
|
observation = out_buffer.read_and_clear()
|
|
|
|
if not observation:
|
|
observation = "Code executed successfully (no output)."
|
|
except Exception as e:
|
|
observation = f"Python Error: {e}"
|
|
logger.error(f"Code Execution Error: {e}")
|
|
|
|
if debug:
|
|
console.print(Panel(
|
|
f"{observation.strip()}",
|
|
title="Observation",
|
|
title_align="left",
|
|
border_style="dark_green"
|
|
))
|
|
messages.append({"role": agent_template.ROLE_USER, "content": f"Observation:\n{observation}"})
|
|
|
|
else:
|
|
messages.append({"role": agent_template.ROLE_USER, "content": f"System: Unknown action '{action}'."})
|
|
|
|
utils.save_agent_trace(trace_filepath, messages)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="""Edge Recursive Language Model
|
|
|
|
A sophisticated data extraction and analysis tool that mimics the process of a human data scientist, carefully exploring and structuring a large dataset before performing targeted queries.""")
|
|
parser.add_argument("--context", default=DEFAULT_CONTEXT_FILE, help="Path to text file to process")
|
|
parser.add_argument("--task", default=DEFAULT_TASK_FILE, help="Path to task instruction file")
|
|
parser.add_argument("--override_task", help="Direct string override for the task")
|
|
parser.add_argument("--agent_api", default=DEFAULT_AGENT_API, help="URL for the Main Agent LLM")
|
|
parser.add_argument("--repl_api", default=DEFAULT_REPL_API, help="URL for the Sub-call/REPL LLM")
|
|
parser.add_argument("--debug", action="store_true", help="Enable verbose debug logging")
|
|
|
|
args = parser.parse_args()
|
|
debug = args.debug
|
|
log_level=logging.DEBUG if debug else logging.INFO
|
|
setup_logging(level=log_level, debug=debug)
|
|
|
|
if debug: logger.info("Starting EdgeRLM...")
|
|
context_content = utils.load_file(args.context)
|
|
if debug: logger.debug(f"Loaded Context: {len(context_content)} characters.")
|
|
task_content = args.override_task if args.override_task else load_file(args.task)
|
|
|
|
agent_client = LlamaClient(args.agent_api, "Agent")
|
|
repl_client = LlamaClient(args.repl_api, "REPL")
|
|
|
|
run_agent(agent_client, repl_client, context_content, task_content) |