From ff818d2f6aba4a0c9c97fdcbf269011caca964a8 Mon Sep 17 00:00:00 2001 From: Quentin Date: Mon, 7 Jul 2025 16:39:59 +0200 Subject: [PATCH 1/7] adding DOI in citation file --- CITATION.cff | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CITATION.cff b/CITATION.cff index a59d066..eebe3a0 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -31,4 +31,6 @@ authors: identifiers: - type: swh value: 'swh:1:rev:66f8d295cc5fbc80f356d11be46571bfbb190609' + - type: doi + value: '10.5281/zenodo.15830087' license: GPL-3.0 From d6bc29835a2e7750f9f62ecc1576db25b2460f97 Mon Sep 17 00:00:00 2001 From: Quentin Date: Mon, 7 Jul 2025 16:36:58 +0200 Subject: [PATCH 2/7] Adding Zenodo DOI badge --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 88e5c1c..d6acad6 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Python version supported Codecov Binder - +DOI

From 90baab1bb165e7700e98af369a19b26a64a72431 Mon Sep 17 00:00:00 2001 From: Quentin Date: Tue, 8 Jul 2025 11:54:55 +0200 Subject: [PATCH 3/7] Start inclusion of Curgraph --- src/radius_clustering/iterative.py | 226 +++++++++++++++++++++++++++++ 1 file changed, 226 insertions(+) create mode 100644 src/radius_clustering/iterative.py diff --git a/src/radius_clustering/iterative.py b/src/radius_clustering/iterative.py new file mode 100644 index 0000000..056f394 --- /dev/null +++ b/src/radius_clustering/iterative.py @@ -0,0 +1,226 @@ +""" +This module contains the implementation of the CURGRAPH algorithm. + +The CURGRAPH algorithm is an iterative algorithm that takes as input: + + * a set of data and features + * a number of max clusters to find (optional) + * a threshold for the radius constraint (optional) + +The algorithm returns whether: + + * the clusters and the maximum radius optimally found for nb_cluster from 2 to max clusters + * the clusters and the maximum radius optimally found for a given threshold + * the clusters and the maximum radius optimally found for a given number of clusters + +The algorithm is based on the following steps: + + 1. Compute the distance matrix between the data + 2. Rank the dissimilarities in decreasing order + 3. If max number of clusters is given: + For each dissimilarity until max number of clusters is reached: + 1. Compute the corresponding input graph considering each + dissimilarity above the dissimilarity threshold as no edge + 2. Find the minimum dominating set of the graph + 3. If the cardinality of the MDS is above the previous + cardinality found, store the threshold,the MDS and the max radius + 4. If not, continue to the next dissimilarity + returns the cardinality of the MDS, the threshold associated + and the max radius + 4. If threshold is given: + 1. Compute the corresponding input graph considering each + dissimilarity above the dissimilarity threshold as no edge + 2. Find the minimum dominating set of the graph + 3. Compute max radius and cardinality of the MDS + For each dissimilarity until cardinality of the MDS is above the previous cardinality found: + 1. Compute the corresponding input graph considering each + dissimilarity above the dissimilarity threshold as no edge + 2. Find the minimum dominating set of the graph + 3. If the cardinality of the MDS is above the previous + cardinality found, store the previous threshold, MDS and max radius + 4. If not, continue to the next dissimilarity + returns the cardinality of the MDS, the MDS, the max radius and the threshold associated +""" + +from __future__ import annotations + +import os +import time +import numpy as np +import scipy as sp + +from copy import deepcopy +from joblib import Parallel, delayed, effective_n_jobs, parallel_backend +from sklearn.base import BaseEstimator, MetaEstimatorMixin +from sklearn.utils.validation import check_is_fitted +from sklearn.utils import check_random_state +from sklearn.metrics import pairwise_distances +from typing import Union, List, Dict, Any, Tuple + +from .radius_clustering import RadiusClustering + +DIR_PATH = os.path.dirname(os.path.realpath(__file__)) +ROOT_PATH = os.getcwd() + + +def gen_even_slices(n, n_packs, n_samples=None): + start = 0 + for pack_num in range(n_packs): + this_n = n // n_packs + if pack_num < n % n_packs: + this_n += 1 + if this_n > 0: + end = start + this_n + if n_samples is not None: + end = min(n_samples, end) + yield slice(start, end, None) + start = end + + +class Curgraph(object): + """ + CURGRAPH algorithm for clustering based on Minimum Dominating Set (MDS). + + The CURGRAPH algorithm is an iterative algorithm that have been designed upon + CLUSTERGRAPH algorithm, presented by Hansen and Delattre (1978) + in "Complete-link cluster analysis by graph coloring". + + Parameters: + ---------- + max_clusters : int, optional + The maximum number of clusters to consider. Default is None. + radius : Union[int, float], optional + The dissimilarity threshold for clustering. Default is None. + manner : {'approx', 'exact'}, optional + The manner in which to compute the clusters. 'approx' uses an approximate method, + while 'exact' uses an exact method. Default is 'approx'. + random_state : int, RandomState instance or None, optional + Controls the randomness of the estimator. Pass an int for reproducible output across multiple function calls. + + Attributes: + ---------- + results_ : dict + A dictionary to store the results of the CURGRAPH algorithm. + + Methods: + ------- + fit(X: ArrayLike, y: None) -> self: + Run the CURGRAPH algorithm. + predict(X: np.ndarray, k: int) -> Tuple[np.ndarray, float]: + Predict the clusters and maximum radius for the given data, for a specific number of clusters. + get_results() -> dict: + Return the results of the CURGRAPH algorithm. + """ + + def __init__( + self, + manner: str = "approx", + max_clusters: int = None, + radius: Union[int, float] = None, + ): + self.manner = manner + self.max_clusters = max_clusters + self.radius = radius + self.results_ = {} + self.solver = RadiusClustering(manner=self.manner) + + def _init_dist_list(self, X: np.ndarray) -> None: + """ + Initialize the list of dissimilarities based on the radius parameter. + """ + self.X_ = X + self.dist_mat_ = pairwise_distances(self.X_) + self._list_t = np.unique(self.dist_mat_)[::-1] + if self.radius is None: + t = self.dist_mat_.max(axis=1).min() + else: + t = self.radius + radius = t + arg_radius = np.where(self._list_t <= radius)[0] + self._list_t = self._list_t[arg_radius:] + + def _init_results(self): + if self.max_clusters is not None: + for i in range(2, self.max_clusters + 1): + self.results_[i] = {"radius": None, "centers": None} + + def fit(self, X: ArrayLike, y: None, n_jobs: int = -1) -> self: + """ + Run the CURGRAPH algorithm. + """ + self._init_results() + self._init_dist_list(X) + self.n_jobs_ = effective_n_jobs(n_jobs) + dissimilarity_index = 0 + first_t = self._list_t[0] + old_mds = self.solver.set_params(radius=first_t).fit(X).centers_ + cardinality_limit = ( + self.max_clusters + 1 if self.max_clusters else len(old_mds) + 1 + ) + tasks = [ + delayed(self._curgraph)( + dissimilarity_index, self._list_t[s], old_mds, cardinality_limit + ) + for s in gen_even_slices(len(self._list_t), self.n_jobs_) + ] + with parallel_backend("threading", n_jobs=self.n_jobs_): + Parallel()(tasks) + + return self + + def _curgraph( + self, + index_d: int, + list_t: List[float], + old_mds: np.ndarray, + cardinality_limit: int, + ) -> None: + while (len(old_mds) < cardinality_limit) and (index_d < len(list_t)): + old_mds = self._process_mds(index_d, list_t, old_mds) + index_d += 1 + + def _process_mds( + self, index_d: int, list_t: List[float], old_mds: np.ndarray + ) -> np.ndarray: + """ + Process the minimum dominating set (MDS) for a given dissimilarity index. + """ + t = list_t[index_d] + if self._is_dominating_set(t, old_mds): + return old_mds + new_mds = self.solver.set_params(radius=t).fit(self.X_).centers_ + if len(new_mds) > len(old_mds): + self._update_results(t, new_mds) + + return new_mds + + def _update_results(self, mds: np.ndarray, t: float) -> None: + """ + Update the results dictionary with the new MDS and radius. + """ + card = len(mds) + if self.results_[card]: + if t < self.results_[card]["radius"]: + self.results_[card] = {"radius": t, "centers": mds} + else: + self.results_[card] = {"radius": t, "centers": mds} + + def _is_dominating_set(self, t: float, mds: np.ndarray) -> bool: + """ + Check if the current MDS is a dominating set for the given threshold t. + """ + adj_mat = self.dist_mat_ <= t + return np.all(np.any(adj_mat[:, mds], axis=1)) + + def predict(self, X: np.ndarray, k: int) -> Tuple[np.ndarray, float]: + """ + Predict the clusters and maximum radius for the given data, for a specific number of clusters. + """ + # Implementation of the predict method + return np.array([]), 0.0 + + def get_results(self) -> dict: + """ + Return the results of the CURGRAPH algorithm. + """ + return self.results_ From f7c3e672637da35553bec84ce14153578c4c328a Mon Sep 17 00:00:00 2001 From: Quentin Date: Tue, 8 Jul 2025 19:33:53 +0200 Subject: [PATCH 4/7] iterative base finalized. Unit tests still to be made --- src/radius_clustering/iterative.py | 123 +++++++++++++++++++++-------- 1 file changed, 88 insertions(+), 35 deletions(-) diff --git a/src/radius_clustering/iterative.py b/src/radius_clustering/iterative.py index 056f394..90517b3 100644 --- a/src/radius_clustering/iterative.py +++ b/src/radius_clustering/iterative.py @@ -77,7 +77,7 @@ def gen_even_slices(n, n_packs, n_samples=None): start = end -class Curgraph(object): +class Curgraph(MetaEstimatorMixin, BaseEstimator): """ CURGRAPH algorithm for clustering based on Minimum Dominating Set (MDS). @@ -116,13 +116,16 @@ def __init__( self, manner: str = "approx", max_clusters: int = None, - radius: Union[int, float] = None, + radius: Union[int, float, None] = None, + random_state: Union[int, np.random.RandomState, None] = None, + n_jobs: int = -1, ): self.manner = manner self.max_clusters = max_clusters + self.solver = RadiusClustering(manner=self.manner, random_state=random_state) self.radius = radius - self.results_ = {} - self.solver = RadiusClustering(manner=self.manner) + self.random_state = random_state + self.n_jobs = n_jobs def _init_dist_list(self, X: np.ndarray) -> None: """ @@ -139,18 +142,12 @@ def _init_dist_list(self, X: np.ndarray) -> None: arg_radius = np.where(self._list_t <= radius)[0] self._list_t = self._list_t[arg_radius:] - def _init_results(self): - if self.max_clusters is not None: - for i in range(2, self.max_clusters + 1): - self.results_[i] = {"radius": None, "centers": None} - - def fit(self, X: ArrayLike, y: None, n_jobs: int = -1) -> self: + def fit(self, X: np.ndarray, y=None) -> "Curgraph": """ Run the CURGRAPH algorithm. """ - self._init_results() self._init_dist_list(X) - self.n_jobs_ = effective_n_jobs(n_jobs) + self.n_jobs = effective_n_jobs(self.n_jobs) dissimilarity_index = 0 first_t = self._list_t[0] old_mds = self.solver.set_params(radius=first_t).fit(X).centers_ @@ -158,66 +155,122 @@ def fit(self, X: ArrayLike, y: None, n_jobs: int = -1) -> self: self.max_clusters + 1 if self.max_clusters else len(old_mds) + 1 ) tasks = [ - delayed(self._curgraph)( - dissimilarity_index, self._list_t[s], old_mds, cardinality_limit + delayed(Curgraph._curgraph)( + dissimilarity_index, + self._list_t[s], + old_mds, + cardinality_limit, + self.dist_mat_, + self.solver, ) - for s in gen_even_slices(len(self._list_t), self.n_jobs_) + for s in gen_even_slices(len(self._list_t), self.n_jobs) ] - with parallel_backend("threading", n_jobs=self.n_jobs_): - Parallel()(tasks) + with parallel_backend("threading", n_jobs=self.n_jobs): + result_list = Parallel()(tasks) + self.results_ = {} + for local_results in result_list: + for card, result in local_results.items(): + if card not in self.results_: + self.results_[card] = result + else: + if result["radius"] < self.results_[card]["radius"]: + self.results_[card] = result + self.labels_ = None return self + def predict(self, n_clusters: int) -> np.ndarray: + check_is_fitted(self) + solution = self.results_.get(n_clusters) + if solution is None: + available = sorted(self.results_.keys()) + raise ValueError( + f"No solution found for {n_clusters} clusters. " + f"Available solutions are for k={available}." + ) + centers = solution["centers"] + distance_to_centers = self.dist_mat_[:, centers] + return np.argmin(distance_to_centers, axis=1) + + @property + def available_clusters(self) -> list[int]: + check_is_fitted(self) + return sorted(self.results_.keys()) + + @staticmethod def _curgraph( - self, index_d: int, list_t: List[float], old_mds: np.ndarray, cardinality_limit: int, - ) -> None: + dist_mat: np.ndarray, + solver: RadiusClustering, + ) -> Dict[int, Dict[str, Any]]: + local_results = {} while (len(old_mds) < cardinality_limit) and (index_d < len(list_t)): - old_mds = self._process_mds(index_d, list_t, old_mds) + old_mds = Curgraph._process_mds( + index_d, list_t, old_mds, local_results, solver, dist_mat + ) index_d += 1 + return local_results + @staticmethod def _process_mds( - self, index_d: int, list_t: List[float], old_mds: np.ndarray + index_d: int, + list_t: List[float], + old_mds: np.ndarray, + local_results: Dict[int, Dict[str, Any]], + solver: RadiusClustering, + dist_mat: np.ndarray, ) -> np.ndarray: """ Process the minimum dominating set (MDS) for a given dissimilarity index. """ t = list_t[index_d] - if self._is_dominating_set(t, old_mds): + if Curgraph._is_dominating_set(t, old_mds, dist_mat): return old_mds - new_mds = self.solver.set_params(radius=t).fit(self.X_).centers_ + new_mds = solver.set_params(radius=t).fit(dist_mat).centers_ if len(new_mds) > len(old_mds): - self._update_results(t, new_mds) + Curgraph._update_results(t, new_mds, local_results) return new_mds - def _update_results(self, mds: np.ndarray, t: float) -> None: + @staticmethod + def _update_results( + mds: np.ndarray, t: float, local_results: Dict[int, Dict[str, Any]] + ) -> None: """ Update the results dictionary with the new MDS and radius. """ card = len(mds) - if self.results_[card]: - if t < self.results_[card]["radius"]: - self.results_[card] = {"radius": t, "centers": mds} + if card not in local_results: + local_results[card] = {"radius": t, "centers": mds} else: - self.results_[card] = {"radius": t, "centers": mds} + if t < local_results[card]["radius"]: + local_results[card] = {"radius": t, "centers": mds} + else: + local_results[card] = {"radius": t, "centers": mds} - def _is_dominating_set(self, t: float, mds: np.ndarray) -> bool: + @staticmethod + def _is_dominating_set(t: float, mds: np.ndarray, dist_mat: np.ndarray) -> bool: """ Check if the current MDS is a dominating set for the given threshold t. """ - adj_mat = self.dist_mat_ <= t + adj_mat = dist_mat <= t return np.all(np.any(adj_mat[:, mds], axis=1)) - def predict(self, X: np.ndarray, k: int) -> Tuple[np.ndarray, float]: + def _compute_labels(self) -> None: """ - Predict the clusters and maximum radius for the given data, for a specific number of clusters. + Compute the cluster labels for each point in the dataset. """ - # Implementation of the predict method - return np.array([]), 0.0 + for _, result in self.results_.items(): + centers = result["centers"] + distances = self.dist_mat_[:, centers] + result["labels"] = np.argmin(distances, axis=1) + min_dist = np.min(distances, axis=1) + result["labels"][min_dist > result["radius"]] = -1 + + self.labels_ = self.results_[self.target_cardinality]["labels"] def get_results(self) -> dict: """ From fcb205e7f0e6ba48bb7109acee1a184d65b56dec Mon Sep 17 00:00:00 2001 From: Quentin Date: Thu, 10 Jul 2025 09:22:03 +0200 Subject: [PATCH 5/7] Structural tests and fix of SEGFAULT --- .coverage | Bin 53248 -> 69632 bytes src/radius_clustering/__init__.py | 3 +- src/radius_clustering/iterative.py | 56 ++++++------- src/radius_clustering/utils/mds_core.cpp | 18 +++-- tests/test_structural.py | 18 ++++- tests/test_unit.py | 95 +++++++++++++---------- 6 files changed, 113 insertions(+), 77 deletions(-) diff --git a/.coverage b/.coverage index 1f404fce329b9579ae56467f4a39332314e873d8..bee8848d14b6ff0eaf5a904e5c8e8624fe3c1b56 100644 GIT binary patch literal 69632 zcmeHQ37phKzOJh7bSl-Igy9+ifks6@fnkta1?9f)W8`pXW*QinnI7g41{N_<7gpG( zE}r|)U0n~n5S3jOuT@de^#Vaf6g;?8bio@%=KZTHRrDzM*yZhetFQXUG~bu1RFW^r zm;R@&F{6eSRVL%pQ{^Rz%J>e}vj z6pQ>8x#!y>+;*`Q_Nmbgu@dXD=#6H=Du4w2(F$kVx7}xSzjMd-i;q3G_t1WEu`J%2#o^PgUHsg5d1`LFELk3(UR0cn7nK$k z6(lN)N@v6?XC`Its)}SGf6&1)c+~NX?bGXXlwVW`YgHy^KoMo-MJ0*y>iFzrb=!1? zU~84hd6i-T6p}2OQOa|~TL-JRiI*p*C(Dzi1<4Ar7p;p5+psnlFJfkPy?S_YGJQ;m z@&fo5t#vf8;P9agnhpgOq{<8V-cKt}lorfPRU_a+oilZ%0mE=^W;>AU! z>BB6oD5@+sF?MgFWVRbA)qw_;_ob zPANG*qhhV9(xN$4Np|%H+I?ImPgMD?d?H znM(6~IQjXx@D;lxi<;TZo8!fE(iaMsCx2Q|WrYG(t&2u2KFX!T(~o}L@Vlh;MZ)i* z;^KH}dis{(YsbaOfvY+5p2?A7B+B`X7Ec8kVNd^0uQtB3un+NR#i?nvueJ0QQ;~!l z1iV0In_g8~z;BfbehdYv3zFrD8OioRs<4Is>kA1kom?xD-LLu4E**6V&gDu-JZJXkyd z;yhzv0x~D2rK&2^*Dt5(%acEQ4)3oh;4qikL#{UI>~MRJOjT4A!DE#_wkmqarSVoX z6BXQrDq6*-rBcPoM5%Z-z#4-pV4Fo1@$^3Y_^H(Ea2XZVyo?zav`X)je7JC%9}72~ zyuSdHlq5ePPW_RKyF)laL6{(l)I^KXn=Kvz4FA zzj>5Q`A<2oIX1ITKOK8j!ObXqCcgq6>Lm%e+W{rDE|MQ49=={ifZ@Wgoy_ZmYnLGh zr(z4-+{782CRC8Qg5NrmKr5gX&TZOj79D$cZ^fx+zwEGv;tZIt$HOwR6$h1vmL}5`?MSekXRYhg8 z9EQ2csVFZv!crI~r5pzFxgZJM`p1W%$lT*DvZ}JExFRP%zo@jRGCv>69UX#l+Z=bf ziQ*Zl@}kO_B^6NSc?8Prdfa9HNG2$HB!;3}9(U0h0`U5O!x@L%dtKy=bH_R>kLeWa zOj-f0fL1^&pcT*xXa%$aS^=$qRzNGD6+mco*lUH4J4(%r|BVS@FZcM1{89WL9uxN3 z9DljB;{VXZu-EnY%REy2Pez8lmez4An-Tx(@&EJ*dK3TOqi z0$KsBfL1^&pcT*xXa%$aS^=%VFH!*mjtaZq@?*<#zjfbsUvpn@pL8E|?{WX+UhiJ+ zUg|Dz=eRT6Def3|klV}c?6z^6x%FKJ;{|@@>~-FE-gI7e);VjP)y{3s3g>su66Zpv z(wXHNCqaWPNV!vEH#ZTF+UJTK8GES^sHW zVO?s?w^CN2b)Gfc>SJ}Y+F2)A*;dRV^c(sq-9g`=FVH9G8hQu4kuIl8=t5dfXVA%X z6zxyXq#bBW+L*dDY#uN_H9s)lG+#2GG9NVWHg7VoF_)SbnN{X2Gv6F*4l>U&bIg;> z1}2Ss7x^^uUSw0``N(6D`y;nURz$9h{5rBAQWi-@CPqd?`bN4&PKz{+)QyC zazm$tvO|n~Pxg@y$m`^J@+i5N+)S<}OUMFJTKlG9ep>@ugc#^ew2Aqr2jmS-P7BB; z^ak_W1mty|(psgw&iqyZ+00W;RVka9-%^ns%s)ku*O-5@B3qe%k|J+0zl9=Om>&@0eJ;M0nHS7nfXlv@-ljb`6mYCvzil`-y|TPbJ92O zCEf)Axe@(|c~t?q0o};FN|_Daz`P1UYG$Gp%qtI4u0u~UZ%#n2saeRpvVdHTu3=s( zAd7iQX+V~utC?34kjv0g<`oCzH|R3v%?`*C^c&{Q3dmx#gn31R)GS1cEpH|#*gzhv zKO@L>5%Mf=IuJ~ZMd%{xB{{~%g^Q?HDB!{inO6{GzZhM|ylDa1i9TXpLXi)tH&v{$ zpo)6=0?uDRz4HaEo=?3g0?wgdb zw()g_OM>2G-Y`YBGjC`>-a+p(Z-^q>s5cl0C^)Py^#%#JbRqQy3OKTcdIJO;RYAS} z!O~IG>nGsqk<{xeU=?WBCxD<`?*M{!=L$G<81>E(aPUy-^%8K7oh{&i zfz&%o!2ScMccy^-`ctomfPMQ>ue*SKVAs0|_;?@cbrrBT>{%B9d-bMXX90WmqFyHf z&+19NGX(5$7WFy`*u4k!as}+xoq8Pv?AncbIRbX+O1<_1cJ4yGb^>G4N{1({5b zKNqlN)R-QBC}0`}#GeQhhv5wI#{tHnviqsW9|izZWz^%30)WLS>hT8wz>;F>@y7sP zT0%Yk5CC{TEai^?fWx6M{r~_tbT9R|{R12VCySduz`;YP$E}}3IB(qe0S*AjZ6Dx3 zI9c5E`BFHA-0}hLg{9o^0oET*J#P0LHl!Xmdw>n=Q;%Cczy<)h(en%qsK;#{U|q<| zO&(xYUFva*2k2%|j~hHd+ciCI?-oGQTy%&a0>H; zr4xMKp)Jf4mQL^_PK2csJP}V~p0IR+^YD4h6P8Z!1f0i&qjNlw2}37%Jf6UWpA$R| zk7vTp2_B2bG2!L}kHKS^Fmr-O!sdjR6FdTsWWvfh9>au_6U+`BWWvbFp)GiXmG*H= zaI0NRxH!Qr4}HLdi4%MZZpnm)6MParg$WBMxCK6m2?r;KKEN%QFmQsK;#N%fH^EH~ z?PbEg$)UZt2@~#3aDCjA3G*g6h8-rno8Y>zYr?t-&cbzB+PMkAQJlqua}(^sa&FuN zmOI#G+_*Wmp&ahq9LKQDgmv?EV@xm@8*oiyj2U-tz-Si^Tim`OH`=kx;`YtqyUQ52Z=U0Qw1aW~ z=J-9ZP;THHzY9jo4V>fc=v^!A;IQgj%b4(Qg5E}3nXqqywxYM0aBq&cGhyBYZ9#7_ z;oSs&Vs2r=x(Rv%EI~LoL9fFPOkvyvZAPy%?%RB`o1uJe+#GL0n;Ca*j^XE^#jTqk z$*ap4w{E_81N<=L-p%nV=vBteTi^|hyEn)DCnC3Rj$eeIaooQ-egVD6xPf#09D0Fq z2j_S_Eaw(3man(Chx5&@TgJGD^Tq4Xb5xi(NiC;RDr}n+CaG|3k|?CYw8_*26`oDRp#!PI7 za$kmnkJ>ct%b3iVX>wnN-#}3_?aCIZF%1J2;7E-bP6H!ihSDHpm?RBI$b_GRQ19=z zCGa11_wg{zb@#bj+*g>#US?lL?}~1>uduf}&$?aRXWU2JTixHo+ki{lYPZxaaP!<@ zZVR`!+t&Hs`OMkvyzM;VthUx#tDFr^wR5Xe>ipih+_}UlaE3W~PH*dGr>oP}{?c9T zv~U_aQ3u%{LDav=o^L;IKLBq6R@y5d{x7k|+CA(6b~`(6*NYvnBe8eg`mwz(jlBl% z0Pc=G9Q#A;cd<)j^C0#=KQf?_pcnlWYyUo&BEu2b;)duwm>Ri1<%sjjTf~X6>_{wcfWj z(Un$x>so7>b&*wR&9o+4BdxYp-0BO_KBfEUL5T9-q%YFHLX^LPUP7;+^Jp<0PZN=5 zv?D1BC&HgZmPH;l^XL$I4n2X!C^Ek=_n2>+ubS)3halGfqj{CN#Qc?6W)_-5&2!97 zW^40AvyK^#d=uFlxhQg7WOd{g<16DG^$bBv^s zW1MU>G-&uc>=*6|QM?_h3GEMk5ZWB72t5~i zICNL&hS2394`TQUp+TW`p)*6xLs{e|vLzHEUy{F(-Q+Qd;g^%S zVWdZ7d}LsxZKO_w7!!;^aLw3}^i5$Ph;ngb#URYZjSe$NJ3%rY=;Cb25JbATp<)o~ z;s%OAtc&X_2Ei__rx--LxUOOl?&2)PAl}7w6oY^lyNW@?iyg%tMr*OH*hA=H98(N} zUK~{nqF&4tgRmD{ib33qsbUYHHP}=PB3~TIU@Nhq7{tCftQZ8p@QX|83DGYmib42` zv0@Pa;={^A0E}u>8brY8kYW%5qwf`i7#JN?41!?vonjCLqi+?1Fc`t@7o1y&gV8sN zK_HC2RtzFx^p#=|3ZpL-TY+vsUr5G-Vf49V2%=&1nPLzQqy36OJdE}!1_3epyJ8R# zqfZrskQjZU7{tVAuVN4sqmLDXs2F`D84rumhcSM95IjO$jP}SR1c5Qy&66+$kPsQ8 zT~SUT--V0NPKg&@h<2z|AUH-JCa%D%RwvE8H2zY-7XoAtl@1L+#LK3Ogng`#_to07&Y+P zj6+JoHSnGcAjKTs9Fz&+HM&XSR2llC#KkG}2Z>9H(Tx%pm!KOYzP}i)ka&0v`cH|6 z?nS?sc*sz6y~Klupz9PKjINbLd#GMS`y@dz#A== zIPZM)YlWwvOM`Qgj=j+(ib3#=E>;YpZ{#Zm;WzRWgZLX=lr}_gL>_@JQUpTq3l)JF ze4!!`g#Strh{6{r0%7=kMIa8Z&LCsvDFTuBTuFE+egQuwh%R|7UL_NFFkYz$MB^2T zKsa8m2*l%aB;f&hnIt?SPbmT+d8s5kCNGhM2j#_zKvX_k5eUm?DFShMks=V7&r}2= z^BIajXg*yLh|QCVKyY5D2t?-via>ZiErUc8ia>xqRT3Vd=S#vv^z#*gXncwyJW7&; zN9mI!;bHnjDK*6D=PCT|vOJXuBJ~N1K&U=m5s1~tDFVUzShfCJ%f_e#2-im|0`dB& z46=QsA`r2UPy|Bu;fg@aK1>k^+J`CvQTvb#vURW`5VsFf1OoSgYS&&}Hb5mn=-yuu zh~4{TkPUqmf#|)DA`rg!Rs`bra}|LA{v1Uhg72lsb7*}}wFB#xovjieh(9ZRi^Dle zEk9Gj(o_$g9j<>K$9I<)0{Lz-OZoDy5|)*Bk#NrP&Jvc*=_Fxk*%=ZpFYPE{NoB5t z$*CPAoV_wf!dbK1OIS3korE)s+DcfMJYB*>;b{_1O|+3Pe`;$5XSR~?{Jc{ooRrs6 z!ijmONH`_$WC{nL?rL05!p2SNN|@a^OTq@(btJ6Uz?HCW zJx9W0+!yamDW*8fMEp>e#Nf2RH))@{U_H2g#L|8xzxEG^i7TmKI^r=~CCZU0#P zKdjf9ACj*BhZp#u!@qz1Km1_0Abm~zMEyVH86DLB103)V)c-^F6F>)OVSn`cf2$js z%JUp={XdkG-qt@;{|`R{)9dQ`f2gtNZ3cAxKW~|@>;F~dsjmNrrv1A9U)TT3D*Jzn z`hVU!fP@H5w-c~x2p|9N1!DXVtK22-uiP@XFx}R_w^0*G zxSiZn-4k5q9CY?MJD}hHOU_fyUz|Ig8==?#ubuhO>VF#a`5)}`bUHexLXUsPA@*0s zqtMoWo4wIkZ9i+=Vz0GV**DwQ*q7NC+U53id!jwu?rnFmPqUlZS+)^75c?$dUTkyh z`PgHz`(n4ou8;jTc2TT6mW)k^4T_x!J^q`=vSOj=m(ktO-~VsX$D;Q}Z;CF5-u`o= zMbSyd`slD|FX-)ma>c)saSgQke}LV}u4TVr3s?!uXQNmj)|s_pjabzB z-uk=sp7omb3^ZE&obH4k{C}VqL!bR2)`RqU`Ut&;UP-IydGu_WLtD^#)G)uc?trlX zerqkV=2!*RIBS6Uq4|XQCv&km+v;weZf=EM{r8zyo7LtNbA)-Wd4}22%r=&omgQPa zp*R00k?oNUk*6aIBY%!uA4x%v{-u#Ik$%vNzfGh`#5VRD9~iIyJYO~Gz19k71^!V5 zf|j@j+K)cNQxtmuy2nfo+G`mD`{GFw_aBJUO*nbX*!OXKo?2ribVd=4I3WdRt;Q2% z5_S)LjmIkn4OH}DYgsJx+=C48t!(}A|cF=*|J zn<@t5x!@BOg8^M|6UC0u!$+n4L?@abcANQ$7iG1()EuGPjM`qUE*Zz(3-1SS+H3d1 ztKIMJr+V7`@NIf#zq_C6TUWbJUhSTJKh^Kzq^0UoA%7UeE*7mcgNh* z4%P0VckDgw{zv-Oq1yfR(tYF7?*!m+^w0LRlMlV4y+5z@K0Ko@U9G-x8U5;N^^?o! zW5@f<@qPW#d-7xHRhHR7k9W9}JI(vpNqn>(Y?-~>GT(&%Z}wFCx%N)u;q{Muph?9u z``=~u6XbmirHP8}po-P%s9L+aQuKLKIn%WweEY!*?sTLty1fP4KWbmMTJN_r`n_fL zbvvH-+qK?lXY_T;?BB-UT|@hS+!lsugZlqZ-JN>;KRy1R9{*2||7Rc=3qgn)g`!eKy?(Txt|1Y{vxc9?Yf7iK7-3wvd|3Y_yJJ>zj?ckp5Hgqkh`TskN z_`lg%599rWVSqh=CVtEp6EPSFTKF-a&&0s@4I20{ zpHIfX3AFEHKA(?)6KLMYd_E(CNZHMNJ|_bw(72EJd|C#PvWxkAW(Gk(9X>gOATXU7 zpPxaHz0BuRGzjuB^Z6_doItZa=JSafIDuAu%;$48Z~~3`n9rwc-~`(AF`v)azzHi=SsaZLLJat^ z;YU3p0AUrt-`WJe7#gJUll8W{kys_4Chj(St-HW_%)Z(@Zl#S+nO}?~j2_s44>W)V z&;T0v|82luZ$+x7P1E+?TS#=Lyh1L~}@-MK`jXTZxhL=(Xbd$KUj-32k* zTsGZQ$VOAiOmUi-e6o} zN6Ovh!F@2eOjcf?Bg6|AvKKq@V#gM;a$Vt`S(3-tPQR3Vz>; zCx#2zVjPE)A8&>yYkw(?RL`8L9T+Mu6s{-Uoh;-7FeWXUp!p(8hi||5y5U#JAv}jiD2qV3$tBns9#t`jJWxFS@wc?7&dvJk3 zE>OGm4QG1zrIP1!=*jN%a*2MgQEnA!;c_n|SUSx{N%gXs7h5{P5{Nek$oy#~k?Pu7 zZLd*WM)GXQzao>CkUEnWNTX^~+Dqhzb1ISixg@EM7|)XG8j)LIItwBt)k|tGHXXTN zyjcSEBjHGOWu@keD}v`G{xfCLBzVjCCGfr*!CyZ7!9-yodj8mXGYq_LN=fxCl@}X0 zA40sDKbP$5)QdwtFs*1m4N0H>geDD{w~P;YrMG9n&O{2b$zURxgG&~2u($(6`HsB_ z*g4Ui9WE5tFAs~$liz#B$16G*kCl7atG>89T;A)l`Fs*?tNgZ=zge|L>jn~eo`v#t z(e7+E(`-prqQl}!1sa{f#$ zQayjZ<`2t@QRJC;9&YOCgc`OC^+}85eH+Ku%UVdd@Y~MBb;8|E(7{9T3Ou=qGuSOG zsJ-OxI;@@-_?r7Jfj{^_184vZpaC?12G9T+Km%w14WI!ufCfHz211&yMR@(MyRQ-V zckqM{G=K)s02)98XaEhM0W^RH&;S}h18CsNG+^nW>8|=G9?h;6s;!p)2w+j;O^wYA zW~IAt5%(?k4_D>{anxu44WI!ufCkV28bAYR01co4G=K)sfHGi(rfce50G)+u4f!Sj z|NKAfJWkwix^2!F06x$F8bAYR01co4G=K)s02)98XaEgdJ_8%WVd6J8ZOFq{0yM!l zuVmoMeVXpb_6(;1^G&%#Z*n*v?@0~k3tkSsIj1R~>$yPda3Pt>H^t-0OtKJd)qqg@JS)`rNETpNlTJ&-3O4U9Z!i>$#WQ^#lR@^Z&4On7IGy61N=y zA7}s#paC?12G9T+Km%w14WI!ufCet3flXn*S-a%I_q_T)+!prhFTLMM^?#@%>@T|X zJ}*@N>+8b)T;q~8OsN0y`~NQEqQ;j-184vZpaC?12G9T+Km%w14WI!u@YgpG(n7@M zpZ{y_dj$UA0}Y@7G=K)s02)98XaEhM0W^RH&;T0v02&B|Y=eLPPo{nVCx=r)184vZ zpaC?12G9T+Km%w14WI!ufCd5s{PTaT{{tNeXaEhM0W^RH&;S}h184vZpaC?120oAm z_~-wJ44t@tbbssq%01;i?>_DRhx=XkU))3P!|p@we)n#7(Cu?O-SuvZd#k(9o#)PQ zD_rIh=N;#)^Sbl0^Ne%U`GNB-=WEVm&S#wmoco-Dv)$=(Hae@ErOu5`qcg{;ai%&Y z_Ivi*_FMKZ?C0#K>?iCa_T%=W_Cb4}z02NV$L%(|#lFREvahnMZPOaFes8^LowSZw zKeE1Oeck%9^*QSS>mF;+@~jSPjkVOe!MetpVU<}W=Kq?%HD5PhG@mw~FprpDGaoS@ zGWVK0&F$tkv(>!QyxD9tXPZ@~$;R04*_-SnJH~#*zQ?}KzQjh^0d_Y_vo6-gma|*g zb!;A+!O9sm{$#vuylI>=jvGHQerSBl_^R=UanRUr>@hM%j}bFg8%vEvMuRcSs4^^& zhz~S?2G9T+Km%yt3NaAzXHU@#jfTe}{wyG>g~zBL1auxw3JzgFN8Xsr>r20EK2aMfAre@NiURn)H*xMC&st0b_B92TmQkMxbW5pk2(;aZ zPXP@rWJml6Z!|(Mos!~ay{n)4Qv}AgQ9mqkEAvBux~^-R&z)etE`v@UXfo)~I+;&o z5YslZ`z3GEV(h+vJJ`LFH)@;MJ(4$Q8`;Mtw`&{N-ICk1cD6_II(W6+lGkeM*sg%v z*iOmp+O)3v%YEH8JNd+rUElRQV8%Z4P+*55rCJ&5muzaKtWUBF?Ve;u zb6Kxs8~W*yY-u*@mTYPkOGq{}#<~JFSzI#J47N@1x%V{6w#xRfR>8JN*2ijCr({jj zS%=`U8ck!H1(P#c*obk~bn?p)W0S;JN7zQW$Lr*mY=h+2$m^_K@~h-E)+YHhdDUp; z486QE!q&^?Q{*IDC;22f#nuLVnyrz1g1o|3OMaQ0V5=m*L|$eqCBH~sVk;#7oV>{H z3iu>jF8KxWbJimHdGZ3gQ}S{099tIf^Xv}E$H=qn_JEJGrIMc^$Ji3dKO@hu#gd;U zKV!E^ev&-RJ|_7oXumaRf6DkMXV7qTgxw;WkCG?p&A@ODvZM4SL7A*^BPU%SAxDix z0t@5}y+LfxjneA{4d&=a1PzVSg|cmkUMDCsNSg(XX6Uto(goTi$m^nwf_5CH3j}T7 zK^p`mx6}E829op|LA@TW7nJCw^8|GzXq}*V7rk220G%snTa3;T)ET2!3F?T^VzSos ztue}zHOzg>Rw|OU-noT}T&;I>Qjx0lSSJ;kT90*5k*M{nxc=ta=O5v>SgI{*Ct5KZv<|2OU%?g{oF zdzp=y-!cCiD)V1E$KC7QWA2aPe+T%A`vrJ6V4u6&O}pLhCU>U0*uBQ3&YzsOoj0AI zI6pLgY<%B2<8Ftt%byr%yw*KgrTfeZLwZ04Q0{pY}n03(F zYpt~gtYNDY-T`o}>#ccKtrargF;AJt%>OhGL*@S{yU+B^#b&{rX7-yM=4x}Hd9`V= zx7n}QPuUaf-`H2!!>pGLu}y3RRQ%VntC+#6jCY~Z|BZ2)e%qL795Oy{JZRi+|EM^XCXXpTk|K{PUNc4> zQM~Fb`GVq=tH{HOSF9wT4|oOn2gNNb$loho-a@1F48J=Fpbp;?1W<_Yo)EO{3IZs_cdCHb;=_DOI;;>V#tUkL z2Gw{z2%sF#1p(CKLqPxq`Jf7TMV?iG1|@kW2%siUtKb|I{y*GVc!pD9!tV0BUnD2%tFc4Fagndx8MU^X?#k`aCfqFuQ^PD)hJtc!|DE z1sc@oTY~_K^esUE9y?XQ%k&Nv@H&07QVoTAEZ|p1HU&FDslG7?pjO`y1W>HE2LV*; zZNc%cjI;(Dpk7}e1W>TAn-H8{8w60YuL%OE*;fYv6z!{m0IK$tK>%g@iV4BVyMh1; z_vJwVm3vDNK)@39TS4%w+8`~?@NOK>h~o<00sQwAb<+~wjh8K{$oM# zBzfxAK+e&Tj|Ll{iod0JiNi_Cj^3;=lf8-W4i^9~lBV=niZxC*DCb7niTFh+^BH-js*&n+Zz-PBfkgzH&UgVYsM-S*0xqCJX2Gyu%@;wIPTezQq^2tGgV<_ zwX3ka(otAeZYwM;vlLD(H5EEjnL^t!6hh&z5GsEIR!Z^hfD%r8)r?IsN_Z0#|6xht zDIrOESf|CU#+c4O{|~!oiTkemj{83V_&@__01co4G=K)s02)98XaEhM0W^RHu4n_H SFnr;ED8!$1{=~ok|Nj85o10_+ diff --git a/src/radius_clustering/__init__.py b/src/radius_clustering/__init__.py index 50d76bb..093043b 100644 --- a/src/radius_clustering/__init__.py +++ b/src/radius_clustering/__init__.py @@ -1,5 +1,6 @@ # Import the main clustering class from .radius_clustering import RadiusClustering +from .iterative import Curgraph -__all__ = ["RadiusClustering"] +__all__ = ["RadiusClustering", "Curgraph"] __version__ = "1.4.2" diff --git a/src/radius_clustering/iterative.py b/src/radius_clustering/iterative.py index 90517b3..eaa0178 100644 --- a/src/radius_clustering/iterative.py +++ b/src/radius_clustering/iterative.py @@ -51,8 +51,8 @@ from copy import deepcopy from joblib import Parallel, delayed, effective_n_jobs, parallel_backend -from sklearn.base import BaseEstimator, MetaEstimatorMixin -from sklearn.utils.validation import check_is_fitted +from sklearn.base import BaseEstimator, ClusterMixin +from sklearn.utils.validation import check_is_fitted, validate_data from sklearn.utils import check_random_state from sklearn.metrics import pairwise_distances from typing import Union, List, Dict, Any, Tuple @@ -77,7 +77,7 @@ def gen_even_slices(n, n_packs, n_samples=None): start = end -class Curgraph(MetaEstimatorMixin, BaseEstimator): +class Curgraph(ClusterMixin, BaseEstimator): """ CURGRAPH algorithm for clustering based on Minimum Dominating Set (MDS). @@ -112,6 +112,8 @@ class Curgraph(MetaEstimatorMixin, BaseEstimator): Return the results of the CURGRAPH algorithm. """ + _estimator_type = "clusterer" + def __init__( self, manner: str = "approx", @@ -122,7 +124,6 @@ def __init__( ): self.manner = manner self.max_clusters = max_clusters - self.solver = RadiusClustering(manner=self.manner, random_state=random_state) self.radius = radius self.random_state = random_state self.n_jobs = n_jobs @@ -131,15 +132,21 @@ def _init_dist_list(self, X: np.ndarray) -> None: """ Initialize the list of dissimilarities based on the radius parameter. """ - self.X_ = X + self.X_ = validate_data(self, X) self.dist_mat_ = pairwise_distances(self.X_) self._list_t = np.unique(self.dist_mat_)[::-1] if self.radius is None: t = self.dist_mat_.max(axis=1).min() else: + if not isinstance(self.radius, (int, float)): + raise ValueError( + f"Radius must be an int or float, got {type(self.radius)} instead." + ) + if self.radius < 0: + raise ValueError("Radius must be non-negative.") + t = self.radius - radius = t - arg_radius = np.where(self._list_t <= radius)[0] + arg_radius = np.where(self._list_t <= t)[0][0] self._list_t = self._list_t[arg_radius:] def fit(self, X: np.ndarray, y=None) -> "Curgraph": @@ -147,10 +154,16 @@ def fit(self, X: np.ndarray, y=None) -> "Curgraph": Run the CURGRAPH algorithm. """ self._init_dist_list(X) - self.n_jobs = effective_n_jobs(self.n_jobs) - dissimilarity_index = 0 - first_t = self._list_t[0] - old_mds = self.solver.set_params(radius=first_t).fit(X).centers_ + self.solver_ = RadiusClustering( + manner=self.manner, random_state=self.random_state + ) + if self.radius is not None: + dissimilarity_index = 0 + first_t = self._list_t[dissimilarity_index] + old_mds = self.solver_.set_params(radius=first_t).fit(X).centers_ + else: + dissimilarity_index = 1 + old_mds = [np.argmin(self.dist_mat_.max(axis=1))] cardinality_limit = ( self.max_clusters + 1 if self.max_clusters else len(old_mds) + 1 ) @@ -161,11 +174,11 @@ def fit(self, X: np.ndarray, y=None) -> "Curgraph": old_mds, cardinality_limit, self.dist_mat_, - self.solver, + self.solver_, ) - for s in gen_even_slices(len(self._list_t), self.n_jobs) + for s in gen_even_slices(len(self._list_t), effective_n_jobs(self.n_jobs)) ] - with parallel_backend("threading", n_jobs=self.n_jobs): + with parallel_backend("threading", n_jobs=effective_n_jobs(self.n_jobs)): result_list = Parallel()(tasks) self.results_ = {} @@ -231,7 +244,7 @@ def _process_mds( return old_mds new_mds = solver.set_params(radius=t).fit(dist_mat).centers_ if len(new_mds) > len(old_mds): - Curgraph._update_results(t, new_mds, local_results) + Curgraph._update_results(mds=new_mds, t=t, local_results=local_results) return new_mds @@ -259,19 +272,6 @@ def _is_dominating_set(t: float, mds: np.ndarray, dist_mat: np.ndarray) -> bool: adj_mat = dist_mat <= t return np.all(np.any(adj_mat[:, mds], axis=1)) - def _compute_labels(self) -> None: - """ - Compute the cluster labels for each point in the dataset. - """ - for _, result in self.results_.items(): - centers = result["centers"] - distances = self.dist_mat_[:, centers] - result["labels"] = np.argmin(distances, axis=1) - min_dist = np.min(distances, axis=1) - result["labels"][min_dist > result["radius"]] = -1 - - self.labels_ = self.results_[self.target_cardinality]["labels"] - def get_results(self) -> dict: """ Return the results of the CURGRAPH algorithm. diff --git a/src/radius_clustering/utils/mds_core.cpp b/src/radius_clustering/utils/mds_core.cpp index 9e44945..0ed3827 100644 --- a/src/radius_clustering/utils/mds_core.cpp +++ b/src/radius_clustering/utils/mds_core.cpp @@ -57,7 +57,7 @@ class Result { std::string name; float value; - + Tuple(std::string name, float value) : name(name), value(value) {} }; std::string instanceName; @@ -67,7 +67,7 @@ class Result { class Instance { public: - Instance(int n, const std::vector& edges_list, int nb_edges, std::string name) + Instance(int n, const std::vector& edges_list, int nb_edges, std::string name) : name(name), numNodes(n), adjacencyList(n) { for (int i = 0; i < numNodes; ++i) { unSelectedNodes.insert(i); @@ -93,9 +93,13 @@ class Instance { const bool supportAndLeafNodes = true; void constructAdjacencyList(const std::vector& edge_list, int nb_edges) { - for (int i = 0; i < 2 * nb_edges; i+=2) { + size_t edge_list_size = edge_list.size(); + for (int i = 0; i + 1 < edge_list_size; i += 2) { int u = edge_list[i]; - int v = edge_list[i+1]; + int v = edge_list[i + 1]; + if (u < 0 || v < 0 || u >= numNodes || v >= numNodes) { + throw std::out_of_range("Edge indices out of range"); + } adjacencyList[u].push_back(v); adjacencyList[v].push_back(u); } @@ -120,7 +124,7 @@ class Instance { class Solution { public: - Solution(const Instance& inst) + Solution(const Instance& inst) : instance(&inst), numCovered(0), watchers(inst.getNumNodes()) { unSelectedNodes = inst.getUnSelectedNodes(); } @@ -338,7 +342,7 @@ class IG { bool randomConstruct = false; public: - IG(GIP& constructive, LocalSearch& localSearch) + IG(GIP& constructive, LocalSearch& localSearch) : constructive(constructive), localSearch(localSearch) {} Result execute(const Instance& instance) { @@ -463,4 +467,4 @@ extern "C" { return main.execute(numNodes, edges_list, nb_edges, seed); } -} \ No newline at end of file +} diff --git a/tests/test_structural.py b/tests/test_structural.py index f081803..927906d 100644 --- a/tests/test_structural.py +++ b/tests/test_structural.py @@ -1,4 +1,6 @@ from sklearn.utils.estimator_checks import parametrize_with_checks + + def test_import(): import radius_clustering as rad @@ -9,9 +11,21 @@ def test_from_import(): from radius_clustering import RadiusClustering + @parametrize_with_checks([RadiusClustering()]) def test_check_estimator_api_consistency(estimator, check, request): + """Check the API consistency of the RadiusClustering estimator""" + check(estimator) + + +def test_curgraph_import(): + from radius_clustering import Curgraph + + +from radius_clustering import Curgraph + - """Check the API consistency of the RadiusClustering estimator - """ +@parametrize_with_checks([Curgraph()]) +def test_check_curgraph_api_consistency(estimator, check, request): + """Check the API consistency of the Curgraph estimator""" check(estimator) diff --git a/tests/test_unit.py b/tests/test_unit.py index bf846be..e4feaaa 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -1,7 +1,8 @@ -from radius_clustering import RadiusClustering +from radius_clustering import RadiusClustering, Curgraph import pytest import numpy as np + def test_symmetric(): """ Test that the RadiusClustering class can handle symmetric distance matrices. @@ -9,30 +10,28 @@ def test_symmetric(): # Check 1D array input - X = np.array([0,1]) + X = np.array([0, 1]) with pytest.raises(ValueError): RadiusClustering(manner="exact", radius=1.5)._check_symmetric(X) # Check a symmetric distance matrix - X = np.array([[0, 1, 2], - [1, 0, 1], - [2, 1, 0]]) + X = np.array([[0, 1, 2], [1, 0, 1], [2, 1, 0]]) clustering = RadiusClustering(manner="exact", radius=1.5) assert clustering._check_symmetric(X), "The matrix should be symmetric." # Check a non-symmetric distance matrix - X_assym = np.array([[0, 1, 2], - [1, 0, 1], - [2, 2, 3]]) # This is not symmetric - assert not clustering._check_symmetric(X_assym), "The matrix should not be symmetric." + X_assym = np.array([[0, 1, 2], [1, 0, 1], [2, 2, 3]]) # This is not symmetric + assert not clustering._check_symmetric( + X_assym + ), "The matrix should not be symmetric." # check a non-square matrix - X_non_square = np.array([[0, 1], - [1, 0], - [2, 1]]) # This is not square - - assert not clustering._check_symmetric(X_non_square), "The matrix should not be symmetric." + X_non_square = np.array([[0, 1], [1, 0], [2, 1]]) # This is not square + + assert not clustering._check_symmetric( + X_non_square + ), "The matrix should not be symmetric." def test_fit_distance_matrix(): @@ -42,24 +41,23 @@ def test_fit_distance_matrix(): """ # Create a symmetric distance matrix - X = np.array([[0, 1, 2], - [1, 0, 1], - [2, 1, 0]]) + X = np.array([[0, 1, 2], [1, 0, 1], [2, 1, 0]]) clustering = RadiusClustering(manner="exact", radius=1.5) clustering.fit(X) # Check that the labels are assigned correctly - assert len(clustering.labels_) == X.shape[0], "Labels length should match number of samples." + assert ( + len(clustering.labels_) == X.shape[0] + ), "Labels length should match number of samples." assert clustering.nb_edges_ > 0, "There should be edges in the graph." - assert np.array_equal(clustering.X_checked_, clustering.dist_mat_), "X_checked_ should be equal to dist_mat_ because X is a distance matrix." + assert np.array_equal( + clustering.X_checked_, clustering.dist_mat_ + ), "X_checked_ should be equal to dist_mat_ because X is a distance matrix." + @pytest.mark.parametrize( - "test_data", [ - ("euclidean",1.5), - ("manhattan", 2.1), - ("cosine", 1.0) - ] + "test_data", [("euclidean", 1.5), ("manhattan", 2.1), ("cosine", 1.0)] ) def test_fit_features(test_data): """ @@ -68,17 +66,20 @@ def test_fit_features(test_data): and multiple metrics methods. """ # Create a feature matrix - X_features = np.array([[0, 1], - [1, 0], - [2, 1]]) + X_features = np.array([[0, 1], [1, 0], [2, 1]]) metric, radius = test_data clustering = RadiusClustering(manner="approx", radius=radius) clustering.fit(X_features, metric=metric) # Check that the labels are assigned correctly - assert len(clustering.labels_) == X_features.shape[0], "Labels length should match number of samples." + assert ( + len(clustering.labels_) == X_features.shape[0] + ), "Labels length should match number of samples." assert clustering.nb_edges_ > 0, "There should be edges in the graph." - assert clustering._check_symmetric(clustering.dist_mat_), "Distance matrix should be symmetric after computed from features." + assert clustering._check_symmetric( + clustering.dist_mat_ + ), "Distance matrix should be symmetric after computed from features." + def test_radius_clustering_invalid_manner(): """ @@ -104,6 +105,7 @@ def test_radius_clustering_invalid_radius(): with pytest.raises(ValueError, match="Radius must be a positive float."): RadiusClustering(manner="exact", radius="invalid").fit([[0, 1], [1, 0], [2, 1]]) + def test_radius_clustering_fit_without_data(): """ Test that an error is raised when fitting without data. @@ -112,30 +114,43 @@ def test_radius_clustering_fit_without_data(): with pytest.raises(ValueError): clustering.fit(None) + def test_radius_clustering_new_clusterer(): """ Test that a custom clusterer can be set within the RadiusClustering class. """ + def custom_clusterer(n, edges, nb_edges, random_state=None): # A mock custom clusterer that returns a fixed set of centers # and a fixed execution time return [0, 1], 0.1 + clustering = RadiusClustering(manner="exact", radius=1.5) # Set the custom clusterer - assert hasattr(clustering, 'set_solver'), "RadiusClustering should have a set_solver method." + assert hasattr( + clustering, "set_solver" + ), "RadiusClustering should have a set_solver method." assert callable(clustering.set_solver), "set_solver should be callable." clustering.set_solver(custom_clusterer) # Fit the clustering with the custom clusterer - X = np.array([[0, 1], - [1, 0], - [2, 1]]) + X = np.array([[0, 1], [1, 0], [2, 1]]) clustering.fit(X) - assert clustering.clusterer_ == custom_clusterer, "The custom clusterer should be set correctly." + assert ( + clustering.clusterer_ == custom_clusterer + ), "The custom clusterer should be set correctly." # Check that the labels are assigned correctly - assert len(clustering.labels_) == X.shape[0], "Labels length should match number of samples." + assert ( + len(clustering.labels_) == X.shape[0] + ), "Labels length should match number of samples." assert clustering.nb_edges_ > 0, "There should be edges in the graph." - assert clustering.centers_ == [0, 1], "The centers should match the custom clusterer's output." - assert clustering.mds_exec_time_ == 0.1, "The MDS execution time should match the custom clusterer's output." + assert clustering.centers_ == [ + 0, + 1, + ], "The centers should match the custom clusterer's output." + assert ( + clustering.mds_exec_time_ == 0.1 + ), "The MDS execution time should match the custom clusterer's output." + def test_invalid_clusterer(): """ @@ -152,10 +167,12 @@ def test_invalid_clusterer(): def invalid_signature(): return [0, 1], 0.1 - + with pytest.raises(ValueError): clustering.set_solver(invalid_signature) + def invalid_clusterer(n, edges, nb_edges): return [0, 1], 0.1 + with pytest.raises(ValueError): - clustering.set_solver(invalid_clusterer) \ No newline at end of file + clustering.set_solver(invalid_clusterer) From 3888cb248174f87b509a57c387283ab66f71fdd3 Mon Sep 17 00:00:00 2001 From: Quentin Date: Thu, 10 Jul 2025 18:12:24 +0200 Subject: [PATCH 6/7] upgrading curgraph to match unit tests validation --- .coverage | Bin 69632 -> 69632 bytes src/radius_clustering/iterative.py | 120 ++++++++++++++++----- src/radius_clustering/radius_clustering.py | 56 ++++++---- tests/test_structural.py | 19 +--- tests/test_unit.py | 6 +- 5 files changed, 136 insertions(+), 65 deletions(-) diff --git a/.coverage b/.coverage index bee8848d14b6ff0eaf5a904e5c8e8624fe3c1b56..3d573d838879a3b8593c389fdc63cf4b7d597a50 100644 GIT binary patch literal 69632 zcmeHw2bdK_+I3a;?c3GWi8PD^We~|>NRmX7MHK2__#m2zrTJp4>O$8)m^vybl=nc zcD+@VH)UdRbuw#Vd1Yy$I;#(AgfK?Kva%3D7W_7bU-4rChzSD!<1~CYXrUhKo4EdT zuBl2Eu^SyMgIgV2u}5lqh6;*{V5{onBB-LGvbZ!+S(mjmS=TMK zLiB0X$y2KN0;nWeyr_(=k<}^MymMA%a$&MESyq^=;(O7lxTrJjymk#WGn+QWXC_m} zl&CC(e|CeTfdz*SW!yrjsIa`Ui0%D?%0yY=;$&5~tb{zm!gAQpQ>yvVR+g7$SC?m% z6qltAv#hGPy12Y7D|t$?u%7w$hcsXfRkU)3;sB#HzsCgn>P5&WvL5=$y2bPxVlOLtHDJh7a!)*VdaOvZrHb^ z;YGr}MI|L!622hd1!A>@HD!hDt5U^|p|E^;vNEwK*&{0Dw(xUbNO0-&vf`Pe+8plEk(WS0 zF9~VTB5r23YK229bs0rxEBeoDa0!Vs`2#j8ib|7-s+vk+BtK@7Vv9pADZV372^>y$ zGoCrT)!_~&TCbp21V~`U%*>1o9Pn=h^Ck9cG&l&hcv^#-*>hc zcD$nz&paXHa68U+p`aK08(%wHOYQQREmG>2g6Xqgy}>TCC5Lu-Lb9&=@zC2g<;m_nhxV6OaEJ@-VXe-o)#2-Xa(Pu%G2B+! zZL4aKSen&waiWU3P*umQ1?A->$wV2yH^3I-s^Bw=tFlu2@cp~e&_i`p)v-DjE$^7x zCvkIOHa``BN||Gc z`g>G>Cal)dL=peGvy+lmW9{PP_p0M}qmhZ~WOZ?AGK;N}+L_daRhy{FDy&RGA>2zl zchrCEGGJ$`pcnhgqGZAz<$^ZY%mL97)$>(I_F>?QWULGVxgCI+Q$WWWHd*|y`EMfs6ZoTklmbctrGQdEDWDWk3Md7X0!jg;fKosy zpcMFJD&T1P(Y{FW=;dR*Rc4d~(7(sP9=$0XmFB;X{P+C>zsv(vd#w~u3Md7X0!jg; zfKosypcGIFCEo`i1w(_>I&uOqIqrnGmR6y5Vi?#owM!&*?Q-`D!PzopolmbctrGQdEDWDWk3Md7X0!jg;fK-6B z|Eu_4%AtT#Kq;UUPzopolmbctrGQdEDWDWk3Md7Bl?oVeRM`KT9b04nYyVaMdH)Ij zL4T`%vww|$sehh-hQHEZ<|qBx{uF<-KiKc>ck++%oB0&R3;e>{@4e;i@}BXwdG~m? zde?bZco%qQd8c~S-V$%VH{F})jqnC|J-p+*mR=(-?*7aD+pt&3?%wa-;cjxT zbkBFsaO>PsH{nim$GIoEz1XV`UishzN=+2iaJ?Ot|!yM^u9 z8vTNPNcYm6^ig^b-Aw;TFQV(|YFa@H>2x}Q4x`7@uCy&}PHl>;1J;Mu>(&d_cI$rY zc59P$rFDUIrnSmiW-YX4TPItitU*?`b)410@-2gWLp~$#keA6*0A^*O-@>=a`{cYnGS^bE-Mc z9BlS7JD9D^My3(}D*kc&?f9U?c+zq zZR4QvnemSCvhjrRSL0UW8sj2kow3R&Hx?MvjIqWEMh~N{(abPoU&lU*y%Bpp_Gs+x z*iEskVi&~Dh@BE!8k-lJ5*r!oAL|l3I@Tzr>tE;}=&$NK^oR62^-cQa`nh_bSL=)P zS^6Y>nBGV4ptsN+tzP?F+o$c)p49HwZqxpxU97FwR%;bnA);Y-16!JAppVf9H0&0U zkI=_7>>82%=tCNIiO7fOBN}#=WIqi%MdSmvTE~dI&uZ%sk@wL1G;AM{chP$^JT4;p z(7QBj7m;_+J{leyk=M~XG;AA@H_+=e%#tN<(6CKJ_Og;=BJwiYOT(ih@*3Jj!`2ab z6}?8oRuS2aUZr8ni0ncy)38NE_ONY^lB?~ZVWwPdHw~Ld#5!2YQxW4c>+C2LlTkg>~qYBJcb^nVLT#_vJxXA+t5Qaj78*Ow2g*(L>@wq z(NK%XgXm!z;)whWJxD_ok*)Q4G*}and(q!$aB4*ELif^Obwut)chO*#EV-KoDY*`!E(9H?KG%~$Zc$!YEJ6&&}}rRib`%q zU(%p5BG=bf(O_9bZe*k)A~&ELX;3a!LpRW%EF$ls12ia&$ff9d8k9uj5_BmImPX`a zbO{ZXMC2lLF%60%av{2i28$zd0lJU|iz0G9x_|}?BeK4JB@L1hS%=ospoo+DDs(;# z3Zs&9(9JYh5S6S%>#QKb2sY5#bu^eCtreoRG$@EjfI=F~i^ypxu!6Zjux6}5r;%U| z!`N6|OM}_bI%lABNHB{pU$vS9GdWzjiUj!_u2@Ne864KFAi;DFYwJibjl)&7B$&$K zN`QGBt^kq#(%!@L?2jEun{!DplEbm%NpKQ}lSh+a z1cy13Nidwli8&+~#^HpCBpAx!_z5H!!r?fm;Y1F{jw8Wf4#$io!3i9W9z%jb9FBsI z7|7v$qew7-!;!E%{W%;yk_7!Y95$Q;$8$Jj7zz4vICuyN`fzx{U=sA^aL@@P=*8i{ zK_uwO;edf8$mX#B021`zuwQ=?u-gP&3w`^MfZZej_UTIkc8dVmyAKK24FX`V-XvhR z2Y@|$k$~MC0A}|j0lPH-?4C^nabtLaSw(bXzyl8MG6QyFuwW%KV0Q&<8TDqsZV8yg z0I~Z4!7(^l>}G(mUa^G)>{bAbDQ<=(9w{1~^O+`2W~ z1h=Bxxiw6&L%DHlxDo6c_iYXPxDn;Ht%1Tm<*uz^7nXC=*06(J$~{}dHq^r{o7HJk z?$~TwN>hfdY1qVsa>HiWqTH`F9LFZ*cC8&eh~t#GwFc|Mo>J!445Rwil-sq2>M^3s ztzmh6YgA8Xa*W==n#KGYwnT4ju$W&nxOW3(ehq7*H_%&@88*YOqc=Gko1T^lI#Y_@n0+D(~fGrSufyqIZoyoWN^W|%!fG23SNC3u2jzRmE9=q1XH zTSG6P7b){?w)}Z`u;Rwe@NUX{n=Rjoo~O*X8Ga7!w3u_V-FapMWzNkO??BH{X59?4 z=O*Ue9PgmaycuSXP|UpZViNTtZ5t2HY;V|npO=rA8z1UCwoqt^0g6YkSm{w%_6TH91WxJzsKrG%TbHZ6~EkJhG6BQ1D+d9w(2 zXf1at;Rda>tS8)`wN|-=+q3p=OTyh*YuUt6V4~d?!D)|>fP(zYCUe< z=RN1udDnVn-eul--sxVUH_6NQMp(Cc{k?AP0smI7t=HVMJ>-4}QU43>3U|AEJG=?l zsf2ljw9psL2ySrI#+->T7BW>n-vZ>saer>oV(HE3}qdORRYi<@d01t=0S!pgbXPG0+er6Z*nD`TB6EhzF8lwES;xEN7i*Jcv7hh<6 z6rUB(iI0f)i+7G6ZM+$8Y`hpZjP1r(5ZklJ{w!mKQDPJrw-_Cb7RE`29s4%+N$gFC z>K}=1jV+5^9y=$tCRP|56B`)Av91u&`})oLKlFY2cKv1jUi~Wle0{B61JV3UeS&_X zo~=Etx7C~KhW4fQf%b~_fOZ?{tzDy?sa>emX~VUp+FWfiymQzPyPgbV7Y%absY}~H z5ar^Ql0le@TS&GW?ZHP$whQgXnUaA``MmULF`EeAuo0%1MkL;WZ>A?mJFg^OeKS`7h94++>4215cpzKGKhR} zTrvoKu^}16zBncs1i$c@CG>>o7i*G1_=~Y*5dY#s%0mE*>SY;3!04c45CWrbCA$mV zjs7JW1i|PV$sh_wUrPpIF#1X|h=bApNCts0`ckqj=yvplWDp9Y1A?(w82wW)1i>)+ zhhz{9qt7LSa2S0i8N|crQ^~GJH=s`>gNPV?EE$Bv=p)G>CPw=ugP<7wT{4J@(T9>j zSd2c94B}$+zGM&>qxU3($QZpV8HC1YpI|IDM(+rQAUH;E3kF-QLvK0kY(kI;;W2ts z6d{O@(HoLMfQ(+33?gK-mu&_&a~2|_*K9_hj@7m3Rk<1j$>b69u)a(X$e-LOUd0 ziJp;o1$tWGnmY89JOl`u(UYQN-5T_SEP=2YZ5JiE_2_Yd^J>sz0_WzTM+M%Tiyjep z?0B?I;K`%W!vg0_Mh^)*F$X;;@PvuzZvu~>fF2Nd+<0`qz+=auzY08NEV@tN(PPlP z0*@Mv?h*LDQRr@gN9Llt1Rg#TZ54RfaP${}hYUk^3OslSxWyv?xMwePy};Q$(I$br zXQMyMvjmYfx=t__TBBd`eRNVHfz`Xhq`GkA4WC&bt24+580 zpsNHfDM!B-xU>XaDR4mn_67o8Ppsn?@3qctG%M!yj_e;ztR;<@Pb=$xd2Z?sl2h`v!M8HC>` zkPPB)bXv-+z^zyWz9x<2pDGE&;HxEpAbgc15QVRl1j6tYl0Y0@CkX`Nr=*c7wUR(6 zzFZI%i`TG2(!ls27_SxuEE=zp1j6x3Ngy6yCJ6-O6@svcyj&0#l9x#WF?p#VEGRD# zghk~`C4sPfi6ju07fS+x`C>^RGG8PKgyst+f!I7L2?Xawl0bA`C<%n;3nYQ~JdsAM z`I10{ULXhy(dP-mV)VI^K#)F15EiA+7KDZAvjkyr`b;4;1nT(`@7*v%t^}d_bV(pq zpC$r*9xXgyDEziUITEP!}@iX;%QpPWYaOqK+m1#={Un0=BY5VTK}1Rew@NCIK| z_%yP6oFovqkCg->_c3zUp4l*37C`JiN)mVg9GOOToFoZ6433Zl;`iZ_Kmb2X5{Tf3 zO7a+be2DzaZ5vLM1rWp!PJOB2@-E+af`Db^gV^eD8d)4aP+$n;2Z&WFH})5>qOzZW z%QhY_V8yb&0+v&Tvby`aS^JcXWFn8)v0=BHr6tGopa{=FNnIT}yR?XxoeRe}r zu{g7369F?a8_N?7!A2voylI9nV3VevfQ_5D0yb*w2q zw*SL+m!-BlqV|8NWh$%Xzt{dxa`bhn{rp+l|BWfgOC8@2wExFX!);S#|CRRt_-W|j z)Nvni`#;oF$Fd53=Jx;iWyoQ>@%{FHl7p{fC9M4)@CJre`#-$q2Oa)@ZU2XdiN)V< z|A(cyn797}9P@9q|HJA>!SO5hkM5VdTdjCJ`Z}ac+Z}R`(|IRi|d=*uicN`x7?T9Cm`E@ zi+i1WnR~7qy35@q?mRcw9SwQ^J=|m64A*kLbv|?UIlG;wod=yeo$H+|oO7MioJuF@ z%y7mzCqSS7Hcn$lw-4CwL683@?Z4W$*jL*dp}&8vKEN)vXWNtPq0ryIotlI_+&ax=M#oNsj{Ye@}RM4CdT{{-mQpH12l!~D|xzWVi_?r*3P`ptTX4D!_D?)Cgf~<6aOgwXT3}O_4vc_=i+z9 z&yHUj&x@~)m&X(FQMv~i|E}?(@tAS9@pinwj^eG1dg$4Iqj9Bio)H*TM$(vJ9c_$* z9{t^oHb&#vlZGC9FZS2if!OZY)v=Axo4+xEx>nP{ZK--~n!3M!_XFLR-oyI!h^@1JTh5O=lfx{eNJjuF1S*~QAaTh!#jrGJQOV$b2oGjU~ zxC73StOI;)CP{W2^rf9BSv!0jo*>yVxE&s!#ya3}k{yMQ!DA(BjgP`(Bx`|NfL!=0VxaBiVHqYG8o1Mx04LUalyw*1_Qa^wvs`vGR~3= zl9h2A$zUuOe2io;mJFO2~gFbC({oxw)Z%gYN_aE!)mfp#Z^?MULeYkhn>HXaPWBu0B`ojIYeb|0& z{nl9K{P*8Yi?fp6V=uioVS2j+@6ss@ykX~0-qgiDBYM4^O4R2sy`#@KY)`iaZ@1HW zy`}ebJCe8C4c=&{^>kys+aUiRcZWgRp#A@2f1euvPmTYl#{V-?qtU4G|J3+@YWzQg z-R9Kze`@?cK7N#;#{WwVyrjneOC_SH@&8zkgBt%&jsK^{|5M}t{j$dYV*>@$Ypnl& zf!!1N|M1_1?Ek0y2mCEC*54)m+5V|8?thU#!yoUT==b#7`OSUH`o7V=~`wjPh+`q#p|IfMG+^y~nFvkCH-L=rxU+m6t zC%Yru|KFV3o$H)SoeeO)KWpjFcP2PPoLLCc#lF_Q#6H_T)m~;7*)#0%Ft&fsBX1Q}H7Es?0!jg;fKosypcGIFCTHz%FwS~gTS5SBYg*Rs*NSqY>L(zR^BZXhgk zkgjEecC!*l9HeX6R0@nh-XLAeCR1Pp(gx{TT@Y3R3sWVKG)UL-=@e2qgLEyQP9c>t zNWwL+8cc>U`aT-+=@bk|5KPc;HQ$`&3tC}nLItQ}&So0&2^9>MEl9%^(T6~?APwsx z0=a@T5Mlxu)T&~uMhJ0EDBQ;e+n3`7s)POWW8uFPH49Ew|q#>VL z!C*;(G~{zDFp@(N|5IBa0e0~K3@&QNFu!OJ1(gSI@ zG};ET18KNKk|8uKjtJxi(r~dX8BW7R5rND=8ZMM2!)VCoSuh|?6iWC+3kHU?KpGZC zt3g&E4HrZNk^*U%hzR5a(vVNKU_c7s02&rVC6EzFLq6St0a>QU({OH70{MV6U9Of*Lq79@ zfgul&hJ5M;1JeG0@VOUQ1zl;FntTB&fFwW~^63{?Nf#RO`4<@JNyEue)sO*5Lp}q8 z0V#eRXvn8vFfio*(Qsn48l?Zxa6&{N`;Ug>BLc~PG#n=hR4_Iokorf%F%f~xKN^mX z2qgZ|kk7?nK%QR<8uG~);0uRa((oi%0hEl02qgW{aCk%@=Z}VbLIwjv${!7fMkSE( zM?*d(1G~tZ(eT8m1oHi8$Y*6>B~59_re$E?SjhIH;h=~>vL6ixMg(&GXvn5#V5>o@ z9}U?A4U9mh9}U?%c3#qmhHRP!Mj+3RhHRz=Mj*|PhHSD1PAnR-`5G8uQ+tGL$_5}T d$B%|=)&@o(#gB&B5rGUp8nU??SP3Nf{U0Oe}vj z6pQ>8x#!y>+;*`Q_Nmbgu@dXD=#6H=Du4w2(F$kVx7}xSzjMd-i;q3G_t1WEu`J%2#o^PgUHsg5d1`LFELk3(UR0cn7nK$k z6(lN)N@v6?XC`Its)}SGf6&1)c+~NX?bGXXlwVW`YgHy^KoMo-MJ0*y>iFzrb=!1? zU~84hd6i-T6p}2OQOa|~TL-JRiI*p*C(Dzi1<4Ar7p;p5+psnlFJfkPy?S_YGJQ;m z@&fo5t#vf8;P9agnhpgOq{<8V-cKt}lorfPRU_a+oilZ%0mE=^W;>AU! z>BB6oD5@+sF?MgFWVRbA)qw_;_ob zPANG*qhhV9(xN$4Np|%H+I?ImPgMD?d?H znM(6~IQjXx@D;lxi<;TZo8!fE(iaMsCx2Q|WrYG(t&2u2KFX!T(~o}L@Vlh;MZ)i* z;^KH}dis{(YsbaOfvY+5p2?A7B+B`X7Ec8kVNd^0uQtB3un+NR#i?nvueJ0QQ;~!l z1iV0In_g8~z;BfbehdYv3zFrD8OioRs<4Is>kA1kom?xD-LLu4E**6V&gDu-JZJXkyd z;yhzv0x~D2rK&2^*Dt5(%acEQ4)3oh;4qikL#{UI>~MRJOjT4A!DE#_wkmqarSVoX z6BXQrDq6*-rBcPoM5%Z-z#4-pV4Fo1@$^3Y_^H(Ea2XZVyo?zav`X)je7JC%9}72~ zyuSdHlq5ePPW_RKyF)laL6{(l)I^KXn=Kvz4FA zzj>5Q`A<2oIX1ITKOK8j!ObXqCcgq6>Lm%e+W{rDE|MQ49=={ifZ@Wgoy_ZmYnLGh zr(z4-+{782CRC8Qg5NrmKr5gX&TZOj79D$cZ^fx+zwEGv;tZIt$HOwR6$h1vmL}5`?MSekXRYhg8 z9EQ2csVFZv!crI~r5pzFxgZJM`p1W%$lT*DvZ}JExFRP%zo@jRGCv>69UX#l+Z=bf ziQ*Zl@}kO_B^6NSc?8Prdfa9HNG2$HB!;3}9(U0h0`U5O!x@L%dtKy=bH_R>kLeWa zOj-f0fL1^&pcT*xXa%$aS^=$qRzNGD6+mco*lUH4J4(%r|BVS@FZcM1{89WL9uxN3 z9DljB;{VXZu-EnY%REy2Pez8lmez4An-Tx(@&EJ*dK3TOqi z0$KsBfL1^&pcT*xXa%$aS^=%VFH!*mjtaZq@?*<#zjfbsUvpn@pL8E|?{WX+UhiJ+ zUg|Dz=eRT6Def3|klV}c?6z^6x%FKJ;{|@@>~-FE-gI7e);VjP)y{3s3g>su66Zpv z(wXHNCqaWPNV!vEH#ZTF+UJTK8GES^sHW zVO?s?w^CN2b)Gfc>SJ}Y+F2)A*;dRV^c(sq-9g`=FVH9G8hQu4kuIl8=t5dfXVA%X z6zxyXq#bBW+L*dDY#uN_H9s)lG+#2GG9NVWHg7VoF_)SbnN{X2Gv6F*4l>U&bIg;> z1}2Ss7x^^uUSw0``N(6D`y;nURz$9h{5rBAQWi-@CPqd?`bN4&PKz{+)QyC zazm$tvO|n~Pxg@y$m`^J@+i5N+)S<}OUMFJTKlG9ep>@ugc#^ew2Aqr2jmS-P7BB; z^ak_W1mty|(psgw&iqyZ+00W;RVka9-%^ns%s)ku*O-5@B3qe%k|J+0zl9=Om>&@0eJ;M0nHS7nfXlv@-ljb`6mYCvzil`-y|TPbJ92O zCEf)Axe@(|c~t?q0o};FN|_Daz`P1UYG$Gp%qtI4u0u~UZ%#n2saeRpvVdHTu3=s( zAd7iQX+V~utC?34kjv0g<`oCzH|R3v%?`*C^c&{Q3dmx#gn31R)GS1cEpH|#*gzhv zKO@L>5%Mf=IuJ~ZMd%{xB{{~%g^Q?HDB!{inO6{GzZhM|ylDa1i9TXpLXi)tH&v{$ zpo)6=0?uDRz4HaEo=?3g0?wgdb zw()g_OM>2G-Y`YBGjC`>-a+p(Z-^q>s5cl0C^)Py^#%#JbRqQy3OKTcdIJO;RYAS} z!O~IG>nGsqk<{xeU=?WBCxD<`?*M{!=L$G<81>E(aPUy-^%8K7oh{&i zfz&%o!2ScMccy^-`ctomfPMQ>ue*SKVAs0|_;?@cbrrBT>{%B9d-bMXX90WmqFyHf z&+19NGX(5$7WFy`*u4k!as}+xoq8Pv?AncbIRbX+O1<_1cJ4yGb^>G4N{1({5b zKNqlN)R-QBC}0`}#GeQhhv5wI#{tHnviqsW9|izZWz^%30)WLS>hT8wz>;F>@y7sP zT0%Yk5CC{TEai^?fWx6M{r~_tbT9R|{R12VCySduz`;YP$E}}3IB(qe0S*AjZ6Dx3 zI9c5E`BFHA-0}hLg{9o^0oET*J#P0LHl!Xmdw>n=Q;%Cczy<)h(en%qsK;#{U|q<| zO&(xYUFva*2k2%|j~hHd+ciCI?-oGQTy%&a0>H; zr4xMKp)Jf4mQL^_PK2csJP}V~p0IR+^YD4h6P8Z!1f0i&qjNlw2}37%Jf6UWpA$R| zk7vTp2_B2bG2!L}kHKS^Fmr-O!sdjR6FdTsWWvfh9>au_6U+`BWWvbFp)GiXmG*H= zaI0NRxH!Qr4}HLdi4%MZZpnm)6MParg$WBMxCK6m2?r;KKEN%QFmQsK;#N%fH^EH~ z?PbEg$)UZt2@~#3aDCjA3G*g6h8-rno8Y>zYr?t-&cbzB+PMkAQJlqua}(^sa&FuN zmOI#G+_*Wmp&ahq9LKQDgmv?EV@xm@8*oiyj2U-tz-Si^Tim`OH`=kx;`YtqyUQ52Z=U0Qw1aW~ z=J-9ZP;THHzY9jo4V>fc=v^!A;IQgj%b4(Qg5E}3nXqqywxYM0aBq&cGhyBYZ9#7_ z;oSs&Vs2r=x(Rv%EI~LoL9fFPOkvyvZAPy%?%RB`o1uJe+#GL0n;Ca*j^XE^#jTqk z$*ap4w{E_81N<=L-p%nV=vBteTi^|hyEn)DCnC3Rj$eeIaooQ-egVD6xPf#09D0Fq z2j_S_Eaw(3man(Chx5&@TgJGD^Tq4Xb5xi(NiC;RDr}n+CaG|3k|?CYw8_*26`oDRp#!PI7 za$kmnkJ>ct%b3iVX>wnN-#}3_?aCIZF%1J2;7E-bP6H!ihSDHpm?RBI$b_GRQ19=z zCGa11_wg{zb@#bj+*g>#US?lL?}~1>uduf}&$?aRXWU2JTixHo+ki{lYPZxaaP!<@ zZVR`!+t&Hs`OMkvyzM;VthUx#tDFr^wR5Xe>ipih+_}UlaE3W~PH*dGr>oP}{?c9T zv~U_aQ3u%{LDav=o^L;IKLBq6R@y5d{x7k|+CA(6b~`(6*NYvnBe8eg`mwz(jlBl% z0Pc=G9Q#A;cd<)j^C0#=KQf?_pcnlWYyUo&BEu2b;)duwm>Ri1<%sjjTf~X6>_{wcfWj z(Un$x>so7>b&*wR&9o+4BdxYp-0BO_KBfEUL5T9-q%YFHLX^LPUP7;+^Jp<0PZN=5 zv?D1BC&HgZmPH;l^XL$I4n2X!C^Ek=_n2>+ubS)3halGfqj{CN#Qc?6W)_-5&2!97 zW^40AvyK^#d=uFlxhQg7WOd{g<16DG^$bBv^s zW1MU>G-&uc>=*6|QM?_h3GEMk5ZWB72t5~i zICNL&hS2394`TQUp+TW`p)*6xLs{e|vLzHEUy{F(-Q+Qd;g^%S zVWdZ7d}LsxZKO_w7!!;^aLw3}^i5$Ph;ngb#URYZjSe$NJ3%rY=;Cb25JbATp<)o~ z;s%OAtc&X_2Ei__rx--LxUOOl?&2)PAl}7w6oY^lyNW@?iyg%tMr*OH*hA=H98(N} zUK~{nqF&4tgRmD{ib33qsbUYHHP}=PB3~TIU@Nhq7{tCftQZ8p@QX|83DGYmib42` zv0@Pa;={^A0E}u>8brY8kYW%5qwf`i7#JN?41!?vonjCLqi+?1Fc`t@7o1y&gV8sN zK_HC2RtzFx^p#=|3ZpL-TY+vsUr5G-Vf49V2%=&1nPLzQqy36OJdE}!1_3epyJ8R# zqfZrskQjZU7{tVAuVN4sqmLDXs2F`D84rumhcSM95IjO$jP}SR1c5Qy&66+$kPsQ8 zT~SUT--V0NPKg&@h<2z|AUH-JCa%D%RwvE8H2zY-7XoAtl@1L+#LK3Ogng`#_to07&Y+P zj6+JoHSnGcAjKTs9Fz&+HM&XSR2llC#KkG}2Z>9H(Tx%pm!KOYzP}i)ka&0v`cH|6 z?nS?sc*sz6y~Klupz9PKjINbLd#GMS`y@dz#A== zIPZM)YlWwvOM`Qgj=j+(ib3#=E>;YpZ{#Zm;WzRWgZLX=lr}_gL>_@JQUpTq3l)JF ze4!!`g#Strh{6{r0%7=kMIa8Z&LCsvDFTuBTuFE+egQuwh%R|7UL_NFFkYz$MB^2T zKsa8m2*l%aB;f&hnIt?SPbmT+d8s5kCNGhM2j#_zKvX_k5eUm?DFShMks=V7&r}2= z^BIajXg*yLh|QCVKyY5D2t?-via>ZiErUc8ia>xqRT3Vd=S#vv^z#*gXncwyJW7&; zN9mI!;bHnjDK*6D=PCT|vOJXuBJ~N1K&U=m5s1~tDFVUzShfCJ%f_e#2-im|0`dB& z46=QsA`r2UPy|Bu;fg@aK1>k^+J`CvQTvb#vURW`5VsFf1OoSgYS&&}Hb5mn=-yuu zh~4{TkPUqmf#|)DA`rg!Rs`bra}|LA{v1Uhg72lsb7*}}wFB#xovjieh(9ZRi^Dle zEk9Gj(o_$g9j<>K$9I<)0{Lz-OZoDy5|)*Bk#NrP&Jvc*=_Fxk*%=ZpFYPE{NoB5t z$*CPAoV_wf!dbK1OIS3korE)s+DcfMJYB*>;b{_1O|+3Pe`;$5XSR~?{Jc{ooRrs6 z!ijmONH`_$WC{nL?rL05!p2SNN|@a^OTq@(btJ6Uz?HCW zJx9W0+!yamDW*8fMEp>e#Nf2RH))@{U_H2g#L|8xzxEG^i7TmKI^r=~CCZU0#P zKdjf9ACj*BhZp#u!@qz1Km1_0Abm~zMEyVH86DLB103)V)c-^F6F>)OVSn`cf2$js z%JUp={XdkG-qt@;{|`R{)9dQ`f2gtNZ3cAxKW~|@>;F~dsjmNrrv1A9U)TT3D*Jzn z`hVU!fP@H5w-c~x2p|9N1!DXVtK22-uiP@XFx}R_w^0*G zxSiZn-4k5q9CY?MJD}hHOU_fyUz|Ig8==?#ubuhO>VF#a`5)}`bUHexLXUsPA@*0s zqtMoWo4wIkZ9i+=Vz0GV**DwQ*q7NC+U53id!jwu?rnFmPqUlZS+)^75c?$dUTkyh z`PgHz`(n4ou8;jTc2TT6mW)k^4T_x!J^q`=vSOj=m(ktO-~VsX$D;Q}Z;CF5-u`o= zMbSyd`slD|FX-)ma>c)saSgQke}LV}u4TVr3s?!uXQNmj)|s_pjabzB z-uk=sp7omb3^ZE&obH4k{C}VqL!bR2)`RqU`Ut&;UP-IydGu_WLtD^#)G)uc?trlX zerqkV=2!*RIBS6Uq4|XQCv&km+v;weZf=EM{r8zyo7LtNbA)-Wd4}22%r=&omgQPa zp*R00k?oNUk*6aIBY%!uA4x%v{-u#Ik$%vNzfGh`#5VRD9~iIyJYO~Gz19k71^!V5 zf|j@j+K)cNQxtmuy2nfo+G`mD`{GFw_aBJUO*nbX*!OXKo?2ribVd=4I3WdRt;Q2% z5_S)LjmIkn4OH}DYgsJx+=C48t!(}A|cF=*|J zn<@t5x!@BOg8^M|6UC0u!$+n4L?@abcANQ$7iG1()EuGPjM`qUE*Zz(3-1SS+H3d1 ztKIMJr+V7`@NIf#zq_C6TUWbJUhSTJKh^Kzq^0UoA%7UeE*7mcgNh* z4%P0VckDgw{zv-Oq1yfR(tYF7?*!m+^w0LRlMlV4y+5z@K0Ko@U9G-x8U5;N^^?o! zW5@f<@qPW#d-7xHRhHR7k9W9}JI(vpNqn>(Y?-~>GT(&%Z}wFCx%N)u;q{Muph?9u z``=~u6XbmirHP8}po-P%s9L+aQuKLKIn%WweEY!*?sTLty1fP4KWbmMTJN_r`n_fL zbvvH-+qK?lXY_T;?BB-UT|@hS+!lsugZlqZ-JN>;KRy1R9{*2||7Rc=3qgn)g`!eKy?(Txt|1Y{vxc9?Yf7iK7-3wvd|3Y_yJJ>zj?ckp5Hgqkh`TskN z_`lg%599rWVSqh=CVtEp6EPSFTKF-a&&0s@4I20{ zpHIfX3AFEHKA(?)6KLMYd_E(CNZHMNJ|_bw(72EJd|C#PvWxkAW(Gk(9X>gOATXU7 zpPxaHz0BuRGzjuB^Z6_doItZa=JSafIDuAu%;$48Z~~3`n9rwc-~`(AF`v)azzH None: + def _check_symmetric(self, X: np.ndarray) -> bool: + """ + Check if the input matrix is symmetric. + """ + if X.ndim != 2 or X.shape[0] != X.shape[1]: + return False + return np.allclose(X, X.T, atol=1e-8) + + def _init_dist_list(self, radius: Optional[int | float]) -> float: """ Initialize the list of dissimilarities based on the radius parameter. """ - self.X_ = validate_data(self, X) - self.dist_mat_ = pairwise_distances(self.X_) - self._list_t = np.unique(self.dist_mat_)[::-1] - if self.radius is None: + if not self._check_symmetric(self.X_): + self.dist_mat_ = pairwise_distances(self.X_) + else: + self.dist_mat_ = self.X_ + tril_mat = np.tril(self.dist_mat_, k=-1) + self._list_t = np.unique(tril_mat)[::-1][:-1] # Exclude the zero distance + if radius is None: t = self.dist_mat_.max(axis=1).min() else: - if not isinstance(self.radius, (int, float)): + if not isinstance(radius, (int, float)): raise ValueError( - f"Radius must be an int or float, got {type(self.radius)} instead." + f"Radius must be an int or float, got {type(radius)} instead." ) - if self.radius < 0: - raise ValueError("Radius must be non-negative.") - - t = self.radius + if radius <= 0: + warnings.warn( + f"Radius must be a positive float, got {radius}.\n" + "Defaulting radius to the MinMax in distance matrix.\n" + "See documentation for more details.", + UserWarning, + stacklevel=2, + ) + radius = self.dist_mat_.max(axis=1).min() + t = radius arg_radius = np.where(self._list_t <= t)[0][0] self._list_t = self._list_t[arg_radius:] + return t - def fit(self, X: np.ndarray, y=None) -> "Curgraph": + def fit( + self, X: np.ndarray, y=None, radius: Optional[int | float] = None + ) -> "Curgraph": """ Run the CURGRAPH algorithm. """ - self._init_dist_list(X) + self.results_ = {} + self.X_ = validate_data(self, X, ensure_all_finite=True) + first_t = self._init_dist_list(radius) self.solver_ = RadiusClustering( manner=self.manner, random_state=self.random_state ) - if self.radius is not None: + if radius is not None: dissimilarity_index = 0 first_t = self._list_t[dissimilarity_index] old_mds = self.solver_.set_params(radius=first_t).fit(X).centers_ else: dissimilarity_index = 1 old_mds = [np.argmin(self.dist_mat_.max(axis=1))] + self.results_[1] = {"radius": first_t, "centers": old_mds} cardinality_limit = ( self.max_clusters + 1 if self.max_clusters else len(old_mds) + 1 ) @@ -181,7 +206,6 @@ def fit(self, X: np.ndarray, y=None) -> "Curgraph": with parallel_backend("threading", n_jobs=effective_n_jobs(self.n_jobs)): result_list = Parallel()(tasks) - self.results_ = {} for local_results in result_list: for card, result in local_results.items(): if card not in self.results_: @@ -189,18 +213,43 @@ def fit(self, X: np.ndarray, y=None) -> "Curgraph": else: if result["radius"] < self.results_[card]["radius"]: self.results_[card] = result - self.labels_ = None + if self.n_clusters is None: + n_clusters = random.choice(list(self.results_.keys())) + else: + n_clusters = self.n_clusters + target_centers = self.results_.get(n_clusters, {}).get("centers", []) + try: + self.labels_ = self._compute_labels(target_centers) + except Exception as e: + if self.results_: + warnings.warn( + f"An error occurred while computing labels: {e}\n" + "Defaulting to n_cluster=1 for algorithm continuity\n" + f"NB clusters available for pickup : {sorted(self.results_.keys())}", + UserWarning, + stacklevel=2, + ) + self.labels_ = np.zeros(self.X_.shape[0], dtype=int) + else: + raise ValueError( + "No clusters found. Please check the input data and parameters." + ) from e return self - def predict(self, n_clusters: int) -> np.ndarray: + def predict_new_data(self, X: np.ndarray) -> np.ndarray: check_is_fitted(self) - solution = self.results_.get(n_clusters) + X = validate_data(self, X, ensure_all_finite=True, reset=False) + solution = self.results_.get(self.n_clusters, None) if solution is None: - available = sorted(self.results_.keys()) - raise ValueError( - f"No solution found for {n_clusters} clusters. " - f"Available solutions are for k={available}." + n_to_predict = random.choice(self.available_clusters) + warnings.warn( + f"No solution found for n_clusters={n_to_predict}.\n" + f"Available clusters: {self.available_clusters}\n" + f"Defaulting to {n_to_predict} clusters.\n", + UserWarning, + stacklevel=2, ) + solution = self.results_.get(n_to_predict, None) centers = solution["centers"] distance_to_centers = self.dist_mat_[:, centers] return np.argmin(distance_to_centers, axis=1) @@ -225,6 +274,15 @@ def _curgraph( index_d, list_t, old_mds, local_results, solver, dist_mat ) index_d += 1 + + if len(old_mds) <= cardinality_limit: + # If the MDS is smaller than the limit, we store it + # with the radius corresponding to the last dissimilarity index. + + local_results[len(old_mds)] = { + "radius": list_t[index_d - 1], + "centers": old_mds, + } return local_results @staticmethod @@ -240,6 +298,7 @@ def _process_mds( Process the minimum dominating set (MDS) for a given dissimilarity index. """ t = list_t[index_d] + print(f"THRESHOLD : {t}") if Curgraph._is_dominating_set(t, old_mds, dist_mat): return old_mds new_mds = solver.set_params(radius=t).fit(dist_mat).centers_ @@ -277,3 +336,14 @@ def get_results(self) -> dict: Return the results of the CURGRAPH algorithm. """ return self.results_ + + def _compute_labels(self, target_centers: np.ndarray) -> np.ndarray: + """ + Compute the labels for the data points based on the target clusters. + """ + distances_to_centers = self.dist_mat_[:, target_centers] + partition_radius = self.results_.get(self.n_clusters, {}).get("radius") + labels = np.argmin(distances_to_centers, axis=1) + min_distances = np.min(distances_to_centers, axis=1) + labels[min_distances > partition_radius] = -1 # Assign -1 for outliers + return labels diff --git a/src/radius_clustering/radius_clustering.py b/src/radius_clustering/radius_clustering.py index aa63fc0..4eee9d1 100644 --- a/src/radius_clustering/radius_clustering.py +++ b/src/radius_clustering/radius_clustering.py @@ -8,6 +8,8 @@ This module serves as the main interface for the Radius clustering library. """ +# TODO: Implement a check_radius_dtype function to catch when a numpy dtype is passed (counter example that should work as expected : catching np.float32 as Exception) + from __future__ import annotations import os @@ -52,12 +54,12 @@ class RadiusClustering(ClusterMixin, BaseEstimator): .. note:: The `random_state_` attribute is not used when the `manner` is set to "exact". - + .. versionchanged:: 1.4.0 The `RadiusClustering` class has been refactored. Clustering algorithms are now separated into their own module (`algorithms.py`) to improve maintainability and extensibility. - + .. versionadded:: 1.4.0 The `set_solver` method was added to allow users to set a custom solver for the MDS problem. This allows for flexibility in how the MDS problem is solved @@ -113,7 +115,9 @@ def _check_symmetric(self, a: np.ndarray, tol: float = 1e-8) -> bool: return False return np.allclose(a, a.T, atol=tol) - def fit(self, X: np.ndarray, y: None = None, metric: str | callable = "euclidean") -> "RadiusClustering": + def fit( + self, X: np.ndarray, y: None = None, metric: str | callable = "euclidean" + ) -> "RadiusClustering": """ Fit the MDS clustering model to the input data. @@ -147,7 +151,7 @@ def fit(self, X: np.ndarray, y: None = None, metric: str | callable = "euclidean This should be a valid metric string from `sklearn.metrics.pairwise_distances` or a callable that computes the distance between two points. - + .. note:: The metric parameter *MUST* be a valid metric string from `sklearn.metrics.pairwise_distances` or a callable that computes @@ -160,11 +164,11 @@ def fit(self, X: np.ndarray, y: None = None, metric: str | callable = "euclidean - and many more supported by scikit-learn. please refer to the `sklearn.metrics.pairwise_distances` documentation for a full list. - + .. attention:: If the input is a distance matrix, the metric parameter is ignored. The distance matrix should be symmetric and square. - + .. warning:: If the parameter is a callable, it should : - Accept two 1D arrays as input. @@ -200,14 +204,16 @@ def fit(self, X: np.ndarray, y: None = None, metric: str | callable = "euclidean dist_mat = pairwise_distances(self.X_checked_, metric=metric) else: dist_mat = self.X_checked_ - - if not self._check_symmetric(dist_mat): - raise ValueError("Input distance matrix must be symmetric. Got a non-symmetric matrix.") + self.dist_mat_ = dist_mat if not isinstance(self.radius, (float, int)): - raise ValueError("Radius must be a positive float.") + raise ValueError( + f"Radius must be an int or float.\n Got {self.radius} of type {type(self.radius)}." + ) if self.radius <= 0: - raise ValueError("Radius must be a positive float.") + raise ValueError( + f"Radius must be a positive int or float.\nGot {self.radius}." + ) adj_mask = np.triu((dist_mat <= self.radius), k=1) self.nb_edges_ = np.sum(adj_mask) if self.nb_edges_ == 0: @@ -227,7 +233,9 @@ def fit(self, X: np.ndarray, y: None = None, metric: str | callable = "euclidean return self - def fit_predict(self, X: np.ndarray, y: None = None, metric: str | callable = "euclidean") -> np.ndarray: + def fit_predict( + self, X: np.ndarray, y: None = None, metric: str | callable = "euclidean" + ) -> np.ndarray: """ Fit the model and return the cluster labels. @@ -243,7 +251,7 @@ def fit_predict(self, X: np.ndarray, y: None = None, metric: str | callable = "e the distance matrix will be computed. y : Ignored Not used, present here for API consistency by convention. - + metric : str | callable, optional (default="euclidean") The metric to use when computing the distance matrix. The default is "euclidean". @@ -263,7 +271,9 @@ def _clustering(self): """ n = self.X_checked_.shape[0] if self.manner not in self._algorithms: - raise ValueError(f"Invalid manner. Please choose in {list(self._algorithms.keys())}.") + raise ValueError( + f"Invalid manner. Please choose in {list(self._algorithms.keys())}." + ) if self.clusterer_ == clustering_approx: if self.random_state is None: self.random_state = 42 @@ -271,7 +281,9 @@ def _clustering(self): seed = self.random_state_.randint(np.iinfo(np.int32).max) else: seed = None - self.centers_, self.mds_exec_time_ = self.clusterer_(n, self.edges_, self.nb_edges_, seed) + self.centers_, self.mds_exec_time_ = self.clusterer_( + n, self.edges_, self.nb_edges_, seed + ) def _compute_effective_radius(self): """ @@ -297,14 +309,14 @@ def set_solver(self, solver: callable) -> None: Set a custom solver for resolving the MDS problem. This method allows users to replace the default MDS solver with a custom one. - An example is provided below and in the example gallery : + An example is provided below and in the example gallery : :ref:`sphx_glr_auto_examples_plot_benchmark_custom.py` .. important:: The custom solver must accept the same parameters as the default solvers and return a tuple containing the cluster centers and the execution time. e.g., it should have the signature: - + >>> def custom_solver( >>> n: int, >>> edges: np.ndarray, @@ -316,7 +328,7 @@ def set_solver(self, solver: callable) -> None: >>> exec_time = ... >>> # Return the centers and execution time >>> return centers, exec_time - + This allows for flexibility in how the MDS problem is solved. Parameters: @@ -334,7 +346,7 @@ def set_solver(self, solver: callable) -> None: """ if not callable(solver): raise ValueError("The provided solver must be callable.") - + # Check if the solver has the correct signature try: n = 3 @@ -342,6 +354,8 @@ def set_solver(self, solver: callable) -> None: nb_edges = edges.shape[0] solver(n, edges, nb_edges, random_state=None) except Exception as e: - raise ValueError(f"The provided solver does not have the correct signature: {e}") from e + raise ValueError( + f"The provided solver does not have the correct signature: {e}" + ) from e self.manner = "custom" - self._algorithms["custom"] = solver \ No newline at end of file + self._algorithms["custom"] = solver diff --git a/tests/test_structural.py b/tests/test_structural.py index 927906d..47c171f 100644 --- a/tests/test_structural.py +++ b/tests/test_structural.py @@ -6,26 +6,13 @@ def test_import(): def test_from_import(): - from radius_clustering import RadiusClustering + from radius_clustering import RadiusClustering, Curgraph -from radius_clustering import RadiusClustering +from radius_clustering import RadiusClustering, Curgraph -@parametrize_with_checks([RadiusClustering()]) +@parametrize_with_checks([RadiusClustering(), Curgraph()]) def test_check_estimator_api_consistency(estimator, check, request): """Check the API consistency of the RadiusClustering estimator""" check(estimator) - - -def test_curgraph_import(): - from radius_clustering import Curgraph - - -from radius_clustering import Curgraph - - -@parametrize_with_checks([Curgraph()]) -def test_check_curgraph_api_consistency(estimator, check, request): - """Check the API consistency of the Curgraph estimator""" - check(estimator) diff --git a/tests/test_unit.py b/tests/test_unit.py index e4feaaa..85385b6 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -96,13 +96,13 @@ def test_radius_clustering_invalid_radius(): """ Test that an error is raised when an invalid radius is provided. """ - with pytest.raises(ValueError, match="Radius must be a positive float."): + with pytest.raises(ValueError, match="Radius must be a positive int or float."): RadiusClustering(manner="exact", radius=-1.0).fit([[0, 1], [1, 0], [2, 1]]) - with pytest.raises(ValueError, match="Radius must be a positive float."): + with pytest.raises(ValueError, match="Radius must be a positive int or float."): RadiusClustering(manner="approx", radius=0.0).fit([[0, 1], [1, 0], [2, 1]]) - with pytest.raises(ValueError, match="Radius must be a positive float."): + with pytest.raises(ValueError, match="Radius must be an int or float."): RadiusClustering(manner="exact", radius="invalid").fit([[0, 1], [1, 0], [2, 1]]) From 54f66786353291ca3c81c48a6ef7bc57595b488e Mon Sep 17 00:00:00 2001 From: Quentin Date: Fri, 11 Jul 2025 09:53:30 +0200 Subject: [PATCH 7/7] fixing type match for numpy based types --- .coverage | Bin 69632 -> 69632 bytes src/radius_clustering/radius_clustering.py | 32 ++++++++++++++++++--- tests/test_unit.py | 5 +++- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/.coverage b/.coverage index 3d573d838879a3b8593c389fdc63cf4b7d597a50..706232cbcfde1bd315b72a7b957599b021ea533b 100644 GIT binary patch literal 69632 zcmeI5349gR+5gX(JL|bKcOWb&vUov}Js|-S5JW(B5M)!qC5CXpXf~3YgpeT4T+xC{ z#rJjJ-rCx#uVU+7wN|TAu~w^4t5qvnTx!Lw?n?4L&pmSvH!A(u7yI_F{}UkmzUR!G zJM+xUGxN-O&bf;g%&knPk}GTLtIN~LQKSo@luSw{2_X*r>k9wGj|D;(0RPh*y)WsI zf=zjze>8E;bBTYVH{T!OUf~@Qf6}dXHpQ>9%AE?RP(R87$^yy)$^!p?TOe5ATKTD*GE-WxdPTor&d7|b@i3i<@GJe zQ&TO&vkOF@mQJln^8#ojRk^ALmq-qdRv(hAPpwSVr)nxv4SXvGSFRYshGf<=E5BPe zdR8jCPv!L$@HgIRZ=hiBuFP5q4OP_EufVNeUSD2QadN6*c(PpXVMQ%$=bAL%+xpt- z!gOu2s7uU>B`!gWNJ;SqA{IX@s&F`9E~P^z<9ZP18tpzJ9H9XGy1F?>+D_m zE7pt_k5*m@8&^{+%({0){AgG?yUjI?)yu);r!-dA;lY75_y?zJ_qFD})=k$&JNvbd z%~u@Dz5Cnga0-*-5#?(&)>N);Ov!^kJh^wx;o04&sU4ogoq#i7VBE|f*NqZ5D^qJ4 zR#(CER#|ytx;D#YaPZ5D;Fo(wSKG=z@IZRj>g2%V_;(eStSikS{ z2CtINCkb9fRaME_mDx)M*G}?-17~yYHB%^zQC^Q1TCz4;5VrK+dbZ)l!ZswASJf`> ze6D5Bn1&QwAm9tcVk;YKD)3Tiz3pET%F0ZYvN|o2}djqU7s{uZhA-ch4&U2GLkqUCs%cPm zo5alp-TY#>=*0B}qPn`gW<_+r!rL%!6?{KlocKB1iNK1^5=lDJ5AW3ne!b9qCuqCN$ca4#J) zDE5y|13X$~Mfe+|WE&slf&;0QKWrEc8l#Jm`vWQQp>pr}G9Mf+ zJniJ3CuqAIb0G3hK2?I=AM=G zl?9Xqlm(Oplm(Oplm(Oplm(Oplm(Oplm))i7KqUp@v#4&`kxc{tA3OPlm(Oplm(Op zlm(Oplm(Oplm(Oplm(Oplm)(n7Kp_>2mAj?m+#=tsa;YQP!>=YP!>=YP!>=YP!>=Y zP!>=YP!>=YP!^CD!2W*~|4VZyp)8;*pe&#)pe&#)pe&#)pe&#)pe&#)pe*p6w15G7 zMg324-@5vr`p^5@{YU)={5$=d{2%$3`akeD_^tkGf0ciNzsR5EAMTIwhxz^eeBbk9 ziBA%5CtgWBmv}Pqo5a0|TN2kLu1fr8Vq;=`qA_u5qC9a-;>g4iiLr@-#KDOJ5?vCe z_nG&$x66CRd(6AvyWP9qyVBe2ZSY#WYOmZ|;?437^NPHIUN6t{sQZEYSN8??N%vv* z7w%2&58VshjqW;kwVQI6x<|T`-7)SEx3AmPHR7MdcgJ6f{~^9D{>%8!Vwc&A?OFCjdz3xc?rZ1Smi4*yuJx+*ob`L_Ve4+|X6wh+ zW!CxDnO2)sYc02qvF2KntrBaPm9%0pgiFvLWn$2dF zS#BO}&M_yMW6YuEfo6BpG1`rHjaQB5jNcm%8+RKw8P^yW8|N5-(PW%rlo^YSS;jy$J3gx3e%`*4)9uo)O%J zH9aD@mE6X{?h)KVZe?LN4jqfhEiBB7YPOIYS=cp#8_C-&>>^sjbTa0$7Ng?0oNlS^1=MQ{!X@hve5}nL!bu9Y-GW@ zXsM87SkM+hKtdL@MsPX_oM0^g)r|GzbUSE4O^tQUELamQvw@sz2hF^^ZJix7aoO5t z2aQ~=ZMB0mmo01Upn=Qg7CWftvaQ(;R&&`3vX0BOAZxj7X?B7dl$y~@HrPQmFK&j9 ztBNFi+^La-k2@uj@Ntz~HnyqWhJ4kW4sL>8qM5T-D zpn}US3t6xnB)0J!dEO4ndHKxQc5o7x^QYTE8JF|s+rf!k&YfomCvZ7ut{oiD|K(L!Ks2*}+U+KDxvXW^g%bv>i<6vUrpoOyjbs*bb(0IkLzOj^MIz zq#aD*azr5uCWC~Q>FD;IESMBQF&%9MhXXj&7-0nyadF4!u!6%-**d5RplqtcF&&R( zG*)+u9URJKZJiyAU1bNQQE9aulyG?ql#bzY9yB(Z%eimb!6+{Ah!%5+ zN3g}7I3!?HS{zVaJR0}({VBG)-}4Gp3Va9)-}2g*Wqp**ICB` z?$$NhiZ$G=YjiDbWdV2V2)ipYU&SU}i z>KdI!XRv^Kb&XD?(^$a0y4HR_oyr34)ipYa&SnAk>KdKczKsRkt84As=tLH9uddNj zI*A3`t7~)wg;P8RheOjeItmX0oa$K^O-HeSdv&OxqglYcx<-qkhC6kQ4yVN|>(gQN zkxNXA{dMhoCGxJSp&D`4o*@#BWm5eywV>Y;Q9;~pK(1v-?WL&x&LP>vp* z>){MtI_g1mFhidXy8Q(@h@n%*@`Gq8L$8jyfBPPWZXI=!_GjqVarr$o$r?sm{>x zqrQNg&(QN{%i(DTT|es0P=16g-^`{&`aJmD$&o&XoX5EH*T}h0j=mq)J%^mj(EFp_ zM9yL8|3P;&!!r@y0kC`{JQCqO0Ck3(;owa`qZ&Eu9ELXmEIyN*#qchGdILF=;cWo* z8DxWl_W@iyXk&OEz+!wZ%H9YxZ5=!o@f(2#Pe1%VpuxirzYS~`X>5pZamL?l*bo@}AUTdS*N7Z7Z z%Li4`VWYnXRa#@Cv&XVh8$CU!M@wvU^Pq}%+34dzjViX$!DHDd8@)TIf?^w8JE~na z`gKqv3v6`ipx};=9vy27ZFJ|Lb`{v@%R%)ivC)x(>ea_aFAgfdmyIqQRF8Zc9XKes z9isOJ)vddYt{YTdHyiyns4jUnI&DyX7aKh`D7Y!2y9O(G7P@LUxE+>ZKmj#7EIlg> z(~4z9%&@eq&|(%mKtsQOI81~8sQ(TIYQF!D|E&Ln)5ocGE@hk8eePCoNn&Z@MSqI_ zuz!#LGyhuuGI%3!hTr1X_!a)K{*iu9ztoTW!~On=H{kmNK1ke`*pb+l*qk^s(VD1B z%!BCvR(LCLh4)5c1iS|rkRaX%iG1(HMBLkvSm8b5-Q(Tt)q2->7kc0K+PvlN)!t0^ z0&l$6-z)ICdzQP~-R^$s?sV6<8{AXf1@4LNRJRn~2Ml$)Ku8m@k#Mf@k8P+>wu{Lx9mOk4Es6zF}s+pV->8HO3=+gaz7JByrYPN_50Ax^@6&)#J}Z9ih)ZC?ek z{`q!+9on<)v^~LIX&-AHVqIYOu$k3vy$4bLpRC8N8?39WUt4!TgkNK=g$Vy>YqmAf zDz*-=G%I0#Xue_YFkjTpjctrwtM4#$;~I0Dd7pWaxyd}utT$Jh%glM^XmhaH$LwP2 z#>d8+#!h1^ME6^ai;PXiX-2)#&sb?JGv*mnj1pt0zEl5Le^b9(pR8}uuYs6;lfF)0 zt&h=D`Vf7o-d8_T?;86k_D1aa*yFMLW4G$DSZl00c5$o>qWL3YV`771y|kBNacz(G zwzgfnQ@dV!7-ISL+G?#rTcXV{CmBx}UG%N+F5((IVRC*vGwW6i0&y<=o}}B!9W-AO zM7p%6BnWkB4@nT~((ZyV*rnYB5r}qao+Jo&Y1bUuLc2(UfS3A`ZX`F+gd_-gsV50y zUg}DMpqIuaT}Q5`OcI2>)R6>nFSR8>;7ct@5cyJ55`@0gkOZ+W)g?jj3s14aW)S_t z^&1g}zmy6h5dTslr~?9E(&6Hg5Jy4;Oxi^ife@JNkpwX?`CJk>d-55s2F?wY+nULHLbo=PcZFWt zOx_W?rJ1}fHVGnP@|LJ*UPpFI-A3M&x|RG*>b2wzp&MJsU*#@9TufdUH5=EH*Mwf! zL0%PlQ6qUp=!J{OE}^$9B!7`DL3B)B7Bw?xlb3{^Kb`Cpdft5UqR?~akv|JPXD)d` z=-G3~^Fq&>O`a2a<}C6jp=Zn_&k8+#26;y4Y17Fcg}!$hd0ObH3&|gZo;;Q85PH&N zvR&>0M9btUQ8RH8c~a;J6Upy|9zTIRA@rf+$>TzgJCyuR=&|F-V?vjXCBGHAq?9}= zKMlfWvQ5;CE+LNyJ!&-hjnKuT$X20?ipj%5k1Qe&30*jnJSg;tLh@_z2^c<;2P8rK z4Da#aG{#pdBj8;fp3#nGbikWBl(wdKz}q;GHbr@F)FuSc&V?gSJjeV2wh!8 z?h?AHn%pV$Emh;$IQF*+4afc#sb`a)3q4~t`I*o&XONqPhVx*H z&~MHlHwoQ+9=TEKp5(uU?%AE(AasxJ~7%xg>|ylZz!mq)jf81fe#$P!h!2D#y~!DpAowO3Nf3RLP!fdSB#;F0H#t4KLvSO(2z8VfsB8uWNp(y zLe|s`5bHsF-(M7`*CwTGIZ()wj($Rx*7Oy!q_mHakCyZnvUt}4LXIl#C1mla?+IB@ zoG<0BoHNhHb9c+EiaxG%M zZ@zB6U_NC&WZq$3XI^GpVx9y2{&nUG^H_7PIoT|M-2Q%Mo@pAN8gCgd8&4aL8uuBu z7(e;W_W$cUv~iG}Wc;)J|F9nPY(tOuH`xEL@1Qqj+x$Q3|Bqcu64`zKivEAtmj2ng ze@XwpzJolN-P8T<|A&@ZFstC--2ZRhM<2({_`3c7MnB@>@?Y-%hnCL38tnf!c0wWu z=>MDk|ER0ZX#bS@^eXYj_J7kuFV_7g`{eO0^q2eH#V6%vDxgh8IUWzOUp$pXQm2!? z^X1>Z6FMkGPq5McC_RjdPEMo!ZaRmGPF17*MtUs$0d8B=<)3|7&JzKmZP;h;f2SV* zoJ@e+-v9mS@$b~jpVRxF+vC5_-u!*^{(p15{K7hzb09m3JNMi}dVNm+Kkvtvi*`-` z_+qzyZaP5cPX5km1Yhj(&+X*zoB;4mb@}IX@?*X|*_`cDgL_G*PJZ71FKmQ4Agcdg z_5Wux^Hu-9>i@?Ke%1dE8TYFHAN%3q8S($!{{L)tfELs2Y@UGAp)q$W*`^I3?%nPd z_Zs&icawXXHruUtSGvpGdF~W#qFVyl_Wj&E*NlG}e=GiS{OS0k@%!Sp#D5aM>?_CF zcaNLwWA+Anfjt4E>~Ce)LcaY**2=13jQs+31RKK!v0g0h>~Y?9UUIfO4?8}Lv8T@U z&gIUz&U&YBZloP6gIN2pPSJLu-QVtJePkQfU#%AF=P;iBch--rUpcF-3$3%PQ=JOy z1ZRmgAI8#;a%NZqt$fIjxZV8B+-<%HW9c6>e`;Q8o^7r(Yhe`qBK;t9nmNSujCUbB z{&2lW?_rjjy^RjzFUB8?-x$9z{@b|HIL7#a(EwxVQ^t50N8i;LVH{}0^bfSV^w;#? z>No5Ep*L#-w2|5#{aO8&`VaLDdZm6GjGq4__P5vzu_yfZoN>;u*aMJle=B6$SI4e} zjDMMN5Bc_QYcD~*{Z{P;?FtwDkliB|=ZX zmx_EygPs7-!6F~hpvTh_sK|!|f;y28Y0%@KR^&q(bQ!S7hcxI?x{SZ-63^&-TXJYH zynx=Iq{kkXZs z`a!Url5_y=M^{MdLl2-8l6uoVbh)ISv^Oo!p?>rvNnL19S|+JG?LtqK)RlIpCrHYp zUFq?Xy3#y)oTM&Lw@i``tt^$4pcXw=k_U8*Bp2vtNpYYhk}MjhizP9rTO`S%j4qVq zK-~gK7Tn{Gl4L{Od`T9y={!j$)Ey~Fqb8l3LzK>u1ouojTM&%zp>&p{-5rbROi8bk z-E@Yem&ogMx};ahOLUqf$TFoy47Mo$R2+lAa>lX_2HSVabt_9*0&6B|QONszB1?ixzNFqU@8|nND|EDLJyGyley4=l3+F$daxvz&V?Q% z3FdR510=zOF0{WSn9+qMCBc*~^gu~4rwi>T2_|)+eI>!HF0_v%nAU~%mIU*<&;ta0 z)j$+chZA#<*Ck~K5lPAp$dQyCFeB-U6ORSo)JQq;xRX0JuG72joN;oU#=Yf?nd>wLE@%AQ zH#QzFcOV{)jT0NauTgKgBjdiYciTDR<^I`mZvSX(o7k5xj%FhyHXPs0@9$_f(f+sX zUAcPa{m+kVi;i_Hcbpx*+l}0Ac5;|~j$HfpzwkC7Kq#jF1^NGU7);a#{r_+Iuc-O| z)ck*H{y&_`z#wY=KZC$b2x|U6m@QDv|EK2vQ}h2}=SED;|EK2vQ}h3+`Tw%B$f^1N z@L5sK|A%)5HUD4M3g4Rf|1cGxL&Nd^jjTKI|L(sE+5eCEzw&Q^dHyc(&-PD)dH+}V z$N01T!~Bu{K`_^!llUBF{C_#|hr}a^dlFj`KZM!-&q}OKR3}bKEJ{pIj7tnp9GJ*U z7~aR;-@HG2PkIk|cX~I#EdS?w8L!E!^p5xDdsDnI-e9k{mw-9`-*;bg|K$G8z2Cjn z{fT?2dycyvat~7OvF;pqqFdx1?B=^Hz9;@p{4bDquq}RX{O0&I@eAYMkGIBa;^pzh z@fq<$VRrvyyc=X4e9CsSooqYX%6`FaWLL8b*jcQN)w1Pm37g5rvjWy1c2)f-3n&XH z3n&XH3n&XH3n&YG!2&FUvuR`cA#@S33{I$xFT^2*jb(67Z2%mW#40fc81TgZOoo-XZSSR#`Jqx2Itwvy@Cxi2E zP~m&LGR+B-{v)GRQJG4>tm246+PP#f|I3xFeRq*|@O=@&#E2C*%NZUo2{WEq^78v&99S%!awLN-^BW%ySpWK#uMhJS@Zc7PMhz~b3;AW@KIU=t$% zGm{7)O^}7_Wz9Gio)!Tl39@jVtQpI~wg@0akcF)gK!zX-*G2#df-G!_0P+J_xF!Pp ziiNN_0>}CTBm}asQq~;A!jmI_bU+rak~ITZxH1Ar24o@s zb_Ih%E+7l}*DDy1`qz(z{QDIQ$oxBih0CLLAQ6y-SJL+1d!Cn!V$9O2^J2I08;w? E2ZDf)od5s; literal 69632 zcmeHw2bdK_+I3a;?c3GWi8PD^We~|>NRmX7MHK2__#m2zrTJp4>O$8)m^vybl=nc zcD+@VH)UdRbuw#Vd1Yy$I;#(AgfK?Kva%3D7W_7bU-4rChzSD!<1~CYXrUhKo4EdT zuBl2Eu^SyMgIgV2u}5lqh6;*{V5{onBB-LGvbZ!+S(mjmS=TMK zLiB0X$y2KN0;nWeyr_(=k<}^MymMA%a$&MESyq^=;(O7lxTrJjymk#WGn+QWXC_m} zl&CC(e|CeTfdz*SW!yrjsIa`Ui0%D?%0yY=;$&5~tb{zm!gAQpQ>yvVR+g7$SC?m% z6qltAv#hGPy12Y7D|t$?u%7w$hcsXfRkU)3;sB#HzsCgn>P5&WvL5=$y2bPxVlOLtHDJh7a!)*VdaOvZrHb^ z;YGr}MI|L!622hd1!A>@HD!hDt5U^|p|E^;vNEwK*&{0Dw(xUbNO0-&vf`Pe+8plEk(WS0 zF9~VTB5r23YK229bs0rxEBeoDa0!Vs`2#j8ib|7-s+vk+BtK@7Vv9pADZV372^>y$ zGoCrT)!_~&TCbp21V~`U%*>1o9Pn=h^Ck9cG&l&hcv^#-*>hc zcD$nz&paXHa68U+p`aK08(%wHOYQQREmG>2g6Xqgy}>TCC5Lu-Lb9&=@zC2g<;m_nhxV6OaEJ@-VXe-o)#2-Xa(Pu%G2B+! zZL4aKSen&waiWU3P*umQ1?A->$wV2yH^3I-s^Bw=tFlu2@cp~e&_i`p)v-DjE$^7x zCvkIOHa``BN||Gc z`g>G>Cal)dL=peGvy+lmW9{PP_p0M}qmhZ~WOZ?AGK;N}+L_daRhy{FDy&RGA>2zl zchrCEGGJ$`pcnhgqGZAz<$^ZY%mL97)$>(I_F>?QWULGVxgCI+Q$WWWHd*|y`EMfs6ZoTklmbctrGQdEDWDWk3Md7X0!jg;fKosy zpcMFJD&T1P(Y{FW=;dR*Rc4d~(7(sP9=$0XmFB;X{P+C>zsv(vd#w~u3Md7X0!jg; zfKosypcGIFCEo`i1w(_>I&uOqIqrnGmR6y5Vi?#owM!&*?Q-`D!PzopolmbctrGQdEDWDWk3Md7X0!jg;fK-6B z|Eu_4%AtT#Kq;UUPzopolmbctrGQdEDWDWk3Md7Bl?oVeRM`KT9b04nYyVaMdH)Ij zL4T`%vww|$sehh-hQHEZ<|qBx{uF<-KiKc>ck++%oB0&R3;e>{@4e;i@}BXwdG~m? zde?bZco%qQd8c~S-V$%VH{F})jqnC|J-p+*mR=(-?*7aD+pt&3?%wa-;cjxT zbkBFsaO>PsH{nim$GIoEz1XV`UishzN=+2iaJ?Ot|!yM^u9 z8vTNPNcYm6^ig^b-Aw;TFQV(|YFa@H>2x}Q4x`7@uCy&}PHl>;1J;Mu>(&d_cI$rY zc59P$rFDUIrnSmiW-YX4TPItitU*?`b)410@-2gWLp~$#keA6*0A^*O-@>=a`{cYnGS^bE-Mc z9BlS7JD9D^My3(}D*kc&?f9U?c+zq zZR4QvnemSCvhjrRSL0UW8sj2kow3R&Hx?MvjIqWEMh~N{(abPoU&lU*y%Bpp_Gs+x z*iEskVi&~Dh@BE!8k-lJ5*r!oAL|l3I@Tzr>tE;}=&$NK^oR62^-cQa`nh_bSL=)P zS^6Y>nBGV4ptsN+tzP?F+o$c)p49HwZqxpxU97FwR%;bnA);Y-16!JAppVf9H0&0U zkI=_7>>82%=tCNIiO7fOBN}#=WIqi%MdSmvTE~dI&uZ%sk@wL1G;AM{chP$^JT4;p z(7QBj7m;_+J{leyk=M~XG;AA@H_+=e%#tN<(6CKJ_Og;=BJwiYOT(ih@*3Jj!`2ab z6}?8oRuS2aUZr8ni0ncy)38NE_ONY^lB?~ZVWwPdHw~Ld#5!2YQxW4c>+C2LlTkg>~qYBJcb^nVLT#_vJxXA+t5Qaj78*Ow2g*(L>@wq z(NK%XgXm!z;)whWJxD_ok*)Q4G*}and(q!$aB4*ELif^Obwut)chO*#EV-KoDY*`!E(9H?KG%~$Zc$!YEJ6&&}}rRib`%q zU(%p5BG=bf(O_9bZe*k)A~&ELX;3a!LpRW%EF$ls12ia&$ff9d8k9uj5_BmImPX`a zbO{ZXMC2lLF%60%av{2i28$zd0lJU|iz0G9x_|}?BeK4JB@L1hS%=ospoo+DDs(;# z3Zs&9(9JYh5S6S%>#QKb2sY5#bu^eCtreoRG$@EjfI=F~i^ypxu!6Zjux6}5r;%U| z!`N6|OM}_bI%lABNHB{pU$vS9GdWzjiUj!_u2@Ne864KFAi;DFYwJibjl)&7B$&$K zN`QGBt^kq#(%!@L?2jEun{!DplEbm%NpKQ}lSh+a z1cy13Nidwli8&+~#^HpCBpAx!_z5H!!r?fm;Y1F{jw8Wf4#$io!3i9W9z%jb9FBsI z7|7v$qew7-!;!E%{W%;yk_7!Y95$Q;$8$Jj7zz4vICuyN`fzx{U=sA^aL@@P=*8i{ zK_uwO;edf8$mX#B021`zuwQ=?u-gP&3w`^MfZZej_UTIkc8dVmyAKK24FX`V-XvhR z2Y@|$k$~MC0A}|j0lPH-?4C^nabtLaSw(bXzyl8MG6QyFuwW%KV0Q&<8TDqsZV8yg z0I~Z4!7(^l>}G(mUa^G)>{bAbDQ<=(9w{1~^O+`2W~ z1h=Bxxiw6&L%DHlxDo6c_iYXPxDn;Ht%1Tm<*uz^7nXC=*06(J$~{}dHq^r{o7HJk z?$~TwN>hfdY1qVsa>HiWqTH`F9LFZ*cC8&eh~t#GwFc|Mo>J!445Rwil-sq2>M^3s ztzmh6YgA8Xa*W==n#KGYwnT4ju$W&nxOW3(ehq7*H_%&@88*YOqc=Gko1T^lI#Y_@n0+D(~fGrSufyqIZoyoWN^W|%!fG23SNC3u2jzRmE9=q1XH zTSG6P7b){?w)}Z`u;Rwe@NUX{n=Rjoo~O*X8Ga7!w3u_V-FapMWzNkO??BH{X59?4 z=O*Ue9PgmaycuSXP|UpZViNTtZ5t2HY;V|npO=rA8z1UCwoqt^0g6YkSm{w%_6TH91WxJzsKrG%TbHZ6~EkJhG6BQ1D+d9w(2 zXf1at;Rda>tS8)`wN|-=+q3p=OTyh*YuUt6V4~d?!D)|>fP(zYCUe< z=RN1udDnVn-eul--sxVUH_6NQMp(Cc{k?AP0smI7t=HVMJ>-4}QU43>3U|AEJG=?l zsf2ljw9psL2ySrI#+->T7BW>n-vZ>saer>oV(HE3}qdORRYi<@d01t=0S!pgbXPG0+er6Z*nD`TB6EhzF8lwES;xEN7i*Jcv7hh<6 z6rUB(iI0f)i+7G6ZM+$8Y`hpZjP1r(5ZklJ{w!mKQDPJrw-_Cb7RE`29s4%+N$gFC z>K}=1jV+5^9y=$tCRP|56B`)Av91u&`})oLKlFY2cKv1jUi~Wle0{B61JV3UeS&_X zo~=Etx7C~KhW4fQf%b~_fOZ?{tzDy?sa>emX~VUp+FWfiymQzPyPgbV7Y%absY}~H z5ar^Ql0le@TS&GW?ZHP$whQgXnUaA``MmULF`EeAuo0%1MkL;WZ>A?mJFg^OeKS`7h94++>4215cpzKGKhR} zTrvoKu^}16zBncs1i$c@CG>>o7i*G1_=~Y*5dY#s%0mE*>SY;3!04c45CWrbCA$mV zjs7JW1i|PV$sh_wUrPpIF#1X|h=bApNCts0`ckqj=yvplWDp9Y1A?(w82wW)1i>)+ zhhz{9qt7LSa2S0i8N|crQ^~GJH=s`>gNPV?EE$Bv=p)G>CPw=ugP<7wT{4J@(T9>j zSd2c94B}$+zGM&>qxU3($QZpV8HC1YpI|IDM(+rQAUH;E3kF-QLvK0kY(kI;;W2ts z6d{O@(HoLMfQ(+33?gK-mu&_&a~2|_*K9_hj@7m3Rk<1j$>b69u)a(X$e-LOUd0 ziJp;o1$tWGnmY89JOl`u(UYQN-5T_SEP=2YZ5JiE_2_Yd^J>sz0_WzTM+M%Tiyjep z?0B?I;K`%W!vg0_Mh^)*F$X;;@PvuzZvu~>fF2Nd+<0`qz+=auzY08NEV@tN(PPlP z0*@Mv?h*LDQRr@gN9Llt1Rg#TZ54RfaP${}hYUk^3OslSxWyv?xMwePy};Q$(I$br zXQMyMvjmYfx=t__TBBd`eRNVHfz`Xhq`GkA4WC&bt24+580 zpsNHfDM!B-xU>XaDR4mn_67o8Ppsn?@3qctG%M!yj_e;ztR;<@Pb=$xd2Z?sl2h`v!M8HC>` zkPPB)bXv-+z^zyWz9x<2pDGE&;HxEpAbgc15QVRl1j6tYl0Y0@CkX`Nr=*c7wUR(6 zzFZI%i`TG2(!ls27_SxuEE=zp1j6x3Ngy6yCJ6-O6@svcyj&0#l9x#WF?p#VEGRD# zghk~`C4sPfi6ju07fS+x`C>^RGG8PKgyst+f!I7L2?Xawl0bA`C<%n;3nYQ~JdsAM z`I10{ULXhy(dP-mV)VI^K#)F15EiA+7KDZAvjkyr`b;4;1nT(`@7*v%t^}d_bV(pq zpC$r*9xXgyDEziUITEP!}@iX;%QpPWYaOqK+m1#={Un0=BY5VTK}1Rew@NCIK| z_%yP6oFovqkCg->_c3zUp4l*37C`JiN)mVg9GOOToFoZ6433Zl;`iZ_Kmb2X5{Tf3 zO7a+be2DzaZ5vLM1rWp!PJOB2@-E+af`Db^gV^eD8d)4aP+$n;2Z&WFH})5>qOzZW z%QhY_V8yb&0+v&Tvby`aS^JcXWFn8)v0=BHr6tGopa{=FNnIT}yR?XxoeRe}r zu{g7369F?a8_N?7!A2voylI9nV3VevfQ_5D0yb*w2q zw*SL+m!-BlqV|8NWh$%Xzt{dxa`bhn{rp+l|BWfgOC8@2wExFX!);S#|CRRt_-W|j z)Nvni`#;oF$Fd53=Jx;iWyoQ>@%{FHl7p{fC9M4)@CJre`#-$q2Oa)@ZU2XdiN)V< z|A(cyn797}9P@9q|HJA>!SO5hkM5VdTdjCJ`Z}ac+Z}R`(|IRi|d=*uicN`x7?T9Cm`E@ zi+i1WnR~7qy35@q?mRcw9SwQ^J=|m64A*kLbv|?UIlG;wod=yeo$H+|oO7MioJuF@ z%y7mzCqSS7Hcn$lw-4CwL683@?Z4W$*jL*dp}&8vKEN)vXWNtPq0ryIotlI_+&ax=M#oNsj{Ye@}RM4CdT{{-mQpH12l!~D|xzWVi_?r*3P`ptTX4D!_D?)Cgf~<6aOgwXT3}O_4vc_=i+z9 z&yHUj&x@~)m&X(FQMv~i|E}?(@tAS9@pinwj^eG1dg$4Iqj9Bio)H*TM$(vJ9c_$* z9{t^oHb&#vlZGC9FZS2if!OZY)v=Axo4+xEx>nP{ZK--~n!3M!_XFLR-oyI!h^@1JTh5O=lfx{eNJjuF1S*~QAaTh!#jrGJQOV$b2oGjU~ zxC73StOI;)CP{W2^rf9BSv!0jo*>yVxE&s!#ya3}k{yMQ!DA(BjgP`(Bx`|NfL!=0VxaBiVHqYG8o1Mx04LUalyw*1_Qa^wvs`vGR~3= zl9h2A$zUuOe2io;mJFO2~gFbC({oxw)Z%gYN_aE!)mfp#Z^?MULeYkhn>HXaPWBu0B`ojIYeb|0& z{nl9K{P*8Yi?fp6V=uioVS2j+@6ss@ykX~0-qgiDBYM4^O4R2sy`#@KY)`iaZ@1HW zy`}ebJCe8C4c=&{^>kys+aUiRcZWgRp#A@2f1euvPmTYl#{V-?qtU4G|J3+@YWzQg z-R9Kze`@?cK7N#;#{WwVyrjneOC_SH@&8zkgBt%&jsK^{|5M}t{j$dYV*>@$Ypnl& zf!!1N|M1_1?Ek0y2mCEC*54)m+5V|8?thU#!yoUT==b#7`OSUH`o7V=~`wjPh+`q#p|IfMG+^y~nFvkCH-L=rxU+m6t zC%Yru|KFV3o$H)SoeeO)KWpjFcP2PPoLLCc#lF_Q#6H_T)m~;7*)#0%Ft&fsBX1Q}H7Es?0!jg;fKosypcGIFCTHz%FwS~gTS5SBYg*Rs*NSqY>L(zR^BZXhgk zkgjEecC!*l9HeX6R0@nh-XLAeCR1Pp(gx{TT@Y3R3sWVKG)UL-=@e2qgLEyQP9c>t zNWwL+8cc>U`aT-+=@bk|5KPc;HQ$`&3tC}nLItQ}&So0&2^9>MEl9%^(T6~?APwsx z0=a@T5Mlxu)T&~uMhJ0EDBQ;e+n3`7s)POWW8uFPH49Ew|q#>VL z!C*;(G~{zDFp@(N|5IBa0e0~K3@&QNFu!OJ1(gSI@ zG};ET18KNKk|8uKjtJxi(r~dX8BW7R5rND=8ZMM2!)VCoSuh|?6iWC+3kHU?KpGZC zt3g&E4HrZNk^*U%hzR5a(vVNKU_c7s02&rVC6EzFLq6St0a>QU({OH70{MV6U9Of*Lq79@ zfgul&hJ5M;1JeG0@VOUQ1zl;FntTB&fFwW~^63{?Nf#RO`4<@JNyEue)sO*5Lp}q8 z0V#eRXvn8vFfio*(Qsn48l?Zxa6&{N`;Ug>BLc~PG#n=hR4_Iokorf%F%f~xKN^mX z2qgZ|kk7?nK%QR<8uG~);0uRa((oi%0hEl02qgW{aCk%@=Z}VbLIwjv${!7fMkSE( zM?*d(1G~tZ(eT8m1oHi8$Y*6>B~59_re$E?SjhIH;h=~>vL6ixMg(&GXvn5#V5>o@ z9}U?A4U9mh9}U?%c3#qmhHRP!Mj+3RhHRz=Mj*|PhHSD1PAnR-`5G8uQ+tGL$_5}T d$B%|=)&@o(#gB&B5rGUp8nU??SP3Nf{U0O None: + """ + Check if the radius is a valid type (int or float). + Utility function to also check for numpy based types : + + - np.float(32|64) + - np.int(8|16|32|64) + + Parameters: + ----------- + radius : float | int + The radius to check. + + Raises: + ------- + TypeError + If the radius is not an int or float. + """ + if not isinstance( + radius, + (int, float, np.float32, np.float64, np.int8, np.int16, np.int32, np.int64), + ): + raise TypeError( + f"Radius must be an int or float. Got {radius} of type {type(radius)}." + ) + + class RadiusClustering(ClusterMixin, BaseEstimator): r""" Radius Clustering algorithm. @@ -206,10 +233,7 @@ def fit( dist_mat = self.X_checked_ self.dist_mat_ = dist_mat - if not isinstance(self.radius, (float, int)): - raise ValueError( - f"Radius must be an int or float.\n Got {self.radius} of type {type(self.radius)}." - ) + check_radius_type(self.radius) if self.radius <= 0: raise ValueError( f"Radius must be a positive int or float.\nGot {self.radius}." diff --git a/tests/test_unit.py b/tests/test_unit.py index 85385b6..50d0c6a 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -102,7 +102,10 @@ def test_radius_clustering_invalid_radius(): with pytest.raises(ValueError, match="Radius must be a positive int or float."): RadiusClustering(manner="approx", radius=0.0).fit([[0, 1], [1, 0], [2, 1]]) - with pytest.raises(ValueError, match="Radius must be an int or float."): + with pytest.raises( + TypeError, + match="Radius must be an int or float. Got invalid of type .", + ): RadiusClustering(manner="exact", radius="invalid").fit([[0, 1], [1, 0], [2, 1]])