Skip to content

Commit 1d59638

Browse files
authored
Add support for restricting detection classes (#45)
* Add support for restricting detection classes in `Options`
1 parent 0102c15 commit 1d59638

File tree

4 files changed

+40
-2
lines changed

4 files changed

+40
-2
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "usls"
3-
version = "0.0.17"
3+
version = "0.0.18"
44
edition = "2021"
55
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
66
repository = "https://github.com/jamjamjon/usls"

examples/yolo/main.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ fn main() -> Result<()> {
160160
// .with_names(&COCO_CLASS_NAMES_80)
161161
// .with_names2(&COCO_KEYPOINTS_17)
162162
.with_find_contours(!args.no_contours) // find contours or not
163+
.exclude_classes(&[0])
164+
// .retain_classes(&[0, 5])
163165
.with_profile(args.profile);
164166

165167
// build model

src/core/options.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ pub struct Options {
4848
pub sam_kind: Option<SamKind>,
4949
pub use_low_res_mask: Option<bool>,
5050
pub sapiens_task: Option<SapiensTask>,
51+
pub classes_excluded: Vec<isize>,
52+
pub classes_retained: Vec<isize>,
5153
}
5254

5355
impl Default for Options {
@@ -88,6 +90,8 @@ impl Default for Options {
8890
use_low_res_mask: None,
8991
sapiens_task: None,
9092
task: Task::Untitled,
93+
classes_excluded: vec![],
94+
classes_retained: vec![],
9195
}
9296
}
9397
}
@@ -276,4 +280,16 @@ impl Options {
276280
self.iiixs.push(Iiix::from((i, ii, x)));
277281
self
278282
}
283+
284+
pub fn exclude_classes(mut self, xs: &[isize]) -> Self {
285+
self.classes_retained.clear();
286+
self.classes_excluded.extend_from_slice(xs);
287+
self
288+
}
289+
290+
pub fn retain_classes(mut self, xs: &[isize]) -> Self {
291+
self.classes_excluded.clear();
292+
self.classes_retained.extend_from_slice(xs);
293+
self
294+
}
279295
}

src/models/yolo.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ pub struct YOLO {
2626
layout: YOLOPreds,
2727
find_contours: bool,
2828
version: Option<YOLOVersion>,
29+
classes_excluded: Vec<isize>,
30+
classes_retained: Vec<isize>,
2931
}
3032

3133
impl Vision for YOLO {
@@ -157,6 +159,10 @@ impl Vision for YOLO {
157159
let kconfs = DynConf::new(&options.kconfs, nk);
158160
let iou = options.iou.unwrap_or(0.45);
159161

162+
// Classes excluded and retained
163+
let classes_excluded = options.classes_excluded;
164+
let classes_retained = options.classes_retained;
165+
160166
// Summary
161167
tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version);
162168

@@ -179,6 +185,8 @@ impl Vision for YOLO {
179185
layout,
180186
version,
181187
find_contours: options.find_contours,
188+
classes_excluded,
189+
classes_retained,
182190
})
183191
}
184192

@@ -276,7 +284,19 @@ impl Vision for YOLO {
276284
}
277285
};
278286

279-
// filtering
287+
// filtering by class id
288+
if !self.classes_excluded.is_empty()
289+
&& self.classes_excluded.contains(&(class_id as isize))
290+
{
291+
return None;
292+
}
293+
if !self.classes_retained.is_empty()
294+
&& !self.classes_retained.contains(&(class_id as isize))
295+
{
296+
return None;
297+
}
298+
299+
// filtering by conf
280300
if confidence < self.confs[class_id] {
281301
return None;
282302
}

0 commit comments

Comments
 (0)