You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
2.3 KiB

import hashlib
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from joserfc import jwt
from joserfc.errors import JoseError
from joserfc.jwk import OctKey
from datetime import datetime, timedelta
import pymongo
from app.models import User
import app.config as config
from app.models.token_data import TokenData
from app.serializers import user_serialize
# Database setup
client = pymongo.MongoClient(config.MONGODB_URL, username=config.MONGODB_USERNAME, password=config.MONGODB_PASSWORD)
db = client[config.MONGODB_DATABASE]
users_collection = db["users"]
def verify_password(plain_password, hashed_password):
return hashlib.sha256(plain_password.encode()).hexdigest() == hashed_password
def get_password_hash(password):
return hashlib.sha256(password.encode()).hexdigest()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=config.TOKEN_URL)
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, OctKey.import_key(config.SECRET_KEY))
username: str = payload.claims["sub"]
expire_date = payload.claims["exp"]
if username is None or int(datetime.now().timestamp()) > expire_date:
raise credentials_exception
token_data = TokenData(username=username)
except JoseError:
raise credentials_exception
user = users_collection.find_one({"username": token_data.username})
if user is None:
raise credentials_exception
return user_serialize(user)
def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.now() + expires_delta
to_encode.update({"exp": expire})
header = {"alg": config.ALGORITHM}
encoded_jwt = jwt.encode(header, to_encode, OctKey.import_key(config.SECRET_KEY))
return encoded_jwt
# Exceptions
def friend_not_found():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Friend not found"
)
def objectid_misformatted():
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="The ObjectID is misformatted"
)