feat: enhance Obsidian API configuration with path support and implement async lifespan for server

This commit is contained in:
2025-08-17 05:29:37 +00:00
parent cf48f23e8e
commit e5e1c1e11c
4 changed files with 78 additions and 22 deletions

View File

@@ -11,6 +11,7 @@ class Obsidian():
protocol: str = os.getenv('OBSIDIAN_PROTOCOL', 'https').lower(),
host: str = str(os.getenv('OBSIDIAN_HOST', '127.0.0.1')),
port: int = int(os.getenv('OBSIDIAN_PORT', '27124')),
path: str = '',
verify_ssl: bool = False,
):
self.api_key = api_key
@@ -22,6 +23,7 @@ class Obsidian():
self.host = host
self.port = port
self.path = path.rstrip('/')
self.verify_ssl = verify_ssl
self.timeout = (3, 6)
@@ -52,6 +54,7 @@ class Obsidian():
protocol = parsed.scheme
host = parsed.hostname
port = parsed.port
path = parsed.path
# Set default ports based on protocol if not specified
if port is None:
@@ -62,20 +65,21 @@ class Obsidian():
protocol=protocol,
host=host,
port=port,
path=path,
verify_ssl=verify_ssl
)
except Exception as e:
raise ValueError(f"Failed to parse OBSIDIAN_HOST URL '{url}': {str(e)}")
@staticmethod
def parse_host_config(host_config: str) -> Tuple[str, str, int]:
def parse_host_config(host_config: str) -> Tuple[str, str, int, str]:
"""Parse host configuration string.
Args:
host_config: Either a full URL (http://host:port) or just hostname/IP
Returns:
Tuple of (protocol, host, port)
Tuple of (protocol, host, port, path)
"""
if '://' in host_config:
# Full URL format
@@ -83,6 +87,7 @@ class Obsidian():
protocol = parsed.scheme or 'https'
host = parsed.hostname or '127.0.0.1'
port = parsed.port or (27124 if protocol == 'https' else 27123)
path = parsed.path
else:
# Support legacy formats
# 1) hostname/IP only
@@ -93,15 +98,17 @@ class Obsidian():
protocol = 'https'
host = parsed.hostname or '127.0.0.1'
port = parsed.port or 27124
path = ''
else:
protocol = 'https'
host = host_config
port = 27124
path = ''
return protocol, host, port
return protocol, host, port, path
def get_base_url(self) -> str:
return f'{self.protocol}://{self.host}:{self.port}'
return f'{self.protocol}://{self.host}:{self.port}{self.path}'
def _get_headers(self) -> dict:
headers = {

View File

@@ -5,7 +5,8 @@ from functools import lru_cache
from typing import Any
import os
from dotenv import load_dotenv
from mcp.server import Server
from mcp.server import Server as MCPServer
from contextlib import asynccontextmanager
from mcp.types import (
Tool,
TextContent,
@@ -15,6 +16,7 @@ from mcp.types import (
load_dotenv()
from . import obsidian
from . import tools
# Load environment variables
@@ -23,11 +25,17 @@ from . import tools
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("mcp-obsidian")
@asynccontextmanager
async def lifespan(app: MCPServer):
api_key = os.getenv("OBSIDIAN_API_KEY")
if not api_key:
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
raise ValueError(f"OBSIDSIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
yield
app = Server("mcp-obsidian")
app = MCPServer(
"mcp-obsidian",
lifespan=lifespan,
)
tool_handlers = {}
def add_tool_handler(tool_class: tools.ToolHandler):

View File

@@ -9,19 +9,6 @@ import json
import os
from . import obsidian
# Load environment variables
api_key = os.getenv("OBSIDIAN_API_KEY", "")
obsidian_host = os.getenv("OBSIDIAN_HOST", "https://127.0.0.1:27124")
if api_key == "":
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
# Parse the OBSIDIAN_HOST configuration at module level for validation
try:
protocol, host, port = obsidian.Obsidian.parse_host_config(obsidian_host)
except ValueError as e:
raise ValueError(f"Invalid OBSIDIAN_HOST configuration: {str(e)}")
def create_obsidian_api() -> obsidian.Obsidian:
"""Factory function to create Obsidian API instances.
@@ -34,12 +21,26 @@ def create_obsidian_api() -> obsidian.Obsidian:
Raises:
Exception: If configuration is invalid or instance creation fails
"""
# Load environment variables
api_key = os.getenv("OBSIDIAN_API_KEY", "")
obsidian_host = os.getenv("OBSIDIAN_HOST", "https://127.0.0.1:27124")
if api_key == "":
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
# Parse the OBSIDIAN_HOST configuration at module level for validation
try:
protocol, host, port, path = obsidian.Obsidian.parse_host_config(obsidian_host)
except ValueError as e:
raise ValueError(f"Invalid OBSIDIAN_HOST configuration: {str(e)}")
try:
return obsidian.Obsidian(
api_key=api_key,
protocol=protocol,
host=host,
port=port,
path=path,
verify_ssl=False # Default to False for local development
)
except Exception as e:

40
tests/integration_test.py Normal file
View File

@@ -0,0 +1,40 @@
import os
import sys
import unittest
# Add the src directory to the Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
from mcp_obsidian import obsidian
class TestObsidianIntegration(unittest.TestCase):
def setUp(self):
"""Set up the environment variables for the test."""
os.environ['OBSIDIAN_API_KEY'] = 'REDACTED_API_KEY'
os.environ['OBSIDIAN_HOST'] = 'http://obsidian.obsidian.svc.cluster.local:27123'
def test_connection(self):
"""Test the connection to the Obsidian API."""
try:
protocol, host, port, path = obsidian.Obsidian.parse_host_config(os.environ['OBSIDIAN_HOST'])
api = obsidian.Obsidian(
api_key=os.environ['OBSIDIAN_API_KEY'],
protocol=protocol,
host=host,
port=port,
path=path,
verify_ssl=False
)
# Use a basic API call to verify the connection
files = api.list_files_in_vault()
self.assertIsNotNone(files)
print("Successfully connected to Obsidian API and listed files.")
except Exception as e:
self.fail(f"Failed to connect to Obsidian API: {e}")
if __name__ == '__main__':
unittest.main()