Compare commits
5 Commits
feat/suppo
...
main
Author | SHA1 | Date | |
---|---|---|---|
e5e1c1e11c | |||
cf48f23e8e | |||
0b746f65e9 | |||
897822ecaa | |||
00245ef40b |
@@ -11,6 +11,7 @@ class Obsidian():
|
||||
protocol: str = os.getenv('OBSIDIAN_PROTOCOL', 'https').lower(),
|
||||
host: str = str(os.getenv('OBSIDIAN_HOST', '127.0.0.1')),
|
||||
port: int = int(os.getenv('OBSIDIAN_PORT', '27124')),
|
||||
path: str = '',
|
||||
verify_ssl: bool = False,
|
||||
):
|
||||
self.api_key = api_key
|
||||
@@ -22,6 +23,7 @@ class Obsidian():
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.path = path.rstrip('/')
|
||||
self.verify_ssl = verify_ssl
|
||||
self.timeout = (3, 6)
|
||||
|
||||
@@ -52,6 +54,7 @@ class Obsidian():
|
||||
protocol = parsed.scheme
|
||||
host = parsed.hostname
|
||||
port = parsed.port
|
||||
path = parsed.path
|
||||
|
||||
# Set default ports based on protocol if not specified
|
||||
if port is None:
|
||||
@@ -62,20 +65,21 @@ class Obsidian():
|
||||
protocol=protocol,
|
||||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
verify_ssl=verify_ssl
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse OBSIDIAN_HOST URL '{url}': {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def parse_host_config(host_config: str) -> Tuple[str, str, int]:
|
||||
def parse_host_config(host_config: str) -> Tuple[str, str, int, str]:
|
||||
"""Parse host configuration string.
|
||||
|
||||
Args:
|
||||
host_config: Either a full URL (http://host:port) or just hostname/IP
|
||||
|
||||
Returns:
|
||||
Tuple of (protocol, host, port)
|
||||
Tuple of (protocol, host, port, path)
|
||||
"""
|
||||
if '://' in host_config:
|
||||
# Full URL format
|
||||
@@ -83,16 +87,28 @@ class Obsidian():
|
||||
protocol = parsed.scheme or 'https'
|
||||
host = parsed.hostname or '127.0.0.1'
|
||||
port = parsed.port or (27124 if protocol == 'https' else 27123)
|
||||
path = parsed.path
|
||||
else:
|
||||
# Legacy hostname/IP only format
|
||||
protocol = 'https'
|
||||
host = host_config
|
||||
port = 27124
|
||||
# Support legacy formats
|
||||
# 1) hostname/IP only
|
||||
# 2) hostname:port (no protocol)
|
||||
if ':' in host_config:
|
||||
# Treat as host:port and default protocol to https
|
||||
parsed = urlparse(f'https://{host_config}')
|
||||
protocol = 'https'
|
||||
host = parsed.hostname or '127.0.0.1'
|
||||
port = parsed.port or 27124
|
||||
path = ''
|
||||
else:
|
||||
protocol = 'https'
|
||||
host = host_config
|
||||
port = 27124
|
||||
path = ''
|
||||
|
||||
return protocol, host, port
|
||||
return protocol, host, port, path
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f'{self.protocol}://{self.host}:{self.port}'
|
||||
return f'{self.protocol}://{self.host}:{self.port}{self.path}'
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
headers = {
|
||||
@@ -297,7 +313,7 @@ class Obsidian():
|
||||
Returns:
|
||||
List of recent periodic notes
|
||||
"""
|
||||
url = f"{self.get_base_url()}/periodic/{period}/recent"
|
||||
url = f"{self.get_base_url()}/periodic/{period}/recent/"
|
||||
params = {
|
||||
"limit": limit,
|
||||
"includeContent": include_content
|
||||
|
@@ -5,7 +5,8 @@ from functools import lru_cache
|
||||
from typing import Any
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from mcp.server import Server
|
||||
from mcp.server import Server as MCPServer
|
||||
from contextlib import asynccontextmanager
|
||||
from mcp.types import (
|
||||
Tool,
|
||||
TextContent,
|
||||
@@ -15,6 +16,7 @@ from mcp.types import (
|
||||
|
||||
load_dotenv()
|
||||
|
||||
from . import obsidian
|
||||
from . import tools
|
||||
|
||||
# Load environment variables
|
||||
@@ -23,11 +25,17 @@ from . import tools
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("mcp-obsidian")
|
||||
|
||||
api_key = os.getenv("OBSIDIAN_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: MCPServer):
|
||||
api_key = os.getenv("OBSIDIAN_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(f"OBSIDSIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
|
||||
yield
|
||||
|
||||
app = Server("mcp-obsidian")
|
||||
app = MCPServer(
|
||||
"mcp-obsidian",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
tool_handlers = {}
|
||||
def add_tool_handler(tool_class: tools.ToolHandler):
|
||||
|
@@ -9,19 +9,6 @@ import json
|
||||
import os
|
||||
from . import obsidian
|
||||
|
||||
# Load environment variables
|
||||
api_key = os.getenv("OBSIDIAN_API_KEY", "")
|
||||
obsidian_host = os.getenv("OBSIDIAN_HOST", "https://127.0.0.1:27124")
|
||||
|
||||
if api_key == "":
|
||||
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
|
||||
|
||||
# Parse the OBSIDIAN_HOST configuration at module level for validation
|
||||
try:
|
||||
protocol, host, port = obsidian.Obsidian.parse_host_config(obsidian_host)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid OBSIDIAN_HOST configuration: {str(e)}")
|
||||
|
||||
def create_obsidian_api() -> obsidian.Obsidian:
|
||||
"""Factory function to create Obsidian API instances.
|
||||
|
||||
@@ -34,12 +21,26 @@ def create_obsidian_api() -> obsidian.Obsidian:
|
||||
Raises:
|
||||
Exception: If configuration is invalid or instance creation fails
|
||||
"""
|
||||
# Load environment variables
|
||||
api_key = os.getenv("OBSIDIAN_API_KEY", "")
|
||||
obsidian_host = os.getenv("OBSIDIAN_HOST", "https://127.0.0.1:27124")
|
||||
|
||||
if api_key == "":
|
||||
raise ValueError(f"OBSIDIAN_API_KEY environment variable required. Working directory: {os.getcwd()}")
|
||||
|
||||
# Parse the OBSIDIAN_HOST configuration at module level for validation
|
||||
try:
|
||||
protocol, host, port, path = obsidian.Obsidian.parse_host_config(obsidian_host)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid OBSIDIAN_HOST configuration: {str(e)}")
|
||||
|
||||
try:
|
||||
return obsidian.Obsidian(
|
||||
api_key=api_key,
|
||||
protocol=protocol,
|
||||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
verify_ssl=False # Default to False for local development
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -343,7 +344,7 @@ class PutContentToolHandler(ToolHandler):
|
||||
if "filepath" not in args or "content" not in args:
|
||||
raise RuntimeError("filepath and content arguments required")
|
||||
|
||||
api = obsidian.Obsidian(api_key=api_key, host=obsidian_host)
|
||||
api = create_obsidian_api()
|
||||
api.put_content(args.get("filepath", ""), args["content"])
|
||||
|
||||
return [
|
||||
@@ -541,13 +542,8 @@ class PeriodicNotesToolHandler(ToolHandler):
|
||||
if type not in valid_types:
|
||||
raise RuntimeError(f"Invalid type: {type}. Must be one of: {', '.join(valid_types)}")
|
||||
|
||||
<<<<<<< ours
|
||||
api = create_obsidian_api()
|
||||
content = api.get_periodic_note(period)
|
||||
=======
|
||||
api = create_obsidian_api()
|
||||
content = api.get_periodic_note(period)
|
||||
>>>>>>> theirs
|
||||
|
||||
return [
|
||||
TextContent(
|
||||
|
40
tests/integration_test.py
Normal file
40
tests/integration_test.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
# Add the src directory to the Python path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
|
||||
|
||||
from mcp_obsidian import obsidian
|
||||
|
||||
class TestObsidianIntegration(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Set up the environment variables for the test."""
|
||||
os.environ['OBSIDIAN_API_KEY'] = 'REDACTED_API_KEY'
|
||||
os.environ['OBSIDIAN_HOST'] = 'http://obsidian.obsidian.svc.cluster.local:27123'
|
||||
|
||||
def test_connection(self):
|
||||
"""Test the connection to the Obsidian API."""
|
||||
try:
|
||||
protocol, host, port, path = obsidian.Obsidian.parse_host_config(os.environ['OBSIDIAN_HOST'])
|
||||
|
||||
api = obsidian.Obsidian(
|
||||
api_key=os.environ['OBSIDIAN_API_KEY'],
|
||||
protocol=protocol,
|
||||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
verify_ssl=False
|
||||
)
|
||||
|
||||
# Use a basic API call to verify the connection
|
||||
files = api.list_files_in_vault()
|
||||
|
||||
self.assertIsNotNone(files)
|
||||
print("Successfully connected to Obsidian API and listed files.")
|
||||
|
||||
except Exception as e:
|
||||
self.fail(f"Failed to connect to Obsidian API: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Reference in New Issue
Block a user