Having worked on a large-scale product, I frequently performed load testing before deploying my features to production. Most of the time, I used Locust and occasionally tried Gatling. I always wondered how these tools worked internally, so I decided — why not build a small load testing library for fun?
This blog is a step-by-step guide to the process I followed in creating a minimal load testing library. It's a basic local implementation and does not cover distributed testing. If you are interested in building one or want to understand how Locust works, follow along!
Features
I am writing this with reference to Locust and will cover the basic features it offers.
- Users can provide details of the APIs to be tested, including URL, payload, request type, headers, etc.
- Users can specify the following load testing parameters:
i. Number of virtual users — The number of concurrent simulated users.
ii. Test duration — The duration for which the load test will run.
iii. Spawn rate — The rate (per second) at which virtual users are introduced. - Live monitoring of the load test with the following metrics:
i. Throughput — Requests per second (RPS).
ii. Response times — Measured in percentiles.
iii.Request statistics — Total number of requests along with failed requests.
Design
To understand how load testing tools work, I explored Locust's code base. Locust uses greenlets (lightweight coroutines provided by the gevent library) to simulate multiple concurrent users. Each greenlet represents a virtual user that establishes a socket connection and continuously sends requests to the target API.
For our library, we will take a similar approach, but instead of using gevent, we will leverage Python's built-in asyncio for asynchronous execution.
- Each virtual user is represented as an async task that runs in an event loop.
- A single HTTP session (using
aiohttp) is maintained per user to avoid unnecessary connection overhead. - Every user randomly selects a task (API request) from the list of tasks provided in the test configuration.
- The request is sent asynchronously, and response times are recorded
We will call our library Stormlight
lets start coding…
First, create a main project folder called stormlight. Inside it, create another directory named stormlight, where we will store our source code.
Handling Configuration and Command-Line Arguments
Before diving into the implementation, let's define how the user will provide API details for testing. Typically, in load testing tools like Locust, the user writes a script to define the test configuration. Similarly, in Stormlight, users will write a configuration file called stormlight_file.py.
In this file, users will specify API details such as method, endpoint, payload, and headers using a class called Task. To define this structure, let's create a new file named data_classes.py and write our Task data structure.
from dataclasses import dataclass, field
from typing import Optional, Dict, Any
@dataclass
class Task:
method: str
path: str
data: Optional[Any] = None
headers: Dict[str, str] = field(default_factory=dict)
def __post_init__(self):
self.method = self.method.upper() # Ensure method is always uppercase
Create stormlight_file.py and write an example user configuration.
from stormlight import Task
# Define the API endpoints and requests to test
endpoints = [
Task("GET", "/hello", headers={"Content-Type": "application/json"}),
Task("POST", "/api/upload",data = {"name": "test", "price": 10}, headers={"Content-Type": "application/json"})
]
Since the test environment can vary, we should allow users to specify the domain (host) when starting the load test, rather than hardcoding it inside the script. This makes the tool more flexible and reusable across different environments.
To achieve this, we need to parse both the configuration file stormlight_file.py and command-line arguments. Let's write the code for this in parser.py.
import os
import argparse
def parse_script(script_path):
if not os.path.exists(script_path):
raise FileNotFoundError(f"Script file not found: {script_path}")
global_vars = {}
with open(script_path, "r") as file:
exec(file.read(), global_vars)
if "endpoints" not in global_vars:
raise ValueError("The script must define an 'endpoints' variable.")
return global_vars["endpoints"]
def parse_args():
parser = argparse.ArgumentParser(description="Configure the load test parameters.")
parser.add_argument("--users", type=int, required=True, help="Number of users to simulate.")
parser.add_argument("--spawn-rate", type=float, required=True, help="Users spawned per second.")
parser.add_argument("--host", type=str, required=True, help="Host/IP address.")
parser.add_argument("--duration", type=int, required=True, help="Duration in seconds for which the test will run.")
args = parser.parse_args()
return args
The parser will extract the endpoints variable from the user-defined configuration file stormlight_file.py, initialize the task objects, and return them.
For load test parameters like number of users, host, spawn rate, and test duration, we will use command-line arguments instead of hardcoding them. The parse_args function will handle parsing these arguments.
In our main.py call the parse_script and parse_args
from .parser import parse_script, parse_args
def main():
test_paths = parse_script("stormlight_file.py")
config = parse_args()
if __name__ == '__main__':
main()
Lets create a data structure called Environment for the test parameters received through command line arguments. In data_classes.py
from dataclasses import dataclass, field
from typing import Optional, Dict, Any
@dataclass
class Task:
method: str
path: str
data: Optional[Any] = None
headers: Dict[str, str] = field(default_factory=dict)
def __post_init__(self):
self.method = self.method.upper() # Ensure method is always uppercase
@dataclass
class Environment:
host: str
tasks: list[Task]
user_count: int
spawn_rate: float
duration: float
In main.py add a code to create environment,
import asyncio
from .parser import parse_script, parse_args
from .runner import Runner
from .dataclasses import Environment
def create_environment(config, tasks):
env = Environment
env.spawn_rate = config.spawn_rate
env.host = config.host
env.user_count = config.users
env.duration = config.duration
env.tasks = tasks
return env
def main():
tasks = parse_script("stormlight_file.py")
config = parse_args()
environment = create_environment(config, tasks)
if __name__ == '__main__':
main()
Handling load test
Lets come to main logic, we will create classes for different functionalities.
- Runner class: this class is responsible for orchestrating the load test execution. It manages the life cycle of user instances, coordinates the start/stop of the test respecting the spawn-rate and the duration provided.
- User class: this class represents a virtual user in a load test. It defines the behaviour of the user. This class will be responsible for performing the tasks configured in
stormlight_file.pyfile. - Metrics: this class is responsible for collecting and aggregating statistics during a load test. It tracks metrics such as request counts, response times, failure rates, and more.
User and Runner
Lets create core.py and add
import time
import random
import asyncio
import aiohttp
from .data_classes import Task, Environment
class User:
host: str
tasks: list[Task] = []
async def send_request(self, session, task):
url = f"{self.host}{task.path}"
await session.request(task.method, url, json=task.data)
async def run(self, end_time):
print('user_running...')
async with aiohttp.ClientSession() as session:
while time.time() < end_time:
# randomly select one task at a time from tasks
task = random.choice(self.tasks)
await self.send_request(session, task)
class Runner:
def __init__(self, environment: Environment):
self.environment = environment
User.tasks = self.environment.tasks
User.host = self.environment.host
async def spawn_user(self, end_time):
user = User()
await user.run(end_time)
async def spawn_users(self, end_time):
async with asyncio.TaskGroup() as tg:
for _ in range(self.environment.user_count):
tg.create_task(self.spawn_user(end_time))
await asyncio.sleep(1 / self.environment.spawn_rate)
async def start(self):
"""Starts the load testing process."""
end_time = time.time() + self.environment.duration
await self.spawn_users(end_time)
When the start method of the Runner class is called, it triggers the spawn_users method. This method is responsible for gradually spawning users while maintaining the specified spawn rate.
Each spawned user instance calls the run method of the User class, which keeps executing tasks until the test duration is reached.
Within the User class, the send_request method is responsible for making actual API requests. It sends requests asynchronously.
Lets test if its working first.
Write two apis that we are calling in above file stormlight_file.py . I am using fastapi. Create a file called fast_api.py and add following code
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
class Item(BaseModel):
name: str
price: float
@app.get("/hello")
async def root():
return {"message": "Hello World"}
@app.post("/api/upload")
async def root(data: Item):
return data
start the fastapi server with (make sure fastapi is installed first):
fastapi dev fast_api.py
Run core.py and ensure that requests are hitting the FastAPI server by checking the FastAPI terminal. Also, verify in the load test terminal that the test runs for the specified duration.
Metrics
lets create metrics.py and add
import asyncio
import time
from collections import defaultdict
from tabulate import tabulate
headers = [ "Method", "Endpoint", "RPS", "Median (ms)", "Average (ms)", "Min (ms)", "Max (ms)",
"Failed Requests", "Total Requests"]
class EntriesDict(dict):
def __init__(self, request_metrics):
self.request_metrics = request_metrics
def __missing__(self, key):
self[key] = MetricsEntry(self.request_metrics, key[0], key[1])
return self[key]
class Metrics:
def __init__(self):
# stores request details of individual tasks in a dictionary
self.entries: dict[tuple[str, str], MetricsEntry] = EntriesDict(self)
# stores aggregated request details of all the requests
self.total = MetricsEntry(self, None, None)
@property
def start_time(self):
return self.total.start_time
@property
def last_request_timestamp(self):
return self.total.last_request_timestamp
def log_request(self, method: str, endpoint: str, response_time: int):
self.total.log(response_time)
self.entries[(endpoint, method)].log(response_time)
def log_error(self, method, endpoint):
self.total.log_error()
self.entries[(endpoint, method)].log_error()
class MetricsEntry:
def __init__(self, metrics: Metrics, endpoint, method):
self.metrics = metrics
self.endpoint = endpoint
self.method = method
self.num_requests: int = 0
self.num_failures: int = 0
self.total_response_time = 0
self.max_response_time: int = 0
self.min_response_time: int | None = None
self.response_times: dict[int, int] = defaultdict(int)
self.start_time = time.time()
self.last_request_timestamp: float | None = None
def log(self, response_time):
current_time = time.time()
self.num_requests += 1
self._log_response_time(response_time)
self._log_request_time(current_time)
def log_error(self) -> None:
self.num_failures += 1
@property
def avg_response_time(self) -> float:
try:
return round(float(self.total_response_time) / self.num_requests, 2)
except ZeroDivisionError:
return 0.0
@property
def rps(self):
if not self.metrics.last_request_timestamp or not self.metrics.start_time:
return 0.0
try:
return round(self.num_requests / (self.metrics.last_request_timestamp - self.metrics.start_time), 2)
except ZeroDivisionError:
return 0.0
def get_percentile(self, percentile):
sorted_times = sorted(self.response_times.keys())
threshold = self.num_requests * (percentile / 100)
cumulative_count = 0
for response_time in sorted_times:
cumulative_count += self.response_times[response_time]
if cumulative_count >= threshold:
return response_time # Return response time at percentile
def _log_request_time(self, current_time: float) -> None:
self.last_request_timestamp = current_time
def _log_response_time(self, response_time):
self.total_response_time += response_time
if self.min_response_time is None:
self.min_response_time = response_time
else:
self.min_response_time = round(min(self.min_response_time, response_time), 2)
self.max_response_time = round(max(self.max_response_time, response_time), 2)
self.response_times[round(response_time)] += 1
def get_metrics_summary(metrics: Metrics):
"""
Get metrics data stored in Metrics object and arrange it in table format
:param metrics:
:return: metrics table
"""
table_data = []
for endpoint in sorted(metrics.entries.keys()):
method = endpoint[1]
path = endpoint[0]
rps = metrics.entries[endpoint].rps
median = metrics.entries[endpoint].get_percentile(50)
average = metrics.entries[endpoint].avg_response_time
max_response_time = metrics.entries[endpoint].max_response_time
min_response_time = metrics.entries[endpoint].min_response_time
failed_requests = metrics.entries[endpoint].num_failures
total_requests = metrics.entries[endpoint].num_requests
table_data.append([method, path, rps, median, average, min_response_time,
max_response_time, failed_requests, total_requests])
table_data.append(['-'*15 for i in range(9)])
total_rps = metrics.total.rps
total_median = metrics.total.get_percentile(50)
total_avg = metrics.total.avg_response_time
total_min = metrics.total.min_response_time
total_max = metrics.total.max_response_time
total_failed = metrics.total.num_failures
total_requests = metrics.total.num_requests
table_data.append(['total', '', total_rps, total_median, total_avg, total_min, total_max,
total_failed, total_requests])
return tabulate(table_data, headers=headers)
async def display_metrics(metrics, end_time):
"""
Display metrics in real time on terminal. This will run for the whole duration of a load test.
Metrics will be displayed every two seconds on terminal
:param metrics: Metrics object
:param end_time: end time of load test
:return:
"""
while time.time() < end_time:
metrics_table = get_metrics_summary(metrics)
print(metrics_table, '\n\n')
await asyncio.sleep(2)
Above metrics code is taken from the locust code base itself (I liked the code style). I have tweaked it here and there to satisfy our minimal use case.
Overview
The metrics system consists of two key classes:
Metrics— The main class responsible for collecting and aggregating API request statistics during the load test.MetricsEntry— Handles individual request statistics, including calculations like average response time, requests per second, and percentiles.
Additionally, we use a custom dictionary EntriesDict to store task-wise metrics efficiently. This ensures that missing keys are automatically created when an unknown (endpoint, method) pair is accessed.
Metrics Class
The Metrics class maintains:
self.entries: A dictionary that stores request metrics categorized by(endpoint, method)self.total: A specialMetricsEntryinstance that tracks aggregated statistics for all requests.
When a request is logged via log_request, it updates both:
- The task-specific entry (
self.entries[(endpoint, method)]) - The total aggregated entry (
self.total)
Since both entries are instances of MetricsEntry, they follow the same logic for calculating statistics.
MetricsEntry Class
Each MetricsEntry instance calculates request statistics like average response time, requests per second, percentiles etc.
Efficient Response Time Storage
Storing response times in a list would be inefficient for high-load tests, consuming too much memory. Instead, we use a dictionary:
self.response_times: dict[int, int] = defaultdict(int)
In this dictionary, the keys represent the response times (in milliseconds) rounded off, and the values indicate how many requests had that specific response time.
When the user performs a request in send_request method of the User class, log_request of metrics will be called. log_request will then log the request for both the task wise entry(endpoint, method) and the total aggregated entry. Both entries are in turn objects of the same MetricsEntry class, so they use the same code.
Example of Request Logging
Step 1: First Request ('/hello', 'GET')
- The request is logged in
self.entriesunder('/hello', 'GET'). - If the key doesn't exist, it is added automatically.
num_requestsfor('/hello', 'GET')is incremented to 1.- The
self.totalentry is also updated, incrementingnum_requeststo 1.
Step 2: Second Request ('/api/update', 'POST')
- A new entry is created for
('/api/update', 'POST'), withnum_requests = 1. - The
self.totalentry is incremented to 2 (since it tracks all requests).
Other metrics (response time, failures, etc.) are updated similarly for both individual and total entries.
Displaying Real-Time Metrics
As the load test is running, request details are stored in Metrics instance.
To display the metrics data in real time as the test is running, a coroutine named display_metrics is run for the duration of the load test and will display metrics data every 2 seconds.
Percentile logic
The get_percentile method takes a percentile value as an argument and returns the response time at that percentile. Let's understand the logic with an example.
Steps for Calculating Percentile
- Sort Response Times
sorted_times = sorted(self.response_times.keys())
self.response_times is a dictionary where:
- Keys are response times (in milliseconds).
- Values represent the number of requests that had that response time.
2. Determine the threshold index
threshold = self.num_requests * (percentile / 100)
self.num_requestsis the total number of requests recorded.percentile / 100converts the percentile into a fraction (e.g., 90 / 100 = 0.9 for the 90th percentile).- Multiplying this fraction by
self.num_requestsgives the threshold index in the sorted list.
3. Find the response time at the threshold
cumulative_count = 0
for response_time in sorted_times:
cumulative_count += self.response_times[response_time]
if cumulative_count >= threshold:
return response_time
- Iterate through
sorted_times, accumulating request counts. - When the cumulative count reaches or exceeds the threshold, return the corresponding response time.
example:
Lets find the 90th percentile for following request details
self.response_times = {
50: 3, # 3 requests took 50ms
75: 10, # 10 requests took 75ms
100: 20, # 20 requests took 100ms
150: 30, # 30 requests took 150ms
200: 37 # 37 requests took 200ms
}
self.num_requests = 100 # Total requests
according to above logic,
sorted_times = [50, 75, 100, 150, 200]
cumulative count
i. 50 ms → 3
ii. 75 ms → 3 + 10 = 13
iii. 100 ms → 13 + 20 = 33
iv. 150 ms → 33 + 30 = 63
v. 200 ms → 63 + 37 = 100 (Crosses threshold at 90)
Result: The 90th percentile response time is 200 ms, meaning 90% of requests completed within 200 ms.
Lets update our core.py,
import time
import random
import asyncio
import aiohttp
from .data_classes import Task, Environment
from .metrics import Metrics, display_metrics
class User:
host: str
tasks: list[Task] = []
async def send_request(self, session, task):
url = f"{self.host}{task.path}"
start_time = time.time()
async with session.request(task.method, url, json=task.data) as response:
latency = (time.time() - start_time) * 1000 # convert to ms
return {
"method": task.method,
"endpoint": task.path,
"response_time": latency
}
async def run(self, end_time, metrics):
async with aiohttp.ClientSession() as session:
while time.time() < end_time:
task = random.choice(self.tasks)
try:
result = await self.send_request(session, task)
metrics.log_request(**result)
except Exception:
metrics.log_error(task.method, task.path)
class Runner:
def __init__(self, environment: Environment):
self.metrics = Metrics()
self.environment = environment
User.tasks = self.environment.tasks
User.host = self.environment.host
async def spawn_user(self, end_time):
user = User()
await user.run(end_time, self.metrics)
async def spawn_users(self, end_time):
async with asyncio.TaskGroup() as tg:
for _ in range(self.environment.user_count):
tg.create_task(self.spawn_user(end_time))
await asyncio.sleep(1 / self.environment.spawn_rate)
async def start(self):
end_time = time.time() + self.environment.duration
async with asyncio.TaskGroup() as tg:
tg.create_task(self.spawn_users(end_time))
tg.create_task(display_metrics(self.metrics, end_time))
We have called log_request in the run method of the User class. For every successful request, log_request is invoked. Also, notice the log_error call in the exception block. If a request fails for any reason, it will be caught in the exception block and logged as a failed request(We are handling exceptions generically for simplicity for now)
Now lets test if it works
Run in stormlight's terminal
python -m stormlight.core --user_count 10 --spawn_rate 2 --duration 10
We can see the metrics table printing on the terminal every 2 seconds like this
Method Endpoint RPS Median (ms) Average (ms) Min (ms) Max (ms) Failed Requests Total Requests
--------------- --------------- --------------- --------------- --------------- --------------- --------------- ----------------- ----------------
POST /api/upload 336.65 14 13.34 1.64 31.71 0 2699
GET /hello 329.54 14 13.15 1.41 31.82 0 2642
--------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
total 666.19 14 13.25 1.41 31.82 0 5341
If you want to display more percentiles, you can modify get_metrics_summary. Just like we calculated the median, we can calculate other percentiles by simply passing the desired percentile value to get_percentile.
github repo: https://github.com/rahulsalgare/stormlight
That was fun! 🎉
Next, let's try adding a Web UI to visualise our load test results next. Stay tuned! 🚀