Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added type annotations for public API + flake8 fixes #627

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions firebase_admin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import os
import threading
from typing import Any, Callable, Dict, Optional

from firebase_admin import credentials
from firebase_admin.__about__ import __version__
Expand All @@ -31,7 +32,8 @@
_CONFIG_VALID_KEYS = ['databaseAuthVariableOverride', 'databaseURL', 'httpTimeout', 'projectId',
'storageBucket']

def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME):

def initialize_app(credential: Optional[credentials.Base] = None, options: Optional[Dict[str, Any]] = None, name: str = _DEFAULT_APP_NAME) -> "App":
"""Initializes and returns a new App instance.

Creates a new App instance using the specified options
Expand Down Expand Up @@ -83,7 +85,7 @@ def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME):
'you call initialize_app().').format(name))


def delete_app(app):
def delete_app(app: "App"):
"""Gracefully deletes an App instance.

Args:
Expand All @@ -98,7 +100,7 @@ def delete_app(app):
with _apps_lock:
if _apps.get(app.name) is app:
del _apps[app.name]
app._cleanup() # pylint: disable=protected-access
app._cleanup() # pylint: disable=protected-access
return
if app.name == _DEFAULT_APP_NAME:
raise ValueError(
Expand All @@ -111,7 +113,7 @@ def delete_app(app):
'second argument.').format(app.name))


def get_app(name=_DEFAULT_APP_NAME):
def get_app(name: str = _DEFAULT_APP_NAME) -> "App":
"""Retrieves an App instance by name.

Args:
Expand Down Expand Up @@ -190,7 +192,7 @@ class App:
common to all Firebase APIs.
"""

def __init__(self, name, credential, options):
def __init__(self, name: str, credential: credentials.Base, options: Optional[Dict[str, Any]]):
"""Constructs a new App using the provided name and options.

Args:
Expand Down Expand Up @@ -265,7 +267,7 @@ def _lookup_project_id(self):
App._validate_project_id(self._options.get('projectId'))
return project_id

def _get_service(self, name, initializer):
def _get_service(self, name: str, initializer: Callable):
"""Returns the service instance identified by the given name.

Services are functional entities exposed by the Admin SDK (e.g. auth, database). Each
Expand Down
8 changes: 4 additions & 4 deletions firebase_admin/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Internal utilities common to all modules."""

import json
from typing import Callable, Optional

import google.auth
import requests
Expand Down Expand Up @@ -76,7 +77,7 @@
}


def _get_initialized_app(app):
def _get_initialized_app(app: Optional[firebase_admin.App]):
"""Returns a reference to an initialized App instance."""
if app is None:
return firebase_admin.get_app()
Expand All @@ -92,10 +93,9 @@ def _get_initialized_app(app):
' firebase_admin.App, but given "{0}".'.format(type(app)))



def get_app_service(app, name, initializer):
def get_app_service(app: Optional[firebase_admin.App], name: str, initializer: Callable):
app = _get_initialized_app(app)
return app._get_service(name, initializer) # pylint: disable=protected-access
return app._get_service(name, initializer) # pylint: disable=protected-access


def handle_platform_error_from_requests(error, handle_func=None):
Expand Down
9 changes: 6 additions & 3 deletions firebase_admin/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import collections
import json
import pathlib
from typing import Any, Dict, Union

import google.auth
from google.auth.transport import requests
from google.oauth2 import credentials
from google.oauth2 import service_account
import google.auth.credentials


_request = requests.Request()
Expand All @@ -44,7 +46,7 @@
class Base:
"""Provides OAuth2 access tokens for accessing Firebase services."""

def get_access_token(self):
def get_access_token(self) -> AccessTokenInfo:
"""Fetches a Google OAuth2 access token using this credential instance.

Returns:
Expand All @@ -54,7 +56,7 @@ def get_access_token(self):
google_cred.refresh(_request)
return AccessTokenInfo(google_cred.token, google_cred.expiry)

def get_credential(self):
def get_credential(self) -> google.auth.credentials.Credentials:
"""Returns the Google credential instance used for authentication."""
raise NotImplementedError

Expand All @@ -64,7 +66,7 @@ class Certificate(Base):

_CREDENTIAL_TYPE = 'service_account'

def __init__(self, cert):
def __init__(self, cert: Union[str, Dict[str, Any]]):
"""Initializes a credential from a Google service account certificate.

Service account certificates can be downloaded as JSON files from the Firebase console.
Expand Down Expand Up @@ -158,6 +160,7 @@ def _load_credential(self):
if not self._g_credential:
self._g_credential, self._project_id = google.auth.default(scopes=_scopes)


class RefreshToken(Base):
"""A credential initialized from an existing refresh token."""

Expand Down
14 changes: 8 additions & 6 deletions firebase_admin/firestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

try:
from google.cloud import firestore # pylint: disable=import-error,no-name-in-module
from google.cloud import firestore # pylint: disable=import-error,no-name-in-module
existing = globals().keys()
for key, value in firestore.__dict__.items():
if not key.startswith('_') and key not in existing:
Expand All @@ -28,13 +28,15 @@
raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure '
'to install the "google-cloud-firestore" module.')

from firebase_admin import _utils
from firebase_admin import _utils, App
import google.auth.credentials
from typing import Optional


_FIRESTORE_ATTRIBUTE = '_firestore'


def client(app=None):
def client(app: Optional[App] = None) -> firestore.Client:
"""Returns a client that can be used to interact with Google Cloud Firestore.

Args:
Expand All @@ -57,14 +59,14 @@ def client(app=None):
class _FirestoreClient:
"""Holds a Google Cloud Firestore client instance."""

def __init__(self, credentials, project):
def __init__(self, credentials: google.auth.credentials.Credentials, project: str):
self._client = firestore.Client(credentials=credentials, project=project)

def get(self):
def get(self) -> firestore.Client:
return self._client

@classmethod
def from_app(cls, app):
def from_app(cls, app: App):
"""Creates a new _FirestoreClient for the specified app."""
credentials = app.credential.get_credential()
project = app.project_id
Expand Down
5 changes: 5 additions & 0 deletions firebase_admin/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
def _get_messaging_service(app):
return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService)


def send(message, dry_run=False, app=None):
"""Sends the given message via Firebase Cloud Messaging (FCM).

Expand All @@ -115,6 +116,7 @@ def send(message, dry_run=False, app=None):
"""
return _get_messaging_service(app).send(message, dry_run)


def send_all(messages, dry_run=False, app=None):
"""Sends the given list of messages via Firebase Cloud Messaging as a single batch.

Expand All @@ -135,6 +137,7 @@ def send_all(messages, dry_run=False, app=None):
"""
return _get_messaging_service(app).send_all(messages, dry_run)


def send_multicast(multicast_message, dry_run=False, app=None):
"""Sends the given mutlicast message to all tokens via Firebase Cloud Messaging (FCM).

Expand Down Expand Up @@ -166,6 +169,7 @@ def send_multicast(multicast_message, dry_run=False, app=None):
) for token in multicast_message.tokens]
return _get_messaging_service(app).send_all(messages, dry_run)


def subscribe_to_topic(tokens, topic, app=None):
"""Subscribes a list of registration tokens to an FCM topic.

Expand All @@ -185,6 +189,7 @@ def subscribe_to_topic(tokens, topic, app=None):
return _get_messaging_service(app).make_topic_management_request(
tokens, topic, 'iid/v1:batchAdd')


def unsubscribe_from_topic(tokens, topic, app=None):
"""Unsubscribes a list of registration tokens from an FCM topic.

Expand Down
20 changes: 10 additions & 10 deletions firebase_admin/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,13 @@ def from_dict(cls, data, app=None):
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
model = Model(model_format=tflite_format)
model._data = data_copy # pylint: disable=protected-access
model._app = app # pylint: disable=protected-access
model._app = app # pylint: disable=protected-access
return model

def _update_from_dict(self, data):
copy = Model.from_dict(data)
self.model_format = copy.model_format
self._data = copy._data # pylint: disable=protected-access
self._data = copy._data # pylint: disable=protected-access

def __eq__(self, other):
if isinstance(other, self.__class__):
Expand Down Expand Up @@ -334,7 +334,7 @@ def model_format(self):
def model_format(self, model_format):
if model_format is not None:
_validate_model_format(model_format)
self._model_format = model_format #Can be None
self._model_format = model_format # Can be None
return self

def as_dict(self, for_upload=False):
Expand Down Expand Up @@ -370,7 +370,7 @@ def from_dict(cls, data):
"""Create an instance of the object from a dict."""
data_copy = dict(data)
tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy))
tflite_format._data = data_copy # pylint: disable=protected-access
tflite_format._data = data_copy # pylint: disable=protected-access
return tflite_format

def __eq__(self, other):
Expand Down Expand Up @@ -405,7 +405,7 @@ def model_source(self, model_source):
if model_source is not None:
if not isinstance(model_source, TFLiteModelSource):
raise TypeError('Model source must be a TFLiteModelSource object.')
self._model_source = model_source # Can be None
self._model_source = model_source # Can be None

@property
def size_bytes(self):
Expand Down Expand Up @@ -485,7 +485,7 @@ def __init__(self, gcs_tflite_uri, app=None):

def __eq__(self, other):
if isinstance(other, self.__class__):
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
return False

def __ne__(self, other):
Expand Down Expand Up @@ -775,7 +775,7 @@ def _validate_display_name(display_name):

def _validate_tags(tags):
if not isinstance(tags, list) or not \
all(isinstance(tag, str) for tag in tags):
all(isinstance(tag, str) for tag in tags):
raise TypeError('Tags must be a list of strings.')
if not all(_TAG_PATTERN.match(tag) for tag in tags):
raise ValueError('Tag format is invalid.')
Expand All @@ -789,6 +789,7 @@ def _validate_gcs_tflite_uri(uri):
raise ValueError('GCS TFLite URI format is invalid.')
return uri


def _validate_auto_ml_model(model):
if not _AUTO_ML_MODEL_PATTERN.match(model):
raise ValueError('Model resource name format is invalid.')
Expand All @@ -809,7 +810,7 @@ def _validate_list_filter(list_filter):

def _validate_page_size(page_size):
if page_size is not None:
if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck
if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck
# Specifically type() to disallow boolean which is a subtype of int
raise TypeError('Page size must be a number or None.')
if page_size < 1 or page_size > _MAX_PAGE_SIZE:
Expand Down Expand Up @@ -864,7 +865,7 @@ def _exponential_backoff(self, current_attempt, stop_time):

if stop_time is not None:
max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds()
if max_seconds_left < 1: # allow a bit of time for rpc
if max_seconds_left < 1: # allow a bit of time for rpc
raise exceptions.DeadlineExceededError('Polling max time exceeded.')
wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1)
time.sleep(wait_time_seconds)
Expand Down Expand Up @@ -925,7 +926,6 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
# If the operation is not complete or timed out, return a (locked) model instead
return get_model(model_id).as_dict()


def create_model(self, model):
_validate_model(model)
try:
Expand Down
8 changes: 5 additions & 3 deletions firebase_admin/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
raise ImportError('Failed to import the Cloud Storage library for Python. Make sure '
'to install the "google-cloud-storage" module.')

from firebase_admin import _utils
from firebase_admin import _utils, App
from typing import Optional


_STORAGE_ATTRIBUTE = '_storage'

def bucket(name=None, app=None) -> storage.Bucket:

def bucket(name: Optional[str] = None, app: Optional[App] = None) -> storage.Bucket:
"""Returns a handle to a Google Cloud Storage bucket.

If the name argument is not provided, uses the 'storageBucket' option specified when
Expand Down Expand Up @@ -67,7 +69,7 @@ def from_app(cls, app):
# significantly speeds up the initialization of the storage client.
return _StorageClient(credentials, app.project_id, default_bucket)

def bucket(self, name=None):
def bucket(self, name: Optional[str] = None):
"""Returns a handle to the specified Cloud Storage Bucket."""
bucket_name = name if name is not None else self._default_bucket
if bucket_name is None:
Expand Down
5 changes: 3 additions & 2 deletions firebase_admin/tenant_mgt.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=Non
FirebaseError: If an error occurs while retrieving the user accounts.
"""
tenant_mgt_service = _get_tenant_mgt_service(app)

def download(page_token, max_results):
return tenant_mgt_service.list_tenants(page_token, max_results)
return ListTenantsPage(download, page_token, max_results)
Expand All @@ -206,7 +207,7 @@ class Tenant:
def __init__(self, data):
if not isinstance(data, dict):
raise ValueError('Invalid data argument in Tenant constructor: {0}'.format(data))
if not 'name' in data:
if 'name' not in data:
raise ValueError('Tenant response missing required keys.')

self._data = data
Expand Down Expand Up @@ -256,7 +257,7 @@ def auth_for_tenant(self, tenant_id):

client = auth.Client(self.app, tenant_id=tenant_id)
self.tenant_clients[tenant_id] = client
return client
return client

def get_tenant(self, tenant_id):
"""Gets the tenant corresponding to the given ``tenant_id``."""
Expand Down