Creating an API with python: Part 9: Authentication Scopes

In my previous post, Creating an API with python: Part 8: Multiple Account Support, I added support for multiple accounts to the FastAPI API. In this post, I will add authentication scopes, which will allow a caller of the API to authenticate as either with an admin or an account scope. The admin scope will have access to all resources, whereas the account scope will only have access to the resources owned by the account that was authenticated.

Prerequisites

These prerequisites are assumed for this post:

  1. Creating an API with python: Part 1: GET Endpoints
  2. Creating an API with python: Part 2: MariaDB Database
  3. Creating an API with python: Part 3: POST Endpoints
  4. Creating an API with python: Part 4: DELETE Endpoints
  5. Creating an API with python: Part 5: Authentication
  6. Creating an API with python: Part 6: HTTPS and Proxying
  7. Creating an API with python: Part 7: CORS
  8. Creating an API with python: Part 8: Multiple Account Support

Step 1: Update schemas.py

Update schemas.py with a new schema class AuthEntity, new Enum EntityType and update the account_id fields in the other classes to be Optional, as they will now only be required when authenticating with the admin scope.

  1. Change directory to the manager directory:
    $ cd manager
    
  2. Open the schemas.py file and replace the code with the following:
    from typing import Optional, Union, List
    from enum import Enum
    
    from pydantic import BaseModel, Field
    
    from fastapi import HTTPException
    
    
    class PostLink(BaseModel):
        link: str = Field(..., description="The link URL")
        tag: Optional[str] = Field(None,
                                   description="Tag name to associate with the link (will be created if it doesn't exist)")
        tag_id: Optional[str] = Field(None, description="Tag ID to associate with the link (must already exist)")
        account_id: Optional[str] = Field(
            None,
            description="The account ID. Not required for account scope. Required for admin scope.")
    
    
    class PostTag(BaseModel):
        tag: str = Field(..., description="Tag name")
        account_id: Optional[str] = Field(
            None,
            description="The account ID. Not required for account scope. Required for admin scope.")
    
    
    class PostTagLink(BaseModel):
        tag_id: str = Field(..., description="Tag ID (must already exist)")
        link_id: str = Field(..., description="Link ID (must already exist)")
        account_id: Optional[str] = Field(
            None,
            description="The account ID. Not required for account scope. Required for admin scope.")
    
    
    class PostAccount(BaseModel):
        email: str = Field(..., description="Email Address for the new account")
        password: str = Field(..., description="Password for the new account")
    
    
    class Token(BaseModel):
        access_token: str
        token_type: str
        expires: str
    
    
    class TokenData(BaseModel):
        username: Union[str, None] = None
        scopes: List[str] = []
        account_id: Union[str, None] = None
    
    
    class User(BaseModel):
        user_id: str
        username: str
    
        class Config:
            orm_mode = True
    
    
    class EntityType(str, Enum):
        USER = 'user'
        ACCOUNT = 'account'
    
    
    class AuthEntity(BaseModel):
        entity_type: EntityType
        entity_id: str
        entity_identifier: str
        hashed_password: str
    
        def get_account_id(self):
            # Get the account id if scope is account, otherwise None
            if self.entity_type == EntityType.ACCOUNT:
                return self.entity_id
            return None
    
        def assert_account_id(self, required: bool = True, account_id: Optional[str] = None, code: Optional[int] = 422):
            # Check if a provided account_id matches the AuthEntity account id, if scope account
            valid = True
            msg_422 = "Invalid account_id"
            if account_id is not None and self.entity_type == EntityType.ACCOUNT:
                if account_id != self.entity_id:
                    msg_422 = "The account_id field should not be provided for account scope"
                    valid = False
            # If account_id is required and this is scope user, check we have an account_id
            if self.entity_type == EntityType.USER and required and account_id is None:
                valid = False
                msg_422 = "The account_id field is required for the admin scope"
    
            if not valid:
                if code == 422:
                    raise HTTPException(status_code=422, detail=msg_422)
                elif code == 404:
                    raise HTTPException(status_code=404, detail=f"Account with account_id '{account_id}' not found")
                else:
                    raise HTTPException(status_code=code, detail="Invalid account_id")
    
            return account_id if account_id is not None else self.get_account_id()
    

Step 2: Update manager.py

Update manager.py with an updated version of get_accounts.

  1. Open the manager.py file and replace the get_accounts function with the following:
    def get_accounts(db: Session, email: Optional[str] = None, account_id: Optional[str] = None):
        filters = []
        if email is None and account_id is None:
            # TODO: Implement offset and limit
            return db.query(models.Account).all()
        if email is not None:
            filters.append(models.Account.email == email)
        if account_id is not None:
            filters.append(models.Account.account_id == account_id)
        return db.query(models.Account).filter(*filters).all()
    
    

Step 3: Update authentication.py

Update authentication.py to retrieve and use security scopes and the new AuthEntity class to determine authentication and authorisation.

  1. Open the authentication.py file and replace the code with the following:
    from datetime import datetime, timedelta
    from typing import Union, Optional, List
    
    from fastapi import Depends, FastAPI, HTTPException, status
    from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes
    from jose import JWTError, jwt
    from passlib.context import CryptContext
    
    from sqlalchemy.orm import Session
    
    from manager import CONFIG, schemas, models
    from manager.database import get_db
    from manager.schemas import EntityType, AuthEntity
    
    
    AUTH_CONFIG = CONFIG['authentication']
    
    SECRET_KEY = AUTH_CONFIG['secret_key']
    ALGORITHM = "HS256"
    ACCESS_TOKEN_EXPIRE_MINUTES = 30
    SCOPE_ACCOUNT = 'account'
    SCOPE_ADMIN = 'admin'
    SCOPES = {SCOPE_ACCOUNT: 'API actions for a specific account', SCOPE_ADMIN: 'All API actions'}
    
    
    pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
    
    oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", scopes=SCOPES)
    
    app = FastAPI()
    
    
    class AuthenticationException(Exception):
        def __init__(self, msg: str, *args):
            super().__init__(self, msg, *args)
    
    
    def verify_password(plain_password, hashed_password):
        return pwd_context.verify(plain_password, hashed_password)
    
    
    def get_password_hash(password):
        return pwd_context.hash(password)
    
    
    def get_user(db: Session, username: str):
        return db.query(models.User).filter(models.User.username == username).first()
    
    
    def get_account(db: Session, email: Optional[str] = None, account_id: Optional[str] = None):
        if email is not None:
            return db.query(models.Account).filter(models.Account.email == email).first()
        elif account_id is not None:
            return db.query(models.Account).filter(models.Account.account_id == account_id).first()
        return None
    
    
    def get_auth_entity(db: Session, identifier: str, security_scopes: List[str], entity_id: Optional[str] = None):
        auth_entity = None
        if SCOPE_ADMIN in security_scopes:
            user = get_user(db, identifier)
            if user:
                auth_entity = AuthEntity(entity_type=EntityType.USER, entity_id=user.user_id, entity_identifier=identifier,
                                         hashed_password=user.hashed_password)
        elif SCOPE_ACCOUNT in security_scopes:
            account = get_account(db, email=identifier, account_id=entity_id)
            if account:
                auth_entity = AuthEntity(entity_type=EntityType.ACCOUNT, entity_id=account.account_id,
                                         entity_identifier=identifier, hashed_password=account.hashed_password)
    
        if not auth_entity:
            scopes = ", ".join(security_scopes)
            msg = f"Cannot get AuthEntity for scopes {scopes} with identifier {identifier}"
            raise AuthenticationException(msg)
        return auth_entity
    
    
    def authenticate(db: Session, identifier: str, password: str, security_scopes: List[str]):
        # Only one scope should be set
        if len(security_scopes) != 1:
            print("Too many security scopes set")
            return False
        try:
            auth_entity = get_auth_entity(db, identifier=identifier, security_scopes=security_scopes)
        except AuthenticationException as ex:
            print('Caught AuthenticationException' + str(ex))
            return False
        if not auth_entity:
            print("No auth_entity found")
            return False
        if not verify_password(password, auth_entity.hashed_password):
            print("Password not verified")
            return False
        return auth_entity
    
    
    def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None):
        to_encode = data.copy()
        if expires_delta:
            expire = datetime.utcnow() + expires_delta
        else:
            expire = datetime.utcnow() + timedelta(minutes=15)
        to_encode.update({"exp": expire})
        encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
        return encoded_jwt
    
    
    async def get_current_auth_entity(security_scopes: SecurityScopes, db: Session = Depends(get_db),
                                      token: str = Depends(oauth2_scheme)):
        if security_scopes.scopes:
            authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
        else:
            authenticate_value = f'Bearer'
        credentials_exception = HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Could not validate credentials",
            headers={"WWW-Authenticate": authenticate_value},
        )
        try:
            payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
            username: str = payload.get("sub")
            if username is None:
                raise credentials_exception
            token_scopes = payload.get("scopes", [])
            account_id = payload.get("account_id", None)
            token_data = schemas.TokenData(username=username, scopes=token_scopes, account_id=account_id)
        except JWTError:
            raise credentials_exception
        # Only one scope should be set
        if len(token_scopes) != 1:
            raise credentials_exception
    
        try:
            auth_entity = get_auth_entity(db, identifier=token_data.username, security_scopes=token_data.scopes,
                                          entity_id=token_data.account_id)
        except RuntimeError:
            raise credentials_exception
    
        if auth_entity is None:
            raise credentials_exception
        print(security_scopes.scopes)
        print(token_data.scopes)
        token_scope = token_scopes[0]
        if token_scope not in security_scopes.scopes:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Not enough permissions",
                headers={"WWW-Authenticate": authenticate_value},
            )
        return auth_entity
    
    
    async def get_current_active_auth_entity(current_auth_entity: schemas.AuthEntity = Depends(get_current_auth_entity)):
        return current_auth_entity
    

Step 4: Update main.py

Update main.py such that each endpoint uses the AuthEntity object instead of the User object, and specifies which security scopes can access each endpoint. Each endpoint also handles asserting that the account_id has been provided (if required) and matches that of the authenticated account if it has been provided. If it hasn’t been provided, and the authenticated entity is an account, it will be automatically determined from the account details.

  1. Open the main.py file and replace the code with the following:
    from typing import Optional
    from datetime import timedelta, datetime
    
    from sqlalchemy.orm import Session
    
    from fastapi import FastAPI, HTTPException, Depends, status, Security
    
    from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes
    
    from fastapi.middleware.cors import CORSMiddleware
    
    from manager import manager, schemas, authentication, CONFIG
    
    from manager.database import get_db
    
    from manager.authentication import SCOPE_ACCOUNT
    
    from manager.schemas import EntityType
    
    
    app = FastAPI()
    
    origins = [origin for origin in CONFIG['origins']]
    
    app.add_middleware(
        CORSMiddleware,
        allow_origins=origins,
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )
    
    
    # Get link by link_id
    @app.get("/link/{link_id}")
    async def get_link(link_id: str, db: Session = Depends(get_db),
                       current_auth_entity: schemas.AuthEntity = Security(authentication.get_current_active_auth_entity,
                                                                          scopes=["admin", "account"])):
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        # If account scope, additionally filter by account_id
        db_link = manager.get_link(db, link_id, account_id=current_auth_entity.get_account_id())
        if db_link is None:
            raise HTTPException(status_code=404, detail="Link not found")
        return db_link
    
    
    # Get links by query params
    @app.get("/link/")
    async def get_links(tag_id: Optional[str] = None, tag: Optional[str] = None, account_id: Optional[str] = None,
                        db: Session = Depends(get_db),
                        current_auth_entity: schemas.AuthEntity = Security(authentication.get_current_active_auth_entity,
                                                                           scopes=["admin", "account"])):
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        filter_account_id = current_auth_entity.assert_account_id(required=False, account_id=account_id)
        return manager.get_links(db, tag_id, tag, account_id=filter_account_id)
    
    
    # Get tag by tag_id
    @app.get("/tag/{tag_id}")
    async def get_tag(tag_id: str, db: Session = Depends(get_db),
                      current_auth_entity: schemas.AuthEntity = Security(authentication.get_current_active_auth_entity,
                                                                         scopes=["admin", "account"])):
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        db_tag = manager.get_tag(db, tag_id, account_id=current_auth_entity.get_account_id())
        if db_tag is None:
            raise HTTPException(status_code=404, detail="Tag not found")
        return db_tag
    
    
    # Get tags by query params
    @app.get("/tag/")
    async def get_tags(tag: Optional[str] = None, account_id: Optional[str] = None, db: Session = Depends(get_db),
                       current_auth_entity: schemas.AuthEntity = Security(authentication.get_current_active_auth_entity,
                                                                          scopes=["admin", "account"])):
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        filter_account_id = current_auth_entity.assert_account_id(required=False, account_id=account_id)
        return manager.get_tags(db, tag, account_id=filter_account_id)
    
    
    # Get taglinks by query params
    @app.get("/taglink/")
    async def get_taglinks(link_id: Optional[str] = None, tag_id: Optional[str] = None, account_id: Optional[str] = None,
                           db: Session = Depends(get_db),
                           current_auth_entity: schemas.AuthEntity = Security(authentication.get_current_active_auth_entity,
                                                                              scopes=["admin", "account"])):
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        filter_account_id = current_auth_entity.assert_account_id(required=False, account_id=account_id)
        return manager.get_taglinks(db, tag_id, link_id, account_id=filter_account_id)
    
    
    # Post a link
    @app.post("/link/")
    async def post_link(link: schemas.PostLink, db: Session = Depends(get_db),
                        current_auth_entity: schemas.AuthEntity = Security(authentication.get_current_active_auth_entity,
                                                                           scopes=["admin", "account"])):
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        current_auth_entity.assert_account_id(required=True, account_id=link.account_id)
        if current_auth_entity.entity_type == EntityType.ACCOUNT:
            link.account_id = current_auth_entity.get_account_id()
    
        if link.tag is None and link.tag_id is None:
            raise HTTPException(status_code=422, detail="One of tag_id or tag must be specified")
    
        if link.tag is not None and link.tag_id is not None:
            raise HTTPException(status_code=422, detail="Only one of tag_id or tag must be specified")
    
        db_link = manager.create_link(db, link)
    
        return db_link
    
    
    # Post a tag
    @app.post("/tag/")
    async def post_tag(tag: schemas.PostTag, db: Session = Depends(get_db),
                       current_auth_entity: schemas.AuthEntity = Security(authentication.get_current_active_auth_entity,
                                                                          scopes=["admin", "account"])):
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        current_auth_entity.assert_account_id(required=True, account_id=tag.account_id)
        if current_auth_entity.entity_type == EntityType.ACCOUNT:
            tag.account_id = current_auth_entity.get_account_id()
        db_tag = manager.create_tag(db, tag)
    
        return db_tag
    
    
    # Post a taglink
    @app.post("/taglink/")
    async def post_taglink(taglink: schemas.PostTagLink, db: Session = Depends(get_db),
                           current_auth_entity: schemas.AuthEntity = Security(authentication.get_current_active_auth_entity,
                                                                              scopes=["admin", "account"])):
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        current_auth_entity.assert_account_id(required=True, account_id=taglink.account_id)
        if current_auth_entity.entity_type == EntityType.ACCOUNT:
            taglink.account_id = current_auth_entity.get_account_id()
        db_tag = manager.create_taglink(db, taglink)
    
        return db_tag
    
    
    # Delete link by link_id
    @app.delete("/link/{link_id}")
    async def delete_link(link_id: str, db: Session = Depends(get_db),
                          current_auth_entity: schemas.AuthEntity = Security(authentication.get_current_active_auth_entity,
                                                                             scopes=["admin", "account"])):
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        return manager.delete_link(db, link_id, account_id=current_auth_entity.get_account_id())
    
    
    # Delete tag by tag_id
    @app.delete("/tag/{tag_id}")
    async def delete_tag(tag_id: str, db: Session = Depends(get_db),
                         current_auth_entity: schemas.AuthEntity = Security(authentication.get_current_active_auth_entity,
                                                                            scopes=["admin", "account"])):
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        return manager.delete_tag(db, tag_id, account_id=current_auth_entity.get_account_id())
    
    
    # Delete taglinks by query params
    @app.delete("/taglink/")
    async def delete_taglinks(link_id: Optional[str] = None, tag_id: Optional[str] = None,  db: Session = Depends(get_db),
                              current_auth_entity: schemas.AuthEntity = Security(
                                  authentication.get_current_active_auth_entity, scopes=["admin", "account"])):
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        if link_id is None and tag_id is None:
            raise HTTPException(status_code=422, detail="One or both of tag_id and link_id must be specified")
        return manager.delete_taglinks(db, tag_id, link_id, account_id=current_auth_entity.get_account_id())
    
    
    @app.post("/token/", response_model=schemas.Token)
    async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
        auth_entity = authentication.authenticate(db, form_data.username, form_data.password, form_data.scopes)
        if not auth_entity:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Incorrect username or password",
                headers={"WWW-Authenticate": "Bearer"},
            )
        access_token_expires = timedelta(minutes=authentication.ACCESS_TOKEN_EXPIRE_MINUTES)
        token_data = {"sub": auth_entity.entity_identifier, "scopes": form_data.scopes, "account_id": None}
        if SCOPE_ACCOUNT in form_data.scopes:
            token_data['account_id'] = auth_entity.entity_id
    
        access_token = authentication.create_access_token(
            data=token_data, expires_delta=access_token_expires
        )
        expires = datetime.utcnow() + access_token_expires
        return {"access_token": access_token, "token_type": "bearer", "expires": expires.isoformat()}
    
    
    # Create a new account
    @app.post("/account/")
    async def post_account(account: schemas.PostAccount, db: Session = Depends(get_db),
                           current_auth_entity: schemas.AuthEntity = Security(authentication.get_current_active_auth_entity,
                                                                              scopes=["admin"])):
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        db_account = manager.create_account(db, account)
    
        return db_account
    
    
    # Get accounts by query params
    @app.get("/account/")
    async def get_accounts(email: Optional[str] = None, db: Session = Depends(get_db),
                           current_auth_entity: schemas.AuthEntity = Security(authentication.get_current_active_auth_entity,
                                                                              scopes=["admin", "account"])):
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        # Allow account scope to retrieve own account only. Return empty list if email and account_id do not match.
        return manager.get_accounts(db, email, account_id=current_auth_entity.get_account_id())
    
    
    # Get account by account_id
    @app.get("/account/{account_id}")
    async def get_account(account_id: str, db: Session = Depends(get_db),
                          current_auth_entity: schemas.AuthEntity = Security(authentication.get_current_active_auth_entity,
                                                                             scopes=["admin", "account"])):
        # Allow account scope to get own account only
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        current_auth_entity.assert_account_id(required=True, account_id=account_id, code=404)
        db_account = manager.get_account(db, account_id)
        if db_account is None:
            raise HTTPException(status_code=404, detail="Account not found")
        return db_account
    
    
    # Delete account by account_id
    @app.delete("/account/{account_id}")
    async def delete_account(account_id: str, db: Session = Depends(get_db),
                             current_auth_entity: schemas.AuthEntity = Security(
                                 authentication.get_current_active_auth_entity, scopes=["admin", "account"])):
        # Allow account scope to delete own account only. Return 404 for other accounts.
        print(f"authenticated as {current_auth_entity.entity_identifier}")
        current_auth_entity.assert_account_id(required=True, account_id=account_id, code=404)
        return manager.delete_account(db, account_id)
    
    

Step 5: Start FastAPI

  1. On your server, change to the code directory (~/vboxshare/fastapi should be replaced with the path to your FastAPI python code):
    $ cd ~/vboxshare/fastapi
    
  2. Run the FastAPI server:
    $ . ~/.venv-fastapi/bin/activate
    (.venv-fastapi) $ uvicorn --host 0.0.0.0 main:app --root-path /api --reload

Step 6: Test the API

Navigate to https://<YOUR_IP>/api/docs (replace YOUR_IP with your server IP). Try authenticating with the admin user (the original user you created) and creating an account (POST /account). Then try logging out as the admin user and authenticating as the account user (with the email and password that you set for the account). You should find that you can only retrieve and delete resources belonging to that account.

Conclusion

You should now have a FastAPI API running with authentication scopes, which means that different levels of access are enabled depending on who is authenticating.

If you want to find out about adding integration tests to the API, see my follow-on post, Creating an API with python: Part 10: Integration Tests.

Thanks for reading!