diff --git a/PWGEM/Dilepton/DataModel/lmeeMLTables.h b/PWGEM/Dilepton/DataModel/lmeeMLTables.h index 9506cd7ab7a..3036164b9ad 100644 --- a/PWGEM/Dilepton/DataModel/lmeeMLTables.h +++ b/PWGEM/Dilepton/DataModel/lmeeMLTables.h @@ -160,11 +160,11 @@ DECLARE_SOA_COLUMN(NClustersMFT, nClustersMFT, uint8_t); //! DECLARE_SOA_COLUMN(IsPrimary, isPrimary, bool); //! DECLARE_SOA_COLUMN(IsCorrectMatch, isCorrectMatch, bool); //! -DECLARE_SOA_COLUMN(NMFTs, nMFTs, uint16_t); //! number of MFTsa tracks per collision +DECLARE_SOA_COLUMN(MultMFT, multMFT, uint16_t); //! number of MFTsa tracks per collision } // namespace emmlfwdtrack DECLARE_SOA_TABLE(EMFwdTracksForML, "AOD", "EMFWDTRKML", //! - o2::soa::Index<>, collision::PosZ, /*collision::NumContrib,*/ mult::MultFT0C, /*evsel::NumTracksInTimeRange,*/ evsel::SumAmpFT0CInTimeRange, emmltrack::HadronicRate, emmlfwdtrack::NMFTs, + o2::soa::Index<>, collision::PosZ, /*collision::NumContrib,*/ mult::MultFT0C, /*evsel::NumTracksInTimeRange,*/ evsel::SumAmpFT0CInTimeRange, emmltrack::HadronicRate, emmlfwdtrack::MultMFT, // fwdtrack::TrackType, emmlfwdtrack::Signed1PtMFTatMP, emmlfwdtrack::TglMFTatMP, emmlfwdtrack::PhiMFTatMP, diff --git a/PWGEM/Dilepton/Tasks/CMakeLists.txt b/PWGEM/Dilepton/Tasks/CMakeLists.txt index a32c3768bd5..473e24c969d 100644 --- a/PWGEM/Dilepton/Tasks/CMakeLists.txt +++ b/PWGEM/Dilepton/Tasks/CMakeLists.txt @@ -117,7 +117,7 @@ o2physics_add_dpl_workflow(study-mc-truth o2physics_add_dpl_workflow(matching-mft SOURCES matchingMFT.cxx - PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore O2::GlobalTracking + PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore O2::GlobalTracking O2Physics::MLCore COMPONENT_NAME Analysis) o2physics_add_dpl_workflow(tagging-hfe diff --git a/PWGEM/Dilepton/Tasks/matchingMFT.cxx b/PWGEM/Dilepton/Tasks/matchingMFT.cxx index d962d51b94d..477fbfbd33d 100644 --- a/PWGEM/Dilepton/Tasks/matchingMFT.cxx +++ b/PWGEM/Dilepton/Tasks/matchingMFT.cxx @@ -13,13 +13,17 @@ /// \brief a task to study matching MFT-[MCH-MID] in MC /// \author daiki.sekihata@cern.ch +#include "PWGEM/Dilepton/Utils/MlResponseFwdTrack.h" + #include "Common/CCDB/EventSelectionParams.h" #include "Common/CCDB/RCTSelectionFlags.h" +#include "Common/Core/RecoDecay.h" #include "Common/Core/fwdtrackUtilities.h" #include "Common/DataModel/Centrality.h" #include "Common/DataModel/CollisionAssociationTables.h" #include "Common/DataModel/EventSelection.h" #include "Common/DataModel/Multiplicity.h" +#include "Tools/ML/MlResponse.h" #include #include @@ -111,6 +115,18 @@ struct matchingMFT { Configurable matchingZ{"matchingZ", -77.5, "z position where matching is performed"}; Configurable cfgApplyPreselectionInBestMatch{"cfgApplyPreselectionInBestMatch", false, "flag to apply preselection in find best match function"}; + // configuration for matching with ML + Configurable useMLmatching{"useMLmatching", false, "Flag to use ML for matching between MFT and MCH-MID"}; + Configurable> onnxFileNames{"onnxFileNames", std::vector{"filename"}, "ONNX file names for each bin (if not from CCDB full path)"}; + Configurable> onnxPathsCCDB{"onnxPathsCCDB", std::vector{"path"}, "Paths of models on CCDB"}; + Configurable> binsMl{"binsMl", std::vector{0.1, 0.15, 0.2, 0.25, 0.4, 0.8, 1.6, 2.0, 20}, "Bin limits for ML application"}; + Configurable> cutsMl{"cutsMl", std::vector{0.95, 0.95, 0.7, 0.7, 0.8, 0.8, 0.7, 0.7}, "ML cuts per bin"}; + Configurable> namesInputFeatures{"namesInputFeatures", std::vector{"multFT0C", "ptMCHMID", "rSigned1Pt", "dEta", "dPhi", "dX", "dY", "chi2MatchMCHMFT"}, "Names of ML model input features"}; + Configurable nameBinningFeature{"nameBinningFeature", "multFT0C", "Names of ML model binning feature"}; + Configurable timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB. Exceptions: > 0 for the specific timestamp, 0 gets the run dependent timestamp"}; + Configurable loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"}; + Configurable enableOptimizations{"enableOptimizations", false, "Enables the ONNX extended model-optimization: sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED)"}; + struct : ConfigurableGroup { std::string prefix = "eventcut_group"; Configurable cfgZvtxMin{"cfgZvtxMin", -10.f, "min. Zvtx"}; @@ -151,10 +167,28 @@ struct matchingMFT { } eventcuts; o2::aod::rctsel::RCTFlagsChecker rctChecker; + o2::analysis::MlResponseFwdTrack mlResponseFwdTrack; HistogramRegistry fRegistry{"fRegistry"}; static constexpr std::string_view muon_types[5] = {"MFTMCHMID/", "MFTMCHMIDOtherMatch/", "MFTMCH/", "MCHMID/", "MCH/"}; + struct matchedCandidate { + float multFT0C{0}; + float multMFT{0}; + float ptMCHMID{0}; + float rSigned1Pt{1e+10}; + float dEta{1e+10}; + float dPhi{1e+10}; + float dX{1e+10}; + float dY{1e+10}; + float chi2MatchMCHMFT{1e+10}; + + float sigmaPhiMFT{1e+10}; + float sigmaTglMFT{1e+10}; + float sigmaPhiMCHMID{1e+10}; + float sigmaTglMCHMID{1e+10}; + }; + void init(o2::framework::InitContext&) { if (doprocessWithoutFTTCA && doprocessWithFTTCA) { @@ -169,6 +203,31 @@ struct matchingMFT { rctChecker.init(eventcuts.cfgRCTLabel.value, eventcuts.cfgCheckZDC.value, eventcuts.cfgTreatLimitedAcceptanceAsBad.value); addHistograms(); + + if (useMLmatching) { + static constexpr int nClassesMl = 2; + const std::vector cutDirMl = {o2::cuts_ml::CutNot, o2::cuts_ml::CutSmaller}; + const std::vector labelsClasses = {"Background", "Signal"}; + const uint32_t nBinsMl = binsMl.value.size() - 1; + const std::vector labelsBins(nBinsMl, "bin"); + double cutsMlArr[nBinsMl][nClassesMl]; + for (uint32_t i = 0; i < nBinsMl; i++) { + cutsMlArr[i][0] = 0.0; + cutsMlArr[i][1] = cutsMl.value[i]; + } + o2::framework::LabeledArray cutsMl = {cutsMlArr[0], nBinsMl, nClassesMl, labelsBins, labelsClasses}; + + mlResponseFwdTrack.configure(binsMl.value, cutsMl, cutDirMl, nClassesMl); + if (loadModelsFromCCDB) { + ccdbApi.init(ccdburl); + mlResponseFwdTrack.setModelPathsCCDB(onnxFileNames.value, ccdbApi, onnxPathsCCDB.value, timestampCCDB.value); + } else { + mlResponseFwdTrack.setModelPathsLocal(onnxFileNames.value); + } + mlResponseFwdTrack.cacheInputFeaturesIndices(namesInputFeatures); + mlResponseFwdTrack.cacheBinningIndex(nameBinningFeature); + mlResponseFwdTrack.init(enableOptimizations.value); + } // end of ML configuration } o2::ccdb::CcdbApi ccdbApi; @@ -415,28 +474,6 @@ struct matchingMFT { return (clmap > 0); } - // template - // float meanClusterSizeMFT(T const& track) - // { - // uint64_t mftClusterSizesAndTrackFlags = track.mftClusterSizesAndTrackFlags(); - // uint16_t clsSize = 0; - // uint16_t n = 0; - // for (unsigned int layer = 0; layer < 10; layer++) { - // uint16_t size_per_layer = (mftClusterSizesAndTrackFlags >> (layer * 6)) & 0x3f; - // clsSize += size_per_layer; - // if (size_per_layer > 0) { - // n++; - // } - // // LOGF(info, "track.globalIndex() = %d, layer = %d, size_per_layer = %d", track.globalIndex(), layer, size_per_layer); - // } - - // if (n > 0) { - // return static_cast(clsSize) / static_cast(n) * std::fabs(std::sin(std::atan(track.tgl()))); - // } else { - // return 0.f; - // } - // } - template void getDxDyAtMatchingPlane(TCollision const& collision, TFwdTrack const& fwdtrack, TMFTrackCov const& mftCovs, float& dx, float& dy) { @@ -817,8 +854,8 @@ struct matchingMFT { std::vector> vec_min_chi2MatchMCHMFT; // std::pair -> chi2MatchMCHMFT; // std::map, bool> mapCorrectMatch; - template - void findBestMatchPerMCHMID(TCollision const& collision, TFwdTrack const& fwdtrack, TFwdTracks const& fwdtracks, TMFTTracks const&) + template + void findBestMatchPerMCHMID(TCollision const& collision, TFwdTrack const& fwdtrack, TFwdTracks const& fwdtracks, TMFTTracks const& mfttracks, TMFTTracksCov const& mftCovs) { if (fwdtrack.trackType() != o2::aod::fwdtrack::ForwardTrackTypeEnum::MuonStandaloneTrack) { return; @@ -827,6 +864,8 @@ struct matchingMFT { return; } + auto mfttracks_per_collision = mfttracks.sliceBy(perCollision_MFT, collision.globalIndex()); + std::tuple tupleIds_at_min_chi2mftmch; float min_chi2MatchMCHMFT = 1e+10; auto muons_per_MCHMID = fwdtracks.sliceBy(fwdtracksPerMCHTrack, fwdtrack.globalIndex()); @@ -843,6 +882,9 @@ struct matchingMFT { float dcaXY_Matched = std::sqrt(dcaX_Matched * dcaX_Matched + dcaY_Matched * dcaY_Matched); float pDCA = fwdtrack.p() * dcaXY_Matched; + o2::dataformats::GlobalFwdTrack muonAtMP = propagateMuon(fwdtrack, fwdtrack, collision, propagationPoint::kToMatchingPlane, matchingZ, mBz, mZShift); // propagated to matching plane + float phiMCHMIDatMP = RecoDecay::constrainAngle(muonAtMP.getPhi(), 0, 1U); + for (const auto& muon_tmp : muons_per_MCHMID) { if (muon_tmp.trackType() == o2::aod::fwdtrack::ForwardTrackTypeEnum::GlobalMuonTrack) { auto tupleId = std::make_tuple(muon_tmp.globalIndex(), muon_tmp.matchMCHTrackId(), muon_tmp.matchMFTTrackId()); @@ -885,6 +927,44 @@ struct matchingMFT { float dcaY = propmuonAtPV.getY() - collision.posY(); float dcaXY = std::sqrt(dcaX * dcaX + dcaY * dcaY); + if constexpr (withMFTCov) { + if (useMLmatching) { + matchedCandidate candidate; + candidate.multFT0C = collision.multFT0C(); + candidate.multMFT = static_cast(mfttracks_per_collision.size()); + candidate.chi2MatchMCHMFT = muon_tmp.chi2MatchMCHMFT(); + + auto mfttrackcov = mftCovs.rawIteratorAt(map_mfttrackcovs[mfttrack.globalIndex()]); + o2::track::TrackParCovFwd mftsaAtMP = getTrackParCovFwdShift(mfttrack, mZShift, mfttrackcov); // values at innermost update + mftsaAtMP.propagateToZhelix(matchingZ, mBz); // propagated to matching plane + float phiMFTatMP = RecoDecay::constrainAngle(mftsaAtMP.getPhi(), 0, 1U); + + candidate.rSigned1Pt = mftsaAtMP.getInvQPt() / muonAtMP.getInvQPt(); + candidate.dEta = mftsaAtMP.getEta() - muonAtMP.getEta(); + candidate.dPhi = RecoDecay::constrainAngle(phiMFTatMP - phiMCHMIDatMP, -o2::constants::math::PIHalf, 1U); + candidate.dX = mftsaAtMP.getX() - muonAtMP.getX(); + candidate.dY = mftsaAtMP.getY() - muonAtMP.getY(); + + candidate.sigmaTglMCHMID = std::sqrt(muonAtMP.getSigma2Tanl()); + candidate.sigmaPhiMCHMID = std::sqrt(muonAtMP.getSigma2Phi()); + candidate.sigmaTglMFT = std::sqrt(mftsaAtMP.getSigma2Tanl()); + candidate.sigmaPhiMFT = std::sqrt(mftsaAtMP.getSigma2Phi()); + + std::vector inputFeatures = mlResponseFwdTrack.getInputFeatures(candidate); + float binningFeature = mlResponseFwdTrack.getBinningFeature(candidate); + int pbin = lower_bound(binsMl.value.begin(), binsMl.value.end(), binningFeature) - binsMl.value.begin() - 1; + if (pbin < 0) { + pbin = 0; + } else if (static_cast(binsMl.value.size()) - 2 < pbin) { + pbin = static_cast(binsMl.value.size()) - 2; + } + float probaEl = mlResponseFwdTrack.getModelOutput(inputFeatures, pbin)[1]; // 0: wrong, 1:correct + if (probaEl < cutsMl.value[pbin]) { + continue; + } + } + } + if (isPrimary) { if (isMatched) { fRegistry.fill(HIST("MFTMCHMID/primary/correct/hdR_Chi2MatchMCHMFT"), muon_tmp.chi2MatchMCHMFT(), dr); @@ -1092,7 +1172,7 @@ struct matchingMFT { initCCDB(bc); auto fwdtracks_per_coll = fwdtracks.sliceBy(perCollision, collision.globalIndex()); for (const auto& fwdtrack : fwdtracks_per_coll) { - findBestMatchPerMCHMID(collision, fwdtrack, fwdtracks, mfttracks); + findBestMatchPerMCHMID(collision, fwdtrack, fwdtracks, mfttracks, nullptr); } // end of fwdtrack loop } // end of collision loop @@ -1150,7 +1230,7 @@ struct matchingMFT { auto fwdtrackIdsThisCollision = fwdtrackIndices.sliceBy(fwdtrackIndicesPerCollision, collision.globalIndex()); for (const auto& fwdtrackId : fwdtrackIdsThisCollision) { auto fwdtrack = fwdtrackId.template fwdtrack_as(); - findBestMatchPerMCHMID(collision, fwdtrack, fwdtracks, mfttracks); + findBestMatchPerMCHMID(collision, fwdtrack, fwdtracks, mfttracks, nullptr); } // end of fwdtrack loop } // end of collision loop @@ -1213,7 +1293,7 @@ struct matchingMFT { auto fwdtrackIdsThisCollision = fwdtrackIndices.sliceBy(fwdtrackIndicesPerCollision, collision.globalIndex()); for (const auto& fwdtrackId : fwdtrackIdsThisCollision) { auto fwdtrack = fwdtrackId.template fwdtrack_as(); - findBestMatchPerMCHMID(collision, fwdtrack, fwdtracks, mfttracks); + findBestMatchPerMCHMID(collision, fwdtrack, fwdtracks, mfttracks, mftCovs); } // end of fwdtrack loop } // end of collision loop diff --git a/PWGEM/Dilepton/Utils/MlResponseFwdTrack.h b/PWGEM/Dilepton/Utils/MlResponseFwdTrack.h new file mode 100644 index 00000000000..073d3df9226 --- /dev/null +++ b/PWGEM/Dilepton/Utils/MlResponseFwdTrack.h @@ -0,0 +1,157 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +/// \file MlResponseFwdTrack.h +/// \brief Class to compute the ML response for fwdtracks +/// \author Daiki Sekihata + +#ifndef PWGEM_DILEPTON_UTILS_MLRESPONSEFWDTRACK_H_ +#define PWGEM_DILEPTON_UTILS_MLRESPONSEFWDTRACK_H_ + +#include "Tools/ML/MlResponse.h" + +#include + +#include +#include +#include + +// Fill the map of available input features +// the key is the feature's name (std::string) +// the value is the corresponding value in EnumInputFeatures +#define FILL_MAP_TRACK(FEATURE) \ + { \ + #FEATURE, static_cast(InputFeaturesFwdTrack::FEATURE)} + +// Check if the index of mCachedIndices (index associated to a FEATURE) +// matches the entry in EnumInputFeatures associated to this FEATURE +// if so, the inputFeatures vector is filled with the FEATURE's value +// by calling the corresponding GETTER=FEATURE from track +#define CHECK_AND_FILL_TRACK(GETTER) \ + case static_cast(InputFeaturesFwdTrack::GETTER): { \ + inputFeature = track.GETTER; \ + break; \ + } + +namespace o2::analysis +{ +// possible input features for ML +enum class InputFeaturesFwdTrack : uint8_t { + multFT0C, + multMFT, + ptMCHMID, + rSigned1Pt, + dEta, + dPhi, + dX, + dY, + chi2MatchMCHMFT, + sigmaPhiMFT, + sigmaTglMFT, + sigmaPhiMCHMID, + sigmaTglMCHMID, +}; + +template +class MlResponseFwdTrack : public MlResponse +{ + public: + /// Default constructor + MlResponseFwdTrack() = default; + /// Default destructor + virtual ~MlResponseFwdTrack() = default; + + template + float return_feature(uint8_t idx, T const& track) + { + float inputFeature = 0.; + switch (idx) { + CHECK_AND_FILL_TRACK(multFT0C); + CHECK_AND_FILL_TRACK(multMFT); + CHECK_AND_FILL_TRACK(ptMCHMID); + CHECK_AND_FILL_TRACK(rSigned1Pt); + CHECK_AND_FILL_TRACK(dEta); + CHECK_AND_FILL_TRACK(dPhi); + CHECK_AND_FILL_TRACK(dX); + CHECK_AND_FILL_TRACK(dY); + CHECK_AND_FILL_TRACK(chi2MatchMCHMFT); + CHECK_AND_FILL_TRACK(sigmaPhiMFT); + CHECK_AND_FILL_TRACK(sigmaTglMFT); + CHECK_AND_FILL_TRACK(sigmaPhiMCHMID); + CHECK_AND_FILL_TRACK(sigmaTglMCHMID); + } + + return inputFeature; + } + + /// Method to get the input features vector needed for ML inference + /// \param track is the single track, \param collision is the collision + /// \return inputFeatures vector + template + std::vector getInputFeatures(T const& track) + { + std::vector inputFeatures; + for (const auto& idx : MlResponse::mCachedIndices) { + float inputFeature = return_feature(idx, track); + inputFeatures.emplace_back(inputFeature); + } + return inputFeatures; + } + + /// Method to get the value of variable chosen for binning + /// \param track is the single track, \param collision is the collision + /// \return binning variable + template + float getBinningFeature(T const& track) + { + return return_feature(mCachedIndexBinning, track); + } + + void cacheBinningIndex(std::string const& cfgBinningFeature) + { + setAvailableInputFeatures(); + if (MlResponse::mAvailableInputFeatures.count(cfgBinningFeature)) { + mCachedIndexBinning = MlResponse::mAvailableInputFeatures[cfgBinningFeature]; + } else { + LOG(fatal) << "Binning feature " << cfgBinningFeature << " not available! Please check your configurables."; + } + } + + protected: + /// Method to fill the map of available input features + void setAvailableInputFeatures() + { + MlResponse::mAvailableInputFeatures = { + FILL_MAP_TRACK(multFT0C), + FILL_MAP_TRACK(multMFT), + FILL_MAP_TRACK(ptMCHMID), + FILL_MAP_TRACK(rSigned1Pt), + FILL_MAP_TRACK(dEta), + FILL_MAP_TRACK(dPhi), + FILL_MAP_TRACK(dX), + FILL_MAP_TRACK(dY), + FILL_MAP_TRACK(chi2MatchMCHMFT), + FILL_MAP_TRACK(sigmaPhiMFT), + FILL_MAP_TRACK(sigmaTglMFT), + FILL_MAP_TRACK(sigmaPhiMCHMID), + FILL_MAP_TRACK(sigmaTglMCHMID), + }; + } + + uint8_t mCachedIndexBinning; // index correspondance between configurable and available input features +}; + +} // namespace o2::analysis + +#undef FILL_MAP_TRACK +#undef CHECK_AND_FILL_TRACK + +#endif // PWGEM_DILEPTON_UTILS_MLRESPONSEFWDTRACK_H_