Nest JS Websockets - Guards and Authorization with CASL

Authors
  • avatar
    Name
    Austin Howard
Updated on

Welcome to Part 4: Guards and Authorization with CASL, of building a realtime chat application with Nest JS and React.

In this part of the series, we’re going to introduce a new feature for the application which allows room hosts to kick users out of the room. For this, we want to implement a strategy for authorization to enable this feature on the server as we only want room hosts to be able to kick others - imagine a chat room where any user can kick any other user! That would be absolute chaos. We’ve also done some light refactoring on the server, added a new kick button and a new socket event to handle kicking users.

This series will attempt to touch on as many features of Nest's websockets integration as well as socket.io in general.

The entire series will be building on one repository that has both the server and the client for the application, so for each part you can download the repository from specific commits to follow along or you can clone the entire repository in it's finished state.

Download the repository at this commit to follow along - we’re not going to cover everything to keep the focus on authorization.

Authorization

Authorization - not to be confused with authentication, is a process that determines what a user is able to do in an application. In this series part we are not covering authentication which more-so involves user identity, and general access to the application.

The nest documentation on authorization covers 3 strategies:

  1. Basic RBAC implementation (role-based access control).
  2. Claims-based authorization.
  3. Attribute-based access control with the CASL library.

For our use case, allowing only users that are hosts of rooms to kick other users out of the room, the RBAC and claims-based authorization strategies aren’t going to be enough.

Attribute-based access control is a much more powerful strategy for building more context specific authorization cases so that’s what we’ll use.

We will be using the "@casl/ability": "^6.3.3" package to build out our new feature.

Guards

Guards in Nest JS are kind of like bouncers at a bar.

Before anyone goes into the bar, a bouncer might:

  • Check the person’s identification to make sure they’re old enough to drink alcohol (in the United States at least).
  • Check that the person is up to par with the dress code.
  • Make sure the person isn’t already too intoxicated.
  • Check any other arbitrary requirements they may have for entrance.

Guards are similar except for request data. Based on the rules of the authentication strategy, guards will typically get a reference to the client/user and the target resource, and run that relationship through the relevant rule checks. If one of those rule checks fail, the guard should fail with a forbidden exception - indicating that the client is forbidden from access or performing a certain action in relation to the resource.

Request Lifecycle

In the last part of this series we created pipes for our chat websocket gateway. Since we’ll now be adding a guard - we should review the request lifecycle for a Nest JS server.

The request lifecycle is has a bit more nuance than our graphic below, but for brevity this pretty much covers the flow. For more in depth on the request lifecycle check out the nest docs.

The pipes we added in the last part execute just before our gateway’s handler functions to validate the form of the input data.

We have yet to add any middleware or interceptors, but what we want now is a guard to validate that a user can access/mutate resources.

Nest JS request lifecycle diagram

The Implementation

We will not be covering the entirety of the application to avoid rehashing all of the work done in the previous parts of the series - so let’s focus on what’s new.

Entities

Our chat application really has two entities - user and room. Even though we don’t have any kind of real database yet, we do have temporary data stores in the services themselves. If we were modeling a relational database for this application, we would most likely have at least two tables - a rooms table and a users table. Our entities would then describe the attributes that a row in each of the tables would consist of.

For this reason it makes sense for us to introduce these two entities.

First we have the roomentity which implements RoomType which we had already written before.

src/server/entities/room.entity.ts

import { RoomName, Room as RoomType, User } from '../../shared/interfaces/chat.interface'

export class Room implements RoomType {
  constructor(attrs: RoomType) {
    Object.assign(this, attrs)
  }
  name: RoomName
  host: User
  users: User[]
}

Secondly we have the user entity which implements UserType.

src/server/entities/user.entity.ts

import {
  SocketId,
  User as UserType,
  UserId,
  UserName,
} from '../../shared/interfaces/chat.interface'

export class User implements UserType {
  constructor(attrs: UserType) {
    Object.assign(this, attrs)
  }
  userId: UserId
  userName: UserName
  socketId: SocketId
}

Casl Ability Factory

With our services refactoring out of the way, let’s get to the main topic of this part of the series - the implementation of CASL to define some authorization rules.

As mentioned earlier, CASL employs an authorization strategy called attribute-based access control. This means we are able to build rules around specific attributes of entities.

src/server/casl/casl-ability.factory.ts

import {
  AbilityBuilder,
  createMongoAbility,
  ExtractSubjectType,
  InferSubjects,
  MongoAbility,
} from '@casl/ability'
import { Injectable } from '@nestjs/common'
import { Room } from '../entities/room.entity'
import { User } from '../entities/user.entity'

export enum Action {
  Kick = 'kick',
  Join = 'join',
  Message = 'message',
}

type Subjects = InferSubjects<typeof Room | typeof User> | 'all'
export type AppAbility = MongoAbility<[Action, Subjects]>
type FlatRoom = Room & {
  'host.userId': Room['host']['userId']
}

@Injectable()
export class CaslAbilityFactory {
  createForUser(user: User) {
    const { can, build } = new AbilityBuilder<AppAbility>(createMongoAbility)

    // Host can kick users from room
    can<FlatRoom>(Action.Kick, Room, {
      'host.userId': user.userId,
    })

    // Any user can join any room
    can(Action.Join, Room)

    // User can send messages in room given they are in the roo
    can(Action.Message, Room, {
      users: { $elemMatch: { userId: user.userId } },
    })

    return build({
      detectSubjectType: (object) => object.constructor as ExtractSubjectType<Subjects>,
    })
  }
}

Let’s step through our ability factory.

With CASL we should define a set of actions that can be performed on a set of subjects.

First we start out by defining an enum of the possible actions a user can perform.

A user can join, kick user from, and send message to a room.

Next we define some key types to use in the factory.

export enum Action {
  Kick = 'kick',
  Join = 'join',
  Message = 'message',
}

type Subjects = InferSubjects<typeof Room | typeof User> | 'all'
export type AppAbility = MongoAbility<[Action, Subjects]>
type FlatRoom = Room & {
  'host.userId': Room['host']['userId']
}

Our factory implements a createForUser method that takes a user. We’ll be able to use this method for users to define the abilities that they can have given certain attributes. CASL has a handy AbilityBuilder class that we can use that will provide us with can, cannot (although we aren’t using it), and build. With these we can define what a user can and cannot do.

Under the hood CASL uses MongoDB query language, although you don’t need to understand MongoDB and we of course aren’t using any databases - this is just a powerful way to create rules with entity attributes.

@Injectable()
export class CaslAbilityFactory {
  createForUser(user: User) {
    const { can, build } = new AbilityBuilder<AppAbility>(createMongoAbility)

    // Host can kick users from room
    can<FlatRoom>(Action.Kick, Room, {
      'host.userId': user.userId,
    })

    // Any user can join any room
    can(Action.Join, Room)

    // User can send messages in room given they are in the room
    can(Action.Message, Room, {
      users: { $elemMatch: { userId: user.userId } },
    })

    return build({
      detectSubjectType: (object) => object.constructor as ExtractSubjectType<Subjects>,
    })
  }
}

From our CASL module we need to make sure our factory is included in both the providers and the exports array. Including it in the exports array will allow us to inject it into other classes which we will do shortly.

src/server/casl/casl.module.ts

import { Module } from '@nestjs/common'
import { CaslAbilityFactory } from './casl-ability.factory'

@Module({
  providers: [CaslAbilityFactory],
  exports: [CaslAbilityFactory],
})
export class CaslModule {}

Finally we also have some policy types and interfaces that we will be using shortly.

src/server/casl/interfaces/policy.interface.ts

import { AppAbility } from '../casl-ability.factory'

interface IPolicyHandler {
  handle(ability: AppAbility): boolean
}

type PolicyHandlerCallback = (ability: AppAbility) => boolean

export type PolicyHandler = IPolicyHandler | PolicyHandlerCallback

Chat Policy Guard

Now that we have our ability factory, we can build a guard for out chat websocket gateway to process incoming payloads - and either reject the user from performing an action or let it pass through if the user is authorized.

src/server/chat/guards/chat.guard.ts

import { CanActivate, ExecutionContext, ForbiddenException, Injectable } from '@nestjs/common'
import { Action, AppAbility, CaslAbilityFactory } from '../../casl/casl-ability.factory'
import { RoomService } from '../../room/room.service'
import { PolicyHandler } from '../../casl/interfaces/policy.interface'
import {
  ClientToServerEvents,
  Room as RoomType,
  User,
} from '../../../shared/interfaces/chat.interface'
import { Room } from '../../entities/room.entity'

@Injectable()
export class ChatPoliciesGuard<
  CtxData extends {
    user: User
    roomName: RoomType['name']
    eventName: keyof ClientToServerEvents
  }
> implements CanActivate
{
  constructor(private caslAbilityFactory: CaslAbilityFactory, private roomService: RoomService) {}

  async canActivate(context: ExecutionContext): Promise<boolean> {
    const policyHandlers: PolicyHandler[] = []
    const ctx = context.switchToWs()
    const data = ctx.getData<CtxData>()
    const user = data.user
    const room = await this.roomService.getRoomByName(data.roomName)

    if (data.eventName === 'kick_user') {
      if (room === 'Not Exists') {
        throw `Room must exist to evaluate ${data.eventName} policy`
      }
      policyHandlers.push((ability) => ability.can(Action.Kick, room))
    }

    if (data.eventName === 'join_room') {
      policyHandlers.push((ability) => ability.can(Action.Join, Room))
    }

    if (data.eventName === 'chat') {
      if (room === 'Not Exists') {
        throw `Room must exist to evaluate ${data.eventName} policy`
      }
      policyHandlers.push((ability) => ability.can(Action.Message, room))
    }

    const ability = this.caslAbilityFactory.createForUser(user)
    policyHandlers.every((handler) => {
      const check = this.execPolicyHandler(handler, ability)
      if (check === false) {
        throw new ForbiddenException()
      }
    })
    return true
  }

  private execPolicyHandler(handler: PolicyHandler, ability: AppAbility) {
    if (typeof handler === 'function') {
      return handler(ability)
    }
    return handler.handle(ability)
  }
}

Our chat guard looks complex, but if we look at the code piece by piece it’s actually quite simple!

First, guards are injectable so we use the @Injectable() class decorator. We make this class a generic class by defining CtxData as a type parameter. We want to allow the instantiation to provide the type for context data in the instantiation of the class. With our caslAbilityFactory we’ll need the user(subject), the roomName (room is another subject) and the eventName (which will map to the actions enum we created earlier). To ensure the generic type parameter CtxData includes these fields, we apply the generic constraint with extends.

@Injectable()
export class ChatPoliciesGuard<
  CtxData extends {
    user: User;
    roomName: RoomType['name'];
    eventName: keyof ClientToServerEvents;
  },
> implements CanActivate
{
  constructor(
    private caslAbilityFactory: CaslAbilityFactory,
    private roomService: RoomService,
  ) {}
//.....

Guards implement a canActivate method that takes a context which in Nest terms in the execution context. This method will return a boolean - false indicates that the user did not pass the authorization checks, true indicates the user did pass the authorization checks. policyHandlers are an array of CASL ability function checks that we will use to do authorization rule checks. Since our application is a websockets application, we switch to a websockets context but setting const ctx = context.switchToWs();. This is important because now we can extract the incoming payload of the event which will give us access to all the data we need to check rules against the ability factory we created earlier. Based on the eventName we push an ability check handler to our policyHandlers, and at the end of the method we loop through the handlers. If one of the handlers returns false, we throw a new ForbiddenException.

async canActivate(context: ExecutionContext): Promise<boolean> {
    const policyHandlers: PolicyHandler[] = [];
    const ctx = context.switchToWs();
    const data = ctx.getData<CtxData>();
    const user = data.user;
    const room = await this.roomService.getRoomByName(data.roomName);

		if (data.eventName === 'kick_user') {
      if (room === 'Not Exists') {
        throw `Room must exist to evaluate ${data.eventName} policy`;
      }
      policyHandlers.push((ability) => ability.can(Action.Kick, room));
    }

    if (data.eventName === 'join_room') {
      policyHandlers.push((ability) => ability.can(Action.Join, Room));
    }

    if (data.eventName === 'chat') {
      if (room === 'Not Exists') {
        throw `Room must exist to evaluate ${data.eventName} policy`;
      }
      policyHandlers.push((ability) => ability.can(Action.Message, room));
    }

    const ability = this.caslAbilityFactory.createForUser(user);
    policyHandlers.every((handler) => {
      const check = this.execPolicyHandler(handler, ability);
      if (check === false) {
        throw new ForbiddenException();
      }
    });
    return true;
  }

Finally, here is the helper function we use to run the policy handler functions.

private execPolicyHandler(handler: PolicyHandler, ability: AppAbility) {
    if (typeof handler === 'function') {
      return handler(ability);
    }
    return handler.handle(ability);
  }

Chat Websocket Gateway

Again, it’s highly encouraged to go back to the last part of the series to get up to speed on how our chat gateway works - but here it is.

In our gateway event handler functions, we can now make use of the ChatPoliciesGuard that we created with @UseGuards(ChatPoliciesGuard<PayloadType>).

src/server/chat/chat.gateway.ts

import {
  MessageBody,
  SubscribeMessage,
  WebSocketGateway,
  WebSocketServer,
  OnGatewayConnection,
  OnGatewayDisconnect,
} from '@nestjs/websockets'
import { Logger, UseGuards, UsePipes } from '@nestjs/common'
import {
  ServerToClientEvents,
  ClientToServerEvents,
  Message,
  JoinRoom,
  KickUser,
} from '../../shared/interfaces/chat.interface'
import { Server, Socket } from 'socket.io'
import { RoomService } from '../room/room.service'
import { ZodValidationPipe } from '../pipes/zod.pipe'
import { ChatMessageSchema, JoinRoomSchema, KickUserSchema } from '../../shared/schemas/chat.schema'
import { UserService } from '../user/user.service'
import { ChatPoliciesGuard } from './guards/chat.guard'

@WebSocketGateway({
  cors: {
    origin: '*',
  },
})
export class ChatGateway implements OnGatewayConnection, OnGatewayDisconnect {
  constructor(private roomService: RoomService, private userService: UserService) {}

  @WebSocketServer() server: Server = new Server<ServerToClientEvents, ClientToServerEvents>()

  private logger = new Logger('ChatGateway')

  @UseGuards(ChatPoliciesGuard<Message>)
  @UsePipes(new ZodValidationPipe(ChatMessageSchema))
  @SubscribeMessage('chat')
  async handleChatEvent(
    @MessageBody()
    payload: Message
  ): Promise<void> {
    this.logger.log(payload)
    this.server.to(payload.roomName).emit('chat', payload) // broadcast messages
  }

  @UseGuards(ChatPoliciesGuard<JoinRoom>)
  @UsePipes(new ZodValidationPipe(JoinRoomSchema))
  @SubscribeMessage('join_room')
  async handleSetClientDataEvent(
    @MessageBody()
    payload: JoinRoom
  ): Promise<void> {
    if (payload.user.socketId) {
      this.logger.log(`${payload.user.socketId} is joining ${payload.roomName}`)
      await this.userService.addUser(payload.user)
      await this.server.in(payload.user.socketId).socketsJoin(payload.roomName)
      await this.roomService.addUserToRoom(payload.roomName, payload.user.userId)
    }
  }

  @UseGuards(ChatPoliciesGuard<KickUser>)
  @UsePipes(new ZodValidationPipe(KickUserSchema))
  @SubscribeMessage('kick_user')
  async handleKickUserEvent(@MessageBody() payload: KickUser): Promise<boolean> {
    this.logger.log(`${payload.userToKick.userName} is getting kicked from ${payload.roomName}`)
    await this.server.to(payload.roomName).emit('kick_user', payload)
    await this.server.in(payload.userToKick.socketId).socketsLeave(payload.roomName)
    await this.server.to(payload.roomName).emit('chat', {
      user: {
        userId: 'serverId',
        userName: 'TheServer',
        socketId: 'ServerSocketId',
      },
      timeSent: new Date(Date.now()).toLocaleString('en-US'),
      message: `${payload.userToKick.userName} was kicked.`,
      roomName: payload.roomName,
    })
    return true
  }

  async handleConnection(socket: Socket): Promise<void> {
    this.logger.log(`Socket connected: ${socket.id}`)
  }

  async handleDisconnect(socket: Socket): Promise<void> {
    const user = await this.roomService.getFirstInstanceOfUser(socket.id)
    if (user) {
      await this.userService.removeUserById(user.userId)
    }
    await this.roomService.removeUserFromAllRooms(socket.id)
    this.logger.log(`Socket disconnected: ${socket.id}`)
  }
}

Schemas and Interfaces

Now that our websocket message payloads must include the user, roomName, and eventName for our chat policies guard to work, we need to expand on our schemas and interfaces.

src/shared/schemas/chat.schema.ts

import { z } from 'zod'

export const UserIdSchema = z.string().min(1).max(24)

export const UserNameSchema = z
  .string()
  .min(1, { message: 'Must be at least 1 character.' })
  .max(16, { message: 'Must be at most 16 characters.' })

export const MessageSchema = z
  .string()
  .min(1, { message: 'Must be at least 1 character.' })
  .max(1000, { message: 'Must be at most 1000 characters.' })

export const TimeSentSchema = z.string()

export const RoomNameSchemaRegex = new RegExp('^\\S+\\w$')

export const RoomNameSchema = z
  .string()
  .min(2, { message: 'Must be at least 2 characters.' })
  .max(16, { message: 'Must be at most 16 characters.' })
  .regex(RoomNameSchemaRegex, {
    message: 'Must not contain spaces or special characters.',
  })

export const EventNameSchema = z.enum(['chat', 'kick_user', 'join_room'])

export const SocketIdSchema = z.string().length(20, { message: 'Must be 20 characters.' })

export const UserSchema = z.object({
  userId: UserIdSchema,
  userName: UserNameSchema,
  socketId: SocketIdSchema,
})

export const ChatMessageSchema = z.object({
  user: UserSchema,
  timeSent: TimeSentSchema,
  message: MessageSchema,
  roomName: RoomNameSchema,
  eventName: EventNameSchema,
})

export const RoomSchema = z.object({
  name: RoomNameSchema,
  host: UserSchema,
  users: UserSchema.array(),
})

export const JoinRoomSchema = z.object({
  user: UserSchema,
  roomName: RoomNameSchema,
  eventName: EventNameSchema,
})

export const KickUserSchema = z.object({
  user: UserSchema,
  userToKick: UserSchema,
  roomName: RoomNameSchema,
  eventName: EventNameSchema,
})

export const ClientToServerEventsSchema = z.object({
  chat: z.function().args(ChatMessageSchema).returns(z.void()),
  join_room: z.function().args(JoinRoomSchema).returns(z.void()),
  kick_user: z
    .function()
    .args(KickUserSchema, z.function().args(z.boolean()).returns(z.void()))
    .returns(z.void()),
})

export const ServerToClientEventsSchema = z.object({
  chat: z.function().args(ChatMessageSchema).returns(z.void()),
  kick_user: z.function().args(KickUserSchema).returns(z.void()),
})

src/shared/interfaces/chat.interface.ts

import { z } from 'zod'
import {
  ChatMessageSchema,
  JoinRoomSchema,
  KickUserSchema,
  RoomNameSchema,
  RoomSchema,
  SocketIdSchema,
  UserIdSchema,
  UserNameSchema,
  UserSchema,
  ServerToClientEventsSchema,
  ClientToServerEventsSchema,
} from '../schemas/chat.schema'

export type UserId = z.infer<typeof UserIdSchema>
export type UserName = z.infer<typeof UserNameSchema>
export type SocketId = z.infer<typeof SocketIdSchema>
export type User = z.infer<typeof UserSchema>

export type RoomName = z.infer<typeof RoomNameSchema>
export type Room = z.infer<typeof RoomSchema>
export type Message = z.infer<typeof ChatMessageSchema>

export type JoinRoom = z.infer<typeof JoinRoomSchema>
export type KickUser = z.infer<typeof KickUserSchema>

export type ServerToClientEvents = z.infer<typeof ServerToClientEventsSchema>
export type ClientToServerEvents = z.infer<typeof ClientToServerEventsSchema>

Frontend Client Modifications

The client now also needs to include user, roomName, and eventName when emitting events.

We’ll just show snippets of the frontend for brevity.

src/client/pages/chat.tsx

// 'join_room' event on socket connection
socket.on('connect', () => {
  const joinRoom: JoinRoom = {
    roomName,
    user: { socketId: socket.id, ...user },
    eventName: 'join_room',
  }
  JoinRoomSchema.parse(joinRoom)
  socket.emit('join_room', joinRoom)
  setIsConnected(true)
})
// .....
// 'chat' event when client sends a chat message
const sendMessage = (message: string) => {
  if (user && socket && roomName) {
    const chatMessage: Message = {
      user: {
        userId: user.userId,
        userName: user.userName,
        socketId: socket.id,
      },
      timeSent: new Date(Date.now()).toLocaleString('en-US'),
      message,
      roomName: roomName,
      eventName: 'chat',
    }
    ChatMessageSchema.parse(chatMessage)
    socket.emit('chat', chatMessage)
  }
}
//.....
// 'kick_user' event when host kicks a user
const kickUser = (userToKick: User) => {
  if (!room) {
    throw 'No room'
  }
  if (!user) {
    throw 'No current user'
  }
  const kickUserData: KickUser = {
    user: { ...user, socketId: socket.id },
    userToKick: userToKick,
    roomName: room.name,
    eventName: 'kick_user',
  }
  KickUserSchema.parse(kickUserData)
  socket.emit('kick_user', kickUserData, (complete) => {
    if (complete) {
      roomRefetch()
    }
  })
}

Also, in our <UserList/> component in our client where we display current users in the chat room, we’ve added a “Kick” button that will allow room hosts to kick users out of the room.

src/client/components/list.tsx

import React from 'react'
import { Room, User } from '../../shared/interfaces/chat.interface'

export const UserList = ({
  room,
  currentUser,
  kickHandler,
}: {
  room: Room
  currentUser: User
  kickHandler: (user: User) => void
}) => {
  return (
    <div className="flex h-4/6 w-full flex-col-reverse overflow-y-scroll">
      {room.users.map((user, index) => {
        return (
          <div key={index} className="mb-4 flex justify-between rounded px-4 py-2">
            <div className="flex items-center">
              <p className="text-white">{user.userName}</p>
              {room.host.userId === user.userId && <span className="ml-2">{'👑'}</span>}
            </div>
            {room.host.userId === currentUser.userId && user.userId !== currentUser.userId && (
              <button
                className="flex h-8 items-center justify-self-end rounded-xl bg-gray-800 px-4"
                onClick={() => kickHandler(user)}
              >
                <span className="mr-1 text-white">{'Kick'}</span>
              </button>
            )}
          </div>
        )
      })}
    </div>
  )
}

Here’s what that looks like now!

Host client connected user list screenshot

So in the above case, this is showing the host (Austin) client. Meeko’s client won’t have a kick button. But what if this application was growing, we hired a bunch of frontend developers, and they accidentally introduced a bug where Meeko now had a kick button.

Let’s simulate that and see what happens.

Non host client connected users list screenshot

Oops! Meeko can now kick the host. That’s bad. But if he tries to…

[1] [Nest] 87703  - 12/21/2022, 11:28:09 AM   ERROR [WsExceptionsHandler] Forbidden
[1] ForbiddenException: Forbidden
[1]     at /Users/austinhoward/code/nest-realtime/nest-react-websockets/src/server/chat/guards/chat.guard.ts:64:15
[1]     at Array.every (<anonymous>)
[1]     at ChatPoliciesGuard.<anonymous> (/Users/austinhoward/code/nest-realtime/nest-react-websockets/src/server/chat/guards/chat.guard.ts:61:20)
[1]     at Generator.next (<anonymous>)
[1]     at fulfilled (/Users/austinhoward/code/nest-realtime/nest-react-websockets/dist/server/server/chat/guards/chat.guard.js:14:58)
[1]     at processTicksAndRejections (node:internal/process/task_queues:96:5)

Perfect! Our server is protected even if there is a mistake on the client.

That’s all for authorization. CASL provides us with powerful tools to composes situational authorization rules in our Nest application. With guards we were able to abstract these authorization capabilities out of our gateway in a clean and concise way.