-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheck_extension.py
197 lines (166 loc) · 6.81 KB
/
check_extension.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import json
import tempfile
import subprocess
import os
from pathlib import Path
import multiprocessing
from functools import partial
from itertools import zip_longest
def load_template(template_path):
"""Load the template file"""
with open(template_path, 'r', encoding='utf-8') as f:
return f.read()
def get_ai_classification(software, cve_id, description, gen_struct_path, template):
"""Ask AI to classify if the CVE is related to an extension/plugin/module"""
# Define the JSON schema for classification
schema = {
"type": "object",
"properties": {
"affected_category": {
"type": "string",
"enum": ["Core", "Extension", "Related", "NotSure"],
"description": "Whether the CVE affects an extension/plugin/module, the core software, or a separate related software"
},
"reason": {
"type": "string",
"description": "Explanation for the classification"
}
},
"required": ["affected_category", "reason"],
"additionalProperties": False
}
# Create temporary files
with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as temp_input:
# Fill in the template
prompt = template.format(
software=software,
cve_id=cve_id,
description=description
)
temp_input.write(prompt)
print(f"Processing CVE: {cve_id}")
temp_input_path = temp_input.name
with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.json') as temp_schema:
json.dump(schema, temp_schema)
schema_file = temp_schema.name
with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.json') as temp_output:
temp_output_path = temp_output.name
try:
# Run gen_struct.py
subprocess.run([
'python', gen_struct_path,
temp_input_path, temp_output_path, schema_file
], check=True)
# Read the result
with open(temp_output_path, 'r', encoding='utf-8') as f:
result = json.load(f)
return result["affected_category"].lower(), result["reason"]
except Exception as e:
print(f"Error during AI classification: {e}")
return "unknown", "Error during classification"
finally:
# Cleanup temporary files
os.unlink(temp_input_path)
os.unlink(temp_output_path)
os.unlink(schema_file)
def process_cve_batch(args):
"""Process a batch of CVEs"""
software, template, gen_struct_path, cve_batch = args
results = []
for cve_data in cve_batch:
if cve_data is None:
continue
cve_id = cve_data['cve_id']
description = cve_data['description']
result, reason = get_ai_classification(
software,
cve_id,
description,
gen_struct_path,
template
)
if result != 'unknown':
results.append({
'cve_id': cve_id,
'description': description,
'affected_category': result,
'reason': reason
})
return results
def grouper(iterable, n):
"""Collect data into fixed-length chunks or blocks"""
args = [iter(iterable)] * n
return zip_longest(*args)
def process_software_cves(software_dir, template, gen_struct_path, batch_size=5, save_interval=1):
"""Process CVEs for a specific software"""
cves_file = software_dir / 'search_cves.json'
output_file = software_dir / 'extension_analysis.json'
if not cves_file.exists():
print(f"No CVEs file found in {software_dir}")
return []
# Load existing results if any
existing_results = []
if output_file.exists():
with open(output_file, 'r', encoding='utf-8') as f:
existing_results = json.load(f)
print(f"Loaded {len(existing_results)} existing results from {output_file}")
# Get set of already processed CVE IDs
processed_cves = {result['cve_id'] for result in existing_results}
# Load and filter unprocessed CVEs
with open(cves_file, 'r', encoding='utf-8') as f:
cves_data = json.load(f)
unprocessed_cves = [cve for cve in cves_data if cve['cve_id'] not in processed_cves]
if not unprocessed_cves:
print(f"All CVEs already processed for {software_dir.name}")
return existing_results
print(f"Processing {len(unprocessed_cves)} new CVEs for {software_dir.name}")
software_name = software_dir.name
results = existing_results
batches_processed = 0
# Process in batches
with multiprocessing.Pool(batch_size) as pool:
for batch in grouper(unprocessed_cves, batch_size):
# Create batch arguments
batch_args = [(software_name, template, gen_struct_path, [cve]) for cve in batch if cve]
# Process batch in parallel
batch_results = pool.map(process_cve_batch, batch_args)
# Flatten results and add to main results list
for batch_result in batch_results:
results.extend(batch_result)
batches_processed += 1
# Save intermediate results
if batches_processed % save_interval == 0:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2)
print(f"Saved intermediate results after {batches_processed} batches")
# Save final results
if results:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2)
print(f"Saved {len(results)} total results to {output_file}")
return results
def main():
# File paths
template_path = Path('scripts/prompts/check_extension_general.template.md')
gen_struct_path = Path('scripts/ai/gen_struct.py')
search_results_dir = Path('search_results')
# Load template
template = load_template(template_path)
# Process each software directory
all_results = {}
for software_dir in search_results_dir.iterdir():
if software_dir.is_dir():
print(f"\nProcessing {software_dir.name} CVEs...")
# if software_dir.name != 'minecraft':
# continue
results = process_software_cves(software_dir, template, gen_struct_path)
if results:
all_results[software_dir.name] = results
print(f"Found {len(results)} classified CVEs for {software_dir.name}")
# Save overall results
output_file = Path('extension_analysis/extension_analysis.json')
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(all_results, f, indent=2)
print(f"\nOverall results saved to {output_file}")
if __name__ == "__main__":
main()