Source code for swarmsort.drawing_utils

"""
Visualization utilities for SwarmSort tracking results.

This module provides tools for drawing tracking results, creating visualizations,
and generating video output from tracking sequences.
"""
import numpy as np
from typing import List, Optional, Tuple, Dict, Any
import colorsys
from dataclasses import dataclass

try:
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    from matplotlib.animation import FuncAnimation

    MATPLOTLIB_AVAILABLE = True
except ImportError:
    MATPLOTLIB_AVAILABLE = False

try:
    import cv2

    OPENCV_AVAILABLE = True
except ImportError:
    OPENCV_AVAILABLE = False

from .data_classes import Detection, TrackedObject


[docs] @dataclass class VisualizationConfig: """Configuration for visualization settings.""" frame_width: int = 800 frame_height: int = 600 background_color: Tuple[int, int, int] = (0, 0, 0) # Black detection_color: Tuple[int, int, int] = (128, 128, 128) # Gray track_thickness: int = 2 detection_thickness: int = 1 font_scale: float = 0.5 font_thickness: int = 1 show_ids: bool = True show_confidences: bool = True show_velocities: bool = False show_trails: bool = True trail_length: int = 20 bbox_alpha: float = 0.3
[docs] class ColorManager: """Manages consistent colors for track IDs.""" def __init__(self, saturation: float = 0.8, value: float = 1.0): self.saturation = saturation self.value = value self.color_cache: Dict[int, Tuple[int, int, int]] = {}
[docs] def get_color(self, track_id: int) -> Tuple[int, int, int]: """Get a consistent color for a track ID.""" if track_id not in self.color_cache: # Generate color based on track ID hue = (track_id * 137.5) % 360 # Golden angle approximation rgb = colorsys.hsv_to_rgb(hue / 360, self.saturation, self.value) # Convert to 0-255 range self.color_cache[track_id] = tuple(int(c * 255) for c in rgb) return self.color_cache[track_id]
[docs] def reset(self): """Clear color cache.""" self.color_cache.clear()
[docs] class TrackingVisualizer: """Main visualization class for tracking results.""" def __init__(self, config: Optional[VisualizationConfig] = None): self.config = config or VisualizationConfig() self.color_manager = ColorManager() self.trail_history: Dict[int, List[Tuple[float, float]]] = {} if not MATPLOTLIB_AVAILABLE and not OPENCV_AVAILABLE: raise ImportError( "Either matplotlib or opencv-python must be installed for visualization" )
[docs] def draw_frame_matplotlib( self, detections: List[Detection], tracks: List[TrackedObject], frame_num: int, ax=None ) -> None: """Draw a single frame using matplotlib.""" if not MATPLOTLIB_AVAILABLE: raise ImportError("matplotlib is required for this function") if ax is None: fig, ax = plt.subplots(figsize=(10, 8)) ax.clear() ax.set_xlim(0, self.config.frame_width) ax.set_ylim(0, self.config.frame_height) ax.set_aspect("equal") ax.invert_yaxis() # Match image coordinates ax.set_title(f"Frame {frame_num} - {len(tracks)} tracks, {len(detections)} detections") # Draw detections for det in detections: x, y = det.position circle = plt.Circle((x, y), 2, color="gray", alpha=0.5) ax.add_patch(circle) if det.bbox is not None and len(det.bbox) >= 4: bbox = patches.Rectangle( (det.bbox[0], det.bbox[1]), det.bbox[2] - det.bbox[0], det.bbox[3] - det.bbox[1], linewidth=1, edgecolor="gray", facecolor="none", alpha=0.3, ) ax.add_patch(bbox) # Draw tracks for track in tracks: color = np.array(self.color_manager.get_color(track.id)) / 255.0 x, y = track.position # Update trail history if track.id not in self.trail_history: self.trail_history[track.id] = [] self.trail_history[track.id].append((x, y)) if len(self.trail_history[track.id]) > self.config.trail_length: self.trail_history[track.id].pop(0) # Draw trail if self.config.show_trails and len(self.trail_history[track.id]) > 1: trail_points = np.array(self.trail_history[track.id]) ax.plot(trail_points[:, 0], trail_points[:, 1], color=color, alpha=0.5, linewidth=1) # Draw current position circle = plt.Circle((x, y), 8, color=color, alpha=0.8) ax.add_patch(circle) # Draw bounding box if track.bbox is not None and len(track.bbox) >= 4: bbox = patches.Rectangle( (track.bbox[0], track.bbox[1]), track.bbox[2] - track.bbox[0], track.bbox[3] - track.bbox[1], linewidth=2, edgecolor=color, facecolor="none", alpha=0.7, ) ax.add_patch(bbox) # Draw ID and info if self.config.show_ids: text = f"ID:{track.id}" if self.config.show_confidences: text += f"\nConf:{track.confidence:.2f}" if self.config.show_velocities: text += f"\nVel:({track.velocity[0]:.1f},{track.velocity[1]:.1f})" ax.text( x + 10, y - 10, text, fontsize=8, color=color, weight="bold", bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7), )
[docs] def draw_frame_opencv( self, detections: List[Detection], tracks: List[TrackedObject], frame_num: int, frame: Optional[np.ndarray] = None, ) -> np.ndarray: """Draw a single frame using OpenCV.""" if not OPENCV_AVAILABLE: raise ImportError("opencv-python is required for this function") if frame is None: frame = np.zeros((self.config.frame_height, self.config.frame_width, 3), dtype=np.uint8) frame[:] = self.config.background_color else: frame = frame.copy() # Draw detections for det in detections: x, y = int(det.position[0]), int(det.position[1]) cv2.circle(frame, (x, y), 5, self.config.detection_color, -1) if det.bbox is not None and len(det.bbox) >= 4: x1, y1, x2, y2 = [int(coord) for coord in det.bbox] cv2.rectangle(frame, (x1, y1), (x2, y2), self.config.detection_color, 1) # Draw tracks for track in tracks: color = self.color_manager.get_color(track.id) x, y = int(track.position[0]), int(track.position[1]) # Update trail history if track.id not in self.trail_history: self.trail_history[track.id] = [] self.trail_history[track.id].append((x, y)) if len(self.trail_history[track.id]) > self.config.trail_length: self.trail_history[track.id].pop(0) # Draw trail if self.config.show_trails and len(self.trail_history[track.id]) > 1: trail_points = np.array(self.trail_history[track.id], dtype=np.int32) for i in range(1, len(trail_points)): alpha = i / len(trail_points) trail_color = tuple(int(c * alpha) for c in color) cv2.line( frame, tuple(trail_points[i - 1]), tuple(trail_points[i]), trail_color, 1 ) # Draw current position cv2.circle(frame, (x, y), 8, color, -1) cv2.circle(frame, (x, y), 8, (255, 255, 255), 1) # Draw bounding box if track.bbox is not None and len(track.bbox) >= 4: x1, y1, x2, y2 = [int(coord) for coord in track.bbox] cv2.rectangle(frame, (x1, y1), (x2, y2), color, self.config.track_thickness) # Draw ID and info if self.config.show_ids: text = f"ID:{track.id}" if self.config.show_confidences: text += f" C:{track.confidence:.2f}" # Text background text_size = cv2.getTextSize( text, cv2.FONT_HERSHEY_SIMPLEX, self.config.font_scale, self.config.font_thickness, )[0] cv2.rectangle( frame, (x + 10, y - 25), (x + 10 + text_size[0], y - 25 + text_size[1] + 5), (0, 0, 0), -1, ) # Text cv2.putText( frame, text, (x + 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, self.config.font_scale, color, self.config.font_thickness, ) # Frame info info_text = f"Frame: {frame_num} | Tracks: {len(tracks)} | Detections: {len(detections)}" cv2.putText(frame, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) return frame
[docs] def create_video_sequence( self, detection_sequence: List[List[Detection]], track_sequence: List[List[TrackedObject]], output_path: str = "tracking_output.mp4", fps: int = 30, ) -> None: """Create a video from detection and track sequences.""" if not OPENCV_AVAILABLE: raise ImportError("opencv-python is required for video creation") # Video writer setup fourcc = cv2.VideoWriter_fourcc(*"mp4v") video_writer = cv2.VideoWriter( output_path, fourcc, fps, (self.config.frame_width, self.config.frame_height) ) try: for frame_num, (detections, tracks) in enumerate( zip(detection_sequence, track_sequence) ): frame = self.draw_frame_opencv(detections, tracks, frame_num) video_writer.write(frame) # Progress indicator if frame_num % 50 == 0: print(f"Processing frame {frame_num}...") print(f"Video saved to {output_path}") finally: video_writer.release()
[docs] def show_sequence_matplotlib( self, detection_sequence: List[List[Detection]], track_sequence: List[List[TrackedObject]], interval: int = 100, save_path: Optional[str] = None, ) -> None: """Show animated sequence using matplotlib.""" if not MATPLOTLIB_AVAILABLE: raise ImportError("matplotlib is required for this function") fig, ax = plt.subplots(figsize=(12, 9)) def animate(frame_num): if frame_num < len(detection_sequence) and frame_num < len(track_sequence): self.draw_frame_matplotlib( detection_sequence[frame_num], track_sequence[frame_num], frame_num, ax ) anim = FuncAnimation( fig, animate, frames=len(detection_sequence), interval=interval, blit=False, repeat=True ) if save_path: anim.save(save_path, writer="pillow", fps=1000 // interval) print(f"Animation saved to {save_path}") else: plt.show() return anim
[docs] def reset(self): """Reset visualization state.""" self.trail_history.clear() self.color_manager.reset()
[docs] def quick_visualize( detections: List[Detection], tracks: List[TrackedObject], frame_num: int = 0, method: str = "matplotlib", ) -> None: """Quick visualization function for debugging.""" visualizer = TrackingVisualizer() if method == "matplotlib" and MATPLOTLIB_AVAILABLE: fig, ax = plt.subplots(figsize=(10, 8)) visualizer.draw_frame_matplotlib(detections, tracks, frame_num, ax) plt.show() elif method == "opencv" and OPENCV_AVAILABLE: frame = visualizer.draw_frame_opencv(detections, tracks, frame_num) cv2.imshow(f"Tracking Frame {frame_num}", frame) cv2.waitKey(0) cv2.destroyAllWindows() else: print("Requested visualization method not available. Install matplotlib or opencv-python.")