-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy path__init__.py
executable file
·196 lines (163 loc) · 6.67 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import sys
from xml.dom import minidom
from tree_sitter import Language, Parser
EMPTY_CONFIG = {"flattened": [], "aliased": {}, "ignored": [], "label_ignored": []}
MAX_LABEL_SIZE = 75
def eprint(*args, **kwargs):
"""
Same as `print` but writes to `sys.stderr`.
"""
print(*args, file=sys.stderr, **kwargs)
def init_parsers(script_dir):
"""
Compile parsers (when needed) and return a parser map that can be indexed by language.
"""
if Language.build_library(
script_dir + "/build/languages.so",
[
script_dir + "/tree-sitter-c",
script_dir + "/tree-sitter-c-sharp",
script_dir + "/tree-sitter-cmake",
script_dir + "/tree-sitter-go",
script_dir + "/tree-sitter-java",
script_dir + "/tree-sitter-javascript",
script_dir + "/tree-sitter-ocaml/grammars/ocaml",
script_dir + "/tree-sitter-php/php",
script_dir + "/tree-sitter-python",
script_dir + "/tree-sitter-r",
script_dir + "/tree-sitter-ruby",
script_dir + "/tree-sitter-rust",
script_dir + "/tree-sitter-typescript/typescript",
script_dir + "/tree-sitter-kotlin"
],
):
eprint("Compiled dynamic library of parsers.")
else:
eprint("Reusing dynamic library of parsers.")
return {
"c": Language(script_dir + "/build/languages.so", "c"),
"csharp": Language(script_dir + "/build/languages.so", "c_sharp"),
"cmake": Language(script_dir + "/build/languages.so", "cmake"),
"go": Language(script_dir + "/build/languages.so", "go"),
"java": Language(script_dir + "/build/languages.so", "java"),
"javascript": Language(script_dir + "/build/languages.so", "javascript"),
"ocaml": Language(script_dir + "/build/languages.so", "ocaml"),
"php": Language(script_dir + "/build/languages.so", "php"),
"python": Language(script_dir + "/build/languages.so", "python"),
"r": Language(script_dir + "/build/languages.so", "r"),
"ruby": Language(script_dir + "/build/languages.so", "ruby"),
"rust": Language(script_dir + "/build/languages.so", "rust"),
"typescript": Language(script_dir + "/build/languages.so", "typescript"),
"kotlin": Language(script_dir + "/build/languages.so", "kotlin")
}
def parse_and_translate(parser_lang, config, input: bytes):
"""
Parse a file and translate the obtained AST to the GumTree XML format.
"""
newline_offsets = create_newline_offsets(input)
parser = Parser()
parser.set_language(parser_lang)
tree = parser.parse(input)
doc = minidom.Document()
xml_root = to_xml_node(doc, tree.root_node, config, newline_offsets)
doc.appendChild(xml_root)
process(doc, tree.root_node, xml_root, config, newline_offsets)
return doc
def create_newline_offsets(input: bytes):
"""
Obtain a list of indices of all newlines in a file. The first line has offset 0.
This list can be used to translate from `(line, column)` to `pos` by using ...
pos = offsets[line] + column
"""
offsets = [0]
for (i, chr) in enumerate(input.decode("utf-8"), start=1):
if chr == "\n":
offsets.append(i)
return offsets
def get_selector(node, config, action):
"""
If there is a selector that matches the given node in the given config for the given action, return it.
Otherwise, return an empty string.
"""
for selector in config[action]:
if match(selector, node):
return selector
return ""
def match(selector, node):
"""
Check if the given node matches the given selector.
"""
expected_types = selector.split(' ')
ancestor_types = collect_ancestor_types(node, len(expected_types))
if len(ancestor_types) < len(expected_types):
return False
else:
for i in range(len(expected_types)):
if ancestor_types[i] != expected_types[i]:
return False
return True
def collect_ancestor_types(node, max_level):
"""
Collect the types of the ancestors of a given node up to a given maximum level.
"""
ancestor_types = []
for _ in range(max_level):
ancestor_types.append(node.type)
if node.parent is None:
return ancestor_types
else:
node = node.parent
ancestor_types.reverse()
return ancestor_types
def process(doc, node, xml_node, config, newline_offsets):
"""
Process a given node of the ast to include it in a given xml document.
"""
if not get_selector(node, config, 'flattened'):
for child in node.children:
if not get_selector(child, config, 'ignored'):
xml_child_node = to_xml_node(doc, child, config, newline_offsets)
xml_node.appendChild(xml_child_node)
process(doc, child, xml_child_node, config, newline_offsets)
def to_xml_node(doc, node, config, newline_offsets):
"""
Converts an AST node into a XML node.
"""
xmlNode = doc.createElement("tree")
alias_selector = get_selector(node, config, "aliased")
type = config["aliased"][alias_selector] if alias_selector else node.type
xmlNode.setAttribute("type", type)
startPos = newline_offsets[node.start_point[0]] + node.start_point[1]
endPos = newline_offsets[node.end_point[0]] + node.end_point[1]
length = endPos - startPos
xmlNode.setAttribute("pos", str(startPos))
xmlNode.setAttribute("length", str(length))
if (node.child_count == 0 and not get_selector(node, config, 'label_ignored')) or get_selector(node, config, 'flattened'):
xmlNode.setAttribute("label", node.text.decode("utf8"))
return xmlNode
def pretty_print_ast(elm, out, level=0):
"""
Outputs the AST into a human-readable format.
"""
elm_desc = f'\033[1m{elm.getAttribute("type")}\033[0m'
if elm.hasAttribute("label"):
elm_desc += f' \033[94m{sanitize_label(elm.getAttribute("label"))}\033[0m'
left_bound = int(elm.getAttribute("pos"))
right_bound = left_bound + int(elm.getAttribute("length"))
elm_desc += f" [{str(left_bound)},{right_bound}]"
log_start = "" if level == 0 else "\n"
out.write(f'{log_start}{" " * level}{elm_desc}')
for child in elm.childNodes:
pretty_print_ast(child, out, level + 1)
def sanitize_label(raw_label):
"""
Sanitize a label by removing newlines and tabs and truncating it if it is too long.
"""
raw_label = raw_label.replace("\n", "")
raw_label = raw_label.replace("\t", "")
label = (
(raw_label[:MAX_LABEL_SIZE] + "..")
if len(raw_label) > MAX_LABEL_SIZE
else raw_label
)
return label