Skip to content

Commit

Permalink
test optional for api key
Browse files Browse the repository at this point in the history
  • Loading branch information
antoinebou12 authored Apr 19, 2024
1 parent fc4b5d0 commit 34e5608
Showing 1 changed file with 54 additions and 50 deletions.
104 changes: 54 additions & 50 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,6 @@
from Crypto.PublicKey import RSA

from pydantic import BaseModel
import logging
import asyncio
import datetime
import logging
import time
from base64 import b64encode
from threading import Timer
from typing import Callable, Dict, Final, List, Optional, Union, Any
from contextlib import asynccontextmanager

import aiohttp
from Crypto.Cipher import PKCS1_v1_5
from Crypto.PublicKey import RSA

from pydantic import BaseModel
import logging

METRIC_TYPE_WEIGHT: Final = "weight"
METRIC_TYPE_GROWTH_RECORD: Final = "growth_record"
Expand Down Expand Up @@ -1021,23 +1005,23 @@ class ClientSSLError(Exception):

from starlette.responses import Response
import httpx
from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.security import HTTPBasic, HTTPBasicCredentials, APIKeyHeader
from fastapi.responses import JSONResponse
from starlette.middleware.cors import CORSMiddleware
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_404_NOT_FOUND
from datetime import datetime
import hashlib
import os
from Crypto.PublicKey import RSA
from Crypto.Signature import pkcs1_15
from Crypto.Hash import SHA256
from fastapi import HTTPException, Security
from fastapi.security.api_key import APIKeyHeader
from pydantic import BaseModel
from typing import Optional

security_basic = HTTPBasic()
API_KEY_NAME = "access_token"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=True)

# Load RSA keys from environment variables
PRIVATE_KEY = RSA.import_key(os.getenv("RSA_PRIVATE_KEY"))
PUBLIC_KEY = RSA.import_key(os.getenv("RSA_PUBLIC_KEY"))

api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)

# Initialize FastAPI and Jinja2
app = FastAPI(docs_url="/docs", redoc_url=None)
Expand All @@ -1059,6 +1043,17 @@ class APIResponse(BaseModel):
message: str
data: Optional[Any] = None

def load_rsa_keys():
try:
private_key = RSA.import_key(os.getenv("RSA_PRIVATE_KEY"))
public_key = RSA.import_key(os.getenv("RSA_PUBLIC_KEY"))
return private_key, public_key
except Exception as e:
print(f"Error loading RSA keys: {e}")
raise HTTPException(status_code=500, detail="Failed to load RSA keys")

PRIVATE_KEY, PUBLIC_KEY = load_rsa_keys()


def generate_api_key(email: str, password: str) -> str:
"""Generate a signed API key."""
Expand All @@ -1074,37 +1069,46 @@ def generate_api_key(email: str, password: str) -> str:

return signed_api_key

def decrypt_api_key(api_key: str) -> Optional[Dict[str, str]]:
def decrypt_api_key(api_key: str):
"""Decrypt API key to extract the email and password."""
try:
api_key_part, signature = api_key.rsplit(':', 1)
key_hash = SHA256.new(api_key_part.encode())
# Verify the signature
pkcs1_15.new(PUBLIC_KEY).verify(key_hash, bytes.fromhex(signature))

# Decode the base string
decoded_bytes = bytes.fromhex(api_key_part)
decoded_string = decoded_bytes.decode('utf-8') # Assuming the input was UTF-8-encoded
email, password, timestamp, salt = decoded_string.split(':')

return {"email": email, "password": password}
except ValueError: # Catches all errors related to cryptographic operations
raise HTTPException(status_code=403, detail="Invalid API key")

def verify_api_key(api_key: str) -> bool:
try:
api_key, signature = api_key.rsplit(':', 1)
key_hash = SHA256.new(api_key.encode())
# Verify the signature
pkcs1_15.new(PUBLIC_KEY).verify(key_hash, bytes.fromhex(signature))
# If verification is successful, decode the email and password
decoded_data = hashlib.sha256().new(bytes.fromhex(api_key)).hexdigest().split(":")
return {"email": decoded_data[0], "password": decoded_data[1]}
except (ValueError, pkcs1_15.PKCS115_SigSchemeError) as e:
return None

async def get_api_key(api_key: str = Depends(api_key_header)):
"""Dependency that validates the API key."""
user_credentials = decrypt_api_key(api_key)
if user_credentials:
return user_credentials
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key")

async def get_current_user(credentials: Optional[HTTPBasicCredentials] = Depends(security_basic), api_key: Optional[str] = Depends(get_api_key)):
"""Dependency that authenticates user either by basic auth or API key."""
if credentials:
user = RenphoWeight(email=credentials.username, password=credentials.password)
if await user.auth():
return user
elif api_key:
user = RenphoWeight(email=api_key["email"], password=api_key["password"])
if await user.auth():
return user
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials or API key")
return True
except Exception:
return False

def get_api_key(api_key: str = Depends(api_key_header)):
if verify_api_key(api_key):
return api_key
else:
raise HTTPException(status_code=403, detail="Invalid API key")


def get_current_user(credentials: Optional[HTTPBasicCredentials] = Depends(security_basic), api_key: Optional[str] = Depends(get_api_key)):
email = credentials.username if credentials else decrypt_api_key(api_key)['email']
password = credentials.password if credentials else decrypt_api_key(api_key)['password']
user = RenphoWeight(email=email, password=password)
if not user.auth():
raise HTTPException(status_code=403, detail="Invalid credentials")
return user

@app.get("/")
def read_root(request: Request):
Expand Down

0 comments on commit 34e5608

Please sign in to comment.