feat: enhance Obsidian API configuration with path support and implement async lifespan for server
This commit is contained in:
@@ -11,6 +11,7 @@ class Obsidian():
|
|||||||
protocol: str = os.getenv('OBSIDIAN_PROTOCOL', 'https').lower(),
|
protocol: str = os.getenv('OBSIDIAN_PROTOCOL', 'https').lower(),
|
||||||
host: str = str(os.getenv('OBSIDIAN_HOST', '127.0.0.1')),
|
host: str = str(os.getenv('OBSIDIAN_HOST', '127.0.0.1')),
|
||||||
port: int = int(os.getenv('OBSIDIAN_PORT', '27124')),
|
port: int = int(os.getenv('OBSIDIAN_PORT', '27124')),
|
||||||
|
path: str = '',
|
||||||
verify_ssl: bool = False,
|
verify_ssl: bool = False,
|
||||||
):
|
):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
@@ -22,6 +23,7 @@ class Obsidian():
|
|||||||
|
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
|
self.path = path.rstrip('/')
|
||||||
self.verify_ssl = verify_ssl
|
self.verify_ssl = verify_ssl
|
||||||
self.timeout = (3, 6)
|
self.timeout = (3, 6)
|
||||||
|
|
||||||
@@ -52,6 +54,7 @@ class Obsidian():
|
|||||||
protocol = parsed.scheme
|
protocol = parsed.scheme
|
||||||
host = parsed.hostname
|
host = parsed.hostname
|
||||||
port = parsed.port
|
port = parsed.port
|
||||||
|
path = parsed.path
|
||||||
|
|
||||||
# Set default ports based on protocol if not specified
|
# Set default ports based on protocol if not specified
|
||||||
if port is None:
|
if port is None:
|
||||||
@@ -62,20 +65,21 @@ class Obsidian():
|
|||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
|
path=path,
|
||||||
verify_ssl=verify_ssl
|
verify_ssl=verify_ssl
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to parse OBSIDIAN_HOST URL '{url}': {str(e)}")
|
raise ValueError(f"Failed to parse OBSIDIAN_HOST URL '{url}': {str(e)}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_host_config(host_config: str) -> Tuple[str, str, int]:
|
def parse_host_config(host_config: str) -> Tuple[str, str, int, str]:
|
||||||
"""Parse host configuration string.
|
"""Parse host configuration string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
host_config: Either a full URL (http://host:port) or just hostname/IP
|
host_config: Either a full URL (http://host:port) or just hostname/IP
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (protocol, host, port)
|
Tuple of (protocol, host, port, path)
|
||||||
"""
|
"""
|
||||||
if '://' in host_config:
|
if '://' in host_config:
|
||||||
# Full URL format
|
# Full URL format
|
||||||
@@ -83,6 +87,7 @@ class Obsidian():
|
|||||||
protocol = parsed.scheme or 'https'
|
protocol = parsed.scheme or 'https'
|
||||||
host = parsed.hostname or '127.0.0.1'
|
host = parsed.hostname or '127.0.0.1'
|
||||||
port = parsed.port or (27124 if protocol == 'https' else 27123)
|
port = parsed.port or (27124 if protocol == 'https' else 27123)
|
||||||
|
path = parsed.path
|
||||||
else:
|
else:
|
||||||
# Support legacy formats
|
# Support legacy formats
|
||||||
# 1) hostname/IP only
|
# 1) hostname/IP only
|
||||||
@@ -93,15 +98,17 @@ class Obsidian():
|
|||||||
protocol = 'https'
|
protocol = 'https'
|
||||||
host = parsed.hostname or '127.0.0.1'
|
host = parsed.hostname or '127.0.0.1'
|
||||||
port = parsed.port or 27124
|
port = parsed.port or 27124
|
||||||
|
path = ''
|
||||||
else:
|
else:
|
||||||
protocol = 'https'
|
protocol = 'https'
|
||||||
host = host_config
|
host = host_config
|
||||||
port = 27124
|
port = 27124
|
||||||
|
path = ''
|
||||||
|
|
||||||
return protocol, host, port
|
return protocol, host, port, path
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return f'{self.protocol}://{self.host}:{self.port}'
|
return f'{self.protocol}://{self.host}:{self.port}{self.path}'
|
||||||
|
|
||||||
def _get_headers(self) -> dict:
|
def _get_headers(self) -> dict:
|
||||||
headers = {
|
headers = {
|
||||||
|
@@ -5,7 +5,8 @@ from functools import lru_cache
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from mcp.server import Server
|
from mcp.server import Server as MCPServer
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
Tool,
|
Tool,
|
||||||
TextContent,
|
TextContent,
|
||||||
@@ -15,6 +16,7 @@ from mcp.types import (
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
from . import obsidian
|
||||||
from . import tools
|
from . import tools
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
@@ -23,11 +25,17 @@ from . import tools
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger("mcp-obsidian")
|
logger = logging.getLogger("mcp-obsidian")
|
||||||
|
|
||||||
api_key = os.getenv("OBSIDIAN_API_KEY")
|
@asynccontextmanager
|
||||||
if not api_key:
|
async def lifespan(app: MCPServer):
|
||||||
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
|
api_key = os.getenv("OBSIDIAN_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError(f"OBSIDSIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
|
||||||
|
yield
|
||||||
|
|
||||||
app = Server("mcp-obsidian")
|
app = MCPServer(
|
||||||
|
"mcp-obsidian",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
tool_handlers = {}
|
tool_handlers = {}
|
||||||
def add_tool_handler(tool_class: tools.ToolHandler):
|
def add_tool_handler(tool_class: tools.ToolHandler):
|
||||||
|
@@ -9,19 +9,6 @@ import json
|
|||||||
import os
|
import os
|
||||||
from . import obsidian
|
from . import obsidian
|
||||||
|
|
||||||
# Load environment variables
|
|
||||||
api_key = os.getenv("OBSIDIAN_API_KEY", "")
|
|
||||||
obsidian_host = os.getenv("OBSIDIAN_HOST", "https://127.0.0.1:27124")
|
|
||||||
|
|
||||||
if api_key == "":
|
|
||||||
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
|
|
||||||
|
|
||||||
# Parse the OBSIDIAN_HOST configuration at module level for validation
|
|
||||||
try:
|
|
||||||
protocol, host, port = obsidian.Obsidian.parse_host_config(obsidian_host)
|
|
||||||
except ValueError as e:
|
|
||||||
raise ValueError(f"Invalid OBSIDIAN_HOST configuration: {str(e)}")
|
|
||||||
|
|
||||||
def create_obsidian_api() -> obsidian.Obsidian:
|
def create_obsidian_api() -> obsidian.Obsidian:
|
||||||
"""Factory function to create Obsidian API instances.
|
"""Factory function to create Obsidian API instances.
|
||||||
|
|
||||||
@@ -34,12 +21,26 @@ def create_obsidian_api() -> obsidian.Obsidian:
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: If configuration is invalid or instance creation fails
|
Exception: If configuration is invalid or instance creation fails
|
||||||
"""
|
"""
|
||||||
|
# Load environment variables
|
||||||
|
api_key = os.getenv("OBSIDIAN_API_KEY", "")
|
||||||
|
obsidian_host = os.getenv("OBSIDIAN_HOST", "https://127.0.0.1:27124")
|
||||||
|
|
||||||
|
if api_key == "":
|
||||||
|
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
|
||||||
|
|
||||||
|
# Parse the OBSIDIAN_HOST configuration at module level for validation
|
||||||
|
try:
|
||||||
|
protocol, host, port, path = obsidian.Obsidian.parse_host_config(obsidian_host)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Invalid OBSIDIAN_HOST configuration: {str(e)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return obsidian.Obsidian(
|
return obsidian.Obsidian(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
|
path=path,
|
||||||
verify_ssl=False # Default to False for local development
|
verify_ssl=False # Default to False for local development
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
40
tests/integration_test.py
Normal file
40
tests/integration_test.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
# Add the src directory to the Python path
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
|
||||||
|
|
||||||
|
from mcp_obsidian import obsidian
|
||||||
|
|
||||||
|
class TestObsidianIntegration(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up the environment variables for the test."""
|
||||||
|
os.environ['OBSIDIAN_API_KEY'] = 'REDACTED_API_KEY'
|
||||||
|
os.environ['OBSIDIAN_HOST'] = 'http://obsidian.obsidian.svc.cluster.local:27123'
|
||||||
|
|
||||||
|
def test_connection(self):
|
||||||
|
"""Test the connection to the Obsidian API."""
|
||||||
|
try:
|
||||||
|
protocol, host, port, path = obsidian.Obsidian.parse_host_config(os.environ['OBSIDIAN_HOST'])
|
||||||
|
|
||||||
|
api = obsidian.Obsidian(
|
||||||
|
api_key=os.environ['OBSIDIAN_API_KEY'],
|
||||||
|
protocol=protocol,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
path=path,
|
||||||
|
verify_ssl=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use a basic API call to verify the connection
|
||||||
|
files = api.list_files_in_vault()
|
||||||
|
|
||||||
|
self.assertIsNotNone(files)
|
||||||
|
print("Successfully connected to Obsidian API and listed files.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.fail(f"Failed to connect to Obsidian API: {e}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Reference in New Issue
Block a user