feat: enhance Obsidian API configuration with path support and implement async lifespan for server

This commit is contained in:
2025-08-17 05:29:37 +00:00
parent cf48f23e8e
commit e5e1c1e11c
4 changed files with 78 additions and 22 deletions

View File

@@ -11,6 +11,7 @@ class Obsidian():
protocol: str = os.getenv('OBSIDIAN_PROTOCOL', 'https').lower(), protocol: str = os.getenv('OBSIDIAN_PROTOCOL', 'https').lower(),
host: str = str(os.getenv('OBSIDIAN_HOST', '127.0.0.1')), host: str = str(os.getenv('OBSIDIAN_HOST', '127.0.0.1')),
port: int = int(os.getenv('OBSIDIAN_PORT', '27124')), port: int = int(os.getenv('OBSIDIAN_PORT', '27124')),
path: str = '',
verify_ssl: bool = False, verify_ssl: bool = False,
): ):
self.api_key = api_key self.api_key = api_key
@@ -22,6 +23,7 @@ class Obsidian():
self.host = host self.host = host
self.port = port self.port = port
self.path = path.rstrip('/')
self.verify_ssl = verify_ssl self.verify_ssl = verify_ssl
self.timeout = (3, 6) self.timeout = (3, 6)
@@ -52,6 +54,7 @@ class Obsidian():
protocol = parsed.scheme protocol = parsed.scheme
host = parsed.hostname host = parsed.hostname
port = parsed.port port = parsed.port
path = parsed.path
# Set default ports based on protocol if not specified # Set default ports based on protocol if not specified
if port is None: if port is None:
@@ -62,20 +65,21 @@ class Obsidian():
protocol=protocol, protocol=protocol,
host=host, host=host,
port=port, port=port,
path=path,
verify_ssl=verify_ssl verify_ssl=verify_ssl
) )
except Exception as e: except Exception as e:
raise ValueError(f"Failed to parse OBSIDIAN_HOST URL '{url}': {str(e)}") raise ValueError(f"Failed to parse OBSIDIAN_HOST URL '{url}': {str(e)}")
@staticmethod @staticmethod
def parse_host_config(host_config: str) -> Tuple[str, str, int]: def parse_host_config(host_config: str) -> Tuple[str, str, int, str]:
"""Parse host configuration string. """Parse host configuration string.
Args: Args:
host_config: Either a full URL (http://host:port) or just hostname/IP host_config: Either a full URL (http://host:port) or just hostname/IP
Returns: Returns:
Tuple of (protocol, host, port) Tuple of (protocol, host, port, path)
""" """
if '://' in host_config: if '://' in host_config:
# Full URL format # Full URL format
@@ -83,6 +87,7 @@ class Obsidian():
protocol = parsed.scheme or 'https' protocol = parsed.scheme or 'https'
host = parsed.hostname or '127.0.0.1' host = parsed.hostname or '127.0.0.1'
port = parsed.port or (27124 if protocol == 'https' else 27123) port = parsed.port or (27124 if protocol == 'https' else 27123)
path = parsed.path
else: else:
# Support legacy formats # Support legacy formats
# 1) hostname/IP only # 1) hostname/IP only
@@ -93,15 +98,17 @@ class Obsidian():
protocol = 'https' protocol = 'https'
host = parsed.hostname or '127.0.0.1' host = parsed.hostname or '127.0.0.1'
port = parsed.port or 27124 port = parsed.port or 27124
path = ''
else: else:
protocol = 'https' protocol = 'https'
host = host_config host = host_config
port = 27124 port = 27124
path = ''
return protocol, host, port return protocol, host, port, path
def get_base_url(self) -> str: def get_base_url(self) -> str:
return f'{self.protocol}://{self.host}:{self.port}' return f'{self.protocol}://{self.host}:{self.port}{self.path}'
def _get_headers(self) -> dict: def _get_headers(self) -> dict:
headers = { headers = {

View File

@@ -5,7 +5,8 @@ from functools import lru_cache
from typing import Any from typing import Any
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
from mcp.server import Server from mcp.server import Server as MCPServer
from contextlib import asynccontextmanager
from mcp.types import ( from mcp.types import (
Tool, Tool,
TextContent, TextContent,
@@ -15,6 +16,7 @@ from mcp.types import (
load_dotenv() load_dotenv()
from . import obsidian
from . import tools from . import tools
# Load environment variables # Load environment variables
@@ -23,11 +25,17 @@ from . import tools
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("mcp-obsidian") logger = logging.getLogger("mcp-obsidian")
api_key = os.getenv("OBSIDIAN_API_KEY") @asynccontextmanager
if not api_key: async def lifespan(app: MCPServer):
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}") api_key = os.getenv("OBSIDIAN_API_KEY")
if not api_key:
raise ValueError(f"OBSIDSIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
yield
app = Server("mcp-obsidian") app = MCPServer(
"mcp-obsidian",
lifespan=lifespan,
)
tool_handlers = {} tool_handlers = {}
def add_tool_handler(tool_class: tools.ToolHandler): def add_tool_handler(tool_class: tools.ToolHandler):

View File

@@ -9,19 +9,6 @@ import json
import os import os
from . import obsidian from . import obsidian
# Load environment variables
api_key = os.getenv("OBSIDIAN_API_KEY", "")
obsidian_host = os.getenv("OBSIDIAN_HOST", "https://127.0.0.1:27124")
if api_key == "":
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
# Parse the OBSIDIAN_HOST configuration at module level for validation
try:
protocol, host, port = obsidian.Obsidian.parse_host_config(obsidian_host)
except ValueError as e:
raise ValueError(f"Invalid OBSIDIAN_HOST configuration: {str(e)}")
def create_obsidian_api() -> obsidian.Obsidian: def create_obsidian_api() -> obsidian.Obsidian:
"""Factory function to create Obsidian API instances. """Factory function to create Obsidian API instances.
@@ -34,12 +21,26 @@ def create_obsidian_api() -> obsidian.Obsidian:
Raises: Raises:
Exception: If configuration is invalid or instance creation fails Exception: If configuration is invalid or instance creation fails
""" """
# Load environment variables
api_key = os.getenv("OBSIDIAN_API_KEY", "")
obsidian_host = os.getenv("OBSIDIAN_HOST", "https://127.0.0.1:27124")
if api_key == "":
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
# Parse the OBSIDIAN_HOST configuration at module level for validation
try:
protocol, host, port, path = obsidian.Obsidian.parse_host_config(obsidian_host)
except ValueError as e:
raise ValueError(f"Invalid OBSIDIAN_HOST configuration: {str(e)}")
try: try:
return obsidian.Obsidian( return obsidian.Obsidian(
api_key=api_key, api_key=api_key,
protocol=protocol, protocol=protocol,
host=host, host=host,
port=port, port=port,
path=path,
verify_ssl=False # Default to False for local development verify_ssl=False # Default to False for local development
) )
except Exception as e: except Exception as e:

40
tests/integration_test.py Normal file
View File

@@ -0,0 +1,40 @@
import os
import sys
import unittest
# Add the src directory to the Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
from mcp_obsidian import obsidian
class TestObsidianIntegration(unittest.TestCase):
def setUp(self):
"""Set up the environment variables for the test."""
os.environ['OBSIDIAN_API_KEY'] = 'REDACTED_API_KEY'
os.environ['OBSIDIAN_HOST'] = 'http://obsidian.obsidian.svc.cluster.local:27123'
def test_connection(self):
"""Test the connection to the Obsidian API."""
try:
protocol, host, port, path = obsidian.Obsidian.parse_host_config(os.environ['OBSIDIAN_HOST'])
api = obsidian.Obsidian(
api_key=os.environ['OBSIDIAN_API_KEY'],
protocol=protocol,
host=host,
port=port,
path=path,
verify_ssl=False
)
# Use a basic API call to verify the connection
files = api.list_files_in_vault()
self.assertIsNotNone(files)
print("Successfully connected to Obsidian API and listed files.")
except Exception as e:
self.fail(f"Failed to connect to Obsidian API: {e}")
if __name__ == '__main__':
unittest.main()