forked from codegen-sh/codegen
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvalidation.py
More file actions
151 lines (122 loc) · 5.87 KB
/
Copy pathvalidation.py
File metadata and controls
151 lines (122 loc) · 5.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from __future__ import annotations
import functools
import socket
from collections import Counter, defaultdict
from enum import StrEnum
from typing import TYPE_CHECKING
from tabulate import tabulate
from codegen.sdk.enums import NodeType
from codegen.sdk.utils import truncate_line
from codegen.shared.logging.get_logger import get_logger
logger = get_logger(__name__)
if TYPE_CHECKING:
from rustworkx import PyDiGraph
from codegen.sdk.core.codebase import CodebaseType
class PostInitValidationStatus(StrEnum):
NO_NODES = "NO_NODES"
NO_EDGES = "NO_EDGES"
MISSING_FILES = "MISSING_FILES"
LOW_IMPORT_RESOLUTION_RATE = "LOW_IMPORT_RESOLUTION_RATE"
SUCCESS = "SUCCESS"
def post_init_validation(codebase: CodebaseType) -> PostInitValidationStatus:
"""Post codebase._init_graph verifies that the built graph is valid."""
from codegen.sdk.codebase.codebase_context import GLOBAL_FILE_IGNORE_LIST
# Verify the graph has nodes
if len(codebase.ctx.nodes) == 0:
return PostInitValidationStatus.NO_NODES
# Verify the graph has the same number of files as there are in the repo
if len(codebase.files) != len(list(codebase.op.iter_files(codebase.ctx.projects[0].subdirectories, extensions=codebase.ctx.extensions, ignore_list=GLOBAL_FILE_IGNORE_LIST))):
return PostInitValidationStatus.MISSING_FILES
# Verify import resolution
num_resolved_imports = len([imp for imp in codebase.imports if imp.imported_symbol and imp.imported_symbol.node_type != NodeType.EXTERNAL])
if len(codebase.imports) > 0 and num_resolved_imports / len(codebase.imports) < 0.2:
logger.info(f"Codebase {codebase.repo_path} has {num_resolved_imports / len(codebase.imports)} < 0.2 resolved imports")
return PostInitValidationStatus.LOW_IMPORT_RESOLUTION_RATE
return PostInitValidationStatus.SUCCESS
def post_reset_validation(init_nodes, nodes, init_edges, edges, repo_name: str, subdirectories: list[str] | None) -> None:
logger.info("Verifying graph state and alerting if necessary")
hostname = socket.gethostname()
if len(dict.fromkeys(nodes)) != len(dict.fromkeys(init_nodes)):
post_message = f"Reset graph: Nodes do not match for {repo_name} for subdirectories {subdirectories}. Hostname: {hostname}"
message = get_nodes_error(init_nodes, nodes)
log_or_throw(post_message, message)
if len(dict.fromkeys(edges)) != len(dict.fromkeys(init_edges)):
post_message = f"Reset graph: Edges do not match for {repo_name} for subdirectories {subdirectories}. Hostname: {hostname}"
message = get_edges_error(edges, init_edges)
log_or_throw(post_message, message)
def post_sync_validation(codebase: CodebaseType) -> bool:
"""Post codebase.sync, checks that the codebase graph is in a valid state (i.e. not corrupted by codebase.sync)"""
if len(codebase.ctx.all_syncs) > 0 or len(codebase.ctx.pending_syncs) > 0 or len(codebase.ctx.transaction_manager.to_commit()) > 0:
msg = "Can only be called on a reset codebase"
raise NotImplementedError(msg)
if not codebase.ctx.config.codebase.track_graph:
msg = "Can only be called with track_graph=true"
raise NotImplementedError(msg)
return len(dict.fromkeys(codebase.ctx.old_graph.nodes())) == len(dict.fromkeys(codebase.ctx.nodes)) and len(dict.fromkeys(codebase.ctx.old_graph.weighted_edge_list())) == len(
dict.fromkeys(codebase.ctx.edges)
)
def log_or_throw(message, thread_message: str):
hostname = socket.gethostname()
logger.error(message)
# logger.error(thread_message)
if hostname != "modal":
msg = f"{message}\n{thread_message}"
raise Exception(msg)
return
def get_edges_error(edges, init_edges):
set_edges = set(edges)
set_init_edges = set(init_edges)
missing_edges = set_init_edges - set_edges
extra_edges = set_edges - set_init_edges
message = ""
if extra_edges:
extras = tabulate((map(functools.partial(truncate_line, max_chars=50), edge) for edge in extra_edges), ["Start", "End", "Edge"], maxcolwidths=50)
message += f"""
Extra edges
```
{extras}
```
"""
if missing_edges:
missing = tabulate((map(functools.partial(truncate_line, max_chars=50), edge) for edge in missing_edges), ["Start", "End", "Edge"], maxcolwidths=50)
message += f"""
Missing edges
```
{missing}
```
"""
missing_by_key = defaultdict(lambda: defaultdict(list))
for u, v, data in missing_edges:
missing_by_key[u][v].append(data)
for u, v, data in extra_edges:
if u in missing_by_key and v in missing_by_key[u]:
for match in missing_by_key[u][v]:
message += f"Possible match from {u} to {v}: {match} -> {data}\n"
if len(edges) != len(set_init_edges):
message += f"{len(edges) - len(set_edges)} edges duplicated from {len(init_edges) - len(set_init_edges)}. Printing out up to 5 edges\n"
extras = tabulate(((*map(functools.partial(truncate_line, max_chars=50), edge), count) for edge, count in Counter(edges).most_common(5)), ["Start", "End", "Edge", "Count"], maxcolwidths=50)
message += extras
return message
def get_nodes_error(init_nodes, nodes):
set_nodes = set(nodes)
set_init_nodes = set(init_nodes)
message = f"""
Extra nodes
```
{set_nodes - set_init_nodes}
```
Missing nodes
```
{set_init_nodes - set_nodes}
```
"""
for node in set_nodes - set_init_nodes:
from codegen.sdk.core.external_module import ExternalModule
if isinstance(node, ExternalModule):
message += "External Module persisted with following dependencies: " + str(list((node.ctx.get_node(source), edge) for source, _, edge in node.ctx.in_edges(node.node_id)))
return message
def get_edges(graph: PyDiGraph):
ret = []
for start, end, edge in graph.weighted_edge_list():
ret.append((graph.get_node_data(start), graph.get_node_data(end), edge))
return ret