courtvision.models
BallDetector
Source code in courtvision/models.py
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
|
predict(image, frame_idx, clip_uid)
Predicts ball detections for a given frame.
Note
This method caches the detections on disk.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image |
torch.Tensor
|
Image tensor of shape (1,3,H,W) |
required |
frame_idx |
int
|
frame index |
required |
clip_uid |
str
|
clip uid that identifies the clip uniquely. |
required |
Returns:
Type | Description |
---|---|
dict[str, torch.Tensor]
|
dict[str, torch.Tensor]: A dict tensor ball detections. |
Source code in courtvision/models.py
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
|
PlayerDetector
Source code in courtvision/models.py
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
|
predict(image, frame_idx, clip_uid)
Predicts player detections for a given frame.
Note
This method caches the detections on disk.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image |
torch.Tensor
|
Image tensor of shape (1,3,H,W) |
required |
frame_idx |
int
|
frame index |
required |
clip_uid |
str
|
clip uid that identifies the clip uniquely. |
required |
Returns:
Type | Description |
---|---|
dict[str, torch.Tensor]
|
dict[str, torch.Tensor]: Dict of player detections. |
Source code in courtvision/models.py
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
|
get_ball_detection_model(model_path)
Grabs a trained ball detection model from a path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_path |
Path
|
Path to the model weights. A .ckpt file. |
required |
Returns:
Name | Type | Description |
---|---|---|
BallDetectorModel |
BallDetectorModel
|
A trained BallDetectorModel from a checkpoint. |
Source code in courtvision/models.py
66 67 68 69 70 71 72 73 74 75 76 77 78 |
|
get_fasterrcnn_ball_detection_model(model_path=None)
Fetches a FasterRCNN model for ball detection. If model_path is None, the model is pretrained on COCO. If model_path is a Path, the model is loaded from the path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_path |
None | Path
|
Path do model weights that will be loaded. Defaults to None. |
None
|
Returns:
Name | Type | Description |
---|---|---|
FasterRCNN |
FasterRCNN
|
A ball detection model using FasterRCNN |
Source code in courtvision/models.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
|
get_yolo_player_detection_model(model_path=None)
Fetches a pretrained YOLO model for player detection.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_path |
None | Path
|
Unused!. Defaults to None. |
None
|
Returns:
Name | Type | Description |
---|---|---|
Any |
Any
|
Yolo model |
Source code in courtvision/models.py
41 42 43 44 45 46 47 48 49 50 51 |
|