#!/usr/bin/env python3
"""Check and optionally fix test order in parser_test.zig to match upstream."""

import re
import sys

OURS = "parser_test.zig"
UPSTREAM = "../zig/lib/std/zig/parser_test.zig"


def extract_test_names(path):
    with open(path) as f:
        return re.findall(r'^test "(.+?)" \{', f.read(), re.M)


def extract_test_blocks(path):
    """Split file into: header, list of (name, content) test blocks, footer."""
    with open(path) as f:
        lines = f.readlines()

    header = []
    footer = []
    blocks = []
    current_name = None
    current_lines = []
    brace_depth = 0
    in_test = False
    found_first_test = False

    for line in lines:
        m = re.match(r'^test "(.+?)" \{', line)
        if m and not in_test:
            found_first_test = True
            if current_name is not None:
                blocks.append((current_name, "".join(current_lines)))
            current_name = m.group(1)
            current_lines = [line]
            brace_depth = 1
            in_test = True
            continue

        if in_test:
            current_lines.append(line)
            brace_depth += line.count("{") - line.count("}")
            if brace_depth == 0:
                in_test = False
        elif not found_first_test:
            header.append(line)
        else:
            # Non-test content after tests started — could be blank lines
            # between tests or footer content
            if current_name is not None:
                # Append to previous test block as trailing content
                current_lines.append(line)
            else:
                footer.append(line)

    if current_name is not None:
        blocks.append((current_name, "".join(current_lines)))

    # Anything after the last test block is footer
    # Split last block's trailing non-test content into footer
    if blocks:
        last_name, last_content = blocks[-1]
        last_lines = last_content.split('\n')
        # Find where the test block ends (} at column 0)
        test_end = len(last_lines)
        for i, line in enumerate(last_lines):
            if line == '}' and i > 0:
                test_end = i + 1
        if test_end < len(last_lines):
            blocks[-1] = (last_name, '\n'.join(last_lines[:test_end]) + '\n')
            footer = ['\n'.join(last_lines[test_end:]) + '\n'] + footer

    return "".join(header), blocks, "".join(footer)


def main():
    fix = "--fix" in sys.argv

    upstream_order = extract_test_names(UPSTREAM)
    our_names = extract_test_names(OURS)

    # Build position map for upstream
    upstream_pos = {name: i for i, name in enumerate(upstream_order)}

    # Check order
    our_in_upstream = [n for n in our_names if n in upstream_pos]
    positions = [upstream_pos[n] for n in our_in_upstream]
    is_sorted = positions == sorted(positions)

    if is_sorted:
        print(f"OK: {len(our_names)} tests in correct order")
        return 0

    # Find out-of-order tests
    out_of_order = []
    prev_pos = -1
    for name in our_in_upstream:
        pos = upstream_pos[name]
        if pos < prev_pos:
            out_of_order.append(name)
        prev_pos = max(prev_pos, pos)

    print(f"WARN: {len(out_of_order)} tests out of order:")
    for name in out_of_order[:10]:
        print(f"  - {name}")
    if len(out_of_order) > 10:
        print(f"  ... and {len(out_of_order) - 10} more")

    if not fix:
        print("\nRun with --fix to reorder")
        return 1

    # Fix: reorder
    header, blocks, footer = extract_test_blocks(OURS)
    block_map = {name: content for name, content in blocks}

    # Reorder: upstream-ordered first, then extras
    ordered = []
    seen = set()
    for name in upstream_order:
        if name in block_map and name not in seen:
            ordered.append((name, block_map[name]))
            seen.add(name)
    for name, content in blocks:
        if name not in seen:
            ordered.append((name, content))
            seen.add(name)

    with open(OURS, "w") as f:
        f.write(header)
        for _, content in ordered:
            f.write("\n")
            f.write(content)
        f.write(footer)

    print(f"Fixed: {len(ordered)} tests reordered")
    return 0


if __name__ == "__main__":
    sys.exit(main())
