Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/mira network integration #2

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
OPENAI_API_KEY=
GOOSEAI_API_KEY=
PERPLEXITYAI_API_KEY=
MIRANETWORK_API_KEY=

RETRY_PARAMS={"tries": 3, "delay": 3, "backoff": 2, "max_delay": 10, "jitter": [0.5, 1.5]}
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def create_app() -> Quart:
CSSProvider.GOOSEAI: os.getenv("GOOSEAI_API_KEY"),
CSSProvider.OPENAI: os.getenv("OPENAI_API_KEY"),
CSSProvider.PERPLEXITYAI: os.getenv("PERPLEXITYAI_API_KEY"),
CSSProvider.MIRANETWORK: os.getenv("MIRANETWORK_API_KEY"),
}
providers = sorted(api_keys.keys())

Expand Down
9 changes: 5 additions & 4 deletions infernet_services/tests/css_inference_service/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
log = logging.getLogger(__name__)

env_vars = {
"PERPLEXITYAI_API_KEY": os.environ["PERPLEXITYAI_API_KEY"],
"GOOSEAI_API_KEY": os.environ["GOOSEAI_API_KEY"],
"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"],
"PERPLEXITYAI_API_KEY": os.environ.get("PERPLEXITYAI_API_KEY", "dummy_key"),
"GOOSEAI_API_KEY": os.environ.get("GOOSEAI_API_KEY", "dummy_key"),
"OPENAI_API_KEY": os.environ.get("OPENAI_API_KEY", "dummy_key"),
"MIRANETWORK_API_KEY": os.environ.get("MIRANETWORK_API_KEY"),
"CSS_INF_WORKFLOW_POSITIONAL_ARGS": "[]",
"CSS_INF_WORKFLOW_KW_ARGS": json.dumps(
{
Expand Down Expand Up @@ -51,7 +52,7 @@
name=CSS_OPENAI_ONLY,
image_id=f"ritualnetwork/{SERVICE_NAME}:{SERVICE_VERSION}",
env_vars={
"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"],
"OPENAI_API_KEY": os.environ.get("OPENAI_API_KEY", "dummy_key"),
"CSS_INF_WORKFLOW_POSITIONAL_ARGS": "[]",
"CSS_INF_WORKFLOW_KW_ARGS": json.dumps(
{
Expand Down
65 changes: 65 additions & 0 deletions infernet_services/tests/css_inference_service/simple_mira_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Simple test for Mira Network integration.
This script tests the Mira Network integration by making direct API calls.
"""

import unittest
import os
import sys
from pathlib import Path
import json
import requests

# Add the project root to the Python path
project_root = str(Path(__file__).parent.parent.parent.parent)
if project_root not in sys.path:
sys.path.append(project_root)

# Use the API key from conftest.py
MIRANETWORK_API_KEY = ""

class TestMiraNetworkAPI(unittest.TestCase):
"""Test class for Mira Network API integration"""

def test_mira_network_api_direct(self):
"""Test Mira Network API directly without using the css_mux function"""
# Define the API endpoint
url = "https://api.mira.network/v1/chat/completions"

# Define the headers
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {MIRANETWORK_API_KEY}"
}

# Define the payload
payload = {
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": "Explain quantum computing in one paragraph."
}
],
"temperature": 0.7,
"max_tokens": 100
}

# Make the API call
response = requests.post(url, headers=headers, json=payload)

# Print the response for debugging
print(f"Status code: {response.status_code}")
print(f"Response: {response.text}")

# Verify the response
self.assertEqual(response.status_code, 200)
response_json = response.json()
self.assertIn("choices", response_json)
self.assertGreater(len(response_json["choices"]), 0)
self.assertIn("message", response_json["choices"][0])
self.assertIn("content", response_json["choices"][0]["message"])
self.assertGreater(len(response_json["choices"][0]["message"]["content"]), 0)

if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion libraries/infernet_ml/src/infernet_ml/utils/codec/css.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class CSSProvider(IntEnum):
OPENAI = 0
GOOSEAI = 1
PERPLEXITYAI = 2

MIRANETWORK = 3

def encode_css_completion_request(
provider: CSSProvider,
Expand Down
7 changes: 7 additions & 0 deletions libraries/infernet_ml/src/infernet_ml/utils/css_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,11 @@
{"id": "GOOSEAI/gpt-neoNone-3b", "name": "GPT-Neo 1.3B", "parameters": "1.3B"},
{"id": "GOOSEAI/gpt-neo-2-7b", "name": "GPT-Neo 2.7B", "parameters": "2.7B"},
],
CSSProvider.MIRANETWORK: [
{"id": "MIRANETWORK/gpt-4o", "name": "GPT-4O", "parameters": None},
{"id": "MIRANETWORK/deepseek-r1", "name": "Deepseek R1", "parameters": None},
{"id": "MIRANETWORK/gpt-4o-mini", "name": "GPT-4O Mini", "parameters": None},
{"id": "MIRANETWORK/claude-3.5-sonnet", "name": "Claude 3.5 Sonnet", "parameters": None},
{"id": "MIRANETWORK/llama-3.3-70b-instruct", "name": "Llama 3.3 70B Instruct", "parameters": "70B"},
],
}
40 changes: 38 additions & 2 deletions libraries/infernet_ml/src/infernet_ml/utils/css_mux.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Library containing functions for accessing closed source models.

Currently, 3 APIs are supported: OPENAI, PERPLEXITYAI, and GOOSEAI.
Currently, 4 APIs are supported: OPENAI, PERPLEXITYAI, GOOSEAI, and MIRANETWORK.

"""

Expand Down Expand Up @@ -51,7 +51,7 @@ class CSSProvider(StrEnum):
OPENAI = "OPENAI"
PERPLEXITYAI = "PERPLEXITYAI"
GOOSEAI = "GOOSEAI"

MIRANETWORK = "MIRANETWORK"

ApiKeys = Dict[CSSProvider, Optional[str]]

Expand Down Expand Up @@ -174,6 +174,29 @@ def goose_ai_request_generator(req: CSSRequest) -> tuple[str, dict[str, Any]]:
raise InfernetMLException(f"Unsupported request {req}")


def mira_network_request_generator(req: CSSRequest) -> tuple[str, dict[str, Any]]:
"""Returns base url & json input for Miran Network API.

Args:
req: a CSSRequest object, containing provider, endpoint, model,
api keys & params.

Returns:
base_url: str
processed input: dict[str, Any]

Raises:
InfernetMLException: if an unsupported model or params specified.
"""
match req:
case CSSRequest(model=model_name, params=CSSCompletionParams(messages=msgs)):
return "https://api.mira.network/v1/", {
"model": model_name,
"messages": [msg.model_dump() for msg in msgs],
}
case _:
raise InfernetMLException(f"Unsupported request {req}")

def extract_completions(result: Dict[str, Any]) -> str:
return cast(str, result["choices"][0]["message"]["content"])

Expand Down Expand Up @@ -214,6 +237,15 @@ def extract_completions_gooseai(result: Dict[str, Any]) -> str:
}
},
},
CSSProvider.MIRANETWORK: {
"input_func": mira_network_request_generator,
"endpoints": {
"completions": {
"real_endpoint": "chat/completions",
"post_process": extract_completions,
}
},
},
}


Expand Down Expand Up @@ -289,6 +321,9 @@ def css_mux(req: CSSRequest) -> str:
case CSSProvider.PERPLEXITYAI:
if result.status_code == 429:
raise RetryableException(result.text)
case CSSProvider.MIRANETWORK:
if result.status_code == 429 or result.status_code == 500:
raise RetryableException(result.text)
case _:
raise InfernetMLException(result.text)

Expand All @@ -304,6 +339,7 @@ def css_mux(req: CSSRequest) -> str:
"content", ""
),
CSSProvider.GOOSEAI: lambda result: result["choices"][0]["text"],
CSSProvider.MIRANETWORK: lambda result: result["choices"][0]["delta"].get("content", ""),
}


Expand Down