256 lines
9.2 KiB
Python
256 lines
9.2 KiB
Python
import json
|
|
import math
|
|
import requests
|
|
import io
|
|
import contextlib
|
|
import re
|
|
import textwrap
|
|
from rich.console import Console
|
|
from rich.markdown import Markdown
|
|
|
|
console = Console()
|
|
|
|
|
|
# --- CONFIGURATION ---
|
|
LLM_API_URL = "http://localhost:8090/v1/chat/completions"
|
|
MAX_ITERATIONS = 3
|
|
EXPANSION_N = 2
|
|
UCT_CONSTANT = 1.41
|
|
|
|
|
|
class Node:
|
|
_id_counter = 0
|
|
|
|
def __init__(self, state, parent=None, action="", observation="", reflection=""):
|
|
self.id = Node._id_counter
|
|
Node._id_counter += 1
|
|
self.state = state
|
|
self.parent = parent
|
|
self.action = action
|
|
self.observation = observation
|
|
self.reflection = reflection
|
|
self.children = []
|
|
self.visits = 0
|
|
self.value = 0.0
|
|
|
|
def uct_score(self):
|
|
if self.visits == 0:
|
|
return float('inf')
|
|
return (self.value / self.visits) + UCT_CONSTANT * math.sqrt(math.log(self.parent.visits) / self.visits)
|
|
|
|
def __repr__(self):
|
|
return f"[Node {self.id} | Val: {self.value:.2f} | Visits: {self.visits}]"
|
|
|
|
|
|
def clean_code(text):
|
|
"""
|
|
EXTRACTOR: Strips Markdown code blocks and normalizes indentation.
|
|
"""
|
|
# Regex to find content inside ```python or ``` blocks
|
|
pattern = r"```(?:python)?\n?(.*?)```"
|
|
matches = re.findall(pattern, text, re.DOTALL)
|
|
|
|
if matches:
|
|
code = matches[0]
|
|
# 1. dedent removes common leading whitespace from every line
|
|
# 2. strip removes leading/trailing empty lines
|
|
return textwrap.dedent(code).strip()
|
|
|
|
# Fallback: if provided raw code without backticks, still clean indentation
|
|
return textwrap.dedent(text).strip()
|
|
|
|
|
|
def call_llm(prompt, system_prompt="You are a helpful coding assistant.", temperature=0.1, json_schema=None):
|
|
# console.print(f"\nSENT:\n{system_prompt}\n{prompt}\nTemperature: {temperature}\n")
|
|
payload = {
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
"temperature": temperature,
|
|
"max_tokens": 4096
|
|
}
|
|
if json_schema:
|
|
payload["json_schema"] = json_schema
|
|
try:
|
|
response = requests.post(LLM_API_URL, json=payload, timeout=60)
|
|
response_text = response.json()['choices'][0]['message']['content'].strip()
|
|
# md = Markdown(response_text)
|
|
# console.print("LLM Response:")
|
|
# console.print(md)
|
|
|
|
return response_text
|
|
except Exception as e:
|
|
return f"Error: {e}"
|
|
|
|
|
|
def execute_python(code, supplied_input=""):
|
|
f = io.StringIO()
|
|
|
|
def fake_input(prompt=""):
|
|
# behave like input(): read one line from supplied_input
|
|
if not fake_input._lines:
|
|
raise EOFError("No input provided via supplied_input")
|
|
return fake_input._lines.pop(0)
|
|
fake_input._lines = supplied_input.splitlines()
|
|
|
|
# CRITICAL CHANGE: We use one dictionary for the environment
|
|
env = {"__builtins__": __builtins__, "input": fake_input}
|
|
|
|
with contextlib.redirect_stdout(f):
|
|
try:
|
|
# CRITICAL CHANGE: Pass 'env' as BOTH globals and locals.
|
|
# This mimics running a standard script at the module level.
|
|
exec(code, env, env)
|
|
except Exception as e:
|
|
# We print the error so it is captured in f.getvalue()
|
|
# and returned to the agent as part of the observation.
|
|
print(f"Runtime Error: {e}")
|
|
|
|
return f.getvalue().strip()
|
|
|
|
|
|
# --- LATS OPERATIONS ---
|
|
|
|
def LATS_Search(task_description):
|
|
# 1. Store the task string separately so we don't lose it in the tree
|
|
root = Node(state="") # Root state is empty (no code yet)
|
|
print(f"\n🚀 STARTING LATS SEARCH FOR: {task_description}")
|
|
|
|
for i in range(MAX_ITERATIONS):
|
|
print(f"\n{'=' * 60}\nITERATION {i + 1}\n{'=' * 60}")
|
|
|
|
# --- 1. SELECTION ---
|
|
print(f"🔍 [1/6 SELECTION] Choosing best path using UCT...")
|
|
curr = root
|
|
path = [str(curr.id)]
|
|
while curr.children:
|
|
curr = max(curr.children, key=lambda c: c.uct_score())
|
|
path.append(str(curr.id))
|
|
print(f" Selected Node {curr.id} following path: {' -> '.join(path)}")
|
|
|
|
# --- 2. EXPANSION ---
|
|
print(f"🌱 [2/6 EXPANSION] Generating {EXPANSION_N} options...")
|
|
|
|
# DYNAMIC PROMPTING
|
|
if curr.id == root.id:
|
|
# We are at the start: Generate first draft
|
|
system_msg = "You are a Python coding assistant."
|
|
prompt = (
|
|
f"# Task: {task_description}\n\n"
|
|
f"Write the complete Python code to solve this task."
|
|
)
|
|
else:
|
|
# We are deep in the tree: Fix/Refine the code
|
|
system_msg = "You are a debugging assistant. Rewrite the code to fix the error."
|
|
prompt = (
|
|
f"# Task: {task_description}\n\n"
|
|
f"Current Code:\n```python\n{curr.state}\n```\n\n"
|
|
f"Previous Output/Error:\n{curr.observation}\n\n"
|
|
f"Insight/Feedback: {curr.reflection}\n\n"
|
|
f"Action: Generate the FULL corrected Python code."
|
|
)
|
|
|
|
for n in range(EXPANSION_N):
|
|
print(f" Generating option {n + 1}/{EXPANSION_N}...")
|
|
|
|
# Use 'clean_code' from previous turn to handle markdown/indentation
|
|
raw_response = call_llm(prompt, system_prompt=system_msg, temperature=1.0)
|
|
generated_code = clean_code(raw_response)
|
|
|
|
# FIX: The state is the NEW code, not appended code.
|
|
# (If generated_code is empty due to parse error, fall back to parent state to avoid crash)
|
|
full_code = generated_code if generated_code else curr.state
|
|
|
|
# --- 3. SIMULATION ---
|
|
print(f"🧪 [3/6 SIMULATION] Executing code...")
|
|
# Use fixed 'execute_python' from previous turn
|
|
observation = execute_python(full_code)
|
|
|
|
# Truncate long observations for the console (keep full for LLM)
|
|
display_obs = (observation[:100] + '...') if len(observation) > 100 else observation
|
|
print(f" REPL Output: {display_obs}")
|
|
if not observation: observation = "(No output)"
|
|
|
|
# --- 4. EVALUATION ---
|
|
print(f"📊 [4/6 EVALUATION] Scoring...")
|
|
|
|
json_schema = {
|
|
"type": "object",
|
|
"properties": {
|
|
"reward": {"type": "number", "minimum": 0.0, "maximum": 1.0},
|
|
"reasoning": {"type": "string"}
|
|
},
|
|
"required": ["reward", "reasoning"]
|
|
}
|
|
|
|
eval_prompt = (
|
|
f"Task: {task_description}\n"
|
|
f"Code:\n```python\n{full_code}\n```\n"
|
|
f"Execution Result: {observation}\n\n"
|
|
"Evaluate if the code solves the task and runs without errors."
|
|
)
|
|
|
|
res = call_llm(eval_prompt, system_prompt="You are a code judge.", temperature=0.1, json_schema=json_schema)
|
|
|
|
try:
|
|
data = json.loads(res)
|
|
lm_score = data["reward"]
|
|
reasoning = data["reasoning"]
|
|
except:
|
|
lm_score = 0.0
|
|
reasoning = "JSON Parsing Failed"
|
|
|
|
# Create Child
|
|
child = Node(state=full_code, parent=curr, action=generated_code, observation=observation)
|
|
child.value = lm_score
|
|
child.visits = 1
|
|
|
|
# --- 5. REFLECTION ---
|
|
# Store the reasoning as reflection immediately (Efficiency boost)
|
|
child.reflection = reasoning
|
|
curr.children.append(child)
|
|
|
|
print(f" Node {child.id} Score: {lm_score} | {reasoning[:60]}...")
|
|
|
|
if lm_score >= 1.0:
|
|
print(f"\n🎯 PERFECT SCORE FOUND IN ITERATION {i + 1}!")
|
|
return child.state
|
|
|
|
# --- 6. BACKPROPAGATION ---
|
|
print(f"⬆️ [6/6 BACKPROPAGATION]...")
|
|
# LATS Standard: Backpropagate the best immediate child's value
|
|
best_child_val = max(c.value for c in curr.children) if curr.children else 0
|
|
|
|
temp = curr
|
|
while temp:
|
|
temp.visits += 1
|
|
# Update value (Running Average or Max, depending on preference. LATS often uses Max for coding)
|
|
temp.value = max(temp.value, best_child_val)
|
|
temp = temp.parent
|
|
|
|
# Final Select
|
|
all_nodes = []
|
|
def collect(n):
|
|
all_nodes.append(n)
|
|
for c in n.children: collect(c)
|
|
collect(root)
|
|
|
|
# Filter out empty root
|
|
candidates = [n for n in all_nodes if n.state]
|
|
if not candidates: return "No code generated."
|
|
|
|
best = max(candidates, key=lambda n: n.value)
|
|
return best.state
|
|
|
|
|
|
# --- TEST EXECUTION ---
|
|
if __name__ == "__main__":
|
|
# Using Task 2 from the previous list: Date extraction/sorting
|
|
task = "Given strings S and T, find the shortest substring of S which has T as a subsequence. Return the substring or empty string if none. Use DP or two-pointer scan with backward trace. Input: ('abcdebdde', 'bde')."
|
|
|
|
final_code = LATS_Search(task)
|
|
print("\n" + "=" * 60)
|
|
print("FINAL RESULTING CODE:")
|
|
print("=" * 60)
|
|
print(final_code) |