Skip to content

Commit c03e8dc

Browse files
With the new DFT schema, identify examples and point to main fields in the schema. (#55)
* Refine example.json Added script.py for testing locally stuff * Added multiple file support in CLI prompt * Added dft constraints and target in PROMPT_REGISTRY Improved StructuredPrompt * Fix test expectations for updated StructuredPrompt instruction format (#56) * Fix failing test for StructuredPrompt instructions format --------- Co-authored-by: Copilot <[email protected]>
1 parent c084c3c commit c03e8dc

File tree

6 files changed

+76
-33
lines changed

6 files changed

+76
-33
lines changed

nerxiv/cli/cli.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ def cli():
6464
"-path",
6565
type=str,
6666
required=True,
67+
multiple=True,
6768
help="""
68-
The path to the HDF5 file used to prompt the LLM.
69+
The path to the HDF5 file or files used to prompt the LLM.
6970
""",
7071
)
7172
@click.option(
@@ -167,22 +168,23 @@ def prompt(
167168
chunker_kwargs = parse_llm_option_to_args(chunker_option)
168169

169170
# Transform to Path and get the hdf5 data
170-
paper = Path(file_path)
171-
paper_time = run_prompt_paper(
172-
paper=paper,
173-
chunker=chunker,
174-
retriever_model=retriever_model,
175-
n_top_chunks=n_top_chunks,
176-
model=model,
177-
retriever_query=retriever_query,
178-
prompt=prompt,
179-
query=query,
180-
paper_time=start_time,
181-
logger=logger,
182-
**chunker_kwargs,
183-
**llm_kwargs,
184-
)
185-
click.echo(f"Processed arXiv papers in {paper_time:.2f} seconds\n\n")
171+
for file in file_path:
172+
paper = Path(file)
173+
paper_time = run_prompt_paper(
174+
paper=paper,
175+
chunker=chunker,
176+
retriever_model=retriever_model,
177+
n_top_chunks=n_top_chunks,
178+
model=model,
179+
retriever_query=retriever_query,
180+
prompt=prompt,
181+
query=query,
182+
paper_time=start_time,
183+
logger=logger,
184+
**chunker_kwargs,
185+
**llm_kwargs,
186+
)
187+
click.echo(f"Processed arXiv paper {file} in {paper_time:.2f} seconds\n\n")
186188

187189

188190
@cli.command(

nerxiv/datamodel/example.json

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{
2+
"crystal_structure": [],
23
"dft": [
34
{
45
"code": "VASP",
@@ -92,5 +93,9 @@
9293
"spin_treatment": "unrestricted",
9394
"relativistic_treatment": "atomic ZORA"
9495
}
95-
]
96+
],
97+
"projection": [],
98+
"interactions": [],
99+
"dmft": [],
100+
"analytical_continuation": []
96101
}

nerxiv/prompts/prompts.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ class StructuredPrompt(BasePrompt):
203203
target_fields: list[str] = Field(
204204
...,
205205
description="""
206-
The fields within `output_schema` that the prompt should extract.
206+
The fields within `output_schema` that the prompt should extract. If set to `all`, all fields defined in `output_schema` will be extracted.
207207
""",
208208
)
209209

@@ -221,6 +221,9 @@ def validate_target_fields_in_schema(cls, data: Any) -> Any:
221221
"""
222222
model_properties = data.output_schema.model_json_schema().get("properties", {})
223223
for field in data.target_fields:
224+
if field == "all":
225+
data.target_fields = list(model_properties.keys())
226+
break
224227
if field not in model_properties:
225228
raise ValueError(
226229
f"Field '{field}' is not defined in the output schema '{data.output_schema.__name__}'."
@@ -243,8 +246,8 @@ def _build_instructions(self) -> str:
243246
description = clean_description(
244247
model.get("description", "<<no definition provided>>")
245248
)
246-
instruction_lines = f"Given the following scientific text, your task is: to identify all mentions of the {name}. "
247-
instruction_lines += f"This is defined as {description}. "
249+
instruction_lines = f"Given the following scientific text, your task is: to identify all mentions of the {name} section. "
250+
instruction_lines += f"This is defined as a {description} "
248251

249252
instruction_lines += "You must extract the values of the following fields:"
250253
# getting the fields defined for the class and maching them with `target_fields`
@@ -255,18 +258,26 @@ def _build_instructions(self) -> str:
255258
prop_types = [
256259
p.get("type") for p in prop.get("anyOf", []) if p.get("type") != "null"
257260
] # only non-null types
258-
instruction_lines += f"\n- {field} defined as '{prop_description}' and which is of type {prop_types[0]}"
261+
if not prop_types:
262+
instruction_lines += f"\n- {field} defined as {prop_description}"
263+
else:
264+
prop_type = prop_types[0]
265+
if prop_type == "object":
266+
prop_type = "dictionary"
267+
instruction_lines += f"\n- {field} defined as {prop_description} and which is of type {prop_type}"
259268
# TODO add data type
260269

261270
instruction_lines += (
262-
"\nYou must return the extracted values in the following format:"
271+
"\nYou must return the extracted values in JSON format:"
263272
"\n```json\n"
264-
f"'{name}': " + "{\n"
273+
"{\n"
274+
f" '{name}': " + "{\n"
265275
)
266276
for field in self.target_fields:
267277
instruction_lines += f" '{field}': <parsed-value>,\n"
268278

269-
instruction_lines += "}\n```\n"
279+
instruction_lines += " }\n}\n```\n"
280+
instruction_lines += "Note that <parsed-value> means a value of the correct type defined for that field."
270281
return instruction_lines
271282

272283
def build(self, text: str) -> str:

nerxiv/prompts/prompts_registry.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,14 @@
3737
prompt=StructuredPrompt(
3838
expert="Condensed Matter Physics",
3939
output_schema=DFT,
40-
target_fields=[],
41-
constraints=[],
40+
target_fields=["all"],
41+
constraints=[
42+
"Return ONLY the requested JSON object without any additional text or explanation.",
43+
"If you do NOT find the value of a field in the text, do NOT make up a value. Leave it as null in the JSON output.",
44+
"Do NOT infere values of fields that are not explicitly mentioned in the text.",
45+
"Return the JSON as specified in the prompt. Do NOT make up a new JSON with different field names or structure.",
46+
"Ensure that all parsed values are of the correct data type as defined in the DFT schema.",
47+
],
4248
examples=[],
4349
),
4450
),

tests/prompts/test_prompts.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,18 +167,19 @@ def test_build_instructions(self):
167167
)
168168
assert prompt._build_instructions() == (
169169
"Given the following scientific text, your task is: to identify all mentions of "
170-
"the ChemicalFormulation. This is defined as A ChemicalFormulation is a descriptive "
170+
"the ChemicalFormulation section. This is defined as a A ChemicalFormulation is a descriptive "
171171
"representation of the chemical composition of a material system, expressed in one or "
172172
"more standardized formula formats (e.g., IUPAC, anonymous, Hill, or reduced), each "
173173
"encoding the stoichiometry and elemental ordering according to specific conventions. "
174174
"For the compound H2O2 (hydrogen peroxide), the different formulations would be: iupac: "
175-
"H2O2 anonymous: AB hill: H2O2 reduced: H2O2. You must extract the values of the following "
176-
"fields:\n- iupac defined as 'Chemical formula where the elements are ordered using a "
175+
"H2O2 anonymous: AB hill: H2O2 reduced: H2O2 You must extract the values of the following "
176+
"fields:\n- iupac defined as Chemical formula where the elements are ordered using a "
177177
"formal list based on electronegativity as defined in the IUPAC nomenclature of inorganic "
178178
"chemistry (2005): - https://en.wikipedia.org/wiki/List_of_inorganic_compounds Contains "
179179
"reduced integer chemical proportion numbers where the proportion number is omitted if it "
180-
"is 1.' and which is of type string\n- reduced defined as 'Alphabetically sorted chemical "
180+
"is 1. and which is of type string\n- reduced defined as Alphabetically sorted chemical "
181181
"formula with reduced integer chemical proportion numbers. The proportion number is omitted "
182-
"if it is 1.' and which is of type string\nYou must return the extracted values in the "
183-
"following format:\n```json\n'ChemicalFormulation': {\n 'iupac': <parsed-value>,\n 'reduced': <parsed-value>,\n}\n```\n"
182+
"if it is 1. and which is of type string\nYou must return the extracted values in JSON "
183+
"format:\n```json\n{\n 'ChemicalFormulation': {\n 'iupac': <parsed-value>,\n 'reduced': <parsed-value>,\n }\n}\n```\n"
184+
"Note that <parsed-value> means a value of the correct type defined for that field."
184185
)

tutorials/script.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from pathlib import Path
2+
3+
import h5py
4+
5+
for i, path in enumerate(Path("./data").glob("*.hdf5")):
6+
# if i > 0:
7+
# break
8+
with h5py.File(path, "r+") as f:
9+
if "raw_llm_answers" not in f:
10+
continue
11+
raw = f["raw_llm_answers"]
12+
old = raw.require_group("20251028_OLD_raw_llm_answers")
13+
for run in list(raw.keys()):
14+
if not run.startswith("run_"):
15+
continue
16+
old.copy(raw[run], run)
17+
del raw[run]
18+
print(f"Processed file: {path}")

0 commit comments

Comments
 (0)