1313// / \brief a task to study matching MFT-[MCH-MID] in MC
1414// / \author daiki.sekihata@cern.ch
1515
16+ #include " PWGEM/Dilepton/Utils/MlResponseFwdTrack.h"
17+
1618#include " Common/CCDB/EventSelectionParams.h"
1719#include " Common/CCDB/RCTSelectionFlags.h"
20+ #include " Common/Core/RecoDecay.h"
1821#include " Common/Core/fwdtrackUtilities.h"
1922#include " Common/DataModel/Centrality.h"
2023#include " Common/DataModel/CollisionAssociationTables.h"
2124#include " Common/DataModel/EventSelection.h"
2225#include " Common/DataModel/Multiplicity.h"
26+ #include " Tools/ML/MlResponse.h"
2327
2428#include < CCDB/BasicCCDBManager.h>
2529#include < DataFormatsParameters/GRPMagField.h>
@@ -111,6 +115,18 @@ struct matchingMFT {
111115 Configurable<float > matchingZ{" matchingZ" , -77.5 , " z position where matching is performed" };
112116 Configurable<bool > cfgApplyPreselectionInBestMatch{" cfgApplyPreselectionInBestMatch" , false , " flag to apply preselection in find best match function" };
113117
118+ // configuration for matching with ML
119+ Configurable<bool > useMLmatching{" useMLmatching" , false , " Flag to use ML for matching between MFT and MCH-MID" };
120+ Configurable<std::vector<std::string>> onnxFileNames{" onnxFileNames" , std::vector<std::string>{" filename" }, " ONNX file names for each bin (if not from CCDB full path)" };
121+ Configurable<std::vector<std::string>> onnxPathsCCDB{" onnxPathsCCDB" , std::vector<std::string>{" path" }, " Paths of models on CCDB" };
122+ Configurable<std::vector<double >> binsMl{" binsMl" , std::vector<double >{0.1 , 0.15 , 0.2 , 0.25 , 0.4 , 0.8 , 1.6 , 2.0 , 20 }, " Bin limits for ML application" };
123+ Configurable<std::vector<double >> cutsMl{" cutsMl" , std::vector<double >{0.95 , 0.95 , 0.7 , 0.7 , 0.8 , 0.8 , 0.7 , 0.7 }, " ML cuts per bin" };
124+ Configurable<std::vector<std::string>> namesInputFeatures{" namesInputFeatures" , std::vector<std::string>{" multFT0C" , " ptMCHMID" , " rSigned1Pt" , " dEta" , " dPhi" , " dX" , " dY" , " chi2MatchMCHMFT" }, " Names of ML model input features" };
125+ Configurable<std::string> nameBinningFeature{" nameBinningFeature" , " multFT0C" , " Names of ML model binning feature" };
126+ Configurable<int64_t > timestampCCDB{" timestampCCDB" , -1 , " timestamp of the ONNX file for ML model used to query in CCDB. Exceptions: > 0 for the specific timestamp, 0 gets the run dependent timestamp" };
127+ Configurable<bool > loadModelsFromCCDB{" loadModelsFromCCDB" , false , " Flag to enable or disable the loading of models from CCDB" };
128+ Configurable<bool > enableOptimizations{" enableOptimizations" , false , " Enables the ONNX extended model-optimization: sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED)" };
129+
114130 struct : ConfigurableGroup {
115131 std::string prefix = " eventcut_group" ;
116132 Configurable<float > cfgZvtxMin{" cfgZvtxMin" , -10 .f , " min. Zvtx" };
@@ -151,10 +167,28 @@ struct matchingMFT {
151167 } eventcuts;
152168
153169 o2::aod::rctsel::RCTFlagsChecker rctChecker;
170+ o2::analysis::MlResponseFwdTrack<float > mlResponseFwdTrack;
154171
155172 HistogramRegistry fRegistry {" fRegistry" };
156173 static constexpr std::string_view muon_types[5 ] = {" MFTMCHMID/" , " MFTMCHMIDOtherMatch/" , " MFTMCH/" , " MCHMID/" , " MCH/" };
157174
175+ struct matchedCandidate {
176+ float multFT0C{0 };
177+ float multMFT{0 };
178+ float ptMCHMID{0 };
179+ float rSigned1Pt{1e+10 };
180+ float dEta{1e+10 };
181+ float dPhi{1e+10 };
182+ float dX{1e+10 };
183+ float dY{1e+10 };
184+ float chi2MatchMCHMFT{1e+10 };
185+
186+ float sigmaPhiMFT{1e+10 };
187+ float sigmaTglMFT{1e+10 };
188+ float sigmaPhiMCHMID{1e+10 };
189+ float sigmaTglMCHMID{1e+10 };
190+ };
191+
158192 void init (o2::framework::InitContext&)
159193 {
160194 if (doprocessWithoutFTTCA && doprocessWithFTTCA) {
@@ -169,6 +203,31 @@ struct matchingMFT {
169203 rctChecker.init (eventcuts.cfgRCTLabel .value , eventcuts.cfgCheckZDC .value , eventcuts.cfgTreatLimitedAcceptanceAsBad .value );
170204
171205 addHistograms ();
206+
207+ if (useMLmatching) {
208+ static constexpr int nClassesMl = 2 ;
209+ const std::vector<int > cutDirMl = {o2::cuts_ml::CutNot, o2::cuts_ml::CutSmaller};
210+ const std::vector<std::string> labelsClasses = {" Background" , " Signal" };
211+ const uint32_t nBinsMl = binsMl.value .size () - 1 ;
212+ const std::vector<std::string> labelsBins (nBinsMl, " bin" );
213+ double cutsMlArr[nBinsMl][nClassesMl];
214+ for (uint32_t i = 0 ; i < nBinsMl; i++) {
215+ cutsMlArr[i][0 ] = 0.0 ;
216+ cutsMlArr[i][1 ] = cutsMl.value [i];
217+ }
218+ o2::framework::LabeledArray<double > cutsMl = {cutsMlArr[0 ], nBinsMl, nClassesMl, labelsBins, labelsClasses};
219+
220+ mlResponseFwdTrack.configure (binsMl.value , cutsMl, cutDirMl, nClassesMl);
221+ if (loadModelsFromCCDB) {
222+ ccdbApi.init (ccdburl);
223+ mlResponseFwdTrack.setModelPathsCCDB (onnxFileNames.value , ccdbApi, onnxPathsCCDB.value , timestampCCDB.value );
224+ } else {
225+ mlResponseFwdTrack.setModelPathsLocal (onnxFileNames.value );
226+ }
227+ mlResponseFwdTrack.cacheInputFeaturesIndices (namesInputFeatures);
228+ mlResponseFwdTrack.cacheBinningIndex (nameBinningFeature);
229+ mlResponseFwdTrack.init (enableOptimizations.value );
230+ } // end of ML configuration
172231 }
173232
174233 o2::ccdb::CcdbApi ccdbApi;
@@ -415,28 +474,6 @@ struct matchingMFT {
415474 return (clmap > 0 );
416475 }
417476
418- // template <typename T>
419- // float meanClusterSizeMFT(T const& track)
420- // {
421- // uint64_t mftClusterSizesAndTrackFlags = track.mftClusterSizesAndTrackFlags();
422- // uint16_t clsSize = 0;
423- // uint16_t n = 0;
424- // for (unsigned int layer = 0; layer < 10; layer++) {
425- // uint16_t size_per_layer = (mftClusterSizesAndTrackFlags >> (layer * 6)) & 0x3f;
426- // clsSize += size_per_layer;
427- // if (size_per_layer > 0) {
428- // n++;
429- // }
430- // // LOGF(info, "track.globalIndex() = %d, layer = %d, size_per_layer = %d", track.globalIndex(), layer, size_per_layer);
431- // }
432-
433- // if (n > 0) {
434- // return static_cast<float>(clsSize) / static_cast<float>(n) * std::fabs(std::sin(std::atan(track.tgl())));
435- // } else {
436- // return 0.f;
437- // }
438- // }
439-
440477 template <typename TFwdTracks, typename TMFTTracks, typename TCollision, typename TFwdTrack, typename TMFTrackCov>
441478 void getDxDyAtMatchingPlane (TCollision const & collision, TFwdTrack const & fwdtrack, TMFTrackCov const & mftCovs, float & dx, float & dy)
442479 {
@@ -817,8 +854,8 @@ struct matchingMFT {
817854 std::vector<std::tuple<int , int , int >> vec_min_chi2MatchMCHMFT; // std::pair<globalIndex of global muon, globalIndex of matched MCH-MID, globalIndex of MFT> -> chi2MatchMCHMFT;
818855 // std::map<std::tuple<int, int, int>, bool> mapCorrectMatch;
819856
820- template <typename TCollision, typename TFwdTrack, typename TFwdTracks, typename TMFTTracks>
821- void findBestMatchPerMCHMID (TCollision const & collision, TFwdTrack const & fwdtrack, TFwdTracks const & fwdtracks, TMFTTracks const &)
857+ template <bool withMFTCov = false , typename TCollision, typename TFwdTrack, typename TFwdTracks, typename TMFTTracks, typename TMFTTracksCov >
858+ void findBestMatchPerMCHMID (TCollision const & collision, TFwdTrack const & fwdtrack, TFwdTracks const & fwdtracks, TMFTTracks const & mfttracks, TMFTTracksCov const & mftCovs )
822859 {
823860 if (fwdtrack.trackType () != o2::aod::fwdtrack::ForwardTrackTypeEnum::MuonStandaloneTrack) {
824861 return ;
@@ -827,6 +864,8 @@ struct matchingMFT {
827864 return ;
828865 }
829866
867+ auto mfttracks_per_collision = mfttracks.sliceBy (perCollision_MFT, collision.globalIndex ());
868+
830869 std::tuple<int , int , int > tupleIds_at_min_chi2mftmch;
831870 float min_chi2MatchMCHMFT = 1e+10 ;
832871 auto muons_per_MCHMID = fwdtracks.sliceBy (fwdtracksPerMCHTrack, fwdtrack.globalIndex ());
@@ -843,6 +882,9 @@ struct matchingMFT {
843882 float dcaXY_Matched = std::sqrt (dcaX_Matched * dcaX_Matched + dcaY_Matched * dcaY_Matched);
844883 float pDCA = fwdtrack.p () * dcaXY_Matched;
845884
885+ o2::dataformats::GlobalFwdTrack muonAtMP = propagateMuon (fwdtrack, fwdtrack, collision, propagationPoint::kToMatchingPlane , matchingZ, mBz , mZShift ); // propagated to matching plane
886+ float phiMCHMIDatMP = RecoDecay::constrainAngle (muonAtMP.getPhi (), 0 , 1U );
887+
846888 for (const auto & muon_tmp : muons_per_MCHMID) {
847889 if (muon_tmp.trackType () == o2::aod::fwdtrack::ForwardTrackTypeEnum::GlobalMuonTrack) {
848890 auto tupleId = std::make_tuple (muon_tmp.globalIndex (), muon_tmp.matchMCHTrackId (), muon_tmp.matchMFTTrackId ());
@@ -885,6 +927,44 @@ struct matchingMFT {
885927 float dcaY = propmuonAtPV.getY () - collision.posY ();
886928 float dcaXY = std::sqrt (dcaX * dcaX + dcaY * dcaY);
887929
930+ if constexpr (withMFTCov) {
931+ if (useMLmatching) {
932+ matchedCandidate candidate;
933+ candidate.multFT0C = collision.multFT0C ();
934+ candidate.multMFT = static_cast <float >(mfttracks_per_collision.size ());
935+ candidate.chi2MatchMCHMFT = muon_tmp.chi2MatchMCHMFT ();
936+
937+ auto mfttrackcov = mftCovs.rawIteratorAt (map_mfttrackcovs[mfttrack.globalIndex ()]);
938+ o2::track::TrackParCovFwd mftsaAtMP = getTrackParCovFwdShift (mfttrack, mZShift , mfttrackcov); // values at innermost update
939+ mftsaAtMP.propagateToZhelix (matchingZ, mBz ); // propagated to matching plane
940+ float phiMFTatMP = RecoDecay::constrainAngle (mftsaAtMP.getPhi (), 0 , 1U );
941+
942+ candidate.rSigned1Pt = mftsaAtMP.getInvQPt () / muonAtMP.getInvQPt ();
943+ candidate.dEta = mftsaAtMP.getEta () - muonAtMP.getEta ();
944+ candidate.dPhi = RecoDecay::constrainAngle (phiMFTatMP - phiMCHMIDatMP, -o2::constants::math::PIHalf, 1U );
945+ candidate.dX = mftsaAtMP.getX () - muonAtMP.getX ();
946+ candidate.dY = mftsaAtMP.getY () - muonAtMP.getY ();
947+
948+ candidate.sigmaTglMCHMID = std::sqrt (muonAtMP.getSigma2Tanl ());
949+ candidate.sigmaPhiMCHMID = std::sqrt (muonAtMP.getSigma2Phi ());
950+ candidate.sigmaTglMFT = std::sqrt (mftsaAtMP.getSigma2Tanl ());
951+ candidate.sigmaPhiMFT = std::sqrt (mftsaAtMP.getSigma2Phi ());
952+
953+ std::vector<float > inputFeatures = mlResponseFwdTrack.getInputFeatures (candidate);
954+ float binningFeature = mlResponseFwdTrack.getBinningFeature (candidate);
955+ int pbin = lower_bound (binsMl.value .begin (), binsMl.value .end (), binningFeature) - binsMl.value .begin () - 1 ;
956+ if (pbin < 0 ) {
957+ pbin = 0 ;
958+ } else if (static_cast <int >(binsMl.value .size ()) - 2 < pbin) {
959+ pbin = static_cast <int >(binsMl.value .size ()) - 2 ;
960+ }
961+ float probaEl = mlResponseFwdTrack.getModelOutput (inputFeatures, pbin)[1 ]; // 0: wrong, 1:correct
962+ if (probaEl < cutsMl.value [pbin]) {
963+ continue ;
964+ }
965+ }
966+ }
967+
888968 if (isPrimary) {
889969 if (isMatched) {
890970 fRegistry .fill (HIST (" MFTMCHMID/primary/correct/hdR_Chi2MatchMCHMFT" ), muon_tmp.chi2MatchMCHMFT (), dr);
@@ -1092,7 +1172,7 @@ struct matchingMFT {
10921172 initCCDB (bc);
10931173 auto fwdtracks_per_coll = fwdtracks.sliceBy (perCollision, collision.globalIndex ());
10941174 for (const auto & fwdtrack : fwdtracks_per_coll) {
1095- findBestMatchPerMCHMID (collision, fwdtrack, fwdtracks, mfttracks);
1175+ findBestMatchPerMCHMID< false > (collision, fwdtrack, fwdtracks, mfttracks, nullptr );
10961176 } // end of fwdtrack loop
10971177 } // end of collision loop
10981178
@@ -1150,7 +1230,7 @@ struct matchingMFT {
11501230 auto fwdtrackIdsThisCollision = fwdtrackIndices.sliceBy (fwdtrackIndicesPerCollision, collision.globalIndex ());
11511231 for (const auto & fwdtrackId : fwdtrackIdsThisCollision) {
11521232 auto fwdtrack = fwdtrackId.template fwdtrack_as <MyFwdTracks>();
1153- findBestMatchPerMCHMID (collision, fwdtrack, fwdtracks, mfttracks);
1233+ findBestMatchPerMCHMID< false > (collision, fwdtrack, fwdtracks, mfttracks, nullptr );
11541234 } // end of fwdtrack loop
11551235 } // end of collision loop
11561236
@@ -1213,7 +1293,7 @@ struct matchingMFT {
12131293 auto fwdtrackIdsThisCollision = fwdtrackIndices.sliceBy (fwdtrackIndicesPerCollision, collision.globalIndex ());
12141294 for (const auto & fwdtrackId : fwdtrackIdsThisCollision) {
12151295 auto fwdtrack = fwdtrackId.template fwdtrack_as <MyFwdTracks>();
1216- findBestMatchPerMCHMID (collision, fwdtrack, fwdtracks, mfttracks);
1296+ findBestMatchPerMCHMID< true > (collision, fwdtrack, fwdtracks, mfttracks, mftCovs );
12171297 } // end of fwdtrack loop
12181298 } // end of collision loop
12191299
0 commit comments