Skip to content

Commit

Permalink
feat: allow to use customized GraphRAG settings.yaml (#387) bump:patch
Browse files Browse the repository at this point in the history
* allow to use customized GraphRAG settings.yaml

* adjust import style

* fix typo

* Added GraphRAG original documentation reference.

* feat: allow to use customized GraphRAG settings.yaml
(#387)

---------

Co-authored-by: Chen, Ron Gang <[email protected]>
  • Loading branch information
ronchengang and amaler authored Oct 14, 2024
1 parent f0f3b4b commit 8188760
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ GRAPHRAG_API_KEY=<YOUR_OPENAI_KEY>
GRAPHRAG_LLM_MODEL=gpt-4o-mini
GRAPHRAG_EMBEDDING_MODEL=text-embedding-3-small

# set to true if you want to use customized GraphRAG config file
USE_CUSTOMIZED_GRAPHRAG_SETTING=false

# settings for Azure DI
AZURE_DI_ENDPOINT=
AZURE_DI_CREDENTIAL=
Expand Down
33 changes: 31 additions & 2 deletions libs/ktem/ktem/index/file/graph/pipelines.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
import subprocess
from pathlib import Path
from shutil import rmtree
Expand All @@ -7,6 +8,8 @@

import pandas as pd
import tiktoken
import yaml
from decouple import config
from ktem.db.models import engine
from sqlalchemy.orm import Session
from theflow.settings import settings
Expand Down Expand Up @@ -116,6 +119,16 @@ def call_graphrag_index(self, input_path: str):
print(result.stdout)
command = command[:-1]

# copy customized GraphRAG config file if it exists
if config("USE_CUSTOMIZED_GRAPHRAG_SETTING", default="value").lower() == "true":
setting_file_path = os.path.join(os.getcwd(), "settings.yaml.example")
destination_file_path = os.path.join(input_path, "settings.yaml")
try:
shutil.copy(setting_file_path, destination_file_path)
except shutil.Error:
# Handle the error if the file copy fails
print("failed to copy customized GraphRAG config file. ")

# Run the command and stream stdout
with subprocess.Popen(command, stdout=subprocess.PIPE, text=True) as process:
if process.stdout:
Expand Down Expand Up @@ -221,12 +234,28 @@ def _build_graph_search(self):
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)

# initialize default settings
embedding_model = os.getenv(
"GRAPHRAG_EMBEDDING_MODEL", "text-embedding-3-small"
)
embedding_api_key = os.getenv("GRAPHRAG_API_KEY")
embedding_api_base = None

# use customized GraphRAG settings if the flag is set
if config("USE_CUSTOMIZED_GRAPHRAG_SETTING", default="value").lower() == "true":
settings_yaml_path = Path(root_path) / "settings.yaml"
with open(settings_yaml_path, "r") as f:
settings = yaml.safe_load(f)
if settings["embeddings"]["llm"]["model"]:
embedding_model = settings["embeddings"]["llm"]["model"]
if settings["embeddings"]["llm"]["api_key"]:
embedding_api_key = settings["embeddings"]["llm"]["api_key"]
if settings["embeddings"]["llm"]["api_base"]:
embedding_api_base = settings["embeddings"]["llm"]["api_base"]

text_embedder = OpenAIEmbedding(
api_key=os.getenv("GRAPHRAG_API_KEY"),
api_base=None,
api_key=embedding_api_key,
api_base=embedding_api_base,
api_type=OpenaiApiType.OpenAI,
model=embedding_model,
deployment_name=embedding_model,
Expand Down
159 changes: 159 additions & 0 deletions settings.yaml.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# This is a sample GraphRAG settings.yaml file that allows users to run the GraphRAG index process with their customized parameters.
# The parameters in this file will only take effect when the USE_CUSTOMIZED_GRAPHRAG_SETTING is true in .env file.
# For a comprehensive understanding of GraphRAG parameters, please refer to: https://microsoft.github.io/graphrag/config/json_yaml/.

encoding_model: cl100k_base
skip_workflows: []
llm:
api_key: ${GRAPHRAG_API_KEY}
type: openai_chat # or azure_openai_chat
api_base: http://127.0.0.1:11434/v1
model: qwen2
model_supports_json: true # recommended if this is available for your model.
# max_tokens: 4000
request_timeout: 1800.0
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-02-15-preview
# organization: <organization_id>
# deployment_name: <azure_model_deployment_name>
# tokens_per_minute: 150_000 # set a leaky bucket throttle
# requests_per_minute: 10_000 # set a leaky bucket throttle
# max_retries: 10
# max_retry_wait: 10.0
# sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times
concurrent_requests: 5 # the number of parallel inflight requests that may be made
# temperature: 0 # temperature for sampling
# top_p: 1 # top-p sampling
# n: 1 # Number of completions to generate

parallelization:
stagger: 0.3
# num_threads: 50 # the number of threads to use for parallel processing

async_mode: threaded # or asyncio

embeddings:
## parallelization: override the global parallelization settings for embeddings
async_mode: threaded # or asyncio
# target: required # or all
# batch_size: 16 # the number of documents to send in a single request
# batch_max_tokens: 8191 # the maximum number of tokens to send in a single request
llm:
api_base: http://localhost:11434/v1
api_key: ${GRAPHRAG_API_KEY}
model: nomic-embed-text
type: openai_embedding
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-02-15-preview
# organization: <organization_id>
# deployment_name: <azure_model_deployment_name>
# tokens_per_minute: 150_000 # set a leaky bucket throttle
# requests_per_minute: 10_000 # set a leaky bucket throttle
# max_retries: 10
# max_retry_wait: 10.0
# sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times
# concurrent_requests: 25 # the number of parallel inflight requests that may be made

chunks:
size: 1200
overlap: 100
group_by_columns: [id] # by default, we don't allow chunks to cross documents

input:
type: file # or blob
file_type: text # or csv
base_dir: "input"
file_encoding: utf-8
file_pattern: ".*\\.txt$"

cache:
type: file # or blob
base_dir: "cache"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>

storage:
type: file # or blob
base_dir: "output"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>

reporting:
type: file # or console, blob
base_dir: "output"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>

entity_extraction:
## strategy: fully override the entity extraction strategy.
## type: one of graph_intelligence, graph_intelligence_json and nltk
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
prompt: "prompts/entity_extraction.txt"
entity_types: [organization,person,geo,event]
max_gleanings: 1

summarize_descriptions:
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
prompt: "prompts/summarize_descriptions.txt"
max_length: 500

claim_extraction:
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
# enabled: true
prompt: "prompts/claim_extraction.txt"
description: "Any claims or facts that could be relevant to information discovery."
max_gleanings: 1

community_reports:
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
prompt: "prompts/community_report.txt"
max_length: 2000
max_input_length: 8000

cluster_graph:
max_cluster_size: 10

embed_graph:
enabled: false # if true, will generate node2vec embeddings for nodes
# num_walks: 10
# walk_length: 40
# window_size: 2
# iterations: 3
# random_seed: 597832

umap:
enabled: false # if true, will generate UMAP embeddings for nodes

snapshots:
graphml: false
raw_entities: false
top_level_nodes: false

local_search:
# text_unit_prop: 0.5
# community_prop: 0.1
# conversation_history_max_turns: 5
# top_k_mapped_entities: 10
# top_k_relationships: 10
# llm_temperature: 0 # temperature for sampling
# llm_top_p: 1 # top-p sampling
# llm_n: 1 # Number of completions to generate
# max_tokens: 12000

global_search:
# llm_temperature: 0 # temperature for sampling
# llm_top_p: 1 # top-p sampling
# llm_n: 1 # Number of completions to generate
# max_tokens: 12000
# data_max_tokens: 12000
# map_max_tokens: 1000
# reduce_max_tokens: 2000
# concurrency: 32

0 comments on commit 8188760

Please sign in to comment.