Initial commit

This commit is contained in:
2026-01-25 21:13:37 +00:00
commit 2490895254

256
lats.py Normal file
View File

@@ -0,0 +1,256 @@
import json
import math
import requests
import io
import contextlib
import re
import textwrap
from rich.console import Console
from rich.markdown import Markdown
console = Console()
# --- CONFIGURATION ---
LLM_API_URL = "http://localhost:8090/v1/chat/completions"
MAX_ITERATIONS = 3
EXPANSION_N = 2
UCT_CONSTANT = 1.41
class Node:
_id_counter = 0
def __init__(self, state, parent=None, action="", observation="", reflection=""):
self.id = Node._id_counter
Node._id_counter += 1
self.state = state
self.parent = parent
self.action = action
self.observation = observation
self.reflection = reflection
self.children = []
self.visits = 0
self.value = 0.0
def uct_score(self):
if self.visits == 0:
return float('inf')
return (self.value / self.visits) + UCT_CONSTANT * math.sqrt(math.log(self.parent.visits) / self.visits)
def __repr__(self):
return f"[Node {self.id} | Val: {self.value:.2f} | Visits: {self.visits}]"
def clean_code(text):
"""
EXTRACTOR: Strips Markdown code blocks and normalizes indentation.
"""
# Regex to find content inside ```python or ``` blocks
pattern = r"```(?:python)?\n?(.*?)```"
matches = re.findall(pattern, text, re.DOTALL)
if matches:
code = matches[0]
# 1. dedent removes common leading whitespace from every line
# 2. strip removes leading/trailing empty lines
return textwrap.dedent(code).strip()
# Fallback: if provided raw code without backticks, still clean indentation
return textwrap.dedent(text).strip()
def call_llm(prompt, system_prompt="You are a helpful coding assistant.", temperature=0.1, json_schema=None):
# console.print(f"\nSENT:\n{system_prompt}\n{prompt}\nTemperature: {temperature}\n")
payload = {
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
"temperature": temperature,
"max_tokens": 4096
}
if json_schema:
payload["json_schema"] = json_schema
try:
response = requests.post(LLM_API_URL, json=payload, timeout=60)
response_text = response.json()['choices'][0]['message']['content'].strip()
# md = Markdown(response_text)
# console.print("LLM Response:")
# console.print(md)
return response_text
except Exception as e:
return f"Error: {e}"
def execute_python(code, supplied_input=""):
f = io.StringIO()
def fake_input(prompt=""):
# behave like input(): read one line from supplied_input
if not fake_input._lines:
raise EOFError("No input provided via supplied_input")
return fake_input._lines.pop(0)
fake_input._lines = supplied_input.splitlines()
# CRITICAL CHANGE: We use one dictionary for the environment
env = {"__builtins__": __builtins__, "input": fake_input}
with contextlib.redirect_stdout(f):
try:
# CRITICAL CHANGE: Pass 'env' as BOTH globals and locals.
# This mimics running a standard script at the module level.
exec(code, env, env)
except Exception as e:
# We print the error so it is captured in f.getvalue()
# and returned to the agent as part of the observation.
print(f"Runtime Error: {e}")
return f.getvalue().strip()
# --- LATS OPERATIONS ---
def LATS_Search(task_description):
# 1. Store the task string separately so we don't lose it in the tree
root = Node(state="") # Root state is empty (no code yet)
print(f"\n🚀 STARTING LATS SEARCH FOR: {task_description}")
for i in range(MAX_ITERATIONS):
print(f"\n{'=' * 60}\nITERATION {i + 1}\n{'=' * 60}")
# --- 1. SELECTION ---
print(f"🔍 [1/6 SELECTION] Choosing best path using UCT...")
curr = root
path = [str(curr.id)]
while curr.children:
curr = max(curr.children, key=lambda c: c.uct_score())
path.append(str(curr.id))
print(f" Selected Node {curr.id} following path: {' -> '.join(path)}")
# --- 2. EXPANSION ---
print(f"🌱 [2/6 EXPANSION] Generating {EXPANSION_N} options...")
# DYNAMIC PROMPTING
if curr.id == root.id:
# We are at the start: Generate first draft
system_msg = "You are a Python coding assistant."
prompt = (
f"# Task: {task_description}\n\n"
f"Write the complete Python code to solve this task."
)
else:
# We are deep in the tree: Fix/Refine the code
system_msg = "You are a debugging assistant. Rewrite the code to fix the error."
prompt = (
f"# Task: {task_description}\n\n"
f"Current Code:\n```python\n{curr.state}\n```\n\n"
f"Previous Output/Error:\n{curr.observation}\n\n"
f"Insight/Feedback: {curr.reflection}\n\n"
f"Action: Generate the FULL corrected Python code."
)
for n in range(EXPANSION_N):
print(f" Generating option {n + 1}/{EXPANSION_N}...")
# Use 'clean_code' from previous turn to handle markdown/indentation
raw_response = call_llm(prompt, system_prompt=system_msg, temperature=1.0)
generated_code = clean_code(raw_response)
# FIX: The state is the NEW code, not appended code.
# (If generated_code is empty due to parse error, fall back to parent state to avoid crash)
full_code = generated_code if generated_code else curr.state
# --- 3. SIMULATION ---
print(f"🧪 [3/6 SIMULATION] Executing code...")
# Use fixed 'execute_python' from previous turn
observation = execute_python(full_code)
# Truncate long observations for the console (keep full for LLM)
display_obs = (observation[:100] + '...') if len(observation) > 100 else observation
print(f" REPL Output: {display_obs}")
if not observation: observation = "(No output)"
# --- 4. EVALUATION ---
print(f"📊 [4/6 EVALUATION] Scoring...")
json_schema = {
"type": "object",
"properties": {
"reward": {"type": "number", "minimum": 0.0, "maximum": 1.0},
"reasoning": {"type": "string"}
},
"required": ["reward", "reasoning"]
}
eval_prompt = (
f"Task: {task_description}\n"
f"Code:\n```python\n{full_code}\n```\n"
f"Execution Result: {observation}\n\n"
"Evaluate if the code solves the task and runs without errors."
)
res = call_llm(eval_prompt, system_prompt="You are a code judge.", temperature=0.1, json_schema=json_schema)
try:
data = json.loads(res)
lm_score = data["reward"]
reasoning = data["reasoning"]
except:
lm_score = 0.0
reasoning = "JSON Parsing Failed"
# Create Child
child = Node(state=full_code, parent=curr, action=generated_code, observation=observation)
child.value = lm_score
child.visits = 1
# --- 5. REFLECTION ---
# Store the reasoning as reflection immediately (Efficiency boost)
child.reflection = reasoning
curr.children.append(child)
print(f" Node {child.id} Score: {lm_score} | {reasoning[:60]}...")
if lm_score >= 1.0:
print(f"\n🎯 PERFECT SCORE FOUND IN ITERATION {i + 1}!")
return child.state
# --- 6. BACKPROPAGATION ---
print(f"⬆️ [6/6 BACKPROPAGATION]...")
# LATS Standard: Backpropagate the best immediate child's value
best_child_val = max(c.value for c in curr.children) if curr.children else 0
temp = curr
while temp:
temp.visits += 1
# Update value (Running Average or Max, depending on preference. LATS often uses Max for coding)
temp.value = max(temp.value, best_child_val)
temp = temp.parent
# Final Select
all_nodes = []
def collect(n):
all_nodes.append(n)
for c in n.children: collect(c)
collect(root)
# Filter out empty root
candidates = [n for n in all_nodes if n.state]
if not candidates: return "No code generated."
best = max(candidates, key=lambda n: n.value)
return best.state
# --- TEST EXECUTION ---
if __name__ == "__main__":
# Using Task 2 from the previous list: Date extraction/sorting
task = "Given strings S and T, find the shortest substring of S which has T as a subsequence. Return the substring or empty string if none. Use DP or two-pointer scan with backward trace. Input: ('abcdebdde', 'bde')."
final_code = LATS_Search(task)
print("\n" + "=" * 60)
print("FINAL RESULTING CODE:")
print("=" * 60)
print(final_code)