Initial commit
This commit is contained in:
256
lats.py
Normal file
256
lats.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import json
|
||||
import math
|
||||
import requests
|
||||
import io
|
||||
import contextlib
|
||||
import re
|
||||
import textwrap
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
# --- CONFIGURATION ---
|
||||
LLM_API_URL = "http://localhost:8090/v1/chat/completions"
|
||||
MAX_ITERATIONS = 3
|
||||
EXPANSION_N = 2
|
||||
UCT_CONSTANT = 1.41
|
||||
|
||||
|
||||
class Node:
|
||||
_id_counter = 0
|
||||
|
||||
def __init__(self, state, parent=None, action="", observation="", reflection=""):
|
||||
self.id = Node._id_counter
|
||||
Node._id_counter += 1
|
||||
self.state = state
|
||||
self.parent = parent
|
||||
self.action = action
|
||||
self.observation = observation
|
||||
self.reflection = reflection
|
||||
self.children = []
|
||||
self.visits = 0
|
||||
self.value = 0.0
|
||||
|
||||
def uct_score(self):
|
||||
if self.visits == 0:
|
||||
return float('inf')
|
||||
return (self.value / self.visits) + UCT_CONSTANT * math.sqrt(math.log(self.parent.visits) / self.visits)
|
||||
|
||||
def __repr__(self):
|
||||
return f"[Node {self.id} | Val: {self.value:.2f} | Visits: {self.visits}]"
|
||||
|
||||
|
||||
def clean_code(text):
|
||||
"""
|
||||
EXTRACTOR: Strips Markdown code blocks and normalizes indentation.
|
||||
"""
|
||||
# Regex to find content inside ```python or ``` blocks
|
||||
pattern = r"```(?:python)?\n?(.*?)```"
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
if matches:
|
||||
code = matches[0]
|
||||
# 1. dedent removes common leading whitespace from every line
|
||||
# 2. strip removes leading/trailing empty lines
|
||||
return textwrap.dedent(code).strip()
|
||||
|
||||
# Fallback: if provided raw code without backticks, still clean indentation
|
||||
return textwrap.dedent(text).strip()
|
||||
|
||||
|
||||
def call_llm(prompt, system_prompt="You are a helpful coding assistant.", temperature=0.1, json_schema=None):
|
||||
# console.print(f"\nSENT:\n{system_prompt}\n{prompt}\nTemperature: {temperature}\n")
|
||||
payload = {
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": temperature,
|
||||
"max_tokens": 4096
|
||||
}
|
||||
if json_schema:
|
||||
payload["json_schema"] = json_schema
|
||||
try:
|
||||
response = requests.post(LLM_API_URL, json=payload, timeout=60)
|
||||
response_text = response.json()['choices'][0]['message']['content'].strip()
|
||||
# md = Markdown(response_text)
|
||||
# console.print("LLM Response:")
|
||||
# console.print(md)
|
||||
|
||||
return response_text
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
def execute_python(code, supplied_input=""):
|
||||
f = io.StringIO()
|
||||
|
||||
def fake_input(prompt=""):
|
||||
# behave like input(): read one line from supplied_input
|
||||
if not fake_input._lines:
|
||||
raise EOFError("No input provided via supplied_input")
|
||||
return fake_input._lines.pop(0)
|
||||
fake_input._lines = supplied_input.splitlines()
|
||||
|
||||
# CRITICAL CHANGE: We use one dictionary for the environment
|
||||
env = {"__builtins__": __builtins__, "input": fake_input}
|
||||
|
||||
with contextlib.redirect_stdout(f):
|
||||
try:
|
||||
# CRITICAL CHANGE: Pass 'env' as BOTH globals and locals.
|
||||
# This mimics running a standard script at the module level.
|
||||
exec(code, env, env)
|
||||
except Exception as e:
|
||||
# We print the error so it is captured in f.getvalue()
|
||||
# and returned to the agent as part of the observation.
|
||||
print(f"Runtime Error: {e}")
|
||||
|
||||
return f.getvalue().strip()
|
||||
|
||||
|
||||
# --- LATS OPERATIONS ---
|
||||
|
||||
def LATS_Search(task_description):
|
||||
# 1. Store the task string separately so we don't lose it in the tree
|
||||
root = Node(state="") # Root state is empty (no code yet)
|
||||
print(f"\n🚀 STARTING LATS SEARCH FOR: {task_description}")
|
||||
|
||||
for i in range(MAX_ITERATIONS):
|
||||
print(f"\n{'=' * 60}\nITERATION {i + 1}\n{'=' * 60}")
|
||||
|
||||
# --- 1. SELECTION ---
|
||||
print(f"🔍 [1/6 SELECTION] Choosing best path using UCT...")
|
||||
curr = root
|
||||
path = [str(curr.id)]
|
||||
while curr.children:
|
||||
curr = max(curr.children, key=lambda c: c.uct_score())
|
||||
path.append(str(curr.id))
|
||||
print(f" Selected Node {curr.id} following path: {' -> '.join(path)}")
|
||||
|
||||
# --- 2. EXPANSION ---
|
||||
print(f"🌱 [2/6 EXPANSION] Generating {EXPANSION_N} options...")
|
||||
|
||||
# DYNAMIC PROMPTING
|
||||
if curr.id == root.id:
|
||||
# We are at the start: Generate first draft
|
||||
system_msg = "You are a Python coding assistant."
|
||||
prompt = (
|
||||
f"# Task: {task_description}\n\n"
|
||||
f"Write the complete Python code to solve this task."
|
||||
)
|
||||
else:
|
||||
# We are deep in the tree: Fix/Refine the code
|
||||
system_msg = "You are a debugging assistant. Rewrite the code to fix the error."
|
||||
prompt = (
|
||||
f"# Task: {task_description}\n\n"
|
||||
f"Current Code:\n```python\n{curr.state}\n```\n\n"
|
||||
f"Previous Output/Error:\n{curr.observation}\n\n"
|
||||
f"Insight/Feedback: {curr.reflection}\n\n"
|
||||
f"Action: Generate the FULL corrected Python code."
|
||||
)
|
||||
|
||||
for n in range(EXPANSION_N):
|
||||
print(f" Generating option {n + 1}/{EXPANSION_N}...")
|
||||
|
||||
# Use 'clean_code' from previous turn to handle markdown/indentation
|
||||
raw_response = call_llm(prompt, system_prompt=system_msg, temperature=1.0)
|
||||
generated_code = clean_code(raw_response)
|
||||
|
||||
# FIX: The state is the NEW code, not appended code.
|
||||
# (If generated_code is empty due to parse error, fall back to parent state to avoid crash)
|
||||
full_code = generated_code if generated_code else curr.state
|
||||
|
||||
# --- 3. SIMULATION ---
|
||||
print(f"🧪 [3/6 SIMULATION] Executing code...")
|
||||
# Use fixed 'execute_python' from previous turn
|
||||
observation = execute_python(full_code)
|
||||
|
||||
# Truncate long observations for the console (keep full for LLM)
|
||||
display_obs = (observation[:100] + '...') if len(observation) > 100 else observation
|
||||
print(f" REPL Output: {display_obs}")
|
||||
if not observation: observation = "(No output)"
|
||||
|
||||
# --- 4. EVALUATION ---
|
||||
print(f"📊 [4/6 EVALUATION] Scoring...")
|
||||
|
||||
json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reward": {"type": "number", "minimum": 0.0, "maximum": 1.0},
|
||||
"reasoning": {"type": "string"}
|
||||
},
|
||||
"required": ["reward", "reasoning"]
|
||||
}
|
||||
|
||||
eval_prompt = (
|
||||
f"Task: {task_description}\n"
|
||||
f"Code:\n```python\n{full_code}\n```\n"
|
||||
f"Execution Result: {observation}\n\n"
|
||||
"Evaluate if the code solves the task and runs without errors."
|
||||
)
|
||||
|
||||
res = call_llm(eval_prompt, system_prompt="You are a code judge.", temperature=0.1, json_schema=json_schema)
|
||||
|
||||
try:
|
||||
data = json.loads(res)
|
||||
lm_score = data["reward"]
|
||||
reasoning = data["reasoning"]
|
||||
except:
|
||||
lm_score = 0.0
|
||||
reasoning = "JSON Parsing Failed"
|
||||
|
||||
# Create Child
|
||||
child = Node(state=full_code, parent=curr, action=generated_code, observation=observation)
|
||||
child.value = lm_score
|
||||
child.visits = 1
|
||||
|
||||
# --- 5. REFLECTION ---
|
||||
# Store the reasoning as reflection immediately (Efficiency boost)
|
||||
child.reflection = reasoning
|
||||
curr.children.append(child)
|
||||
|
||||
print(f" Node {child.id} Score: {lm_score} | {reasoning[:60]}...")
|
||||
|
||||
if lm_score >= 1.0:
|
||||
print(f"\n🎯 PERFECT SCORE FOUND IN ITERATION {i + 1}!")
|
||||
return child.state
|
||||
|
||||
# --- 6. BACKPROPAGATION ---
|
||||
print(f"⬆️ [6/6 BACKPROPAGATION]...")
|
||||
# LATS Standard: Backpropagate the best immediate child's value
|
||||
best_child_val = max(c.value for c in curr.children) if curr.children else 0
|
||||
|
||||
temp = curr
|
||||
while temp:
|
||||
temp.visits += 1
|
||||
# Update value (Running Average or Max, depending on preference. LATS often uses Max for coding)
|
||||
temp.value = max(temp.value, best_child_val)
|
||||
temp = temp.parent
|
||||
|
||||
# Final Select
|
||||
all_nodes = []
|
||||
def collect(n):
|
||||
all_nodes.append(n)
|
||||
for c in n.children: collect(c)
|
||||
collect(root)
|
||||
|
||||
# Filter out empty root
|
||||
candidates = [n for n in all_nodes if n.state]
|
||||
if not candidates: return "No code generated."
|
||||
|
||||
best = max(candidates, key=lambda n: n.value)
|
||||
return best.state
|
||||
|
||||
|
||||
# --- TEST EXECUTION ---
|
||||
if __name__ == "__main__":
|
||||
# Using Task 2 from the previous list: Date extraction/sorting
|
||||
task = "Given strings S and T, find the shortest substring of S which has T as a subsequence. Return the substring or empty string if none. Use DP or two-pointer scan with backward trace. Input: ('abcdebdde', 'bde')."
|
||||
|
||||
final_code = LATS_Search(task)
|
||||
print("\n" + "=" * 60)
|
||||
print("FINAL RESULTING CODE:")
|
||||
print("=" * 60)
|
||||
print(final_code)
|
||||
Reference in New Issue
Block a user