From e5e1c1e11cd13c86af2bdd3b45964c343fc7c36b Mon Sep 17 00:00:00 2001 From: Eric Liu Date: Sun, 17 Aug 2025 05:29:37 +0000 Subject: [PATCH] feat: enhance Obsidian API configuration with path support and implement async lifespan for server --- src/mcp_obsidian/obsidian.py | 15 ++++++++++---- src/mcp_obsidian/server.py | 18 +++++++++++----- src/mcp_obsidian/tools.py | 27 ++++++++++++------------ tests/integration_test.py | 40 ++++++++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 22 deletions(-) create mode 100644 tests/integration_test.py diff --git a/src/mcp_obsidian/obsidian.py b/src/mcp_obsidian/obsidian.py index 051bd06..198ba11 100644 --- a/src/mcp_obsidian/obsidian.py +++ b/src/mcp_obsidian/obsidian.py @@ -11,6 +11,7 @@ class Obsidian(): protocol: str = os.getenv('OBSIDIAN_PROTOCOL', 'https').lower(), host: str = str(os.getenv('OBSIDIAN_HOST', '127.0.0.1')), port: int = int(os.getenv('OBSIDIAN_PORT', '27124')), + path: str = '', verify_ssl: bool = False, ): self.api_key = api_key @@ -22,6 +23,7 @@ class Obsidian(): self.host = host self.port = port + self.path = path.rstrip('/') self.verify_ssl = verify_ssl self.timeout = (3, 6) @@ -52,6 +54,7 @@ class Obsidian(): protocol = parsed.scheme host = parsed.hostname port = parsed.port + path = parsed.path # Set default ports based on protocol if not specified if port is None: @@ -62,20 +65,21 @@ class Obsidian(): protocol=protocol, host=host, port=port, + path=path, verify_ssl=verify_ssl ) except Exception as e: raise ValueError(f"Failed to parse OBSIDIAN_HOST URL '{url}': {str(e)}") @staticmethod - def parse_host_config(host_config: str) -> Tuple[str, str, int]: + def parse_host_config(host_config: str) -> Tuple[str, str, int, str]: """Parse host configuration string. Args: host_config: Either a full URL (http://host:port) or just hostname/IP Returns: - Tuple of (protocol, host, port) + Tuple of (protocol, host, port, path) """ if '://' in host_config: # Full URL format @@ -83,6 +87,7 @@ class Obsidian(): protocol = parsed.scheme or 'https' host = parsed.hostname or '127.0.0.1' port = parsed.port or (27124 if protocol == 'https' else 27123) + path = parsed.path else: # Support legacy formats # 1) hostname/IP only @@ -93,15 +98,17 @@ class Obsidian(): protocol = 'https' host = parsed.hostname or '127.0.0.1' port = parsed.port or 27124 + path = '' else: protocol = 'https' host = host_config port = 27124 + path = '' - return protocol, host, port + return protocol, host, port, path def get_base_url(self) -> str: - return f'{self.protocol}://{self.host}:{self.port}' + return f'{self.protocol}://{self.host}:{self.port}{self.path}' def _get_headers(self) -> dict: headers = { diff --git a/src/mcp_obsidian/server.py b/src/mcp_obsidian/server.py index 41a4f76..7ff2850 100644 --- a/src/mcp_obsidian/server.py +++ b/src/mcp_obsidian/server.py @@ -5,7 +5,8 @@ from functools import lru_cache from typing import Any import os from dotenv import load_dotenv -from mcp.server import Server +from mcp.server import Server as MCPServer +from contextlib import asynccontextmanager from mcp.types import ( Tool, TextContent, @@ -15,6 +16,7 @@ from mcp.types import ( load_dotenv() +from . import obsidian from . import tools # Load environment variables @@ -23,11 +25,17 @@ from . import tools logging.basicConfig(level=logging.INFO) logger = logging.getLogger("mcp-obsidian") -api_key = os.getenv("OBSIDIAN_API_KEY") -if not api_key: - raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}") +@asynccontextmanager +async def lifespan(app: MCPServer): + api_key = os.getenv("OBSIDIAN_API_KEY") + if not api_key: + raise ValueError(f"OBSIDSIAN_API_KEY environment variable required. Working directory: {os.getcwd()}") + yield -app = Server("mcp-obsidian") +app = MCPServer( + "mcp-obsidian", + lifespan=lifespan, +) tool_handlers = {} def add_tool_handler(tool_class: tools.ToolHandler): diff --git a/src/mcp_obsidian/tools.py b/src/mcp_obsidian/tools.py index e49c061..4d4a69c 100644 --- a/src/mcp_obsidian/tools.py +++ b/src/mcp_obsidian/tools.py @@ -9,19 +9,6 @@ import json import os from . import obsidian -# Load environment variables -api_key = os.getenv("OBSIDIAN_API_KEY", "") -obsidian_host = os.getenv("OBSIDIAN_HOST", "https://127.0.0.1:27124") - -if api_key == "": - raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}") - -# Parse the OBSIDIAN_HOST configuration at module level for validation -try: - protocol, host, port = obsidian.Obsidian.parse_host_config(obsidian_host) -except ValueError as e: - raise ValueError(f"Invalid OBSIDIAN_HOST configuration: {str(e)}") - def create_obsidian_api() -> obsidian.Obsidian: """Factory function to create Obsidian API instances. @@ -34,12 +21,26 @@ def create_obsidian_api() -> obsidian.Obsidian: Raises: Exception: If configuration is invalid or instance creation fails """ + # Load environment variables + api_key = os.getenv("OBSIDIAN_API_KEY", "") + obsidian_host = os.getenv("OBSIDIAN_HOST", "https://127.0.0.1:27124") + + if api_key == "": + raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}") + + # Parse the OBSIDIAN_HOST configuration at module level for validation + try: + protocol, host, port, path = obsidian.Obsidian.parse_host_config(obsidian_host) + except ValueError as e: + raise ValueError(f"Invalid OBSIDIAN_HOST configuration: {str(e)}") + try: return obsidian.Obsidian( api_key=api_key, protocol=protocol, host=host, port=port, + path=path, verify_ssl=False # Default to False for local development ) except Exception as e: diff --git a/tests/integration_test.py b/tests/integration_test.py new file mode 100644 index 0000000..1f269a6 --- /dev/null +++ b/tests/integration_test.py @@ -0,0 +1,40 @@ +import os +import sys +import unittest + +# Add the src directory to the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) + +from mcp_obsidian import obsidian + +class TestObsidianIntegration(unittest.TestCase): + def setUp(self): + """Set up the environment variables for the test.""" + os.environ['OBSIDIAN_API_KEY'] = 'REDACTED_API_KEY' + os.environ['OBSIDIAN_HOST'] = 'http://obsidian.obsidian.svc.cluster.local:27123' + + def test_connection(self): + """Test the connection to the Obsidian API.""" + try: + protocol, host, port, path = obsidian.Obsidian.parse_host_config(os.environ['OBSIDIAN_HOST']) + + api = obsidian.Obsidian( + api_key=os.environ['OBSIDIAN_API_KEY'], + protocol=protocol, + host=host, + port=port, + path=path, + verify_ssl=False + ) + + # Use a basic API call to verify the connection + files = api.list_files_in_vault() + + self.assertIsNotNone(files) + print("Successfully connected to Obsidian API and listed files.") + + except Exception as e: + self.fail(f"Failed to connect to Obsidian API: {e}") + +if __name__ == '__main__': + unittest.main()