1
1
import json
2
2
import sys
3
+ from enum import Enum
3
4
from pathlib import Path
4
- from typing import Any , Dict , List , Optional
5
+ from typing import Any , Dict , List , Optional , Type , Union
5
6
6
7
import toml
7
8
import typer
8
- from datamodel_code_generator import Error , PythonVersion , chdir
9
+ from datamodel_code_generator import Error , chdir
9
10
from datamodel_code_generator .__main__ import Config , Exit
11
+ from datamodel_code_generator .format import CodeFormatter as BlackCodeFormatter
10
12
from datamodel_code_generator .imports import Import , Imports
11
13
from datamodel_code_generator .reference import Reference
12
14
from datamodel_code_generator .types import DataType
24
26
MODEL_PATH : Path = Path ('models.py' )
25
27
26
28
29
+ class Formatters (str , Enum ):
30
+ YAPF = 'yapf'
31
+ BLACK = 'black'
32
+
33
+
34
+ formaters = {
35
+ Formatters .YAPF .value : YapfCodeFormatter ,
36
+ Formatters .BLACK .value : BlackCodeFormatter ,
37
+ }
38
+
39
+
27
40
@app .command ()
28
41
def main (
29
42
input_file : typer .FileText = typer .Option (..., '--input' , '-i' ), # noqa: B008
@@ -42,6 +55,7 @@ def main(
42
55
'-a' ,
43
56
help = 'Base class for client class' ,
44
57
),
58
+ formater : Optional [Formatters ] = typer .Option (Formatters .YAPF .value , case_sensitive = False ), # noqa: B008
45
59
skip_deprecated : Optional [bool ] = typer .Option ( # noqa: B008
46
60
True ,
47
61
'--skip-deprecated' ,
@@ -75,6 +89,7 @@ def main(
75
89
prefix_api_cls ,
76
90
base_apiclient_cls ,
77
91
skip_deprecated ,
92
+ formaters .get (formater .value ) or YapfCodeFormatter ,
78
93
)
79
94
80
95
@@ -114,6 +129,7 @@ def generate_code(
114
129
prefix_api_cls : Optional [str ],
115
130
base_apiclient_cls : Optional [str ],
116
131
skip_deprecated : Optional [bool ],
132
+ code_formatter_cls : Type [Union [YapfCodeFormatter , BlackCodeFormatter ]],
117
133
) -> None :
118
134
output_dir .mkdir (parents = True , exist_ok = True )
119
135
template_dir = template_dir or BUILTIN_TEMPLATE_DIR
@@ -177,12 +193,15 @@ def generate_code(
177
193
if reference :
178
194
imports .append (data_type .all_imports )
179
195
imports .append (Import .from_full_path (f'.{ MODEL_PATH .stem } .{ reference .name } ' ))
196
+
180
197
for from_ , imports_ in parser .imports_for_endpoints .items ():
181
198
imports [from_ ].update (imports_ )
199
+
182
200
results : Dict [Path , str ] = {}
183
- # TODO: Choose formater from cli
184
- code_formatter = YapfCodeFormatter ( PythonVersion . PY_38 , Path ().resolve ())
201
+
202
+ code_formatter = code_formatter_cls ( data_config . target_python_version , Path ().resolve ())
185
203
sorted_operations : List [Operation ] = sorted (parser .operations .values (), key = lambda m : m .path )
204
+
186
205
for target in template_dir .rglob ('*' ):
187
206
relative_path = target .relative_to (template_dir )
188
207
result = environment .get_template (str (relative_path )).render (
0 commit comments