Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add callback function for server url #96

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
32 changes: 27 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
dist: xenial
language: python

python:
- "2.7"
- "3.3"
- "3.4"
- "3.5"
- "3.6"
- "3.7"

env:
- DJANGO_VERSION=Django==1.5
- DJANGO_VERSION=Django==1.6
- DJANGO_VERSION=Django==1.7
- DJANGO_VERSION=Django==1.8
- DJANGO_VERSION="Django<2"
- DJANGO_VERSION="Django<2.1"
- DJANGO_VERSION="Django<2.2"
- DJANGO_VERSION="Django<2.3"


matrix:
include:
- python: "3.7"
dist: xenial
exclude:
- python: "3.7"
env: DJANGO_VERSION="Django<2"
- python: "2.7"
env: DJANGO_VERSION="Django<2.1"
- python: "2.7"
env: DJANGO_VERSION="Django<2.2"
- python: "2.7"
env: DJANGO_VERSION="Django<2.3"
- python: "3.4"
env: DJANGO_VERSION="Django<2.2"
- python: "3.4"
env: DJANGO_VERSION="Django<2.3"

# command to install dependencies
install:
Expand Down
8 changes: 4 additions & 4 deletions cas/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from cas.exceptions import CasTicketException
from cas.models import Tgt, PgtIOU
from cas.utils import cas_response_callbacks
from cas.utils import cas_response_callbacks, get_cas_server_url

__all__ = ['CASBackend']

Expand All @@ -44,7 +44,7 @@ def _verify_cas1(ticket, service):
"""

params = {'ticket': ticket, 'service': service}
url = (urljoin(settings.CAS_SERVER_URL, 'validate') + '?' +
url = (urljoin(get_cas_server_url(service), 'validate') + '?' +
urlencode(params))
page = urlopen(url)

Expand Down Expand Up @@ -82,7 +82,7 @@ def _internal_verify_cas(ticket, service, suffix):
if settings.CAS_PROXY_CALLBACK:
params['pgtUrl'] = settings.CAS_PROXY_CALLBACK

url = (urljoin(settings.CAS_SERVER_URL, suffix) + '?' +
url = (urljoin(get_cas_server_url(service), suffix) + '?' +
urlencode(params))

page = urlopen(url)
Expand Down Expand Up @@ -149,7 +149,7 @@ def verify_proxy_ticket(ticket, service):

params = {'ticket': ticket, 'service': service}

url = (urljoin(settings.CAS_SERVER_URL, 'proxyValidate') + '?' +
url = (urljoin(get_cas_server_url(service), 'proxyValidate') + '?' +
urlencode(params))

page = urlopen(url)
Expand Down
11 changes: 9 additions & 2 deletions cas/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

try:
from django.contrib.auth.views import login, logout
except:
from django.contrib.auth import login, logout
except ImportError:
from django.contrib.auth.views import LoginView, LogoutView
login = LoginView.as_view().view_class
logout = LogoutView.as_view().view_class

from django.http import HttpResponseRedirect, HttpResponseForbidden
from django.core.exceptions import ImproperlyConfigured
Expand Down Expand Up @@ -57,6 +59,11 @@ def process_view(self, request, view_func, view_args, view_kwargs):
logout.
"""

try:
view_func = view_func.view_class
except AttributeError:
pass

if view_func == login:
return cas_login(request, *view_args, **view_kwargs)
elif view_func == logout:
Expand Down
3 changes: 2 additions & 1 deletion cas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@


from cas.exceptions import CasTicketException, CasConfigException
from cas.utils import get_cas_server_url


logger = logging.getLogger(__name__)
Expand All @@ -47,7 +48,7 @@ def get_proxy_ticket_for(self, service):

params = {'pgt': self.tgt, 'targetService': service}

url = (urljoin(settings.CAS_SERVER_URL, 'proxy') + '?' +
url = (urljoin(get_cas_server_url(service), 'proxy') + '?' +
urlencode(params))

page = urlopen(url)
Expand Down
5 changes: 4 additions & 1 deletion cas/tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import mock
try:
from unittest import mock
except ImportError:
import mock
from django.test import TestCase

from cas.backends import CASBackend
Expand Down
45 changes: 45 additions & 0 deletions cas/tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
try:
from unittest import mock
except ImportError:
import mock

from urllib.parse import quote_plus, urlencode

from django.conf import settings
from django.test import TestCase, Client, override_settings, modify_settings



@override_settings(MIDDLEWARE=[
'django.contrib.sessions.middleware.SessionMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'cas.middleware.CASMiddleware'
])
class CASBackendTest(TestCase):

def setUp(self):
from cas.tests import factories
self.user = factories.UserFactory.create()
self.client = Client()

def test_login_calls_cas_login(self):
resp = self.client.get('/login/')
self.assertTrue(resp.has_header('Location'))
expected_url = '{}/login?{}'.format(
settings.CAS_SERVER_URL,
urlencode({
'service': 'http://testserver/login/?next={}'.format(quote_plus('/'))
})
)
self.assertRedirects(resp, expected_url, fetch_redirect_response=False)

def test_logout_calls_cas_logout(self):
resp = self.client.get('/logout/')
self.assertTrue(resp.has_header('Location'))
expected_url = '{}/logout?{}'.format(
settings.CAS_SERVER_URL,
urlencode({
'service': 'http://testserver/'
})
)
self.assertRedirects(resp, expected_url, fetch_redirect_response=False)
25 changes: 22 additions & 3 deletions cas/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from django.test import TestCase, RequestFactory
from django.test import TestCase, RequestFactory, override_settings
from django.test.utils import override_settings

from cas.views import _redirect_url, _login_url, _logout_url, _service_url


def custom_cas_server_url(service):
if 'secret' in service:
return 'http://secret.cas.com'
return 'http://signin.cas.com/'


class RequestFactoryRemix(RequestFactory):

path = '/'
Expand Down Expand Up @@ -61,5 +67,18 @@ def test_login_url(self):
self.assertEqual(_login_url('http://localhost:8000/accounts/login/'),
'http://signin.cas.com/login?service=http%3A%2F%2Flocalhost%3A8000%2Faccounts%2Flogin%2F')

def test_logout_url(self):
self.assertEqual(_logout_url(self.request), 'http://signin.cas.com/logout')

@override_settings(CAS_SERVER_URL_CALLBACK='cas.tests.test_views.custom_cas_server_url')
def test_login_url_custom(self):
self.assertEqual(_login_url('http://localhost:8000/accounts/login/?return_url=/secret/'),
'http://secret.cas.com/login?service=http%3A%2F%2Flocalhost%3A8000%2Faccounts%2Flogin%2F%3Freturn_url%3D%2Fsecret%2F')

@override_settings(CAS_SERVER_URL_CALLBACK='cas.tests.test_views.custom_cas_server_url')
def test_login_url_custom_normal(self):
self.assertEqual(_login_url('http://localhost:8000/accounts/login/?return_url=/normal/'),
'http://signin.cas.com/login?service=http%3A%2F%2Flocalhost%3A8000%2Faccounts%2Flogin%2F%3Freturn_url%3D%2Fnormal%2F')

@override_settings(CAS_SERVER_URL_CALLBACK='cas.nonexistent.callback')
def test_login_url_bad_callback_raises_exception(self):
with self.assertRaises(RuntimeError):
_ = _login_url('http://localhost:8000/accounts/login/?return_url=/normal/')
5 changes: 5 additions & 0 deletions cas/tests/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django.urls import path, include

urlpatterns = [
path('', include('django.contrib.auth.urls')),
]
19 changes: 18 additions & 1 deletion cas/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from django.conf import settings

from django.utils.module_loading import import_string

logger = logging.getLogger(__name__)

Expand All @@ -24,3 +24,20 @@ def cas_response_callbacks(tree):
logger.error("Attribute Error: %s" % e)
raise e
func(tree)

def get_cas_server_url(service):
try:
cas_server_url_callback = settings.CAS_SERVER_URL_CALLBACK
except AttributeError:
pass
else:
try:
callback = import_string(cas_server_url_callback)
except ImportError:
raise RuntimeError(
"Invalid callback for CAS_SERVER_URL_CALLBACK: {}".format(
cas_server_url_callback
)
)
return callback(service)
return settings.CAS_SERVER_URL
8 changes: 6 additions & 2 deletions cas/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from django.core.urlresolvers import reverse

from cas.models import PgtIOU
from cas.utils import get_cas_server_url

__all__ = ['login', 'logout']

Expand Down Expand Up @@ -130,7 +131,7 @@ def _login_url(service, ticket='ST', gateway=False):

login_type = LOGINS.get(ticket[:2], 'login')

return urlparse.urljoin(settings.CAS_SERVER_URL, login_type) + '?' + urlencode(params)
return urlparse.urljoin(get_cas_server_url(service), login_type) + '?' + urlencode(params)


def _logout_url(request, next_page=None):
Expand All @@ -142,7 +143,10 @@ def _logout_url(request, next_page=None):

"""

url = urlparse.urljoin(settings.CAS_SERVER_URL, 'logout')
protocol = ('http://', 'https://')[request.is_secure()]
service = protocol + request.get_host() + request.path

url = urlparse.urljoin(get_cas_server_url(service), 'logout')

if next_page and getattr(settings, 'CAS_PROVIDE_URL_TO_LOGOUT', True):
parsed_url = urlparse.urlparse(next_page)
Expand Down
6 changes: 2 additions & 4 deletions run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
'ENGINE': 'django.db.backends.sqlite3',
}
},
#ROOT_URLCONF='mailqueue.urls',
ROOT_URLCONF='cas.tests.urls',
INSTALLED_APPS=('django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.admin',
'cas',),
CAS_SERVER_URL = 'http://signin.cas.com',
)
Expand All @@ -31,11 +30,10 @@
'ENGINE': 'django.db.backends.sqlite3',
}
},
#ROOT_URLCONF='mailqueue.urls',
ROOT_URLCONF='cas.tests.urls',
INSTALLED_APPS=('django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.admin',
'cas',),
USE_TZ=True,
CAS_SERVER_URL = 'http://signin.cas.com',)
Expand Down