feat: enhance Obsidian API configuration with path support and implement async lifespan for server
This commit is contained in:
@@ -11,6 +11,7 @@ class Obsidian():
|
||||
protocol: str = os.getenv('OBSIDIAN_PROTOCOL', 'https').lower(),
|
||||
host: str = str(os.getenv('OBSIDIAN_HOST', '127.0.0.1')),
|
||||
port: int = int(os.getenv('OBSIDIAN_PORT', '27124')),
|
||||
path: str = '',
|
||||
verify_ssl: bool = False,
|
||||
):
|
||||
self.api_key = api_key
|
||||
@@ -22,6 +23,7 @@ class Obsidian():
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.path = path.rstrip('/')
|
||||
self.verify_ssl = verify_ssl
|
||||
self.timeout = (3, 6)
|
||||
|
||||
@@ -52,6 +54,7 @@ class Obsidian():
|
||||
protocol = parsed.scheme
|
||||
host = parsed.hostname
|
||||
port = parsed.port
|
||||
path = parsed.path
|
||||
|
||||
# Set default ports based on protocol if not specified
|
||||
if port is None:
|
||||
@@ -62,20 +65,21 @@ class Obsidian():
|
||||
protocol=protocol,
|
||||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
verify_ssl=verify_ssl
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse OBSIDIAN_HOST URL '{url}': {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def parse_host_config(host_config: str) -> Tuple[str, str, int]:
|
||||
def parse_host_config(host_config: str) -> Tuple[str, str, int, str]:
|
||||
"""Parse host configuration string.
|
||||
|
||||
Args:
|
||||
host_config: Either a full URL (http://host:port) or just hostname/IP
|
||||
|
||||
Returns:
|
||||
Tuple of (protocol, host, port)
|
||||
Tuple of (protocol, host, port, path)
|
||||
"""
|
||||
if '://' in host_config:
|
||||
# Full URL format
|
||||
@@ -83,6 +87,7 @@ class Obsidian():
|
||||
protocol = parsed.scheme or 'https'
|
||||
host = parsed.hostname or '127.0.0.1'
|
||||
port = parsed.port or (27124 if protocol == 'https' else 27123)
|
||||
path = parsed.path
|
||||
else:
|
||||
# Support legacy formats
|
||||
# 1) hostname/IP only
|
||||
@@ -93,15 +98,17 @@ class Obsidian():
|
||||
protocol = 'https'
|
||||
host = parsed.hostname or '127.0.0.1'
|
||||
port = parsed.port or 27124
|
||||
path = ''
|
||||
else:
|
||||
protocol = 'https'
|
||||
host = host_config
|
||||
port = 27124
|
||||
path = ''
|
||||
|
||||
return protocol, host, port
|
||||
return protocol, host, port, path
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f'{self.protocol}://{self.host}:{self.port}'
|
||||
return f'{self.protocol}://{self.host}:{self.port}{self.path}'
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
headers = {
|
||||
|
@@ -5,7 +5,8 @@ from functools import lru_cache
|
||||
from typing import Any
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from mcp.server import Server
|
||||
from mcp.server import Server as MCPServer
|
||||
from contextlib import asynccontextmanager
|
||||
from mcp.types import (
|
||||
Tool,
|
||||
TextContent,
|
||||
@@ -15,6 +16,7 @@ from mcp.types import (
|
||||
|
||||
load_dotenv()
|
||||
|
||||
from . import obsidian
|
||||
from . import tools
|
||||
|
||||
# Load environment variables
|
||||
@@ -23,11 +25,17 @@ from . import tools
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("mcp-obsidian")
|
||||
|
||||
api_key = os.getenv("OBSIDIAN_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: MCPServer):
|
||||
api_key = os.getenv("OBSIDIAN_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(f"OBSIDSIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
|
||||
yield
|
||||
|
||||
app = Server("mcp-obsidian")
|
||||
app = MCPServer(
|
||||
"mcp-obsidian",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
tool_handlers = {}
|
||||
def add_tool_handler(tool_class: tools.ToolHandler):
|
||||
|
@@ -9,19 +9,6 @@ import json
|
||||
import os
|
||||
from . import obsidian
|
||||
|
||||
# Load environment variables
|
||||
api_key = os.getenv("OBSIDIAN_API_KEY", "")
|
||||
obsidian_host = os.getenv("OBSIDIAN_HOST", "https://127.0.0.1:27124")
|
||||
|
||||
if api_key == "":
|
||||
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
|
||||
|
||||
# Parse the OBSIDIAN_HOST configuration at module level for validation
|
||||
try:
|
||||
protocol, host, port = obsidian.Obsidian.parse_host_config(obsidian_host)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid OBSIDIAN_HOST configuration: {str(e)}")
|
||||
|
||||
def create_obsidian_api() -> obsidian.Obsidian:
|
||||
"""Factory function to create Obsidian API instances.
|
||||
|
||||
@@ -34,12 +21,26 @@ def create_obsidian_api() -> obsidian.Obsidian:
|
||||
Raises:
|
||||
Exception: If configuration is invalid or instance creation fails
|
||||
"""
|
||||
# Load environment variables
|
||||
api_key = os.getenv("OBSIDIAN_API_KEY", "")
|
||||
obsidian_host = os.getenv("OBSIDIAN_HOST", "https://127.0.0.1:27124")
|
||||
|
||||
if api_key == "":
|
||||
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
|
||||
|
||||
# Parse the OBSIDIAN_HOST configuration at module level for validation
|
||||
try:
|
||||
protocol, host, port, path = obsidian.Obsidian.parse_host_config(obsidian_host)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid OBSIDIAN_HOST configuration: {str(e)}")
|
||||
|
||||
try:
|
||||
return obsidian.Obsidian(
|
||||
api_key=api_key,
|
||||
protocol=protocol,
|
||||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
verify_ssl=False # Default to False for local development
|
||||
)
|
||||
except Exception as e:
|
||||
|
40
tests/integration_test.py
Normal file
40
tests/integration_test.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
# Add the src directory to the Python path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
|
||||
|
||||
from mcp_obsidian import obsidian
|
||||
|
||||
class TestObsidianIntegration(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Set up the environment variables for the test."""
|
||||
os.environ['OBSIDIAN_API_KEY'] = 'REDACTED_API_KEY'
|
||||
os.environ['OBSIDIAN_HOST'] = 'http://obsidian.obsidian.svc.cluster.local:27123'
|
||||
|
||||
def test_connection(self):
|
||||
"""Test the connection to the Obsidian API."""
|
||||
try:
|
||||
protocol, host, port, path = obsidian.Obsidian.parse_host_config(os.environ['OBSIDIAN_HOST'])
|
||||
|
||||
api = obsidian.Obsidian(
|
||||
api_key=os.environ['OBSIDIAN_API_KEY'],
|
||||
protocol=protocol,
|
||||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
verify_ssl=False
|
||||
)
|
||||
|
||||
# Use a basic API call to verify the connection
|
||||
files = api.list_files_in_vault()
|
||||
|
||||
self.assertIsNotNone(files)
|
||||
print("Successfully connected to Obsidian API and listed files.")
|
||||
|
||||
except Exception as e:
|
||||
self.fail(f"Failed to connect to Obsidian API: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Reference in New Issue
Block a user