Files
erlm/edge_rlm.py
2026-01-26 12:40:24 +00:00

414 lines
17 KiB
Python

import time
import requests
import json
import argparse
import io
import logging
import types
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
from rich.json import JSON
# Local imports
from logging_config import setup_logging
import utils
import prompts as prompts
from templates import agent_template, repl_template
logger = logging.getLogger(__name__)
console = Console()
# Configuration
DEFAULT_AGENT_API = "http://localhost:8080"
DEFAULT_REPL_API = "http://localhost:8090"
DEFAULT_CONTEXT_FILE = "context.txt"
DEFAULT_TASK_FILE = "task.txt"
MAX_REPL_STEPS = 20
MAX_VIRTUAL_CONTEXT_RATIO = 0.85
class LlamaClient:
def __init__(self, base_url, name="LlamaClient"):
self.base_url = base_url
self.name = name
self.n_ctx = self._get_context_size()
self.max_input_tokens = int(self.n_ctx * MAX_VIRTUAL_CONTEXT_RATIO)
self.color = self._determine_color() # Add this line
if debug: logger.debug(f"Connected to {name} ({base_url}). Model Context: {self.n_ctx}. Max Input Safe Limit: {self.max_input_tokens}. Color: {self.color}")
def _determine_color(self):
if self.base_url == DEFAULT_AGENT_API: # Assuming args.agent_api is a string
return "dodger_blue1"
elif self.base_url == DEFAULT_REPL_API:
return "dodger_blue3"
else:
return "cyan1" # Default color if base_url is unknown
def _get_context_size(self):
try:
resp = requests.get(f"{self.base_url}/props")
resp.raise_for_status()
data = resp.json()
if 'n_ctx' in data: return data['n_ctx']
if 'default_n_ctx' in data: return data['default_n_ctx']
if 'default_generation_settings' in data:
settings = data['default_generation_settings']
if 'n_ctx' in settings: return settings['n_ctx']
return 4096
except Exception as e:
logger.error(f"[{self.name}] Failed to get props: {e}. Defaulting to 4096.")
return 4096
def tokenize(self, text):
try:
resp = requests.post(f"{self.base_url}/tokenize", json={"content": text})
resp.raise_for_status()
return len(resp.json().get('tokens', []))
except Exception:
return len(text) // 4
def completion(self, prompt, schema=None, temperature=0.1):
payload = {
"prompt": prompt,
"n_predict": -1,
"temperature": temperature,
"cache_prompt": True
}
if schema:
payload["json_schema"] = schema
else:
payload["stop"] = ["<|eot_id|>", "<|im_end|>", "Observation:", "User:"]
if debug:
console.print(Panel(
prompt[500:],
title=f"Last 500 Characters of {self.name} Call",
title_align="left",
border_style=self.color
))
try:
resp = requests.post(f"{self.base_url}/completion", json=payload)
if debug:
console.print(Panel(
JSON.from_data(resp.json().get('content', '').strip()),
title=f"{self.name} Response",
title_align="left",
border_style=self.color
))
resp.raise_for_status()
return resp.json().get('content', '').strip()
except Exception as e:
logger.error(f"[{self.name}] Error calling LLM: {e}")
return f"Error: {e}"
class AgentTools:
def __init__(self, repl_client: LlamaClient, data_content: str):
self.client = repl_client
self.RAW_CORPUS = data_content
def llm_query(self, content_chunk, query):
if content_chunk == "RAW_CORPUS":
return "ERROR: You passed the string 'RAW_CORPUS' You must pass the CONTENT of the variable (e.g., `chunk = RAW_CORPUS[:1000]`, then `llm_query(chunk, ...)`)."
# --- OPTIMIZATION FIX: Heuristic check before network call ---
# Assume approx 4 chars per token. If it's wildly larger than context,
# fail fast to prevent network timeout on the /tokenize call.
estimated_tokens = len(content_chunk) // 3
if estimated_tokens > (self.client.n_ctx * 2):
return f"ERROR: Chunk is massively too large (approx {estimated_tokens} tokens). Slice strictly."
# 2. Precise Safety check
chunk_tokens = self.client.tokenize(content_chunk)
query_tokens = self.client.tokenize(query)
total = chunk_tokens + query_tokens + 150
if debug: logger.debug(f"[Sub-LLM] Processing Query with {total} tokens.")
if total > self.client.n_ctx:
msg = f"ERROR: Chunk too large ({chunk_tokens} tokens). Limit is {self.client.n_ctx}. Slice smaller."
logger.warning(msg)
return msg
# 3. Strict Grounding Prompt
sub_messages = [
{"role": repl_template.ROLE_SYSTEM, "content": (
"You are a strict reading assistant. "
"Answer the question based ONLY on the provided Context. "
"Do not use outside training data. "
f"If the answer is not in the text, say 'NULL'."
)},
{"role": repl_template.ROLE_USER, "content": f"Context:\n{content_chunk}\n\nQuestion: {query}"}
]
results = self.client.completion(utils.build_chat_prompt(sub_messages))
result_tokens = self.client.tokenize(results)
if debug: logger.debug(f"[Sub-LLM] Responded with {result_tokens} tokens.")
return results
class AgentOutputBuffer:
def __init__(self, max_total_chars=20000, max_len_per_print=1009):
self._io = io.StringIO()
self.max_total_chars = max_total_chars # Hard cap for infinite loop protection
self.max_len_per_print = max_len_per_print # Soft cap for raw data dumping protection
self.current_chars = 0
self.global_truncated = False
def custom_print(self, *args, **kwargs):
# 1. Capture the content of THIS specific print call
temp_io = io.StringIO()
print(*args, file=temp_io, **kwargs)
text = temp_io.getvalue()
# 2. Check PER-PRINT limit (The "Density" Check)
# This prevents printing raw corpus data, but allows short summaries to pass through
if len(text) > self.max_len_per_print:
# Slice the text
truncated_text = text[:self.max_len_per_print]
# Create a localized warning that doesn't stop the whole stream
text = (
f"{truncated_text}\n"
f"... [LINE TRUNCATED: Output exceeded {self.max_len_per_print-9} chars. "
f"Use slicing or llm_query() to inspect data.] ...\n"
)
# 3. Check GLOBAL limit (The "Sanity" Check)
# This prevents infinite loops (while True: print('a')) from crashing memory
if self.current_chars + len(text) > self.max_total_chars:
remaining = self.max_total_chars - self.current_chars
if remaining > 0:
self._io.write(text[:remaining])
if not self.global_truncated:
self._io.write(f"\n... [SYSTEM HALT: Total output limit ({self.max_total_chars}) reached] ...\n")
self.global_truncated = True
self.current_chars += len(text)
else:
self._io.write(text)
self.current_chars += len(text)
def read_and_clear(self):
value = self._io.getvalue()
self._io = io.StringIO()
self.current_chars = 0
self.global_truncated = False
return value
def run_agent(agent_client, repl_client, context_text, task_text):
tools = AgentTools(repl_client, context_text)
agent_schema = {
"type": "object",
"properties": {
"thought": {"type": "string", "description": "Reasoning about current state and what to do next."},
"action": {"type": "string", "enum": ["execute_python", "final_answer"]},
"content": {"type": "string", "description": "Python code or Final Answer text."}
},
"required": ["thought", "action", "content"]
}
# 1. Instantiate the buffer
out_buffer = AgentOutputBuffer()
trace_filepath = utils.init_trace_file(debug)
# 2. Add it to the environment
exec_env = {
"RAW_CORPUS": tools.RAW_CORPUS,
"llm_query": tools.llm_query,
# Standard Libs
"re": __import__("re"),
"math": __import__("math"),
"json": __import__("json"),
"collections": __import__("collections"),
"statistics": __import__("statistics"),
"random": __import__("random"),
"datetime": __import__("datetime"),
"difflib": __import__("difflib"),
"string": __import__("string"),
# Overrides
"print": out_buffer.custom_print
}
system_instruction = prompts.get_system_prompt()
messages = [
{"role": agent_template.ROLE_SYSTEM, "content": system_instruction},
{"role": agent_template.ROLE_USER, "content": f"USER TASK: {task_text}"}
]
step = 0
while step < MAX_REPL_STEPS:
step += 1
if debug: logger.debug(f"Step {step} of {MAX_REPL_STEPS}")
modules = []
functions = []
variables = []
ACTIVE_VAR_SNIPPET_LEN = 100
for name, val in exec_env.items():
if name.startswith("__"): continue
if name == "print": continue # Hide print, it's implied
if isinstance(val, types.ModuleType):
modules.append(name)
elif callable(val):
functions.append(name)
else:
# For variables, provide a type and a short preview
type_name = type(val).__name__
s_val = str(val)
# Truncate long values for display (e.g. RAW_CORPUS)
snippet = (s_val[:ACTIVE_VAR_SNIPPET_LEN] + '...') if len(s_val) > ACTIVE_VAR_SNIPPET_LEN else s_val
variables.append(f"{name} ({type_name}): {snippet}")
# 2. Create the status message
dynamic_state_msg = (
f"[SYSTEM STATE REMINDER]\n"
f"Current Step: {step}/{MAX_REPL_STEPS}\n"
f"Available Libraries: {', '.join(modules)}\n"
f"Available Tools: {', '.join(functions)}\n"
f"Active Variables:\n" + ("\n".join([f" - {v}" for v in variables]) if variables else " (None)") + "\n---"
)
# 3. Create a temporary message list for this specific inference
# We append the state to the very end so it has high 'recency' bias
inference_messages = messages.copy()
inference_messages.append({"role": agent_template.ROLE_USER, "content": dynamic_state_msg})
# 4. Build prompt using the INFERENCE messages (not the permanent history)
full_prompt = utils.build_chat_prompt(inference_messages)
usage = agent_client.tokenize(full_prompt)
if debug: logger.debug(f"Context Usage: {usage} / {agent_client.max_input_tokens}")
# Check context use and attempt compression
if usage > agent_client.max_input_tokens:
if debug: logger.warning("Context limit exceeded. Triggering History Compression.")
messages = utils.compress_history(debug, agent_client, messages, keep_last_pairs=2)
# Re-check usage after compression
full_prompt = utils.build_chat_prompt(messages)
new_usage = agent_client.tokenize(full_prompt)
if debug: logger.debug(f"Context Usage after compression: {new_usage}")
# Panic mode: If it's STILL too big (unlikely), truncate the summary
if new_usage > agent_client.max_input_tokens:
logger.error("Compression insufficient. Forcing hard truncation.")
messages.pop(2)
# Agent Completion
response_text = agent_client.completion(full_prompt, schema=agent_schema, temperature=0.5)
try:
response_json = json.loads(response_text)
except json.JSONDecodeError:
logger.error("JSON Parse Error")
messages.append({"role": agent_template.ROLE_USER, "content": "System: Invalid JSON returned. Please retry."})
continue
thought = response_json.get("thought", "")
action = response_json.get("action", "")
content = response_json.get("content", "")
if action == "execute_python" and content:
# Run the safeguard. If the code is bad, 'content' gets replaced
content = utils.safeguard_and_repair(debug, agent_client, messages, agent_schema, content)
if debug:
console.print(Panel(
f"[italic]{thought}[/italic]",
title="🧠 Agent Thought",
title_align="left",
border_style="magenta"
))
messages.append({"role": agent_template.ROLE_ASSISTANT, "content": json.dumps(response_json, indent=2, ensure_ascii=False)})
# 3. Execution
if action == "final_answer":
# 1. Capture the raw result (keep this for logs/debugging)
if debug: logger.debug(f"Raw Agent Output: {content}")
# Check if content looks like JSON/Structure, if so, summarize it.
# Even if it's already text, a quick polish pass ensures consistent tone.
final_report = utils.generate_final_report(debug, agent_client, task_text, content)
# 3. Print the pretty version
final_report_md = Markdown(final_report)
print("\n\n")
console.print(final_report_md)
print("\n")
break
elif action == "execute_python":
# Update the thought/log to reflect potential changes for the human observer
if debug and content != response_json.get("content"):
console.print(Panel(content, title="Executing Code via Safeguard", title_align="left", border_style="cyan"))
elif debug and content == response_json.get("content"):
console.print(Panel(content, title="Executing Code", title_align="left", border_style="yellow"))
observation = ""
try:
# 1. Clear any leftover junk from previous steps (safety)
out_buffer.read_and_clear()
# 2. Execute. The Agent calls 'print', which goes to out_buffer
exec(content, exec_env)
# 3. Extract the text
observation = out_buffer.read_and_clear()
if not observation:
observation = "Code executed successfully (no output)."
except Exception as e:
observation = f"Python Error: {e}"
logger.error(f"Code Execution Error: {e}")
if debug:
console.print(Panel(
f"{observation.strip()}",
title="Observation",
title_align="left",
border_style="dark_green"
))
messages.append({"role": agent_template.ROLE_USER, "content": f"Observation:\n{observation}"})
else:
messages.append({"role": agent_template.ROLE_USER, "content": f"System: Unknown action '{action}'."})
utils.save_agent_trace(trace_filepath, messages)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="""Edge Recursive Language Model
A sophisticated data extraction and analysis tool that mimics the process of a human data scientist, carefully exploring and structuring a large dataset before performing targeted queries.""")
parser.add_argument("--context", default=DEFAULT_CONTEXT_FILE, help="Path to text file to process")
parser.add_argument("--task", default=DEFAULT_TASK_FILE, help="Path to task instruction file")
parser.add_argument("--override_task", help="Direct string override for the task")
parser.add_argument("--agent_api", default=DEFAULT_AGENT_API, help="URL for the Main Agent LLM")
parser.add_argument("--repl_api", default=DEFAULT_REPL_API, help="URL for the Sub-call/REPL LLM")
parser.add_argument("--debug", action="store_true", help="Enable verbose debug logging")
args = parser.parse_args()
debug = args.debug
log_level=logging.DEBUG if debug else logging.INFO
setup_logging(level=log_level, debug=debug)
if debug: logger.info("Starting EdgeRLM...")
context_content = utils.load_file(args.context)
if debug: logger.debug(f"Loaded Context: {len(context_content)} characters.")
task_content = args.override_task if args.override_task else load_file(args.task)
agent_client = LlamaClient(args.agent_api, "Agent")
repl_client = LlamaClient(args.repl_api, "REPL")
run_agent(agent_client, repl_client, context_content, task_content)