-
Notifications
You must be signed in to change notification settings - Fork 35
/
client_forward_refs.py
396 lines (308 loc) · 14.4 KB
/
client_forward_refs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
"""
Plugin that delays imports of Pydantic models in client module.
Puts all imports under the `typing.TYPE_CHECKING` flag, making
type annotations for generated client's methods forward references.
This greatly improves import time of generated `client` module when
there are many Pydantic models.
Because generated client's methods need type definitions for models
they are using, those models imports will be also inserted in their
bodies.
This will massively reduce import times for larger projects since you only have
to load the input types when loading the client.
All input and return types that's used to process the server response will
only be imported when the method is called.
"""
import ast
from typing import Dict, List, Optional, Set, Union
from graphql import GraphQLSchema
from ariadne_codegen import Plugin
TYPE_CHECKING_MODULE: str = "typing"
TYPE_CHECKING_FLAG: str = "TYPE_CHECKING"
class ClientForwardRefsPlugin(Plugin):
"""Only import types when you call an endpoint needing it"""
def __init__(self, schema: GraphQLSchema, config_dict: Dict) -> None:
"""Constructor"""
# Types that should only be imported in a `TYPE_CHECKING` context. This
# is all the types used as arguments to a method or as a return type,
# i.e. for type checking.
self.input_and_return_types: Set[str] = set()
# Imported classes are classes imported from local imports. We keep a
# map between name and module so we know how to import them in each
# method.
self.imported_classes: Dict[str, str] = {}
# Imported classes in each method definition.
self.imported_in_method: Set[str] = set()
super().__init__(schema, config_dict)
def generate_client_module(self, module: ast.Module) -> ast.Module:
"""
Update the generated client.
This will parse all current imports to map them to a path. It will then
traverse all methods and look for the actual return type. The return
node will be converted to an `ast.Constant` if it's an `ast.Name` and
the return type will be imported only under `if TYPE_CHECKING`
conditions.
It will also move all imports of the types used to parse the response
inside each method since that's the only place where they're used. The
result will be that we end up with imports in the global scope only for
types used as input types.
:param module: The ast for the module
:returns: A modified `ast.Module`
"""
self._store_imported_classes(module.body)
# Find the actual client class so we can grab all input and output
# types. We also ensure to manipulate the ast while we do this.
client_class_def = next(
filter(lambda o: isinstance(o, ast.ClassDef), module.body), None
)
if not client_class_def or not isinstance(client_class_def, ast.ClassDef):
return super().generate_client_module(module)
for method_def in [
m
for m in client_class_def.body
if isinstance(m, (ast.FunctionDef, ast.AsyncFunctionDef))
]:
method_def = self._rewrite_input_args_to_constants(method_def)
# If the method returns anything, update whatever it returns.
if method_def.returns:
method_def.returns = self._update_name_to_constant(method_def.returns)
self._insert_import_statement_in_method(method_def)
self._update_imports(module)
return super().generate_client_module(module)
def _store_imported_classes(self, module_body: List[ast.stmt]):
"""Fetch and store imported classes.
Grab all imported classes with level 1 or starting with `.` because
these are the ones generated by us. We store a map between the class and
which module it was imported from so we can easily import it when
needed. This can be in a `TYPE_CHECKING` condition or inside a method.
:param module_body: The body of an `ast.Module`
"""
for node in module_body:
if not isinstance(node, ast.ImportFrom):
continue
if node.module is None:
continue
# We only care about local imports from our generated code.
if node.level != 1 and not node.module.startswith("."):
continue
for name in node.names:
from_ = "." * node.level + node.module
if isinstance(name, ast.alias):
self.imported_classes[name.name] = from_
def _rewrite_input_args_to_constants(
self, method_def: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> Union[ast.FunctionDef, ast.AsyncFunctionDef]:
"""Rewrite the arguments to a method.
For any `ast.Name` that requires an import convert it to an
`ast.Constant` instead. The actual class will be noted and imported
in a `TYPE_CHECKING` context.
:param method_def: Method definition
:returns: The same definition but updated
"""
if not isinstance(method_def, (ast.FunctionDef, ast.AsyncFunctionDef)):
return method_def
for i, input_arg in enumerate(method_def.args.args):
annotation = input_arg.annotation
if isinstance(annotation, (ast.Name, ast.Subscript, ast.Tuple)):
method_def.args.args[i].annotation = self._update_name_to_constant(
annotation
)
return method_def
def _insert_import_statement_in_method(
self, method_def: Union[ast.FunctionDef, ast.AsyncFunctionDef]
):
"""Insert import statement in method.
Each method will eventually pass the returned value to a class we've
generated. Since we only need it in the scope of the method ensure we
add it at the top of the method only. It will be removed from the global
scope.
:param method_def: The method definition to updated
"""
# Find the last statement in the body, the call to this class is
# what we need to import first.
return_stmt = method_def.body[-1]
if isinstance(return_stmt, ast.Return):
call = self._get_call_arg_from_return(return_stmt)
elif isinstance(return_stmt, ast.AsyncFor):
call = self._get_call_arg_from_async_for(return_stmt)
else:
return
if call is None:
return
import_class = self._get_class_from_call(call)
if import_class is None:
return
import_class_name = import_class.name
# We add the class to our set of imported in methods - these classes
# don't need to be imported at all in the global scope.
self.imported_in_method.add(import_class_name)
method_def.body.insert(
0,
ast.ImportFrom(
module=self.imported_classes[import_class_name],
names=[import_class],
level=1,
),
)
def _get_call_arg_from_return(self, return_stmt: ast.Return) -> Optional[ast.Call]:
"""Get the class used in the return statement.
:param return_stmt: The statement used for return
"""
# If it's a call of the class like produced by
# `ShorterResultsPlugin` we have an attribute.
if isinstance(return_stmt.value, ast.Attribute) and isinstance(
return_stmt.value.value, ast.Call
):
return return_stmt.value.value
# If not it's just a call statement to the generated class.
if isinstance(return_stmt.value, ast.Call):
return return_stmt.value
return None
def _get_call_arg_from_async_for(
self, last_stmt: ast.AsyncFor
) -> Optional[ast.Call]:
"""Get the class used in the yield expression.
:param last_stmt: The statement used in `ast.AsyncFor`
"""
if isinstance(last_stmt.body, list) and isinstance(last_stmt.body[0], ast.Expr):
body = last_stmt.body[0]
elif isinstance(last_stmt.body, ast.Expr):
body = last_stmt.body
else:
return None
if not isinstance(body, ast.Expr):
return None
if not isinstance(body.value, ast.Yield):
return None
# If it's a call of the class like produced by
# `ShorterResultsPlugin` we have an attribute.
if isinstance(body.value.value, ast.Attribute) and isinstance(
body.value.value.value, ast.Call
):
return body.value.value.value
# If not it's just a call statement to the generated class.
if isinstance(body.value.value, ast.Call):
return body.value.value
return None
def _get_class_from_call(self, call: ast.Call) -> Optional[ast.alias]:
"""Get the class from an `ast.Call`.
:param call: The `ast.Call` arg
:returns: `ast.alias` or `None`
"""
if not isinstance(call.func, ast.Attribute):
return None
if not isinstance(call.func.value, ast.Name):
return None
return ast.alias(name=call.func.value.id)
def _update_imports(self, module: ast.Module) -> None:
"""Update all imports.
Iterate over all imports and remove the aliases that we use as input or
return value. These will be moved and added to an `if TYPE_CHECKING`
block.
We do this by storing all imports that we want to keep in an array, we
then drop all from the body and re-insert the ones to keep. Lastly we
import `TYPE_CHECKING` and add all our imports in the `if TYPE_CHECKING`
block.
:param module: The ast for the whole module.
"""
# We now know all our input types and all our return types. The return
# types that are _not_ used as import types should be in an `if
# TYPE_CHECKING` import block.
return_types_not_used_as_input = set(self.input_and_return_types)
# The ones we import in the method don't need to be imported at all -
# unless that's the type we return. This behaviour can differ if you use
# a plugin such as `ShorterResultsPlugin` that will import a type that
# is different from the type returned.
return_types_not_used_as_input |= (
self.imported_in_method - self.input_and_return_types
)
if len(return_types_not_used_as_input) == 0:
return None
non_empty_imports = self._update_existing_imports(
module, return_types_not_used_as_input
)
self._add_forward_ref_imports(module, non_empty_imports)
return None
def _update_existing_imports(
self, module: ast.Module, return_types_not_used_as_input: set[str]
) -> List[Union[ast.Import, ast.ImportFrom]]:
"""Update existing imports.
Remove all import or import from statements that would otherwise be
useless after moving them to forward refs.
It's very important that we get this right, if we keep any `ImportFrom`
that ends up without any names, the formatting will not work! It will
only remove the empty `import from` but not other unused imports.
:param module: The ast module to update
:param return_types_not_used_as_input: Set of return types not used as
input
"""
non_empty_imports: List[Union[ast.Import, ast.ImportFrom]] = []
last_import_at = 0
for i, node in enumerate(module.body):
if isinstance(node, ast.Import):
last_import_at = i
non_empty_imports.append(node)
if not isinstance(node, ast.ImportFrom):
continue
last_import_at = i
reduced_names = []
for name in node.names:
if name.name not in return_types_not_used_as_input:
reduced_names.append(name)
node.names = reduced_names
if len(reduced_names) > 0:
non_empty_imports.append(node)
# We can now remove all imports and re-insert the ones that's not empty.
module.body = non_empty_imports + module.body[last_import_at + 1 :]
return non_empty_imports
def _add_forward_ref_imports(
self,
module: ast.Module,
non_empty_imports: List[Union[ast.Import, ast.ImportFrom]],
) -> None:
"""Add forward ref imports.
Add all the forward ref imports meaning all the types needed for type
checking under the `if TYPE_CHECKING` condition.
"""
type_checking_imports = {}
for cls in self.input_and_return_types:
module_name = self.imported_classes[cls]
if module_name not in type_checking_imports:
type_checking_imports[module_name] = ast.ImportFrom(
module=module_name, names=[], level=1
)
type_checking_imports[module_name].names.append(ast.alias(cls))
import_if_type_checking = ast.If(
test=ast.Name(id=TYPE_CHECKING_FLAG),
body=list(type_checking_imports.values()),
orelse=[],
)
module.body.insert(len(non_empty_imports), import_if_type_checking)
# Import `TYPE_CHECKING`.
module.body.insert(
len(non_empty_imports),
ast.ImportFrom(
module=TYPE_CHECKING_MODULE,
names=[ast.alias(TYPE_CHECKING_FLAG)],
level=1,
),
)
def _update_name_to_constant(self, node: ast.expr) -> ast.expr:
"""Update return types.
If the return type contains any type that resolves to an `ast.Name`,
convert it to an `ast.Constant`. We only need the type for type checking
and can avoid importing the type in the global scope unless needed.
:param node: The ast node used as return type
:returns: A modified ast node
"""
if isinstance(node, ast.Name):
if node.id in self.imported_classes:
self.input_and_return_types.add(node.id)
return ast.Constant(value=node.id)
if isinstance(node, ast.Subscript):
node.slice = self._update_name_to_constant(node.slice)
return node
if isinstance(node, ast.Tuple):
for i, _ in enumerate(node.elts):
node.elts[i] = self._update_name_to_constant(node.elts[i])
return node
return node