Source code for swarmsort.track_state

"""
SwarmSort Track State Management

This module contains the core data structures for managing track states in the
SwarmSort tracking system. It includes both active track states and pending
detections that are waiting to be confirmed as tracks.

Classes:
    PendingDetection: Detection waiting to become a confirmed track
    FastTrackState: State representation of an active tracked object
"""

# ============================================================================
# STANDARD IMPORTS
# ============================================================================
from dataclasses import dataclass, field
from typing import Optional, Literal
from collections import deque
import numpy as np


# ============================================================================
# PENDING DETECTION CLASS
# ============================================================================
[docs] @dataclass class PendingDetection: """Represents a detection waiting to become a track.""" position: np.ndarray embedding: Optional[np.ndarray] = None bbox: np.ndarray = field( default_factory=lambda: np.zeros(4, dtype=np.float32) ) class_id: Optional[int] = None confidence: float = 1.0 first_seen_frame: int = 0 last_seen_frame: int = 0 consecutive_frames: int = 1 total_detections: int = 1 average_position: np.ndarray = field(default_factory=lambda: np.zeros(2, dtype=np.float32)) def __post_init__(self): self.average_position = self.position.copy()
[docs] def update(self, position: np.ndarray, embedding: Optional[np.ndarray] = None, bbox: Optional[np.ndarray] = None, confidence: float = 1.0): """Update pending detection with new observation.""" self.total_detections += 1 self.consecutive_frames += 1 # Update average position self.average_position = ( self.average_position * (self.total_detections - 1) + position ) / self.total_detections # Update latest position self.position = position # Update embedding if provided if embedding is not None: self.embedding = embedding # Update bbox if provided if bbox is not None: self.bbox = bbox # Update confidence (keep max) self.confidence = max(self.confidence, confidence)
[docs] def is_ready_for_promotion( self, min_consecutive: int, max_gap: int, current_frame: int ) -> bool: """Check if pending detection should become a track.""" gap = current_frame - self.last_seen_frame return self.consecutive_frames >= min_consecutive and gap <= max_gap
# ============================================================================ # FAST TRACK STATE CLASS # ============================================================================
[docs] @dataclass class FastTrackState: """Enhanced track state with N-embedding history and kalman_type support.""" id: int class_id: Optional[int] = None position: np.ndarray = field(default_factory=lambda: np.zeros(2, dtype=np.float32)) velocity: np.ndarray = field(default_factory=lambda: np.zeros(2, dtype=np.float32)) predicted_position: np.ndarray = field(default_factory=lambda: np.zeros(2, dtype=np.float32)) bbox: np.ndarray = field(default_factory=lambda: np.zeros(4, dtype=np.float32)) # Kalman state (for "simple" type) kalman_state: np.ndarray = field(default_factory=lambda: np.zeros(4, dtype=np.float32)) last_detection_pos: np.ndarray = field(default_factory=lambda: np.zeros(2, dtype=np.float32)) last_detection_frame: int = 0 # Observation history for both types observation_history: deque = field(default_factory=lambda: deque(maxlen=5)) observation_frames: deque = field(default_factory=lambda: deque(maxlen=5)) # OC-SORT specific arrays (for "oc" type) observation_history_array: np.ndarray = field( default_factory=lambda: np.zeros((0, 2), dtype=np.float32) ) observation_frames_array: np.ndarray = field( default_factory=lambda: np.zeros(0, dtype=np.int32) ) # Track type kalman_type: str = "simple" # Kalman filter tuning velocity_damping: float = 0.95 # Embedding freeze tracking embedding_frozen: bool = False last_safe_embedding: Optional[np.ndarray] = None # Embedding history with configurable size embedding_history: deque = field(default_factory=lambda: deque(maxlen=5)) embedding_method: Literal["average", "best_match", "weighted_average", "last", "median"] = "average" # Embedding match score history (cosine similarity from recent matches) embedding_score_history: deque = field(default_factory=lambda: deque(maxlen=5)) # Cache for average embedding _cached_avg_embedding: Optional[np.ndarray] = None _cache_valid: bool = False # Cache for multi-embedding computation _cached_representative_embedding: Optional[np.ndarray] = None _representative_cache_valid: bool = False # Keep for backward compatibility avg_embedding: Optional[np.ndarray] = None embedding_update_count: int = 0 age: int = 0 hits: int = 0 misses: int = 0 confirmed: bool = False detection_confidence: float = 0.0 confidence_score: float = 0.5 lost_frames: int = 0 # Recent hit/miss history for uncertainty computation # True = hit (matched), False = miss (unmatched) recent_match_history: deque = field(default_factory=lambda: deque(maxlen=10)) def __post_init__(self): self.kalman_state[:2] = self.position self.last_detection_pos = self.position.copy() self.predicted_position = self.position.copy() # Initialize OC-SORT arrays properly if needed if self.kalman_type == "oc": if self.observation_history_array.shape[0] == 0: self.observation_history_array = np.zeros((0, 2), dtype=np.float32) if self.observation_frames_array.shape[0] == 0: self.observation_frames_array = np.zeros(0, dtype=np.int32)
[docs] def get_observation_prediction(self, current_frame: int, max_history: int = 5) -> np.ndarray: """Get observation-based prediction using recent detection history.""" if len(self.observation_history) < 2: return self.predicted_position # Simple linear extrapolation from last two observations pos1 = self.observation_history[-2] pos2 = self.observation_history[-1] frame1 = self.observation_frames[-2] frame2 = self.observation_frames[-1] if frame2 == frame1: return pos2.copy() dt = current_frame - frame2 velocity = (pos2 - pos1) / (frame2 - frame1) predicted = pos2 + velocity * dt return predicted.astype(np.float32)
[docs] def update_observation_history(self, position: np.ndarray, frame: int): """Update observation history for observation-based prediction.""" self.observation_history.append(position.copy()) self.observation_frames.append(frame)
[docs] def set_embedding_params( self, max_embeddings: int = 5, method: Literal["average", "best_match", "weighted_average", "last", "median"] = "average", score_history_length: int = 5, ): """Configure embedding storage parameters. Args: max_embeddings: Maximum number of embeddings to keep in history method: Method for computing representative embedding score_history_length: Number of recent match scores to keep """ self.embedding_history = deque(maxlen=max_embeddings) self.embedding_method = method self.embedding_score_history = deque(maxlen=score_history_length) self._cache_valid = False
[docs] def add_embedding(self, embedding: np.ndarray): """Add new embedding to history with safe normalization. When the track is frozen (embedding_frozen=True), this method returns immediately without modifying the embedding_history. This protects the appearance model during collisions or crowded areas. Args: embedding: The embedding vector to add (will be L2-normalized) """ if self.embedding_frozen: # Track is in collision zone - don't update embeddings return if embedding is not None: embedding = np.asarray(embedding, dtype=np.float32) norm = np.linalg.norm(embedding) if norm > 0: # Only normalize if not already normalized if abs(norm - 1.0) > 0.01: normalized_emb = embedding / (norm + 1e-8) else: normalized_emb = embedding if self.last_safe_embedding is None: self.last_safe_embedding = normalized_emb.copy() self.embedding_history.append(normalized_emb.copy()) self.embedding_update_count += 1 # Only invalidate caches - don't recompute (lazy evaluation) self._cache_valid = False self._representative_cache_valid = False
# REMOVED: self._update_avg_embedding() - avg_embedding is never read def _update_avg_embedding(self): """Update avg_embedding with caching.""" if len(self.embedding_history) > 0: if self.embedding_method == "average": if not self._cache_valid: # OPTIMIZATION: Use np.array() directly on deque (more efficient than list()) self._cached_avg_embedding = np.mean(np.array(self.embedding_history), axis=0) self._cache_valid = True self.avg_embedding = self._cached_avg_embedding elif self.embedding_method == "weighted_average": weights = np.arange(1, len(self.embedding_history) + 1, dtype=np.float32) weights = weights / weights.sum() # OPTIMIZATION: Use np.array() directly on deque self.avg_embedding = np.average(np.array(self.embedding_history), axis=0, weights=weights) else: self.avg_embedding = self.embedding_history[-1]
[docs] def get_representative_embedding(self) -> Optional[np.ndarray]: """Get representative embedding based on configured method.""" if len(self.embedding_history) == 0: return None if self.embedding_method == "last": return self.embedding_history[-1] elif self.embedding_method == "average": # OPTIMIZATION: Use np.array() directly on deque return np.mean(np.array(self.embedding_history), axis=0) elif self.embedding_method == "weighted_average": weights = np.arange(1, len(self.embedding_history) + 1, dtype=np.float32) weights = weights / weights.sum() # OPTIMIZATION: Use np.array() directly on deque return np.average(np.array(self.embedding_history), axis=0, weights=weights) elif self.embedding_method == "median": # Median embedding (element-wise median across history) return np.median(np.array(self.embedding_history), axis=0) else: return self.embedding_history[-1]
[docs] def freeze_embeddings(self): """Freeze embeddings when collision/crowded area detected. When frozen, add_embedding() will return early without modifying the embedding_history. This protects the track's appearance model from being corrupted by mixed-object detections during collisions. The freeze/unfreeze is controlled by core.py's _update_collision_states() which uses embedding_freeze_density with hysteresis to prevent oscillation. Note: last_safe_embedding stores a reference embedding from before the collision. This can be useful for debugging or future ReID improvements. """ if not self.embedding_frozen: self.embedding_frozen = True # Save last embedding before freeze for potential debugging/recovery if len(self.embedding_history) > 0 and self.last_safe_embedding is None: self.last_safe_embedding = self.embedding_history[-1].copy()
[docs] def unfreeze_embeddings(self): """Unfreeze embeddings when collision area is cleared. Since add_embedding() returns early when frozen, the embedding_history should still contain only pre-collision embeddings at this point. No restoration is needed - the history is already clean. Note: We don't clear last_safe_embedding here in case it's useful for debugging or comparison purposes. """ if self.embedding_frozen: self.embedding_frozen = False
[docs] def update_with_detection( self, position: np.ndarray, embedding: Optional[np.ndarray] = None, bbox: Optional[np.ndarray] = None, frame: int = 0, det_conf: float = 0.0, is_reid: bool = False, ): """Update track state with new detection.""" # Import here to avoid circular dependency from .kalman_filters import simple_kalman_update self.position = position.astype(np.float32) self.last_detection_pos = position.copy() self.last_detection_frame = frame self.detection_confidence = det_conf self.update_observation_history(position, frame) if self.kalman_type == "simple": self.kalman_state = simple_kalman_update(self.kalman_state, position, is_reid=is_reid) self.velocity = self.kalman_state[2:].copy() # predicted_position is WHERE we expect the object in the NEXT frame self.predicted_position = self.position + self.velocity elif self.kalman_type == "oc": new_observation = position.reshape(1, 2).astype(np.float32) new_frame = np.array([frame], dtype=np.int32) # Max history size for OC-SORT arrays (prevents unbounded growth) max_oc_history = 30 if len(self.observation_history_array) == 0: self.observation_history_array = new_observation self.observation_frames_array = new_frame else: self.observation_history_array = np.vstack( [self.observation_history_array, new_observation] ) self.observation_frames_array = np.append( self.observation_frames_array, frame ) # Trim to max history size (keep most recent) if len(self.observation_history_array) > max_oc_history: self.observation_history_array = self.observation_history_array[-max_oc_history:] self.observation_frames_array = self.observation_frames_array[-max_oc_history:] if len(self.observation_history_array) >= 2: dt = self.observation_frames_array[-1] - self.observation_frames_array[-2] if dt > 0: self.velocity[0] = (self.observation_history_array[-1, 0] - self.observation_history_array[-2, 0]) / dt self.velocity[1] = (self.observation_history_array[-1, 1] - self.observation_history_array[-2, 1]) / dt if embedding is not None: self.add_embedding(embedding) if bbox is not None: self.bbox = np.asarray(bbox, dtype=np.float32) self.hits += 1 self.age += 1 self.misses = 0 self.lost_frames = 0 # Record hit for uncertainty tracking self.recent_match_history.append(True)
[docs] def predict_position(self, current_frame: int = None): """Update predicted_position using Kalman filter WITHOUT modifying counters. This should be called for ALL tracks BEFORE assignment to ensure predicted_position is up-to-date for cost matrix computation. """ # Import here to avoid circular dependency from .kalman_filters import simple_kalman_predict_with_damping, oc_sort_predict if self.kalman_type == "simple": self.kalman_state = simple_kalman_predict_with_damping( self.kalman_state, self.velocity_damping ) self.velocity = self.kalman_state[2:].copy() self.predicted_position = self.kalman_state[:2].copy() elif self.kalman_type == "oc": if hasattr(self, 'observation_history_array') and len(self.observation_history_array) > 0: frame_to_use = current_frame if current_frame is not None else ( self.observation_frames_array[-1] + 1 if len(self.observation_frames_array) > 0 else 0 ) pred_state = oc_sort_predict( self.observation_history_array, self.observation_frames_array, frame_to_use ) self.predicted_position = pred_state[:2].copy() else: self.predicted_position = self.position.copy()
[docs] def predict_only(self, current_frame: int = None): """Prediction step for UNMATCHED tracks - updates position AND counters. This is called for tracks that were NOT matched to a detection. For matched tracks, use update_with_detection() instead. """ # Update predicted position self.predict_position(current_frame) # Increment counters for unmatched track self.age += 1 # Record miss for uncertainty tracking self.recent_match_history.append(False) self.misses += 1 self.lost_frames += 1
[docs] def get_predicted_position(self, current_frame: int) -> np.ndarray: """Get predicted position based on kalman_type.""" # Import here to avoid circular dependency from .kalman_filters import oc_sort_predict if self.kalman_type == "simple": return self.predicted_position elif self.kalman_type == "oc": pred_state = oc_sort_predict( self.observation_history_array, self.observation_frames_array, current_frame ) return pred_state[:2] else: return self.predicted_position
[docs] def add_embedding_fast(self, embedding: np.ndarray, pre_normalized: bool = True): """Fast embedding addition without expensive checks.""" if self.embedding_frozen or embedding is None: return if not pre_normalized: norm = np.linalg.norm(embedding) if norm > 0: embedding = embedding / norm if self.last_safe_embedding is None: self.last_safe_embedding = embedding.copy() if len(self.embedding_history) > 0: self._cache_valid = False self._representative_cache_valid = False self.embedding_history.append(embedding) self.embedding_update_count += 1 # Only invalidate caches - don't recompute (lazy evaluation) self._cache_valid = False self._representative_cache_valid = False
# REMOVED: self._update_avg_embedding() - avg_embedding is never read
[docs] def get_recent_miss_ratio(self) -> float: """Get the ratio of misses in recent frames. Returns: Float in [0, 1] where 0 = all hits, 1 = all misses. Returns 0 if no history yet. """ if len(self.recent_match_history) == 0: return 0.0 # Count False values (misses) miss_count = sum(1 for hit in self.recent_match_history if not hit) return miss_count / len(self.recent_match_history)
[docs] def set_uncertainty_window(self, window_size: int): """Set the window size for uncertainty tracking. Args: window_size: Number of recent frames to track (default: 10) """ # Create new deque with updated maxlen, preserving recent history old_history = list(self.recent_match_history) self.recent_match_history = deque(old_history[-window_size:], maxlen=window_size)