103 lines
3.5 KiB
Python
103 lines
3.5 KiB
Python
from authlib.integrations.flask_oauth2 import (
|
|
AuthorizationServer,
|
|
ResourceProtector,
|
|
)
|
|
from authlib.integrations.sqla_oauth2 import (
|
|
create_query_client_func,
|
|
create_save_token_func,
|
|
create_revocation_endpoint,
|
|
create_bearer_token_validator,
|
|
)
|
|
from authlib.oauth2.rfc6749 import grants
|
|
from authlib.oauth2.rfc7636 import CodeChallenge
|
|
from werkzeug.security import gen_salt
|
|
from ..models import db, User
|
|
from ..models import OAuth2Client, OAuth2AuthorizationCode, OAuth2Token
|
|
|
|
|
|
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
|
TOKEN_ENDPOINT_AUTH_METHODS = [
|
|
'client_secret_basic',
|
|
'client_secret_post',
|
|
'none',
|
|
]
|
|
|
|
def save_authorization_code(self, code, request):
|
|
code_challenge = request.data.get('code_challenge')
|
|
code_challenge_method = request.data.get('code_challenge_method')
|
|
auth_code = OAuth2AuthorizationCode(
|
|
code=code,
|
|
client_id=request.client.client_id,
|
|
redirect_uri=request.redirect_uri,
|
|
scope=request.scope,
|
|
user_id=request.user.id,
|
|
code_challenge=code_challenge,
|
|
code_challenge_method=code_challenge_method,
|
|
)
|
|
db.session.add(auth_code)
|
|
db.session.commit()
|
|
return auth_code
|
|
|
|
def query_authorization_code(self, code, client):
|
|
auth_code = OAuth2AuthorizationCode.query.filter_by(
|
|
code=code, client_id=client.client_id).first()
|
|
if auth_code and not auth_code.is_expired():
|
|
return auth_code
|
|
|
|
def delete_authorization_code(self, authorization_code):
|
|
db.session.delete(authorization_code)
|
|
db.session.commit()
|
|
|
|
def authenticate_user(self, authorization_code):
|
|
return User.query.get(authorization_code.user_id)
|
|
|
|
|
|
class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant):
|
|
def authenticate_user(self, username, password):
|
|
user = User.query.filter_by(username=username).first()
|
|
if user is not None and user.check_password(password):
|
|
return user
|
|
|
|
|
|
class RefreshTokenGrant(grants.RefreshTokenGrant):
|
|
def authenticate_refresh_token(self, refresh_token):
|
|
token = OAuth2Token.query.filter_by(refresh_token=refresh_token).first()
|
|
if token and token.is_refresh_token_active():
|
|
return token
|
|
|
|
def authenticate_user(self, credential):
|
|
return User.query.get(credential.user_id)
|
|
|
|
def revoke_old_credential(self, credential):
|
|
credential.revoked = True
|
|
db.session.add(credential)
|
|
db.session.commit()
|
|
|
|
|
|
query_client = create_query_client_func(db.session, OAuth2Client)
|
|
save_token = create_save_token_func(db.session, OAuth2Token)
|
|
authorization = AuthorizationServer(
|
|
query_client=query_client,
|
|
save_token=save_token,
|
|
)
|
|
require_oauth = ResourceProtector()
|
|
|
|
|
|
def config_oauth(app):
|
|
authorization.init_app(app)
|
|
|
|
# support all grants
|
|
authorization.register_grant(grants.ImplicitGrant)
|
|
authorization.register_grant(grants.ClientCredentialsGrant)
|
|
authorization.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)])
|
|
authorization.register_grant(PasswordGrant)
|
|
authorization.register_grant(RefreshTokenGrant)
|
|
|
|
# support revocation
|
|
revocation_cls = create_revocation_endpoint(db.session, OAuth2Token)
|
|
authorization.register_endpoint(revocation_cls)
|
|
|
|
# protect resource
|
|
bearer_cls = create_bearer_token_validator(db.session, OAuth2Token)
|
|
require_oauth.register_token_validator(bearer_cls())
|