Skip to content

courtvision.models

BallDetector

Source code in courtvision/models.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class BallDetector:
    PIPELINE_NAME = "ball_detection"

    def __init__(self, model_file_or_dir: Path, cache_dir: Path):
        if model_file_or_dir.is_dir():
            self.model_path = get_latest_file(model_file_or_dir)
        else:
            self.model_path = model_file_or_dir

        self.model = get_ball_detection_model(model_path=self.model_path)
        self.cache_dir = cache_dir
        self.model.eval()

    def predict(
        self, image: torch.Tensor, frame_idx: int, clip_uid: str
    ) -> dict[str, torch.Tensor]:
        """Predicts ball detections for a given frame.
        !!! note
            This method caches the detections on disk.
        Args:
            image (torch.Tensor): Image tensor of shape (1,3,H,W)
            frame_idx (int): frame index
            clip_uid (str): clip uid that identifies the clip uniquely.

        Returns:
            dict[str, torch.Tensor]: A dict tensor ball detections.
        """
        cache_path = (
            self.cache_dir
            / self.PIPELINE_NAME
            / clip_uid
            / f"detections_at_{frame_idx}.pt"
        )
        if not cache_path.is_dir():
            cache_path.parent.mkdir(parents=True, exist_ok=True)
        if cache_path.is_file():
            return torch.load(cache_path)
        else:
            with torch.no_grad():
                detections = self.model(image)
            torch.save(detections, cache_path)
            return detections

predict(image, frame_idx, clip_uid)

Predicts ball detections for a given frame.

Note

This method caches the detections on disk.

Parameters:

Name Type Description Default
image torch.Tensor

Image tensor of shape (1,3,H,W)

required
frame_idx int

frame index

required
clip_uid str

clip uid that identifies the clip uniquely.

required

Returns:

Type Description
dict[str, torch.Tensor]

dict[str, torch.Tensor]: A dict tensor ball detections.

Source code in courtvision/models.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def predict(
    self, image: torch.Tensor, frame_idx: int, clip_uid: str
) -> dict[str, torch.Tensor]:
    """Predicts ball detections for a given frame.
    !!! note
        This method caches the detections on disk.
    Args:
        image (torch.Tensor): Image tensor of shape (1,3,H,W)
        frame_idx (int): frame index
        clip_uid (str): clip uid that identifies the clip uniquely.

    Returns:
        dict[str, torch.Tensor]: A dict tensor ball detections.
    """
    cache_path = (
        self.cache_dir
        / self.PIPELINE_NAME
        / clip_uid
        / f"detections_at_{frame_idx}.pt"
    )
    if not cache_path.is_dir():
        cache_path.parent.mkdir(parents=True, exist_ok=True)
    if cache_path.is_file():
        return torch.load(cache_path)
    else:
        with torch.no_grad():
            detections = self.model(image)
        torch.save(detections, cache_path)
        return detections

PlayerDetector

Source code in courtvision/models.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
class PlayerDetector:
    PIPELINE_NAME = "player_detection"

    def __init__(self, model_dir: Path, cache_dir: Path):
        self.model_path = get_latest_file(model_dir)
        self.cache_dir = cache_dir
        self.model = get_yolov8_player_detection_model(model_path=self.model_path)
        # self.model.eval()

    def predict(
        self, image: torch.Tensor, frame_idx: int, clip_uid: str
    ) -> dict[str, torch.Tensor]:
        """Predicts player detections for a given frame.
        !!! note
            This method caches the detections on disk.
        Args:
            image (torch.Tensor): Image tensor of shape (1,3,H,W)
            frame_idx (int): frame index
            clip_uid (str): clip uid that identifies the clip uniquely.

        Returns:
            dict[str, torch.Tensor]: Dict of player detections.
        """
        cache_path = (
            self.cache_dir
            / self.PIPELINE_NAME
            / clip_uid
            / f"detections_at_{frame_idx}.pt"
        )
        if not cache_path.is_dir():
            cache_path.parent.mkdir(parents=True, exist_ok=True)
        if cache_path.is_file():
            return torch.load(cache_path)
        else:
            with torch.no_grad():
                detections = self.model.track(
                    source=image.squeeze(0).permute(1, 2, 0).numpy(),
                    persist=True,
                )
            torch.save(detections, cache_path)
            return detections

predict(image, frame_idx, clip_uid)

Predicts player detections for a given frame.

Note

This method caches the detections on disk.

Parameters:

Name Type Description Default
image torch.Tensor

Image tensor of shape (1,3,H,W)

required
frame_idx int

frame index

required
clip_uid str

clip uid that identifies the clip uniquely.

required

Returns:

Type Description
dict[str, torch.Tensor]

dict[str, torch.Tensor]: Dict of player detections.

Source code in courtvision/models.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def predict(
    self, image: torch.Tensor, frame_idx: int, clip_uid: str
) -> dict[str, torch.Tensor]:
    """Predicts player detections for a given frame.
    !!! note
        This method caches the detections on disk.
    Args:
        image (torch.Tensor): Image tensor of shape (1,3,H,W)
        frame_idx (int): frame index
        clip_uid (str): clip uid that identifies the clip uniquely.

    Returns:
        dict[str, torch.Tensor]: Dict of player detections.
    """
    cache_path = (
        self.cache_dir
        / self.PIPELINE_NAME
        / clip_uid
        / f"detections_at_{frame_idx}.pt"
    )
    if not cache_path.is_dir():
        cache_path.parent.mkdir(parents=True, exist_ok=True)
    if cache_path.is_file():
        return torch.load(cache_path)
    else:
        with torch.no_grad():
            detections = self.model.track(
                source=image.squeeze(0).permute(1, 2, 0).numpy(),
                persist=True,
            )
        torch.save(detections, cache_path)
        return detections

get_ball_detection_model(model_path)

Grabs a trained ball detection model from a path.

Parameters:

Name Type Description Default
model_path Path

Path to the model weights. A .ckpt file.

required

Returns:

Name Type Description
BallDetectorModel BallDetectorModel

A trained BallDetectorModel from a checkpoint.

Source code in courtvision/models.py
66
67
68
69
70
71
72
73
74
75
76
77
78
def get_ball_detection_model(model_path: Path) -> "BallDetectorModel":
    """Grabs a trained ball detection model from a path.

    Args:
        model_path (Path): Path to the model weights. A .ckpt file.

    Returns:
        BallDetectorModel: A trained BallDetectorModel from a checkpoint.
    """
    from courtvision.trainer import BallDetectorModel  # TODO: move to models.py

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    return BallDetectorModel.load_from_checkpoint(model_path, map_location=device)

get_fasterrcnn_ball_detection_model(model_path=None)

Fetches a FasterRCNN model for ball detection. If model_path is None, the model is pretrained on COCO. If model_path is a Path, the model is loaded from the path.

Parameters:

Name Type Description Default
model_path None | Path

Path do model weights that will be loaded. Defaults to None.

None

Returns:

Name Type Description
FasterRCNN FasterRCNN

A ball detection model using FasterRCNN

Source code in courtvision/models.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def get_fasterrcnn_ball_detection_model(model_path: None | Path = None) -> FasterRCNN:
    """Fetches a FasterRCNN model for ball detection.
    If model_path is None, the model is pretrained on COCO.
    If model_path is a Path, the model is loaded from the path.

    Args:
        model_path (None | Path, optional): Path do model weights that will be loaded. Defaults to None.

    Returns:
        FasterRCNN: A ball detection model using FasterRCNN
    """

    pretrained = model_path is None

    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
        weights=None, weights_backbone=None, pretrained=pretrained
    )
    num_classes = 2  # 1 class (ball) + background
    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    if model_path is not None:
        model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
    return model

get_yolo_player_detection_model(model_path=None)

Fetches a pretrained YOLO model for player detection.

Parameters:

Name Type Description Default
model_path None | Path

Unused!. Defaults to None.

None

Returns:

Name Type Description
Any Any

Yolo model

Source code in courtvision/models.py
41
42
43
44
45
46
47
48
49
50
51
def get_yolo_player_detection_model(model_path: None | Path = None) -> Any:
    """Fetches a pretrained YOLO model for player detection.

    Args:
        model_path (None | Path, optional): Unused!. Defaults to None.

    Returns:
        Any: Yolo model
    """
    model = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True)
    return model