diff --git a/app/data_source/sources/notion/__init__.py b/app/data_source/sources/notion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/data_source/sources/notion/notion.py b/app/data_source/sources/notion/notion.py new file mode 100644 index 0000000..1721196 --- /dev/null +++ b/app/data_source/sources/notion/notion.py @@ -0,0 +1,242 @@ +import logging +from datetime import datetime +from enum import Enum +from typing import Dict, List + +import requests +from data_source.api.base_data_source import BaseDataSource, BaseDataSourceConfig, ConfigField, HTMLInputType +from data_source.api.basic_document import BasicDocument, DocumentType +from data_source.api.exception import InvalidDataSourceConfig +from queues.index_queue import IndexQueue +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +logger = logging.getLogger(__name__) + +# Notion API Status codes https://developers.notion.com/reference/status-codes + +HTTP_OK = 200 +HTTP_BAD_REQUEST = 400 +HTTP_UNAUTHORIZED = 401 +HTTP_FORBIDDEN = 403 +HTTP_NOT_FOUND = 404 +HTTP_CONFLICT = 409 +HTTP_TOO_MANY_REQUESTS = 429 + +# 5xx Server Errors +HTTP_INTERNAL_SERVER_ERROR = 500 +HTTP_SERVICE_UNAVAILABLE = 503 +HTTP_GATEWAY_TIMEOUT = 504 + +RETRY_AFTER_STATUS_CODES = frozenset( + { + HTTP_TOO_MANY_REQUESTS, + HTTP_INTERNAL_SERVER_ERROR, + HTTP_SERVICE_UNAVAILABLE, + HTTP_GATEWAY_TIMEOUT, + } +) + + +def _notion_retry_session(token, retries=10, backoff_factor=2.0, status_forcelist=RETRY_AFTER_STATUS_CODES): + """Creates a retry session""" + session = requests.Session() + retry = Retry( + total=retries, + connect=retries, + read=retries, + status=retries, + backoff_factor=backoff_factor, + status_forcelist=status_forcelist, + raise_on_redirect=False, + raise_on_status=False, + respect_retry_after_header=True, + ) + adapter = HTTPAdapter() + adapter.max_retries = retry + session.mount("http://", adapter) + session.mount("https://", adapter) + session.headers.update({"Notion-Version": "2022-06-28", "Authorization": f"Bearer {token}"}) + return session + + +class NotionObject(str, Enum): + page = "page" + database = "database" + + +class NotionClient: + def __init__(self, token): + self.api_url = "https://api.notion.com/v1" + self.session = _notion_retry_session(token) + + def auth_check(self): + url = f"{self.api_url}/users/me" + response = self.session.get(url) + response.raise_for_status() + + def get_user(self, user_id): + url = f"{self.api_url}/users/{user_id}" + response = self.session.get(url) + try: + return response.json() + except requests.exceptions.JSONDecodeError: + return {} + + def list_objects(self, notion_object: NotionObject): + url = f"{self.api_url}/search" + filter_data = { + "filter": {"value": notion_object, "property": "object"}, + "sort": {"direction": "ascending", "timestamp": "last_edited_time"}, + } + response = self.session.post(url, json=filter_data) + results = response.json()["results"] + while response.json()["has_more"] is True: + response = self.session.post(url, json={"start_cursor": response.json()["next_cursor"], **filter_data}) + results.extend(response.json()["results"]) + return results + + def list_pages(self): + return self.list_objects(NotionObject.page) + + def list_databases(self): + return self.list_objects(NotionObject.database) + + def list_blocks(self, block_id: str): + url = f"{self.api_url}/blocks/{block_id}/children" + params = {"page_size": 100} + response = self.session.get(url, params=params) + if not response.json()["results"]: + return [] + results = response.json()["results"] + while response.json()["has_more"] is True: + response = self.session.get(url, params={"start_cursor": response.json()["next_cursor"], **params}) + results.extend(response.json()["results"]) + return results + + def list_database_pages(self, database_id: str): + url = f"{self.api_url}/databases/{database_id}/query" + filter_data = {"page_size": 100} + response = self.session.post(url, json=filter_data) + results = response.json()["results"] + while response.json()["has_more"] is True: + response = self.session.post( + url, + json={"start_cursor": response.json()["next_cursor"], **filter_data}, + ) + results.extend(response.json()["results"]) + return results + + +class NotionConfig(BaseDataSourceConfig): + token: str + + +class NotionDataSource(BaseDataSource): + @staticmethod + def get_config_fields() -> List[ConfigField]: + """ + list of the config fields which should be the same fields as in MagicConfig, for dynamic UI generation + """ + return [ + ConfigField( + label="Notion Integration Token", + name="token", + placeholder="secret_AZefAeAZqsfDAZE", + input_type=HTMLInputType.PASSWORD, + ) + ] + + @staticmethod + async def validate_config(config: Dict) -> None: + """ + Validate the configuration and raise an exception if it's invalid, + You should try to actually connect to the data source and verify that it's working + """ + try: + parsed_config = NotionConfig(**config) + notion_client = NotionClient(token=parsed_config.token) + notion_client.auth_check() + except Exception as e: + raise InvalidDataSourceConfig from e + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + notion_config = NotionConfig(**self._raw_config) + self._notion_client = NotionClient( + token=notion_config.token, + ) + self.data_source_id = "DUMMY_SOURCE_ID" + + def _parse_rich_text(self, rich_text: list): + return "\n".join([text["plain_text"] for text in rich_text]) + + def _parse_content_from_blocks(self, notion_blocks): + return "\n".join( + [ + self._parse_rich_text(block[block["type"]]["rich_text"]) + for block in notion_blocks + if block[block["type"]].get("rich_text") + ] + ) + + def _parse_title(self, page): + title_prop = next(prop for prop in page["properties"] if page["properties"][prop]["type"] == "title") + return self._parse_rich_text(page["properties"][title_prop]["title"]) + + def _parse_content_from_page(self, page): + metadata_list = [ + f"{prop}: {self._parse_rich_text(page['properties'][prop].get('rich_text',''))}" + for prop in page["properties"] + if prop != "Name" + ] + title = f"{self._parse_title(page)}" + metadata = "\n".join([f"Title: {title}"] + metadata_list) + page_blocks = self._notion_client.list_blocks(page["id"]) + blocks_content = self._parse_content_from_blocks(page_blocks) + author = self._notion_client.get_user(page["created_by"]["id"]) + return { + "id": page["id"], + "author": author.get("name", ""), + "author_image_url": author.get("avatar_url", ""), + "url": page["url"], + "title": title, + "location": title, + "content": metadata + blocks_content, + "timestamp": datetime.strptime(page["last_edited_time"], "%Y-%m-%dT%H:%M:%S.%fZ"), + } + + def _feed_new_documents(self) -> None: + logger.info("Fetching non database pages ...") + single_pages = self._notion_client.list_pages() + logger.info(f"Found {len(single_pages)} non database pages ...") + + logger.info("Fetching databases ...") + databases = self._notion_client.list_databases() + logger.info(f"Found {len(databases)} databases ...") + + all_database_pages = [] + for database in databases: + database_pages = self._notion_client.list_database_pages(database["id"]) + logger.info(f"Found {len(database_pages)} pages to index in database {database['id']} ...") + all_database_pages.extend(database_pages) + + pages = single_pages + all_database_pages + logger.info(f"Found {len(pages)} pages in total ...") + + for page in pages: + last_updated_at = datetime.strptime(page["last_edited_time"], "%Y-%m-%dT%H:%M:%S.%fZ") + if last_updated_at < self._last_index_time: + # skipping already indexed pages + continue + try: + page_data = self._parse_content_from_page(page) + logger.info(f"Indexing page {page_data['id']}") + document = BasicDocument( + data_source_id=self._data_source_id, + type=DocumentType.DOCUMENT, + **page_data, + ) + IndexQueue.get_instance().put_single(document) + except Exception as e: + logger.error(f"Failed to index page {page['id']}", exc_info=e) diff --git a/app/data_source/sources/slack/slack.py b/app/data_source/sources/slack/slack.py index 332700b..d9e7593 100644 --- a/app/data_source/sources/slack/slack.py +++ b/app/data_source/sources/slack/slack.py @@ -3,16 +3,15 @@ import time from dataclasses import dataclass from http.client import IncompleteRead -from typing import Optional, Dict, List +from typing import Dict, List, Optional +from data_source.api.base_data_source import BaseDataSource, BaseDataSourceConfig, ConfigField, HTMLInputType +from data_source.api.basic_document import BasicDocument, DocumentType +from queues.index_queue import IndexQueue from retry import retry from slack_sdk import WebClient from slack_sdk.errors import SlackApiError -from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType, BaseDataSourceConfig -from data_source.api.basic_document import DocumentType, BasicDocument -from queues.index_queue import IndexQueue - logger = logging.getLogger(__name__) @@ -37,9 +36,7 @@ class SlackDataSource(BaseDataSource): @staticmethod def get_config_fields() -> List[ConfigField]: - return [ - ConfigField(label="Bot User OAuth Token", name="token", type=HTMLInputType.PASSWORD) - ] + return [ConfigField(label="Bot User OAuth Token", name="token", type=HTMLInputType.PASSWORD)] @staticmethod async def validate_config(config: Dict) -> None: @@ -49,7 +46,7 @@ async def validate_config(config: Dict) -> None: @staticmethod def _is_valid_message(message: Dict) -> bool: - return 'client_msg_id' in message or 'bot_id' in message + return "client_msg_id" in message or "bot_id" in message def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -59,8 +56,7 @@ def __init__(self, *args, **kwargs): def _list_conversations(self) -> List[SlackConversation]: conversations = self._slack.conversations_list(exclude_archived=True, limit=1000) - return [SlackConversation(id=conv['id'], name=conv['name']) - for conv in conversations['channels']] + return [SlackConversation(id=conv["id"], name=conv["name"]) for conv in conversations["channels"]] def _feed_conversations(self, conversations: List[SlackConversation]) -> List[SlackConversation]: joined_conversations = [] @@ -68,11 +64,11 @@ def _feed_conversations(self, conversations: List[SlackConversation]) -> List[Sl for conv in conversations: try: result = self._slack.conversations_join(channel=conv.id) - if result['ok']: - logger.info(f'Joined channel {conv.name}, adding a fetching task...') + if result["ok"]: + logger.info(f"Joined channel {conv.name}, adding a fetching task...") self.add_task_to_queue(self._feed_conversation, conv=conv) except Exception as e: - logger.warning(f'Could not join channel {conv.name}: {e}') + logger.warning(f"Could not join channel {conv.name}: {e}") return joined_conversations @@ -80,22 +76,21 @@ def _get_author_details(self, author_id: str) -> SlackAuthor: author = self._authors_cache.get(author_id, None) if author is None: author_info = self._slack.users_info(user=author_id) - user = author_info['user'] - name = user.get('real_name') or user.get('name') or user.get('profile', {}).get('display_name') or 'Unknown' - author = SlackAuthor(name=name, - image_url=author_info['user']['profile']['image_72']) + user = author_info["user"] + name = user.get("real_name") or user.get("name") or user.get("profile", {}).get("display_name") or "Unknown" + author = SlackAuthor(name=name, image_url=author_info["user"]["profile"]["image_72"]) self._authors_cache[author_id] = author return author def _feed_new_documents(self) -> None: conversations = self._list_conversations() - logger.info(f'Found {len(conversations)} conversations') + logger.info(f"Found {len(conversations)} conversations") self._feed_conversations(conversations) def _feed_conversation(self, conv: SlackConversation): - logger.info(f'Feeding conversation {conv.name}') + logger.info(f"Feeding conversation {conv.name}") last_msg: Optional[BasicDocument] = None @@ -107,13 +102,13 @@ def _feed_conversation(self, conv: SlackConversation): last_msg = None continue - text = message['text'] - if author_id := message.get('user'): + text = message["text"] + if author_id := message.get("user"): author = self._get_author_details(author_id) - elif message.get('bot_id'): - author = SlackAuthor(name=message.get('username'), image_url=message.get('icons', {}).get('image_48')) + elif message.get("bot_id"): + author = SlackAuthor(name=message.get("username"), image_url=message.get("icons", {}).get("image_48")) else: - logger.warning(f'Unknown message author: {message}') + logger.warning(f"Unknown message author: {message}") continue if last_msg is not None: @@ -124,15 +119,22 @@ def _feed_conversation(self, conv: SlackConversation): IndexQueue.get_instance().put_single(doc=last_msg) last_msg = None - timestamp = message['ts'] - message_id = message.get('client_msg_id') or timestamp + timestamp = message["ts"] + message_id = message.get("client_msg_id") or timestamp readable_timestamp = datetime.datetime.fromtimestamp(float(timestamp)) message_url = f"https://slack.com/app_redirect?channel={conv.id}&message_ts={timestamp}" - last_msg = BasicDocument(title=author.name, content=text, author=author.name, - timestamp=readable_timestamp, id=message_id, - data_source_id=self._data_source_id, location=conv.name, - url=message_url, author_image_url=author.image_url, - type=DocumentType.MESSAGE) + last_msg = BasicDocument( + title=author.name, + content=text, + author=author.name, + timestamp=readable_timestamp, + id=message_id, + data_source_id=self._data_source_id, + location=conv.name, + url=message_url, + author_image_url=author.image_url, + type=DocumentType.MESSAGE, + ) if last_msg is not None: IndexQueue.get_instance().put_single(doc=last_msg) @@ -140,19 +142,19 @@ def _feed_conversation(self, conv: SlackConversation): @retry(tries=5, delay=1, backoff=2, logger=logger) def _get_conversation_history(self, conv: SlackConversation, cursor: str, last_index_unix: str): try: - return self._slack.conversations_history(channel=conv.id, oldest=last_index_unix, - limit=1000, cursor=cursor) + return self._slack.conversations_history(channel=conv.id, oldest=last_index_unix, limit=1000, cursor=cursor) except SlackApiError as e: - logger.warning(f'SlackApi error while fetching messages for conversation {conv.name}: {e}') + logger.warning(f"SlackApi error while fetching messages for conversation {conv.name}: {e}") response = e.response - if response['error'] == 'ratelimited': - retry_after_seconds = int(response['headers']['Retry-After']) - logger.warning(f'Rate-limited: Slack API rate limit exceeded,' - f' retrying after {retry_after_seconds} seconds') + if response["error"] == "ratelimited": + retry_after_seconds = int(response["headers"]["Retry-After"]) + logger.warning( + f"Rate-limited: Slack API rate limit exceeded," f" retrying after {retry_after_seconds} seconds" + ) time.sleep(retry_after_seconds) raise e except IncompleteRead as e: - logger.warning(f'IncompleteRead error while fetching messages for conversation {conv.name}') + logger.warning(f"IncompleteRead error while fetching messages for conversation {conv.name}") raise e def _fetch_conversation_messages(self, conv: SlackConversation): @@ -160,19 +162,22 @@ def _fetch_conversation_messages(self, conv: SlackConversation): cursor = None has_more = True last_index_unix = self._last_index_time.timestamp() - logger.info(f'Fetching messages for conversation {conv.name}') + logger.info(f"Fetching messages for conversation {conv.name}") while has_more: try: - response = self._get_conversation_history(conv=conv, cursor=cursor, - last_index_unix=str(last_index_unix)) + response = self._get_conversation_history( + conv=conv, cursor=cursor, last_index_unix=str(last_index_unix) + ) except Exception as e: - logger.warning(f'Error fetching all messages for conversation {conv.name},' - f' returning {len(messages)} messages. Error: {e}') + logger.warning( + f"Error fetching all messages for conversation {conv.name}," + f" returning {len(messages)} messages. Error: {e}" + ) return messages logger.info(f'Fetched {len(response["messages"])} messages for conversation {conv.name}') - messages.extend(response['messages']) + messages.extend(response["messages"]) if has_more := response["has_more"]: cursor = response["response_metadata"]["next_cursor"] diff --git a/app/indexing/index_documents.py b/app/indexing/index_documents.py index 8722983..c4b3bec 100644 --- a/app/indexing/index_documents.py +++ b/app/indexing/index_documents.py @@ -7,12 +7,10 @@ from db_engine import Session from indexing.bm25_index import Bm25Index from indexing.faiss_index import FaissIndex +from langchain.schema import Document as PDFDocument from models import bi_encoder -from parsers.pdf import split_PDF_into_paragraphs from paths import IS_IN_DOCKER from schemas import Document, Paragraph -from langchain.schema import Document as PDFDocument - logger = logging.getLogger(__name__) @@ -25,7 +23,6 @@ def get_enum_value_or_none(enum: Optional[Enum]) -> Optional[str]: class Indexer: - @staticmethod def basic_to_document(document: BasicDocument, parent: Document = None) -> Document: paragraphs = Indexer._split_into_paragraphs(document.content) @@ -43,11 +40,8 @@ def basic_to_document(document: BasicDocument, parent: Document = None) -> Docum location=document.location, url=document.url, timestamp=document.timestamp, - paragraphs=[ - Paragraph(content=content) - for content in paragraphs - ], - parent=parent + paragraphs=[Paragraph(content=content) for content in paragraphs], + parent=parent, ) @staticmethod @@ -57,10 +51,11 @@ def index_documents(documents: List[BasicDocument]): ids_in_data_source = [document.id_in_data_source for document in documents] with Session() as session: - documents_to_delete = session.query(Document).filter( - Document.id_in_data_source.in_(ids_in_data_source)).all() + documents_to_delete = ( + session.query(Document).filter(Document.id_in_data_source.in_(ids_in_data_source)).all() + ) if documents_to_delete: - logging.info(f'removing documents that were updated and need to be re-indexed.') + logging.info(f"removing documents that were updated and need to be re-indexed.") Indexer.remove_documents(documents_to_delete, session) for document in documents_to_delete: # Currently bulk deleting doesn't cascade. So we need to delete them one by one. @@ -120,15 +115,15 @@ def _split_into_paragraphs(text, minimum_length=256): if text is None: return [] paragraphs = [] - current_paragraph = '' - for paragraph in re.split(r'\n\s*\n', text): + current_paragraph = "" + for paragraph in re.split(r"\n\s*\n", text): if len(current_paragraph) > 0: - current_paragraph += ' ' + current_paragraph += " " current_paragraph += paragraph.strip() if len(current_paragraph) > minimum_length: paragraphs.append(current_paragraph) - current_paragraph = '' + current_paragraph = "" if len(current_paragraph) > 0: paragraphs.append(current_paragraph) @@ -138,7 +133,7 @@ def _split_into_paragraphs(text, minimum_length=256): def _add_metadata_for_indexing(paragraph: Paragraph) -> str: result = paragraph.content if paragraph.document.title is not None: - result += '; ' + paragraph.document.title + result += "; " + paragraph.document.title return result @staticmethod diff --git a/app/static/data_source_icons/notion.png b/app/static/data_source_icons/notion.png new file mode 100644 index 0000000..956fcff Binary files /dev/null and b/app/static/data_source_icons/notion.png differ diff --git a/ui/src/components/data-source-panel.tsx b/ui/src/components/data-source-panel.tsx index 867b70b..4a2e6b4 100644 --- a/ui/src/components/data-source-panel.tsx +++ b/ui/src/components/data-source-panel.tsx @@ -393,6 +393,15 @@ export default class DataSourcePanel extends React.Component ) } + { + this.state.selectedDataSource.value === 'notion' && ( + + 1. {'Go to notion.so/my-integrations -> Create an integration and copy its token -> Give it read access capabilities (users included)'} + 2. {'Add your integration to high authority page so that it has access to all of its children.'} + 3. {'Paste the token'} + + ) + }