Initial commit
This commit is contained in:
312
utils.py
Normal file
312
utils.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
import ast
|
||||
import contextlib
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from templates import agent_template
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
console = Console()
|
||||
|
||||
|
||||
def init_trace_file(debug, log_dir="logs"):
|
||||
"""
|
||||
Creates the log directory and returns a unique filepath
|
||||
based on the current timestamp.
|
||||
"""
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
|
||||
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
||||
filename = os.path.join(log_dir, f"trace_{timestamp}.json")
|
||||
if debug: logger.debug(f"Trace logging initialized: {filename}")
|
||||
return filename
|
||||
|
||||
def save_agent_trace(filepath, messages, full_history=None):
|
||||
"""
|
||||
Dumps the current state of the conversation to a JSON file.
|
||||
Overwrites the file each step so the last write is always the complete history.
|
||||
"""
|
||||
try:
|
||||
data_to_save = {
|
||||
"timestamp": time.time(),
|
||||
# If you are using history compression, 'messages' might get cut.
|
||||
# If you want the RAW full history, pass full_history.
|
||||
# Otherwise, we log what the agent currently 'sees'.
|
||||
"context_window": messages
|
||||
}
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(data_to_save, f, indent=2, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save trace file: {e}")
|
||||
|
||||
def build_chat_prompt(messages):
|
||||
prompt = ""
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
prompt += f"{agent_template.IM_START}{role}\n{content}{agent_template.IM_END}\n"
|
||||
prompt += f"{agent_template.IM_START}{agent_template.ROLE_ASSISTANT}" # Removed trailing newline
|
||||
return prompt
|
||||
|
||||
|
||||
def _analyze_code_safety(code_str):
|
||||
"""
|
||||
Returns: (is_safe: bool, error_msg: str, line_number: int | None)
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code_str)
|
||||
except SyntaxError as e:
|
||||
# e.lineno is the line where the parser failed
|
||||
return False, f"SyntaxError: {e.msg}", e.lineno
|
||||
|
||||
tainted_vars = {"RAW_CORPUS"}
|
||||
has_print = False
|
||||
|
||||
for node in ast.walk(tree):
|
||||
# 1. Track assignments
|
||||
if isinstance(node, ast.Assign):
|
||||
if isinstance(node.value, ast.Name) and node.value.id in tainted_vars:
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
tainted_vars.add(target.id)
|
||||
|
||||
# 2. Check Call nodes
|
||||
if isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name) and node.func.id == 'print':
|
||||
has_print = True
|
||||
for arg in node.args:
|
||||
if isinstance(arg, ast.Name) and arg.id in tainted_vars:
|
||||
return False, f"Safety Violation: Printing '{arg.id}' (RAW_CORPUS). Use slicing.", node.lineno
|
||||
|
||||
# Check re.compile arguments
|
||||
is_re_compile = False
|
||||
if isinstance(node.func, ast.Attribute) and node.func.attr == 'compile':
|
||||
is_re_compile = True
|
||||
elif isinstance(node.func, ast.Name) and node.func.id == 'compile':
|
||||
is_re_compile = True
|
||||
|
||||
if is_re_compile and len(node.args) > 2:
|
||||
return False, "Library Usage Error: `re.compile` accepts max 2 args.", node.lineno
|
||||
|
||||
# 3. Global Check (No specific line number)
|
||||
if not has_print:
|
||||
return False, "Observability Error: No `print()` statements found.", None
|
||||
|
||||
return True, None, None
|
||||
|
||||
def _extract_context_block(code_str, target_lineno):
|
||||
"""
|
||||
Extracts lines surrounding target_lineno bounded by empty lines.
|
||||
Returns: (start_index, end_index, snippet_str)
|
||||
"""
|
||||
lines = code_str.split('\n')
|
||||
# target_lineno is 1-based, list is 0-based
|
||||
idx = target_lineno - 1
|
||||
|
||||
# Clamp index just in case
|
||||
if idx < 0: idx = 0
|
||||
if idx >= len(lines): idx = len(lines) - 1
|
||||
|
||||
start_idx = idx
|
||||
end_idx = idx
|
||||
|
||||
# Scan Up
|
||||
while start_idx > 0:
|
||||
if lines[start_idx - 1].strip() == "":
|
||||
break
|
||||
start_idx -= 1
|
||||
|
||||
# Scan Down
|
||||
while end_idx < len(lines) - 1:
|
||||
if lines[end_idx + 1].strip() == "":
|
||||
break
|
||||
end_idx += 1
|
||||
|
||||
# Extract the block including the found boundaries (or lack thereof)
|
||||
snippet_lines = lines[start_idx : end_idx + 1]
|
||||
return start_idx, end_idx, "\n".join(snippet_lines)
|
||||
|
||||
def safeguard_and_repair(debug, client, messages, schema, original_code):
|
||||
is_safe, error_msg, line_no = _analyze_code_safety(original_code)
|
||||
|
||||
if is_safe:
|
||||
return original_code
|
||||
|
||||
if debug:
|
||||
logger.warning(f"Safeguard triggered: {error_msg} (Line: {line_no})")
|
||||
console.print(Panel(f"{error_msg}", title="Safeguard Interrupt", style="bold red"))
|
||||
|
||||
console.print(Panel(
|
||||
f"[italic]{thought}[/italic]",
|
||||
title="Unsafe Code",
|
||||
title_align="left",
|
||||
border_style="hot_pink2"
|
||||
))
|
||||
|
||||
# STRATEGY 1: SNIPPET REPAIR (Optimization)
|
||||
# If we have a specific line number, we only send that block.
|
||||
if line_no is not None:
|
||||
start_idx, end_idx, snippet = _extract_context_block(original_code, line_no)
|
||||
|
||||
# We create a temporary "micro-agent" prompt just for fixing the snippet
|
||||
# We reuse the schema to ensure we get a clean content block back
|
||||
repair_prompt = [
|
||||
{"role": "system", "content": "You are a code repair assistant. Output only the fixed code snippet in the JSON content field."},
|
||||
{"role": "user", "content": (
|
||||
f"The following Python code snippet failed validation.\n"
|
||||
f"Error: {error_msg} (occurred around line {line_no})\n\n"
|
||||
f"```python\n{snippet}\n```\n\n"
|
||||
f"Return the JSON with the fixed snippet. "
|
||||
f"Maintain original indentation. Add a comment (# FIXED) to changed lines."
|
||||
)}
|
||||
]
|
||||
if debug:
|
||||
console.print(Panel(f"{snippet}", title="Attempting Snippet Repair", style="light_goldenrod1"))
|
||||
|
||||
response_text = client.completion(repair_prompt, schema=schema, temperature=0.0)
|
||||
|
||||
try:
|
||||
response_json = json.loads(response_text)
|
||||
fixed_snippet = response_json.get("content", "")
|
||||
|
||||
if debug:
|
||||
console.print(Panel(f"{fixed_snippet}", title="Repaired Snippet", style="yellow1"))
|
||||
|
||||
# Stitch the code back together
|
||||
all_lines = original_code.split('\n')
|
||||
# We replace the range we extracted with the new snippet
|
||||
# Note: fixed_snippet might have different line count, that's fine.
|
||||
|
||||
pre_block = all_lines[:start_idx]
|
||||
post_block = all_lines[end_idx + 1:]
|
||||
|
||||
# Reassemble
|
||||
full_fixed_code = "\n".join(pre_block + [fixed_snippet] + post_block)
|
||||
|
||||
return full_fixed_code
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# If the snippet repair fails to parse, fall through to full repair
|
||||
if debug: logger.error("Snippet repair failed to parse. Falling back to full repair.")
|
||||
pass
|
||||
|
||||
# STRATEGY 2: FULL REPAIR (Fallback)
|
||||
# Used for global errors (missing prints) or if snippet repair crashed
|
||||
repair_messages = messages + [
|
||||
{"role": agent_template.ROLE_ASSISTANT, "content": json.dumps({
|
||||
"thought": "Drafting code...",
|
||||
"action": "execute_python",
|
||||
"content": original_code
|
||||
})},
|
||||
{"role": agent_template.ROLE_USER, "content": (
|
||||
f"SYSTEM INTERRUPT: Your code failed pre-flight safety checks.\n"
|
||||
f"Error: {error_msg}\n\n"
|
||||
f"Generate the JSON response again with CORRECTED Python code.\n"
|
||||
f"IMPORTANT: You must add a comment (# FIXED: ...) to the corrected line."
|
||||
)}
|
||||
]
|
||||
|
||||
response_text = client.completion(build_chat_prompt(repair_messages), schema=schema, temperature=0.0)
|
||||
|
||||
try:
|
||||
response_json = json.loads(response_text)
|
||||
return response_json.get("content", "")
|
||||
except json.JSONDecodeError:
|
||||
return ""
|
||||
|
||||
def compress_history(debug, client, messages, keep_last_pairs=2):
|
||||
"""
|
||||
Compresses the middle of the conversation history.
|
||||
Preserves: System Prompt (0), User Task (1), and the last N pairs of interaction.
|
||||
"""
|
||||
# Calculate how many messages to keep at the end (pairs * 2)
|
||||
keep_count = keep_last_pairs * 2
|
||||
|
||||
# Check if we actually have enough history to compress
|
||||
# We need: System + Task + (At least 2 messages to compress) + Keep_Count
|
||||
if len(messages) < (2 + 2 + keep_count):
|
||||
if debug: logger.warning("History too short to compress, but context is full. Crashing safely.")
|
||||
return messages # Nothing we can do, let it fail or truncate manually
|
||||
|
||||
# Define the slice to compress
|
||||
# Start at 2 (after Task), End at -keep_count
|
||||
to_compress = messages[2:-keep_count]
|
||||
|
||||
# 1. format the text for the summarizer
|
||||
history_text = ""
|
||||
for msg in to_compress:
|
||||
role = msg['role'].upper()
|
||||
content = msg['content']
|
||||
history_text += f"[{role}]: {content}\n"
|
||||
|
||||
# 2. Build the summarization prompt
|
||||
summary_prompt = (
|
||||
"You are a technical documentation assistant. "
|
||||
"Summarize the following interaction history between an AI Agent and a System. "
|
||||
"Focus on: 1. Code executed, 2. Errors encountered, 3. Specific data/variables discovered. "
|
||||
"Be concise. Do not chat.\n\n"
|
||||
f"--- HISTORY START ---\n{history_text}\n--- HISTORY END ---"
|
||||
)
|
||||
|
||||
if debug: logger.debug(f"Compressing {len(to_compress)} messages...")
|
||||
|
||||
# 3. Call the LLM (We use the Agent Client for high-quality summaries)
|
||||
# We use a simple generation call here.
|
||||
summary_text = client.completion(
|
||||
build_chat_prompt([{"role": "user", "content": summary_prompt}])
|
||||
)
|
||||
|
||||
# 4. Create the new compressed message
|
||||
summary_message = {
|
||||
"role": "user",
|
||||
"content": f"[SYSTEM SUMMARY OF PREVIOUS ACTIONS]\n{summary_text}"
|
||||
}
|
||||
|
||||
# 5. Reconstruct the list
|
||||
new_messages = [messages[0], messages[1]] + [summary_message] + messages[-keep_count:]
|
||||
|
||||
if debug: logger.info(f"Compression complete. Reduced {len(messages)} msgs to {len(new_messages)}.")
|
||||
return new_messages
|
||||
|
||||
def generate_final_report(debug, client, task_text, raw_answer):
|
||||
"""
|
||||
Converts the Agent's raw (likely structured/technical) answer into
|
||||
a natural language response for the user.
|
||||
"""
|
||||
system_prompt = (
|
||||
"You are a professional report writer. "
|
||||
"Your goal is to convert the provided Raw Data into a clear, concise, "
|
||||
"and well-formatted response to the User's original request. "
|
||||
"Do not add new facts. Just format and explain the existing data."
|
||||
)
|
||||
|
||||
user_prompt = f"""### USER REQUEST
|
||||
{task_text}
|
||||
|
||||
### RAW DATA COLLECTED
|
||||
{raw_answer}
|
||||
|
||||
### INSTRUCTION
|
||||
Write the final response in natural language (Markdown).
|
||||
"""
|
||||
|
||||
if debug: logger.debug("Generating natural language report...")
|
||||
return client.completion(build_chat_prompt([
|
||||
{"role": agent_template.ROLE_SYSTEM, "content": system_prompt},
|
||||
{"role": agent_template.ROLE_USER, "content": user_prompt}
|
||||
]))
|
||||
|
||||
def load_file(filepath):
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"File not found: {filepath}")
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user