"""Module for visualizing simulation results."""
from pymicrostructure.markets.base import Market
from pymicrostructure.traders.base import Trader
from pymicrostructure.metrics.trader import (
position_history,
profit_history,
)
import matplotlib.pyplot as plt
from typing import List
[docs]
def participant_comparison(participants: List[Trader]):
"""
Compare the position and profit history of a list of participants.
Parameters
----------
participants : List[Trader]
A list of participants to compare.
"""
included_participants = [p for p in participants if p.include_in_results]
fig, axs = plt.subplots(2, len(included_participants))
# adjust size
fig.set_size_inches(15, 10)
for i, participant in enumerate(included_participants):
trader_type = type(participant).__name__
pos_ts, pos_hist = position_history(participant)
pnl_ts, pnl_hist = profit_history(participant)
axs[0, i].plot(pos_ts, pos_hist)
axs[0, i].set_title(f"{trader_type} {participant.trader_id} Position")
axs[0, i].set_xlabel("Trade Number")
axs[0, i].set_ylim(min(pos_hist), max(pos_hist))
axs[0, i].set_xlim(0, len(pos_ts))
axs[1, i].plot(pnl_ts, pnl_hist)
axs[1, i].set_title(f"{trader_type} {participant.trader_id} Profit")
axs[1, i].set_xlabel("Trade Number")
axs[1, i].set_ylim(min(pnl_hist), max(pnl_hist))
axs[1, i].set_xlim(0, len(pnl_ts))
plt.tight_layout()
plt.show()
[docs]
def price_path(market: Market):
"""
Visualize the price path of a market.
Parameters
----------
market : Market
The market to visualize.
Returns
-------
None
"""
prices = [trade["price"] for trade in market.trade_history]
aggressor_side = [trade["aggressor_side"] for trade in market.trade_history]
time = [trade["time"] for trade in market.trade_history]
best_bid = [
snapshot["bid"][0]["price"] if snapshot["bid"] else None
for snapshot in market.ob_snapshots
]
best_ask = [
snapshot["ask"][0]["price"] if snapshot["ask"] else None
for snapshot in market.ob_snapshots
]
ob_time = [snapshot["time"] for snapshot in market.ob_snapshots]
plt.figure(figsize=(15, 5))
plt.plot(ob_time, best_bid, label="Best Bid", color="green")
plt.plot(ob_time, best_ask, label="Best Ask", color="red")
# plt.scatter(time, prices, c=aggressor_side, cmap="RdYlGn_r", label="Trades")
plt.xlabel("Time")
plt.ylabel("Price")
plt.title("Price Path")
plt.legend()
plt.show()