diff --git a/.travis.yml b/.travis.yml index a18521b..e77e4a2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,15 +1,37 @@ +dist: xenial language: python python: - "2.7" - - "3.3" - "3.4" + - "3.5" + - "3.6" + - "3.7" env: - - DJANGO_VERSION=Django==1.5 - - DJANGO_VERSION=Django==1.6 - - DJANGO_VERSION=Django==1.7 - - DJANGO_VERSION=Django==1.8 + - DJANGO_VERSION="Django<2" + - DJANGO_VERSION="Django<2.1" + - DJANGO_VERSION="Django<2.2" + - DJANGO_VERSION="Django<2.3" + + +matrix: + include: + - python: "3.7" + dist: xenial + exclude: + - python: "3.7" + env: DJANGO_VERSION="Django<2" + - python: "2.7" + env: DJANGO_VERSION="Django<2.1" + - python: "2.7" + env: DJANGO_VERSION="Django<2.2" + - python: "2.7" + env: DJANGO_VERSION="Django<2.3" + - python: "3.4" + env: DJANGO_VERSION="Django<2.2" + - python: "3.4" + env: DJANGO_VERSION="Django<2.3" # command to install dependencies install: diff --git a/cas/backends.py b/cas/backends.py index cecc62f..de69c23 100644 --- a/cas/backends.py +++ b/cas/backends.py @@ -26,7 +26,7 @@ from cas.exceptions import CasTicketException from cas.models import Tgt, PgtIOU -from cas.utils import cas_response_callbacks +from cas.utils import cas_response_callbacks, get_cas_server_url __all__ = ['CASBackend'] @@ -44,7 +44,7 @@ def _verify_cas1(ticket, service): """ params = {'ticket': ticket, 'service': service} - url = (urljoin(settings.CAS_SERVER_URL, 'validate') + '?' + + url = (urljoin(get_cas_server_url(service), 'validate') + '?' + urlencode(params)) page = urlopen(url) @@ -82,7 +82,7 @@ def _internal_verify_cas(ticket, service, suffix): if settings.CAS_PROXY_CALLBACK: params['pgtUrl'] = settings.CAS_PROXY_CALLBACK - url = (urljoin(settings.CAS_SERVER_URL, suffix) + '?' + + url = (urljoin(get_cas_server_url(service), suffix) + '?' + urlencode(params)) page = urlopen(url) @@ -149,7 +149,7 @@ def verify_proxy_ticket(ticket, service): params = {'ticket': ticket, 'service': service} - url = (urljoin(settings.CAS_SERVER_URL, 'proxyValidate') + '?' + + url = (urljoin(get_cas_server_url(service), 'proxyValidate') + '?' + urlencode(params)) page = urlopen(url) diff --git a/cas/middleware.py b/cas/middleware.py index 2a3ea5d..b513b91 100644 --- a/cas/middleware.py +++ b/cas/middleware.py @@ -11,8 +11,10 @@ try: from django.contrib.auth.views import login, logout -except: - from django.contrib.auth import login, logout +except ImportError: + from django.contrib.auth.views import LoginView, LogoutView + login = LoginView.as_view().view_class + logout = LogoutView.as_view().view_class from django.http import HttpResponseRedirect, HttpResponseForbidden from django.core.exceptions import ImproperlyConfigured @@ -57,6 +59,11 @@ def process_view(self, request, view_func, view_args, view_kwargs): logout. """ + try: + view_func = view_func.view_class + except AttributeError: + pass + if view_func == login: return cas_login(request, *view_args, **view_kwargs) elif view_func == logout: diff --git a/cas/models.py b/cas/models.py index 54b1606..5ec1759 100644 --- a/cas/models.py +++ b/cas/models.py @@ -24,6 +24,7 @@ from cas.exceptions import CasTicketException, CasConfigException +from cas.utils import get_cas_server_url logger = logging.getLogger(__name__) @@ -47,7 +48,7 @@ def get_proxy_ticket_for(self, service): params = {'pgt': self.tgt, 'targetService': service} - url = (urljoin(settings.CAS_SERVER_URL, 'proxy') + '?' + + url = (urljoin(get_cas_server_url(service), 'proxy') + '?' + urlencode(params)) page = urlopen(url) diff --git a/cas/tests/test_backend.py b/cas/tests/test_backend.py index d7ad9f0..4cf2022 100644 --- a/cas/tests/test_backend.py +++ b/cas/tests/test_backend.py @@ -1,4 +1,7 @@ -import mock +try: + from unittest import mock +except ImportError: + import mock from django.test import TestCase from cas.backends import CASBackend diff --git a/cas/tests/test_middleware.py b/cas/tests/test_middleware.py new file mode 100644 index 0000000..ea1b4c9 --- /dev/null +++ b/cas/tests/test_middleware.py @@ -0,0 +1,45 @@ +try: + from unittest import mock +except ImportError: + import mock + +from urllib.parse import quote_plus, urlencode + +from django.conf import settings +from django.test import TestCase, Client, override_settings, modify_settings + + + +@override_settings(MIDDLEWARE=[ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'cas.middleware.CASMiddleware' +]) +class CASBackendTest(TestCase): + + def setUp(self): + from cas.tests import factories + self.user = factories.UserFactory.create() + self.client = Client() + + def test_login_calls_cas_login(self): + resp = self.client.get('/login/') + self.assertTrue(resp.has_header('Location')) + expected_url = '{}/login?{}'.format( + settings.CAS_SERVER_URL, + urlencode({ + 'service': 'http://testserver/login/?next={}'.format(quote_plus('/')) + }) + ) + self.assertRedirects(resp, expected_url, fetch_redirect_response=False) + + def test_logout_calls_cas_logout(self): + resp = self.client.get('/logout/') + self.assertTrue(resp.has_header('Location')) + expected_url = '{}/logout?{}'.format( + settings.CAS_SERVER_URL, + urlencode({ + 'service': 'http://testserver/' + }) + ) + self.assertRedirects(resp, expected_url, fetch_redirect_response=False) \ No newline at end of file diff --git a/cas/tests/test_views.py b/cas/tests/test_views.py index a3b734f..b1d5831 100644 --- a/cas/tests/test_views.py +++ b/cas/tests/test_views.py @@ -1,9 +1,15 @@ -from django.test import TestCase, RequestFactory +from django.test import TestCase, RequestFactory, override_settings from django.test.utils import override_settings from cas.views import _redirect_url, _login_url, _logout_url, _service_url +def custom_cas_server_url(service): + if 'secret' in service: + return 'http://secret.cas.com' + return 'http://signin.cas.com/' + + class RequestFactoryRemix(RequestFactory): path = '/' @@ -61,5 +67,18 @@ def test_login_url(self): self.assertEqual(_login_url('http://localhost:8000/accounts/login/'), 'http://signin.cas.com/login?service=http%3A%2F%2Flocalhost%3A8000%2Faccounts%2Flogin%2F') - def test_logout_url(self): - self.assertEqual(_logout_url(self.request), 'http://signin.cas.com/logout') + + @override_settings(CAS_SERVER_URL_CALLBACK='cas.tests.test_views.custom_cas_server_url') + def test_login_url_custom(self): + self.assertEqual(_login_url('http://localhost:8000/accounts/login/?return_url=/secret/'), + 'http://secret.cas.com/login?service=http%3A%2F%2Flocalhost%3A8000%2Faccounts%2Flogin%2F%3Freturn_url%3D%2Fsecret%2F') + + @override_settings(CAS_SERVER_URL_CALLBACK='cas.tests.test_views.custom_cas_server_url') + def test_login_url_custom_normal(self): + self.assertEqual(_login_url('http://localhost:8000/accounts/login/?return_url=/normal/'), + 'http://signin.cas.com/login?service=http%3A%2F%2Flocalhost%3A8000%2Faccounts%2Flogin%2F%3Freturn_url%3D%2Fnormal%2F') + + @override_settings(CAS_SERVER_URL_CALLBACK='cas.nonexistent.callback') + def test_login_url_bad_callback_raises_exception(self): + with self.assertRaises(RuntimeError): + _ = _login_url('http://localhost:8000/accounts/login/?return_url=/normal/') diff --git a/cas/tests/urls.py b/cas/tests/urls.py new file mode 100644 index 0000000..6a3ae9a --- /dev/null +++ b/cas/tests/urls.py @@ -0,0 +1,5 @@ +from django.urls import path, include + +urlpatterns = [ + path('', include('django.contrib.auth.urls')), +] diff --git a/cas/utils.py b/cas/utils.py index f03e1c5..1617402 100644 --- a/cas/utils.py +++ b/cas/utils.py @@ -1,7 +1,7 @@ import logging from django.conf import settings - +from django.utils.module_loading import import_string logger = logging.getLogger(__name__) @@ -24,3 +24,20 @@ def cas_response_callbacks(tree): logger.error("Attribute Error: %s" % e) raise e func(tree) + +def get_cas_server_url(service): + try: + cas_server_url_callback = settings.CAS_SERVER_URL_CALLBACK + except AttributeError: + pass + else: + try: + callback = import_string(cas_server_url_callback) + except ImportError: + raise RuntimeError( + "Invalid callback for CAS_SERVER_URL_CALLBACK: {}".format( + cas_server_url_callback + ) + ) + return callback(service) + return settings.CAS_SERVER_URL diff --git a/cas/views.py b/cas/views.py index 14e250b..d4d61fc 100644 --- a/cas/views.py +++ b/cas/views.py @@ -23,6 +23,7 @@ from django.core.urlresolvers import reverse from cas.models import PgtIOU +from cas.utils import get_cas_server_url __all__ = ['login', 'logout'] @@ -130,7 +131,7 @@ def _login_url(service, ticket='ST', gateway=False): login_type = LOGINS.get(ticket[:2], 'login') - return urlparse.urljoin(settings.CAS_SERVER_URL, login_type) + '?' + urlencode(params) + return urlparse.urljoin(get_cas_server_url(service), login_type) + '?' + urlencode(params) def _logout_url(request, next_page=None): @@ -142,7 +143,10 @@ def _logout_url(request, next_page=None): """ - url = urlparse.urljoin(settings.CAS_SERVER_URL, 'logout') + protocol = ('http://', 'https://')[request.is_secure()] + service = protocol + request.get_host() + request.path + + url = urlparse.urljoin(get_cas_server_url(service), 'logout') if next_page and getattr(settings, 'CAS_PROVIDE_URL_TO_LOGOUT', True): parsed_url = urlparse.urlparse(next_page) diff --git a/run_tests.py b/run_tests.py index 1daa5f4..9efffae 100644 --- a/run_tests.py +++ b/run_tests.py @@ -16,11 +16,10 @@ 'ENGINE': 'django.db.backends.sqlite3', } }, - #ROOT_URLCONF='mailqueue.urls', + ROOT_URLCONF='cas.tests.urls', INSTALLED_APPS=('django.contrib.auth', 'django.contrib.contenttypes', 'django.contrib.sessions', - 'django.contrib.admin', 'cas',), CAS_SERVER_URL = 'http://signin.cas.com', ) @@ -31,11 +30,10 @@ 'ENGINE': 'django.db.backends.sqlite3', } }, - #ROOT_URLCONF='mailqueue.urls', + ROOT_URLCONF='cas.tests.urls', INSTALLED_APPS=('django.contrib.auth', 'django.contrib.contenttypes', 'django.contrib.sessions', - 'django.contrib.admin', 'cas',), USE_TZ=True, CAS_SERVER_URL = 'http://signin.cas.com',)