Skip to content

Commit

Permalink
OIDC support
Browse files Browse the repository at this point in the history
  • Loading branch information
q committed Oct 28, 2024
1 parent 42b8cc9 commit 6420faf
Show file tree
Hide file tree
Showing 11 changed files with 1,094 additions and 615 deletions.
5 changes: 5 additions & 0 deletions irrd/conf/known_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@
+ [AUTH_SET_CREATION_COMMON_KEY]
},
"password_hashers": {hasher_name.lower(): {} for hasher_name in PASSWORD_HASHERS_ALL.keys()},
"oidc": {
"issuer": {},
"client_id": {},
"client_secret": {},
}
},
"rpki": {
"roa_source": {},
Expand Down
30 changes: 30 additions & 0 deletions irrd/storage/alembic/versions/efb8b67865ea_oidc_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""OIDC support
Revision ID: efb8b67865ea
Revises: f56387c94696
Create Date: 2024-10-28 16:24:50.826630
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'efb8b67865ea'
down_revision = 'f56387c94696'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('auth_user', sa.Column('oidc_sub', sa.String(), nullable=True))
op.create_index(op.f('ix_auth_user_oidc_sub'), 'auth_user', ['oidc_sub'], unique=False)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_auth_user_oidc_sub'), table_name='auth_user')
op.drop_column('auth_user', 'oidc_sub')
# ### end Alembic commands ###
3 changes: 2 additions & 1 deletion irrd/storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ class AuthUser(Base): # type: ignore
__tablename__ = "auth_user"

pk = sa.Column(pg.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), primary_key=True)
oidc_sub = sa.Column(sa.String, index=True, nullable=True)
email = sa.Column(sa.String, index=True, unique=True, nullable=False)
name = sa.Column(sa.String, nullable=False)
password = sa.Column(sa.String, nullable=False)
Expand Down Expand Up @@ -462,7 +463,7 @@ def get_display_name(self) -> str: # pragma: no cover
return self.name

def get_id(self) -> str:
return self.email
return str(self.pk)

def get_hashed_password(self) -> str:
return self.password
Expand Down
91 changes: 88 additions & 3 deletions irrd/webui/auth/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
import secrets
from urllib.parse import unquote_plus, urlparse

import oic
import oic.oic.message
import wtforms
from starlette.requests import Request
from starlette.responses import RedirectResponse, Response
from starlette_wtf import StarletteForm, csrf_protect

from imia import login_user
from irrd.storage.models import AuthUser
from irrd.storage.orm_provider import ORMSessionProvider, session_provider_manager
from irrd.webui import MFA_COMPLETE_SESSION_KEY
Expand All @@ -16,7 +17,7 @@
PasswordResetToken,
get_login_manager,
password_handler,
validate_password_strength,
validate_password_strength, get_oidc_client,
)
from irrd.webui.helpers import (
client_ip_str,
Expand All @@ -26,6 +27,7 @@
send_template_email,
)
from irrd.webui.rendering import render_form, template_context_render
from irrd.conf import get_setting

logger = logging.getLogger(__name__)

Expand All @@ -45,6 +47,13 @@ def clean_next_url(request: Request, default: str = DEFAULT_REDIRECT_URL):

@rate_limit_post
async def login(request: Request):
if get_setting("auth.oidc"):
return await login_oidc(request)
else:
return await login_local(request)


async def login_local(request: Request):
if request.method == "GET":
return template_context_render(
"login.html",
Expand Down Expand Up @@ -115,6 +124,9 @@ async def validate(self):
@csrf_protect
@session_provider_manager
async def create_account(request: Request, session_provider: ORMSessionProvider) -> Response:
if get_setting("auth.oidc"):
return Response(status_code=404)

form = await CreateAccountForm.from_formdata(request=request, session_provider=session_provider)
if not form.is_submitted() or not await form.validate():
return template_context_render("create_account_form.html", request, {"form_html": render_form(form)})
Expand Down Expand Up @@ -146,6 +158,9 @@ class ResetPasswordRequestForm(StarletteForm):
@csrf_protect
@session_provider_manager
async def reset_password_request(request: Request, session_provider: ORMSessionProvider) -> Response:
if get_setting("auth.oidc"):
return Response(status_code=404)

form = await ResetPasswordRequestForm.from_formdata(request=request)
if not form.is_submitted() or not await form.validate():
return template_context_render(
Expand Down Expand Up @@ -198,6 +213,9 @@ async def validate(self, current_user: AuthUser):
@session_provider_manager
@authentication_required
async def change_password(request: Request, session_provider: ORMSessionProvider) -> Response:
if get_setting("auth.oidc"):
return Response(status_code=404)

form = await ChangePasswordForm.from_formdata(request=request)
if not form.is_submitted() or not await form.validate(current_user=request.auth.user):
return template_context_render(
Expand Down Expand Up @@ -229,6 +247,9 @@ class ChangeProfileForm(CurrentPasswordForm):
@session_provider_manager
@authentication_required
async def change_profile(request: Request, session_provider: ORMSessionProvider) -> Response:
if get_setting("auth.oidc"):
return Response(status_code=404)

form = await ChangeProfileForm.from_formdata(
request=request, email=request.auth.user.email, name=request.auth.user.name
)
Expand Down Expand Up @@ -286,6 +307,9 @@ async def validate(self):
@csrf_protect
@session_provider_manager
async def set_password(request: Request, session_provider: ORMSessionProvider) -> Response:
if get_setting("auth.oidc"):
return Response(status_code=404)

query = session_provider.session.query(AuthUser).filter(
AuthUser.pk == request.path_params["pk"],
)
Expand All @@ -310,3 +334,64 @@ async def set_password(request: Request, session_provider: ORMSessionProvider) -
if not initial:
send_authentication_change_mail(user, request, "Your password was reset.")
return RedirectResponse(request.url_for("ui:auth:login"), status_code=302)


async def login_oidc(request: Request):
client = get_oidc_client()
request.session["state"] = oic.rndstr()
request.session["nonce"] = oic.rndstr()
request.session["next_url"] = str(clean_next_url(request, "ui:index"))
args = {
"client_id": client.client_id,
"response_type": "code",
"scope": ["openid"],
"state": request.session["state"],
"nonce": request.session["nonce"],
"redirect_uri": request.url_for("ui:auth:oauth_callback"),
}

auth_req = client.construct_AuthorizationRequest(request_args=args)
login_url = auth_req.request(client.authorization_endpoint)
return RedirectResponse(login_url, status_code=302)


async def oauth_callback(request: Request):
client = get_oidc_client()
resp = client.parse_response(
oic.oic.message.AuthorizationResponse,
info=dict(request.query_params), sformat="dict"
)

if resp["state"] != request.session["state"]:
return Response(status_code=403)

next_url = request.session["next_url"]

client.do_access_token_request(
state=resp["state"], request_args={
"code": resp["code"],
"redirect_uri": request.url_for("ui:auth:oauth_callback"),
}, authn_method="client_secret_basic"
)
userinfo = client.do_user_info_request(state=resp["state"])

session_provider = ORMSessionProvider()
target = session_provider.session.query(AuthUser).filter_by(oidc_sub=userinfo["sub"])
user = await session_provider.run(target.one)

if not user:
user = AuthUser(
email=userinfo["email"],
oidc_sub=userinfo["sub"],
name=userinfo["name"],
password=""
)
session_provider.session.add(user)
session_provider.session.commit()

await login_user(request, user, oic.rndstr())

session_provider.close()

request.session[MFA_COMPLETE_SESSION_KEY] = True
return RedirectResponse(next_url, status_code=302)
24 changes: 24 additions & 0 deletions irrd/webui/auth/endpoints_mfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def webauthn_challenge_override() -> Optional[bytes]:

@authentication_required
async def mfa_status(request: Request) -> Response:
if get_setting("auth.oidc"):
return Response(status_code=404)

context = {
"has_mfa": request.auth.user.has_mfa,
"has_totp": request.auth.user.has_totp,
Expand Down Expand Up @@ -113,6 +116,9 @@ async def validate(self, totp: pyotp.totp.TOTP, last_used: str):
@authentication_required(mfa_check=False)
@session_provider_manager
async def mfa_authenticate(request: Request, session_provider: ORMSessionProvider) -> Response:
if get_setting("auth.oidc"):
return Response(status_code=404)

"""
MFA authentication page for both TOTP and WebAuthn.
For the TOTP flow, this endpoint processes the form POST request and checks it.
Expand Down Expand Up @@ -177,6 +183,9 @@ async def mfa_authenticate(request: Request, session_provider: ORMSessionProvide
async def webauthn_verify_authentication_response(
request: Request, session_provider: ORMSessionProvider
) -> Response:
if get_setting("auth.oidc"):
return Response(status_code=404)

wn_origin, wn_rpid = get_webauthn_origin_rpid()
try:
expected_challenge = base64.b64decode(request.session[WN_CHALLENGE_SESSION_KEY])
Expand Down Expand Up @@ -218,6 +227,9 @@ async def webauthn_verify_authentication_response(

@authentication_required
async def webauthn_register(request: Request) -> Response:
if get_setting("auth.oidc"):
return Response(status_code=404)

existing_credentials = [
PublicKeyCredentialDescriptor(id=auth.credential_id) for auth in request.auth.user.webauthns
]
Expand Down Expand Up @@ -254,6 +266,9 @@ async def webauthn_register(request: Request) -> Response:
async def webauthn_verify_registration_response(
request: Request, session_provider: ORMSessionProvider
) -> Response:
if get_setting("auth.oidc"):
return Response(status_code=404)

wn_origin, wn_rpid = get_webauthn_origin_rpid()
try:
expected_challenge = base64.b64decode(request.session[WN_CHALLENGE_SESSION_KEY])
Expand Down Expand Up @@ -299,6 +314,9 @@ class WebAuthnRemoveForm(CurrentPasswordForm):
@session_provider_manager
@authentication_required
async def webauthn_remove(request: Request, session_provider: ORMSessionProvider) -> Response:
if get_setting("auth.oidc"):
return Response(status_code=404)

query = session_provider.session.query(AuthWebAuthn)
query = query.filter(
AuthWebAuthn.pk == request.path_params["webauthn"], AuthWebAuthn.user_id == str(request.auth.user.pk)
Expand Down Expand Up @@ -346,6 +364,9 @@ async def validate(self, current_user: AuthUser, totp: Optional[pyotp.totp.TOTP]
@authentication_required
@session_provider_manager
async def totp_register(request: Request, session_provider: ORMSessionProvider) -> Response:
if get_setting("auth.oidc"):
return Response(status_code=404)

form = await TOTPRegisterForm.from_formdata(request=request)
totp_secret = request.session.get(TOTP_REGISTRATION_SECRET_SESSION_KEY, pyotp.random_base32())
totp = pyotp.totp.TOTP(totp_secret)
Expand Down Expand Up @@ -383,6 +404,9 @@ class TOTPRemoveForm(CurrentPasswordForm):
@session_provider_manager
@authentication_required
async def totp_remove(request: Request, session_provider: ORMSessionProvider) -> Response:
if get_setting("auth.oidc"):
return Response(status_code=404)

form = await TOTPRemoveForm.from_formdata(request=request)
if not form.is_submitted() or not await form.validate(current_user=request.auth.user):
return template_context_render(
Expand Down
2 changes: 2 additions & 0 deletions irrd/webui/auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
logout,
reset_password_request,
set_password,
oauth_callback,
)
from .endpoints_mfa import (
mfa_authenticate,
Expand Down Expand Up @@ -56,4 +57,5 @@
name="webauthn_verify_authentication_response",
methods=["POST"],
),
Route("/oauth/callback", oauth_callback, name="oauth_callback", methods=["GET"]),
]
22 changes: 19 additions & 3 deletions irrd/webui/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from base64 import urlsafe_b64decode, urlsafe_b64encode
from datetime import date, timedelta
from typing import Optional, Tuple, Union

import oic.oic
import oic.utils.authn.client
import passlib
import wtforms
from imia import (
Expand All @@ -21,6 +22,7 @@
from starlette_wtf import StarletteForm
from zxcvbn import zxcvbn

from irrd.conf import get_setting
from irrd.storage.models import AuthUser
from irrd.storage.orm_provider import ORMSessionProvider
from irrd.webui.helpers import secret_key_derive
Expand All @@ -34,7 +36,7 @@
class AuthProvider(UserProvider):
async def find_by_id(self, connection: HTTPConnection, identifier: str) -> Optional[UserLike]:
session_provider = ORMSessionProvider()
target = session_provider.session.query(AuthUser).filter_by(email=identifier).options(joinedload("*"))
target = session_provider.session.query(AuthUser).filter_by(pk=identifier).options(joinedload("*"))
user = await session_provider.run(target.one)
session_provider.session.expunge_all()
session_provider.commit_close()
Expand All @@ -43,7 +45,12 @@ async def find_by_id(self, connection: HTTPConnection, identifier: str) -> Optio
async def find_by_username(
self, connection: HTTPConnection, username_or_email: str
) -> Optional[UserLike]:
return await self.find_by_id(connection, username_or_email)
session_provider = ORMSessionProvider()
target = session_provider.session.query(AuthUser).filter_by(email=username_or_email, oidc_sub=None).options(joinedload("*"))
user = await session_provider.run(target.one)
session_provider.session.expunge_all()
session_provider.commit_close()
return user

async def find_by_token(self, connection: HTTPConnection, token: str) -> Optional[UserLike]:
return None # pragma: no cover
Expand All @@ -70,6 +77,15 @@ def _get_hasher(self):
def get_login_manager() -> LoginManager:
return LoginManager(user_provider, password_handler, secret_key_derive("web.login_manager"))

def get_oidc_client() -> oic.oic.Client:
client = oic.oic.Client(
client_authn_method=oic.utils.authn.client.CLIENT_AUTHN_METHOD,
client_id=get_setting("auth.oidc.client_id"),
)
client.provider_config(get_setting("auth.oidc.issuer"))
client.client_secret = get_setting("auth.oidc.client_secret")
return client


authenticators = [
SessionAuthenticator(user_provider=user_provider),
Expand Down
1 change: 1 addition & 0 deletions irrd/webui/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def template_context_render(template_name, request, context) -> Response:
context["messages"] = get_messages(request)
context["readonly_standby"] = get_setting("readonly_standby")
context["irrd_internal_migration_enabled"] = get_setting("auth.irrd_internal_migration_enabled")
context["oidc_enabled"] = bool(get_setting("auth.oidc"))

context["auth_sources"] = [
name for name, settings in get_setting("sources", {}).items() if settings.get("authoritative")
Expand Down
Loading

0 comments on commit 6420faf

Please sign in to comment.