From 249089525473a6d61ed31d1ee91151c2d196eb7a Mon Sep 17 00:00:00 2001 From: Morpheus Sandmann Date: Sun, 25 Jan 2026 21:13:37 +0000 Subject: [PATCH] Initial commit --- lats.py | 256 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 lats.py diff --git a/lats.py b/lats.py new file mode 100644 index 0000000..e7672f5 --- /dev/null +++ b/lats.py @@ -0,0 +1,256 @@ +import json +import math +import requests +import io +import contextlib +import re +import textwrap +from rich.console import Console +from rich.markdown import Markdown + +console = Console() + + +# --- CONFIGURATION --- +LLM_API_URL = "http://localhost:8090/v1/chat/completions" +MAX_ITERATIONS = 3 +EXPANSION_N = 2 +UCT_CONSTANT = 1.41 + + +class Node: + _id_counter = 0 + + def __init__(self, state, parent=None, action="", observation="", reflection=""): + self.id = Node._id_counter + Node._id_counter += 1 + self.state = state + self.parent = parent + self.action = action + self.observation = observation + self.reflection = reflection + self.children = [] + self.visits = 0 + self.value = 0.0 + + def uct_score(self): + if self.visits == 0: + return float('inf') + return (self.value / self.visits) + UCT_CONSTANT * math.sqrt(math.log(self.parent.visits) / self.visits) + + def __repr__(self): + return f"[Node {self.id} | Val: {self.value:.2f} | Visits: {self.visits}]" + + +def clean_code(text): + """ + EXTRACTOR: Strips Markdown code blocks and normalizes indentation. + """ + # Regex to find content inside ```python or ``` blocks + pattern = r"```(?:python)?\n?(.*?)```" + matches = re.findall(pattern, text, re.DOTALL) + + if matches: + code = matches[0] + # 1. dedent removes common leading whitespace from every line + # 2. strip removes leading/trailing empty lines + return textwrap.dedent(code).strip() + + # Fallback: if provided raw code without backticks, still clean indentation + return textwrap.dedent(text).strip() + + +def call_llm(prompt, system_prompt="You are a helpful coding assistant.", temperature=0.1, json_schema=None): + # console.print(f"\nSENT:\n{system_prompt}\n{prompt}\nTemperature: {temperature}\n") + payload = { + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt} + ], + "temperature": temperature, + "max_tokens": 4096 + } + if json_schema: + payload["json_schema"] = json_schema + try: + response = requests.post(LLM_API_URL, json=payload, timeout=60) + response_text = response.json()['choices'][0]['message']['content'].strip() + # md = Markdown(response_text) + # console.print("LLM Response:") + # console.print(md) + + return response_text + except Exception as e: + return f"Error: {e}" + + +def execute_python(code, supplied_input=""): + f = io.StringIO() + + def fake_input(prompt=""): + # behave like input(): read one line from supplied_input + if not fake_input._lines: + raise EOFError("No input provided via supplied_input") + return fake_input._lines.pop(0) + fake_input._lines = supplied_input.splitlines() + + # CRITICAL CHANGE: We use one dictionary for the environment + env = {"__builtins__": __builtins__, "input": fake_input} + + with contextlib.redirect_stdout(f): + try: + # CRITICAL CHANGE: Pass 'env' as BOTH globals and locals. + # This mimics running a standard script at the module level. + exec(code, env, env) + except Exception as e: + # We print the error so it is captured in f.getvalue() + # and returned to the agent as part of the observation. + print(f"Runtime Error: {e}") + + return f.getvalue().strip() + + +# --- LATS OPERATIONS --- + +def LATS_Search(task_description): + # 1. Store the task string separately so we don't lose it in the tree + root = Node(state="") # Root state is empty (no code yet) + print(f"\n๐Ÿš€ STARTING LATS SEARCH FOR: {task_description}") + + for i in range(MAX_ITERATIONS): + print(f"\n{'=' * 60}\nITERATION {i + 1}\n{'=' * 60}") + + # --- 1. SELECTION --- + print(f"๐Ÿ” [1/6 SELECTION] Choosing best path using UCT...") + curr = root + path = [str(curr.id)] + while curr.children: + curr = max(curr.children, key=lambda c: c.uct_score()) + path.append(str(curr.id)) + print(f" Selected Node {curr.id} following path: {' -> '.join(path)}") + + # --- 2. EXPANSION --- + print(f"๐ŸŒฑ [2/6 EXPANSION] Generating {EXPANSION_N} options...") + + # DYNAMIC PROMPTING + if curr.id == root.id: + # We are at the start: Generate first draft + system_msg = "You are a Python coding assistant." + prompt = ( + f"# Task: {task_description}\n\n" + f"Write the complete Python code to solve this task." + ) + else: + # We are deep in the tree: Fix/Refine the code + system_msg = "You are a debugging assistant. Rewrite the code to fix the error." + prompt = ( + f"# Task: {task_description}\n\n" + f"Current Code:\n```python\n{curr.state}\n```\n\n" + f"Previous Output/Error:\n{curr.observation}\n\n" + f"Insight/Feedback: {curr.reflection}\n\n" + f"Action: Generate the FULL corrected Python code." + ) + + for n in range(EXPANSION_N): + print(f" Generating option {n + 1}/{EXPANSION_N}...") + + # Use 'clean_code' from previous turn to handle markdown/indentation + raw_response = call_llm(prompt, system_prompt=system_msg, temperature=1.0) + generated_code = clean_code(raw_response) + + # FIX: The state is the NEW code, not appended code. + # (If generated_code is empty due to parse error, fall back to parent state to avoid crash) + full_code = generated_code if generated_code else curr.state + + # --- 3. SIMULATION --- + print(f"๐Ÿงช [3/6 SIMULATION] Executing code...") + # Use fixed 'execute_python' from previous turn + observation = execute_python(full_code) + + # Truncate long observations for the console (keep full for LLM) + display_obs = (observation[:100] + '...') if len(observation) > 100 else observation + print(f" REPL Output: {display_obs}") + if not observation: observation = "(No output)" + + # --- 4. EVALUATION --- + print(f"๐Ÿ“Š [4/6 EVALUATION] Scoring...") + + json_schema = { + "type": "object", + "properties": { + "reward": {"type": "number", "minimum": 0.0, "maximum": 1.0}, + "reasoning": {"type": "string"} + }, + "required": ["reward", "reasoning"] + } + + eval_prompt = ( + f"Task: {task_description}\n" + f"Code:\n```python\n{full_code}\n```\n" + f"Execution Result: {observation}\n\n" + "Evaluate if the code solves the task and runs without errors." + ) + + res = call_llm(eval_prompt, system_prompt="You are a code judge.", temperature=0.1, json_schema=json_schema) + + try: + data = json.loads(res) + lm_score = data["reward"] + reasoning = data["reasoning"] + except: + lm_score = 0.0 + reasoning = "JSON Parsing Failed" + + # Create Child + child = Node(state=full_code, parent=curr, action=generated_code, observation=observation) + child.value = lm_score + child.visits = 1 + + # --- 5. REFLECTION --- + # Store the reasoning as reflection immediately (Efficiency boost) + child.reflection = reasoning + curr.children.append(child) + + print(f" Node {child.id} Score: {lm_score} | {reasoning[:60]}...") + + if lm_score >= 1.0: + print(f"\n๐ŸŽฏ PERFECT SCORE FOUND IN ITERATION {i + 1}!") + return child.state + + # --- 6. BACKPROPAGATION --- + print(f"โฌ†๏ธ [6/6 BACKPROPAGATION]...") + # LATS Standard: Backpropagate the best immediate child's value + best_child_val = max(c.value for c in curr.children) if curr.children else 0 + + temp = curr + while temp: + temp.visits += 1 + # Update value (Running Average or Max, depending on preference. LATS often uses Max for coding) + temp.value = max(temp.value, best_child_val) + temp = temp.parent + + # Final Select + all_nodes = [] + def collect(n): + all_nodes.append(n) + for c in n.children: collect(c) + collect(root) + + # Filter out empty root + candidates = [n for n in all_nodes if n.state] + if not candidates: return "No code generated." + + best = max(candidates, key=lambda n: n.value) + return best.state + + +# --- TEST EXECUTION --- +if __name__ == "__main__": + # Using Task 2 from the previous list: Date extraction/sorting + task = "Given strings S and T, find the shortest substring of S which has T as a subsequence. Return the substring or empty string if none. Use DP or two-pointer scan with backward trace. Input: ('abcdebdde', 'bde')." + + final_code = LATS_Search(task) + print("\n" + "=" * 60) + print("FINAL RESULTING CODE:") + print("=" * 60) + print(final_code) \ No newline at end of file